nvflare.app_common.workflows.base_fedavg module

class BaseFedAvg(*args, num_clients: int = 3, num_rounds: int = 5, start_round: int = 0, **kwargs)[source]

Bases: ModelController

The base controller for FedAvg Workflow. Note: This class is based on the ModelController.

Implements [FederatedAveraging](https://arxiv.org/abs/1602.05629).

A model persistor can be configured via the persistor_id argument of the ModelController. The model persistor is used to load the initial global model which is sent to a list of clients. Each client sends it’s updated weights after local training which is aggregated. Next, the global model is updated. The model_persistor will also save the model after training.

Provides the default implementations for the follow routines:
  • def aggregate(self, results: List[FLModel], aggregate_fn=None) -> FLModel

  • def update_model(self, aggr_result)

The run routine needs to be implemented by the derived class:

  • def run(self)

Parameters:
  • num_clients (int, optional) – The number of clients. Defaults to 3.

  • num_rounds (int, optional) – The total number of training rounds. Defaults to 5.

  • start_round (int, optional) – The starting round number.

aggregate(results: List[FLModel], aggregate_fn=None) FLModel[source]

Called by the run routine to aggregate the training results of clients.

Parameters:
  • results – a list of FLModel containing training results of the clients.

  • aggregate_fn – a function that turns the list of FLModel into one resulting (aggregated) FLModel.

Returns: aggregated FLModel.

static aggregate_fn(results: List[FLModel]) FLModel[source]
update_model(model, aggr_result)[source]

Called by the run routine to update the current global model (self.model) given the aggregated result.

Parameters:
  • model – FLModel to be updated.

  • aggr_result – aggregated FLModel.

Returns: None.