Source code for nvflare.client.flare_agent

# Copyright (c) 2023, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import logging
import threading
import time
import traceback
from typing import Any, Optional

from nvflare.apis.dxo import DXO, MetaKey, from_shareable
from nvflare.apis.fl_constant import FLContextKey
from nvflare.apis.fl_constant import ReturnCode as RC
from nvflare.apis.shareable import Shareable
from nvflare.apis.utils.decomposers import flare_decomposers
from nvflare.app_common.app_constant import AppConstants
from nvflare.app_common.decomposers import common_decomposers
from nvflare.fuel.utils.constants import PipeChannelName
from nvflare.fuel.utils.pipe.cell_pipe import CellPipe
from nvflare.fuel.utils.pipe.pipe import Message, Mode, Pipe
from nvflare.fuel.utils.pipe.pipe_handler import PipeHandler


[docs]class FlareAgentException(Exception): pass
[docs]class AgentClosed(FlareAgentException): pass
[docs]class CallStateError(FlareAgentException): pass
[docs]class Task: def __init__(self, task_name: str, task_id: str, data): self.task_name = task_name self.task_id = task_id self.data = data def __str__(self): return f"'{self.task_name} {self.task_id}'"
class _TaskContext: def __init__(self, task_id, task_name: str, msg_id): self.task_id = task_id self.task_name = task_name self.msg_id = msg_id
[docs]class FlareAgent: def __init__( self, pipe: Pipe, read_interval=0.1, heartbeat_interval=5.0, heartbeat_timeout=30.0, resend_interval=2.0, max_resends=None, submit_result_timeout=30.0, metric_pipe=None, task_channel_name: str = PipeChannelName.TASK, metric_channel_name: str = PipeChannelName.METRIC, close_pipe: bool = True, close_metric_pipe: bool = True, ): """Constructor of Flare Agent. The agent is responsible for communicating with the Flare Client Job cell (CJ) to get task and to submit task result. Args: pipe (Pipe): pipe for task communication. read_interval (float): how often to read from the pipe. Defaults to 0.1. heartbeat_interval (float): how often to send a heartbeat to the peer. Defaults to 5.0. heartbeat_timeout (float): how long to wait for a heartbeat from the peer before treating the peer as dead, 0 means DO NOT check for heartbeat. Defaults to 30.0. resend_interval (float): how often to resend a message if failing to send. None means no resend. Note that if the pipe does not support resending, then no resend. Defaults to 2.0. max_resends (int, optional): max number of resend. None means no limit. Defaults to None. submit_result_timeout (float): when submitting task result, how long to wait for response from the CJ. Defaults to 30.0. metric_pipe (Pipe, optional): pipe for metric communication. Defaults to None. task_channel_name (str): channel name for task. Defaults to ``task``. metric_channel_name (str): channel name for metric. Defaults to ``metric``. close_pipe (bool): whether to close the task pipe when stopped. Defaults to True. Usually for ``FilePipe`` we set to False, for ``CellPipe`` we set to True. close_metric_pipe (bool): whether to close the metric pipe when stopped. Defaults to True. Usually for ``FilePipe`` we set to False, for ``CellPipe`` we set to True. """ flare_decomposers.register() common_decomposers.register() self.logger = logging.getLogger(self.__class__.__name__) self.pipe = pipe self.pipe_handler = PipeHandler( pipe=self.pipe, read_interval=read_interval, heartbeat_interval=heartbeat_interval, heartbeat_timeout=heartbeat_timeout, resend_interval=resend_interval, max_resends=max_resends, ) self.submit_result_timeout = submit_result_timeout self.task_channel_name = task_channel_name self.metric_channel_name = metric_channel_name self.metric_pipe = metric_pipe self.metric_pipe_handler = None if self.metric_pipe: self.metric_pipe_handler = PipeHandler( pipe=self.metric_pipe, read_interval=read_interval, heartbeat_interval=heartbeat_interval, heartbeat_timeout=heartbeat_timeout, resend_interval=resend_interval, max_resends=max_resends, ) self.current_task = None self.task_lock = threading.Lock() self.asked_to_stop = False self._close_pipe = close_pipe self._close_metric_pipe = close_metric_pipe
[docs] def start(self): """Start the agent. This method must be called to enable CJ/Agent communication. Returns: None """ self.pipe.open(self.task_channel_name) self.pipe_handler.set_status_cb(self._status_cb, pipe_handler=self.pipe_handler, channel=self.task_channel_name) self.pipe_handler.start() if self.metric_pipe: self.metric_pipe.open(self.metric_channel_name) self.metric_pipe_handler.set_status_cb( self._status_cb, pipe_handler=self.metric_pipe_handler, channel=self.metric_channel_name ) self.metric_pipe_handler.start()
def _status_cb(self, msg: Message, pipe_handler: PipeHandler, channel): self.logger.info(f"{channel} pipe status changed to {msg.topic}: {msg.data}") self.asked_to_stop = True pipe_handler.stop(self._close_pipe)
[docs] def stop(self): """Stop the agent. After this is called, there will be no more communications between CJ and agent. Returns: None """ self.logger.info("Calling flare agent stop") self.asked_to_stop = True self.pipe_handler.stop(self._close_pipe) if self.metric_pipe_handler: self.metric_pipe_handler.stop(self._close_metric_pipe)
[docs] def shareable_to_task_data(self, shareable: Shareable) -> Any: """Convert the Shareable object received from the TaskExchanger to an app-friendly format. Subclass can override this method to convert to its own app-friendly task data. By default, we convert to DXO object. Args: shareable: the Shareable object received from the TaskExchanger. Returns: task data. """ try: dxo = from_shareable(shareable) # add training-related headers carried in the Shareable header to the DXO meta. total_rounds = shareable.get_header(AppConstants.NUM_ROUNDS) if total_rounds is not None: dxo.set_meta_prop(MetaKey.TOTAL_ROUNDS, total_rounds) current_round = shareable.get_header(AppConstants.CURRENT_ROUND) if current_round is not None: dxo.set_meta_prop(MetaKey.CURRENT_ROUND, current_round) return dxo except Exception as ex: self.logger.error(f"failed to extract DXO from shareable object: {ex}") raise ex
[docs] def get_task(self, timeout: Optional[float] = None) -> Optional[Task]: """Get a task from FLARE. This is a blocking call. Args: timeout (float, optional): If specified, this call is blocked only for the specified amount of time. If not specified, this call is blocked forever until a task has been received or agent has been closed. Returns: None if no task is available before timeout; or a Task object if task is available. Raises: AgentClosed exception if the agent has been closed before timeout. CallStateError exception if the call has not been made properly. AgentAbortException: If the other endpoint of the pipe requests to abort. AgentEndException: If the other endpoint has ended. AgentPeerGoneException: If the other endpoint is gone. Note: the application must make the call only when it is just started or after a previous task's result has been submitted. """ start_time = time.time() while True: if self.asked_to_stop: raise AgentClosed("agent closed") if self.current_task: raise CallStateError("application called get_task while the current task is not processed") if timeout is not None and time.time() - start_time >= timeout: self.logger.debug("get request timeout") return None req: Optional[Message] = self.pipe_handler.get_next() if req is not None: if not isinstance(req.data, Shareable): self.logger.error(f"bad task: expect request data to be Shareable but got {type(req.data)}") raise RuntimeError("bad request data") shareable = req.data task_data = self.shareable_to_task_data(shareable) task_id = shareable.get_header(FLContextKey.TASK_ID) task_name = shareable.get_header(FLContextKey.TASK_NAME) tc = _TaskContext( task_id=task_id, task_name=task_name, msg_id=req.msg_id, ) self.current_task = tc return Task(task_name=tc.task_name, task_id=tc.task_id, data=task_data) time.sleep(0.5)
[docs] def submit_result(self, result, rc=RC.OK) -> bool: """Submit the result of the current task. This is a blocking call. The agent will try to send the result to flare site until it is successfully sent or the task is aborted or the agent is closed. Args: result: result to be submitted rc: return code Returns: whether the result is submitted successfully Raises: the CallStateError exception if the submit_result call is not made properly. Notes: the application must only make this call after the received task is processed. The call can only be made a single time regardless whether the submission is successful. """ with self.task_lock: current_task = self.current_task if not current_task: self.logger.error("submit_result is called but there is no current task!") return False try: result = self._do_submit_result(current_task, result, rc) except Exception as ex: self.logger.error(f"exception submitting result to {current_task.sender}: {ex}") traceback.print_exc() result = False with self.task_lock: self.current_task = None return result
[docs] def task_result_to_shareable(self, result: Any, rc) -> Shareable: """Convert the result object to Shareable object before sending back to the TaskExchanger. Subclass can override this method to convert its app-friendly result type to Shareable. By default, we expect the result to be DXO object. Args: result: the result object to be converted to Shareable. If None, an empty Shareable object will be created with the rc only. rc: the return code. Returns: A Shareable object """ if result is not None: if not isinstance(result, DXO): self.logger.error(f"expect result to be DXO but got {type(result)}") raise RuntimeError("bad result data") result = result.to_shareable() else: result = Shareable() result.set_return_code(rc) return result
def _do_submit_result(self, current_task: _TaskContext, result, rc): result = self.task_result_to_shareable(result, rc) reply = Message.new_reply(topic=current_task.task_name, req_msg_id=current_task.msg_id, data=result) return self.pipe_handler.send_to_peer(reply, self.submit_result_timeout)
[docs] def log(self, record: DXO) -> bool: """Logs a metric record. Args: record (DXO): A metric record. Returns: whether the metric record is submitted successfully """ if not self.metric_pipe_handler: raise RuntimeError("metric pipe is not available") msg = Message.new_request(topic="metric", data=record) return self.metric_pipe_handler.send_to_peer(msg, self.submit_result_timeout)
[docs]class FlareAgentWithCellPipe(FlareAgent): def __init__( self, agent_id: str, site_name: str, root_url: str, secure_mode: bool, workspace_dir: str, read_interval=0.1, heartbeat_interval=5.0, heartbeat_timeout=30.0, resend_interval=2.0, max_resends=None, submit_result_timeout=30.0, has_metrics=False, ): """Constructor of Flare Agent with Cell Pipe. This is a convenient class. Args: agent_id (str): unique id to guarantee the uniqueness of cell's FQCN. site_name (str): name of the FLARE site root_url (str): the root url of the cellnet that the pipe's cell will join secure_mode (bool): whether connection to the root is secure (TLS) workspace_dir (str): the directory that contains startup for joining the cellnet. Required only in secure mode read_interval (float): how often to read from the pipe. heartbeat_interval (float): how often to send a heartbeat to the peer. heartbeat_timeout (float): how long to wait for a heartbeat from the peer before treating the peer as gone, 0 means DO NOT check for heartbeat. resend_interval (float): how often to resend a message if failing to send. None means no resend. Note that if the pipe does not support resending, then no resend. max_resends (int, optional): max number of resend. None means no limit. submit_result_timeout (float): when submitting task result, how long to wait for response from the CJ. has_metrics (bool): has metric pipe or not. """ pipe = CellPipe( mode=Mode.ACTIVE, token=agent_id, site_name=site_name, root_url=root_url, secure_mode=secure_mode, workspace_dir=workspace_dir, ) metric_pipe = None if has_metrics: metric_pipe = CellPipe( mode=Mode.ACTIVE, token=agent_id, site_name=site_name, root_url=root_url, secure_mode=secure_mode, workspace_dir=workspace_dir, ) super().__init__( pipe, read_interval=read_interval, heartbeat_interval=heartbeat_interval, heartbeat_timeout=heartbeat_timeout, resend_interval=resend_interval, max_resends=max_resends, submit_result_timeout=submit_result_timeout, metric_pipe=metric_pipe, )