Source code for nvflare.app_common.ccwf.ccwf_job

# Copyright (c) 2024, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, List, Optional

from nvflare.apis.executor import Executor
from nvflare.app_common.abstract.aggregator import Aggregator
from nvflare.app_common.abstract.metric_comparator import MetricComparator
from nvflare.app_common.abstract.model_persistor import ModelPersistor
from nvflare.app_common.abstract.shareable_generator import ShareableGenerator
from nvflare.app_common.app_constant import AppConstants
from nvflare.app_common.ccwf.common import Constant, CyclicOrder
from nvflare.job_config.api import FedJob, validate_object_for_job
from nvflare.widgets.widget import Widget

from .cse_client_ctl import CrossSiteEvalClientController
from .cse_server_ctl import CrossSiteEvalServerController
from .cyclic_client_ctl import CyclicClientController
from .cyclic_server_ctl import CyclicServerController
from .swarm_client_ctl import SwarmClientController
from .swarm_server_ctl import SwarmServerController

_EXECUTOR_TASKS = ["train", "validate", "submit_model"]


[docs] class SwarmServerConfig: def __init__( self, num_rounds: int, start_round: int = 0, start_task_timeout=Constant.START_TASK_TIMEOUT, configure_task_timeout=Constant.CONFIG_TASK_TIMEOUT, participating_clients=None, result_clients=None, starting_client: str = "", max_status_report_interval: float = Constant.PER_CLIENT_STATUS_REPORT_TIMEOUT, progress_timeout: float = Constant.WORKFLOW_PROGRESS_TIMEOUT, private_p2p: bool = True, aggr_clients=None, train_clients=None, ): self.num_rounds = num_rounds self.start_round = start_round self.start_task_timeout = start_task_timeout self.configure_task_timeout = configure_task_timeout self.participating_clients = participating_clients self.result_clients = result_clients self.starting_client = starting_client self.max_status_report_interval = max_status_report_interval self.progress_timeout = progress_timeout self.private_p2p = private_p2p self.aggr_clients = aggr_clients self.train_clients = train_clients
[docs] class SwarmClientConfig: def __init__( self, executor: Any, persistor: Any, shareable_generator: Any, aggregator: Any, metric_comparator: Any = None, model_selector: Any = None, learn_task_check_interval=Constant.LEARN_TASK_CHECK_INTERVAL, learn_task_abort_timeout=Constant.LEARN_TASK_ABORT_TIMEOUT, learn_task_ack_timeout=Constant.LEARN_TASK_ACK_TIMEOUT, learn_task_timeout=None, final_result_ack_timeout=Constant.FINAL_RESULT_ACK_TIMEOUT, min_responses_required: int = 1, wait_time_after_min_resps_received: float = 10.0, ): # the executor could be a wrapper object that adds real Executor when added to job! validate_object_for_job("executor", executor, Executor) validate_object_for_job("persistor", persistor, ModelPersistor) validate_object_for_job("shareable_generator", shareable_generator, ShareableGenerator) validate_object_for_job("aggregator", aggregator, Aggregator) if model_selector: validate_object_for_job("model_selector", model_selector, Widget) if metric_comparator: validate_object_for_job("metric_comparator", metric_comparator, MetricComparator) self.executor = executor self.persistor = persistor self.shareable_generator = shareable_generator self.aggregator = aggregator self.metric_comparator = metric_comparator self.model_selector = model_selector self.learn_task_check_interval = learn_task_check_interval self.learn_task_abort_timeout = learn_task_abort_timeout self.learn_task_ack_timeout = learn_task_ack_timeout self.learn_task_timeout = learn_task_timeout self.final_result_ack_timeout = final_result_ack_timeout self.min_responses_required = min_responses_required self.wait_time_after_min_resps_received = wait_time_after_min_resps_received
[docs] class CyclicServerConfig: def __init__( self, num_rounds: int, start_task_timeout=Constant.START_TASK_TIMEOUT, configure_task_timeout=Constant.CONFIG_TASK_TIMEOUT, participating_clients=None, result_clients=None, starting_client: str = "", max_status_report_interval: float = Constant.PER_CLIENT_STATUS_REPORT_TIMEOUT, progress_timeout: float = Constant.WORKFLOW_PROGRESS_TIMEOUT, private_p2p: bool = True, cyclic_order: str = CyclicOrder.FIXED, ): self.num_rounds = num_rounds self.start_task_timeout = start_task_timeout self.configure_task_timeout = configure_task_timeout self.participating_clients = participating_clients self.result_clients = result_clients self.starting_client = starting_client self.max_status_report_interval = max_status_report_interval self.progress_timeout = progress_timeout self.private_p2p = private_p2p self.cyclic_order = cyclic_order
[docs] class CyclicClientConfig: def __init__( self, executor: Any, persistor: Any, shareable_generator: Any, learn_task_abort_timeout=Constant.LEARN_TASK_ABORT_TIMEOUT, learn_task_ack_timeout=Constant.LEARN_TASK_ACK_TIMEOUT, final_result_ack_timeout=Constant.FINAL_RESULT_ACK_TIMEOUT, ): validate_object_for_job("executor", executor, Executor) validate_object_for_job("persistor", persistor, ModelPersistor) validate_object_for_job("shareable_generator", shareable_generator, ShareableGenerator) self.executor = executor self.persistor = persistor self.shareable_generator = shareable_generator self.learn_task_abort_timeout = learn_task_abort_timeout self.learn_task_ack_timeout = learn_task_ack_timeout self.final_result_ack_timeout = final_result_ack_timeout
[docs] class CrossSiteEvalConfig: def __init__( self, start_task_timeout=Constant.START_TASK_TIMEOUT, configure_task_timeout=Constant.CONFIG_TASK_TIMEOUT, eval_task_timeout=30, progress_timeout: float = Constant.WORKFLOW_PROGRESS_TIMEOUT, private_p2p: bool = True, participating_clients=None, evaluators=None, evaluatees=None, global_model_client=None, max_status_report_interval: float = Constant.PER_CLIENT_STATUS_REPORT_TIMEOUT, eval_result_dir=AppConstants.CROSS_VAL_DIR, get_model_timeout=Constant.GET_MODEL_TIMEOUT, ): self.start_task_timeout = start_task_timeout self.configure_task_timeout = configure_task_timeout self.eval_task_timeout = eval_task_timeout self.progress_timeout = progress_timeout self.private_p2p = private_p2p self.participating_clients = participating_clients self.evaluators = evaluators self.evaluatees = evaluatees self.global_model_client = global_model_client self.max_status_report_interval = max_status_report_interval self.eval_result_dir = eval_result_dir self.get_model_timeout = get_model_timeout
[docs] class CCWFJob(FedJob): def __init__( self, name: str = "fed_job", min_clients: int = 1, mandatory_clients: Optional[List[str]] = None, executor_tasks: Optional[List[str]] = None, external_resources: Optional[str] = None, ): """Client-Controlled Workflow Job. Provides methods for adding client-controlled swarm learning, cyclic, and cross-site evaluation workflows. Args: name (name, optional): name of the job. Defaults to "fed_job" min_clients (int, optional): the minimum number of clients for the job. Defaults to 1. mandatory_clients (List[str], optional): mandatory clients to run the job. Default None. executor_tasks (List[str], optional): tasks for the executor external_resources (str, optional): External resources directory or filename. Defaults to None. """ super().__init__(name, min_clients, mandatory_clients) # A CCWF job can have multiple workflows (swarm, cyclic, etc.), but can only have one executor for training! # This executor can be added by any workflow. self.executor = None self.executor_tasks = executor_tasks if not executor_tasks: self.executor_tasks = _EXECUTOR_TASKS if external_resources: self.to_server(external_resources) self.to_clients(external_resources)
[docs] def add_swarm( self, server_config: SwarmServerConfig, client_config: SwarmClientConfig, cse_config: CrossSiteEvalConfig = None, ): controller = SwarmServerController( num_rounds=server_config.num_rounds, start_round=server_config.start_round, start_task_timeout=server_config.start_task_timeout, configure_task_timeout=server_config.configure_task_timeout, participating_clients=server_config.participating_clients, result_clients=server_config.result_clients, starting_client=server_config.starting_client, max_status_report_interval=server_config.max_status_report_interval, progress_timeout=server_config.progress_timeout, private_p2p=server_config.private_p2p, aggr_clients=server_config.aggr_clients, train_clients=server_config.train_clients, ) self.to_server(controller) metric_comparator_id = None if client_config.metric_comparator: metric_comparator_id = self.to_clients(client_config.metric_comparator, id="metric_comparator") persistor_id = self.to_clients(client_config.persistor, id="persistor") shareable_generator_id = self.to_clients(client_config.shareable_generator, id="shareable_generator") aggregator_id = self.to_clients(client_config.aggregator, id="aggregator") client_controller = SwarmClientController( aggregator_id=aggregator_id, persistor_id=persistor_id, shareable_generator_id=shareable_generator_id, metric_comparator_id=metric_comparator_id, learn_task_abort_timeout=client_config.learn_task_abort_timeout, learn_task_ack_timeout=client_config.learn_task_ack_timeout, learn_task_timeout=client_config.learn_task_timeout, final_result_ack_timeout=client_config.final_result_ack_timeout, min_responses_required=client_config.min_responses_required, wait_time_after_min_resps_received=client_config.wait_time_after_min_resps_received, ) self.to_clients(client_controller, tasks=["swarm_*"]) if not self.executor: # We add the executor only if it's not added yet. self.to_clients(client_config.executor, tasks=self.executor_tasks) self.executor = client_config.executor if client_config.model_selector: self.to_clients(client_config.model_selector, id="model_selector") if cse_config: self.add_cross_site_eval(cse_config, persistor_id)
[docs] def add_cyclic( self, server_config: CyclicServerConfig, client_config: CyclicClientConfig, cse_config: CrossSiteEvalConfig = None, ): controller = CyclicServerController( num_rounds=server_config.num_rounds, start_task_timeout=server_config.start_task_timeout, configure_task_timeout=server_config.configure_task_timeout, participating_clients=server_config.participating_clients, result_clients=server_config.result_clients, starting_client=server_config.starting_client, max_status_report_interval=server_config.max_status_report_interval, progress_timeout=server_config.progress_timeout, private_p2p=server_config.private_p2p, cyclic_order=server_config.cyclic_order, ) self.to_server(controller) persistor_id = self.to_clients(client_config.persistor, id="persistor") shareable_generator_id = self.to_clients(client_config.shareable_generator, id="shareable_generator") client_controller = CyclicClientController( persistor_id=persistor_id, shareable_generator_id=shareable_generator_id, learn_task_abort_timeout=client_config.learn_task_abort_timeout, learn_task_ack_timeout=client_config.learn_task_ack_timeout, final_result_ack_timeout=client_config.final_result_ack_timeout, ) self.to_clients(client_controller, tasks=["cyclic_*"]) if not self.executor: # We add the executor only if it's not added yet. self.to_clients(client_config.executor, tasks=self.executor_tasks) self.executor = client_config.executor if cse_config: self.add_cross_site_eval(cse_config, persistor_id)
[docs] def add_cross_site_eval( self, cse_config: CrossSiteEvalConfig, persistor_id: str, ): controller = CrossSiteEvalServerController( start_task_timeout=cse_config.start_task_timeout, configure_task_timeout=cse_config.configure_task_timeout, eval_task_timeout=cse_config.eval_task_timeout, progress_timeout=cse_config.progress_timeout, private_p2p=cse_config.private_p2p, participating_clients=cse_config.participating_clients, evaluators=cse_config.evaluators, evaluatees=cse_config.evaluatees, global_model_client=cse_config.global_model_client, max_status_report_interval=cse_config.max_status_report_interval, eval_result_dir=cse_config.eval_result_dir, ) self.to_server(controller) client_controller = CrossSiteEvalClientController( persistor_id=persistor_id, get_model_timeout=cse_config.get_model_timeout, ) self.to_clients(client_controller, tasks=["cse_*"])