# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import queue
import threading
import time
from typing import Tuple, Union
from nvflare.apis.fl_constant import ConnPropKey, FLMetaKey, SystemVarName
from nvflare.fuel.data_event.utils import get_scope_property
from nvflare.fuel.f3.cellnet.cell import Cell
from nvflare.fuel.f3.cellnet.cell import Message as CellMessage
from nvflare.fuel.f3.cellnet.defs import MessageHeaderKey, ReturnCode
from nvflare.fuel.f3.cellnet.fqcn import FQCN
from nvflare.fuel.f3.cellnet.net_agent import NetAgent
from nvflare.fuel.f3.cellnet.utils import make_reply
from nvflare.fuel.f3.drivers.driver_params import DriverParams
from nvflare.fuel.sec.authn import set_add_auth_headers_filters
from nvflare.fuel.utils.attributes_exportable import ExportMode
from nvflare.fuel.utils.config_service import search_file
from nvflare.fuel.utils.constants import Mode
from nvflare.fuel.utils.log_utils import get_obj_logger
from nvflare.fuel.utils.validation_utils import check_object_type, check_str
from .pipe import Message, Pipe, Topic
SSL_ROOT_CERT = "rootCA.pem"
_PREFIX = "cell_pipe."
_HEADER_MSG_TYPE = _PREFIX + "msg_type"
_HEADER_MSG_ID = _PREFIX + "msg_id"
_HEADER_REQ_ID = _PREFIX + "req_id"
_HEADER_START_TIME = _PREFIX + "start"
_HEADER_HB_SEQ = _PREFIX + "hb_seq"
def _cell_fqcn(mode, site_name, token, parent_fqcn):
# The FQCN of the cell must be unique in the whole cellnet.
# We use the combination of mode, site_name, and token to derive the value of FQCN
# Since the token is usually used across all sites, the "site_name" differentiate cell on one site from another.
# The two peer pipes on the same site share the same site_name and token, but are differentiated by their modes.
base = f"{site_name}_{token}_{mode}"
if parent_fqcn == FQCN.ROOT_SERVER:
return base
else:
return FQCN.join([parent_fqcn, base])
def _to_cell_message(msg: Message, extra=None) -> CellMessage:
headers = {_HEADER_MSG_TYPE: msg.msg_type, _HEADER_MSG_ID: msg.msg_id, _HEADER_START_TIME: time.time()}
if extra:
headers.update(extra)
if msg.req_id:
headers[_HEADER_REQ_ID] = msg.req_id
return CellMessage(headers=headers, payload=msg.data)
def _from_cell_message(cm: CellMessage) -> Message:
return Message(
msg_id=cm.get_header(_HEADER_MSG_ID),
msg_type=cm.get_header(_HEADER_MSG_TYPE),
topic=cm.get_header(MessageHeaderKey.TOPIC),
req_id=cm.get_header(_HEADER_REQ_ID),
data=cm.payload,
)
class _CellInfo:
"""
A cell could be used by multiple pipes (e.g. one pipe for task interaction, another for metrics logging).
"""
def __init__(self, site_name, cell, net_agent, auth_token, token_signature):
self.site_name = site_name
self.cell = cell
self.auth_token = auth_token
self.token_signature = token_signature
self.net_agent = net_agent
self.started = False
self.pipes = []
self.lock = threading.Lock()
def start(self):
with self.lock:
if not self.started:
self.cell.start()
self.started = True
def add_pipe(self, p):
with self.lock:
self.pipes.append(p)
def close_pipe(self, p):
with self.lock:
try:
self.pipes.remove(p)
if len(self.pipes) == 0:
# all pipes are closed - close cell and agent
self.net_agent.close()
self.cell.stop()
except:
pass
[docs]
class CellPipe(Pipe):
"""
CellPipe is an implementation of `Pipe` that utilizes the `Cell` from NVFlare's foundation layer (f3) to
do the communication.
"""
_lock = threading.Lock()
_cells_info = {} # (root_url, site_name, token) => _CellInfo
@classmethod
def _build_cell(cls, site_name, fqcn, parent_conn_props, secure_mode, workspace_dir, logger):
"""Build a cell if necessary.
The combination of (root_url, site_name, token) uniquely determine one cell.
There can be multiple pipes on the same cell.
Args:
parent_conn_props: parent for this cell
secure_mode: whether cellnet is in secure mode
workspace_dir: workspace that contains startup kit for connecting to server. Needed only if secure_mode
Returns:
"""
with cls._lock:
ci = cls._cells_info.get(fqcn)
if not ci:
if secure_mode:
root_cert_path = search_file(SSL_ROOT_CERT, workspace_dir)
if not root_cert_path:
raise ValueError(f"cannot find {SSL_ROOT_CERT} from config path {workspace_dir}")
credentials = {
DriverParams.CA_CERT.value: root_cert_path,
}
else:
credentials = {}
conn_sec = parent_conn_props.get(ConnPropKey.CONNECTION_SECURITY)
if conn_sec:
credentials[DriverParams.CONNECTION_SECURITY.value] = conn_sec
parent_url = parent_conn_props.get(ConnPropKey.URL)
if FQCN.get_parent(fqcn):
# the cell has a parent: connect to the parent
cell_root = None
cell_parent_url = parent_url
else:
# the cell has no parent: the parent_url is the root of the cellnet
cell_root = parent_url
cell_parent_url = None
cell = Cell(
fqcn=fqcn,
root_url=cell_root,
secure=secure_mode,
credentials=credentials,
parent_url=cell_parent_url,
create_internal_listener=False,
)
auth_token = get_scope_property(scope_name=site_name, key=FLMetaKey.AUTH_TOKEN, default="NA")
token_signature = get_scope_property(site_name, FLMetaKey.AUTH_TOKEN_SIGNATURE, default="NA")
net_agent = NetAgent(cell)
ci = _CellInfo(site_name, cell, net_agent, auth_token, token_signature)
cls._cells_info[fqcn] = ci
set_add_auth_headers_filters(cell, ci.site_name, ci.auth_token, ci.token_signature)
return ci
def __init__(
self,
mode: Mode,
site_name: str,
token: str,
root_url: str = "",
secure_mode: bool = True,
workspace_dir: str = "",
):
"""The constructor of the CellPipe.
Args:
mode: passive or active mode
site_name (str): name of the FLARE site
token (str): unique id to guarantee the uniqueness of cell's FQCN.
root_url (str): the root url of the cellnet that the pipe's cell will join
secure_mode (bool): whether connection to the root is secure (TLS)
workspace_dir (str): the directory that contains startup for joining the cellnet. Required only in secure_mode
"""
super().__init__(mode)
self.logger = get_obj_logger(self)
self.site_name = site_name
self.token = token
self.secure_mode = secure_mode
self.workspace_dir = workspace_dir
self.root_url = root_url
# this section is needed by job config to prevent building cell when using SystemVarName arguments
# TODO: enhance this part
sysvarname_placeholders = ["{" + varname + "}" for varname in dir(SystemVarName)]
if any([arg in sysvarname_placeholders for arg in [site_name, token, root_url, secure_mode, workspace_dir]]):
return
check_str("root_url", root_url)
check_object_type("secure_mode", secure_mode, bool)
check_str("token", token)
check_str("site_name", site_name)
check_str("workspace_dir", workspace_dir)
# determine the endpoint for this pipe to connect to
root_conn_props = get_scope_property(site_name, ConnPropKey.ROOT_CONN_PROPS)
if root_conn_props:
# Not in simulator
if not isinstance(root_conn_props, dict):
raise RuntimeError(f"expect root_conn_props for {site_name} to be dict but got {type(root_conn_props)}")
cp_conn_props = get_scope_property(site_name, ConnPropKey.CP_CONN_PROPS)
if cp_conn_props:
if not isinstance(cp_conn_props, dict):
raise RuntimeError(f"expect cp_conn_props to be dict but got {type(cp_conn_props)}")
url_to_conns = {
root_conn_props.get(ConnPropKey.URL): root_conn_props,
cp_conn_props.get(ConnPropKey.URL): cp_conn_props,
}
relay_conn_props = get_scope_property(site_name, ConnPropKey.RELAY_CONN_PROPS)
if relay_conn_props:
if not isinstance(relay_conn_props, dict):
raise RuntimeError(f"expect relay_conn_props to be dict but got {type(relay_conn_props)}")
url_to_conns[relay_conn_props.get(ConnPropKey.URL)] = relay_conn_props
if not root_url:
# root_url not specified - use CP!
root_url = cp_conn_props.get(ConnPropKey.URL)
self.root_url = root_url
conn_props = url_to_conns.get(self.root_url)
if not conn_props:
raise RuntimeError(f"cannot determine conn props for '{root_url}'")
else:
# this is running in simulator
conn_props = {
ConnPropKey.URL: root_url,
ConnPropKey.FQCN: FQCN.ROOT_SERVER,
}
mode = f"{mode}".strip().lower() # convert to lower case string
fqcn = _cell_fqcn(mode, site_name, token, conn_props.get(ConnPropKey.FQCN))
self.ci = self._build_cell(site_name, fqcn, conn_props, secure_mode, workspace_dir, self.logger)
self.cell = self.ci.cell
self.ci.add_pipe(self)
if mode == "active":
peer_mode = "passive"
elif mode == "passive":
peer_mode = "active"
else:
raise ValueError(f"invalid mode {mode} - must be 'active' or 'passive'")
self.peer_fqcn = _cell_fqcn(peer_mode, site_name, token, conn_props.get(ConnPropKey.FQCN))
self.received_msgs = queue.Queue() # contains raw CellMessage objects
self.channel = None # the cellnet message channel
self.pipe_lock = threading.Lock() # used to ensure no msg to be sent after closed
self.closed = False
self.last_peer_active_time = 0.0
self.hb_seq = 1
# When True, every non-heartbeat outgoing cell message gets
# MessageHeaderKey.PASS_THROUGH=True stamped on it. The peer's
# Adapter.call() reads this header and builds a per-call FOBS decode
# context with FOBSContextKey.PASS_THROUGH=True so that tensors arrive
# as LazyDownloadRef placeholders rather than being downloaded inline.
#
# Set by ExProcessClientAPI.init() (subprocess→CJ reverse direction)
# when CellPipe is in use. Has no effect for FilePipe (which is not a
# CellPipe and never has this attribute set).
#
# Note: the forward direction (Fix 18, CJ→subprocess) does not use this
# flag; it is implemented via ReservedHeaderKey.PASS_THROUGH stamped on
# the shareable in SwarmClientController._scatter() and propagated by
# aux_runner.py — not through pipe.pass_through_on_send.
self.pass_through_on_send: bool = False
def _update_peer_active_time(self, msg: CellMessage, ch_name: str, msg_type: str):
origin = msg.get_header(MessageHeaderKey.ORIGIN)
if origin == self.peer_fqcn:
self.logger.debug(f"{time.time()}: _update_peer_active_time: {ch_name=} {msg_type=} {msg.headers}")
self.last_peer_active_time = time.time()
[docs]
def get_last_peer_active_time(self):
return self.last_peer_active_time
[docs]
def set_cell_cb(self, channel_name: str):
# This allows multiple pipes over the same cell (e.g. one channel for tasks, another for metrics),
# as long as different pipes use different cell message channels
self.channel = f"{_PREFIX}{channel_name}"
self.cell.register_request_cb(channel=self.channel, topic="*", cb=self._receive_message)
self.cell.core_cell.add_incoming_request_filter(
channel="*", topic="*", cb=self._update_peer_active_time, ch_name=channel_name, msg_type="req"
)
self.cell.core_cell.add_incoming_reply_filter(
channel="*", topic="*", cb=self._update_peer_active_time, ch_name=channel_name, msg_type="reply"
)
self.logger.info(f"registered CellPipe request CB for {self.channel}")
[docs]
def send(self, msg: Message, timeout=None) -> bool:
"""Sends the specified message to the peer.
Args:
msg: the message to be sent
timeout: if specified, number of secs to wait for the peer to read the message.
If not specified, wait indefinitely.
Returns:
Whether the message is read by the peer.
"""
with self.pipe_lock:
if self.closed:
raise BrokenPipeError("pipe closed")
# Note: the following code must not be within the lock scope
# Otherwise only one message can be sent at a time!
optional = False
if msg.topic in [Topic.END, Topic.ABORT, Topic.HEARTBEAT]:
optional = True
if not timeout and msg.topic in [Topic.END, Topic.ABORT]:
timeout = 5.0 # need to keep the connection for some time; otherwise the msg may not go out
if msg.topic == Topic.HEARTBEAT:
# Heartbeats are fire-and-forget; always create a fresh CellMessage so
# the timestamp header reflects the actual send time.
extra_headers = {_HEADER_HB_SEQ: self.hb_seq}
self.hb_seq += 1
# don't need to wait for reply!
self.cell.fire_and_forget(
channel=self.channel,
topic=msg.topic,
targets=[self.peer_fqcn],
message=_to_cell_message(msg, extra_headers),
optional=optional,
)
return True
# Serialize the message ONCE and cache the result on the Message object.
#
# Why: PipeHandler retries the same `msg` object on every failed send.
# Without caching, each retry calls _to_cell_message(msg) → FOBS encodes
# msg.data (the Shareable/numpy result) → creates a new ArrayDownloadable
# transaction in DownloadService. With a 5 GiB model and 14+ retries this
# produces 70–135 GiB of live transactions simultaneously (OOM crash).
#
# How it works: cell.send_request() calls encode_payload() which checks
# whether the CellMessage's encoding header is already set. On the first
# call it encodes (FOBS) and mutates request.payload to bytes, then sets the
# header. On every subsequent call with the same CellMessage object,
# encode_payload() sees the header is already set and skips re-serialization,
# so no new ArrayDownloadable is created.
if not hasattr(msg, "_cached_cell_msg"):
msg._cached_cell_msg = _to_cell_message(msg)
request = msg._cached_cell_msg
request.set_header(MessageHeaderKey.MSG_ROOT_ID, msg.msg_id)
# For REPLY messages (subprocess→CJ result direction), stamp MSG_ROOT_TTL so
# via_downloader._create_downloader() keeps the subprocess's DownloadService
# transaction alive long enough for the server to pull tensors directly from
# the subprocess.
#
# When pass_through_on_send is active (reverse PASS_THROUGH path), use
# _dl_ttl stamped by FlareAgent._do_submit_result() — this is
# download_complete_timeout (default 1800s), the actual transfer budget.
# Mirrors the forward direction where the server uses task.timeout.
#
# Fall back to `timeout` (= submit_result_timeout, the CJ-ACK timeout)
# for non-PASS_THROUGH REPLY messages where no tensor transfer occurs.
if msg.msg_type == Message.REPLY:
dl_ttl = getattr(msg, "_dl_ttl", None) if self.pass_through_on_send else None
ttl = dl_ttl if dl_ttl and dl_ttl > 0 else timeout
if ttl is not None and ttl > 0:
request.set_header(MessageHeaderKey.MSG_ROOT_TTL, float(ttl))
# Stamp PASS_THROUGH on every outgoing task/result message when the
# caller has opted in. Adapter.call() on the receiving side reads this
# header and builds a per-call FOBS decode context with
# FOBSContextKey.PASS_THROUGH=True so that large tensors arrive as
# LazyDownloadRef placeholders rather than being downloaded inline.
# Heartbeat messages do not carry model data and always skip this path
# (they use fire_and_forget above).
if self.pass_through_on_send:
request.set_header(MessageHeaderKey.PASS_THROUGH, True)
reply = self.cell.send_request(
channel=self.channel,
topic=msg.topic,
target=self.peer_fqcn,
request=request,
timeout=timeout,
optional=optional,
)
if reply:
rc = reply.get_header(MessageHeaderKey.RETURN_CODE)
if rc == ReturnCode.OK:
return True
else:
err = f"failed to send '{msg.topic}' to '{self.peer_fqcn}' in channel '{self.channel}': {rc}"
if optional:
self.logger.debug(err)
else:
self.logger.error(err)
return False
else:
return False
def _receive_message(self, request: CellMessage) -> Union[None, CellMessage]:
# Return the pipe-level ACK as quickly as possible.
#
# The expensive work (FOBS decode / tensor download) has ALREADY been done
# by Adapter.call() in cell.py BEFORE this callback is invoked. With
# reverse PASS_THROUGH enabled on the pipe cell, that decode is cheap
# (creates LazyDownloadRef objects rather than downloading).
#
# We queue the raw CellMessage rather than converting it to a Message here.
# The conversion (_from_cell_message) is deferred to receive() time so that
# this callback – and therefore the cell-level ACK path – performs the
# absolute minimum work before returning ReturnCode.OK to the sender.
sender = request.get_header(MessageHeaderKey.ORIGIN)
topic = request.get_header(MessageHeaderKey.TOPIC)
self.logger.debug(f"got msg from peer {sender}: {topic}")
if self.peer_fqcn != sender:
raise RuntimeError(f"peer FQCN mismatch: expect {self.peer_fqcn} but got {sender}")
self.received_msgs.put_nowait(request)
return make_reply(ReturnCode.OK)
[docs]
def receive(self, timeout=None) -> Union[None, Message]:
try:
if timeout:
cm = self.received_msgs.get(block=True, timeout=timeout)
else:
cm = self.received_msgs.get_nowait()
except queue.Empty:
return None
# Convert the raw CellMessage to a Message at dequeue time.
return _from_cell_message(cm)
[docs]
def clear(self):
while not self.received_msgs.empty():
self.received_msgs.get_nowait()
[docs]
def release_send_cache(self, msg: Message):
"""Clear the cached CellMessage that was attached to *msg* by send().
The cache is created on the first send() call so that retries reuse the
already-serialized CellMessage. Once the retry loop exits this
cache is no longer needed. Dropping it allows the encoded payload bytes
and any lingering references to be reclaimed by GC promptly, rather than
waiting for the Message object itself to go out of scope.
"""
msg.__dict__.pop("_cached_cell_msg", None)
[docs]
def can_resend(self) -> bool:
return True
[docs]
def open(self, name: str):
with self.pipe_lock:
if self.closed:
raise BrokenPipeError("pipe already closed")
self.ci.start()
self.set_cell_cb(name)
[docs]
def close(self):
with self.pipe_lock:
if self.closed:
return
self.ci.close_pipe(self)
self.closed = True
[docs]
def export(self, export_mode: str) -> Tuple[str, dict]:
if export_mode == ExportMode.SELF:
mode = self.mode
else:
mode = Mode.ACTIVE if self.mode == Mode.PASSIVE else Mode.PASSIVE
export_args = {
"mode": mode,
"site_name": self.site_name,
"token": self.token,
"root_url": self.root_url,
"secure_mode": self.cell.core_cell.secure,
"workspace_dir": self.workspace_dir,
}
return f"{self.__module__}.{self.__class__.__name__}", export_args