Source code for nvflare.edge.executors.simple_edge_executor

# Copyright (c) 2025, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional

from nvflare.apis.fl_context import FLContext
from nvflare.apis.shareable import ReservedHeaderKey, Shareable
from nvflare.edge.constants import EdgeApiStatus, MsgKey
from nvflare.edge.executors.ete import EdgeTaskExecutor
from nvflare.edge.executors.hug import TaskInfo
from nvflare.edge.web.models.result_report import ResultReport
from nvflare.edge.web.models.result_response import ResultResponse
from nvflare.edge.web.models.selection_request import SelectionRequest
from nvflare.edge.web.models.selection_response import SelectionResponse
from nvflare.edge.web.models.task_request import TaskRequest
from nvflare.edge.web.models.task_response import TaskResponse
from nvflare.security.logging import secure_format_exception


[docs] class SimpleEdgeExecutor(EdgeTaskExecutor): """A very simple edge executor that only does aggregation""" def __init__(self, updater_id, update_timeout=60): EdgeTaskExecutor.__init__(self, updater_id, update_timeout) self.devices = None
[docs] def convert_task(self, task_data: Shareable, current_task: TaskInfo, fl_ctx: FLContext) -> dict: """Convert task_data to a plain dict""" self.log_debug(fl_ctx, f"Converting task for task: {current_task.id}") return {"weights": task_data.get("weights", None), MsgKey.TASK_ID: current_task.id}
[docs] def convert_result(self, result: dict, current_task: TaskInfo, fl_ctx: FLContext) -> Shareable: """Convert result from device to shareable""" self.log_debug(fl_ctx, f"Converting result for task: {current_task.id}") shareable = Shareable(result) shareable.set_header(ReservedHeaderKey.TASK_ID, current_task.id) return shareable
[docs] def process_edge_task_request( self, request: TaskRequest, current_task: TaskInfo, fl_ctx: FLContext ) -> TaskResponse: """Handle task request from device""" device_id = request.get_device_id() job_id = fl_ctx.get_job_id() # This device already processed current task last_task_id = self.devices.get(device_id, None) task_id = current_task.id if task_id == last_task_id: msg = f"Task {task_id} is already processed by this device" return TaskResponse(EdgeApiStatus.RETRY, job_id, 30, message=msg) task_data = self.convert_task(current_task.task, current_task, fl_ctx) self.devices[device_id] = task_id return TaskResponse(EdgeApiStatus.OK, job_id, 0, task_id, current_task.name, task_data)
[docs] def process_edge_result_report( self, report: ResultReport, current_task: TaskInfo, fl_ctx: FLContext ) -> ResultResponse: """Handle result report from device The report task_id may be different from current task_id. Let HAM deal with it """ try: result = self.convert_result(report.result, current_task, fl_ctx) self.accept_update(task_id=report.task_id, update=result, fl_ctx=fl_ctx) return ResultResponse(EdgeApiStatus.OK, task_id=report.task_id, task_name=report.task_name) except Exception as ex: msg = f"Error accepting contribution: {secure_format_exception(ex)}" self.log_error(fl_ctx, msg) return ResultResponse(EdgeApiStatus.ERROR, task_id=report.task_id, task_name=report.task_name, message=msg)
[docs] def task_started(self, task: TaskInfo, fl_ctx: FLContext): self.log_info(fl_ctx, f"Got task_started: {task.id} (seq {task.seq})") self.devices = {}
[docs] def task_ended(self, task: TaskInfo, fl_ctx: FLContext): self.log_info(fl_ctx, f"Got task_ended: {task.id} (seq {task.seq})") self.devices = None
[docs] def process_edge_selection_request( self, request: SelectionRequest, current_task: TaskInfo, fl_ctx: FLContext ) -> Optional[SelectionResponse]: return None