Source code for nvflare.app_common.ccwf.recipes.swarm

# Copyright (c) 2025, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from nvflare.apis.dxo import DataKind
from nvflare.app_common.aggregators.intime_accumulate_model_aggregator import InTimeAccumulateWeightedAggregator
from nvflare.app_common.ccwf.ccwf_job import CCWFJob, CrossSiteEvalConfig, SwarmClientConfig, SwarmServerConfig
from nvflare.app_common.ccwf.comps.simple_model_shareable_generator import SimpleModelShareableGenerator
from nvflare.app_opt.pt.file_model_persistor import PTFileModelPersistor
from nvflare.job_config.script_runner import ScriptRunner
from nvflare.recipe.spec import Recipe


[docs] class BaseSwarmLearningRecipe(Recipe): def __init__( self, name: str, server_config: SwarmServerConfig, client_config: SwarmClientConfig, cse_config: CrossSiteEvalConfig = None, ): job = CCWFJob(name=name) job.add_swarm( server_config=server_config, client_config=client_config, cse_config=cse_config, ) Recipe.__init__(self, job)
[docs] class SimpleSwarmLearningRecipe(BaseSwarmLearningRecipe): def __init__( self, name: str, initial_model, num_rounds: int, train_script: str, train_args: dict = None, do_cross_site_eval: bool = False, cross_site_eval_timeout: float = 300, ): aggregator = InTimeAccumulateWeightedAggregator(expected_data_kind=DataKind.WEIGHTS) if do_cross_site_eval: cse_config = CrossSiteEvalConfig(eval_task_timeout=cross_site_eval_timeout) else: cse_config = None if not train_args: train_args = {} server_config = SwarmServerConfig(num_rounds=num_rounds) client_config = SwarmClientConfig( executor=ScriptRunner(script=train_script, **train_args), aggregator=aggregator, persistor=PTFileModelPersistor(model=initial_model), shareable_generator=SimpleModelShareableGenerator(), ) BaseSwarmLearningRecipe.__init__(self, name, server_config, client_config, cse_config)