# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os.path
import re
import uuid
from typing import Any, Dict, List, Optional, Union
from nvflare.apis.executor import Executor
from nvflare.apis.filter import Filter
from nvflare.apis.fl_constant import ConfigVarName
from nvflare.apis.impl.controller import Controller
from nvflare.apis.job_def import ALL_SITES, SERVER_SITE_NAME
from nvflare.fuel.utils.class_utils import get_component_init_parameters
from nvflare.fuel.utils.validation_utils import check_object_type, check_positive_int, check_str
from nvflare.job_config.fed_app_config import ClientAppConfig, FedAppConfig, ServerAppConfig
from nvflare.job_config.fed_job_config import FedJobConfig
from .defs import FilterType, JobTargetType
SPECIAL_CHARACTERS = '"!@#$%^&*()+?=,<>/'
_ADD_TO_JOB_METHOD_NAME = "add_to_fed_job"
[docs]
class FedApp:
def __init__(self, app_config: Union[ClientAppConfig, ServerAppConfig]):
"""FedApp handles `ClientAppConfig` and `ServerAppConfig` and allows setting task result or task data filters."""
self.app_config = app_config
self._used_ids = []
# obj_id => comp_id
# obj_id is the Python's object ID; comp_id is the component ID for job config
# _oid_to_cid keeps the mapping between obj_id and comp_id.
# this is to make sure that when the same object is used, it is configured only once in the job.
self._oid_to_cid = {}
[docs]
def get_app_config(self):
return self.app_config
[docs]
def add_task_result_filter(self, tasks: List[str], task_filter: Filter):
self.app_config.add_task_result_filter(tasks, task_filter)
[docs]
def add_task_data_filter(self, tasks: List[str], task_filter: Filter):
self.app_config.add_task_data_filter(tasks, task_filter)
[docs]
def add_component(self, component, comp_id=None):
# is the component already configured?
oid = id(component)
cid = self._oid_to_cid.get(oid)
if cid:
# the component is already configured
return cid
if comp_id is None:
comp_id = "component"
final_id = self.generate_tracked_id(comp_id)
self.app_config.add_component(final_id, component)
self._oid_to_cid[oid] = final_id
return final_id
def _generate_id(self, id: str = "") -> str:
if id not in self._used_ids:
return id
else:
while id in self._used_ids:
# increase integer counts in id
cnt = re.search(r"\d+", id)
if cnt:
cnt = cnt.group()
id = id.replace(cnt, str(int(cnt) + 1))
else:
id = id + "1"
return id
[docs]
def generate_tracked_id(self, id: str = "") -> str:
id = self._generate_id(id)
self._used_ids.append(id)
return id
[docs]
def add_external_script(self, ext_script: str):
"""Register external script to include them in custom directory.
Args:
ext_script: List of external scripts that need to be deployed to the client/server.
"""
self.app_config.add_ext_script(ext_script)
[docs]
def add_external_dir(self, ext_dir: str):
"""Register external folder to include them in custom directory.
Args:
ext_dir: external folder that need to be deployed to the client/server.
"""
self.app_config.add_ext_dir(ext_dir)
[docs]
def add_file_source(self, src_path: str, dest_dir=None, app_folder_type=None):
self.app_config.add_file_source(src_path, dest_dir, app_folder_type)
[docs]
def add_params(self, args: Dict[str, any]):
"""Add additional system configuration parameters to be included in the generated JSON configs.
Args:
args: Dictionary of system configuration parameters (e.g., {"timeout": 600, "max_retries": 3})
"""
self.app_config.add_params(args)
def _add_resource(self, resource: str):
if not isinstance(resource, str):
raise ValueError(f"cannot add resource: resource must be a str but got {type(resource)}")
elif os.path.isdir(resource):
self.add_external_dir(resource)
elif os.path.isfile(resource):
self.add_external_script(resource)
elif os.path.isabs(resource):
# Absolute path that doesn't exist locally - add_external_script accepts absolute paths
# Validation based on ExecEnv will happen in each env's deploy() method
self.add_external_script(resource)
else:
raise ValueError(f"cannot add resource: invalid resource {resource}: it must be either a directory or file")
[docs]
def add_resources(self, resources: List[str]):
"""Add resources to the job. To be used by job component programmer.
Args:
resources:
Returns:
"""
for r in resources:
self._add_resource(r)
[docs]
class JobCtx:
def __init__(self, obj: Any, target: str, comp_id: str):
self.obj = obj
self.target = target
self.comp_id = comp_id
[docs]
class ClientApp(FedApp):
def __init__(self):
"""Wrapper around `ClientAppConfig`."""
super().__init__(ClientAppConfig())
[docs]
def add_executor(self, executor: Executor, tasks=None):
if not tasks:
tasks = ["*"] # Add executor for any task by default
self.app_config.add_executor(tasks, executor)
[docs]
class ServerApp(FedApp):
"""Wrapper around `ServerAppConfig`."""
def __init__(self):
super().__init__(ServerAppConfig())
[docs]
def add_controller(self, controller: Controller, id=None):
if not id:
id = "controller"
self.app_config.add_workflow(self.generate_tracked_id(id), controller)
[docs]
class FedJob:
def __init__(
self,
name: str = "fed_job",
min_clients: int = 1,
mandatory_clients: Optional[List[str]] = None,
meta_props: Optional[Dict[str, Any]] = None,
fail_fast: bool = False,
) -> None:
"""FedJob allows users to generate job configurations in a Pythonic way.
The `to()` routine allows users to send different components to either the server or clients.
Args:
name: the name of the NVFlare job
min_clients: the minimum number of clients for the job
mandatory_clients: mandatory clients to run the job (optional)
meta_props: additional meta properties for the job (optional)
fail_fast: if True, sets dead_client_grace_period to 0 so that a client already
reported dead is declared disconnected on the next monitor tick (~0.2 s) rather
than after the default 60-second grace period. The job then aborts only when the
normal deployment policy is violated: alive clients drop below min_clients, all
clients die, or a mandatory client is lost. In the common development scenario
where min_clients equals the total number of enrolled clients, this means an
immediate abort on any client failure; when min_clients < total enrolled, the
disconnect is simply detected faster without necessarily aborting the job.
When False (the default), the existing dead-client grace period behaviour applies.
"""
check_str("name", name)
check_positive_int("min_clients", min_clients)
if mandatory_clients:
check_object_type("mandatory_clients", mandatory_clients, list)
if meta_props:
check_object_type("meta_props", meta_props, dict)
check_object_type("fail_fast", fail_fast, bool)
self.name = name
self.clients = []
self._fail_fast = fail_fast
self.job: FedJobConfig = FedJobConfig(
job_name=self.name,
min_clients=min_clients,
mandatory_clients=mandatory_clients,
meta_props=meta_props,
)
self._deploy_map = {}
self._deployed = False
self._components = {}
[docs]
def set_app_packages(self, app_packages: List[str]):
"""Set app packages.
When generating job config, code from these packages will not be included into "custom" folder.
Args:
app_packages: app packages to be set
Returns: None
"""
self.job.set_app_packages(app_packages)
[docs]
def set_up_client(self, target: str):
"""Setup routine called by FedJob when first sending object to a client target.
Args:
target: the target to perform setup.
Returns:
"""
pass
def _add_server_app(self, obj: ServerApp, target: str):
self._deploy_map[target] = obj
def _add_client_app(self, obj: ClientApp, target: str):
self._deploy_map[target] = obj
if target not in self.clients:
self.clients.append(target)
[docs]
def to(
self,
obj: Any,
target: str,
id=None,
**kwargs,
) -> Any:
"""Assign an object to the target. For end users.
Args:
obj: the object to be assigned
target: the target that the object is assigned to
id: the id of the object
**kwargs: additional args to be passed to the object's add_to_fed_job method.
If the obj provides the add_to_fed_job method, it will be called with the kwargs.
This method must follow this signature:
add_to_fed_job(job, ctx, ...)
job: this is the job (self)
ctx: this is the JobCtx that keeps contextual info of this call.
The add_to_fed_job function is usually implemented in FL component classes.
When implementing this function, you should not use anything in the ctx; instead, you should use
the "add_xxx" methods of the "job" object: add_component, add_resources, add_filter, add_executor, etc.
Returns:
result of add_to_job_method if called, or id of added component
"""
if not obj:
raise ValueError("cannot add empty object to job")
if isinstance(obj, (ClientApp, ServerApp)):
raise ValueError("adding (ClientApp, ServerApp) is not allowed")
self._validate_target(target)
target_type = JobTargetType.get_target_type(target)
app = self._deploy_map.get(target)
if not app:
if target_type == JobTargetType.SERVER:
app = ServerApp()
self._add_server_app(app, target)
else:
app = ClientApp()
self._add_client_app(app, target)
self.set_up_client(target)
if isinstance(obj, str): # treat the str type object as external script
if os.path.isdir(obj):
app.add_external_dir(obj)
else:
app.add_external_script(obj)
return None
if isinstance(obj, dict): # treat dict type object as additional system parameters
app.add_params(obj)
return None
get_target_type_method = getattr(obj, "get_job_target_type", None)
if get_target_type_method is not None:
expected_target_type = get_target_type_method()
if expected_target_type != target_type:
if target_type == JobTargetType.SERVER:
raise ValueError(f"this object can only be assigned to server, but tried to assign to {target}")
else:
raise ValueError(f"this object can only be assigned to client, but tried to assign to {target}")
add_to_job_method = getattr(obj, _ADD_TO_JOB_METHOD_NAME, None)
if add_to_job_method is not None:
ctx = JobCtx(obj, target, id)
result = add_to_job_method(self, ctx, **kwargs)
else:
# basic object
result = app.add_component(obj, id)
# add any other components the object might have referenced via id
if self._components:
self._add_referenced_components(obj, target)
return result
def _add_referenced_components(self, base_component, target):
"""Adds any other components the object might have referenced via id"""
# Check all arguments for ids referenced with .as_id()
if hasattr(base_component, "__dict__"):
parameters = get_component_init_parameters(base_component)
attrs = base_component.__dict__
for param in parameters:
attr_key = param if param in attrs.keys() else "_" + param
if attr_key in attrs.keys():
base_id = attrs[attr_key]
if isinstance(base_id, str): # could be id
if base_id in self._components:
self._deploy_map[target].add_component(self._components[base_id], base_id)
# add any components referenced by this component
self._add_referenced_components(self._components[base_id], target)
# remove already added components from tracked components
self._components.pop(base_id)
def _get_app(self, ctx: JobCtx):
app = self._deploy_map.get(ctx.target)
if not app:
target_type = JobTargetType.get_target_type(ctx.target)
if target_type == JobTargetType.CLIENT:
app_type = "a ClientApp"
else:
app_type = "a ServerApp"
raise RuntimeError(f"No app found for target '{ctx.target}' - missing {app_type}")
return app
[docs]
def add_component(self, comp_id: str, obj: Any, ctx: JobCtx):
"""Add a component to the job. To be used by job component programmer.
Args:
comp_id: component id
obj: component to be added to job.
ctx: JobCtx for contextual information.
Returns:
final id assigned to component.
"""
app = self._get_app(ctx)
if not comp_id:
comp_id = ctx.comp_id
final_id = app.add_component(obj, comp_id)
if self._components:
self._add_referenced_components(obj, ctx.target)
return final_id
[docs]
def add_controller(self, obj: Controller, ctx: JobCtx):
"""Add a Controller object to the job. To be used by controller programmer.
Args:
obj: Controller to be added to job.
ctx: JobCtx for contextual information.
Returns:
"""
target_type = JobTargetType.get_target_type(ctx.target)
app = self._get_app(ctx)
if target_type != JobTargetType.SERVER: # add client-side controllers as components
app.add_component(obj, ctx.comp_id)
else:
app.add_controller(obj, ctx.comp_id)
[docs]
def add_executor(self, obj: Executor, tasks: List[str], ctx: JobCtx):
"""Add an executor to the job. To be used by executor programmer.
Args:
obj: Executor to be added to job.
tasks: List of tasks that should be handled. If `None`, all tasks will be handled using `[*]`.
ctx: JobCtx for contextual information.
Returns:
"""
app = self._get_app(ctx)
app.add_executor(obj, tasks=tasks)
[docs]
def add_filter(self, obj: Filter, filter_type: str, tasks, ctx: JobCtx):
"""Add a filter to the job. To be used by filter programmer.
Args:
obj: Filter to be added to job.
filter_type: The type of filter used. Either `FilterType.TASK_RESULT` or `FilterType.TASK_DATA`.
tasks: List of tasks that Filter applies to.
ctx: JobCtx for contextual information.
Returns:
"""
app = self._get_app(ctx)
if filter_type == FilterType.TASK_RESULT:
app.add_task_result_filter(tasks, obj)
elif filter_type == FilterType.TASK_DATA:
app.add_task_data_filter(tasks, obj)
else:
raise ValueError(
f"Provided a filter for {ctx.target} without specifying a valid `filter_type`. "
f"Select from `FilterType.TASK_RESULT` or `FilterType.TASK_DATA`."
)
[docs]
def add_resources(self, resources: List[str], ctx: JobCtx):
"""Add resources to the job. To be used by job component programmer.
Args:
resources: List of filenames or directories to be added to job.
ctx: JobCtx for contextual information.
Returns:
"""
app = self._get_app(ctx)
app.add_resources(resources)
[docs]
def add_file_source(self, src_path: str, dest_dir, app_folder_type, ctx: JobCtx):
"""Add a file source to the job. To be used by job component programmer.
Args:
src_path: path to the source to be added to job.
dest_dir: destination path for the source
app_folder_type: type of app folder to place the files
ctx: JobCtx for contextual information.
Returns:
"""
app = self._get_app(ctx)
app.add_file_source(src_path, dest_dir, app_folder_type)
[docs]
def add_params(self, args: Dict[str, any], ctx: JobCtx):
"""Add additional system configuration parameters to the job. To be used by job component programmer.
Args:
args: Dictionary of configuration parameters (e.g., {"timeout": 600, "max_retries": 3})
ctx: JobCtx for contextual information.
Returns:
"""
app = self._get_app(ctx)
app.add_params(args)
[docs]
def to_server(
self,
obj: Any,
id=None,
**kwargs,
):
"""assign an object to the server. For end users.
Args:
obj: The object to be assigned. The obj will be given a default `id` if none is provided based on its type.
id: Optional user-defined id for the object. Defaults to `None` and ID will automatically be assigned.
**kwargs: additional args to be passed to the object's add_to_fed_job method.
Returns:
result of add_to_job_method if called, or id of added component
"""
if isinstance(obj, Executor):
raise ValueError("Use `job.to(executor, <client_name>)` or `job.to_clients(executor)` for Executors.")
return self.to(obj=obj, target=SERVER_SITE_NAME, id=id, **kwargs)
[docs]
def to_clients(
self,
obj: Any,
id=None,
**kwargs,
):
"""assign an object to all clients. For end users.
Args:
obj (Any): Object to be deployed.
id: Optional user-defined id for the object. Defaults to `None` and ID will automatically be assigned.
**kwargs: additional args to be passed to the object's add_to_fed_job method.
Returns:
result of add_to_job_method if called, or id of added component
"""
if isinstance(obj, Controller):
raise ValueError('Use `job.to(controller, "server")` or `job.to_server(controller)` for Controllers.')
return self.to(obj=obj, target=ALL_SITES, id=id, **kwargs)
[docs]
def add_file_to(self, src_path: str, target: str, dest_dir=None, app_folder_type=None):
"""Add a file to a specific target's app directory.
Args:
src_path: Local path to the file to be bundled into the job.
target: Target site name (e.g., "server", "site-1", or ALL_SITES for all clients).
dest_dir: Optional subdirectory within the target folder to place the file.
app_folder_type: Type of app folder to place the file. Valid values: "custom", "config".
If not specified, defaults to "custom".
"""
self._validate_target(target)
target_type = JobTargetType.get_target_type(target)
app = self._deploy_map.get(target)
if not app:
if target_type == JobTargetType.SERVER:
app = ServerApp()
self._add_server_app(app, target)
else:
app = ClientApp()
self._add_client_app(app, target)
app.add_file_source(src_path, dest_dir, app_folder_type)
[docs]
def add_file_to_server(self, src_path: str, dest_dir=None, app_folder_type=None):
"""Add a file to the server app directory.
Args:
src_path: Local path to the file to be bundled into the job.
dest_dir: Optional subdirectory within the target folder to place the file.
app_folder_type: Type of app folder to place the file. Valid values: "custom", "config".
If not specified, defaults to "custom".
"""
self.add_file_to(src_path, SERVER_SITE_NAME, dest_dir, app_folder_type)
[docs]
def add_file_to_clients(self, src_path: str, dest_dir=None, app_folder_type=None):
"""Add a file to all client apps' directory.
Args:
src_path: Local path to the file to be bundled into the job.
dest_dir: Optional subdirectory within the target folder to place the file.
app_folder_type: Type of app folder to place the file. Valid values: "custom", "config".
If not specified, defaults to "custom".
"""
self.add_file_to(src_path, ALL_SITES, dest_dir, app_folder_type)
def _validate_target(self, target):
if not target:
raise ValueError("Must provide a valid target name")
if any(c in SPECIAL_CHARACTERS for c in target) and target != ALL_SITES:
raise ValueError(f"target {target} name contains invalid character")
def _apply_fail_fast(self, server_config: ServerAppConfig):
"""Inject fail_fast configuration into the server app config.
When fail_fast is enabled, sets dead_client_grace_period to 0 so a client
already reported dead is declared disconnected on the next monitor tick
instead of after the default 60-second grace period. The job still aborts
only when the normal deployment policy is violated (alive < min_clients,
all dead, or a required client lost) - this only changes how quickly that
check trips.
Args:
server_config: the ServerAppConfig to update.
"""
if self._fail_fast:
server_config.add_params({ConfigVarName.DEAD_CLIENT_GRACE_PERIOD: 0})
def _set_all_app(self, client_app: ClientApp, server_app: ServerApp):
if not isinstance(client_app, ClientApp):
raise ValueError(f"`client_app` needs to be of type `ClientApp` but was type {type(client_app)}")
if not isinstance(server_app, ServerApp):
raise ValueError(f"`server_app` needs to be of type `ServerApp` but was type {type(server_app)}")
client_config = client_app.get_app_config()
server_config = server_app.get_app_config()
self._apply_fail_fast(server_config)
app_config = FedAppConfig(server_app=server_config, client_app=client_config)
app_name = "app"
self.job.add_fed_app(app_name, app_config)
self.job.set_site_app(ALL_SITES, app_name)
def _set_site_app(self, app: FedApp, target: str):
if not isinstance(app, FedApp):
raise ValueError(f"App needs to be of type `FedApp` but was type {type(app)}")
client_server_config = app.get_app_config()
if isinstance(client_server_config, ClientAppConfig):
app_config = FedAppConfig(server_app=None, client_app=client_server_config)
app_name = f"app_{target}"
elif isinstance(client_server_config, ServerAppConfig):
self._apply_fail_fast(client_server_config) # intentionally server-only; clients don't read this key
app_config = FedAppConfig(server_app=client_server_config, client_app=None)
app_name = "app_server"
else:
raise ValueError(
f"App needs to be of type `ClientAppConfig` or `ServerAppConfig` but was type {type(client_server_config)}"
)
self.job.add_fed_app(app_name, app_config)
self.job.set_site_app(target, app_name)
def _set_all_apps(self):
if not self._deployed:
if ALL_SITES in self._deploy_map:
if SERVER_SITE_NAME not in self._deploy_map:
raise ValueError('Missing server components! Deploy using `to(obj, "server") or `to_server(obj)`')
self._set_all_app(client_app=self._deploy_map[ALL_SITES], server_app=self._deploy_map[SERVER_SITE_NAME])
else:
for target in self._deploy_map:
self._set_site_app(self._deploy_map[target], target)
self._deployed = True
[docs]
def export_job(self, job_root: str):
"""Export job config to `job_root` directory with name `self.name`.
For end users.
Args:
job_root: directory to export job configuration.
Returns:
"""
self._set_all_apps()
self.job.generate_job_config(job_root)
[docs]
def simulator_run(
self,
workspace: str,
n_clients: Optional[int] = None,
clients: Optional[List[str]] = None,
threads: Optional[int] = None,
gpu: Optional[str] = None,
log_config: Optional[str] = None,
):
"""Run the job with the simulator with the `workspace` using `clients` and `threads`.
For end users.
Args:
workspace: workspace directory for job.
n_clients: number of clients.
clients: client names.
threads: number of threads.
gpu: gpu assignments for simulating clients, comma separated
log_config: log config mode ('concise', 'msg_only', 'full', 'verbose'), filepath, or level
Returns:
"""
if clients:
self.clients = clients
self._set_all_apps()
if ALL_SITES in self.clients and not n_clients:
raise ValueError("Clients were not specified using to(). Please provide the number of clients to simulate.")
elif ALL_SITES in self.clients and n_clients:
check_positive_int("n_clients", n_clients)
self.clients = [f"site-{i}" for i in range(1, n_clients + 1)]
elif self.clients and n_clients:
raise ValueError("You already specified clients using `to()`. Don't use `n_clients` in simulator_run.")
n_clients = len(self.clients)
if threads is None:
threads = n_clients
return self.job.simulator_run(
workspace,
clients=",".join(self.clients),
n_clients=n_clients,
threads=threads,
gpu=gpu,
log_config=log_config,
)
[docs]
def as_id(self, obj: Any) -> str:
"""Generate and return uuid for `obj`. For end users.
If this id is referenced by another added object, this `obj` will also be added as a component.
"""
cid = str(uuid.uuid4())
self._components[cid] = obj
return cid
[docs]
@staticmethod
def check_kwargs(args_to_check: dict, args_expected: dict):
"""Check kwargs for arguments. Raise Error if required arg is missing, or unexpected arg is given.
Args:
args_to_check (dict): kwargs dictionary to check.
args_expected (dict): dictionary of argument name to boolean of whether argument is required (True) or optional (False).
"""
if not args_expected and not args_to_check:
return
if args_to_check and not args_expected:
raise ValueError(f"received args {list(args_to_check.keys())}, but no args expected")
args_info = {}
for k, required in args_expected.items():
args_info[k] = "required" if required else "optional"
# see whether required args are present
for k, required in args_expected.items():
if required and (not args_to_check or k not in args_to_check):
raise ValueError(f"Missing required arg '{k}'. " f"Supported args: {args_info}")
# see whether we got unexpected args
if args_to_check:
for k in args_to_check.keys():
if k not in args_expected:
raise ValueError(f"Received unexpected arg '{k}'. " f"Supported args: {args_info}")
[docs]
def has_add_to_job_method(obj: Any) -> bool:
add_to_job_method = getattr(obj, _ADD_TO_JOB_METHOD_NAME, None)
return add_to_job_method is not None and callable(add_to_job_method)
[docs]
def validate_object_for_job(name, obj, obj_type):
"""Check whether the specified object is valid for job.
The object must either have the add_to_fed_job method or is valid object type.
Args:
name: name of the object
obj: the object to be checked
obj_type: the object type that the object should be, if it doesn't have the add_to_fed_job method.
Returns: None
"""
if has_add_to_job_method(obj):
return
check_object_type(name, obj, obj_type)