Source code for nvflare.edge.updaters.aggr

# Copyright (c) 2025, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import threading
from typing import Optional

from nvflare.apis.event_type import EventType
from nvflare.apis.fl_context import FLContext
from nvflare.apis.shareable import Shareable
from nvflare.app_common.abstract.aggregator import Aggregator
from nvflare.edge.updater import Updater
from nvflare.fuel.utils.validation_utils import check_positive_int, check_str


[docs] class AggrUpdater(Updater): def __init__(self, aggregator_id: str, min_accepted=2): """Constructor of AggrUpdater. AggrUpdater implements required logic by using an Aggregator. Args: aggregator_id: component ID of the aggregator min_accepted: minimum updates required before aggregating. """ Updater.__init__(self) self.aggregator_id = aggregator_id self.aggregator = None self.num_accepted = 0 self.min_accepted = min_accepted self.aggr_lock = threading.Lock() check_str("aggregator_id", aggregator_id) check_positive_int("min_accepted", min_accepted) self.register_event_handler(EventType.START_RUN, self._handle_start_run) def _handle_start_run(self, event_type: str, fl_ctx: FLContext): engine = fl_ctx.get_engine() aggr = engine.get_component(self.aggregator_id) if not isinstance(aggr, Aggregator): self.system_panic( f"aggregator {self.aggregator_id} must be an Aggregator type object but got {type(aggr)}", fl_ctx, ) return self.aggregator = aggr
[docs] def process_parent_update_reply(self, reply: Shareable, fl_ctx: FLContext): # do not update my state. return
[docs] def prepare_update_for_parent(self, fl_ctx: FLContext) -> Shareable: # return aggregation result of the aggregator with self.aggr_lock: if self.num_accepted >= self.min_accepted: # only when we have accepted enough updates from children update = self.aggregator.aggregate(fl_ctx) self.aggregator.reset(fl_ctx) self.num_accepted = 0 else: # otherwise we don't update the parent update = None return update
[docs] def process_child_update(self, update: Shareable, fl_ctx: FLContext) -> (bool, Optional[Shareable]): # use the aggregator to accept the update self.log_info(fl_ctx, f"accepting child update by {type(self.aggregator)}") with self.aggr_lock: accepted = self.aggregator.accept(update, fl_ctx) self.log_info(fl_ctx, f"done child update: {accepted=}") if accepted: self.num_accepted += 1 return accepted, None