Source code for nvflare.app_common.np.recipes.fedavg

# Copyright (c) 2025, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Any, Dict, Optional, Union

import numpy as np

from nvflare.apis.dxo import DataKind
from nvflare.app_common.abstract.aggregator import Aggregator
from nvflare.app_common.np.np_model_persistor import NPModelPersistor
from nvflare.client.config import ExchangeFormat, TransferType
from nvflare.fuel.utils.constants import FrameworkType
from nvflare.recipe.fedavg import FedAvgRecipe as UnifiedFedAvgRecipe


[docs] class NumpyFedAvgRecipe(UnifiedFedAvgRecipe): """A recipe for implementing Federated Averaging (FedAvg) with NumPy in NVFlare. FedAvg is a fundamental federated learning algorithm that aggregates model updates from multiple clients by computing a weighted average based on the amount of local training data. This recipe sets up a complete federated learning workflow with memory-efficient InTime aggregation, specifically designed for NumPy-based models. The recipe configures: - A federated job with initial model (required for cross-site eval; otherwise client may see KeyError: 'numpy_key') - FedAvg controller with InTime aggregation for memory efficiency - Optional early stopping and model selection - Script runners for client-side training execution Args: name: Name of the federated learning job. Defaults to "fedavg". model: Initial model (as list or numpy array) to start federated training with. Lists are preferred for JSON serialization compatibility. Required unless initial_ckpt is provided: the base FedAvgRecipe raises ValueError if model, initial_ckpt, and model_persistor are all None. initial_model: Deprecated alias for ``model``. Use ``model``. If both are set, ``model`` wins. initial_ckpt: Absolute path to a pre-trained checkpoint file (.npy, .npz). The file may not exist locally as it could be on the server. Used to load initial model parameters. min_clients: Minimum number of clients required to start a training round. num_rounds: Number of federated training rounds to execute. Defaults to 2. train_script: Path to the training script that will be executed on each client. train_args: Command line arguments to pass to the training script. aggregator: Custom aggregator (ModelAggregator) for combining client model updates. Must implement accept_model(), aggregate_model(), reset_stats() methods. If None, uses built-in memory-efficient weighted averaging. aggregator_data_kind: Data kind to use for the aggregator. Defaults to DataKind.WEIGHTS. Kept for backward compatibility. launch_external_process (bool): Whether to launch the script in external process. Defaults to False. command (str): If launch_external_process=True, command to run script (prepended to script). Defaults to "python3 -u". server_expected_format (str): What format to exchange the parameters between server and client. params_transfer_type (str): How to transfer the parameters. FULL means the whole model parameters are sent. DIFF means that only the difference is sent. Defaults to TransferType.FULL. per_site_config: Per-site configuration for the federated learning job. launch_once: Whether external process is launched once or per task. Defaults to True. shutdown_timeout: Seconds to wait before shutdown. Defaults to 0.0. key_metric: Metric used to determine if the model is globally best. Defaults to "accuracy". stop_cond: Early stopping condition based on metric. String literal in the format of '<key> <op> <value>' (e.g. "accuracy >= 80"). If None, early stopping is disabled. patience: Number of rounds with no improvement after which FL will be stopped. save_filename: Filename for saving the best model. Defaults to "FL_global_model.pt". exclude_vars: Regex pattern for variables to exclude from aggregation. aggregation_weights: Per-client aggregation weights dict. Defaults to equal weights. Example: ```python recipe = NumpyFedAvgRecipe( name="my_fedavg_job", model=numpy_model, min_clients=2, num_rounds=10, train_script="client.py", train_args="--learning_rate 0.01", stop_cond="accuracy >= 95", patience=3 ) ``` Note: This recipe uses InTime (streaming) aggregation for memory efficiency - each client result is aggregated immediately upon receipt rather than collecting all results first. Memory usage is constant regardless of the number of clients. By default, this recipe implements the standard FedAvg algorithm where model updates are aggregated using weighted averaging based on the number of training samples provided by each client. If you want to use a custom aggregator, you can pass it in the aggregator parameter. The custom aggregator must be a subclass of the ModelAggregator class. """ def __init__( self, *, name: str = "fedavg", model: Union[Any, Dict[str, Any], None] = None, initial_model: Union[Any, Dict[str, Any], None] = None, # backward compat (2.7 / old job.py) initial_ckpt: Optional[str] = None, min_clients: int, num_rounds: int = 2, train_script: str, train_args: str = "", aggregator: Optional[Aggregator] = None, aggregator_data_kind: Optional[DataKind] = DataKind.WEIGHTS, launch_external_process: bool = False, command: str = "python3 -u", server_expected_format: ExchangeFormat = ExchangeFormat.NUMPY, params_transfer_type: TransferType = TransferType.FULL, per_site_config: Optional[Dict[str, Dict]] = None, launch_once: bool = True, shutdown_timeout: float = 0.0, key_metric: str = "accuracy", # New FedAvg features stop_cond: Optional[str] = None, patience: Optional[int] = None, save_filename: str = "FL_global_model.pt", exclude_vars: Optional[str] = None, aggregation_weights: Optional[Dict[str, float]] = None, client_memory_gc_rounds: int = 0, cuda_empty_cache: bool = False, ): # Store model and initial_ckpt for NumPy-specific setup (model wins over initial_model for 2.7 compat) self._np_model = model if model is not None else initial_model self._np_initial_ckpt = initial_ckpt # Call the unified FedAvgRecipe with NumPy-specific settings super().__init__( name=name, model=self._np_model, initial_ckpt=initial_ckpt, min_clients=min_clients, num_rounds=num_rounds, train_script=train_script, train_args=train_args, aggregator=aggregator, aggregator_data_kind=aggregator_data_kind, launch_external_process=launch_external_process, command=command, framework=FrameworkType.NUMPY, server_expected_format=server_expected_format, params_transfer_type=params_transfer_type, model_persistor=None, # We'll set up NPModelPersistor in _setup_model_and_persistor per_site_config=per_site_config, launch_once=launch_once, shutdown_timeout=shutdown_timeout, key_metric=key_metric, stop_cond=stop_cond, patience=patience, save_filename=save_filename, exclude_vars=exclude_vars, aggregation_weights=aggregation_weights, client_memory_gc_rounds=client_memory_gc_rounds, cuda_empty_cache=cuda_empty_cache, ) # Override framework for cross-site evaluation compatibility # Parent class sets self.framework = NUMPY for internal processing, # but external APIs (add_cross_site_evaluation) expect RAW for NumPy recipes self.framework = FrameworkType.RAW def _setup_model_and_persistor(self, job) -> str: """Override to handle NumPy-specific model setup. Returns a non-empty string component id when a persistor is configured, and \"\" otherwise. Normalizes job.to_server() return (e.g. None or non-string) so the parent's has_persistor = persistor_id != \"\" and model_params assignment remain correct. """ if self._np_model is not None or self._np_initial_ckpt is not None: from nvflare.recipe.utils import prepare_initial_ckpt # Convert numpy array to list for JSON serialization # NPModelPersistor expects a list, not a numpy array model_list = None if self._np_model is not None: if isinstance(self._np_model, np.ndarray): model_list = self._np_model.tolist() elif isinstance(self._np_model, list): model_list = self._np_model else: raise TypeError(f"model must be a numpy array or list, got {type(self._np_model).__name__}") ckpt_path = prepare_initial_ckpt(self._np_initial_ckpt, job) persistor = NPModelPersistor( model=model_list, source_ckpt_file_full_name=ckpt_path, ) raw_id = job.to_server(persistor, id="persistor") persistor_id = raw_id if isinstance(raw_id, str) and (raw_id or "").strip() else "" if persistor_id and hasattr(job, "comp_ids"): job.comp_ids["persistor_id"] = persistor_id return persistor_id return ""
[docs] def add_cse_validator_if_needed(self): """Add NPValidator for cross-site evaluation if not already configured. NumPy recipes need specialized NPValidator because: - NumPy training scripts typically only handle training tasks - Wildcard executors (tasks=["*"]) don't actually implement validation - Cross-site evaluation requires dedicated validation component This method checks if a dedicated validator is already configured. If only wildcard executors exist, adds NPValidator. """ from nvflare.app_common.app_constant import AppConstants from nvflare.app_common.np.np_validator import NPValidator # Check if validation task is explicitly configured (not just via wildcard) has_explicit_validator = False if hasattr(self.job, "_deploy_map"): for target, app in self.job._deploy_map.items(): if target == "server": continue if hasattr(app, "app_config") and hasattr(app.app_config, "executors"): for executor_def in app.app_config.executors: if hasattr(executor_def, "tasks"): try: # Check if validation is explicitly listed (not just wildcard) if AppConstants.TASK_VALIDATION in executor_def.tasks: has_explicit_validator = True break except (TypeError, AttributeError): continue if has_explicit_validator: break if not has_explicit_validator: # No explicit validator found - add NPValidator for cross-site evaluation validator = NPValidator() self.job.to_clients(validator, tasks=[AppConstants.TASK_VALIDATION])