Source code for nvflare.app_common.aggregators.weighted_aggregation_helper

# Copyright (c) 2025, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import re
import threading
from typing import Any, Callable, Dict, Optional, Set


def _is_aggregatable_metric_value(v: Any) -> bool:
    """Return True if the metric value supports weighted aggregation (v * weight and addition).

    Boolean values are considered aggregatable and treated as binary values
    (`True=1.0`, `False=0.0`) when averaged.
    """
    if v is None:
        return False
    if isinstance(v, (dict, list, set, tuple, str)):
        return False
    # Bool metrics are treated as binary values (True=1, False=0) and averaged.
    if isinstance(v, (int, float, bool)):
        return True
    try:
        _ = v * 1.0
        _ = v + v
        return True
    except (TypeError, ValueError, AttributeError):
        return False


[docs] def filter_aggregatable_metrics( metrics: Optional[Dict[str, Any]], warn_skipped: Optional[Callable[[str, str], None]] = None, warned_metric_keys: Optional[Set[str]] = None, ) -> Dict[str, Any]: """Return metric entries that support weighted aggregation. Note: Boolean metric values are included and aggregate as binary rates. Args: metrics: Dict of metric name -> value. warn_skipped: Optional callback invoked as warn_skipped(key, type_name) for skipped metrics. warned_metric_keys: Optional set of keys already warned about. If provided, warnings are emitted at most once per key and newly warned keys are added to this set. """ if not metrics: return {} filtered = {} for key, value in metrics.items(): if _is_aggregatable_metric_value(value): filtered[key] = value continue if warn_skipped is None: continue if warned_metric_keys is None or key not in warned_metric_keys: warn_skipped(key, type(value).__name__) if warned_metric_keys is not None: warned_metric_keys.add(key) return filtered
[docs] class WeightedAggregationHelper(object): def __init__(self, exclude_vars: Optional[str] = None, weigh_by_local_iter: bool = True): """Perform weighted aggregation. Args: exclude_vars (str, optional): regex string to match excluded vars during aggregation. Defaults to None. weigh_by_local_iter (bool, optional): Whether to weight the contributions by the number of iterations performed in local training in the current round. Defaults to `True`. Setting it to `False` can be useful in applications such as homomorphic encryption to reduce the number of computations on encrypted ciphertext. The aggregated sum will still be divided by the provided weights and `aggregation_weights` for the resulting weighted sum to be valid. """ super().__init__() self.lock = threading.Lock() self.exclude_vars = re.compile(exclude_vars) if exclude_vars else None self.weigh_by_local_iter = weigh_by_local_iter self.reset_stats() self.total = dict() self.counts = dict() self.history = list()
[docs] def reset_stats(self): self.total = dict() self.counts = dict() self.history = list()
@staticmethod def _is_pytorch_tensor(tensor): """Check if tensor is a PyTorch tensor with in-place operation support.""" return hasattr(tensor, "add_") and hasattr(tensor, "mul_") and hasattr(tensor, "clone")
[docs] def add(self, data, weight, contributor_name, contribution_round): """Compute weighted sum and sum of weights.""" with self.lock: for k, v in data.items(): if self.exclude_vars is not None and self.exclude_vars.search(k): continue current_total = self.total.get(k, None) if current_total is None: # First contribution: initialize accumulator # We must create a copy to avoid mutating caller's input tensors if self._is_pytorch_tensor(v): if self.weigh_by_local_iter: # Weigh by local iter: create weighted copy (multiply by weight) self.total[k] = v.mul(weight) else: self.total[k] = v.clone() else: # Fallback for non-PyTorch tensors if self.weigh_by_local_iter: # Multiply creates a new array/tensor, no aliasing issue self.total[k] = v * weight else: # For HE mode: try to copy to avoid aliasing # But encrypted tensors can't be copied (requires secret key) try: self.total[k] = v.copy() if hasattr(v, "copy") else v except (ValueError, RuntimeError): # Encrypted tensor copy failed, use reference (safe, immutable) self.total[k] = v self.counts[k] = weight else: # Subsequent contributions: use in-place operations if self._is_pytorch_tensor(v) and self._is_pytorch_tensor(current_total): if self.weigh_by_local_iter: # Weigh by local iter: weighted accumulation self.total[k].add_(v, alpha=weight) else: self.total[k].add_(v) else: # Fallback for non-PyTorch tensors if self.weigh_by_local_iter: self.total[k] = current_total + v * weight else: self.total[k] = current_total + v self.counts[k] = self.counts[k] + weight self.history.append( { "contributor_name": contributor_name, "round": contribution_round, "weight": weight, } )
[docs] def get_result(self): """Divide weighted sum by sum of weights.""" with self.lock: aggregated_dict = {} for k, v in self.total.items(): if self._is_pytorch_tensor(v): # For PyTorch tensors, use in-place division to avoid creating a copy aggregated_dict[k] = v.div_(self.counts[k]) else: # Fallback for non-PyTorch tensors (including encrypted tensors) aggregated_dict[k] = v * (1.0 / self.counts[k]) self.reset_stats() return aggregated_dict
[docs] def get_history(self): return self.history
[docs] def get_len(self): return len(self.get_history())