nvflare.app_common.workflows.fedavg module¶
- class FedAvg(*args, num_clients: int = 3, num_rounds: int = 5, start_round: int = 0, **kwargs)[source]¶
Bases:
BaseFedAvg
Controller for FedAvg Workflow. Note: This class is based on the ModelController. Implements [FederatedAveraging](https://arxiv.org/abs/1602.05629).
- Provides the implementations for the run routine, controlling the main workflow:
def run(self)
The parent classes provide the default implementations for other routines.
- Parameters:
num_clients (int, optional) – The number of clients. Defaults to 3.
num_rounds (int, optional) – The total number of training rounds. Defaults to 5.
start_round (int, optional) – The starting round number.
persistor_id (str, optional) – ID of the persistor component. Defaults to “persistor”.
The base controller for FedAvg Workflow. Note: This class is based on the ModelController.
Implements [FederatedAveraging](https://arxiv.org/abs/1602.05629).
A model persistor can be configured via the persistor_id argument of the ModelController. The model persistor is used to load the initial global model which is sent to a list of clients. Each client sends it’s updated weights after local training which is aggregated. Next, the global model is updated. The model_persistor will also save the model after training.
- Provides the default implementations for the follow routines:
def aggregate(self, results: List[FLModel], aggregate_fn=None) -> FLModel
def update_model(self, aggr_result)
The run routine needs to be implemented by the derived class:
def run(self)
- Parameters:
num_clients (int, optional) – The number of clients. Defaults to 3.
num_rounds (int, optional) – The total number of training rounds. Defaults to 5.
start_round (int, optional) – The starting round number.