nvflare.app_common.workflows.base_fedavg module

class BaseFedAvg(*args, num_clients: int = 3, num_rounds: int = 5, start_round: int = 0, memory_gc_rounds: int = 0, **kwargs)[source]

Bases: ModelController

The base controller for FedAvg Workflow. Note: This class is based on the ModelController.

Implements [FederatedAveraging](https://arxiv.org/abs/1602.05629).

A model persistor can be configured via the persistor_id argument of the ModelController. The model persistor is used to load the initial global model which is sent to a list of clients. Each client sends it’s updated weights after local training which is aggregated. Next, the global model is updated. The model_persistor will also save the model after training.

Provides the default implementations for the follow routines:
  • def aggregate(self, results: List[FLModel], aggregate_fn=None) -> FLModel

  • def update_model(self, aggr_result)

The run routine needs to be implemented by the derived class:

  • def run(self)

Parameters:
  • num_clients (int, optional) – The number of clients. Defaults to 3. NOTE: this argument should not be here

  • release. (we will remove this argument in next)

  • num_rounds (int, optional) – The total number of training rounds. Defaults to 5.

  • start_round (int, optional) – The starting round number.

  • memory_gc_rounds (int, optional) – Run memory cleanup (gc.collect + malloc_trim) every N rounds. Set to 0 to disable. Defaults to 0 (disabled).

aggregate(results: List[FLModel], aggregate_fn=None) FLModel[source]

Called by the run routine to aggregate the training results of clients.

Parameters:
  • results – a list of FLModel containing training results of the clients.

  • aggregate_fn – a function that turns the list of FLModel into one resulting (aggregated) FLModel.

Returns: aggregated FLModel.

static aggregate_fn(results: List[FLModel]) FLModel[source]

Aggregate model params and metrics across results with weighted averaging.

Note

Metric values that do not support weighted arithmetic are skipped during aggregation. If no aggregatable metrics remain after filtering, the aggregated metrics are returned as None.

update_model(model, aggr_result)[source]

Called by the run routine to update the current global model (self.model) given the aggregated result.

Parameters:
  • model – FLModel to be updated.

  • aggr_result – aggregated FLModel.

Returns: None.