# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import queue
import threading
import time
from typing import Tuple, Union
from nvflare.apis.fl_constant import CellMessageAuthHeaderKey, FLMetaKey, SecureTrainConst, SystemVarName
from nvflare.fuel.data_event.utils import get_scope_property
from nvflare.fuel.f3.cellnet.cell import Cell
from nvflare.fuel.f3.cellnet.cell import Message as CellMessage
from nvflare.fuel.f3.cellnet.defs import MessageHeaderKey, ReturnCode
from nvflare.fuel.f3.cellnet.net_agent import NetAgent
from nvflare.fuel.f3.cellnet.utils import make_reply
from nvflare.fuel.f3.drivers.driver_params import DriverParams
from nvflare.fuel.utils.attributes_exportable import ExportMode
from nvflare.fuel.utils.config_service import search_file
from nvflare.fuel.utils.constants import Mode
from nvflare.fuel.utils.log_utils import get_obj_logger
from nvflare.fuel.utils.validation_utils import check_object_type, check_str
from .pipe import Message, Pipe, Topic
SSL_ROOT_CERT = "rootCA.pem"
_PREFIX = "cell_pipe."
_HEADER_MSG_TYPE = _PREFIX + "msg_type"
_HEADER_MSG_ID = _PREFIX + "msg_id"
_HEADER_REQ_ID = _PREFIX + "req_id"
_HEADER_START_TIME = _PREFIX + "start"
_HEADER_HB_SEQ = _PREFIX + "hb_seq"
def _cell_fqcn(mode, site_name, token):
# The FQCN of the cell must be unique in the whole cellnet.
# We use the combination of mode, site_name, and token to derive the value of FQCN
# Since the token is usually used across all sites, the "site_name" differentiate cell on one site from another.
# The two peer pipes on the same site share the same site_name and token, but are differentiated by their modes.
return f"{site_name}_{token}_{mode}"
def _to_cell_message(msg: Message, extra=None) -> CellMessage:
headers = {_HEADER_MSG_TYPE: msg.msg_type, _HEADER_MSG_ID: msg.msg_id, _HEADER_START_TIME: time.time()}
if extra:
headers.update(extra)
if msg.req_id:
headers[_HEADER_REQ_ID] = msg.req_id
return CellMessage(headers=headers, payload=msg.data)
def _from_cell_message(cm: CellMessage) -> Message:
return Message(
msg_id=cm.get_header(_HEADER_MSG_ID),
msg_type=cm.get_header(_HEADER_MSG_TYPE),
topic=cm.get_header(MessageHeaderKey.TOPIC),
req_id=cm.get_header(_HEADER_REQ_ID),
data=cm.payload,
)
class _CellInfo:
"""
A cell could be used by multiple pipes (e.g. one pipe for task interaction, another for metrics logging).
"""
def __init__(self, cell, net_agent):
self.cell = cell
self.net_agent = net_agent
self.started = False
self.pipes = []
self.lock = threading.Lock()
def start(self):
with self.lock:
if not self.started:
self.cell.start()
self.started = True
def add_pipe(self, p):
with self.lock:
self.pipes.append(p)
def close_pipe(self, p):
with self.lock:
try:
self.pipes.remove(p)
if len(self.pipes) == 0:
# all pipes are closed - close cell and agent
self.net_agent.close()
self.cell.stop()
except:
pass
[docs]
class CellPipe(Pipe):
"""
CellPipe is an implementation of `Pipe` that utilizes the `Cell` from NVFlare's foundation layer (f3) to
do the communication.
"""
_lock = threading.Lock()
_cells_info = {} # (root_url, site_name, token) => _CellInfo
_auth_token = None
_token_signature = None
_site_name = None
@classmethod
def _build_cell(cls, mode, root_url, site_name, token, secure_mode, workspace_dir):
"""Build a cell if necessary.
The combination of (root_url, site_name, token) uniquely determine one cell.
There can be multiple pipes on the same cell.
Args:
root_url: root url of the cell net
mode: mode (passive or active) of the pipe
site_name: name of the site
token: the unique token
secure_mode: whether cellnet is in secure mode
workspace_dir: workspace that contains startup kit for connecting to server. Needed only if secure_mode
Returns:
"""
with cls._lock:
cls._site_name = site_name
cell_key = f"{root_url}.{site_name}.{token}"
ci = cls._cells_info.get(cell_key)
if not ci:
credentials = {}
if secure_mode:
root_cert_path = search_file(SSL_ROOT_CERT, workspace_dir)
if not root_cert_path:
raise ValueError(f"cannot find {SSL_ROOT_CERT} from config path {workspace_dir}")
credentials = {
DriverParams.CA_CERT.value: root_cert_path,
}
conn_sec = get_scope_property(site_name, SecureTrainConst.CONNECTION_SECURITY)
if conn_sec:
credentials[DriverParams.CONNECTION_SECURITY.value] = conn_sec
cell = Cell(
fqcn=_cell_fqcn(mode, site_name, token),
root_url=root_url,
secure=secure_mode,
credentials=credentials,
create_internal_listener=False,
)
# set filter to add additional auth headers
cell.core_cell.add_outgoing_reply_filter(channel="*", topic="*", cb=cls._add_auth_headers)
cell.core_cell.add_outgoing_request_filter(channel="*", topic="*", cb=cls._add_auth_headers)
net_agent = NetAgent(cell)
ci = _CellInfo(cell, net_agent)
cls._cells_info[cell_key] = ci
return ci
@classmethod
def _add_auth_headers(cls, message: CellMessage):
if not cls._auth_token:
cls._auth_token = get_scope_property(scope_name=cls._site_name, key=FLMetaKey.AUTH_TOKEN, default="NA")
cls._token_signature = get_scope_property(cls._site_name, FLMetaKey.AUTH_TOKEN_SIGNATURE, default="NA")
message.set_header(CellMessageAuthHeaderKey.CLIENT_NAME, cls._site_name)
message.set_header(CellMessageAuthHeaderKey.TOKEN, cls._auth_token)
message.set_header(CellMessageAuthHeaderKey.TOKEN_SIGNATURE, cls._token_signature)
def __init__(
self,
mode: Mode,
site_name: str,
token: str,
root_url: str = "",
secure_mode: bool = True,
workspace_dir: str = "",
):
"""The constructor of the CellPipe.
Args:
mode: passive or active mode
site_name (str): name of the FLARE site
token (str): unique id to guarantee the uniqueness of cell's FQCN.
root_url (str): the root url of the cellnet that the pipe's cell will join
secure_mode (bool): whether connection to the root is secure (TLS)
workspace_dir (str): the directory that contains startup for joining the cellnet. Required only in secure_mode
"""
super().__init__(mode)
self.logger = get_obj_logger(self)
self.site_name = site_name
self.token = token
self.root_url = root_url
self.secure_mode = secure_mode
self.workspace_dir = workspace_dir
# this section is needed by job config to prevent building cell when using SystemVarName arguments
# TODO: enhance this part
sysvarname_placeholders = ["{" + varname + "}" for varname in dir(SystemVarName)]
if any([arg in sysvarname_placeholders for arg in [site_name, token, root_url, secure_mode, workspace_dir]]):
return
check_str("root_url", root_url)
check_object_type("secure_mode", secure_mode, bool)
check_str("token", token)
check_str("site_name", site_name)
check_str("workspace_dir", workspace_dir)
mode = f"{mode}".strip().lower() # convert to lower case string
self.ci = self._build_cell(mode, root_url, site_name, token, secure_mode, workspace_dir)
self.cell = self.ci.cell
self.ci.add_pipe(self)
if mode == "active":
peer_mode = "passive"
elif mode == "passive":
peer_mode = "active"
else:
raise ValueError(f"invalid mode {mode} - must be 'active' or 'passive'")
self.peer_fqcn = _cell_fqcn(peer_mode, site_name, token)
self.received_msgs = queue.Queue() # contains Message(s), not CellMessage(s)!
self.channel = None # the cellnet message channel
self.pipe_lock = threading.Lock() # used to ensure no msg to be sent after closed
self.closed = False
self.last_peer_active_time = 0.0
self.hb_seq = 1
def _update_peer_active_time(self, msg: CellMessage, ch_name: str, msg_type: str):
origin = msg.get_header(MessageHeaderKey.ORIGIN)
if origin == self.peer_fqcn:
self.logger.debug(f"{time.time()}: _update_peer_active_time: {ch_name=} {msg_type=} {msg.headers}")
self.last_peer_active_time = time.time()
[docs]
def get_last_peer_active_time(self):
return self.last_peer_active_time
[docs]
def set_cell_cb(self, channel_name: str):
# This allows multiple pipes over the same cell (e.g. one channel for tasks, another for metrics),
# as long as different pipes use different cell message channels
self.channel = f"{_PREFIX}{channel_name}"
self.cell.register_request_cb(channel=self.channel, topic="*", cb=self._receive_message)
self.cell.core_cell.add_incoming_request_filter(
channel="*", topic="*", cb=self._update_peer_active_time, ch_name=channel_name, msg_type="req"
)
self.cell.core_cell.add_incoming_reply_filter(
channel="*", topic="*", cb=self._update_peer_active_time, ch_name=channel_name, msg_type="reply"
)
self.logger.info(f"registered CellPipe request CB for {self.channel}")
[docs]
def send(self, msg: Message, timeout=None) -> bool:
"""Sends the specified message to the peer.
Args:
msg: the message to be sent
timeout: if specified, number of secs to wait for the peer to read the message.
If not specified, wait indefinitely.
Returns:
Whether the message is read by the peer.
"""
with self.pipe_lock:
if self.closed:
raise BrokenPipeError("pipe closed")
# Note: the following code must not be within the lock scope
# Otherwise only one message can be sent at a time!
optional = False
if msg.topic in [Topic.END, Topic.ABORT, Topic.HEARTBEAT]:
optional = True
if not timeout and msg.topic in [Topic.END, Topic.ABORT]:
timeout = 5.0 # need to keep the connection for some time; otherwise the msg may not go out
if msg.topic == Topic.HEARTBEAT:
# for debugging purpose
extra_headers = {_HEADER_HB_SEQ: self.hb_seq}
self.hb_seq += 1
# don't need to wait for reply!
self.cell.fire_and_forget(
channel=self.channel,
topic=msg.topic,
targets=[self.peer_fqcn],
message=_to_cell_message(msg, extra_headers),
optional=optional,
)
return True
reply = self.cell.send_request(
channel=self.channel,
topic=msg.topic,
target=self.peer_fqcn,
request=_to_cell_message(msg),
timeout=timeout,
optional=optional,
)
if reply:
rc = reply.get_header(MessageHeaderKey.RETURN_CODE)
if rc == ReturnCode.OK:
return True
else:
err = f"failed to send '{msg.topic}' to '{self.peer_fqcn}' in channel '{self.channel}': {rc}"
if optional:
self.logger.debug(err)
else:
self.logger.error(err)
return False
else:
return False
def _receive_message(self, request: CellMessage) -> Union[None, CellMessage]:
sender = request.get_header(MessageHeaderKey.ORIGIN)
topic = request.get_header(MessageHeaderKey.TOPIC)
self.logger.debug(f"got msg from peer {sender}: {topic}")
if self.peer_fqcn != sender:
raise RuntimeError(f"peer FQCN mismatch: expect {self.peer_fqcn} but got {sender}")
msg = _from_cell_message(request)
self.received_msgs.put_nowait(msg)
return make_reply(ReturnCode.OK)
[docs]
def receive(self, timeout=None) -> Union[None, Message]:
try:
if timeout:
return self.received_msgs.get(block=True, timeout=timeout)
else:
return self.received_msgs.get_nowait()
except queue.Empty:
return None
[docs]
def clear(self):
while not self.received_msgs.empty():
self.received_msgs.get_nowait()
[docs]
def can_resend(self) -> bool:
return True
[docs]
def open(self, name: str):
with self.pipe_lock:
if self.closed:
raise BrokenPipeError("pipe already closed")
self.ci.start()
self.set_cell_cb(name)
[docs]
def close(self):
with self.pipe_lock:
if self.closed:
return
self.ci.close_pipe(self)
self.closed = True
[docs]
def export(self, export_mode: str) -> Tuple[str, dict]:
if export_mode == ExportMode.SELF:
mode = self.mode
root_url = self.root_url
else:
mode = Mode.ACTIVE if self.mode == Mode.PASSIVE else Mode.PASSIVE
root_url = self.cell.get_root_url_for_child()
export_args = {
"mode": mode,
"site_name": self.site_name,
"token": self.token,
"root_url": root_url,
"secure_mode": self.cell.core_cell.secure,
"workspace_dir": self.workspace_dir,
}
return f"{self.__module__}.{self.__class__.__name__}", export_args