Source code for nvflare.app_common.aggregators.lazy

# Copyright (c) 2025, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import queue
import threading

from nvflare.apis.event_type import EventType
from nvflare.apis.fl_context import FLContext
from nvflare.apis.shareable import Shareable
from nvflare.app_common.abstract.aggregator import Aggregator
from nvflare.fuel.utils.validation_utils import check_positive_number, check_str
from nvflare.fuel.utils.waiter_utils import WaiterRC, conditional_wait
from nvflare.security.logging import secure_format_exception

_SHORT_WAIT = 0.1


class _AcceptWaitRC(WaiterRC):
    END_RUN = 10


class _Contribution:

    def __init__(self, data: Shareable, fl_ctx: FLContext):
        self.data = data

        # We need to make a copy of the fl_ctx since it could be used for multiple contributions and its
        # content could be overwritten.
        self.fl_ctx = fl_ctx.clone()


[docs] class LazyAggregator(Aggregator): def __init__(self, aggregator_id: str, accept_timeout: float = 600.0): """Constructor of LazyAggregator. LazyAggregator is a wrapper for other aggregators to do accept processing in a separate thread. During a typical SAG-based training, updates from clients are processed by the aggregator's "accept" method. To ensure the integrity of training task, The SAG workflow processes client updates sequentially. If the "accept" method is time-consuming and there are many clients, then the update processing will become bottleneck. Using the LazyAggregator, the updates from clients are simply added to a queue quickly. The actual "accept" processing is done in a separate thread that processes the queued updates sequentially. Args: aggregator_id: component ID of the real aggregator accept_timeout: max amount of time to wait for accept to finish """ Aggregator.__init__(self) check_str("aggregator_id", aggregator_id) check_positive_number("accept_timeout", accept_timeout) self.aggregator_id = aggregator_id self.accept_timeout = accept_timeout self.aggregator = None self.contributions = queue.Queue() self._q_lock = threading.Lock() self.aggregating = False self.run_ended = False self.accept_done = threading.Event() self.register_event_handler(EventType.START_RUN, self._lazy_aggr_start_run) self.register_event_handler(EventType.END_RUN, self._lazy_aggr_end_run) def _clear_contributions(self): with self._q_lock: q = self.contributions while True: try: q.get(block=False) # Attempt to get an item without blocking q.task_done() # Mark the task as done (important for JoinableQueue) except queue.Empty: break # Break the loop when the queue is empty def _add_contribution(self, contrib: Shareable, fl_ctx: FLContext): self.log_debug(fl_ctx, "adding contribution to queue") with self._q_lock: self.contributions.put(_Contribution(contrib, fl_ctx)) self.log_debug(fl_ctx, "done adding contribution to queue") def _lazy_aggr_start_run(self, event_type: str, fl_ctx: FLContext): engine = fl_ctx.get_engine() aggr = engine.get_component(self.aggregator_id) if not isinstance(aggr, Aggregator): self.system_panic(f"component {self.aggregator_id} must be Aggregator but got {type(aggr)}", fl_ctx) return if isinstance(aggr, LazyAggregator): self.system_panic(f"component {self.aggregator_id} must not be LazyAggregator", fl_ctx) return self.aggregator = aggr accept_thread = threading.Thread(target=self._do_accept, daemon=True) accept_thread.start() def _lazy_aggr_end_run(self, event_type: str, fl_ctx: FLContext): self.run_ended = True
[docs] def accept(self, shareable: Shareable, fl_ctx: FLContext) -> bool: if not self.aggregating: self._add_contribution(shareable, fl_ctx) return True else: # when the aggregation is started, we no longer accept new contributions self.log_warning(fl_ctx, "dropped contribution while aggregating") return False
def _do_accept(self): # This thread monitors the contribution queue. # It takes contributions from the queue and processes them one by one. self.logger.debug("Started accept thread") while True: if self.run_ended: # the job is already done or aborted self.logger.debug("run ended - exit") break try: # we wait very shortly when trying to get a contribution, so we can check other conditions contrib = self.contributions.get(timeout=_SHORT_WAIT) except queue.Empty: contrib = None if contrib: assert isinstance(contrib, _Contribution) self.log_debug(contrib.fl_ctx, "Accepting contribution") try: accepted = self.aggregator.accept(contrib.data, contrib.fl_ctx) self.log_debug(contrib.fl_ctx, f"{type(self.aggregator)} processed contribution: {accepted=}") except Exception as ex: self.log_exception( contrib.fl_ctx, f"exception from {type(self.aggregator)} when accept: {secure_format_exception(ex)}", ) if self.aggregating and self.contributions.empty() and not self.accept_done.is_set(): # When self.aggregating is set, the "aggregate" process is started and waiting for all contributions # to be accepted. # We set "accept_done" when all pending contributions are done. self.logger.debug("Finished accept for one round") self.accept_done.set() self.logger.debug("Finished accept thread")
[docs] def aggregate(self, fl_ctx: FLContext) -> Shareable: self.log_debug(fl_ctx, "starting to aggregate") # set self.aggregating to notify the "accept" thread that we are starting aggregation. self.aggregating = True # we wait until the "accept" thread is done with pending contributions rc = conditional_wait( waiter=self.accept_done, timeout=self.accept_timeout, abort_signal=fl_ctx.get_run_abort_signal(), condition_cb=self._check_end_run, ) if rc in [_AcceptWaitRC.ABORTED, _AcceptWaitRC.END_RUN]: self.log_info(fl_ctx, "skipped aggregation since job is aborted") return Shareable() if rc != _AcceptWaitRC.IS_SET: self.log_warning(fl_ctx, f"abnormal result {rc} waiting for accept thread") # we then call the aggregator to perform actual aggregation result = self.aggregator.aggregate(fl_ctx) # reset state - some controllers may not call the aggregator's reset method for historical reason # we call it here to make sure the aggregator state is reset. self._reset(fl_ctx) self.log_debug(fl_ctx, "Finished aggregate for one round") return result
def _reset(self, fl_ctx: FLContext): self.aggregating = False self.accept_done.clear() self._clear_contributions() self.aggregator.reset(fl_ctx) def _check_end_run(self): if self.run_ended: return _AcceptWaitRC.END_RUN else: return _AcceptWaitRC.OK
[docs] def reset(self, fl_ctx: FLContext): self._reset(fl_ctx)