# Copyright (c) 2021, NVIDIA CORPORATION.  All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Any, Dict, Union

from nvflare.apis.dxo import DXO, DataKind, from_shareable
from nvflare.apis.event_type import EventType
from nvflare.apis.fl_constant import ReservedKey, ReturnCode
from nvflare.apis.fl_context import FLContext
from nvflare.apis.shareable import Shareable
from nvflare.app_common.abstract.aggregator import Aggregator
from nvflare.app_common.aggregators.dxo_aggregator import DXOAggregator
from nvflare.app_common.app_constant import AppConstants

def _is_nested_aggregation_weights(aggregation_weights):
    if not aggregation_weights:
        return False
    if not isinstance(aggregation_weights, dict):
        return False
    first_value = next(iter(aggregation_weights.items()))[1]
    if not isinstance(first_value, dict):
        return False
    return True

def _get_missing_keys(ref_dict: dict, dict_to_check: dict):
    result = []
    for k in ref_dict:
        if k not in dict_to_check:
    return result

[docs]class InTimeAccumulateWeightedAggregator(Aggregator): def __init__( self, exclude_vars: Union[str, Dict[str, str], None] = None, aggregation_weights: Union[Dict[str, Any], Dict[str, Dict[str, Any]], None] = None, expected_data_kind: Union[DataKind, Dict[str, DataKind]] = DataKind.WEIGHT_DIFF, weigh_by_local_iter: bool = True, ): """Perform accumulated weighted aggregation. This is often used as the default aggregation method and can be used for FedAvg. It parses the shareable and aggregates the contained DXO(s). Args: exclude_vars (Union[str, Dict[str, str]], optional): Regular expression string to match excluded vars during aggregation. Defaults to None. Can be one string or a dict of {dxo_name: regex strings} corresponding to each aggregated DXO when processing a DXO of `DataKind.COLLECTION`. aggregation_weights (Union[Dict[str, Any], Dict[str, Dict[str, Any]]], optional): Aggregation weight for each contributor. Defaults to None. Can be one dict of {contrib_name: aggr_weight} or a dict of dicts corresponding to each aggregated DXO when processing a DXO of `DataKind.COLLECTION`. expected_data_kind (Union[DataKind, Dict[str, DataKind]]): DataKind for DXO. Defaults to DataKind.WEIGHT_DIFF Can be one DataKind or a dict of {dxo_name: DataKind} corresponding to each aggregated DXO when processing a DXO of `DataKind.COLLECTION`. Only the keys in this dict will be processed. weigh_by_local_iter (bool, optional): Whether to weight the contributions by the number of iterations performed in local training in the current round. Defaults to `True`. Setting it to `False` can be useful in applications such as homomorphic encryption to reduce the number of computations on encrypted ciphertext. The aggregated sum will still be divided by the provided weights and `aggregation_weights` for the resulting weighted sum to be valid. """ super().__init__() self.logger.debug(f"exclude vars: {exclude_vars}") self.logger.debug(f"aggregation weights control: {aggregation_weights}") self.logger.debug(f"expected data kind: {expected_data_kind}") self._single_dxo_key = "" self._weigh_by_local_iter = weigh_by_local_iter self.aggregation_weights = aggregation_weights self.exclude_vars = exclude_vars self.expected_data_kind = expected_data_kind
[docs] def handle_event(self, event_type: str, fl_ctx: FLContext): # _initialize() can not be called from the constructor. Because it changes the data, even the data format # of the aggregation_weights and exclude_vars parameters. Inspect could not figure out the passed in # parameters when re-construct the object creation configuration. if event_type == EventType.START_RUN: self._initialize(self.aggregation_weights, self.exclude_vars, self.expected_data_kind)
def _initialize(self, aggregation_weights, exclude_vars, expected_data_kind): # Check expected data kind if isinstance(expected_data_kind, dict): for k, v in expected_data_kind.items(): if v not in [DataKind.WEIGHT_DIFF, DataKind.WEIGHTS, DataKind.METRICS]: raise ValueError( f"expected_data_kind[{k}] = {v} is not {DataKind.WEIGHT_DIFF} or {DataKind.WEIGHTS} or {DataKind.METRICS}" ) self.expected_data_kind = expected_data_kind else: if expected_data_kind not in [DataKind.WEIGHT_DIFF, DataKind.WEIGHTS, DataKind.METRICS]: raise ValueError( f"expected_data_kind = {expected_data_kind} is not {DataKind.WEIGHT_DIFF} or {DataKind.WEIGHTS} or {DataKind.METRICS}" ) self.expected_data_kind = {self._single_dxo_key: expected_data_kind} # Check exclude_vars if exclude_vars: if not isinstance(exclude_vars, dict) and not isinstance(exclude_vars, str): raise ValueError( f"exclude_vars = {exclude_vars} should be a regex string but got {type(exclude_vars)}." ) if isinstance(exclude_vars, dict): missing_keys = _get_missing_keys(expected_data_kind, exclude_vars) if len(missing_keys) != 0: raise ValueError( "A dict exclude_vars should specify exclude_vars for every key in expected_data_kind. " f"But missed these keys: {missing_keys}" ) exclude_vars_dict = dict() for k in self.expected_data_kind.keys(): if isinstance(exclude_vars, dict): if k in exclude_vars: if not isinstance(exclude_vars[k], str): raise ValueError( f"exclude_vars[{k}] = {exclude_vars[k]} should be a regex string but got {type(exclude_vars[k])}." ) exclude_vars_dict[k] = exclude_vars[k] else: # assume same exclude vars for each entry of DXO collection. exclude_vars_dict[k] = exclude_vars if self._single_dxo_key in self.expected_data_kind: exclude_vars_dict[self._single_dxo_key] = exclude_vars self.exclude_vars = exclude_vars_dict # Check aggregation weights if _is_nested_aggregation_weights(aggregation_weights): missing_keys = _get_missing_keys(expected_data_kind, aggregation_weights) if len(missing_keys) != 0: raise ValueError( "A dict of dict aggregation_weights should specify aggregation_weights " f"for every key in expected_data_kind. But missed these keys: {missing_keys}" ) aggregation_weights = aggregation_weights or {} aggregation_weights_dict = dict() for k in self.expected_data_kind.keys(): if k in aggregation_weights: aggregation_weights_dict[k] = aggregation_weights[k] else: # assume same aggregation weights for each entry of DXO collection. aggregation_weights_dict[k] = aggregation_weights self.aggregation_weights = aggregation_weights_dict # Set up DXO aggregators self.dxo_aggregators = dict() for k in self.expected_data_kind.keys(): self.dxo_aggregators.update( { k: DXOAggregator( exclude_vars=self.exclude_vars[k], aggregation_weights=self.aggregation_weights[k], expected_data_kind=self.expected_data_kind[k], name_postfix=k, weigh_by_local_iter=self._weigh_by_local_iter, ) } )
[docs] def accept(self, shareable: Shareable, fl_ctx: FLContext) -> bool: """Store shareable and update aggregator's internal state Args: shareable: information from contributor fl_ctx: context provided by workflow Returns: The first boolean indicates if this shareable is accepted. The second boolean indicates if aggregate can be called. """ try: dxo = from_shareable(shareable) except Exception: self.log_exception(fl_ctx, "shareable data is not a valid DXO") return False if dxo.data_kind not in (DataKind.WEIGHT_DIFF, DataKind.WEIGHTS, DataKind.METRICS, DataKind.COLLECTION): self.log_error( fl_ctx, f"cannot handle data kind {dxo.data_kind}, " f"expecting DataKind.WEIGHT_DIFF, DataKind.WEIGHTS, or DataKind.COLLECTION.", ) return False contributor_name = shareable.get_peer_prop(key=ReservedKey.IDENTITY_NAME, default="?") contribution_round = shareable.get_cookie(AppConstants.CONTRIBUTION_ROUND) rc = shareable.get_return_code() if rc and rc != ReturnCode.OK: self.log_warning(fl_ctx, f"Contributor {contributor_name} returned rc: {rc}. Disregarding contribution.") return False # Accept expected DXO(s) in shareable n_accepted = 0 for key in self.expected_data_kind.keys(): if key == self._single_dxo_key: # expecting a single DXO sub_dxo = dxo else: # expecting a collection of DXOs sub_dxo = if not isinstance(sub_dxo, DXO): self.log_warning(fl_ctx, f"Collection does not contain DXO for key {key} but {type(sub_dxo)}.") continue accepted = self.dxo_aggregators[key].accept( dxo=sub_dxo, contributor_name=contributor_name, contribution_round=contribution_round, fl_ctx=fl_ctx ) if not accepted: return False else: n_accepted += 1 if n_accepted > 0: return True else: self.log_warning(fl_ctx, f"Did not accept any DXOs from {contributor_name} in round {contribution_round}!") return False
[docs] def aggregate(self, fl_ctx: FLContext) -> Shareable: """Called when workflow determines to generate shareable to send back to contributors Args: fl_ctx (FLContext): context provided by workflow Returns: Shareable: the weighted mean of accepted shareables from contributors """ self.log_debug(fl_ctx, "Start aggregation") result_dxo_dict = dict() # Aggregate the expected DXO(s) for key in self.expected_data_kind.keys(): aggregated_dxo = self.dxo_aggregators[key].aggregate(fl_ctx) if key == self._single_dxo_key: # return single DXO with aggregation results return aggregated_dxo.to_shareable() self.log_info(fl_ctx, f"Aggregated contributions matching key '{key}'.") result_dxo_dict.update({key: aggregated_dxo}) # return collection of DXOs with aggregation results collection_dxo = DXO(data_kind=DataKind.COLLECTION, data=result_dxo_dict) return collection_dxo.to_shareable()