# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import threading
import time
from nvflare.apis.event_type import EventType
from nvflare.apis.executor import Executor
from nvflare.apis.fl_constant import FLContextKey, ReservedKey
from nvflare.apis.fl_context import FLContext
from nvflare.apis.shareable import ReturnCode, Shareable, make_reply
from nvflare.apis.signal import Signal
from nvflare.app_common.abstract.aggregator import Aggregator
from nvflare.app_common.app_constant import AppConstants
from nvflare.fuel.utils.validation_utils import check_non_negative_int, check_positive_number, check_str
from nvflare.security.logging import secure_format_exception
[docs]
class HierarchicalAggregationManager(Executor):
def __init__(
self,
learner_id: str,
aggregator_id: str,
aggr_timeout: float,
min_responses: int,
wait_time_after_min_resps_received: float,
):
Executor.__init__(self)
check_str("learner_id", learner_id)
check_str("aggregator_id", aggregator_id)
check_positive_number("aggr_timeout", aggr_timeout)
check_non_negative_int("min_responses", min_responses)
check_positive_number("wait_time_after_min_resps_received", wait_time_after_min_resps_received)
self.learner_id = learner_id
self.aggregator_id = aggregator_id
self.aggr_timeout = aggr_timeout
self.pending_task_id = None
self.current_round = None
self.pending_clients = {}
self.aggregator = None
self.learner = None
self.min_responses = min_responses
self.wait_time_after_min_resps_received = wait_time_after_min_resps_received
self._status_lock = threading.Lock()
self._aggr_lock = threading.Lock()
self._process_error = None
self.register_event_handler(EventType.START_RUN, self._handle_start_run)
self.register_event_handler(EventType.TASK_ASSIGNMENT_SENT, self._handle_task_sent)
self.register_event_handler(EventType.TASK_RESULT_RECEIVED, self._handle_result_received)
def _handle_start_run(self, event_type: str, fl_ctx: FLContext):
self.log_debug(fl_ctx, f"handling event {event_type}")
engine = fl_ctx.get_engine()
aggr = engine.get_component(self.aggregator_id)
if not isinstance(aggr, Aggregator):
self.log_error(fl_ctx, f"component '{self.aggregator_id}' must be Aggregator but got {type(aggr)}")
self.aggregator = aggr
learner = engine.get_component(self.learner_id)
if not isinstance(learner, Executor):
self.log_error(fl_ctx, f"component '{self.learner_id}' must be Executor but got {type(learner)}")
self.learner = learner
def _handle_task_sent(self, event_type: str, fl_ctx: FLContext):
# the task was sent to a child client
self.log_debug(fl_ctx, f"handling event {event_type}")
if not self.pending_task_id:
# I don't have a pending task
return
child_client_ctx = fl_ctx.get_peer_context()
assert isinstance(child_client_ctx, FLContext)
child_client_name = child_client_ctx.get_identity_name()
self._update_client_status(child_client_name, None)
task_id = fl_ctx.get_prop(FLContextKey.TASK_ID)
# indicate that this event has been processed by me
fl_ctx.set_prop(FLContextKey.EVENT_PROCESSED, True, private=True, sticky=False)
self.log_info(fl_ctx, f"sent task {task_id} to child {child_client_name}")
def _handle_result_received(self, event_type: str, fl_ctx: FLContext):
# received results from a child client
self.log_debug(fl_ctx, f"handling event {event_type}")
if not self.pending_task_id:
# I don't have a pending task
return
# indicate that this event has been processed by me
fl_ctx.set_prop(FLContextKey.EVENT_PROCESSED, True, private=True, sticky=False)
result = fl_ctx.get_prop(FLContextKey.TASK_RESULT)
assert isinstance(result, Shareable)
task_id = result.get_header(ReservedKey.TASK_ID)
peer_ctx = fl_ctx.get_peer_context()
assert isinstance(peer_ctx, FLContext)
child_client_name = peer_ctx.get_identity_name()
self.log_info(fl_ctx, f"received result for task {task_id} from child {child_client_name}")
if task_id != self.pending_task_id:
self.log_warning(
fl_ctx,
f"dropped the received result from child {child_client_name} "
f"for task {task_id} while waiting for task {self.pending_task_id}",
)
return
rc = result.get_return_code(ReturnCode.OK)
if rc == ReturnCode.OK:
self.log_info(fl_ctx, f"accepting result from client {child_client_name}")
self._do_aggregation(result, fl_ctx)
else:
self.log_error(fl_ctx, f"Received bad result from client {child_client_name}: {rc=}")
self.log_info(fl_ctx, f"received result from child {child_client_name}")
self._update_client_status(child_client_name, time.time())
def _do_aggregation(self, result: Shareable, fl_ctx: FLContext):
with self._aggr_lock:
try:
# some aggregators expect current_round to be in the fl_ctx!
fl_ctx.set_prop(AppConstants.CURRENT_ROUND, self.current_round, private=True, sticky=False)
self.aggregator.accept(result, fl_ctx)
except Exception as ex:
self.log_error(
fl_ctx,
f"exception when 'accept' from aggregator {type(self.aggregator)}: {secure_format_exception(ex)}",
)
self._process_error = True
def _pending_clients_status(self):
with self._status_lock:
if not self.pending_clients:
return 0, 0
received = 0
for received_time in self.pending_clients.values():
if received_time:
received += 1
return received, len(self.pending_clients)
def _update_client_status(self, client_name, status):
with self._status_lock:
self.pending_clients[client_name] = status
[docs]
def execute(self, task_name: str, shareable: Shareable, fl_ctx: FLContext, abort_signal: Signal) -> Shareable:
"""Execute the assigned task.
If we are a leaf node in client hierarchy, we'll execute the task by using the configured executor for the
task name "exec_<task_name>". This way different tasks can be handled by different executors.
If we are not leaf node, we'll wait for results from child clients and then aggregate their results using
the configured aggregator.
Args:
task_name: name of the assigned task
shareable: task data
fl_ctx: FLContext object
abort_signal: signal to notify abort
Returns: task result
"""
is_leaf = fl_ctx.get_prop(ReservedKey.IS_LEAF)
if is_leaf:
return self.learner.execute(task_name, shareable, fl_ctx, abort_signal)
self.log_info(fl_ctx, "waiting for results from children ...")
self.current_round = shareable.get_header(AppConstants.CURRENT_ROUND)
self.log_debug(fl_ctx, f"got current_round: {self.current_round}")
self.pending_task_id = shareable.get_header(ReservedKey.TASK_ID)
# Set header to indicate that we are ready to manage child clients
# Note: when a child comes to pull task, the communicator only sends it after the task is ready.
# This is to avoid the potential race condition that the client gets the task and then quickly submits
# result before we are even ready.
shareable.set_header(ReservedKey.TASK_IS_READY, True)
result = self._do_execute(fl_ctx, abort_signal)
# reset state
self.pending_task_id = None
self.pending_clients = {}
self.aggregator.reset(fl_ctx)
self._process_error = False
return result
def _do_execute(self, fl_ctx: FLContext, abort_signal: Signal) -> Shareable:
start_time = time.time()
min_received_time = None
while True:
if abort_signal.triggered:
return make_reply(ReturnCode.TASK_ABORTED)
if self._process_error:
# we bail out when any processing error encountered
break
current_time = time.time()
if current_time - start_time > self.aggr_timeout:
# we have waited long enough
break
# have we received all results?
received, total = self._pending_clients_status()
if received < self.min_responses:
# we have not received min responses - continue to wait
continue
if not min_received_time:
# received min responses - remember the time at which this happened
min_received_time = current_time
if current_time - min_received_time >= self.wait_time_after_min_resps_received:
# we have waited long enough after min responses received
break
time.sleep(0.5)
# return aggregation result
received, total = self._pending_clients_status()
self.log_info(fl_ctx, f"process done after {time.time() - start_time} secs: {received=} {total=}")
if self._process_error:
self.log_error(fl_ctx, "there is process error")
return make_reply(ReturnCode.EXECUTION_EXCEPTION)
if received == 0:
# nothing received! This maybe ok
self.log_warning(fl_ctx, "nothing received - timeout")
# return make_reply(ReturnCode.TIMEOUT)
try:
self.log_info(fl_ctx, "return aggregation result")
return self.aggregator.aggregate(fl_ctx)
except Exception as ex:
self.log_error(fl_ctx, f"exception 'aggregate' from {type(self.aggregator)}: {secure_format_exception(ex)}")
return make_reply(ReturnCode.EXECUTION_EXCEPTION)