Source code for nvflare.recipe.spec

# Copyright (c) 2025, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from abc import ABC, abstractmethod
from typing import List, Optional, Union

from nvflare.apis.filter import Filter
from nvflare.app_common.widgets.decomposer_reg import DecomposerRegister
from nvflare.fuel.utils.fobs import Decomposer
from nvflare.job_config.api import FedJob
from nvflare.job_config.defs import FilterType


[docs] class ExecEnv(ABC): def __init__(self, extra: dict = None): """Constructor of ExecEnv Args: extra: a dict of extra properties """ if not extra: extra = {} if not isinstance(extra, dict): raise ValueError(f"extra must be dict but got {type(extra)}") self.extra = extra
[docs] def get_extra_prop(self, prop_name: str, default=None): """Get the specified extra property. Args: prop_name: name of the property default: the default value to return if the named property does not exist. Returns: value of the property or the default """ return self.extra.get(prop_name, default)
[docs] @abstractmethod def deploy(self, job: FedJob) -> str: """Deploy a FedJob and return an execution response. Args: job: The FedJob to deploy. Returns: str: The job ID. """ pass
[docs] @abstractmethod def get_job_status(self, job_id: str) -> Optional[str]: """Get the status of a job. Args: job_id: The job ID to check status for. Returns: Optional[str]: The status of the job, or None if not supported. """ pass
[docs] @abstractmethod def abort_job(self, job_id: str) -> None: """Abort a running job. Args: job_id: The job ID to abort. """ pass
[docs] @abstractmethod def get_job_result(self, job_id: str, timeout: float = 0.0) -> Optional[str]: """Get the result workspace of a job. Args: job_id: The job ID to get results for. timeout: The timeout for the job to complete. Defaults to 0.0 (no timeout). Returns: Optional[str]: The result workspace path if job completed, None if still running or stopped early. """ pass
[docs] class Recipe(ABC): def __init__(self, job: FedJob): """This is base class of a recipe. Recipes are implemented by jobs. A concrete recipe must provide the job for recipe implementation. Args: job: the job that implements the recipe. """ self.job = job
[docs] def process_env(self, env: ExecEnv): pass
[docs] def add_client_input_filter( self, filter: Filter, tasks: Optional[List[str]] = None, clients: Optional[List[str]] = None ): """Add a filter to clients for incoming tasks from the server. Args: filter: the filter to be added tasks: tasks that the filter applies to clients: client names to add, if None, all clients will be added. Returns: None """ if clients is None: self.job.to_clients(filter, filter_type=FilterType.TASK_DATA, tasks=tasks) else: for client in clients: self.job.to(filter, client, filter_type=FilterType.TASK_DATA, tasks=tasks)
[docs] def add_client_output_filter( self, filter: Filter, tasks: Optional[List[str]] = None, clients: Optional[List[str]] = None ): """Add a filter to clients for outgoing result to server. Args: filter: the filter to be added tasks: tasks that the filter applies to clients: client names to add, if None, all clients will be added. Returns: None """ if clients is None: self.job.to_clients(filter, filter_type=FilterType.TASK_RESULT, tasks=tasks) else: for client in clients: self.job.to(filter, client, filter_type=FilterType.TASK_RESULT, tasks=tasks)
[docs] def add_server_output_filter(self, filter: Filter, tasks: Optional[List[str]] = None): """Add a filter to the server for outgoing tasks to clients. Args: filter: the filter to be added tasks: tasks that the filter applies to Returns: None """ self.job.to_server(filter, filter_type=FilterType.TASK_DATA, tasks=tasks)
[docs] def add_server_input_filter(self, filter: Filter, tasks: Optional[List[str]] = None): """Add a filter to server for incoming task result from clients. . Args: filter: the filter to be added tasks: tasks that the filter applies to Returns: None """ self.job.to_server(filter, filter_type=FilterType.TASK_RESULT, tasks=tasks)
@staticmethod def _get_full_class_name(obj): """ Returns the fully qualified name of an object. """ cls = type(obj) module = cls.__module__ qualname = cls.__qualname__ if module == "builtins": # For built-in types like int, str, etc. return qualname return f"{module}.{qualname}"
[docs] def add_decomposers(self, decomposers: List[Union[str, Decomposer]]): """Add decomposers to the job Args: decomposers: spec of decomposers. Can be class names or Decomposer objects Returns: None """ if not decomposers: return class_names = [] for d in decomposers: if isinstance(d, str): # class name class_names.append(d) elif isinstance(d, Decomposer): class_names.append(self._get_full_class_name(d)) reg = DecomposerRegister(class_names) self.job.to_server(reg, id="decomposer_reg") self.job.to_clients(reg, id="decomposer_reg")
[docs] def export( self, job_dir: str, server_exec_params: dict = None, client_exec_params: dict = None, env: ExecEnv = None, ): """Export the recipe to a job definition. Args: job_dir: directory where the job will be exported to. server_exec_params: execution params for the server client_exec_params: execution params for clients env: the environment that the exported job will be running in Returns: None """ if server_exec_params: self.job.to_server(server_exec_params) if client_exec_params: self.job.to_clients(client_exec_params) if env: self.process_env(env) self.job.export_job(job_dir)
[docs] def execute(self, env: ExecEnv, server_exec_params: dict = None, client_exec_params: dict = None) -> "Run": """Execute the recipe in a specified execution environment. Args: env: the execution environment server_exec_params: execution params for the server client_exec_params: execution params for clients Returns: Run to get job ID and execution results """ if server_exec_params: self.job.to_server(server_exec_params) if client_exec_params: self.job.to_clients(client_exec_params) self.process_env(env) job_id = env.deploy(self.job) from nvflare.recipe.run import Run run = Run(env, job_id) return run