# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os.path
import re
import uuid
from typing import Any, List, Optional, Union
from nvflare.apis.executor import Executor
from nvflare.apis.filter import Filter
from nvflare.apis.impl.controller import Controller
from nvflare.apis.job_def import ALL_SITES, SERVER_SITE_NAME
from nvflare.fuel.utils.class_utils import get_component_init_parameters
from nvflare.fuel.utils.validation_utils import check_object_type, check_positive_int
from nvflare.job_config.fed_app_config import ClientAppConfig, FedAppConfig, ServerAppConfig
from nvflare.job_config.fed_job_config import FedJobConfig
from .defs import FilterType, JobTargetType
SPECIAL_CHARACTERS = '"!@#$%^&*()+?=,<>/'
_ADD_TO_JOB_METHOD_NAME = "add_to_fed_job"
[docs]
class FedApp:
def __init__(self, app_config: Union[ClientAppConfig, ServerAppConfig]):
"""FedApp handles `ClientAppConfig` and `ServerAppConfig` and allows setting task result or task data filters."""
self.app_config = app_config
self._used_ids = []
# obj_id => comp_id
# obj_id is the Python's object ID; comp_id is the component ID for job config
# _oid_to_cid keeps the mapping between obj_id and comp_id.
# this is to make sure that when the same object is used, it is configured only once in the job.
self._oid_to_cid = {}
[docs]
def get_app_config(self):
return self.app_config
[docs]
def add_task_result_filter(self, tasks: List[str], task_filter: Filter):
self.app_config.add_task_result_filter(tasks, task_filter)
[docs]
def add_task_data_filter(self, tasks: List[str], task_filter: Filter):
self.app_config.add_task_data_filter(tasks, task_filter)
[docs]
def add_component(self, component, comp_id=None):
# is the component already configured?
oid = id(component)
cid = self._oid_to_cid.get(oid)
if cid:
# the component is already configured
return cid
if comp_id is None:
comp_id = "component"
final_id = self.generate_tracked_id(comp_id)
self.app_config.add_component(final_id, component)
self._oid_to_cid[oid] = final_id
return final_id
def _generate_id(self, id: str = "") -> str:
if id not in self._used_ids:
return id
else:
while id in self._used_ids:
# increase integer counts in id
cnt = re.search(r"\d+", id)
if cnt:
cnt = cnt.group()
id = id.replace(cnt, str(int(cnt) + 1))
else:
id = id + "1"
return id
[docs]
def generate_tracked_id(self, id: str = "") -> str:
id = self._generate_id(id)
self._used_ids.append(id)
return id
[docs]
def add_external_script(self, ext_script: str):
"""Register external script to include them in custom directory.
Args:
ext_script: List of external scripts that need to be deployed to the client/server.
"""
self.app_config.add_ext_script(ext_script)
[docs]
def add_external_dir(self, ext_dir: str):
"""Register external folder to include them in custom directory.
Args:
ext_dir: external folder that need to be deployed to the client/server.
"""
self.app_config.add_ext_dir(ext_dir)
[docs]
def add_file_source(self, src_path: str, dest_dir=None):
self.app_config.add_file_source(src_path, dest_dir)
def _add_resource(self, resource: str):
if not isinstance(resource, str):
raise ValueError(f"cannot add resource: resource must be a str but got {type(resource)}")
elif os.path.isdir(resource):
self.add_external_dir(resource)
elif os.path.isfile(resource):
self.add_external_script(resource)
else:
raise ValueError(f"cannot add resource: invalid resource {resource}: it must be either a directory or file")
[docs]
def add_resources(self, resources: List[str]):
"""Add resources to the job. To be used by job component programmer.
Args:
resources:
Returns:
"""
for r in resources:
self._add_resource(r)
[docs]
class JobCtx:
def __init__(self, obj: Any, target: str, comp_id: str):
self.obj = obj
self.target = target
self.comp_id = comp_id
[docs]
class ClientApp(FedApp):
def __init__(self):
"""Wrapper around `ClientAppConfig`."""
super().__init__(ClientAppConfig())
[docs]
def add_executor(self, executor: Executor, tasks=None):
if not tasks:
tasks = ["*"] # Add executor for any task by default
self.app_config.add_executor(tasks, executor)
[docs]
class ServerApp(FedApp):
"""Wrapper around `ServerAppConfig`."""
def __init__(self):
super().__init__(ServerAppConfig())
[docs]
def add_controller(self, controller: Controller, id=None):
if not id:
id = "controller"
self.app_config.add_workflow(self.generate_tracked_id(id), controller)
[docs]
class FedJob:
def __init__(
self,
name: str = "fed_job",
min_clients: int = 1,
mandatory_clients: Optional[List[str]] = None,
) -> None:
"""FedJob allows users to generate job configurations in a Pythonic way.
The `to()` routine allows users to send different components to either the server or clients.
Args:
name: the name of the NVFlare job
min_clients: the minimum number of clients for the job
mandatory_clients: mandatory clients to run the job (optional)
"""
self.name = name
self.clients = []
self.job: FedJobConfig = FedJobConfig(
job_name=self.name, min_clients=min_clients, mandatory_clients=mandatory_clients
)
self._deploy_map = {}
self._deployed = False
self._components = {}
[docs]
def set_up_client(self, target: str):
"""Setup routine called by FedJob when first sending object to a client target.
Args:
target: the target to perform setup.
Returns:
"""
pass
def _add_server_app(self, obj: ServerApp, target: str):
self._deploy_map[target] = obj
def _add_client_app(self, obj: ClientApp, target: str):
self._deploy_map[target] = obj
if target not in self.clients:
self.clients.append(target)
[docs]
def to(
self,
obj: Any,
target: str,
id=None,
**kwargs,
) -> Any:
"""Assign an object to the target. For end users.
Args:
obj: the object to be assigned
target: the target that the object is assigned to
id: the id of the object
**kwargs: additional args to be passed to the object's add_to_fed_job method.
If the obj provides the add_to_fed_job method, it will be called with the kwargs.
This method must follow this signature:
add_to_fed_job(job, ctx, ...)
job: this is the job (self)
ctx: this is the JobCtx that keeps contextual info of this call.
The add_to_fed_job function is usually implemented in FL component classes.
When implementing this function, you should not use anything in the ctx; instead, you should use
the "add_xxx" methods of the "job" object: add_component, add_resources, add_filter, add_executor, etc.
Returns:
result of add_to_job_method if called, or id of added component
"""
if not obj:
raise ValueError("cannot add empty object to job")
if isinstance(obj, (ClientApp, ServerApp)):
raise ValueError("adding (ClientApp, ServerApp) is not allowed")
self._validate_target(target)
target_type = JobTargetType.get_target_type(target)
app = self._deploy_map.get(target)
if not app:
if target_type == JobTargetType.SERVER:
app = ServerApp()
self._add_server_app(app, target)
else:
app = ClientApp()
self._add_client_app(app, target)
self.set_up_client(target)
if isinstance(obj, str): # treat the str type object as external script
if os.path.isdir(obj):
app.add_external_dir(obj)
else:
app.add_external_script(obj)
return None
get_target_type_method = getattr(obj, "get_job_target_type", None)
if get_target_type_method is not None:
expected_target_type = get_target_type_method()
if expected_target_type != target_type:
if target_type == JobTargetType.SERVER:
raise ValueError(f"this object can only be assigned to server, but tried to assign to {target}")
else:
raise ValueError(f"this object can only be assigned to client, but tried to assign to {target}")
add_to_job_method = getattr(obj, _ADD_TO_JOB_METHOD_NAME, None)
if add_to_job_method is not None:
ctx = JobCtx(obj, target, id)
result = add_to_job_method(self, ctx, **kwargs)
else:
# basic object
result = app.add_component(obj, id)
# add any other components the object might have referenced via id
if self._components:
self._add_referenced_components(obj, target)
return result
def _add_referenced_components(self, base_component, target):
"""Adds any other components the object might have referenced via id"""
# Check all arguments for ids referenced with .as_id()
if hasattr(base_component, "__dict__"):
parameters = get_component_init_parameters(base_component)
attrs = base_component.__dict__
for param in parameters:
attr_key = param if param in attrs.keys() else "_" + param
if attr_key in attrs.keys():
base_id = attrs[attr_key]
if isinstance(base_id, str): # could be id
if base_id in self._components:
self._deploy_map[target].add_component(self._components[base_id], base_id)
# add any components referenced by this component
self._add_referenced_components(self._components[base_id], target)
# remove already added components from tracked components
self._components.pop(base_id)
def _get_app(self, ctx: JobCtx):
app = self._deploy_map.get(ctx.target)
if not app:
target_type = JobTargetType.get_target_type(ctx.target)
if target_type == JobTargetType.CLIENT:
app_type = "a ClientApp"
else:
app_type = "a ServerApp"
raise RuntimeError(f"No app found for target '{ctx.target}' - missing {app_type}")
return app
[docs]
def add_component(self, comp_id: str, obj: Any, ctx: JobCtx):
"""Add a component to the job. To be used by job component programmer.
Args:
comp_id: component id
obj: component to be added to job.
ctx: JobCtx for contextual information.
Returns:
final id assigned to component.
"""
app = self._get_app(ctx)
if not comp_id:
comp_id = ctx.comp_id
final_id = app.add_component(obj, comp_id)
if self._components:
self._add_referenced_components(obj, ctx.target)
return final_id
[docs]
def add_controller(self, obj: Controller, ctx: JobCtx):
"""Add a Controller object to the job. To be used by controller programmer.
Args:
obj: Controller to be added to job.
ctx: JobCtx for contextual information.
Returns:
"""
target_type = JobTargetType.get_target_type(ctx.target)
app = self._get_app(ctx)
if target_type != JobTargetType.SERVER: # add client-side controllers as components
app.add_component(obj, ctx.comp_id)
else:
app.add_controller(obj, ctx.comp_id)
[docs]
def add_executor(self, obj: Executor, tasks: List[str], ctx: JobCtx):
"""Add an executor to the job. To be used by executor programmer.
Args:
obj: Executor to be added to job.
tasks: List of tasks that should be handled. If `None`, all tasks will be handled using `[*]`.
ctx: JobCtx for contextual information.
Returns:
"""
app = self._get_app(ctx)
app.add_executor(obj, tasks=tasks)
[docs]
def add_filter(self, obj: Filter, filter_type: str, tasks, ctx: JobCtx):
"""Add a filter to the job. To be used by filter programmer.
Args:
obj: Filter to be added to job.
filter_type: The type of filter used. Either `FilterType.TASK_RESULT` or `FilterType.TASK_DATA`.
tasks: List of tasks that Filter applies to.
ctx: JobCtx for contextual information.
Returns:
"""
app = self._get_app(ctx)
if filter_type == FilterType.TASK_RESULT:
app.add_task_result_filter(tasks, obj)
elif filter_type == FilterType.TASK_DATA:
app.add_task_data_filter(tasks, obj)
else:
raise ValueError(
f"Provided a filter for {ctx.target} without specifying a valid `filter_type`. "
f"Select from `FilterType.TASK_RESULT` or `FilterType.TASK_DATA`."
)
[docs]
def add_resources(self, resources: List[str], ctx: JobCtx):
"""Add resources to the job. To be used by job component programmer.
Args:
resources: List of filenames or directories to be added to job.
ctx: JobCtx for contextual information.
Returns:
"""
app = self._get_app(ctx)
app.add_resources(resources)
[docs]
def add_file_source(self, src_path: str, dest_dir, ctx: JobCtx):
"""Add a file source to the job. To be used by job component programmer.
Args:
src_path: path to the source to be added to job.
dest_dir: destination path for the source
ctx: JobCtx for contextual information.
Returns:
"""
app = self._get_app(ctx)
app.add_file_source(src_path, dest_dir)
[docs]
def to_server(
self,
obj: Any,
id=None,
**kwargs,
):
"""assign an object to the server. For end users.
Args:
obj: The object to be assigned. The obj will be given a default `id` if none is provided based on its type.
id: Optional user-defined id for the object. Defaults to `None` and ID will automatically be assigned.
**kwargs: additional args to be passed to the object's add_to_fed_job method.
Returns:
result of add_to_job_method if called, or id of added component
"""
if isinstance(obj, Executor):
raise ValueError("Use `job.to(executor, <client_name>)` or `job.to_clients(executor)` for Executors.")
return self.to(obj=obj, target=SERVER_SITE_NAME, id=id, **kwargs)
[docs]
def to_clients(
self,
obj: Any,
id=None,
**kwargs,
):
"""assign an object to all clients. For end users.
Args:
obj (Any): Object to be deployed.
id: Optional user-defined id for the object. Defaults to `None` and ID will automatically be assigned.
**kwargs: additional args to be passed to the object's add_to_fed_job method.
Returns:
result of add_to_job_method if called, or id of added component
"""
if isinstance(obj, Controller):
raise ValueError('Use `job.to(controller, "server")` or `job.to_server(controller)` for Controllers.')
return self.to(obj=obj, target=ALL_SITES, id=id, **kwargs)
def _validate_target(self, target):
if not target:
raise ValueError("Must provide a valid target name")
if any(c in SPECIAL_CHARACTERS for c in target) and target != ALL_SITES:
raise ValueError(f"target {target} name contains invalid character")
def _set_all_app(self, client_app: ClientApp, server_app: ServerApp):
if not isinstance(client_app, ClientApp):
raise ValueError(f"`client_app` needs to be of type `ClientApp` but was type {type(client_app)}")
if not isinstance(server_app, ServerApp):
raise ValueError(f"`server_app` needs to be of type `ServerApp` but was type {type(server_app)}")
client_config = client_app.get_app_config()
server_config = server_app.get_app_config()
app_config = FedAppConfig(server_app=server_config, client_app=client_config)
app_name = "app"
self.job.add_fed_app(app_name, app_config)
self.job.set_site_app(ALL_SITES, app_name)
def _set_site_app(self, app: FedApp, target: str):
if not isinstance(app, FedApp):
raise ValueError(f"App needs to be of type `FedApp` but was type {type(app)}")
client_server_config = app.get_app_config()
if isinstance(client_server_config, ClientAppConfig):
app_config = FedAppConfig(server_app=None, client_app=client_server_config)
app_name = f"app_{target}"
elif isinstance(client_server_config, ServerAppConfig):
app_config = FedAppConfig(server_app=client_server_config, client_app=None)
app_name = "app_server"
else:
raise ValueError(
f"App needs to be of type `ClientAppConfig` or `ServerAppConfig` but was type {type(client_server_config)}"
)
self.job.add_fed_app(app_name, app_config)
self.job.set_site_app(target, app_name)
def _set_all_apps(self):
if not self._deployed:
if ALL_SITES in self._deploy_map:
if SERVER_SITE_NAME not in self._deploy_map:
raise ValueError('Missing server components! Deploy using `to(obj, "server") or `to_server(obj)`')
self._set_all_app(client_app=self._deploy_map[ALL_SITES], server_app=self._deploy_map[SERVER_SITE_NAME])
else:
for target in self._deploy_map:
self._set_site_app(self._deploy_map[target], target)
self._deployed = True
[docs]
def export_job(self, job_root: str):
"""Export job config to `job_root` directory with name `self.job_name`.
For end users.
Args:
job_root: directory to export job configuration.
Returns:
"""
self._set_all_apps()
self.job.generate_job_config(job_root)
[docs]
def simulator_run(self, workspace: str, n_clients: int = None, threads: int = None, gpu: str = None):
"""Run the job with the simulator with the `workspace` using `clients` and `threads`.
For end users.
Args:
workspace: workspace directory for job.
n_clients: number of clients.
threads: number of threads.
gpu: gpu assignments for simulating clients, comma separated
Returns:
"""
self._set_all_apps()
if ALL_SITES in self.clients and not n_clients:
raise ValueError("Clients were not specified using to(). Please provide the number of clients to simulate.")
elif ALL_SITES in self.clients and n_clients:
check_positive_int("n_clients", n_clients)
self.clients = [f"site-{i}" for i in range(1, n_clients + 1)]
elif self.clients and n_clients:
raise ValueError("You already specified clients using `to()`. Don't use `n_clients` in simulator_run.")
n_clients = len(self.clients)
if threads is None:
threads = n_clients
self.job.simulator_run(
workspace,
clients=",".join(self.clients),
n_clients=n_clients,
threads=threads,
gpu=gpu,
)
[docs]
def as_id(self, obj: Any) -> str:
"""Generate and return uuid for `obj`. For end users.
If this id is referenced by another added object, this `obj` will also be added as a component.
"""
cid = str(uuid.uuid4())
self._components[cid] = obj
return cid
[docs]
@staticmethod
def check_kwargs(args_to_check: dict, args_expected: dict):
"""Check kwargs for arguments. Raise Error if required arg is missing, or unexpected arg is given.
Args:
args_to_check (dict): kwargs dictionary to check.
args_expected (dict): dictionary of argument name to boolean of whether argument is required (True) or optional (False).
"""
if not args_expected and not args_to_check:
return
if args_to_check and not args_expected:
raise ValueError(f"received args {list(args_to_check.keys())}, but no args expected")
args_info = {}
for k, required in args_expected.items():
args_info[k] = "required" if required else "optional"
# see whether required args are present
for k, required in args_expected.items():
if required and (not args_to_check or k not in args_to_check):
raise ValueError(f"Missing required arg '{k}'. " f"Supported args: {args_info}")
# see whether we got unexpected args
if args_to_check:
for k in args_to_check.keys():
if k not in args_expected:
raise ValueError(f"Received unexpected arg '{k}'. " f"Supported args: {args_info}")
[docs]
def has_add_to_job_method(obj: Any) -> bool:
add_to_job_method = getattr(obj, _ADD_TO_JOB_METHOD_NAME, None)
return add_to_job_method is not None and callable(add_to_job_method)
[docs]
def validate_object_for_job(name, obj, obj_type):
"""Check whether the specified object is valid for job.
The object must either have the add_to_fed_job method or is valid object type.
Args:
name: name of the object
obj: the object to be checked
obj_type: the object type that the object should be, if it doesn't have the add_to_fed_job method.
Returns: None
"""
if has_add_to_job_method(obj):
return
check_object_type(name, obj, obj_type)