Source code for nvflare.app_common.aggregators.dxo_aggregator

# Copyright (c) 2022, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import logging
from typing import Any, Dict, Optional

from nvflare.apis.dxo import DXO, DataKind, MetaKey
from nvflare.apis.fl_component import FLComponent
from nvflare.apis.fl_context import FLContext
from nvflare.app_common.aggregators.weighted_aggregation_helper import WeightedAggregationHelper
from nvflare.app_common.app_constant import AppConstants


[docs]class DXOAggregator(FLComponent): def __init__( self, exclude_vars: Optional[str] = None, aggregation_weights: Optional[Dict[str, Any]] = None, expected_data_kind: DataKind = DataKind.WEIGHT_DIFF, name_postfix: str = "", weigh_by_local_iter: bool = True, ): """Perform accumulated weighted aggregation for one kind of corresponding DXO from contributors. Args: exclude_vars (str, optional): Regex to match excluded vars during aggregation. Defaults to None. aggregation_weights (Dict[str, Any], optional): Aggregation weight for each contributor. Defaults to None. expected_data_kind (DataKind): Expected DataKind for this DXO. name_postfix: optional postfix to give to class name and show in logger output. weigh_by_local_iter (bool, optional): Whether to weight the contributions by the number of iterations performed in local training in the current round. Defaults to `True`. Setting it to `False` can be useful in applications such as homomorphic encryption to reduce the number of computations on encrypted ciphertext. The aggregated sum will still be divided by the provided weights and `aggregation_weights` for the resulting weighted sum to be valid. """ super().__init__() self.expected_data_kind = expected_data_kind self.aggregation_weights = aggregation_weights or {} self.logger.debug(f"aggregation weights control: {aggregation_weights}") self.aggregation_helper = WeightedAggregationHelper( exclude_vars=exclude_vars, weigh_by_local_iter=weigh_by_local_iter ) self.warning_count = {} self.warning_limit = 10 self.processed_algorithm = None if name_postfix: self._name += name_postfix self.logger = logging.getLogger(self._name)
[docs] def reset_aggregation_helper(self): if self.aggregation_helper: self.aggregation_helper.reset_stats()
[docs] def accept(self, dxo: DXO, contributor_name, contribution_round, fl_ctx: FLContext) -> bool: """Store DXO and update aggregator's internal state Args: dxo: information from contributor contributor_name: name of the contributor contribution_round: round of the contribution fl_ctx: context provided by workflow Returns: The boolean to indicate if DXO is accepted. """ if not isinstance(dxo, DXO): self.log_error(fl_ctx, f"Expected DXO but got {type(dxo)}") return False if dxo.data_kind not in (DataKind.WEIGHT_DIFF, DataKind.WEIGHTS, DataKind.METRICS): self.log_error(fl_ctx, "cannot handle data kind {}".format(dxo.data_kind)) return False if dxo.data_kind != self.expected_data_kind: self.log_error(fl_ctx, "expected {} but got {}".format(self.expected_data_kind, dxo.data_kind)) return False processed_algorithm = dxo.get_meta_prop(MetaKey.PROCESSED_ALGORITHM) if processed_algorithm is not None: if self.processed_algorithm is None: self.processed_algorithm = processed_algorithm elif self.processed_algorithm != processed_algorithm: self.log_error( fl_ctx, f"Only supports aggregation of data processed with the same algorithm ({self.processed_algorithm}) " f"but got algorithm: {processed_algorithm}", ) return False current_round = fl_ctx.get_prop(AppConstants.CURRENT_ROUND) if contribution_round != current_round: self.log_warning( fl_ctx, f"discarding DXO from {contributor_name} at round: " f"{contribution_round}. Current round is: {current_round}", ) return False self.log_debug(fl_ctx, f"current_round: {current_round}") data = dxo.data if data is None: self.log_error(fl_ctx, "no data to aggregate") return False for item in self.aggregation_helper.get_history(): if contributor_name == item["contributor_name"]: prev_round = item["round"] self.log_warning( fl_ctx, f"discarding DXO from {contributor_name} at round: " f"{contribution_round} as {prev_round} accepted already", ) return False n_iter = dxo.get_meta_prop(MetaKey.NUM_STEPS_CURRENT_ROUND) if n_iter is None: if self.warning_count.get(contributor_name, 0) <= self.warning_limit: self.log_warning( fl_ctx, f"NUM_STEPS_CURRENT_ROUND missing in meta of DXO" f" from {contributor_name} and set to default value, 1.0. " f" This kind of message will show {self.warning_limit} times at most.", ) if contributor_name in self.warning_count: self.warning_count[contributor_name] = self.warning_count[contributor_name] + 1 else: self.warning_count[contributor_name] = 0 n_iter = 1.0 float_n_iter = float(n_iter) aggregation_weight = self.aggregation_weights.get(contributor_name) if aggregation_weight is None: if self.warning_count.get(contributor_name, 0) <= self.warning_limit: self.log_warning( fl_ctx, f"Aggregation_weight missing for {contributor_name} and set to default value, 1.0" f" This kind of message will show {self.warning_limit} times at most.", ) if contributor_name in self.warning_count: self.warning_count[contributor_name] = self.warning_count[contributor_name] + 1 else: self.warning_count[contributor_name] = 0 aggregation_weight = 1.0 # aggregate self.aggregation_helper.add(data, aggregation_weight * float_n_iter, contributor_name, contribution_round) self.log_debug(fl_ctx, "End accept") return True
[docs] def aggregate(self, fl_ctx: FLContext) -> DXO: """Called when workflow determines to generate DXO to send back to contributors Args: fl_ctx (FLContext): context provided by workflow Returns: DXO: the weighted mean of accepted DXOs from contributors """ self.log_debug(fl_ctx, f"Start aggregation with weights {self.aggregation_weights}") current_round = fl_ctx.get_prop(AppConstants.CURRENT_ROUND) self.log_info(fl_ctx, f"aggregating {self.aggregation_helper.get_len()} update(s) at round {current_round}") self.log_debug(fl_ctx, f"complete history {self.aggregation_helper.get_len()}") aggregated_dict = self.aggregation_helper.get_result() self.log_debug(fl_ctx, "End aggregation") dxo = DXO(data_kind=self.expected_data_kind, data=aggregated_dict) if self.processed_algorithm is not None: dxo.set_meta_prop(MetaKey.PROCESSED_ALGORITHM, self.processed_algorithm) self.processed_algorithm = None return dxo