# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List
from nvflare.apis.fl_constant import FLMetaKey
from nvflare.app_common.abstract.fl_model import FLModel
from nvflare.app_common.abstract.model import make_model_learnable
from nvflare.app_common.aggregators.weighted_aggregation_helper import (
WeightedAggregationHelper,
filter_aggregatable_metrics,
)
from nvflare.app_common.app_constant import AppConstants
from nvflare.app_common.app_event_type import AppEventType
from nvflare.app_common.utils.fl_model_utils import FLModelUtils
from nvflare.fuel.utils.memory_utils import cleanup_memory
from nvflare.fuel.utils.validation_utils import check_non_negative_int
from nvflare.security.logging import secure_format_exception
from .model_controller import ModelController
[docs]
class BaseFedAvg(ModelController):
def __init__(
self,
*args,
num_clients: int = 3,
num_rounds: int = 5,
start_round: int = 0,
memory_gc_rounds: int = 0,
**kwargs,
):
"""The base controller for FedAvg Workflow. *Note*: This class is based on the `ModelController`.
Implements [FederatedAveraging](https://arxiv.org/abs/1602.05629).
A model persistor can be configured via the `persistor_id` argument of the `ModelController`.
The model persistor is used to load the initial global model which is sent to a list of clients.
Each client sends it's updated weights after local training which is aggregated.
Next, the global model is updated.
The model_persistor will also save the model after training.
Provides the default implementations for the follow routines:
- def aggregate(self, results: List[FLModel], aggregate_fn=None) -> FLModel
- def update_model(self, aggr_result)
The `run` routine needs to be implemented by the derived class:
- def run(self)
Args:
num_clients (int, optional): The number of clients. Defaults to 3. NOTE: this argument should not be here
we will remove this argument in next release.
num_rounds (int, optional): The total number of training rounds. Defaults to 5.
start_round (int, optional): The starting round number.
memory_gc_rounds (int, optional): Run memory cleanup (gc.collect + malloc_trim) every N rounds.
Set to 0 to disable. Defaults to 0 (disabled).
"""
super().__init__(*args, **kwargs)
check_non_negative_int("memory_gc_rounds", memory_gc_rounds)
self.num_clients = num_clients
self.num_rounds = num_rounds
self.start_round = start_round
self.memory_gc_rounds = memory_gc_rounds
self.current_round = None
def _maybe_cleanup_memory(self):
"""Perform memory cleanup if configured (every N rounds based on memory_gc_rounds)."""
if self.current_round is None:
return
if self.memory_gc_rounds > 0 and (self.current_round + 1) % self.memory_gc_rounds == 0:
self.info(f"Memory cleanup at round {self.current_round + 1}")
cleanup_memory()
@staticmethod
def _check_results(results: List[FLModel]):
empty_clients = []
for _result in results:
if not _result.params:
empty_clients.append(_result.meta.get("client_name", AppConstants.CLIENT_UNKNOWN))
if len(empty_clients) > 0:
raise ValueError(f"Result from client(s) {empty_clients} is empty!")
[docs]
@staticmethod
def aggregate_fn(results: List[FLModel]) -> FLModel:
"""Aggregate model params and metrics across results with weighted averaging.
Note:
Metric values that do not support weighted arithmetic are skipped during
aggregation. If no aggregatable metrics remain after filtering, the
aggregated metrics are returned as ``None``.
"""
if not results:
raise ValueError("received empty results for aggregation.")
aggr_helper = WeightedAggregationHelper()
aggr_metrics_helper = WeightedAggregationHelper()
all_metrics = True
for _result in results:
aggr_helper.add(
data=_result.params,
weight=_result.meta.get(FLMetaKey.NUM_STEPS_CURRENT_ROUND, 1.0),
contributor_name=_result.meta.get("client_name", AppConstants.CLIENT_UNKNOWN),
contribution_round=_result.current_round,
)
if _result.metrics is None:
all_metrics = False
if all_metrics:
aggregatable = filter_aggregatable_metrics(_result.metrics)
if aggregatable:
aggr_metrics_helper.add(
data=aggregatable,
weight=_result.meta.get(FLMetaKey.NUM_STEPS_CURRENT_ROUND, 1.0),
contributor_name=_result.meta.get("client_name", AppConstants.CLIENT_UNKNOWN),
contribution_round=_result.current_round,
)
aggr_params = aggr_helper.get_result()
aggr_metrics = aggr_metrics_helper.get_result() if all_metrics else None
aggr_metrics = aggr_metrics or None
aggr_result = FLModel(
params=aggr_params,
params_type=results[0].params_type,
metrics=aggr_metrics,
meta={"nr_aggregated": len(results), "current_round": results[0].current_round},
)
return aggr_result
[docs]
def aggregate(self, results: List[FLModel], aggregate_fn=None) -> FLModel:
"""Called by the `run` routine to aggregate the training results of clients.
Args:
results: a list of FLModel containing training results of the clients.
aggregate_fn: a function that turns the list of FLModel into one resulting (aggregated) FLModel.
Returns: aggregated FLModel.
"""
self.debug("Start aggregation.")
self.event(AppEventType.BEFORE_AGGREGATION)
self._check_results(results)
if not aggregate_fn:
aggregate_fn = self.aggregate_fn
self.info(f"aggregating {len(results)} update(s) at round {self.current_round}")
try:
aggr_result = aggregate_fn(results)
except Exception as e:
error_msg = f"Exception in aggregate call: {secure_format_exception(e)}"
self.exception(error_msg)
self.panic(error_msg)
return FLModel()
self._results = []
self.fire_event_with_data(
AppEventType.AFTER_AGGREGATION, self.fl_ctx, AppConstants.AGGREGATION_RESULT, aggr_result
)
self.debug("End aggregation.")
return aggr_result
[docs]
def update_model(self, model, aggr_result):
"""Called by the `run` routine to update the current global model (self.model) given the aggregated result.
Args:
model: FLModel to be updated.
aggr_result: aggregated FLModel.
Returns: None.
"""
self.event(AppEventType.BEFORE_SHAREABLE_TO_LEARNABLE)
model = FLModelUtils.update_model(model, aggr_result)
# persistor uses Learnable format to save model
ml = make_model_learnable(weights=model.params, meta_props=model.meta)
self.fl_ctx.set_prop(AppConstants.GLOBAL_MODEL, ml, private=True, sticky=True)
self.event(AppEventType.AFTER_SHAREABLE_TO_LEARNABLE)
return model