Source code for nvflare.client.ipc.ipc_agent

# Copyright (c) 2024, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import logging
import threading
import time
import traceback
from typing import Union

from nvflare.app_common.decomposers import numpy_decomposers
from nvflare.client.ipc import defs
from nvflare.fuel.f3.cellnet.cell import Cell, Message
from nvflare.fuel.f3.cellnet.defs import MessageHeaderKey, ReturnCode
from nvflare.fuel.f3.cellnet.net_agent import NetAgent
from nvflare.fuel.f3.cellnet.utils import make_reply
from nvflare.fuel.f3.drivers.driver_params import DriverParams
from nvflare.fuel.utils.config_service import ConfigService

_SSL_ROOT_CERT = "rootCA.pem"
_SHORT_SLEEP_TIME = 0.2


[docs]class IPCAgent: def __init__( self, flare_site_url: str, flare_site_name: str, agent_id: str, workspace_dir: str, secure_mode=False, submit_result_timeout=30.0, flare_site_connection_timeout=60.0, flare_site_heartbeat_timeout=None, resend_result_interval=2.0, ): """Constructor of Flare Agent. The agent is responsible for communicating with the Flare Client Job cell (CJ) to get task and to submit task result. Args: flare_site_url: the URL to the client parent cell (CP) flare_site_name: the CJ's site name (client name) agent_id: the unique ID of the agent workspace_dir: directory that contains startup folder and comm_config.json secure_mode: whether the connection is in secure mode or not submit_result_timeout: when submitting task result, how long to wait for response from the CJ flare_site_heartbeat_timeout: time for missing heartbeats from CJ before considering it dead flare_site_connection_timeout: time for missing heartbeats from CJ before considering it disconnected """ ConfigService.initialize(section_files={}, config_path=[workspace_dir]) self.logger = logging.getLogger(self.__class__.__name__) self.cell_name = defs.agent_site_fqcn(flare_site_name, agent_id) self.workspace_dir = workspace_dir self.secure_mode = secure_mode self.flare_site_url = flare_site_url self.submit_result_timeout = submit_result_timeout self.flare_site_heartbeat_timeout = flare_site_heartbeat_timeout self.flare_site_connection_timeout = flare_site_connection_timeout self.resend_result_interval = resend_result_interval self.num_results_submitted = 0 self.current_task = None self.pending_task = None self.task_lock = threading.Lock() self.last_msg_time = time.time() # last time to get msg from flare site self.peer_fqcn = None self.is_done = False self.is_started = False # has the agent been started? self.is_stopped = False # has the agent been stopped? self.is_connected = False # is the agent connected to the flare site? self.credentials = {} # security credentials for secure connection if secure_mode: root_cert_path = ConfigService.find_file(_SSL_ROOT_CERT) if not root_cert_path: raise ValueError(f"cannot find {_SSL_ROOT_CERT} from config path {workspace_dir}") self.credentials = { DriverParams.CA_CERT.value: root_cert_path, } self.cell = Cell( fqcn=self.cell_name, root_url="", parent_url=self.flare_site_url, secure=self.secure_mode, credentials=self.credentials, create_internal_listener=False, ) self.net_agent = NetAgent(self.cell) self.cell.register_request_cb(channel=defs.CHANNEL, topic=defs.TOPIC_GET_TASK, cb=self._receive_task) self.logger.info(f"registered task CB for {defs.CHANNEL} {defs.TOPIC_GET_TASK}") self.cell.register_request_cb(channel=defs.CHANNEL, topic=defs.TOPIC_HEARTBEAT, cb=self._handle_heartbeat) self.cell.register_request_cb(channel=defs.CHANNEL, topic=defs.TOPIC_BYE, cb=self._handle_bye) self.cell.register_request_cb(channel=defs.CHANNEL, topic=defs.TOPIC_ABORT, cb=self._handle_abort_task) self.cell.core_cell.add_incoming_request_filter( channel="*", topic="*", cb=self._msg_received, ) self.cell.core_cell.add_incoming_reply_filter( channel="*", topic="*", cb=self._msg_received, ) numpy_decomposers.register()
[docs] def start(self): """Start the agent. This method must be called to enable CJ/Agent communication. Returns: None """ if self.is_started: self.logger.warning("the agent is already started") return if self.is_stopped: raise defs.CallStateError("cannot start the agent since it is already stopped") self.is_started = True self.logger.info(f"starting agent {self.cell_name} ...") self.cell.start() t = threading.Thread(target=self._monitor, daemon=True) t.start()
[docs] def stop(self): """Stop the agent. After this is called, there will be no more communications between CJ and agent. Returns: None """ if not self.is_started: self.logger.warning("cannot stop the agent since it is not started") return if self.is_stopped: self.logger.warning("agent is already stopped") return self.is_stopped = True self.cell.stop() self.net_agent.close()
def _monitor(self): while True: since_last_msg = time.time() - self.last_msg_time if since_last_msg > self.flare_site_connection_timeout: if self.is_connected: self.logger.error( "flare site disconnected since no message received " f"for {self.flare_site_connection_timeout} seconds" ) self.is_connected = False if self.flare_site_heartbeat_timeout and since_last_msg > self.flare_site_heartbeat_timeout: self.logger.error( f"flare site is dead since no message received for {self.flare_site_heartbeat_timeout} seconds" ) self.is_done = True return time.sleep(_SHORT_SLEEP_TIME) def _handle_bye(self, request: Message) -> Union[None, Message]: peer = request.get_header(MessageHeaderKey.ORIGIN) self.logger.info(f"got goodbye from {peer}") self.is_done = True return make_reply(ReturnCode.OK) def _msg_received(self, request: Message): peer = request.get_header(MessageHeaderKey.ORIGIN) if self.peer_fqcn and self.peer_fqcn != peer: # this could happen when a new job is started for the same training self.logger.warning(f"got peer FQCN '{peer}' while expecting '{self.peer_fqcn}'") self.peer_fqcn = peer self.last_msg_time = time.time() if not self.is_connected: self.is_connected = True self.logger.info(f"connected to flare site {peer}") def _handle_heartbeat(self, request: Message) -> Union[None, Message]: peer = request.get_header(MessageHeaderKey.ORIGIN) self.logger.debug(f"got heartbeat from {peer}") return make_reply(ReturnCode.OK) def _handle_abort_task(self, request: Message) -> Union[None, Message]: peer = request.get_header(MessageHeaderKey.ORIGIN) task_id = request.get_header(defs.MsgHeader.TASK_ID) task_name = request.get_header(defs.MsgHeader.TASK_NAME) self.logger.warning(f"received from {peer} to abort {task_name=} {task_id=}") with self.task_lock: if self.current_task and task_id == self.current_task.task_id: self.current_task.aborted = True elif self.pending_task and task_id == self.pending_task.task_id: self.pending_task = None return make_reply(ReturnCode.OK) def _receive_task(self, request: Message) -> Union[None, Message]: with self.task_lock: return self._do_receive_task(request) def _create_task(self, request: Message): peer = request.get_header(MessageHeaderKey.ORIGIN) task_id = request.get_header(defs.MsgHeader.TASK_ID) task_name = request.get_header(defs.MsgHeader.TASK_NAME) self.logger.info(f"received task from {peer}: {task_name=} {task_id=}") task_data = request.payload if not isinstance(task_data, dict): self.logger.error(f"bad task data from {peer}: expect dict but got {type(task_data)}") return None data = task_data.get(defs.PayloadKey.DATA) if not data: self.logger.error(f"bad task data from {peer}: missing {defs.PayloadKey.DATA}") return None meta = task_data.get(defs.PayloadKey.META) if not meta: self.logger.error(f"bad task data from {peer}: missing {defs.PayloadKey.META}") return None return defs.Task(task_name, task_id, meta, data) def _do_receive_task(self, request: Message) -> Union[None, Message]: peer = request.get_header(MessageHeaderKey.ORIGIN) task_id = request.get_header(defs.MsgHeader.TASK_ID) task_name = request.get_header(defs.MsgHeader.TASK_NAME) # create a new task new_task = self._create_task(request) if not new_task: return make_reply(ReturnCode.INVALID_REQUEST) if self.pending_task: assert isinstance(self.pending_task, defs.Task) if task_id == self.pending_task.task_id: return make_reply(ReturnCode.OK) else: # this could happen when the CJ is restarted self.logger.warning(f"got new task from {peer} while already having a pending task!") # replace the pending task self.pending_task = new_task return make_reply(ReturnCode.OK) current_task = self.current_task if current_task: assert isinstance(current_task, defs.Task) if task_id == current_task.task_id: self.logger.info(f"received duplicate task {task_id} from {peer}") return make_reply(ReturnCode.OK) if current_task.last_send_result_time: # we already tried to send result back # assume that the flare site has received # we set the flag so the sending process will end quickly # in the meanwhile we ask flare site to retry later current_task.already_received = True else: # error - one task at a time self.logger.warning( f"got task {task_name} {task_id} from {peer} " f"while still working on {current_task.task_name} {current_task.task_id}" ) # this could happen when CJ is restarted while we are processing current task # we set the current_task to be aborted. App should check this flag frequently to abort processing current_task.aborted = True # treat the new task as pending task - it will become current after the current_task is submitted self.pending_task = new_task return make_reply(ReturnCode.OK) else: # no current task self.current_task = new_task return make_reply(ReturnCode.OK)
[docs] def get_task(self, timeout=None): """Get a task from FLARE. This is a blocking call. If timeout is specified, this call is blocked only for the specified amount of time. If timeout is not specified, this call is blocked forever until a task is received or agent is closed. Args: timeout: amount of time to block Returns: None if no task is available during before timeout; or a Task object if task is available. Raises: AgentClosed exception if the agent is closed before timeout. CallStateError exception if the call is not made properly. Note: the application must make the call only when it is just started or after a previous task's result has been submitted. """ if timeout is not None: if not isinstance(timeout, (int, float)): raise TypeError(f"timeout must be (int, float) but got {type(timeout)}") if timeout <= 0: raise ValueError(f"timeout must > 0, but got {timeout}") start = time.time() while True: if self.is_done or self.is_stopped: self.logger.info("no more tasks - agent closed") raise defs.AgentClosed("flare agent is closed") with self.task_lock: current_task = self.current_task if current_task: assert isinstance(current_task, defs.Task) if current_task.aborted: pass elif current_task.status == defs.Task.NEW: current_task.status = defs.Task.FETCHED return current_task else: raise defs.CallStateError( f"application called get_task while the current task is in status {current_task.status}" ) if timeout and time.time() - start > timeout: # no task available before timeout self.logger.info(f"get_task timeout after {timeout} seconds") return None time.sleep(_SHORT_SLEEP_TIME)
[docs] def submit_result(self, result: defs.TaskResult) -> bool: """Submit the result of the current task. This is a blocking call. The agent will try to send the result to flare site until it is successfully sent or the task is aborted or the agent is closed. Args: result: result to be submitted Returns: whether the result is submitted successfully Raises: the CallStateError exception if the submit_result call is not made properly. Notes: the application must only make this call after the received task is processed. The call can only be made a single time regardless whether the submission is successful. """ try: result_submitted = self._do_submit_result(result) except Exception as ex: self.logger.error(f"exception encountered: {ex}") result_submitted = False with self.task_lock: self.current_task = None if self.pending_task: # a new task is waiting for the current task to finish self.current_task = self.pending_task self.pending_task = None return result_submitted
def _do_submit_result(self, result: defs.TaskResult) -> bool: if not isinstance(result, defs.TaskResult): raise TypeError(f"result must be TaskResult but got {type(result)}") with self.task_lock: current_task = self.current_task if current_task: if current_task.aborted: return False if current_task.status != defs.Task.FETCHED: raise defs.CallStateError( f"submit_result is called while current task is in status {current_task.status}" ) current_task.status = defs.Task.PROCESSED elif self.num_results_submitted > 0: self.logger.error("submit_result is called but there is no current task!") return False else: # if the agent is restarted, it may pick up from previous checkpoint and continue training. # then it can send the result after finish training. pass self.num_results_submitted += 1 try: return self._send_result(current_task, result) except: self.logger.error(f"exception submitting result to {current_task.sender}") traceback.print_exc() return False def _send_result(self, current_task: defs.Task, result: defs.TaskResult): meta = result.meta rc = result.return_code data = result.data msg = Message( headers={ defs.MsgHeader.TASK_NAME: current_task.task_name if current_task else "", defs.MsgHeader.TASK_ID: current_task.task_id if current_task else "", defs.MsgHeader.RC: rc, }, payload={ defs.PayloadKey.META: meta, defs.PayloadKey.DATA: data, }, ) last_send_time = 0 while True: if self.is_done or self.is_stopped: self.logger.error(f"quit submitting result for task {current_task} since agent is closed") raise defs.AgentClosed("agent is stopped") if current_task and current_task.already_received: if not current_task.last_send_result_time: self.logger.warning(f"task {current_task} was marked already_received but has been sent!") return True if current_task and current_task.aborted: self.logger.error(f"quit submitting result for task {current_task} since it is aborted") return False if self.is_connected and time.time() - last_send_time > self.resend_result_interval: self.logger.info(f"sending result to {self.peer_fqcn} for task {current_task}") if current_task: current_task.last_send_result_time = time.time() reply = self.cell.send_request( channel=defs.CHANNEL, topic=defs.TOPIC_SUBMIT_RESULT, target=self.peer_fqcn, request=msg, timeout=self.submit_result_timeout, ) last_send_time = time.time() if reply: rc = reply.get_header(MessageHeaderKey.RETURN_CODE) peer = reply.get_header(MessageHeaderKey.ORIGIN) if rc == ReturnCode.OK: return True elif rc == ReturnCode.INVALID_REQUEST: self.logger.error(f"received return code from {peer}: {rc}") return False else: self.logger.info(f"failed to send to {self.peer_fqcn}: {rc} - will retry") time.sleep(_SHORT_SLEEP_TIME)