nvflare.app_common.workflows.fedavg module
- class FedAvg(*args, model: Dict | FLModel | None = None, save_filename: str | None = 'FL_global_model.pt', aggregator: ModelAggregator | None = None, stop_cond: str | None = None, patience: int | None = None, task_name: str | None = 'train', exclude_vars: str | None = None, aggregation_weights: Dict[str, float] | None = None, enable_tensor_disk_offload: bool = False, **kwargs)[source]
Bases:
BaseFedAvgController for FedAvg Workflow with optional Early Stopping and Model Selection.
Note: This class is based on the ModelController. Implements [FederatedAveraging](https://arxiv.org/abs/1602.05629).
Uses InTime (streaming) aggregation for memory efficiency - each client result is aggregated immediately upon receipt rather than collecting all results first.
Supports custom aggregators via the ModelAggregator interface.
- Provides the implementations for the run routine, controlling the main workflow:
def run(self)
The parent classes provide the default implementations for other routines.
For simple model persistence without complex ModelPersistor setup, you can: 1. Pass model (dict of params) and save_filename 2. Override save_model() and load_model() for framework-specific serialization
- Parameters:
num_clients (int, optional) – The number of clients. Defaults to 3.
num_rounds (int, optional) – The total number of training rounds. Defaults to 5.
start_round (int, optional) – The starting round number.
persistor_id (str, optional) – ID of the persistor component. Defaults to “persistor”. If empty and model is provided, uses simple save_model/load_model methods.
model (dict or FLModel, optional) – Initial model parameters. If provided, this is used instead of loading from persistor. Defaults to None.
save_filename (str, optional) – Filename for saving the best model. Defaults to “FL_global_model.pt”. Only used when persistor_id is empty.
aggregator (ModelAggregator, optional) – Custom aggregator for combining client model updates. Must implement accept_model(), aggregate_model(), reset_stats(). If None, uses built-in weighted averaging (memory-efficient). Defaults to None.
stop_cond (str, optional) – Early stopping condition based on metric. String literal in the format of ‘<key> <op> <value>’ (e.g. “accuracy >= 80”). If None, early stopping is disabled. Defaults to None.
patience (int, optional) – The number of rounds with no improvement after which FL will be stopped. Only applies if stop_cond is set. Defaults to None.
task_name (str, optional) – Task name for training. Defaults to “train”.
exclude_vars (str, optional) – Regex pattern for variables to exclude from aggregation. Defaults to None. Only used when no custom aggregator is provided.
aggregation_weights (dict, optional) – Per-client aggregation weights. Defaults to None (equal weights). Only used when no custom aggregator is provided.
enable_tensor_disk_offload (bool, optional) – Download tensors to disk during FOBS streaming instead of deserializing into memory. Reduces peak server memory from ~N× to ~1× model size during aggregation. When used with a custom aggregator, lazy refs are passed through directly and must be handled by that aggregator. Defaults to False.
The base controller for FedAvg Workflow. Note: This class is based on the ModelController.
Implements [FederatedAveraging](https://arxiv.org/abs/1602.05629).
A model persistor can be configured via the persistor_id argument of the ModelController. The model persistor is used to load the initial global model which is sent to a list of clients. Each client sends it’s updated weights after local training which is aggregated. Next, the global model is updated. The model_persistor will also save the model after training.
- Provides the default implementations for the follow routines:
def aggregate(self, results: List[FLModel], aggregate_fn=None) -> FLModel
def update_model(self, aggr_result)
The run routine needs to be implemented by the derived class:
def run(self)
- Parameters:
num_clients (int, optional) – The number of clients. Defaults to 3. NOTE: this argument should not be here
release. (we will remove this argument in next)
num_rounds (int, optional) – The total number of training rounds. Defaults to 5.
start_round (int, optional) – The starting round number.
memory_gc_rounds (int, optional) – Run memory cleanup (gc.collect + malloc_trim) every N rounds. Set to 0 to disable. Defaults to 0 (disabled).
- is_curr_model_better(curr_model: FLModel) bool[source]
Checks if the new model is better than the current best model.
- Parameters:
curr_model (FLModel) – the new model to evaluate.
- Returns:
True if the new model is better than the current best model, False otherwise
- load_model() FLModel[source]
Load model. Uses persistor if available, otherwise uses load_model_file.
Override load_model_file for framework-specific deserialization (e.g., torch.load).
- Returns:
loaded model, or None if loading fails
- Return type:
FLModel
- load_model_file(filepath: str) FLModel[source]
Load model from file. Override this for framework-specific deserialization.
Default implementation uses FOBS (pickle-compatible). For PyTorch, override with: FLModel(params=torch.load(filepath))
- Parameters:
filepath (str) – path to load the model from
- Returns:
loaded model
- Return type:
FLModel
- save_model(model: FLModel) None[source]
Save model. Uses persistor if available, otherwise uses save_model_file.
Override save_model_file for framework-specific serialization (e.g., torch.save).
- Parameters:
model (FLModel) – model to save
- save_model_file(model: FLModel, filepath: str) None[source]
Save model to file. Override this for framework-specific serialization.
Default implementation uses FOBS (pickle-compatible). For PyTorch, override with: torch.save(model.params, filepath)
- Parameters:
model (FLModel) – model to save
filepath (str) – path to save the model