nvflare.app_opt.pt.recipes.fedavg module
- class FedAvgRecipe(*, name: str = 'fedavg', model: Any | dict[str, Any] | None = None, initial_ckpt: str | None = None, min_clients: int, num_rounds: int = 2, train_script: str, train_args: str = '', aggregator: Aggregator | None = None, aggregator_data_kind: DataKind | None = DataKind.WEIGHTS, launch_external_process: bool = False, command: str = 'python3 -u', server_expected_format: ExchangeFormat = ExchangeFormat.NUMPY, params_transfer_type: TransferType = TransferType.FULL, model_persistor: ModelPersistor | None = None, model_locator: ModelLocator | None = None, per_site_config: dict[str, dict] | None = None, launch_once: bool = True, shutdown_timeout: float = 0.0, key_metric: str = 'accuracy', stop_cond: str | None = None, patience: int | None = None, best_model_filename: str | None = None, save_filename: str | None = None, exclude_vars: str | None = None, aggregation_weights: dict[str, float] | None = None, server_memory_gc_rounds: int = 0, enable_tensor_disk_offload: bool = False, client_memory_gc_rounds: int = 0, cuda_empty_cache: bool = False)[source]
Bases:
FedAvgRecipeA recipe for implementing Federated Averaging (FedAvg) for PyTorch.
FedAvg is a fundamental federated learning algorithm that aggregates model updates from multiple clients by computing a weighted average based on the amount of local training data. This recipe sets up a complete federated learning workflow with memory-efficient InTime aggregation.
The recipe configures: - A federated job with initial model (optional) - FedAvg controller with InTime aggregation for memory efficiency - Optional early stopping and model selection - Script runners for client-side training execution
- Parameters:
name – Name of the federated learning job. Defaults to “fedavg”.
model – Initial model to start federated training with. Can be: - nn.Module instance - Dict config: {“class_path”: “module.ClassName”, “args”: {“param”: value}} - None: no initial model
initial_ckpt – Absolute path to a pre-trained checkpoint file. The file may not exist locally as it could be on the server. Used to load initial weights. Note: PyTorch requires model when using initial_ckpt (for architecture).
min_clients – Minimum number of clients required to start a training round.
num_rounds – Number of federated training rounds to execute. Defaults to 2.
train_script – Path to the training script that will be executed on each client.
train_args – Command line arguments to pass to the training script.
aggregator – Custom aggregator (ModelAggregator) for combining client model updates. Must implement accept_model(), aggregate_model(), reset_stats() methods. If None, uses built-in memory-efficient weighted averaging.
aggregator_data_kind – Data kind to use for the aggregator. Defaults to DataKind.WEIGHTS. Kept for backward compatibility.
launch_external_process (bool) – Whether to launch the script in external process. Defaults to False.
command (str) – If launch_external_process=True, command to run script (prepended to script). Defaults to “python3 -u”.
server_expected_format (str) – What format to exchange the parameters between server and client.
params_transfer_type (str) – How to transfer the parameters. FULL means the whole model parameters are sent. DIFF means that only the difference is sent. Defaults to TransferType.FULL.
model_persistor – Custom model persistor. If None, PTFileModelPersistor will be used.
model_locator – Custom model locator. If None, PTFileModelLocator will be used.
per_site_config – Per-site configuration for the federated learning job.
launch_once – Whether external process is launched once or per task. Defaults to True.
shutdown_timeout – Seconds to wait before shutdown. Defaults to 0.0.
key_metric – Metric used to determine if the model is globally best. Defaults to “accuracy”.
stop_cond – Early stopping condition based on metric. String literal in the format of ‘<key> <op> <value>’ (e.g. “accuracy >= 80”). If None, early stopping is disabled.
patience – Number of rounds with no improvement after which FL will be stopped.
best_model_filename – Filename for saving the best model. If unset, the default PyTorch persistor uses DefaultCheckpointFileName.BEST_GLOBAL_MODEL.
save_filename – Deprecated alias for best_model_filename. If both are specified, they must match.
exclude_vars – Regex pattern for variables to exclude from aggregation.
aggregation_weights – Per-client aggregation weights dict. Defaults to equal weights.
enable_tensor_disk_offload – Enable disk-backed tensor offload for incoming streamed payloads.
Example
Basic usage with early stopping:
```python recipe = FedAvgRecipe(
name=”my_fedavg_job”, model=pretrained_model, min_clients=2, num_rounds=10, train_script=”client.py”, train_args=”–epochs 5 –batch_size 32”, stop_cond=”accuracy >= 95”, patience=3
)
Note
This recipe uses InTime (streaming) aggregation for memory efficiency - each client result is aggregated immediately upon receipt rather than collecting all results first. Memory usage is constant regardless of the number of clients.
This is base class of a recipe. Recipes are implemented by jobs. A concrete recipe must provide the job for recipe implementation.
- param job:
the job that implements the recipe.