Source code for nvflare.recipe.fedavg

# Copyright (c) 2025, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import warnings
from typing import Any, Dict, Optional, Union

from pydantic import BaseModel

from nvflare.apis.dxo import DataKind
from nvflare.apis.job_def import ALL_SITES, SERVER_SITE_NAME
from nvflare.app_common.abstract.aggregator import Aggregator
from nvflare.app_common.abstract.model_persistor import ModelPersistor
from nvflare.app_common.app_constant import DefaultCheckpointFileName
from nvflare.app_common.workflows.fedavg import FedAvg
from nvflare.client.config import ExchangeFormat, TransferType
from nvflare.fuel.utils.constants import FrameworkType
from nvflare.job_config.base_fed_job import BaseFedJob
from nvflare.job_config.script_runner import ScriptRunner
from nvflare.recipe.spec import Recipe


# Internal — not part of the public API
class _FedAvgValidator(BaseModel):
    model_config = {"arbitrary_types_allowed": True}

    name: str
    model: Any
    initial_ckpt: Optional[str] = None
    min_clients: int
    num_rounds: int
    train_script: str
    train_args: str
    # Legacy parameters for backward compatibility (not used by new FedAvg)
    aggregator: Optional[Aggregator] = None
    aggregator_data_kind: Optional[DataKind] = DataKind.WEIGHTS
    # Core parameters
    launch_external_process: bool
    command: str
    framework: FrameworkType
    server_expected_format: ExchangeFormat
    params_transfer_type: TransferType
    model_persistor: Optional[ModelPersistor] = None
    per_site_config: Optional[Dict[str, Dict]] = None
    launch_once: bool = True
    shutdown_timeout: float = 0.0
    key_metric: str = "accuracy"
    # New FedAvg features
    stop_cond: Optional[str] = None
    patience: Optional[int] = None
    best_model_filename: str = DefaultCheckpointFileName.BEST_GLOBAL_MODEL
    save_filename: Optional[str] = None
    exclude_vars: Optional[str] = None
    aggregation_weights: Optional[Dict[str, float]] = None
    # Memory management
    server_memory_gc_rounds: int = 0
    enable_tensor_disk_offload: bool = False
    client_memory_gc_rounds: int = 0
    cuda_empty_cache: bool = False


[docs] class FedAvgRecipe(Recipe): """Unified FedAvg recipe for PyTorch, TensorFlow, and Scikit-learn. FedAvg is a fundamental federated learning algorithm that aggregates model updates from multiple clients by computing a weighted average based on the amount of local training data. This recipe sets up a complete federated learning workflow with memory-efficient InTime aggregation. The recipe configures: - A federated job with initial model (optional) - FedAvg controller with InTime aggregation for memory efficiency - Optional early stopping and model selection - Script runners for client-side training execution Args: name: Name of the federated learning job. Defaults to "fedavg". model: Initial model to start federated training with. Can be: - Model instance (nn.Module, tf.keras.Model, etc.) - Dict config: {"class_path": "module.ClassName", "args": {"param": value}} - None: no initial model For framework-specific types (nn.Module, tf.keras.Model), use the corresponding framework recipe (e.g., nvflare.app_opt.pt.recipes.FedAvgRecipe). initial_ckpt: Absolute path to a pre-trained checkpoint file. The file may not exist locally as it could be on the server. Used to load initial weights. min_clients: Minimum number of clients required to start a training round. num_rounds: Number of federated training rounds to execute. Defaults to 2. train_script: Path to the training script that will be executed on each client. train_args: Command line arguments to pass to the training script. aggregator: Custom aggregator (ModelAggregator) for combining client model updates. Must implement accept_model(), aggregate_model(), reset_stats() methods. If None, uses built-in memory-efficient weighted averaging. Defaults to None. aggregator_data_kind: Data kind for aggregation (DataKind.WEIGHTS or DataKind.WEIGHT_DIFF). Kept for backward compatibility. Defaults to DataKind.WEIGHTS. launch_external_process: Whether to launch the script in external process. Defaults to False. command: If launch_external_process=True, command to run script (prepended to script). Defaults to "python3 -u". framework: The framework type. One of: - FrameworkType.PYTORCH (default) - FrameworkType.TENSORFLOW - FrameworkType.NUMPY - FrameworkType.RAW (for custom frameworks, e.g., sklearn, XGBoost) server_expected_format: What format to exchange the parameters between server and client. Defaults to ExchangeFormat.NUMPY. params_transfer_type: How to transfer the parameters. FULL means the whole model parameters are sent. DIFF means that only the difference is sent. Defaults to TransferType.FULL. model_persistor: Custom model persistor for any framework. If None, uses the framework's default persistor when one is available. per_site_config: Per-site configuration for the federated learning job. Dictionary mapping site names to configuration dicts. Each config dict can contain optional overrides: - train_script (str): Training script path - train_args (str): Script arguments - launch_external_process (bool): Whether to launch external process - command (str): Command prefix for external process - framework (FrameworkType): Framework type - server_expected_format (ExchangeFormat): Exchange format - params_transfer_type (TransferType): Parameter transfer type - launch_once (bool): Whether to launch external process once or per task - shutdown_timeout (float): Shutdown timeout in seconds If not provided, the same configuration will be used for all clients. launch_once: Whether the external process will be launched only once at the beginning or on each task. Only used if `launch_external_process` is True. Defaults to True. shutdown_timeout: If provided, will wait for this number of seconds before shutdown. Only used if `launch_external_process` is True. Defaults to 0.0. key_metric: Metric used to determine if the model is globally best. If validation metrics are a dict, key_metric selects the metric used for global model selection by the IntimeModelSelector. Defaults to "accuracy". stop_cond: Early stopping condition based on metric. String literal in the format of '<key> <op> <value>' (e.g. "accuracy >= 80"). If None, early stopping is disabled. patience: Number of rounds with no improvement after which FL will be stopped. Only applies if stop_cond is set. Defaults to None. best_model_filename: Filename for saving the best model. If unset, framework persistors that expose a separate best-model filename use their own default, such as DefaultCheckpointFileName.BEST_GLOBAL_MODEL for the default PyTorch persistor. save_filename: Deprecated alias for best_model_filename. If both are specified, they must match. exclude_vars: Regex pattern for variables to exclude from aggregation. aggregation_weights: Per-client aggregation weights dict. Defaults to equal weights. server_memory_gc_rounds: Run memory cleanup (gc.collect + malloc_trim) every N rounds on server. Set to 0 to disable. Defaults to 0. enable_tensor_disk_offload: Enable disk-backed tensor offload for incoming streamed payloads. When True, server receives tensor payloads via temp files and materializes lazily. Note: This recipe uses InTime (streaming) aggregation for memory efficiency - each client result is aggregated immediately upon receipt rather than collecting all results first. Memory usage is constant regardless of the number of clients. If you want to use a custom aggregator, you can pass it in the aggregator parameter. The custom aggregator must be a subclass of the Aggregator class. """ def __init__( self, *, name: str = "fedavg", model: Union[Any, Dict[str, Any], None] = None, initial_ckpt: Optional[str] = None, min_clients: int, num_rounds: int = 2, train_script: str, train_args: str = "", # Legacy parameters for backward compatibility aggregator: Optional[Aggregator] = None, aggregator_data_kind: Optional[DataKind] = DataKind.WEIGHTS, # Core parameters launch_external_process: bool = False, command: str = "python3 -u", framework: FrameworkType = FrameworkType.PYTORCH, server_expected_format: ExchangeFormat = ExchangeFormat.NUMPY, params_transfer_type: TransferType = TransferType.FULL, model_persistor: Optional[ModelPersistor] = None, per_site_config: Optional[Dict[str, Dict]] = None, launch_once: bool = True, shutdown_timeout: float = 0.0, key_metric: str = "accuracy", # New FedAvg features stop_cond: Optional[str] = None, patience: Optional[int] = None, best_model_filename: Optional[str] = None, save_filename: Optional[str] = None, exclude_vars: Optional[str] = None, aggregation_weights: Optional[Dict[str, float]] = None, server_memory_gc_rounds: int = 0, enable_tensor_disk_offload: bool = False, client_memory_gc_rounds: int = 0, cuda_empty_cache: bool = False, ): explicit_best_model_filename = best_model_filename is not None explicit_save_filename = save_filename is not None best_model_filename, controller_save_filename = self._resolve_model_filenames( best_model_filename, save_filename ) if framework in (FrameworkType.TENSORFLOW, FrameworkType.NUMPY) and ( explicit_best_model_filename or explicit_save_filename ): warnings.warn( "best_model_filename/save_filename is accepted for API compatibility by TensorFlow and NumPy " "FedAvg recipes, but their default persistors do not currently create a separate best-model artifact " "because adding new best-model event save paths would increase model memory use.", UserWarning, stacklevel=3, ) # Validate inputs internally v = _FedAvgValidator( name=name, model=model, initial_ckpt=initial_ckpt, min_clients=min_clients, num_rounds=num_rounds, train_script=train_script, train_args=train_args, aggregator=aggregator, aggregator_data_kind=aggregator_data_kind, launch_external_process=launch_external_process, command=command, framework=framework, server_expected_format=server_expected_format, params_transfer_type=params_transfer_type, model_persistor=model_persistor, per_site_config=per_site_config, launch_once=launch_once, shutdown_timeout=shutdown_timeout, key_metric=key_metric, stop_cond=stop_cond, patience=patience, best_model_filename=best_model_filename, save_filename=save_filename, exclude_vars=exclude_vars, aggregation_weights=aggregation_weights, server_memory_gc_rounds=server_memory_gc_rounds, enable_tensor_disk_offload=enable_tensor_disk_offload, client_memory_gc_rounds=client_memory_gc_rounds, cuda_empty_cache=cuda_empty_cache, ) self.name = v.name self.model = v.model self.initial_ckpt = v.initial_ckpt # Validate inputs using shared utilities from nvflare.recipe.utils import recipe_model_to_job_model, validate_ckpt validate_ckpt(self.initial_ckpt) if isinstance(self.model, dict): self.model = recipe_model_to_job_model(self.model) self.min_clients = v.min_clients self.num_rounds = v.num_rounds self.train_script = v.train_script self.train_args = v.train_args self.aggregator = v.aggregator self.aggregator_data_kind = v.aggregator_data_kind self.launch_external_process = v.launch_external_process self.command = v.command self.framework = v.framework self.server_expected_format = v.server_expected_format self.params_transfer_type = v.params_transfer_type self.model_persistor = v.model_persistor self.per_site_config = v.per_site_config self._validate_per_site_config(self.per_site_config) self.launch_once = v.launch_once self.shutdown_timeout = v.shutdown_timeout self.key_metric = v.key_metric self.stop_cond = v.stop_cond self.patience = v.patience self.best_model_filename = v.best_model_filename self.save_filename = controller_save_filename self.exclude_vars = v.exclude_vars self.aggregation_weights = v.aggregation_weights self.server_memory_gc_rounds = v.server_memory_gc_rounds self.enable_tensor_disk_offload = v.enable_tensor_disk_offload self.client_memory_gc_rounds = v.client_memory_gc_rounds self.cuda_empty_cache = v.cuda_empty_cache # Validate that we have at least one model source # Note: Subclasses (e.g., sklearn) that manage models differently should pass # a model or model_persistor to satisfy this check. if self.model is None and self.model_persistor is None and self.initial_ckpt is None: raise ValueError( "Must provide either model, initial_ckpt, or model_persistor. " "Cannot create a job without a model source." ) # Create BaseFedJob - all frameworks use it for consistency job = BaseFedJob( name=self.name, min_clients=self.min_clients, key_metric=self.key_metric, ) # Setup framework-specific model components and persistor # Child classes (PT/TF wrappers) override this method for framework-specific logic persistor_id = self._setup_model_and_persistor(job) # Convert model to dict if needed (e.g., PyTorch nn.Module) # Only pass to controller if no persistor is handling the model # (persistor already handles initial model via PTModel/TFModel) # Note: empty string "" means no persistor, so we need model_params has_persistor = persistor_id != "" model_params = None if has_persistor else self._get_model_params() if not has_persistor and model_params is None: raise ValueError( "Unable to configure a model source for FedAvgRecipe: no persistor and no model parameters. " "Use a framework-specific recipe for checkpoint-only initialization, or provide model/model_persistor." ) # Prepare aggregator for controller - must be ModelAggregator for FLModel-based aggregation model_aggregator = self._get_model_aggregator() # Add controller with InTime aggregation and all features controller = FedAvg( num_clients=self.min_clients, num_rounds=self.num_rounds, persistor_id=persistor_id, model=model_params, save_filename=self.save_filename, aggregator=model_aggregator, stop_cond=self.stop_cond, patience=self.patience, task_name="train", exclude_vars=self.exclude_vars, aggregation_weights=self.aggregation_weights, memory_gc_rounds=self.server_memory_gc_rounds, enable_tensor_disk_offload=self.enable_tensor_disk_offload, ) job.to_server(controller) if self.per_site_config is not None: for site_name, site_config in self.per_site_config.items(): # Use site-specific config or fall back to defaults script = ( site_config.get("train_script") if site_config.get("train_script") is not None else self.train_script ) script_args = ( site_config.get("train_args") if site_config.get("train_args") is not None else self.train_args ) launch_external = ( site_config.get("launch_external_process") if site_config.get("launch_external_process") is not None else self.launch_external_process ) command = site_config.get("command") if site_config.get("command") is not None else self.command framework = site_config.get("framework") if site_config.get("framework") is not None else self.framework expected_format = ( site_config.get("server_expected_format") if site_config.get("server_expected_format") is not None else self.server_expected_format ) transfer_type = ( site_config.get("params_transfer_type") if site_config.get("params_transfer_type") is not None else self.params_transfer_type ) launch_once = ( site_config.get("launch_once") if site_config.get("launch_once") is not None else self.launch_once ) shutdown_timeout = ( site_config.get("shutdown_timeout") if site_config.get("shutdown_timeout") is not None else self.shutdown_timeout ) executor = ScriptRunner( script=script, script_args=script_args, launch_external_process=launch_external, command=command, framework=framework, server_expected_format=expected_format, params_transfer_type=transfer_type, launch_once=launch_once, shutdown_timeout=shutdown_timeout, memory_gc_rounds=self.client_memory_gc_rounds, cuda_empty_cache=self.cuda_empty_cache, ) job.to(executor, site_name) else: executor = ScriptRunner( script=self.train_script, script_args=self.train_args, launch_external_process=self.launch_external_process, command=self.command, framework=self.framework, server_expected_format=self.server_expected_format, params_transfer_type=self.params_transfer_type, launch_once=self.launch_once, shutdown_timeout=self.shutdown_timeout, memory_gc_rounds=self.client_memory_gc_rounds, cuda_empty_cache=self.cuda_empty_cache, ) job.to_clients(executor) Recipe.__init__(self, job) @staticmethod def _resolve_model_filenames(best_model_filename: Optional[str], save_filename: Optional[str]) -> tuple[str, str]: if save_filename is None: resolved_best_model_filename = best_model_filename or DefaultCheckpointFileName.BEST_GLOBAL_MODEL controller_save_filename = best_model_filename or DefaultCheckpointFileName.GLOBAL_MODEL return resolved_best_model_filename, controller_save_filename if best_model_filename is not None and best_model_filename != save_filename: raise ValueError("Specify either best_model_filename or save_filename, not conflicting values for both.") warnings.warn( "save_filename is deprecated; use best_model_filename instead. FedAvg recipes treat save_filename as " "an alias for the best-model checkpoint filename.", FutureWarning, stacklevel=3, ) return save_filename, save_filename @staticmethod def _validate_per_site_config(per_site_config: Optional[Dict[str, Dict]]) -> None: if per_site_config is None: return reserved_targets = {SERVER_SITE_NAME, ALL_SITES} for site_name, site_config in per_site_config.items(): if not isinstance(site_name, str): raise ValueError(f"per_site_config key must be str, got {type(site_name).__name__}") if site_name in reserved_targets: raise ValueError( f"'{site_name}' is a reserved target name and cannot be used in per_site_config. " f"Reserved names: {sorted(reserved_targets)}" ) if not isinstance(site_config, dict): raise ValueError(f"per_site_config['{site_name}'] must be a dict, got {type(site_config).__name__}") def _get_model_params(self) -> Optional[Dict]: """Convert model to dict of params. Base implementation handles dict and None. Framework-specific subclasses should override this to handle their model types (e.g., nn.Module, tf.keras.Model). Returns: Optional[Dict]: model parameters as dict, or None """ if self.model is None: return None if isinstance(self.model, dict): return self.model # Unknown type - subclasses should override for framework-specific handling raise TypeError( f"model must be a dict or None for the base recipe. " f"Got {type(self.model).__name__}. " f"Use a framework-specific recipe (e.g., nvflare.app_opt.pt.recipes.FedAvgRecipe) " f"for nn.Module or other model types." ) def _get_model_aggregator(self): """Get the ModelAggregator for the FedAvg controller. The FedAvg controller expects a ModelAggregator (works with FLModel). If no aggregator is provided, returns None (uses built-in weighted averaging). If a ModelAggregator is provided, returns it directly. Returns: ModelAggregator or None """ if self.aggregator is None: return None # Import here to avoid circular imports from nvflare.app_common.aggregators.model_aggregator import ModelAggregator if isinstance(self.aggregator, ModelAggregator): return self.aggregator else: # It's a Shareable-based Aggregator - can't use directly with FedAvg # Log a warning and fall back to built-in aggregation import logging logging.getLogger(__name__).warning( f"Provided aggregator {type(self.aggregator).__name__} is not a ModelAggregator. " "Using built-in weighted averaging instead. For custom aggregation with FedAvg, " "please use a ModelAggregator subclass (e.g., from model_aggregator.py)." ) return None def _setup_numpy_model_and_persistor(self, job: BaseFedJob, *, model: Any, initial_ckpt: Optional[str]) -> str: """Configure NPModelPersistor for unified NumPy recipe usage.""" import numpy as np from nvflare.app_common.np.np_model_persistor import NPModelPersistor from nvflare.recipe.utils import extract_persistor_id, resolve_initial_ckpt model_list = None if model is not None: if isinstance(model, np.ndarray): model_list = model.tolist() elif isinstance(model, list): model_list = model else: raise TypeError( f"FrameworkType.NUMPY requires model to be a numpy array or list, got {type(model).__name__}." ) ckpt_path = resolve_initial_ckpt(initial_ckpt, getattr(self, "_prepared_initial_ckpt", None), job) persistor = NPModelPersistor( model=model_list, source_ckpt_file_full_name=ckpt_path, ) persistor_id = extract_persistor_id(job.to_server(persistor, id="persistor")) if persistor_id and hasattr(job, "comp_ids"): job.comp_ids["persistor_id"] = persistor_id return persistor_id def _setup_model_and_persistor(self, job: BaseFedJob) -> str: """Setup generic custom persistor only. Framework-specific recipes (PT/TF/NumPy) override this method to build and register their model wrappers and default persistors. Returns: str: The persistor_id to be used by the controller. """ from nvflare.recipe.utils import setup_custom_persistor persistor_id = setup_custom_persistor(job=job, model_persistor=self.model_persistor) if persistor_id: if hasattr(job, "comp_ids"): job.comp_ids.setdefault("persistor_id", persistor_id) return persistor_id if self.framework == FrameworkType.NUMPY and (self.model is not None or self.initial_ckpt is not None): return self._setup_numpy_model_and_persistor(job, model=self.model, initial_ckpt=self.initial_ckpt) return ""