# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .base_fedavg import BaseFedAvg
[docs]
class FedAvg(BaseFedAvg):
"""Controller for FedAvg Workflow. *Note*: This class is based on the `ModelController`.
Implements [FederatedAveraging](https://arxiv.org/abs/1602.05629).
Provides the implementations for the `run` routine, controlling the main workflow:
- def run(self)
The parent classes provide the default implementations for other routines.
Args:
num_clients (int, optional): The number of clients. Defaults to 3.
num_rounds (int, optional): The total number of training rounds. Defaults to 5.
start_round (int, optional): The starting round number.
persistor_id (str, optional): ID of the persistor component. Defaults to "persistor".
"""
[docs]
def run(self) -> None:
self.info("Start FedAvg.")
model = self.load_model()
model.start_round = self.start_round
model.total_rounds = self.num_rounds
for self.current_round in range(self.start_round, self.start_round + self.num_rounds):
self.info(f"Round {self.current_round} started.")
model.current_round = self.current_round
clients = self.sample_clients(self.num_clients)
results = self.send_model_and_wait(targets=clients, data=model)
aggregate_results = self.aggregate(
results, aggregate_fn=self.aggregate_fn
) # using default aggregate_fn with `WeightedAggregationHelper`. Can overwrite self.aggregate_fn with signature Callable[List[FLModel], FLModel]
model = self.update_model(model, aggregate_results)
self.save_model(model)
self.info("Finished FedAvg.")