nvflare.recipe.fedavg module
- class FedAvgRecipe(*, name: str = 'fedavg', model: Any | Dict[str, Any] | None = None, initial_ckpt: str | None = None, min_clients: int, num_rounds: int = 2, train_script: str, train_args: str = '', aggregator: Aggregator | None = None, aggregator_data_kind: DataKind | None = DataKind.WEIGHTS, launch_external_process: bool = False, command: str = 'python3 -u', framework: FrameworkType = FrameworkType.PYTORCH, server_expected_format: ExchangeFormat = ExchangeFormat.NUMPY, params_transfer_type: TransferType = TransferType.FULL, model_persistor: ModelPersistor | None = None, per_site_config: Dict[str, Dict] | None = None, launch_once: bool = True, shutdown_timeout: float = 0.0, key_metric: str = 'accuracy', stop_cond: str | None = None, patience: int | None = None, best_model_filename: str | None = None, save_filename: str | None = None, exclude_vars: str | None = None, aggregation_weights: Dict[str, float] | None = None, server_memory_gc_rounds: int = 0, enable_tensor_disk_offload: bool = False, client_memory_gc_rounds: int = 0, cuda_empty_cache: bool = False)[source]
Bases:
RecipeUnified FedAvg recipe for PyTorch, TensorFlow, and Scikit-learn.
FedAvg is a fundamental federated learning algorithm that aggregates model updates from multiple clients by computing a weighted average based on the amount of local training data. This recipe sets up a complete federated learning workflow with memory-efficient InTime aggregation.
The recipe configures: - A federated job with initial model (optional) - FedAvg controller with InTime aggregation for memory efficiency - Optional early stopping and model selection - Script runners for client-side training execution
- Parameters:
name – Name of the federated learning job. Defaults to “fedavg”.
model – Initial model to start federated training with. Can be: - Model instance (nn.Module, tf.keras.Model, etc.) - Dict config: {“class_path”: “module.ClassName”, “args”: {“param”: value}} - None: no initial model For framework-specific types (nn.Module, tf.keras.Model), use the corresponding framework recipe (e.g., nvflare.app_opt.pt.recipes.FedAvgRecipe).
initial_ckpt – Absolute path to a pre-trained checkpoint file. The file may not exist locally as it could be on the server. Used to load initial weights.
min_clients – Minimum number of clients required to start a training round.
num_rounds – Number of federated training rounds to execute. Defaults to 2.
train_script – Path to the training script that will be executed on each client.
train_args – Command line arguments to pass to the training script.
aggregator – Custom aggregator (ModelAggregator) for combining client model updates. Must implement accept_model(), aggregate_model(), reset_stats() methods. If None, uses built-in memory-efficient weighted averaging. Defaults to None.
aggregator_data_kind – Data kind for aggregation (DataKind.WEIGHTS or DataKind.WEIGHT_DIFF). Kept for backward compatibility. Defaults to DataKind.WEIGHTS.
launch_external_process – Whether to launch the script in external process. Defaults to False.
command – If launch_external_process=True, command to run script (prepended to script). Defaults to “python3 -u”.
framework – The framework type. One of: - FrameworkType.PYTORCH (default) - FrameworkType.TENSORFLOW - FrameworkType.NUMPY - FrameworkType.RAW (for custom frameworks, e.g., sklearn, XGBoost)
server_expected_format – What format to exchange the parameters between server and client. Defaults to ExchangeFormat.NUMPY.
params_transfer_type – How to transfer the parameters. FULL means the whole model parameters are sent. DIFF means that only the difference is sent. Defaults to TransferType.FULL.
model_persistor – Custom model persistor for any framework. If None, uses the framework’s default persistor when one is available.
per_site_config – Per-site configuration for the federated learning job. Dictionary mapping site names to configuration dicts. Each config dict can contain optional overrides: - train_script (str): Training script path - train_args (str): Script arguments - launch_external_process (bool): Whether to launch external process - command (str): Command prefix for external process - framework (FrameworkType): Framework type - server_expected_format (ExchangeFormat): Exchange format - params_transfer_type (TransferType): Parameter transfer type - launch_once (bool): Whether to launch external process once or per task - shutdown_timeout (float): Shutdown timeout in seconds If not provided, the same configuration will be used for all clients.
launch_once – Whether the external process will be launched only once at the beginning or on each task. Only used if launch_external_process is True. Defaults to True.
shutdown_timeout – If provided, will wait for this number of seconds before shutdown. Only used if launch_external_process is True. Defaults to 0.0.
key_metric – Metric used to determine if the model is globally best. If validation metrics are a dict, key_metric selects the metric used for global model selection by the IntimeModelSelector. Defaults to “accuracy”.
stop_cond – Early stopping condition based on metric. String literal in the format of ‘<key> <op> <value>’ (e.g. “accuracy >= 80”). If None, early stopping is disabled.
patience – Number of rounds with no improvement after which FL will be stopped. Only applies if stop_cond is set. Defaults to None.
best_model_filename – Filename for saving the best model. If unset, framework persistors that expose a separate best-model filename use their own default, such as DefaultCheckpointFileName.BEST_GLOBAL_MODEL for the default PyTorch persistor.
save_filename – Deprecated alias for best_model_filename. If both are specified, they must match.
exclude_vars – Regex pattern for variables to exclude from aggregation.
aggregation_weights – Per-client aggregation weights dict. Defaults to equal weights.
server_memory_gc_rounds – Run memory cleanup (gc.collect + malloc_trim) every N rounds on server. Set to 0 to disable. Defaults to 0.
enable_tensor_disk_offload – Enable disk-backed tensor offload for incoming streamed payloads. When True, server receives tensor payloads via temp files and materializes lazily.
Note
This recipe uses InTime (streaming) aggregation for memory efficiency - each client result is aggregated immediately upon receipt rather than collecting all results first. Memory usage is constant regardless of the number of clients.
If you want to use a custom aggregator, you can pass it in the aggregator parameter. The custom aggregator must be a subclass of the Aggregator class.
This is base class of a recipe. Recipes are implemented by jobs. A concrete recipe must provide the job for recipe implementation.
- Parameters:
job – the job that implements the recipe.