# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from abc import ABC, abstractmethod
from typing import List, Optional, Union
from nvflare.apis.filter import Filter
from nvflare.app_common.widgets.decomposer_reg import DecomposerRegister
from nvflare.fuel.utils.fobs import Decomposer
from nvflare.job_config.api import FedJob
from nvflare.job_config.defs import FilterType
[docs]
class ExecEnv(ABC):
def __init__(self, extra: dict = None):
"""Constructor of ExecEnv
Args:
extra: a dict of extra properties
"""
if not extra:
extra = {}
if not isinstance(extra, dict):
raise ValueError(f"extra must be dict but got {type(extra)}")
self.extra = extra
[docs]
@abstractmethod
def deploy(self, job: FedJob) -> str:
"""Deploy a FedJob and return an execution response.
Args:
job: The FedJob to deploy.
Returns:
str: The job ID.
"""
pass
[docs]
@abstractmethod
def get_job_status(self, job_id: str) -> Optional[str]:
"""Get the status of a job.
Args:
job_id: The job ID to check status for.
Returns:
Optional[str]: The status of the job, or None if not supported.
"""
pass
[docs]
@abstractmethod
def abort_job(self, job_id: str) -> None:
"""Abort a running job.
Args:
job_id: The job ID to abort.
"""
pass
[docs]
@abstractmethod
def get_job_result(self, job_id: str, timeout: float = 0.0) -> Optional[str]:
"""Get the result workspace of a job.
Args:
job_id: The job ID to get results for.
timeout: The timeout for the job to complete. Defaults to 0.0 (no timeout).
Returns:
Optional[str]: The result workspace path if job completed, None if still running or stopped early.
"""
pass
[docs]
class Recipe(ABC):
def __init__(self, job: FedJob):
"""This is base class of a recipe. Recipes are implemented by jobs.
A concrete recipe must provide the job for recipe implementation.
Args:
job: the job that implements the recipe.
"""
self.job = job
[docs]
def process_env(self, env: ExecEnv):
pass
[docs]
def add_client_output_filter(
self, filter: Filter, tasks: Optional[List[str]] = None, clients: Optional[List[str]] = None
):
"""Add a filter to clients for outgoing result to server.
Args:
filter: the filter to be added
tasks: tasks that the filter applies to
clients: client names to add, if None, all clients will be added.
Returns: None
"""
if clients is None:
self.job.to_clients(filter, filter_type=FilterType.TASK_RESULT, tasks=tasks)
else:
for client in clients:
self.job.to(filter, client, filter_type=FilterType.TASK_RESULT, tasks=tasks)
[docs]
def add_server_output_filter(self, filter: Filter, tasks: Optional[List[str]] = None):
"""Add a filter to the server for outgoing tasks to clients.
Args:
filter: the filter to be added
tasks: tasks that the filter applies to
Returns: None
"""
self.job.to_server(filter, filter_type=FilterType.TASK_DATA, tasks=tasks)
@staticmethod
def _get_full_class_name(obj):
"""
Returns the fully qualified name of an object.
"""
cls = type(obj)
module = cls.__module__
qualname = cls.__qualname__
if module == "builtins": # For built-in types like int, str, etc.
return qualname
return f"{module}.{qualname}"
[docs]
def add_decomposers(self, decomposers: List[Union[str, Decomposer]]):
"""Add decomposers to the job
Args:
decomposers: spec of decomposers. Can be class names or Decomposer objects
Returns: None
"""
if not decomposers:
return
class_names = []
for d in decomposers:
if isinstance(d, str):
# class name
class_names.append(d)
elif isinstance(d, Decomposer):
class_names.append(self._get_full_class_name(d))
reg = DecomposerRegister(class_names)
self.job.to_server(reg, id="decomposer_reg")
self.job.to_clients(reg, id="decomposer_reg")
[docs]
def export(
self,
job_dir: str,
server_exec_params: dict = None,
client_exec_params: dict = None,
env: ExecEnv = None,
):
"""Export the recipe to a job definition.
Args:
job_dir: directory where the job will be exported to.
server_exec_params: execution params for the server
client_exec_params: execution params for clients
env: the environment that the exported job will be running in
Returns: None
"""
if server_exec_params:
self.job.to_server(server_exec_params)
if client_exec_params:
self.job.to_clients(client_exec_params)
if env:
self.process_env(env)
self.job.export_job(job_dir)
[docs]
def execute(self, env: ExecEnv, server_exec_params: dict = None, client_exec_params: dict = None) -> "Run":
"""Execute the recipe in a specified execution environment.
Args:
env: the execution environment
server_exec_params: execution params for the server
client_exec_params: execution params for clients
Returns: Run to get job ID and execution results
"""
if server_exec_params:
self.job.to_server(server_exec_params)
if client_exec_params:
self.job.to_clients(client_exec_params)
self.process_env(env)
job_id = env.deploy(self.job)
from nvflare.recipe.run import Run
run = Run(env, job_id)
return run