Source code for nvflare.app_common.workflows.base_fedavg

# Copyright (c) 2023, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import random
from typing import List

from nvflare.apis.fl_constant import FLMetaKey
from nvflare.app_common.abstract.fl_model import FLModel
from nvflare.app_common.abstract.model import make_model_learnable
from nvflare.app_common.aggregators.weighted_aggregation_helper import WeightedAggregationHelper
from nvflare.app_common.app_constant import AppConstants
from nvflare.app_common.app_event_type import AppEventType
from nvflare.app_common.utils.fl_model_utils import FLModelUtils
from nvflare.security.logging import secure_format_exception

from .wf_controller import WFController


[docs]class BaseFedAvg(WFController): def __init__( self, *args, min_clients: int = 1000, num_rounds: int = 5, start_round: int = 0, persist_every_n_rounds: int = 1, **kwargs, ): """The base controller for FedAvg Workflow. *Note*: This class is based on the `WFController`. Implements [FederatedAveraging](https://arxiv.org/abs/1602.05629). A model persistor can be configured via the `persistor_id` argument of the `WFController`. The model persistor is used to load the initial global model which is sent to a list of clients. Each client sends it's updated weights after local training which is aggregated. Next, the global model is updated. The model_persistor will also save the model after training. Provides the default implementations for the follow routines: - def sample_clients(self, min_clients) - def aggregate(self, results: List[FLModel], aggregate_fn=None) -> FLModel - def update_model(self, aggr_result) The `run` routine needs to be implemented by the derived class: - def run(self) Args: min_clients (int, optional): The minimum number of clients responses before Workflow starts to wait for `wait_time_after_min_received`. Note that the workflow will move forward when all available clients have responded regardless of this value. Defaults to 1000. num_rounds (int, optional): The total number of training rounds. Defaults to 5. start_round (int, optional): The starting round number. persist_every_n_rounds (int, optional): persist the global model every n rounds. Defaults to 1. If n is 0 then no persist. """ super().__init__(*args, **kwargs) self.min_clients = min_clients self.num_rounds = num_rounds self.start_round = start_round self.persist_every_n_rounds = persist_every_n_rounds self.current_round = None
[docs] def sample_clients(self, num_clients): """Called by the `run` routine to get a list of available clients. Args: min_clients: number of clients to return. Returns: list of clients. """ clients = self.engine.get_clients() if num_clients <= len(clients): random.shuffle(clients) clients = clients[0:num_clients] else: self.info( f"num_clients ({num_clients}) is greater than the number of available clients. Returning all clients." ) return clients
@staticmethod def _check_results(results: List[FLModel]): empty_clients = [] for _result in results: if not _result.params: empty_clients.append(_result.meta.get("client_name", AppConstants.CLIENT_UNKNOWN)) if len(empty_clients) > 0: raise ValueError(f"Result from client(s) {empty_clients} is empty!")
[docs] @staticmethod def aggregate_fn(results: List[FLModel]) -> FLModel: aggregation_helper = WeightedAggregationHelper() for _result in results: aggregation_helper.add( data=_result.params, weight=_result.meta.get(FLMetaKey.NUM_STEPS_CURRENT_ROUND, 1.0), contributor_name=_result.meta.get("client_name", AppConstants.CLIENT_UNKNOWN), contribution_round=_result.current_round, ) aggregated_dict = aggregation_helper.get_result() aggr_result = FLModel( params=aggregated_dict, params_type=results[0].params_type, meta={"nr_aggregated": len(results), "current_round": results[0].current_round}, ) return aggr_result
[docs] def aggregate(self, results: List[FLModel], aggregate_fn=None) -> FLModel: """Called by the `run` routine to aggregate the training results of clients. Args: results: a list of FLModel containing training results of the clients. aggregate_fn: a function that turns the list of FLModel into one resulting (aggregated) FLModel. Returns: aggregated FLModel. """ self.debug("Start aggregation.") self.event(AppEventType.BEFORE_AGGREGATION) self._check_results(results) if not aggregate_fn: aggregate_fn = self.aggregate_fn self.info(f"aggregating {len(results)} update(s) at round {self.current_round}") try: aggr_result = aggregate_fn(results) except Exception as e: error_msg = f"Exception in aggregate call: {secure_format_exception(e)}" self.exception(error_msg) self.panic(error_msg) return FLModel() self._results = [] self.fl_ctx.set_prop(AppConstants.AGGREGATION_RESULT, aggr_result, private=True, sticky=False) self.event(AppEventType.AFTER_AGGREGATION) self.debug("End aggregation.") return aggr_result
[docs] def update_model(self, model, aggr_result): """Called by the `run` routine to update the current global model (self.model) given the aggregated result. Args: model: FLModel to be updated. aggr_result: aggregated FLModel. Returns: None. """ self.event(AppEventType.BEFORE_SHAREABLE_TO_LEARNABLE) model = FLModelUtils.update_model(model, aggr_result) # persistor uses Learnable format to save model ml = make_model_learnable(weights=model.params, meta_props=model.meta) self.fl_ctx.set_prop(AppConstants.GLOBAL_MODEL, ml, private=True, sticky=True) self.event(AppEventType.AFTER_SHAREABLE_TO_LEARNABLE) return model
[docs] def save_model(self, model: FLModel): if ( self.persist_every_n_rounds != 0 and (self.current_round + 1) % self.persist_every_n_rounds == 0 ) or self.current_round == self.num_rounds - 1: super().save_model(model)