Source code for nvflare.app_opt.he.intime_accumulate_model_aggregator

# Copyright (c) 2023, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Any, Dict, Union

import tenseal as ts

from nvflare.apis.dxo import DataKind, MetaKey, from_shareable
from nvflare.apis.fl_context import FLContext
from nvflare.apis.shareable import Shareable
from nvflare.app_common.aggregators.intime_accumulate_model_aggregator import InTimeAccumulateWeightedAggregator
from nvflare.app_opt.he import decomposers


[docs] class HEInTimeAccumulateWeightedAggregator(InTimeAccumulateWeightedAggregator): def __init__( self, exclude_vars: Union[str, Dict[str, str], None] = None, aggregation_weights: Union[Dict[str, Any], Dict[str, Dict[str, Any]], None] = None, expected_data_kind: Union[DataKind, Dict[str, DataKind]] = DataKind.WEIGHT_DIFF, weigh_by_local_iter=False, ): """In time aggregator for `Shareables` encrypted using homomorphic encryption (HE) with TenSEAL https://github.com/OpenMined/TenSEAL. Needed to register FOBS decomposer for HE (e.g. for CKKSVector). Args: exclude_vars ([list], optional): variable names that should be excluded from aggregation (use regular expression). Defaults to None. aggregation_weights ([dict], optional): dictionary of client aggregation. Defaults to None. weigh_by_local_iter (bool, optional): If true, multiply client weights on first in encryption space (default: `False` which is recommended for HE, first multiply happens in `HEModelEncryptor`)]. expected_data_kind (str, optional): the data_kind this aggregator can process. Defaults to "WEIGHT_DIFF". """ super().__init__( exclude_vars=exclude_vars, aggregation_weights=aggregation_weights, expected_data_kind=expected_data_kind, weigh_by_local_iter=weigh_by_local_iter, ) decomposers.register()
[docs] def aggregate(self, fl_ctx: FLContext) -> Shareable: shareable = super().aggregate(fl_ctx=fl_ctx) # get processed keys and add to dxo dxo = from_shareable(shareable) weights = dxo.data if not isinstance(weights, dict): raise ValueError(f"Expected weights to be of type dict but got type {type(weights)}") encrypted_layers = dict() for k, v in weights.items(): if isinstance(v, ts.CKKSVector): encrypted_layers[k] = True else: encrypted_layers[k] = False dxo.set_meta_prop(MetaKey.PROCESSED_KEYS, encrypted_layers) return dxo.to_shareable()