Source code for nvflare.app_common.executors.task_exchanger

# Copyright (c) 2023, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import time
from typing import Optional

from nvflare.apis.event_type import EventType
from nvflare.apis.executor import Executor
from nvflare.apis.fl_constant import FLContextKey, FLMetaKey
from nvflare.apis.fl_context import FLContext
from nvflare.apis.shareable import ReturnCode, Shareable, make_reply
from nvflare.apis.signal import Signal
from nvflare.app_common.app_constant import AppConstants
from nvflare.fuel.utils.constants import PipeChannelName
from nvflare.fuel.utils.pipe.pipe import Message, Pipe
from nvflare.fuel.utils.pipe.pipe_handler import PipeHandler, Topic
from nvflare.fuel.utils.validation_utils import check_non_negative_int, check_positive_number, check_str
from nvflare.security.logging import secure_format_exception


[docs]class TaskExchanger(Executor): def __init__( self, pipe_id: str, read_interval: float = 0.5, heartbeat_interval: float = 5.0, heartbeat_timeout: Optional[float] = 60.0, resend_interval: float = 2.0, max_resends: Optional[int] = None, peer_read_timeout: Optional[float] = 60.0, task_wait_time: Optional[float] = None, result_poll_interval: float = 0.5, pipe_channel_name=PipeChannelName.TASK, ): """Constructor of TaskExchanger. Args: pipe_id (str): component id of pipe. read_interval (float): how often to read from pipe. heartbeat_interval (float): how often to send heartbeat to peer. heartbeat_timeout (float, optional): how long to wait for a heartbeat from the peer before treating the peer as dead, 0 means DO NOT check for heartbeat. resend_interval (float): how often to resend a message if failing to send. None means no resend. Note that if the pipe does not support resending, then no resend. max_resends (int, optional): max number of resend. None means no limit. Defaults to None. peer_read_timeout (float, optional): time to wait for peer to accept sent message. task_wait_time (float, optional): how long to wait for a task to complete. None means waiting forever. Defaults to None. result_poll_interval (float): how often to poll task result. Defaults to 0.5. pipe_channel_name: the channel name for sending task requests. Defaults to "task". """ Executor.__init__(self) check_str("pipe_id", pipe_id) check_positive_number("read_interval", read_interval) check_positive_number("heartbeat_interval", heartbeat_interval) if heartbeat_timeout is not None: check_positive_number("heartbeat_timeout", heartbeat_timeout) check_positive_number("resend_interval", resend_interval) if max_resends is not None: check_non_negative_int("max_resends", max_resends) if peer_read_timeout is not None: check_positive_number("peer_read_timeout", peer_read_timeout) if task_wait_time is not None: check_positive_number("task_wait_time", task_wait_time) check_positive_number("result_poll_interval", result_poll_interval) check_str("pipe_channel_name", pipe_channel_name) self.pipe_id = pipe_id self.read_interval = read_interval self.heartbeat_interval = heartbeat_interval self.heartbeat_timeout = heartbeat_timeout self.resend_interval = resend_interval self.max_resends = max_resends self.peer_read_timeout = peer_read_timeout self.task_wait_time = task_wait_time self.result_poll_interval = result_poll_interval self.pipe_channel_name = pipe_channel_name self.pipe = None self.pipe_handler = None
[docs] def handle_event(self, event_type: str, fl_ctx: FLContext): if event_type == EventType.START_RUN: engine = fl_ctx.get_engine() self.pipe = engine.get_component(self.pipe_id) if not isinstance(self.pipe, Pipe): self.system_panic(f"component of {self.pipe_id} must be Pipe but got {type(self.pipe)}", fl_ctx) return self.pipe_handler = PipeHandler( pipe=self.pipe, read_interval=self.read_interval, heartbeat_interval=self.heartbeat_interval, heartbeat_timeout=self.heartbeat_timeout, resend_interval=self.resend_interval, max_resends=self.max_resends, ) self.pipe_handler.set_status_cb(self._pipe_status_cb) self.pipe.open(self.pipe_channel_name) self.pipe_handler.start() elif event_type == EventType.ABOUT_TO_END_RUN: self.log_info(fl_ctx, "Stopping pipe handler") if self.pipe_handler: self.pipe_handler.notify_end("end_of_job") self.pipe_handler.stop()
def _pipe_status_cb(self, msg: Message): self.logger.info(f"pipe status changed to {msg.topic}") self.pipe_handler.stop()
[docs] def execute(self, task_name: str, shareable: Shareable, fl_ctx: FLContext, abort_signal: Signal) -> Shareable: """ The TaskExchanger always sends the Shareable to the peer, and expects to receive a Shareable object from the peer. The peer can convert the Shareable object to whatever format that is best for its applications (e.g. DXO or FLModel object). Similarly, when submitting result, the peer must convert its result object to a Shareable object before sending it back to the TaskExchanger. This "late-binding" (binding of the Shareable object to an application-friendly object) strategy makes the TaskExchanger generic and can be reused for any applications (e.g. Shareable based, DXO based, or any custom data based). """ if not self.check_input_shareable(task_name, shareable, fl_ctx): self.log_error(fl_ctx, "bad input task shareable") return make_reply(ReturnCode.BAD_TASK_DATA) shareable.set_header(FLMetaKey.JOB_ID, fl_ctx.get_job_id()) shareable.set_header(FLMetaKey.SITE_NAME, fl_ctx.get_identity_name()) task_id = shareable.get_header(key=FLContextKey.TASK_ID) # send to peer self.log_info(fl_ctx, f"sending task to peer {self.peer_read_timeout=}") req = Message.new_request(topic=task_name, data=shareable, msg_id=task_id) start_time = time.time() has_been_read = self.pipe_handler.send_to_peer(req, timeout=self.peer_read_timeout, abort_signal=abort_signal) if self.peer_read_timeout and not has_been_read: self.log_error( fl_ctx, f"peer does not accept task '{task_name}' in {time.time()-start_time} secs - aborting task!", ) return make_reply(ReturnCode.EXECUTION_EXCEPTION) self.log_info(fl_ctx, f"task {task_name} sent to peer in {time.time()-start_time} secs") # wait for result self.log_debug(fl_ctx, "Waiting for result from peer") start = time.time() while True: if abort_signal.triggered: # notify peer that the task is aborted self.log_debug(fl_ctx, f"task '{task_name}' is aborted.") self.pipe_handler.notify_abort(task_id) self.pipe_handler.stop() return make_reply(ReturnCode.TASK_ABORTED) if self.pipe_handler.asked_to_stop: self.log_debug(fl_ctx, "task pipe stopped!") self.pipe_handler.notify_abort(task_id) abort_signal.trigger("task pipe stopped!") return make_reply(ReturnCode.TASK_ABORTED) reply: Optional[Message] = self.pipe_handler.get_next() if reply is None: if self.task_wait_time and time.time() - start > self.task_wait_time: # timed out self.log_error(fl_ctx, f"task '{task_name}' timeout after {self.task_wait_time} secs") # also tell peer to abort the task self.pipe_handler.notify_abort(task_id) abort_signal.trigger(f"task '{task_name}' timeout after {self.task_wait_time} secs") return make_reply(ReturnCode.EXECUTION_EXCEPTION) elif reply.msg_type != Message.REPLY: self.log_warning( fl_ctx, f"ignored reply: '{reply}' (wrong message type) while waiting for the result of {task_name}" ) elif req.topic != reply.topic: # ignore wrong topic self.log_warning( fl_ctx, f"ignored reply: '{reply}' (reply topic does not match req: '{req}') while waiting for the result of {task_name}", ) elif req.msg_id != reply.req_id: self.log_warning( fl_ctx, f"ignored reply: '{reply}' (reply req_id does not match req msg_id: '{req}') while waiting for the result of {task_name}", ) else: self.log_info(fl_ctx, f"got result '{reply}' for task '{task_name}'") try: result = reply.data if not isinstance(result, Shareable): self.log_error(fl_ctx, f"bad task result from peer: expect Shareable but got {type(result)}") return make_reply(ReturnCode.EXECUTION_EXCEPTION) current_round = shareable.get_header(AppConstants.CURRENT_ROUND) if current_round is not None: result.set_header(AppConstants.CURRENT_ROUND, current_round) if not self.check_output_shareable(task_name, result, fl_ctx): self.log_error(fl_ctx, "bad task result from peer") return make_reply(ReturnCode.EXECUTION_EXCEPTION) self.log_info(fl_ctx, f"received result of {task_name} from peer in {time.time()-start} secs") return result except Exception as ex: self.log_error(fl_ctx, f"Failed to convert result: {secure_format_exception(ex)}") return make_reply(ReturnCode.EXECUTION_EXCEPTION) time.sleep(self.result_poll_interval)
[docs] def check_input_shareable(self, task_name: str, shareable: Shareable, fl_ctx: FLContext) -> bool: """Checks input shareable before execute. Returns: True, if input shareable looks good; False, otherwise. """ return True
[docs] def check_output_shareable(self, task_name: str, shareable: Shareable, fl_ctx: FLContext) -> bool: """Checks output shareable after execute. Returns: True, if output shareable looks good; False, otherwise. """ return True
[docs] def ask_peer_to_end(self, fl_ctx: FLContext) -> bool: req = Message.new_request(topic=Topic.END, data="END") has_been_read = self.pipe_handler.send_to_peer(req, timeout=self.peer_read_timeout) if self.peer_read_timeout and not has_been_read: self.log_warning( fl_ctx, f"3rd party does not read END msg in {self.peer_read_timeout} secs!", ) return False return True
[docs] def peer_is_up_or_dead(self) -> bool: return self.pipe_handler.peer_is_up_or_dead.is_set()
[docs] def pause_pipe_handler(self): """Stops pipe_handler heartbeat.""" self.pipe_handler.pause()
[docs] def resume_pipe_handler(self): """Resumes pipe_handler heartbeat.""" self.pipe_handler.resume()
[docs] def get_pipe(self): """Gets pipe.""" return self.pipe
[docs] def get_pipe_channel_name(self): """Gets pipe_channel_name.""" return self.pipe_channel_name