Source code for nvflare.app_opt.pt.recipes.fedavg

# Copyright (c) 2025, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Any, Optional, Union

from nvflare.apis.dxo import DataKind
from nvflare.app_common.abstract.aggregator import Aggregator
from nvflare.app_common.abstract.model_locator import ModelLocator
from nvflare.app_common.abstract.model_persistor import ModelPersistor
from nvflare.client.config import ExchangeFormat, TransferType
from nvflare.fuel.utils.constants import FrameworkType
from nvflare.recipe.fedavg import FedAvgRecipe as UnifiedFedAvgRecipe


[docs] class FedAvgRecipe(UnifiedFedAvgRecipe): """A recipe for implementing Federated Averaging (FedAvg) for PyTorch. FedAvg is a fundamental federated learning algorithm that aggregates model updates from multiple clients by computing a weighted average based on the amount of local training data. This recipe sets up a complete federated learning workflow with memory-efficient InTime aggregation. The recipe configures: - A federated job with initial model (optional) - FedAvg controller with InTime aggregation for memory efficiency - Optional early stopping and model selection - Script runners for client-side training execution Args: name: Name of the federated learning job. Defaults to "fedavg". model: Initial model to start federated training with. Can be: - nn.Module instance - Dict config: {"class_path": "module.ClassName", "args": {"param": value}} - None: no initial model initial_ckpt: Absolute path to a pre-trained checkpoint file. The file may not exist locally as it could be on the server. Used to load initial weights. Note: PyTorch requires model when using initial_ckpt (for architecture). min_clients: Minimum number of clients required to start a training round. num_rounds: Number of federated training rounds to execute. Defaults to 2. train_script: Path to the training script that will be executed on each client. train_args: Command line arguments to pass to the training script. aggregator: Custom aggregator (ModelAggregator) for combining client model updates. Must implement accept_model(), aggregate_model(), reset_stats() methods. If None, uses built-in memory-efficient weighted averaging. aggregator_data_kind: Data kind to use for the aggregator. Defaults to DataKind.WEIGHTS. Kept for backward compatibility. launch_external_process (bool): Whether to launch the script in external process. Defaults to False. command (str): If launch_external_process=True, command to run script (prepended to script). Defaults to "python3 -u". server_expected_format (str): What format to exchange the parameters between server and client. params_transfer_type (str): How to transfer the parameters. FULL means the whole model parameters are sent. DIFF means that only the difference is sent. Defaults to TransferType.FULL. model_persistor: Custom model persistor. If None, PTFileModelPersistor will be used. model_locator: Custom model locator. If None, PTFileModelLocator will be used. per_site_config: Per-site configuration for the federated learning job. launch_once: Whether external process is launched once or per task. Defaults to True. shutdown_timeout: Seconds to wait before shutdown. Defaults to 0.0. key_metric: Metric used to determine if the model is globally best. Defaults to "accuracy". stop_cond: Early stopping condition based on metric. String literal in the format of '<key> <op> <value>' (e.g. "accuracy >= 80"). If None, early stopping is disabled. patience: Number of rounds with no improvement after which FL will be stopped. save_filename: Filename for saving the best model. Defaults to "FL_global_model.pt". exclude_vars: Regex pattern for variables to exclude from aggregation. aggregation_weights: Per-client aggregation weights dict. Defaults to equal weights. Example: Basic usage with early stopping: ```python recipe = FedAvgRecipe( name="my_fedavg_job", model=pretrained_model, min_clients=2, num_rounds=10, train_script="client.py", train_args="--epochs 5 --batch_size 32", stop_cond="accuracy >= 95", patience=3 ) ``` Note: This recipe uses InTime (streaming) aggregation for memory efficiency - each client result is aggregated immediately upon receipt rather than collecting all results first. Memory usage is constant regardless of the number of clients. """ def __init__( self, *, name: str = "fedavg", model: Union[Any, dict[str, Any], None] = None, initial_ckpt: Optional[str] = None, min_clients: int, num_rounds: int = 2, train_script: str, train_args: str = "", aggregator: Optional[Aggregator] = None, aggregator_data_kind: Optional[DataKind] = DataKind.WEIGHTS, launch_external_process: bool = False, command: str = "python3 -u", server_expected_format: ExchangeFormat = ExchangeFormat.NUMPY, params_transfer_type: TransferType = TransferType.FULL, model_persistor: Optional[ModelPersistor] = None, model_locator: Optional[ModelLocator] = None, per_site_config: Optional[dict[str, dict]] = None, launch_once: bool = True, shutdown_timeout: float = 0.0, key_metric: str = "accuracy", # New FedAvg features stop_cond: Optional[str] = None, patience: Optional[int] = None, save_filename: str = "FL_global_model.pt", exclude_vars: Optional[str] = None, aggregation_weights: Optional[dict[str, float]] = None, server_memory_gc_rounds: int = 0, client_memory_gc_rounds: int = 0, cuda_empty_cache: bool = False, ): # Store PyTorch-specific model_locator before calling parent self._pt_model_locator = model_locator # Call the unified FedAvgRecipe with PyTorch-specific settings super().__init__( name=name, model=model, initial_ckpt=initial_ckpt, min_clients=min_clients, num_rounds=num_rounds, train_script=train_script, train_args=train_args, aggregator=aggregator, aggregator_data_kind=aggregator_data_kind, launch_external_process=launch_external_process, command=command, framework=FrameworkType.PYTORCH, server_expected_format=server_expected_format, params_transfer_type=params_transfer_type, model_persistor=model_persistor, per_site_config=per_site_config, launch_once=launch_once, shutdown_timeout=shutdown_timeout, key_metric=key_metric, stop_cond=stop_cond, patience=patience, save_filename=save_filename, exclude_vars=exclude_vars, aggregation_weights=aggregation_weights, server_memory_gc_rounds=server_memory_gc_rounds, client_memory_gc_rounds=client_memory_gc_rounds, cuda_empty_cache=cuda_empty_cache, ) def _setup_model_and_persistor(self, job) -> str: """Override to handle PyTorch-specific model setup.""" from nvflare.app_opt.pt.job_config.model import PTModel from nvflare.recipe.utils import extract_persistor_id, resolve_initial_ckpt, setup_custom_persistor persistor_id = setup_custom_persistor(job=job, model_persistor=self.model_persistor) if persistor_id: if hasattr(job, "comp_ids"): job.comp_ids["persistor_id"] = persistor_id if self._pt_model_locator is not None: locator_id = job.to_server(self._pt_model_locator, id="locator") if isinstance(locator_id, str) and locator_id: job.comp_ids["locator_id"] = locator_id return persistor_id ckpt_path = resolve_initial_ckpt(self.initial_ckpt, getattr(self, "_prepared_initial_ckpt", None), job) if self.model is None and ckpt_path: raise ValueError("FrameworkType.PYTORCH requires 'model' when using initial_ckpt.") if self.model is None: return "" # Disable numpy conversion when using tensor format to keep PyTorch tensors. allow_numpy_conversion = self.server_expected_format != ExchangeFormat.PYTORCH pt_model = PTModel( model=self.model, initial_ckpt=ckpt_path, locator=self._pt_model_locator, allow_numpy_conversion=allow_numpy_conversion, ) result = job.to_server(pt_model, id="persistor") if isinstance(result, dict) and hasattr(job, "comp_ids"): job.comp_ids.update(result) persistor_id = extract_persistor_id(result) if persistor_id and hasattr(job, "comp_ids"): job.comp_ids.setdefault("persistor_id", persistor_id) return persistor_id