# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import gc
import time
from enum import Enum
from nvflare.apis.client import Client
from nvflare.apis.controller_spec import ClientTask, Task
from nvflare.apis.fl_constant import FLContextKey, ReturnCode
from nvflare.apis.fl_context import FLContext
from nvflare.apis.impl.controller import Controller
from nvflare.apis.shareable import Shareable
from nvflare.apis.signal import Signal
from nvflare.app_common.app_constant import AppConstants
from nvflare.app_common.app_event_type import AppEventType
from nvflare.edge.assessor import Assessment, Assessor
from nvflare.edge.constants import EdgeTaskHeaderKey
from nvflare.edge.utils import message_topic_for_task_end, message_topic_for_task_update, process_update_from_child
from nvflare.fuel.utils.validation_utils import check_positive_number, check_str
from nvflare.fuel.utils.waiter_utils import WaiterRC, conditional_wait
from nvflare.security.logging import secure_format_exception
from nvflare.widgets.info_collector import GroupInfoCollector, InfoCollector
[docs]
class TaskDoneReason(Enum):
ALL_CHILDREN_DONE = "all_children_done"
ABORTED = "aborted"
ASSESSED_TASK_DONE = "assessed_task_done"
ASSESSED_WORKFLOW_DONE = "assessed_workflow_done"
[docs]
class ScatterAndGatherForEdge(Controller):
next_task_seq = 0
def __init__(
self,
num_rounds: int = 5,
assessor_id: str = "assessor",
task_name=AppConstants.TASK_TRAIN,
task_check_period: float = 0.5,
assess_interval: float = 0.5,
update_interval: float = 1.0,
):
"""ScatterAndGatherForEdge Workflow.
The ScatterAndGatherForEdge workflow is a Fed Average algorithm for hierarchically organized edge devices.
During the execution of a task, the assessor (specified by assessor_id) is invoked periodically to assess
the quality of training results to determine whether the task should be continued.
Args:
num_rounds (int, optional): The total number of training rounds. Defaults to 5.
assessor_id (str): ID of the assessor component.
task_name (str): Name of the train task. Defaults to "train".
task_check_period (float, optional): interval for checking status of tasks. Defaults to 0.5.
assess_interval: how often to invoke the assessor during task execution
update_interval: how often for children to send updates
Raises:
TypeError: when any of input arguments does not have correct type
ValueError: when any of input arguments is out of range
"""
super().__init__(task_check_period=task_check_period)
# Check arguments
check_str("assessor_id", assessor_id)
check_str("task_name", task_name)
check_positive_number("task_check_period", task_check_period)
check_positive_number("assess_interval", assess_interval)
check_positive_number("update_interval", update_interval)
self.assessor_id = assessor_id
self.task_name = task_name
self.assessor = None
# config data
self._num_rounds = num_rounds
self._assess_interval = assess_interval
self._update_interval = update_interval
# workflow phases: init, train, validate
self._current_round = None
self._current_task_seq = 0
self._num_children = 0
self._children = None
self._end_task_topic = message_topic_for_task_end(self.task_name)
self._wf_done = False
[docs]
@classmethod
def get_next_task_seq(cls):
cls.next_task_seq += 1
return cls.next_task_seq
[docs]
def start_controller(self, fl_ctx: FLContext) -> None:
self.log_info(fl_ctx, f"Initializing {self._name} workflow.")
engine = fl_ctx.get_engine()
self.assessor = engine.get_component(self.assessor_id)
if not isinstance(self.assessor, Assessor):
self.system_panic(
f"Assessor {self.assessor_id} must be an Assessor but got {type(self.assessor)}",
fl_ctx,
)
return
# register aux message handler for receiving aggr results from children
engine.register_aux_message_handler(message_topic_for_task_update(self.task_name), self._process_update_report)
# get children clients
client_hierarchy = fl_ctx.get_prop(FLContextKey.CLIENT_HIERARCHY)
self._children = client_hierarchy.roots
self._num_children = len(self._children)
self.log_info(fl_ctx, f"my child clients: {self._children}")
[docs]
def control_flow(self, abort_signal: Signal, fl_ctx: FLContext) -> None:
try:
self.log_info(fl_ctx, f"Starting {self._name}")
fl_ctx.set_prop(AppConstants.NUM_ROUNDS, self._num_rounds, private=True, sticky=False)
self.fire_event(AppEventType.TRAINING_STARTED, fl_ctx)
for r in range(self._num_rounds):
round_start = time.time()
self._current_round = r
if self._check_abort_signal(fl_ctx, abort_signal):
break
self.log_info(fl_ctx, f"Round {r} started.")
self._current_task_seq = self.get_next_task_seq()
# Create train_task
fl_ctx.set_prop(AppConstants.CURRENT_ROUND, self._current_round, private=True, sticky=True)
try:
task_data = self.assessor.start_task(fl_ctx)
except Exception as ex:
self.log_exception(fl_ctx, f"exception in 'start_task' from {type(self.assessor)}")
self.system_panic(
f"Task execution encountered exception: {secure_format_exception(ex)}",
fl_ctx,
)
break
task_data.set_header(AppConstants.CURRENT_ROUND, self._current_round)
task_data.set_header(AppConstants.NUM_ROUNDS, self._num_rounds)
task_data.set_header(EdgeTaskHeaderKey.TASK_SEQ, self._current_task_seq)
task_data.set_header(EdgeTaskHeaderKey.UPDATE_INTERVAL, self._update_interval)
task_data.add_cookie(AppConstants.CONTRIBUTION_ROUND, self._current_round)
self.fire_event_with_data(AppEventType.ROUND_STARTED, fl_ctx, FLContextKey.TASK_DATA, task_data)
task = Task(
name=self.task_name,
data=task_data,
before_task_sent_cb=self._prepare_train_task_data,
result_received_cb=self._process_train_result,
)
self.broadcast(
task=task,
fl_ctx=fl_ctx,
targets=self._children,
min_responses=self._num_children,
wait_time_after_min_received=0,
)
# monitor the task until it's done
seq = self._current_task_seq
try:
task_done_reason = self._monitor_task(task, fl_ctx, abort_signal)
except Exception as ex:
self.system_panic(
f"Task {seq} execution encountered exception: {secure_format_exception(ex)}",
fl_ctx,
)
self.log_exception(fl_ctx, "exception in monitor_task")
return
try:
self.assessor.end_task(fl_ctx)
except Exception as ex:
self.log_exception(fl_ctx, f"exception in 'end_task' from {type(self.assessor)}")
self.system_panic(
f"Task execution encountered exception: {secure_format_exception(ex)}",
fl_ctx,
)
return
self._current_task_seq = 0
if not task.completion_status:
self.cancel_task(task, fl_ctx=fl_ctx)
if task_done_reason in [TaskDoneReason.ABORTED, TaskDoneReason.ASSESSED_WORKFLOW_DONE]:
break
self.fire_event(AppEventType.ROUND_DONE, fl_ctx)
self.log_info(fl_ctx, f"Round {self._current_round} finished in {time.time() - round_start} seconds")
gc.collect()
self._wf_done = True
self._current_task_seq = 0
self.log_info(fl_ctx, f"Finished {self._name}")
# give some time for clients to end gracefully when sync task seq
time.sleep(self._update_interval + 1.0)
except Exception as e:
error_msg = f"Exception in {self._name} workflow: {secure_format_exception(e)}"
self.log_exception(fl_ctx, error_msg)
self.system_panic(error_msg, fl_ctx)
def _monitor_task(self, task: Task, fl_ctx: FLContext, abort_signal: Signal) -> TaskDoneReason:
seq = self._current_task_seq
while True:
if task.completion_status:
# all children are done with their current task
self.log_info(fl_ctx, f"Task seq {seq} is completed: {task.completion_status=}")
return TaskDoneReason.ALL_CHILDREN_DONE
assessment = self.assessor.assess(fl_ctx)
if assessment != Assessment.CONTINUE:
self.log_info(fl_ctx, f"Task seq {seq} is done: {assessment=}")
# notify children to end task
req = Shareable()
req.set_header(EdgeTaskHeaderKey.TASK_SEQ, seq)
engine = fl_ctx.get_engine()
engine.send_aux_request(
targets=self._children,
topic=self._end_task_topic,
request=req,
timeout=0, # fire and forget
fl_ctx=fl_ctx,
optional=True,
)
if assessment == Assessment.WORKFLOW_DONE:
return TaskDoneReason.ASSESSED_WORKFLOW_DONE
else:
return TaskDoneReason.ASSESSED_TASK_DONE
wrc = conditional_wait(
waiter=None,
timeout=self._assess_interval,
abort_signal=abort_signal,
)
if wrc == WaiterRC.ABORTED:
self.log_info(fl_ctx, f"Task seq {seq} is done: ABORTED")
return TaskDoneReason.ABORTED
[docs]
def stop_controller(self, fl_ctx: FLContext):
try:
self.assessor.finalize(fl_ctx)
except Exception as e:
self.log_exception(fl_ctx, f"error finalizing assessor: {secure_format_exception(e)}")
[docs]
def handle_event(self, event_type: str, fl_ctx: FLContext):
super().handle_event(event_type, fl_ctx)
if event_type == InfoCollector.EVENT_TYPE_GET_STATS:
collector = fl_ctx.get_prop(InfoCollector.CTX_KEY_STATS_COLLECTOR, None)
if collector:
if not isinstance(collector, GroupInfoCollector):
raise TypeError("collector must be GroupInfoCollector but got {}".format(type(collector)))
collector.add_info(
group_name=self._name,
info={"current_round": self._current_round, "num_rounds": self._num_rounds},
)
def _prepare_train_task_data(self, client_task: ClientTask, fl_ctx: FLContext) -> None:
self.fire_event_with_data(
AppEventType.BEFORE_TRAIN_TASK, fl_ctx, AppConstants.TRAIN_SHAREABLE, client_task.task.data
)
def _process_train_result(self, client_task: ClientTask, fl_ctx: FLContext) -> None:
result = client_task.result
client_task.result = None
client_name = client_task.client.name
assert isinstance(result, Shareable)
rc = result.get_return_code()
# Raise errors if bad peer context or execution exception.
if rc and rc != ReturnCode.OK:
self.system_panic(
f"Result from {client_name} is bad, error code: {rc}. "
f"{self.__class__.__name__} exiting at round {self._current_round}.",
fl_ctx=fl_ctx,
)
return
has_update_data = result.get_header(EdgeTaskHeaderKey.HAS_UPDATE_DATA, False)
if has_update_data:
accepted = self._accept_update(result, fl_ctx)
self.log_debug(fl_ctx, f"processed update from task submission: {accepted=}")
[docs]
def process_result_of_unknown_task(
self, client: Client, task_name, client_task_id, result: Shareable, fl_ctx: FLContext
) -> None:
if not self._wf_done:
self.log_warning(fl_ctx, f"Ignoring late result from {client.name} for task '{task_name}' {client_task_id}")
def _process_update_report(self, topic: str, request: Shareable, fl_ctx: FLContext) -> Shareable:
accepted, reply = process_update_from_child(
processor=self,
update=request,
update_f=self._accept_update,
current_task_seq=self._current_task_seq,
fl_ctx=fl_ctx,
)
self.log_debug(fl_ctx, f"processed update from report: {accepted=}")
return reply
def _check_abort_signal(self, fl_ctx, abort_signal: Signal):
if abort_signal.triggered:
self.log_info(fl_ctx, f"Abort signal received. Exiting at round {self._current_round}.")
return True
return False
def _accept_update(self, update: Shareable, fl_ctx: FLContext):
return self.assessor.process_child_update(update, fl_ctx)