# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from nvflare.apis.fl_component import FLComponent
from nvflare.apis.fl_context import FLContext
from nvflare.apis.shareable import ReturnCode, Shareable, make_reply
from nvflare.edge.constants import EdgeTaskHeaderKey
from nvflare.security.logging import secure_format_exception
TOPIC_PREFIX = "SAGE"
[docs]
def message_topic_for_task_update(task_name: str) -> str:
return f"{TOPIC_PREFIX}__{task_name}_update"
[docs]
def message_topic_for_task_end(task_name: str) -> str:
return f"{TOPIC_PREFIX}__{task_name}_end"
def _make_update_reply(rc: str, seq: int, data: Shareable = None) -> Shareable:
if not data:
data = Shareable()
data.set_return_code(rc)
data.set_header(EdgeTaskHeaderKey.TASK_SEQ, seq)
return data
[docs]
def process_update_from_child(
processor: FLComponent,
update: Shareable,
current_task_seq: int,
fl_ctx: FLContext,
update_f,
**kwargs,
) -> (bool, Shareable):
"""Process aggregation report sent from a child client.
Args:
processor: the component that received the update report from the child.
update: the report request
current_task_seq: sequence number of the current task
fl_ctx: FLContext object
update_f: the function to be called to process the update report
**kwargs: args to be passed to update_f
Returns: a tuple of (whether the report is accepted, reply to be sent back to the reporter).
"""
peer_ctx = fl_ctx.get_peer_context()
assert isinstance(peer_ctx, FLContext)
child_name = peer_ctx.get_identity_name()
task_seq = update.get_header(EdgeTaskHeaderKey.TASK_SEQ)
if not task_seq:
processor.log_error(fl_ctx, f"missing {EdgeTaskHeaderKey.TASK_SEQ} from update header")
return False, _make_update_reply(ReturnCode.BAD_REQUEST_DATA, current_task_seq)
if task_seq != current_task_seq:
rc = ReturnCode.TASK_ABORTED
else:
rc = ReturnCode.OK
if task_seq != current_task_seq:
if current_task_seq == 0:
# this means no current task
processor.log_warning(
fl_ctx, f"dropped update from {child_name}: got task seq {task_seq} but no current task"
)
else:
processor.log_warning(
fl_ctx, f"dropped update from {child_name}: expect task seq {current_task_seq} but got {task_seq}"
)
return False, make_reply(rc, current_task_seq)
has_update_data = update.get_header(EdgeTaskHeaderKey.HAS_UPDATE_DATA)
if has_update_data is None:
processor.log_info(fl_ctx, f"request does not have header {EdgeTaskHeaderKey.HAS_UPDATE_DATA}")
processor.log_debug(fl_ctx, f"result has update data: {has_update_data=}")
if not has_update_data:
return False, make_reply(rc, current_task_seq)
reply_data = None
try:
accepted, reply_data = update_f(update, fl_ctx, **kwargs)
except Exception as ex:
processor.log_exception(
fl_ctx, f"exception accepting update result from {update_f.__name__}: {secure_format_exception(ex)}"
)
accepted = False
return accepted, _make_update_reply(rc, current_task_seq, reply_data)