Source code for nvflare.app_common.workflows.fedavg

# Copyright (c) 2023, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from nvflare.fuel.utils.log_utils import center_message

from .base_fedavg import BaseFedAvg


[docs] class FedAvg(BaseFedAvg): """Controller for FedAvg Workflow. *Note*: This class is based on the `ModelController`. Implements [FederatedAveraging](https://arxiv.org/abs/1602.05629). Provides the implementations for the `run` routine, controlling the main workflow: - def run(self) The parent classes provide the default implementations for other routines. Args: num_clients (int, optional): The number of clients. Defaults to 3. num_rounds (int, optional): The total number of training rounds. Defaults to 5. start_round (int, optional): The starting round number. persistor_id (str, optional): ID of the persistor component. Defaults to "persistor". """
[docs] def run(self) -> None: self.info(center_message("Start FedAvg.")) model = self.load_model() model.start_round = self.start_round model.total_rounds = self.num_rounds for self.current_round in range(self.start_round, self.start_round + self.num_rounds): self.info(center_message(message=f"Round {self.current_round} started.", boarder_str="-")) model.current_round = self.current_round clients = self.sample_clients(self.num_clients) results = self.send_model_and_wait(targets=clients, data=model) aggregate_results = self.aggregate( results, aggregate_fn=self.aggregate_fn ) # using default aggregate_fn with `WeightedAggregationHelper`. Can overwrite self.aggregate_fn with signature Callable[List[FLModel], FLModel] model = self.update_model(model, aggregate_results) self.save_model(model) self.info(center_message("Finished FedAvg."))