Source code for nvflare.client.ipc_agent

# Copyright (c) 2023, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import logging
import threading
import time
import traceback
from typing import Union

from nvflare.app_common.decomposers import common_decomposers
from nvflare.client import defs
from nvflare.fuel.f3.cellnet.cell import Cell, Message
from nvflare.fuel.f3.cellnet.defs import MessageHeaderKey, ReturnCode
from nvflare.fuel.f3.cellnet.net_agent import NetAgent
from nvflare.fuel.f3.cellnet.utils import make_reply, new_message
from nvflare.fuel.f3.drivers.driver_params import DriverParams
from nvflare.fuel.utils.config_service import ConfigService

SSL_ROOT_CERT = "rootCA.pem"


class _TaskContext:

    NEW = 0
    FETCHED = 1
    PROCESSED = 2

    def __init__(self, sender: str, task_name: str, task_id: str, meta: dict, data):
        self.sender = sender
        self.task_name = task_name
        self.task_id = task_id
        self.meta = meta
        self.data = data
        self.status = _TaskContext.NEW
        self.last_send_result_time = None
        self.aborted = False
        self.already_received = False

    def __str__(self):
        return f"'{self.task_name} {self.task_id}'"


[docs]class IPCAgent(defs.FlareAgent): def __init__( self, root_url: str, flare_site_name: str, agent_id: str, workspace_dir: str, secure_mode=False, submit_result_timeout=30.0, flare_site_heartbeat_timeout=60.0, job_id=None, flare_site_url=None, ): """Constructor of Flare Agent. The agent is responsible for communicating with the Flare Client Job cell (CJ) to get task and to submit task result. Args: root_url: the URL to the server parent cell (SP) flare_site_name: the CJ's site name (client name) agent_id: the unique ID of the agent workspace_dir: directory that contains startup folder and comm_config.json secure_mode: whether the connection is in secure mode or not submit_result_timeout: when submitting task result, how long to wait for response from the CJ flare_site_heartbeat_timeout: max time allowed for missing heartbeats from CJ job_id: ID of the current Flare Job. Only needed for child-based communication with CJ flare_site_url: URL for connection to CJ. Only needed for child-based communication with CJ """ ConfigService.initialize(section_files={}, config_path=[workspace_dir]) self.logger = logging.getLogger(self.__class__.__name__) self.cell_name = defs.agent_site_fqcn(flare_site_name, agent_id, job_id) self.workspace_dir = workspace_dir self.secure_mode = secure_mode self.root_url = root_url self.submit_result_timeout = submit_result_timeout self.flare_site_heartbeat_timeout = flare_site_heartbeat_timeout self.job_id = job_id self.flare_site_url = flare_site_url self.connect_waiter = threading.Event() self.current_task = None self.pending_task = None self.task_lock = threading.Lock() self.last_hb_time = time.time() self.is_done = False self.is_started = False self.is_stopped = False self.credentials = {} if secure_mode: root_cert_path = ConfigService.find_file(SSL_ROOT_CERT) if not root_cert_path: raise ValueError(f"cannot find {SSL_ROOT_CERT} from config path {workspace_dir}") self.credentials = { DriverParams.CA_CERT.value: root_cert_path, } self.cell = Cell( fqcn=self.cell_name, root_url=self.root_url, secure=self.secure_mode, credentials=self.credentials, create_internal_listener=False, parent_url=self.flare_site_url, ) self.net_agent = NetAgent(self.cell) self.cell.register_request_cb(channel=defs.CHANNEL, topic=defs.TOPIC_GET_TASK, cb=self._receive_task) self.logger.info(f"registered task CB for {defs.CHANNEL} {defs.TOPIC_GET_TASK}") self.cell.register_request_cb(channel=defs.CHANNEL, topic=defs.TOPIC_HELLO, cb=self._handle_hello) self.cell.register_request_cb(channel=defs.CHANNEL, topic=defs.TOPIC_HEARTBEAT, cb=self._handle_heartbeat) self.cell.register_request_cb(channel=defs.CHANNEL, topic=defs.TOPIC_BYE, cb=self._handle_bye) self.cell.register_request_cb(channel=defs.CHANNEL, topic=defs.TOPIC_ABORT, cb=self._handle_abort_task) common_decomposers.register()
[docs] def start(self): """Start the agent. This method must be called to enable CJ/Agent communication. Returns: None """ if self.is_started: self.logger.warning("the agent is already started") return if self.is_stopped: raise defs.CallStateError("cannot start the agent since it is already stopped") self.is_started = True self.logger.info(f"starting agent {self.cell_name} ...") self.cell.start() t = threading.Thread(target=self._maintain, daemon=True) t.start()
[docs] def stop(self): """Stop the agent. After this is called, there will be no more communications between CJ and agent. Returns: None """ if not self.is_started: self.logger.warning("cannot stop the agent since it is not started") return if self.is_stopped: self.logger.warning("agent is already stopped") return self.is_stopped = True self.cell.stop() self.net_agent.close()
def _maintain(self): self.logger.info("waiting for flare site to connect ...") start = time.time() while True: if self.connect_waiter.wait(0.5): # connected! break else: if self.is_done or self.is_stopped: return if time.time() - start > self.flare_site_heartbeat_timeout: self.logger.error( f"closing agent {self.cell_name}: flare site not connected " f"in {self.flare_site_heartbeat_timeout} seconds" ) self.is_done = True return while True: if time.time() - self.last_hb_time > self.flare_site_heartbeat_timeout: self.logger.error( f"closing agent {self.cell_name}: no heartbeat from flare site " f"for {self.flare_site_heartbeat_timeout} seconds" ) self.is_done = True return time.sleep(1.0) def _handle_hello(self, request: Message) -> Union[None, Message]: self.logger.info(f"got hello: {request.headers}") sender = request.get_header(MessageHeaderKey.ORIGIN) self.logger.info(f"connected to the flare site {sender}") self.last_hb_time = time.time() self.connect_waiter.set() return make_reply(ReturnCode.OK) def _handle_bye(self, request: Message) -> Union[None, Message]: sender = request.get_header(MessageHeaderKey.ORIGIN) self.logger.info(f"got goodbye from {sender}") self.is_done = True return make_reply(ReturnCode.OK) def _handle_heartbeat(self, request: Message) -> Union[None, Message]: self.last_hb_time = time.time() sender = request.get_header(MessageHeaderKey.ORIGIN) self.logger.info(f"got heartbeat from {sender}") return make_reply(ReturnCode.OK) def _handle_abort_task(self, request: Message) -> Union[None, Message]: sender = request.get_header(MessageHeaderKey.ORIGIN) task_id = request.get_header(defs.MsgHeader.TASK_ID) task_name = request.get_header(defs.MsgHeader.TASK_NAME) self.logger.warning(f"received from {sender} to abort {task_name=} {task_id=}") with self.task_lock: if self.current_task and task_id == self.current_task.task_id: self.current_task.aborted = True elif self.pending_task and task_id == self.pending_task.task_id: self.pending_task = None return make_reply(ReturnCode.OK) def _receive_task(self, request: Message) -> Union[None, Message]: self.logger.info("receiving task ...") with self.task_lock: return self._do_receive_task(request) def _create_task(self, request: Message): sender = request.get_header(MessageHeaderKey.ORIGIN) task_id = request.get_header(defs.MsgHeader.TASK_ID) task_name = request.get_header(defs.MsgHeader.TASK_NAME) task_data = request.payload if not isinstance(task_data, dict): self.logger.error(f"bad task data from {sender}: expect dict but got {type(task_data)}") return None data = task_data.get(defs.PayloadKey.DATA) if not data: self.logger.error(f"bad task data from {sender}: missing {defs.PayloadKey.DATA}") return None meta = task_data.get(defs.PayloadKey.META) if not meta: self.logger.error(f"bad task data from {sender}: missing {defs.PayloadKey.META}") return None return _TaskContext(sender, task_name, task_id, meta, data) def _do_receive_task(self, request: Message) -> Union[None, Message]: sender = request.get_header(MessageHeaderKey.ORIGIN) task_id = request.get_header(defs.MsgHeader.TASK_ID) task_name = request.get_header(defs.MsgHeader.TASK_NAME) self.logger.info(f"_do_receive_task from {sender}: {task_name=} {task_id=}") if self.pending_task: if task_id == self.pending_task.task_id: return make_reply(ReturnCode.OK) else: self.logger.error("got a new task while already have a pending task!") return make_reply(ReturnCode.INVALID_REQUEST) current_task = self.current_task if current_task: assert isinstance(current_task, _TaskContext) if task_id == current_task.task_id: self.logger.info(f"received duplicate task {task_id} from {sender}") return make_reply(ReturnCode.OK) if current_task.last_send_result_time: # we already tried to send result back # assume that the flare site has received # we set the flag so the sending process will end quickly # in the meanwhile we ask flare site to retry later current_task.already_received = True self.pending_task = self._create_task(request) if self.pending_task: return make_reply(ReturnCode.OK) else: return make_reply(ReturnCode.INVALID_REQUEST) else: # error - one task at a time self.logger.error( f"got task {task_name} {task_id} from {sender} " f"while still working on {current_task.task_name} {current_task.task_id}" ) return make_reply(ReturnCode.INVALID_REQUEST) self.current_task = self._create_task(request) if self.current_task: return make_reply(ReturnCode.OK) else: return make_reply(ReturnCode.INVALID_REQUEST)
[docs] def get_task(self, timeout=None): """Get a task from FLARE. This is a blocking call. If timeout is specified, this call is blocked only for the specified amount of time. If timeout is not specified, this call is blocked forever until a task is received or agent is closed. Args: timeout: amount of time to block Returns: None if no task is available during before timeout; or a Task object if task is available. Raises: AgentClosed exception if the agent is closed before timeout. CallStateError exception if the call is not made properly. Note: the application must make the call only when it is just started or after a previous task's result has been submitted. """ if timeout is not None: if not isinstance(timeout, (int, float)): raise TypeError(f"timeout must be (int, float) but got {type(timeout)}") if timeout <= 0: raise ValueError(f"timeout must > 0, but got {timeout}") start = time.time() while True: if self.is_done or self.is_stopped: self.logger.info("no more tasks - agent closed") raise defs.AgentClosed("flare agent is closed") with self.task_lock: current_task = self.current_task if current_task: assert isinstance(current_task, _TaskContext) if current_task.aborted: pass elif current_task.status == _TaskContext.NEW: current_task.status = _TaskContext.FETCHED return defs.Task( current_task.task_name, current_task.task_id, current_task.meta, current_task.data ) else: raise defs.CallStateError( f"application called get_task while the current task is in status {current_task.status}" ) if timeout and time.time() - start > timeout: # no task available before timeout self.logger.info(f"get_task timeout after {timeout} seconds") return None time.sleep(0.5)
[docs] def submit_result(self, result: defs.TaskResult) -> bool: """Submit the result of the current task. This is a blocking call. The agent will try to send the result to flare site until it is successfully sent or the task is aborted or the agent is closed. Args: result: result to be submitted Returns: whether the result is submitted successfully Raises: the CallStateError exception if the submit_result call is not made properly. Notes: the application must only make this call after the received task is processed. The call can only be made a single time regardless whether the submission is successful. """ if not isinstance(result, defs.TaskResult): raise TypeError(f"result must be TaskResult but got {type(result)}") with self.task_lock: current_task = self.current_task if not current_task: self.logger.error("submit_result is called but there is no current task!") return False assert isinstance(current_task, _TaskContext) if current_task.aborted: return False if current_task.status != _TaskContext.FETCHED: raise defs.CallStateError( f"submit_result is called while current task is in status {current_task.status}" ) current_task.status = _TaskContext.PROCESSED try: result = self._do_submit_result(current_task, result) except: self.logger.error(f"exception submitting result to {current_task.sender}") traceback.print_exc() result = False with self.task_lock: self.current_task = None if self.pending_task: # a new task is waiting for the current task to finish self.current_task = self.pending_task self.pending_task = None return result
def _do_submit_result(self, current_task: _TaskContext, result: defs.TaskResult): meta = result.meta rc = result.return_code data = result.data msg = new_message( headers={ defs.MsgHeader.TASK_NAME: current_task.task_name, defs.MsgHeader.TASK_ID: current_task.task_id, defs.MsgHeader.RC: rc, }, payload={ defs.PayloadKey.META: meta, defs.PayloadKey.DATA: data, }, ) while True: if current_task.already_received: if not current_task.last_send_result_time: self.logger.warning(f"task {current_task} was marked already_received but has been sent!") return True if self.is_done or self.is_stopped: self.logger.error(f"quit submitting result for task {current_task} since agent is closed") return False if current_task.aborted: self.logger.error(f"quit submitting result for task {current_task} since it is aborted") return False current_task.last_send_result_time = time.time() self.logger.info(f"sending result to {current_task.sender} for task {current_task}") reply = self.cell.send_request( channel=defs.CHANNEL, topic=defs.TOPIC_SUBMIT_RESULT, target=current_task.sender, request=msg, timeout=self.submit_result_timeout, ) if reply: rc = reply.get_header(MessageHeaderKey.RETURN_CODE) sender = reply.get_header(MessageHeaderKey.ORIGIN) if rc == ReturnCode.OK: return True elif rc == ReturnCode.INVALID_REQUEST: self.logger.error(f"received return code from {sender}: {rc}") return False else: self.logger.info(f"failed to send to {current_task.sender}: {rc} - will retry") time.sleep(2.0)