# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional
from nvflare.apis.dxo import DXO, from_shareable
from nvflare.apis.fl_constant import ReservedKey, ReturnCode
from nvflare.apis.fl_context import FLContext
from nvflare.apis.shareable import Shareable
from nvflare.app_common.abstract.aggregator import Aggregator
from nvflare.app_common.aggregators.assembler import Assembler
from nvflare.app_common.app_constant import AppConstants
[docs]class CollectAndAssembleAggregator(Aggregator):
"""Perform collection and flexible assemble aggregation
This is used for methods needing a special assemble mechanism on the client submissions.
It first collects all submissions from clients, then delegates the assembling functionality to assembler,
which is specific to a particular algorithm.
Note that the aggregation in this case is not in-time, since the assembling function may not be arithmetic mean.
"""
def __init__(self, assembler_id: str):
super().__init__()
self.assembler_id = assembler_id
self.assembler: Optional[Assembler] = None
[docs] def accept(self, shareable: Shareable, fl_ctx: FLContext) -> bool:
if not self.assembler:
self.assembler = fl_ctx.get_engine().get_component(self.assembler_id)
contributor_name = shareable.get_peer_prop(key=ReservedKey.IDENTITY_NAME, default="?")
dxo = self._get_contribution(shareable, fl_ctx)
if dxo is None or dxo.data is None:
self.log_error(fl_ctx, "no data to aggregate")
return False
current_round = fl_ctx.get_prop(AppConstants.CURRENT_ROUND)
return self._accept_contribution(contributor_name, current_round, dxo, fl_ctx)
def _accept_contribution(self, contributor: str, current_round: int, dxo: DXO, fl_ctx: FLContext) -> bool:
collection = self.assembler.collection
if contributor not in collection:
collection[contributor] = self.assembler.get_model_params(dxo)
accepted = True
else:
self.log_info(
fl_ctx,
f"Discarded: Current round: {current_round} " + f"contributions already include client: {contributor}",
)
accepted = False
return accepted
def _get_contribution(self, shareable: Shareable, fl_ctx: FLContext) -> Optional[DXO]:
contributor_name = shareable.get_peer_prop(key=ReservedKey.IDENTITY_NAME, default="?")
try:
dxo = from_shareable(shareable)
except Exception:
self.log_exception(fl_ctx, "shareable data is not a valid DXO")
return None
rc = shareable.get_return_code()
if rc and rc != ReturnCode.OK:
self.log_warning(
fl_ctx,
f"Contributor {contributor_name} returned rc: {rc}. Disregarding contribution.",
)
return None
expected_data_kind = self.assembler.get_expected_data_kind()
if dxo.data_kind != expected_data_kind:
self.log_error(
fl_ctx,
"expected {} but got {}".format(expected_data_kind, dxo.data_kind),
)
return None
contribution_round = shareable.get_cookie(AppConstants.CONTRIBUTION_ROUND)
current_round = fl_ctx.get_prop(AppConstants.CURRENT_ROUND)
if contribution_round != current_round:
self.log_warning(
fl_ctx,
f"discarding DXO from {contributor_name} at round: "
f"{contribution_round}. Current round is: {current_round}",
)
return None
return dxo
[docs] def aggregate(self, fl_ctx: FLContext) -> Shareable:
self.log_debug(fl_ctx, "Start aggregation")
current_round = fl_ctx.get_prop(AppConstants.CURRENT_ROUND)
collection = self.assembler.collection
site_num = len(collection)
self.log_info(fl_ctx, f"aggregating {site_num} update(s) at round {current_round}")
dxo = self.assembler.assemble(data=collection, fl_ctx=fl_ctx)
# Reset assembler for next round
self.assembler.reset()
self.log_debug(fl_ctx, "End aggregation")
return dxo.to_shareable()