Source code for nvflare.app_common.aggregators.weighted_aggregation_helper

# Copyright (c) 2022, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import re
import threading
from typing import Optional


[docs]class WeightedAggregationHelper(object): def __init__(self, exclude_vars: Optional[str] = None, weigh_by_local_iter: bool = True): """Perform weighted aggregation. Args: exclude_vars (str, optional): regex string to match excluded vars during aggregation. Defaults to None. weigh_by_local_iter (bool, optional): Whether to weight the contributions by the number of iterations performed in local training in the current round. Defaults to `True`. Setting it to `False` can be useful in applications such as homomorphic encryption to reduce the number of computations on encrypted ciphertext. The aggregated sum will still be divided by the provided weights and `aggregation_weights` for the resulting weighted sum to be valid. """ super().__init__() self.lock = threading.Lock() self.exclude_vars = re.compile(exclude_vars) if exclude_vars else None self.weigh_by_local_iter = weigh_by_local_iter self.reset_stats() self.total = dict() self.counts = dict() self.history = list()
[docs] def reset_stats(self): self.total = dict() self.counts = dict() self.history = list()
[docs] def add(self, data, weight, contributor_name, contribution_round): """Compute weighted sum and sum of weights.""" with self.lock: for k, v in data.items(): if self.exclude_vars is not None and self.exclude_vars.search(k): continue if self.weigh_by_local_iter: weighted_value = v * weight else: weighted_value = v # used in homomorphic encryption to reduce computations on ciphertext current_total = self.total.get(k, None) if current_total is None: self.total[k] = weighted_value self.counts[k] = weight else: self.total[k] = current_total + weighted_value self.counts[k] = self.counts[k] + weight self.history.append( { "contributor_name": contributor_name, "round": contribution_round, "weight": weight, } )
[docs] def get_result(self): """Divide weighted sum by sum of weights.""" with self.lock: aggregated_dict = {k: v * (1.0 / self.counts[k]) for k, v in self.total.items()} self.reset_stats() return aggregated_dict
[docs] def get_history(self): return self.history
[docs] def get_len(self): return len(self.get_history())