Source code for nvflare.app_common.executors.ipc_exchanger

# Copyright (c) 2023, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import threading
import time
from typing import Union

from nvflare.apis.dxo import DXO, DataKind, from_shareable
from nvflare.apis.event_type import EventType
from nvflare.apis.executor import Executor
from nvflare.apis.fl_constant import FLContextKey
from nvflare.apis.fl_context import FLContext
from nvflare.apis.shareable import ReturnCode, Shareable, make_reply
from nvflare.apis.signal import Signal
from nvflare.app_common.app_constant import AppConstants
from nvflare.client import defs
from nvflare.fuel.f3.cellnet.cell import Cell, Message, MessageHeaderKey
from nvflare.fuel.f3.cellnet.cell import ReturnCode as CellReturnCode
from nvflare.fuel.f3.cellnet.cell import new_message
from nvflare.fuel.f3.cellnet.utils import make_reply as make_cell_reply


class _TaskContext:
    def __init__(self, task_name: str, task_id: str, fl_ctx: FLContext):
        self.task_id = task_id
        self.task_name = task_name
        self.fl_ctx = fl_ctx
        self.send_rc = None
        self.result_rc = None
        self.result_error = None
        self.result = None
        self.result_received_time = None
        self.result_waiter = threading.Event()

    def __str__(self):
        return f"'{self.task_name} {self.task_id}'"


[docs]class IPCExchanger(Executor): def __init__( self, send_task_timeout=5.0, agent_ready_timeout=60.0, agent_heartbeat_timeout=600.0, agent_is_child=False, ): """Constructor of IPCExchanger Args: send_task_timeout: when sending task to Agent, how long to wait for response agent_ready_timeout: how long to wait for the agent to be connected agent_heartbeat_timeout: max time allowed to miss heartbeats from the agent agent_is_child: whether the agent will be a child cell. """ Executor.__init__(self) self.flare_agent_fqcn = None self.agent_ready_waiter = threading.Event() self.agent_ready_timeout = agent_ready_timeout self.agent_heartbeat_timeout = agent_heartbeat_timeout self.send_task_timeout = send_task_timeout self.agent_is_child = agent_is_child self.internal_listener_url = None self.last_agent_ack_time = time.time() self.engine = None self.cell = None self.is_done = False self.task_ctx = None
[docs] def handle_event(self, event_type: str, fl_ctx: FLContext): if event_type == EventType.START_RUN: self.engine = fl_ctx.get_engine() self.cell = self.engine.get_cell() self.cell.register_request_cb( channel=defs.CHANNEL, topic=defs.TOPIC_SUBMIT_RESULT, cb=self._receive_result, ) # get meta meta = fl_ctx.get_prop(FLContextKey.JOB_META) assert isinstance(meta, dict) agent_id = meta.get(defs.JOB_META_KEY_AGENT_ID) if not agent_id: self.system_panic(reason=f"missing {defs.JOB_META_KEY_AGENT_ID} from job meta", fl_ctx=fl_ctx) return client_name = fl_ctx.get_identity_name() self.flare_agent_fqcn = defs.agent_site_fqcn(client_name, agent_id) if self.agent_is_child: job_id = fl_ctx.get_job_id() self.flare_agent_fqcn = defs.agent_site_fqcn(client_name, agent_id, job_id) self.cell.make_internal_listener() self.internal_listener_url = self.cell.get_internal_listener_url() self.logger.info(f"URL for Agent: {self.internal_listener_url}") self.log_info(fl_ctx, f"Flare Agent FQCN: {self.flare_agent_fqcn}") t = threading.Thread(target=self._maintain, daemon=True) t.start() elif event_type == EventType.END_RUN: self.is_done = True self._say_goodbye()
def _say_goodbye(self): # say goodbye to agent self.logger.info(f"job done - say goodbye to {self.flare_agent_fqcn}") reply = self.cell.send_request( channel=defs.CHANNEL, topic=defs.TOPIC_BYE, target=self.flare_agent_fqcn, request=new_message(), optional=True, timeout=2.0, ) if reply: rc = reply.get_header(MessageHeaderKey.RETURN_CODE) if rc != CellReturnCode.OK: self.logger.warning(f"return code from agent {self.flare_agent_fqcn} for bye: {rc}") def _maintain(self): # try to connect the flare agent self.logger.info(f"waiting for flare agent {self.flare_agent_fqcn} ...") assert isinstance(self.cell, Cell) start_time = time.time() while not self.is_done: self.logger.info(f"ping {self.flare_agent_fqcn}") reply = self.cell.send_request( channel=defs.CHANNEL, topic=defs.TOPIC_HELLO, target=self.flare_agent_fqcn, request=new_message(), timeout=2.0, optional=True, ) rc = reply.get_header(MessageHeaderKey.RETURN_CODE) if rc == CellReturnCode.OK: self.logger.info(f"connected to agent {self.flare_agent_fqcn}") self.agent_ready_waiter.set() break self.logger.info(f"get reply: {reply.headers}") if time.time() - start_time > self.agent_ready_timeout: # cannot connect to agent! with self.engine.new_context() as fl_ctx: self.system_panic( reason=f"cannot connect to agent {self.flare_agent_fqcn} after {self.agent_ready_timeout} secs", fl_ctx=fl_ctx, ) self.is_done = True return time.sleep(2.0) # agent is now connected - heartbeats last_hb_time = 0 hb_interval = 10.0 while True: if self.is_done: return if time.time() - last_hb_time > hb_interval: reply = self.cell.send_request( channel=defs.CHANNEL, topic=defs.TOPIC_HEARTBEAT, target=self.flare_agent_fqcn, request=new_message(), timeout=1.5, ) last_hb_time = time.time() rc = reply.get_header(MessageHeaderKey.RETURN_CODE) if rc == CellReturnCode.OK: self.last_agent_ack_time = time.time() if time.time() - self.last_agent_ack_time > self.agent_heartbeat_timeout: with self.engine.new_context() as fl_ctx: self.system_panic( reason=f"agent dead: no heartbeat for {self.agent_heartbeat_timeout} secs", fl_ctx=fl_ctx, ) self.is_done = True return # sleep only small amount of time, so we can check other conditions frequently time.sleep(0.2)
[docs] def execute(self, task_name: str, shareable: Shareable, fl_ctx: FLContext, abort_signal: Signal) -> Shareable: # wait for flare agent while True: if self.is_done or abort_signal.triggered: return make_reply(ReturnCode.TASK_ABORTED) # wait for agent to be ready # we only wait for short time, so we could check other conditions (is_done, abort_signal) if self.agent_ready_waiter.wait(0.5): break task_id = shareable.get_header(key=FLContextKey.TASK_ID) current_task = self.task_ctx if current_task: # still working on previous task! self.log_error(fl_ctx, f"got new task {task_name=} {task_id=} while still working on {current_task}") return make_reply(ReturnCode.BAD_REQUEST_DATA) self.task_ctx = _TaskContext(task_name, task_id, fl_ctx) result = self._do_execute(task_name, shareable, fl_ctx, abort_signal) self.task_ctx = None return result
def _send_task(self, task_ctx: _TaskContext, msg, abort_signal): # keep sending until done fl_ctx = task_ctx.fl_ctx task_name = task_ctx.task_name task_id = task_ctx.task_id task_ctx.send_rc = ReturnCode.OK while True: if self.is_done or abort_signal.triggered: self.log_info(fl_ctx, "task aborted - ask agent to abort the task") # it's possible that the agent may have already received the task # we ask it to abort the task. self._ask_agent_to_abort_task(task_name, task_id) task_ctx.send_rc = ReturnCode.TASK_ABORTED return if task_ctx.result_received_time: # the result has been received # this could happen only when we thought the previous send didn't succeed, but it actually did! return self.log_info(fl_ctx, f"try to send task to {self.flare_agent_fqcn}") start = time.time() reply = self.cell.send_request( channel=defs.CHANNEL, topic=defs.TOPIC_GET_TASK, request=msg, target=self.flare_agent_fqcn, timeout=self.send_task_timeout, ) rc = reply.get_header(MessageHeaderKey.RETURN_CODE) if rc == CellReturnCode.OK: self.log_info(fl_ctx, f"Sent task to {self.flare_agent_fqcn} in {time.time() - start} secs") return elif rc == CellReturnCode.INVALID_REQUEST: self.log_error(fl_ctx, f"Task rejected by {self.flare_agent_fqcn}: {rc}") task_ctx.send_rc = ReturnCode.BAD_REQUEST_DATA return else: self.log_error(fl_ctx, f"Failed to send task to {self.flare_agent_fqcn}: {rc}. Will keep trying.") time.sleep(2.0) def _do_execute(self, task_name: str, shareable: Shareable, fl_ctx: FLContext, abort_signal: Signal) -> Shareable: try: dxo = from_shareable(shareable) except: self.log_error(fl_ctx, "Unable to extract dxo from shareable.") return make_reply(ReturnCode.BAD_TASK_DATA) # Ensure data kind is weights. if dxo.data_kind != DataKind.WEIGHTS: self.log_error(fl_ctx, f"data_kind expected WEIGHTS but got {dxo.data_kind} instead.") return make_reply(ReturnCode.BAD_TASK_DATA) # send to flare agent task_ctx = self.task_ctx task_id = task_ctx.task_id data = dxo.data if not data: data = {} meta = dxo.meta if not meta: meta = {} current_round = shareable.get_header(AppConstants.CURRENT_ROUND, None) total_rounds = shareable.get_header(AppConstants.NUM_ROUNDS, None) meta[defs.MetaKey.DATA_KIND] = dxo.data_kind if current_round is not None: meta[defs.MetaKey.CURRENT_ROUND] = current_round if total_rounds is not None: meta[defs.MetaKey.TOTAL_ROUND] = total_rounds msg = new_message( headers={ defs.MsgHeader.TASK_ID: task_id, defs.MsgHeader.TASK_NAME: task_name, }, payload={defs.PayloadKey.DATA: data, defs.PayloadKey.META: meta}, ) # keep sending until done self._send_task(task_ctx, msg, abort_signal) if task_ctx.send_rc != ReturnCode.OK: # send_task failed return make_reply(task_ctx.send_rc) # wait for result self.log_info(fl_ctx, f"Waiting for result from {self.flare_agent_fqcn}") waiter_timeout = 0.5 while True: if task_ctx.result_waiter.wait(timeout=waiter_timeout): # result available break else: # timed out - check other conditions if self.is_done or abort_signal.triggered: self.log_info(fl_ctx, "task is aborted") # notify the agent self._ask_agent_to_abort_task(task_name, task_id) self.task_ctx = None return make_reply(ReturnCode.TASK_ABORTED) # convert the result if task_ctx.result_rc != defs.RC.OK: return make_reply(task_ctx.result_rc) result = task_ctx.result meta = result.get(defs.PayloadKey.META) data_kind = meta.get(defs.MetaKey.DATA_KIND, DataKind.WEIGHTS) dxo = DXO( data_kind=data_kind, data=result.get(defs.PayloadKey.DATA), meta=meta, ) return dxo.to_shareable() def _ask_agent_to_abort_task(self, task_name, task_id): msg = new_message( headers={ defs.MsgHeader.TASK_ID: task_id, defs.MsgHeader.TASK_NAME: task_name, } ) self.cell.fire_and_forget( channel=defs.CHANNEL, topic=defs.TOPIC_ABORT, message=msg, targets=[self.flare_agent_fqcn], optional=True, ) @staticmethod def _finish_result(task_ctx: _TaskContext, result_rc="", result=None, result_is_valid=True): task_ctx.result_rc = result_rc task_ctx.result = result task_ctx.result_received_time = time.time() task_ctx.result_waiter.set() if result_is_valid: return make_cell_reply(CellReturnCode.OK) else: return make_cell_reply(CellReturnCode.INVALID_REQUEST) def _receive_result(self, request: Message) -> Union[None, Message]: sender = request.get_header(MessageHeaderKey.ORIGIN) task_id = request.get_header(defs.MsgHeader.TASK_ID) task_ctx = self.task_ctx if not task_ctx: self.logger.error(f"received result from {sender} for task {task_id} while not waiting for result!") return make_cell_reply(CellReturnCode.INVALID_REQUEST) fl_ctx = task_ctx.fl_ctx if task_id != task_ctx.task_id: self.log_error(fl_ctx, f"received task id {task_id} != expected {task_ctx.task_id}") return make_cell_reply(CellReturnCode.INVALID_REQUEST) if task_ctx.result_received_time: # already received - this is a dup self.log_info(fl_ctx, f"received duplicate result from {sender}") return make_cell_reply(CellReturnCode.OK) payload = request.payload if not isinstance(payload, dict): self.log_error(fl_ctx, f"bad result from {sender}: expect dict but got {type(payload)}") return self._finish_result(task_ctx, result_is_valid=False, result_rc=ReturnCode.EXECUTION_EXCEPTION) data = payload.get(defs.PayloadKey.DATA) if data is None: self.log_error(fl_ctx, f"bad result from {sender}: missing {defs.PayloadKey.DATA}") return self._finish_result(task_ctx, result_is_valid=False, result_rc=ReturnCode.EXECUTION_EXCEPTION) meta = payload.get(defs.PayloadKey.META) if meta is None: self.log_error(fl_ctx, f"bad result from {sender}: missing {defs.PayloadKey.META}") return self._finish_result(task_ctx, result_is_valid=False, result_rc=ReturnCode.EXECUTION_EXCEPTION) self.log_info(fl_ctx, f"received result from {sender}") return self._finish_result( task_ctx, result_is_valid=True, result_rc=request.get_header(defs.MsgHeader.RC, defs.RC.OK), result=payload, )