Source code for nvflare.app_common.workflows.base_fedavg

# Copyright (c) 2023, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import math
from typing import Any, Dict, List, Optional

from nvflare.apis.fl_constant import FLMetaKey
from nvflare.app_common.abstract.fl_model import FLModel
from nvflare.app_common.abstract.model import make_model_learnable
from nvflare.app_common.aggregators.weighted_aggregation_helper import (
    WeightedAggregationHelper,
    filter_aggregatable_metrics,
)
from nvflare.app_common.app_constant import AppConstants
from nvflare.app_common.app_event_type import AppEventType
from nvflare.app_common.utils.fl_model_utils import FLModelUtils
from nvflare.fuel.utils.memory_utils import cleanup_memory
from nvflare.fuel.utils.validation_utils import check_non_negative_int
from nvflare.security.logging import secure_format_exception

from .model_controller import ModelController


[docs] def make_fedavg_metrics_aggregation_info( key_metric: Optional[str] = None, key_metric_mode: Optional[str] = None, key_metric_mode_source: Optional[str] = None, weight_key: str = FLMetaKey.NUM_STEPS_CURRENT_ROUND, weight_formula: Optional[str] = None, site_weights: Optional[List[Dict[str, Any]]] = None, ) -> Dict[str, Any]: aggregation = { "method": "weighted_average", "weight_key": weight_key, "metric_policy": "finite_numeric_metrics_only_per_key_denominator", } if weight_formula: aggregation["weight_formula"] = weight_formula info = { "metric_source": "client_reported_flmodel_metrics", "aggregation": aggregation, } if key_metric and key_metric_mode in ("max", "min"): key_metric_info = {"name": key_metric, "mode": key_metric_mode} if key_metric_mode_source: key_metric_info["mode_source"] = key_metric_mode_source info["key_metric"] = key_metric_info if site_weights: info["site_weights"] = site_weights return info
[docs] def make_key_metric_info_from_stop_condition(stop_cond, stop_condition) -> Optional[Dict[str, Any]]: if not stop_cond or not stop_condition: return None tokens = stop_cond.split(" ") if len(tokens) != 3: return None op = tokens[1] if op in (">", ">="): mode = "max" elif op in ("<", "<="): mode = "min" else: return None return { "name": stop_condition[0], "mode": mode, "mode_source": "derived_from_stop_condition", }
def _get_client_name(result: FLModel) -> str: meta = result.meta or {} value = meta.get("client_name", AppConstants.CLIENT_UNKNOWN) return value if isinstance(value, str) and value else AppConstants.CLIENT_UNKNOWN def _get_num_steps_weight(result: FLModel) -> float: meta = result.meta or {} value = meta.get(FLMetaKey.NUM_STEPS_CURRENT_ROUND) if value is None or isinstance(value, bool): return 1.0 try: weight = float(value) except (TypeError, ValueError, OverflowError): return 1.0 if not math.isfinite(weight) or weight <= 0: return 1.0 return weight def _aggregate_fl_model_metrics(results: List[FLModel]) -> Optional[Dict[str, Any]]: """Aggregate FLModel metrics across results with FedAvg-compatible semantics. Metrics are weighted averages of client metric values. For non-linear metrics such as AUROC, this is not the same as computing the metric from pooled predictions across all clients. """ aggr_metrics_helper = WeightedAggregationHelper() for _result in results: if _result.metrics is None: return None aggregatable = filter_aggregatable_metrics(_result.metrics) if aggregatable: aggr_metrics_helper.add( data=aggregatable, weight=_get_num_steps_weight(_result), contributor_name=_get_client_name(_result), contribution_round=_result.current_round, ) aggr_metrics = aggr_metrics_helper.get_result() return aggr_metrics or None
[docs] class BaseFedAvg(ModelController): def __init__( self, *args, num_clients: int = 3, num_rounds: int = 5, start_round: int = 0, memory_gc_rounds: int = 0, **kwargs, ): """The base controller for FedAvg Workflow. *Note*: This class is based on the `ModelController`. Implements [FederatedAveraging](https://arxiv.org/abs/1602.05629). A model persistor can be configured via the `persistor_id` argument of the `ModelController`. The model persistor is used to load the initial global model which is sent to a list of clients. Each client sends it's updated weights after local training which is aggregated. Next, the global model is updated. The model_persistor will also save the model after training. Provides the default implementations for the follow routines: - def aggregate(self, results: List[FLModel], aggregate_fn=None) -> FLModel - def update_model(self, aggr_result) The `run` routine needs to be implemented by the derived class: - def run(self) Args: num_clients (int, optional): The number of clients. Defaults to 3. NOTE: this argument should not be here we will remove this argument in next release. num_rounds (int, optional): The total number of training rounds. Defaults to 5. start_round (int, optional): The starting round number. memory_gc_rounds (int, optional): Run memory cleanup (gc.collect + malloc_trim) every N rounds. Set to 0 to disable. Defaults to 0 (disabled). """ super().__init__(*args, **kwargs) check_non_negative_int("memory_gc_rounds", memory_gc_rounds) self.num_clients = num_clients self.num_rounds = num_rounds self.start_round = start_round self.memory_gc_rounds = memory_gc_rounds self.current_round = None def _maybe_cleanup_memory(self): """Perform memory cleanup if configured (every N rounds based on memory_gc_rounds).""" if self.current_round is None: return if self.memory_gc_rounds > 0 and (self.current_round + 1) % self.memory_gc_rounds == 0: self.info(f"Memory cleanup at round {self.current_round + 1}") cleanup_memory() @staticmethod def _check_results(results: List[FLModel]): empty_clients = [] for _result in results: if not _result.params: empty_clients.append(_get_client_name(_result)) if len(empty_clients) > 0: raise ValueError(f"Result from client(s) {empty_clients} is empty!")
[docs] @staticmethod def aggregate_fn(results: List[FLModel]) -> FLModel: """Aggregate model params and metrics across results with weighted averaging. Note: Metric values that do not support weighted arithmetic are skipped during aggregation. If no aggregatable metrics remain after filtering, the aggregated metrics are returned as ``None``. """ if not results: raise ValueError("received empty results for aggregation.") aggr_helper = WeightedAggregationHelper() for _result in results: aggr_helper.add( data=_result.params, weight=_get_num_steps_weight(_result), contributor_name=_get_client_name(_result), contribution_round=_result.current_round, ) aggr_params = aggr_helper.get_result() aggr_metrics = _aggregate_fl_model_metrics(results) aggr_result = FLModel( params=aggr_params, params_type=results[0].params_type, metrics=aggr_metrics, meta={ "nr_aggregated": len(results), "current_round": results[0].current_round, AppConstants.METRICS_AGGREGATION_INFO: make_fedavg_metrics_aggregation_info(), }, ) return aggr_result
[docs] def aggregate(self, results: List[FLModel], aggregate_fn=None) -> FLModel: """Called by the `run` routine to aggregate the training results of clients. Args: results: a list of FLModel containing training results of the clients. aggregate_fn: a function that turns the list of FLModel into one resulting (aggregated) FLModel. Returns: aggregated FLModel. """ self.debug("Start aggregation.") self.event(AppEventType.BEFORE_AGGREGATION) self._check_results(results) if not aggregate_fn: aggregate_fn = self.aggregate_fn self.info(f"aggregating {len(results)} update(s) at round {self.current_round}") try: aggr_result = aggregate_fn(results) except Exception as e: error_msg = f"Exception in aggregate call: {secure_format_exception(e)}" self.exception(error_msg) self.panic(error_msg) return FLModel() self._results = [] self._set_metrics_aggregation_info(aggr_result) self.fire_event_with_data( AppEventType.AFTER_AGGREGATION, self.fl_ctx, AppConstants.AGGREGATION_RESULT, aggr_result ) self.debug("End aggregation.") return aggr_result
def _set_metrics_aggregation_info(self, aggr_result: FLModel): if not isinstance(aggr_result, FLModel): return aggr_result.meta = aggr_result.meta or {} info = aggr_result.meta.get(AppConstants.METRICS_AGGREGATION_INFO) if isinstance(info, dict): info = dict(info) else: info = make_fedavg_metrics_aggregation_info() key_metric_info = make_key_metric_info_from_stop_condition( getattr(self, "stop_cond", None), getattr(self, "stop_condition", None) ) if key_metric_info and "key_metric" not in info: info["key_metric"] = key_metric_info aggr_result.meta[AppConstants.METRICS_AGGREGATION_INFO] = info
[docs] def update_model(self, model, aggr_result): """Called by the `run` routine to update the current global model (self.model) given the aggregated result. Args: model: FLModel to be updated. aggr_result: aggregated FLModel. Returns: None. """ self.event(AppEventType.BEFORE_SHAREABLE_TO_LEARNABLE) model = FLModelUtils.update_model(model, aggr_result) # persistor uses Learnable format to save model ml = make_model_learnable(weights=model.params, meta_props=model.meta) self.fl_ctx.set_prop(AppConstants.GLOBAL_MODEL, ml, private=True, sticky=True) self.event(AppEventType.AFTER_SHAREABLE_TO_LEARNABLE) return model