# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import random
import threading
import time
from typing import Optional
from nvflare.apis.event_type import EventType
from nvflare.apis.executor import Executor
from nvflare.apis.fl_constant import FLContextKey, ReservedKey
from nvflare.apis.fl_context import FLContext
from nvflare.apis.shareable import ReturnCode, Shareable, make_reply
from nvflare.apis.signal import Signal
from nvflare.app_common.app_constant import AppConstants
from nvflare.edge.constants import EdgeTaskHeaderKey
from nvflare.edge.updater import Updater
from nvflare.edge.utils import message_topic_for_task_end, message_topic_for_task_update, process_update_from_child
from nvflare.fuel.utils.tree_utils import Forest, Node
from nvflare.fuel.utils.validation_utils import check_positive_number, check_str
from nvflare.fuel.utils.waiter_utils import WaiterRC, conditional_wait
from nvflare.security.logging import secure_format_exception
[docs]
class TaskInfo:
def __init__(self, task: Shareable):
self.task = task
self.cookie_jar = task.get_cookie_jar()
self.round = task.get_header(AppConstants.CURRENT_ROUND)
self.id = task.get_header(ReservedKey.TASK_ID)
self.name = task.get_header(ReservedKey.TASK_NAME)
self.seq = task.get_header(EdgeTaskHeaderKey.TASK_SEQ)
self.update_interval = task.get_header(EdgeTaskHeaderKey.UPDATE_INTERVAL, 1.0)
[docs]
class HierarchicalUpdateGatherer(Executor):
def __init__(
self,
learner_id: str,
updater_id: str,
update_timeout,
):
Executor.__init__(self)
check_str("learner_id", learner_id)
check_str("updater_id", updater_id)
check_positive_number("update_timeout", update_timeout)
self.learner_id = learner_id
self.updater_id = updater_id
self.update_timeout = update_timeout
self._pending_task = None
self._pending_clients = {}
self._updater = None
self._learner = None
self._status_lock = threading.Lock()
self._update_lock = threading.Lock()
self._process_error = None
self._task_start_time = None
self._children = None
self._num_children = 0
self._num_children_done = 0
self._parent_name = None
self._task_done = False
self._msg_handler_registered = {} # topic => bool
self.register_event_handler(EventType.START_RUN, self._hug_handle_start_run)
self.register_event_handler(EventType.POST_TASK_ASSIGNMENT_SENT, self._handle_task_sent)
self.register_event_handler(EventType.POST_TASK_RESULT_RECEIVED, self._handle_result_received)
[docs]
def get_updater(self, fl_ctx: FLContext) -> Optional[Updater]:
return None
def _hug_handle_start_run(self, event_type: str, fl_ctx: FLContext):
self.log_debug(fl_ctx, f"handling event {event_type}")
engine = fl_ctx.get_engine()
if self.updater_id:
updater = engine.get_component(self.updater_id)
if not isinstance(updater, Updater):
self.system_panic(f"component '{self.updater_id}' must be Updater but got {type(updater)}", fl_ctx)
return
else:
updater = self.get_updater(fl_ctx)
if not isinstance(updater, Updater):
self.system_panic(f"get_updater() must return Updater but got {type(updater)}", fl_ctx)
return
self._updater = updater
self.log_info(fl_ctx, f"got updater: {type(self._updater)}")
if self.learner_id:
learner = engine.get_component(self.learner_id)
if not isinstance(learner, Executor):
self.system_panic(f"component '{self.learner_id}' must be Executor but got {type(learner)}", fl_ctx)
return
self._learner = learner
client_hierarchy = fl_ctx.get_prop(FLContextKey.CLIENT_HIERARCHY)
if not isinstance(client_hierarchy, Forest):
self.system_panic(
f"cannot get client hierarchy from fl-ctx: expect Forest but got {type(client_hierarchy)}", fl_ctx
)
return
my_name = fl_ctx.get_identity_name()
my_node = client_hierarchy.nodes.get(my_name)
if not isinstance(my_node, Node):
self.system_panic(f"cannot get my node from client hierarchy: expect Node but got {type(my_node)}", fl_ctx)
return
self._children = [n.obj.name for n in my_node.children]
self._num_children = len(self._children)
self.log_info(fl_ctx, f"got {self._num_children} child clients: {self._children}")
parent_node = my_node.parent
if not parent_node:
self._parent_name = None # for server
else:
self._parent_name = parent_node.obj.name
self.log_info(fl_ctx, f"my parent is: {self._parent_name}")
def _handle_task_sent(self, event_type: str, fl_ctx: FLContext):
# the task was sent to a child client
self.log_debug(fl_ctx, f"handling event {event_type}")
fl_ctx.set_prop(FLContextKey.EVENT_PROCESSED, True, private=True, sticky=False)
task_info = self._pending_task
if not task_info:
# I don't have a pending task
return
assert isinstance(task_info, TaskInfo)
sent_task_id = fl_ctx.get_prop(FLContextKey.TASK_ID)
if sent_task_id != task_info.id:
# task sent is not the same as what I have
self.log_warning(fl_ctx, f"task sent {sent_task_id} is not the same as what I have {task_info.id}")
return
child_client_ctx = fl_ctx.get_peer_context()
assert isinstance(child_client_ctx, FLContext)
child_client_name = child_client_ctx.get_identity_name()
self._update_client_status(child_client_name, None)
self.log_info(fl_ctx, f"sent task {sent_task_id} to child {child_client_name}")
def _handle_result_received(self, event_type: str, fl_ctx: FLContext):
# received results from a child client
self.log_debug(fl_ctx, f"handling event {event_type}")
# indicate that this event has been processed by me
fl_ctx.set_prop(FLContextKey.EVENT_PROCESSED, True, private=True, sticky=False)
task_info = self._pending_task
if not task_info:
# I don't have a pending task
return
assert isinstance(task_info, TaskInfo)
result = fl_ctx.get_prop(FLContextKey.TASK_RESULT)
assert isinstance(result, Shareable)
peer_ctx = fl_ctx.get_peer_context()
assert isinstance(peer_ctx, FLContext)
child_client_name = peer_ctx.get_identity_name()
rc = result.get_return_code(ReturnCode.OK)
self.log_info(fl_ctx, f"received task submission from child {child_client_name}: {rc}")
result_task_id = result.get_header(ReservedKey.TASK_ID)
if result_task_id != task_info.id:
self.log_info(
fl_ctx,
f"dropped task submission from child {child_client_name} for task {result_task_id}: "
f"we are working on task {task_info.id}",
)
return
self._update_client_status(child_client_name, time.time())
has_update_data = result.get_header(EdgeTaskHeaderKey.HAS_UPDATE_DATA, False)
if has_update_data:
accepted, _ = self._accept_child_update(result, fl_ctx, task_info.round)
self.log_debug(fl_ctx, f"processed update from task submission: {accepted=}")
def _pending_clients_status(self):
with self._status_lock:
if not self._pending_clients:
return 0, 0
received = 0
for received_time in self._pending_clients.values():
if received_time:
received += 1
return received, len(self._pending_clients)
def _update_client_status(self, client_name, status):
with self._status_lock:
self._pending_clients[client_name] = status
[docs]
def execute(self, task_name: str, shareable: Shareable, fl_ctx: FLContext, abort_signal: Signal) -> Shareable:
"""Execute the assigned task.
If we are a leaf node in client hierarchy, we'll execute the task by using the configured executor for the
task name "exec_<task_name>". This way different tasks can be handled by different executors.
If we are not leaf node, we'll wait for results from child clients and then aggregate their results using
the configured aggregator.
Args:
task_name: name of the assigned task
shareable: task data
fl_ctx: FLContext object
abort_signal: signal to notify abort
Returns: task result
"""
is_leaf = fl_ctx.get_prop(ReservedKey.IS_LEAF)
if is_leaf and self._learner:
try:
return self._learner.execute(task_name, shareable, fl_ctx, abort_signal)
except Exception as ex:
self.log_error(fl_ctx, f"exception from {type(self._learner)}: {secure_format_exception(ex)}")
return make_reply(ReturnCode.EXECUTION_EXCEPTION)
# register msg handler for aggr reports from children
if not self._msg_handler_registered.get(task_name):
engine = fl_ctx.get_engine()
engine.register_aux_message_handler(message_topic_for_task_update(task_name), self._process_child_update)
engine.register_aux_message_handler(message_topic_for_task_end(task_name), self._process_task_end)
self._msg_handler_registered[task_name] = True
self._pending_task = TaskInfo(shareable)
assert isinstance(self._updater, Updater)
self._pending_task.task = self._updater.start_task(shareable, fl_ctx)
self.task_started(self._pending_task, fl_ctx)
self.log_info(fl_ctx, f"got current_round: {self._pending_task.round}")
# Set header to indicate that we are ready to manage child clients
# Note: when a child comes to pull task, the communicator only sends it after the task is ready.
# This is to avoid the potential race condition that the client gets the task and then quickly submits
# result before we are even ready.
shareable.set_header(ReservedKey.TASK_IS_READY, True)
self._task_start_time = time.time()
result = self._do_task(fl_ctx, abort_signal)
# reset state
self.task_ended(self._pending_task, fl_ctx)
self._task_done = False
self._task_start_time = None
self._pending_task = None
self._pending_clients = {}
self._updater.end_task(fl_ctx)
self._process_error = False
return result
def _do_task(self, fl_ctx: FLContext, abort_signal: Signal) -> Shareable:
task_info = self._pending_task
assert isinstance(task_info, TaskInfo)
self.log_info(fl_ctx, f"Starting task seq {task_info.seq} and waiting for results from children ...")
update_interval = task_info.update_interval
while True:
# has the task_done been set?
if self._task_done:
self.log_info(fl_ctx, f"task {task_info.seq} is done: task was set to done")
break
if self._process_error:
# we bail out when any processing error encountered
self.log_info(fl_ctx, f"task seq {task_info.seq} is done: processing error occurred")
break
# send aggr results periodically
report = self._make_update_report(task_info, fl_ctx)
engine = fl_ctx.get_engine()
replies = engine.send_aux_request(
targets=self._parent_name,
topic=message_topic_for_task_update(task_info.name),
request=report,
timeout=self.update_timeout,
fl_ctx=fl_ctx,
)
assert isinstance(replies, dict)
if len(replies) != 1:
# this should never happen since the engine should always return a reply
self.log_error(fl_ctx, f"no reply from parent {self._parent_name}")
self._process_error = True
break
reply = list(replies.values())[0]
if not isinstance(reply, Shareable):
self.log_error(
fl_ctx,
f"bad reply from parent {self._parent_name}: expect reply to be Shareable but got {type(reply)}",
)
self._process_error = True
break
rc = reply.get_return_code()
if rc == ReturnCode.TASK_ABORTED:
self.log_info(fl_ctx, f"task {task_info.seq} is done: parent task is gone")
break
if rc != ReturnCode.OK:
if rc == ReturnCode.TIMEOUT and self._task_done:
self.log_info(
fl_ctx, f"parent update timeout after task {task_info.seq} completion - this is expected"
)
break
else:
self.log_error(fl_ctx, f"error updating parent {self._parent_name}: {rc}")
break
parent_task_seq = reply.get_header(EdgeTaskHeaderKey.TASK_SEQ, task_info.seq)
if parent_task_seq != task_info.seq:
# this task is done
self.log_info(fl_ctx, f"task {task_info.seq} is done: parent moved to task {parent_task_seq}")
break
# have I received all possible responses from my children?
if self._num_children > 0:
received, _ = self._pending_clients_status()
if received >= self._num_children:
self.log_info(fl_ctx, f"task {task_info.seq} is done: all {received} child clients are done!")
break
try:
assert isinstance(self._updater, Updater)
self._updater.process_parent_update_reply(reply, fl_ctx)
except Exception as ex:
self.log_exception(
fl_ctx,
f"exception 'process_parent_update' from {type(self._updater)}: {secure_format_exception(ex)}",
)
wrc = conditional_wait(
waiter=None,
timeout=update_interval + random.uniform(0.0, 0.5),
abort_signal=abort_signal,
condition_cb=self._check_task_done,
)
if wrc == WaiterRC.ABORTED:
return make_reply(ReturnCode.TASK_ABORTED)
elif wrc == WaiterRC.IS_SET:
break
received, total = self._pending_clients_status()
self.log_info(fl_ctx, f"task done after {time.time() - self._task_start_time} secs: {received=} {total=}")
if self._process_error:
self.log_error(fl_ctx, "there is process error")
return make_reply(ReturnCode.EXECUTION_EXCEPTION)
# still anything to be aggregated?
return self._make_update_report(task_info, fl_ctx)
def _check_task_done(self):
if self._task_done:
# force the conditional wait to stop
return WaiterRC.IS_SET
def _make_update_report(self, task_info: TaskInfo, fl_ctx: FLContext):
fl_ctx.set_prop(AppConstants.CURRENT_ROUND, task_info.round, private=True, sticky=False)
with self._update_lock:
has_update_data = True
update = self._prepare_update(fl_ctx)
if not update:
has_update_data = False
update = Shareable()
self.log_debug(fl_ctx, f"making update report to parent for task {task_info.seq}: {has_update_data=}")
update.set_header(EdgeTaskHeaderKey.TASK_SEQ, task_info.seq)
update.set_header(EdgeTaskHeaderKey.HAS_UPDATE_DATA, has_update_data)
update.set_return_code(ReturnCode.OK)
update.set_cookie_jar(task_info.cookie_jar)
return update
def _prepare_update(self, fl_ctx: FLContext):
try:
assert isinstance(self._updater, Updater)
update = self._updater.prepare_update_for_parent(fl_ctx)
except Exception as ex:
self.log_exception(
fl_ctx, f"exception 'get_current_update' from {type(self._updater)}: {secure_format_exception(ex)}"
)
self._process_error = True
update = None
return update
def _process_task_end(self, topic: str, request: Shareable, fl_ctx: FLContext) -> Shareable:
# process notification from parent that task is ended
task_seq = request.get_header(EdgeTaskHeaderKey.TASK_SEQ)
if self._num_children > 0:
# fire-and-forget notification to all my children
req = Shareable()
req.set_header(EdgeTaskHeaderKey.TASK_SEQ, task_seq)
engine = fl_ctx.get_engine()
engine.send_aux_request(
targets=self._children,
topic=topic,
request=req,
timeout=0, # fire and forget
fl_ctx=fl_ctx,
optional=True,
)
task_info = self._pending_task
if task_info:
if task_info.seq <= task_seq:
# my current task is before the ended task - end my task
self.log_info(
fl_ctx, f"ended current task seq {task_info.seq}: got end_task from parent for task {task_seq}"
)
self._task_done = True
else:
self.log_info(
fl_ctx, f"ignored end_task from parent for task {task_seq} since it's < my task seq {task_info.seq}"
)
else:
self.log_info(fl_ctx, f"ignored end_task from parent for task {task_seq} since I have no current task")
return make_reply(ReturnCode.OK)
def _process_child_update(self, topic: str, request: Shareable, fl_ctx: FLContext) -> Shareable:
"""Process update received from a child.
Every time an update is received from a child, we call the "updater" to accept the update and return the
reply from the updater back to the child.
Args:
topic: topic of the update message
update: the update data from child
fl_ctx: FLContext object
Returns: reply from the updater
"""
self.log_debug(fl_ctx, f"processing child update: {topic}")
task_info = self._pending_task
if task_info:
assert isinstance(task_info, TaskInfo)
seq = task_info.seq
current_round = task_info.round
else:
seq = 0
current_round = -1
accepted, reply = process_update_from_child(
processor=self,
update=request,
current_task_seq=seq,
fl_ctx=fl_ctx,
update_f=self._accept_child_update,
current_round=current_round,
)
peer_ctx = fl_ctx.get_peer_context()
client_name = peer_ctx.get_identity_name()
self.log_debug(fl_ctx, f"processed aggr result report from {client_name} at round {current_round}: {accepted=}")
return reply
def _accept_child_update(self, update: Shareable, fl_ctx: FLContext, current_round) -> (bool, Shareable):
fl_ctx.set_prop(AppConstants.CURRENT_ROUND, current_round, private=True, sticky=False)
with self._update_lock:
reply = None
try:
accepted, reply = self._updater.process_child_update(update, fl_ctx)
except Exception as ex:
self.log_exception(fl_ctx, f"exception accepting update: {secure_format_exception(ex)}")
accepted = False
return accepted, reply
[docs]
def accept_update(self, task_id: str, update: Shareable, fl_ctx: FLContext) -> bool:
"""This is to be called by subclass to accept a specified update
Args:
task_id: ID of the task
update: the update to be accepted.
fl_ctx: FLContext object
Returns: whether the update is accepted
"""
task_info = self._pending_task
if not task_info:
self.log_warning(fl_ctx, f"update dropped for task_id {task_id}: no current task")
return False
if task_id and task_id != task_info.id:
self.log_warning(
fl_ctx, f"contribution dropped for task_id {task_id}: it does not match current task {task_info.id}"
)
return False
accepted, _ = self._accept_child_update(update, fl_ctx, task_info.round)
return accepted
[docs]
def set_task_done(self, task_id: str, fl_ctx: FLContext) -> bool:
"""This method is to be called by subclass to forcefully end the specified task
Args:
task_id: ID of the task to be ended
fl_ctx: FLContext object
Returns: whether this request is accepted
"""
task_info = self._pending_task
if not task_info:
self.log_info(fl_ctx, f"ignored set_task_done for task_id {task_id}: no current task.")
return False
if task_id != task_info.id:
self.log_info(
fl_ctx, f"ignored set_task_done for task_id {task_id}: it does not match current task {task_info.id}"
)
return False
self._task_done = True
self.log_info(fl_ctx, f"accepted set_task_done for task_id {task_id}")
return True
[docs]
def get_current_task(self, fl_ctx: FLContext) -> Optional[TaskInfo]:
"""Get the info of current task
Returns: TaskInfo of current task or None if no current task
Note: During the life of the task processing, the "task" data of the TaskInfo could be updated many times.
"""
if not self._updater:
# we are not ready yet
return None
if not self._pending_task:
return None
try:
self._pending_task.task = self._updater.get_current_state(fl_ctx)
except Exception as ex:
self.log_exception(
fl_ctx, f"exception get_current_state from {type(self._updater)}: {secure_format_exception(ex)}"
)
return self._pending_task
[docs]
def task_started(self, task: TaskInfo, fl_ctx: FLContext):
"""This method is called when a task assignment is received from the controller.
Subclass can implement this method to prepare for task processing.
Args:
task: info of the received task
fl_ctx: FLContext object
Returns: None
"""
pass
[docs]
def task_ended(self, task: TaskInfo, fl_ctx: FLContext):
"""This method is called when the current task is ended.
Subclass can implement this method to finish task processing.
Args:
task: info of the task that is ended
fl_ctx: FLContext object
Returns: None
"""
pass