nvflare.app_common.workflows.fedavg module

class FedAvg(*args, model: Dict | FLModel | None = None, save_filename: str | None = 'FL_global_model.pt', aggregator: ModelAggregator | None = None, stop_cond: str | None = None, patience: int | None = None, task_name: str | None = 'train', exclude_vars: str | None = None, aggregation_weights: Dict[str, float] | None = None, enable_tensor_disk_offload: bool = False, **kwargs)[source]

Bases: BaseFedAvg

Controller for FedAvg Workflow with optional Early Stopping and Model Selection.

Note: This class is based on the ModelController. Implements [FederatedAveraging](https://arxiv.org/abs/1602.05629).

Uses InTime (streaming) aggregation for memory efficiency - each client result is aggregated immediately upon receipt rather than collecting all results first.

Supports custom aggregators via the ModelAggregator interface.

Provides the implementations for the run routine, controlling the main workflow:
  • def run(self)

The parent classes provide the default implementations for other routines.

For simple model persistence without complex ModelPersistor setup, you can: 1. Pass model (dict of params) and save_filename 2. Override save_model() and load_model() for framework-specific serialization

Parameters:
  • num_clients (int, optional) – The number of clients. Defaults to 3.

  • num_rounds (int, optional) – The total number of training rounds. Defaults to 5.

  • start_round (int, optional) – The starting round number.

  • persistor_id (str, optional) – ID of the persistor component. Defaults to “persistor”. If empty and model is provided, uses simple save_model/load_model methods.

  • model (dict or FLModel, optional) – Initial model parameters. If provided, this is used instead of loading from persistor. Defaults to None.

  • save_filename (str, optional) – Filename for saving the best model. Defaults to “FL_global_model.pt”. Only used when persistor_id is empty.

  • aggregator (ModelAggregator, optional) – Custom aggregator for combining client model updates. Must implement accept_model(), aggregate_model(), reset_stats(). If None, uses built-in weighted averaging (memory-efficient). Defaults to None.

  • stop_cond (str, optional) – Early stopping condition based on metric. String literal in the format of ‘<key> <op> <value>’ (e.g. “accuracy >= 80”). If None, early stopping is disabled. Defaults to None.

  • patience (int, optional) – The number of rounds with no improvement after which FL will be stopped. Only applies if stop_cond is set. Defaults to None.

  • task_name (str, optional) – Task name for training. Defaults to “train”.

  • exclude_vars (str, optional) – Regex pattern for variables to exclude from aggregation. Defaults to None. Only used when no custom aggregator is provided.

  • aggregation_weights (dict, optional) – Per-client aggregation weights. Defaults to None (equal weights). Only used when no custom aggregator is provided.

  • enable_tensor_disk_offload (bool, optional) – Download tensors to disk during FOBS streaming instead of deserializing into memory. Reduces peak server memory from ~N× to ~1× model size during aggregation. When used with a custom aggregator, lazy refs are passed through directly and must be handled by that aggregator. Defaults to False.

The base controller for FedAvg Workflow. Note: This class is based on the ModelController.

Implements [FederatedAveraging](https://arxiv.org/abs/1602.05629).

A model persistor can be configured via the persistor_id argument of the ModelController. The model persistor is used to load the initial global model which is sent to a list of clients. Each client sends it’s updated weights after local training which is aggregated. Next, the global model is updated. The model_persistor will also save the model after training.

Provides the default implementations for the follow routines:
  • def aggregate(self, results: List[FLModel], aggregate_fn=None) -> FLModel

  • def update_model(self, aggr_result)

The run routine needs to be implemented by the derived class:

  • def run(self)

Parameters:
  • num_clients (int, optional) – The number of clients. Defaults to 3. NOTE: this argument should not be here

  • release. (we will remove this argument in next)

  • num_rounds (int, optional) – The total number of training rounds. Defaults to 5.

  • start_round (int, optional) – The starting round number.

  • memory_gc_rounds (int, optional) – Run memory cleanup (gc.collect + malloc_trim) every N rounds. Set to 0 to disable. Defaults to 0 (disabled).

is_curr_model_better(curr_model: FLModel) bool[source]

Checks if the new model is better than the current best model.

Parameters:

curr_model (FLModel) – the new model to evaluate.

Returns:

True if the new model is better than the current best model, False otherwise

load_model() FLModel[source]

Load model. Uses persistor if available, otherwise uses load_model_file.

Override load_model_file for framework-specific deserialization (e.g., torch.load).

Returns:

loaded model, or None if loading fails

Return type:

FLModel

load_model_file(filepath: str) FLModel[source]

Load model from file. Override this for framework-specific deserialization.

Default implementation uses FOBS (pickle-compatible). For PyTorch, override with: FLModel(params=torch.load(filepath))

Parameters:

filepath (str) – path to load the model from

Returns:

loaded model

Return type:

FLModel

run() None[source]

Main run routine for the controller workflow.

save_model(model: FLModel) None[source]

Save model. Uses persistor if available, otherwise uses save_model_file.

Override save_model_file for framework-specific serialization (e.g., torch.save).

Parameters:

model (FLModel) – model to save

save_model_file(model: FLModel, filepath: str) None[source]

Save model to file. Override this for framework-specific serialization.

Default implementation uses FOBS (pickle-compatible). For PyTorch, override with: torch.save(model.params, filepath)

Parameters:
  • model (FLModel) – model to save

  • filepath (str) – path to save the model

should_stop(metrics: Dict | None = None) bool[source]

Checks whether the current FL experiment should stop.

Parameters:

metrics (Dict, optional) – experiment metrics.

Returns:

True if the experiment should stop, False otherwise.