Source code for nvflare.fuel.utils.pipe.cell_pipe

# Copyright (c) 2023, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import logging
import queue
import threading
import time
from typing import Tuple, Union

from nvflare.fuel.f3.cellnet.cell import Cell
from nvflare.fuel.f3.cellnet.cell import Message as CellMessage
from nvflare.fuel.f3.cellnet.defs import MessageHeaderKey, ReturnCode
from nvflare.fuel.f3.cellnet.net_agent import NetAgent
from nvflare.fuel.f3.cellnet.utils import make_reply
from nvflare.fuel.f3.drivers.driver_params import DriverParams
from nvflare.fuel.utils.attributes_exportable import ExportMode
from nvflare.fuel.utils.config_service import search_file
from nvflare.fuel.utils.constants import Mode
from nvflare.fuel.utils.validation_utils import check_object_type, check_str

from .pipe import Message, Pipe, Topic

SSL_ROOT_CERT = "rootCA.pem"
_PREFIX = "cell_pipe."

_HEADER_MSG_TYPE = _PREFIX + "msg_type"
_HEADER_MSG_ID = _PREFIX + "msg_id"
_HEADER_REQ_ID = _PREFIX + "req_id"
_HEADER_START_TIME = _PREFIX + "start"
_HEADER_HB_SEQ = _PREFIX + "hb_seq"


def _cell_fqcn(mode, site_name, token):
    # The FQCN of the cell must be unique in the whole cellnet.
    # We use the combination of mode, site_name, and token to derive the value of FQCN
    # Since the token is usually used across all sites, the "site_name" differentiate cell on one site from another.
    # The two peer pipes on the same site share the same site_name and token, but are differentiated by their modes.
    return f"{site_name}_{token}_{mode}"


def _to_cell_message(msg: Message, extra=None) -> CellMessage:
    headers = {_HEADER_MSG_TYPE: msg.msg_type, _HEADER_MSG_ID: msg.msg_id, _HEADER_START_TIME: time.time()}
    if extra:
        headers.update(extra)
    if msg.req_id:
        headers[_HEADER_REQ_ID] = msg.req_id

    return CellMessage(headers=headers, payload=msg.data)


def _from_cell_message(cm: CellMessage) -> Message:
    return Message(
        msg_id=cm.get_header(_HEADER_MSG_ID),
        msg_type=cm.get_header(_HEADER_MSG_TYPE),
        topic=cm.get_header(MessageHeaderKey.TOPIC),
        req_id=cm.get_header(_HEADER_REQ_ID),
        data=cm.payload,
    )


class _CellInfo:

    """
    A cell could be used by multiple pipes (e.g. one pipe for task interaction, another for metrics logging).
    """

    def __init__(self, cell, net_agent):
        self.cell = cell
        self.net_agent = net_agent
        self.started = False
        self.pipes = []
        self.lock = threading.Lock()

    def start(self):
        with self.lock:
            if not self.started:
                self.cell.start()
                self.started = True

    def add_pipe(self, p):
        with self.lock:
            self.pipes.append(p)

    def close_pipe(self, p):
        with self.lock:
            try:
                self.pipes.remove(p)
                if len(self.pipes) == 0:
                    # all pipes are closed - close cell and agent
                    self.net_agent.close()
                    self.cell.stop()
            except:
                pass


[docs]class CellPipe(Pipe): """ CellPipe is an implementation of `Pipe` that utilizes the `Cell` from NVFlare's foundation layer (f3) to do the communication. """ _lock = threading.Lock() _cells_info = {} # (root_url, site_name, token) => _CellInfo @classmethod def _build_cell(cls, mode, root_url, site_name, token, secure_mode, workspace_dir): """Build a cell if necessary. The combination of (root_url, site_name, token) uniquely determine one cell. There can be multiple pipes on the same cell. Args: root_url: root url of the cell net mode: mode (passive or active) of the pipe site_name: name of the site token: the unique token secure_mode: whether cellnet is in secure mode workspace_dir: workspace that contains startup kit for connecting to server. Needed only if secure_mode Returns: """ with cls._lock: cell_key = f"{root_url}.{site_name}.{token}" ci = cls._cells_info.get(cell_key) if not ci: credentials = {} if secure_mode: root_cert_path = search_file(SSL_ROOT_CERT, workspace_dir) if not root_cert_path: raise ValueError(f"cannot find {SSL_ROOT_CERT} from config path {workspace_dir}") credentials = { DriverParams.CA_CERT.value: root_cert_path, } cell = Cell( fqcn=_cell_fqcn(mode, site_name, token), root_url=root_url, secure=secure_mode, credentials=credentials, create_internal_listener=False, ) net_agent = NetAgent(cell) ci = _CellInfo(cell, net_agent) cls._cells_info[cell_key] = ci return ci def __init__( self, mode: Mode, site_name: str, token: str, root_url: str = "", secure_mode: bool = True, workspace_dir: str = "", ): """The constructor of the CellPipe. Args: mode: passive or active mode site_name (str): name of the FLARE site token (str): unique id to guarantee the uniqueness of cell's FQCN. root_url (str): the root url of the cellnet that the pipe's cell will join secure_mode (bool): whether connection to the root is secure (TLS) workspace_dir (str): the directory that contains startup for joining the cellnet. Required only in secure_mode """ super().__init__(mode) self.logger = logging.getLogger(self.__class__.__name__) check_str("root_url", root_url) check_object_type("secure_mode", secure_mode, bool) check_str("token", token) check_str("site_name", site_name) check_str("workspace_dir", workspace_dir) self.root_url = root_url self.secure_mode = secure_mode self.workspace_dir = workspace_dir self.site_name = site_name self.token = token mode = f"{mode}".strip().lower() # convert to lower case string self.ci = self._build_cell(mode, root_url, site_name, token, secure_mode, workspace_dir) self.cell = self.ci.cell self.ci.add_pipe(self) if mode == "active": peer_mode = "passive" elif mode == "passive": peer_mode = "active" else: raise ValueError(f"invalid mode {mode} - must be 'active' or 'passive'") self.peer_fqcn = _cell_fqcn(peer_mode, site_name, token) self.received_msgs = queue.Queue() # contains Message(s), not CellMessage(s)! self.channel = None # the cellnet message channel self.pipe_lock = threading.Lock() # used to ensure no msg to be sent after closed self.closed = False self.last_peer_active_time = 0.0 self.hb_seq = 1 def _update_peer_active_time(self, msg: CellMessage, ch_name: str, msg_type: str): origin = msg.get_header(MessageHeaderKey.ORIGIN) if origin == self.peer_fqcn: self.logger.debug(f"{time.time()}: _update_peer_active_time: {ch_name=} {msg_type=} {msg.headers}") self.last_peer_active_time = time.time()
[docs] def get_last_peer_active_time(self): return self.last_peer_active_time
[docs] def set_cell_cb(self, channel_name: str): # This allows multiple pipes over the same cell (e.g. one channel for tasks, another for metrics), # as long as different pipes use different cell message channels self.channel = f"{_PREFIX}{channel_name}" self.cell.register_request_cb(channel=self.channel, topic="*", cb=self._receive_message) self.cell.core_cell.add_incoming_request_filter( channel="*", topic="*", cb=self._update_peer_active_time, ch_name=channel_name, msg_type="req" ) self.cell.core_cell.add_incoming_reply_filter( channel="*", topic="*", cb=self._update_peer_active_time, ch_name=channel_name, msg_type="reply" ) self.logger.info(f"registered CellPipe request CB for {self.channel}")
[docs] def send(self, msg: Message, timeout=None) -> bool: """Sends the specified message to the peer. Args: msg: the message to be sent timeout: if specified, number of secs to wait for the peer to read the message. If not specified, wait indefinitely. Returns: Whether the message is read by the peer. """ with self.pipe_lock: if self.closed: raise BrokenPipeError("pipe closed") # Note: the following code must not be within the lock scope # Otherwise only one message can be sent at a time! optional = False if msg.topic in [Topic.END, Topic.ABORT, Topic.HEARTBEAT]: optional = True if not timeout and msg.topic in [Topic.END, Topic.ABORT]: timeout = 5.0 # need to keep the connection for some time; otherwise the msg may not go out if msg.topic == Topic.HEARTBEAT: # for debugging purpose extra_headers = {_HEADER_HB_SEQ: self.hb_seq} self.hb_seq += 1 # don't need to wait for reply! self.cell.fire_and_forget( channel=self.channel, topic=msg.topic, targets=[self.peer_fqcn], message=_to_cell_message(msg, extra_headers), optional=optional, ) return True reply = self.cell.send_request( channel=self.channel, topic=msg.topic, target=self.peer_fqcn, request=_to_cell_message(msg), timeout=timeout, optional=optional, ) if reply: rc = reply.get_header(MessageHeaderKey.RETURN_CODE) if rc == ReturnCode.OK: return True else: err = f"failed to send '{msg.topic}' to '{self.peer_fqcn}' in channel '{self.channel}': {rc}" if optional: self.logger.debug(err) else: self.logger.error(err) return False else: return False
def _receive_message(self, request: CellMessage) -> Union[None, CellMessage]: sender = request.get_header(MessageHeaderKey.ORIGIN) topic = request.get_header(MessageHeaderKey.TOPIC) self.logger.debug(f"got msg from peer {sender}: {topic}") if self.peer_fqcn != sender: raise RuntimeError(f"peer FQCN mismatch: expect {self.peer_fqcn} but got {sender}") msg = _from_cell_message(request) self.received_msgs.put_nowait(msg) return make_reply(ReturnCode.OK)
[docs] def receive(self, timeout=None) -> Union[None, Message]: try: if timeout: return self.received_msgs.get(block=True, timeout=timeout) else: return self.received_msgs.get_nowait() except queue.Empty: return None
[docs] def clear(self): while not self.received_msgs.empty(): self.received_msgs.get_nowait()
[docs] def can_resend(self) -> bool: return True
[docs] def open(self, name: str): with self.pipe_lock: if self.closed: raise BrokenPipeError("pipe already closed") self.ci.start() self.set_cell_cb(name)
[docs] def close(self): with self.pipe_lock: if self.closed: return self.ci.close_pipe(self) self.closed = True
[docs] def export(self, export_mode: str) -> Tuple[str, dict]: if export_mode == ExportMode.SELF: mode = self.mode root_url = self.root_url else: mode = Mode.ACTIVE if self.mode == Mode.PASSIVE else Mode.PASSIVE root_url = self.cell.get_root_url_for_child() export_args = { "mode": mode, "site_name": self.site_name, "token": self.token, "root_url": root_url, "secure_mode": self.cell.core_cell.secure, "workspace_dir": self.workspace_dir, } return f"{self.__module__}.{self.__class__.__name__}", export_args