Source code for nvflare.job_config.fed_job_config

# Copyright (c) 2024, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import builtins
import inspect
import json
import os
import shlex
import shutil
import subprocess
import sys
from enum import Enum
from tempfile import TemporaryDirectory
from typing import Dict

from nvflare.fuel.utils.class_utils import get_component_init_parameters
from nvflare.fuel.utils.log_utils import get_obj_logger
from nvflare.job_config.base_app_config import BaseAppConfig
from nvflare.job_config.fed_app_config import FedAppConfig
from nvflare.private.fed.app.fl_conf import FL_PACKAGES
from nvflare.private.fed.app.utils import kill_child_processes

CONFIG = "config"
CUSTOM = "custom"
FED_SERVER_JSON = "config_fed_server.json"
FED_CLIENT_JSON = "config_fed_client.json"
META_JSON = "meta.json"


[docs] class FedJobConfig: """FedJobConfig represents the job in the NVFlare.""" def __init__(self, job_name, min_clients, mandatory_clients=None) -> None: """FedJobConfig uses the job_name, min_clients and optional mandatory_clients to create the object. It also provides the method to add in the FedApp, the deployment map of the FedApp and participants, and the resource _spec requirements of the participants if needed. Args: job_name: the name of the NVFlare job min_clients: the minimum number of clients for the job mandatory_clients: mandatory clients to run the job (optional) """ super().__init__() self.job_name = job_name self.min_clients = min_clients self.mandatory_clients = mandatory_clients self.fed_apps: Dict[str, FedAppConfig] = {} self.deploy_map: Dict[str, str] = {} self.resource_specs: Dict[str, Dict] = {} self.custom_modules = [] self.logger = get_obj_logger(self)
[docs] def add_fed_app(self, app_name: str, fed_app: FedAppConfig): if not isinstance(fed_app, FedAppConfig): raise RuntimeError(f"server_app must be type of FedAppConfig, but got {fed_app.__class__}") self.fed_apps[app_name] = fed_app
[docs] def set_site_app(self, site_name: str, app_name: str): """assign an app to a certain site. Args: site_name: The target site name. app_name: The app name. Returns: """ if app_name not in self.fed_apps.keys(): raise RuntimeError(f"fed_app {app_name} does not exist.") self.deploy_map[site_name] = app_name
[docs] def add_resource_spec(self, site_name: str, resource_spec: Dict): if site_name in self.resource_specs.keys(): raise RuntimeError(f"{site_name} resource specs already exist.") if not isinstance(resource_spec, dict): raise RuntimeError(f"resource_spec must be a dict. But got: {resource_spec.__class__}") self.resource_specs[site_name] = resource_spec
def _generate_meta(self, job_dir): """generate the job meta.json Returns: """ meta_file = os.path.join(job_dir, META_JSON) meta_json = { "name": self.job_name, "resource_spec": self.resource_specs, "min_clients": self.min_clients, "deploy_map": self._get_deploy_map(), } if self.mandatory_clients: meta_json["mandatory_clients"] = self.mandatory_clients with open(meta_file, "w") as outfile: json_dump = json.dumps(meta_json, indent=4) outfile.write(json_dump)
[docs] def generate_job_config(self, job_root): """generate the job config Returns: """ job_dir = os.path.join(job_root, self.job_name) if os.path.exists(job_dir): if self._is_valid_job_folder(job_dir): shutil.rmtree(job_dir, ignore_errors=True) else: raise RuntimeError(f"Job folder {job_dir} already exists and is not a valid job folder.") for app_name, fed_app in self.fed_apps.items(): self.custom_modules = [] config_dir = os.path.join(job_dir, app_name, CONFIG) custom_dir = os.path.join(job_dir, app_name, CUSTOM) os.makedirs(config_dir, exist_ok=True) if fed_app.server_app: self._get_server_app(config_dir, custom_dir, fed_app) if fed_app.client_app: self._get_client_app(config_dir, custom_dir, fed_app) self._generate_meta(job_dir)
[docs] def simulator_run(self, workspace, clients=None, n_clients=None, threads=None, gpu=None, log_config=None): with TemporaryDirectory() as job_root: self.generate_job_config(job_root) try: command = ( f"{sys.executable} -m nvflare.private.fed.app.simulator.simulator " + os.path.join(job_root, self.job_name) + " -w " + workspace ) if clients: clients = self._trim_whitespace(clients) command += " -c " + str(clients) if n_clients: command += " -n " + str(n_clients) if threads: command += " -t " + str(threads) if gpu: gpu = self._trim_whitespace(gpu) command += " -gpu " + str(gpu) if log_config: command += " -l" + str(log_config) new_env = os.environ.copy() process = subprocess.Popen(shlex.split(command, True), preexec_fn=os.setsid, env=new_env) process.wait() except KeyboardInterrupt: self.logger.info("KeyboardInterrupt, terminate all the child processes.") kill_child_processes(os.getpid()) return -9
def _get_server_app(self, config_dir, custom_dir, fed_app): server_app = {"format_version": 2, "workflows": []} for workflow in fed_app.server_app.workflows: server_app["workflows"].append( { "id": workflow.id, "path": self._get_class_path(workflow.controller, custom_dir), "args": self._get_args(workflow.controller, custom_dir), } ) self._get_base_app(custom_dir, fed_app.server_app, server_app) server_config = os.path.join(config_dir, FED_SERVER_JSON) with open(server_config, "w") as outfile: json_dump = json.dumps(server_app, indent=4) outfile.write(json_dump) self._copy_ext_scripts(custom_dir, fed_app.server_app.ext_scripts) self._copy_ext_dirs(custom_dir, fed_app.server_app) self._copy_file_sources(custom_dir, fed_app.server_app.file_sources) def _copy_file_sources(self, custom_dir, file_sources): for s in file_sources: # s is a tuple of (src_path, dest_dir) src_path, dest_dir = s if dest_dir: dest_path = os.path.join(custom_dir, dest_dir) if not os.path.exists(dest_path): os.makedirs(dest_path, exist_ok=True) else: dest_path = custom_dir if os.path.isfile(src_path): base_name = os.path.basename(src_path) dest_file = os.path.join(dest_path, base_name) shutil.copy(src_path, dest_file) else: # this is a dir shutil.copytree(src_path, dest_path, dirs_exist_ok=True) def _copy_ext_scripts(self, custom_dir, ext_scripts): for script in ext_scripts: if os.path.exists(script): if os.path.isabs(script): relative_script = self._get_relative_script(script) else: relative_script = script dest_file = os.path.join(custom_dir, relative_script) module = "".join(relative_script.rsplit(".py", 1)).replace(os.sep, ".") self._copy_source_file(custom_dir, module, script, dest_file) def _copy_ext_dirs(self, custom_dir, app_config: BaseAppConfig): for dir in app_config.ext_dirs: shutil.copytree(dir, custom_dir, dirs_exist_ok=True) def _get_relative_script(self, script): package_path = "" for path in sys.path: if script.startswith(path): if len(path) > len(package_path): package_path = path return script[len(package_path) + 1 :] def _get_class_path(self, obj, custom_dir): module = obj.__module__ source_file = inspect.getsourcefile(obj.__class__) if module == "__main__": module = os.path.basename(source_file).strip(".py") self._get_custom_file(custom_dir, module, source_file) return module + "." + obj.__class__.__name__ def _get_custom_file(self, custom_dir, module, source_file): package = module.split(".")[0] if os.path.exists(source_file): if package not in FL_PACKAGES and module not in self.custom_modules: module_path = module.replace(".", os.sep) if module_path in source_file: index = source_file.rindex(module_path) dest = source_file[index:] self.custom_modules.append(module) os.makedirs(custom_dir, exist_ok=True) # dest_file = os.path.join(custom_dir, module.replace(".", os.sep) + ".py") dest_file = os.path.join(custom_dir, dest) self._copy_source_file(custom_dir, module, source_file, dest_file) def _copy_source_file(self, custom_dir, module, source_file, dest_file): os.makedirs(custom_dir, exist_ok=True) source_dir = os.path.dirname(source_file) with open(source_file, "r") as sf: import_lines = list(self.locate_imports(sf, dest_file)) for line in import_lines: import_module = line.split(" ")[1] import_source = import_module if import_module.startswith("."): import_source = import_source[1:] new_module = module.split(".")[0:-1] new_module.append(import_source) import_module = ".".join(new_module) import_source_file = os.path.join(source_dir, import_source.replace(".", os.sep) + ".py") if os.path.exists(import_source_file): # Handle the import from within the same module self._get_custom_file(custom_dir, import_module, import_source_file) else: # Handle the import from outside the module size = len(module.split(".")) - 1 source_root = os.sep.join(source_dir.split(os.sep)[0:-size]) import_source_file = os.path.join(source_root, import_source.replace(".", os.sep) + ".py") if os.path.exists(import_source_file): self._get_custom_file(custom_dir, import_module, import_source_file) def _get_client_app(self, config_dir, custom_dir, fed_app): client_app = {"format_version": 2, "executors": []} for e in fed_app.client_app.executors: client_app["executors"].append( { "tasks": e.tasks, "executor": { "path": self._get_class_path(e.executor, custom_dir), "args": self._get_args(e.executor, custom_dir), }, } ) self._get_base_app(custom_dir, fed_app.client_app, client_app) client_config = os.path.join(config_dir, FED_CLIENT_JSON) with open(client_config, "w") as outfile: json_dump = json.dumps(client_app, indent=4) outfile.write(json_dump) self._copy_ext_scripts(custom_dir, fed_app.client_app.ext_scripts) self._copy_ext_dirs(custom_dir, fed_app.client_app) self._copy_file_sources(custom_dir, fed_app.client_app.file_sources) def _get_base_app(self, custom_dir, app, app_config): app_config["components"] = [] for cid, component in app.components.items(): app_config["components"].append( { "id": cid, "path": self._get_class_path(component, custom_dir), "args": self._get_args(component, custom_dir), } ) app_config["task_data_filters"] = self._process_filters(app.task_data_filters, custom_dir) app_config["task_result_filters"] = self._process_filters(app.task_result_filters, custom_dir) def _process_filters(self, taskset_filters: list, custom_dir): """Process taskset_filters into app filter configuration Args: taskset_filters: the list of tuples that contain taskset/filters association. custom_dir: custom dir of the app. Returns: app filter configuration that is a list of dicts, each dict represents a taskset/filters association. """ app_config_filters = [] for task_set, filter_list in taskset_filters: filters = [] for f in filter_list: filters.append( { "path": self._get_class_path(f, custom_dir), "args": self._get_args(f, custom_dir), } ) app_config_filters.append({"tasks": list(task_set), "filters": filters}) return app_config_filters def _get_args(self, component, custom_dir): args = {} if hasattr(component, "__dict__"): parameters = get_component_init_parameters(component) attrs = component.__dict__ for param in parameters: attr_key = param if param in attrs.keys() else "_" + param if attr_key in ["args", "kwargs"]: continue if attr_key in attrs.keys() and parameters[param].default != attrs[attr_key]: if type(attrs[attr_key]).__name__ in dir(builtins): args[param] = attrs[attr_key] elif issubclass(attrs[attr_key].__class__, Enum): args[param] = attrs[attr_key].value else: args[param] = { "path": self._get_class_path(attrs[attr_key], custom_dir), "args": self._get_args(attrs[attr_key], custom_dir), } return args def _get_filters(self, filters, custom_dir): r = [] for f in filters: r.append({"path": self._get_class_path(f, custom_dir), "args": self._get_args(f, custom_dir)}) return r
[docs] def locate_imports(self, sf, dest_file): """Locate all the import statements from the python script, including the imports across multiple lines, using the line break continuing. Args: sf: source file dest_file: copy to destination file Returns: yield all the imports within the source file """ os.makedirs(os.path.dirname(dest_file), exist_ok=True) with open(dest_file, "w") as df: trimmed = "" for line in sf: df.write(line) trimmed += line.strip() if trimmed.endswith("\\"): trimmed = trimmed[0:-1] trimmed = trimmed.strip() + " " else: if trimmed.startswith("from ") and ("import " in trimmed): yield trimmed elif trimmed.startswith("import "): yield trimmed trimmed = ""
def _get_deploy_map(self): deploy_map = {} for site, app_name in self.deploy_map.items(): deploy_map[app_name] = deploy_map.get(app_name, []) deploy_map[app_name].append(site) return deploy_map def _trim_whitespace(self, string: str): strings = string.split(",") for i in range(len(strings)): strings[i] = strings[i].strip() return ",".join(strings) @staticmethod def _is_valid_job_folder(job_folder: str) -> bool: meta_file = os.path.join(job_folder, META_JSON) return os.path.exists(meta_file)