# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
import logging
import os
import random
import threading
import time
import uuid
from typing import Dict, List, Tuple, Union
from urllib.parse import urlparse
from nvflare.fuel.f3.cellnet.connector_manager import ConnectorManager
from nvflare.fuel.f3.cellnet.credential_manager import CredentialManager
from nvflare.fuel.f3.cellnet.defs import (
AbortRun,
AuthenticationError,
CellPropertyKey,
InvalidRequest,
InvalidSession,
MessageHeaderKey,
MessagePropKey,
MessageType,
ReturnCode,
ReturnReason,
ServiceUnavailable,
)
from nvflare.fuel.f3.cellnet.fqcn import FQCN, FqcnInfo, same_family
from nvflare.fuel.f3.cellnet.registry import Callback, Registry
from nvflare.fuel.f3.cellnet.utils import decode_payload, encode_payload, format_log_message, make_reply
from nvflare.fuel.f3.comm_config import CommConfigurator
from nvflare.fuel.f3.communicator import Communicator, MessageReceiver
from nvflare.fuel.f3.connection import Connection
from nvflare.fuel.f3.drivers.driver_params import DriverParams
from nvflare.fuel.f3.drivers.net_utils import enhance_credential_info
from nvflare.fuel.f3.endpoint import Endpoint, EndpointMonitor, EndpointState
from nvflare.fuel.f3.message import Message
from nvflare.fuel.f3.mpm import MainProcessMonitor
from nvflare.fuel.f3.stats_pool import StatsPoolManager
from nvflare.security.logging import secure_format_exception, secure_format_traceback
_CHANNEL = "cellnet.channel"
_TOPIC_BULK = "bulk"
_TOPIC_BYE = "bye"
_SM_CHANNEL = "credential_manager"
_SM_TOPIC = "key_exchange"
_ONE_MB = 1024 * 1024
[docs]class TargetMessage:
def __init__(
self,
target: str,
channel: str,
topic: str,
message: Message,
):
self.target = target
self.channel = channel
self.topic = topic
self.message = message
message.add_headers(
{
MessageHeaderKey.TOPIC: topic,
MessageHeaderKey.CHANNEL: channel,
MessageHeaderKey.DESTINATION: target,
}
)
[docs] def to_dict(self):
return {
"target": self.target,
"channel": self.channel,
"topic": self.topic,
"message": {"headers": dict(self.message.headers), "payload": self.message.payload},
}
[docs] @staticmethod
def from_dict(d: dict):
msg_dict = d.get("message")
msg = Message(headers=msg_dict.get("headers"), payload=msg_dict.get("payload"))
return TargetMessage(target=d.get("target"), channel=d.get("channel"), topic=d.get("topic"), message=msg)
[docs]class CellAgent:
"""A CellAgent represents a cell in another cell."""
def __init__(self, fqcn: str, endpoint: Endpoint):
"""
Args:
fqcn: FQCN of the cell represented
"""
err = FQCN.validate(fqcn)
if err:
raise ValueError(f"Invalid FQCN '{fqcn}': {err}")
self.info = FqcnInfo(FQCN.normalize(fqcn))
self.endpoint = endpoint
[docs] def get_fqcn(self):
return self.info.fqcn
class _Waiter(threading.Event):
def __init__(self, targets: List[str]):
super().__init__()
self.targets = [x for x in targets]
self.reply_time = {} # target_id => reply recv timestamp
self.send_time = time.time()
self.id = str(uuid.uuid4())
self.received_replies = {}
[docs]def log_messaging_error(
logger, log_text: str, cell, msg: Union[Message, None], log_except=False, log_level=logging.ERROR
):
debug = False
if msg:
debug = msg.get_header(MessageHeaderKey.OPTIONAL, default=False)
log_text = format_log_message(cell.get_fqcn(), msg, log_text)
else:
log_text = f"{cell.get_fqcn()}: {log_text}"
if MainProcessMonitor.is_stopping():
debug = True
if debug:
log_level = logging.DEBUG
logger.log(log_level, log_text)
if log_except:
logger.log(log_level, secure_format_traceback())
class _BulkSender:
def __init__(self, cell, target: str, max_queue_size, secure=False):
self.cell = cell
self.target = target
self.max_queue_size = max_queue_size
self.secure = secure
self.messages = []
self.last_send_time = 0
self.lock = threading.Lock()
self.logger = logging.getLogger(self.__class__.__name__)
def queue_message(self, channel: str, topic: str, message: Message):
if self.secure:
message.add_headers({MessageHeaderKey.SECURE: True})
encode_payload(message)
self.cell.encrypt_payload(message)
with self.lock:
tm = TargetMessage(target=self.target, channel=channel, topic=topic, message=message)
self.messages.append(tm)
self.logger.debug(f"{self.cell.get_fqcn()}: bulk sender {self.target} queue size {len(self.messages)}")
def send(self):
with self.lock:
num_msgs = len(self.messages)
if num_msgs == 0:
return
if num_msgs <= self.max_queue_size:
messages_to_send = self.messages
self.messages = []
else:
messages_to_send = self.messages[: self.max_queue_size]
self.messages = self.messages[self.max_queue_size :]
self.logger.debug(
f"{self.cell.get_fqcn()}: bulk sender {self.target} sending bulk size {len(messages_to_send)}"
)
tms = [m.to_dict() for m in messages_to_send]
bulk_msg = Message(None, tms)
send_errs = self.cell.fire_and_forget(
channel=_CHANNEL, topic=_TOPIC_BULK, targets=[self.target], message=bulk_msg
)
if send_errs[self.target]:
log_messaging_error(
logger=self.logger,
msg=bulk_msg,
log_text=f"failed to send bulk message: {send_errs[self.target]}",
cell=self.cell,
)
else:
self.logger.debug(f"{self.cell.get_fqcn()}: sent bulk messages ({len(messages_to_send)}) to {self.target}")
self.last_send_time = time.time()
def _validate_url(url: str) -> bool:
if not isinstance(url, str) or not url:
return False
result = urlparse(url)
if not result.scheme or not result.netloc:
return False
return True
class _CounterName:
LATE = "late"
SENT = "sent"
RETURN = "return"
FORWARD = "forward"
RECEIVED = "received"
REPLIED = "replied"
REPLY_NONE = "no_reply:none"
NO_REPLY_LATE = "no_reply:late"
REPLY_NOT_EXPECTED = "no_reply_expected"
REQ_FILTER_ERROR = "req_filter_error"
REP_FILTER_ERROR = "rep_filter_error"
[docs]class CertificateExchanger:
"""This class handles cert-exchange messages"""
def __init__(self, core_cell, credential_manager: CredentialManager):
self.core_cell = core_cell
self.credential_manager = credential_manager
self.core_cell.register_request_cb(_SM_CHANNEL, _SM_TOPIC, self._handle_cert_request)
[docs] def get_certificate(self, target: str) -> bytes:
cert = self.credential_manager.get_certificate(target)
if cert:
return cert
cert = self.exchange_certificate(target)
self.credential_manager.save_certificate(target, cert)
return cert
[docs] def exchange_certificate(self, target: str) -> bytes:
root = FQCN.get_root(target)
req = self.credential_manager.create_request(root)
response = self.core_cell.send_request(_SM_CHANNEL, _SM_TOPIC, root, Message(None, req))
reply = response.payload
if not reply:
error_code = response.get_header(MessageHeaderKey.RETURN_CODE)
raise RuntimeError(f"Cert exchanged to {root} failed: {error_code}")
return self.credential_manager.process_response(reply)
def _handle_cert_request(self, request: Message):
reply = self.credential_manager.process_request(request.payload)
return Message(None, reply)
[docs]class CoreCell(MessageReceiver, EndpointMonitor):
APP_ID = 1
ERR_TYPE_MSG_TOO_BIG = "MsgTooBig"
ERR_TYPE_COMM = "CommErr"
ALL_CELLS = {} # cell name => Cell
SUB_TYPE_CHILD = 1
SUB_TYPE_CLIENT = 2
SUB_TYPE_NONE = 0
def __init__(
self,
fqcn: str,
root_url: str,
secure: bool,
credentials: dict,
create_internal_listener: bool = False,
parent_url: str = None,
max_timeout=3600,
bulk_check_interval=0.5,
bulk_process_interval=0.5,
max_bulk_size=100,
):
"""
Args:
fqcn: the Cell's FQCN (Fully Qualified Cell Name)
credentials: credentials for secure connections
root_url: the URL for backbone external connection
secure: secure mode or not
max_timeout: default timeout for send_and_receive
create_internal_listener: whether to create an internal listener for child cells
parent_url: url for connecting to parent cell
FQCN is the names of all ancestor, concatenated with dots.
.. note::
Internal listener is automatically created for root cells.
.. code-block:: text
Example:
server.J12345 (the cell for job J12345 on the server)
server (the root cell of server)
nih_1.J12345 (the cell for job J12345 on client_1's site)
client_1.J12345.R0 (the cell for rank R0 of J12345 on client_1 site)
client_1 (he root cell of client_1)
"""
if fqcn in self.ALL_CELLS:
raise ValueError(f"there is already a cell named {fqcn}")
comm_configurator = CommConfigurator()
self._name = self.__class__.__name__
self.logger = logging.getLogger(self._name)
self.max_msg_size = comm_configurator.get_max_message_size()
self.comm_configurator = comm_configurator
err = FQCN.validate(fqcn)
if err:
raise ValueError(f"Invalid FQCN '{fqcn}': {err}")
self.my_info = FqcnInfo(FQCN.normalize(fqcn))
self.secure = secure
self.logger.debug(f"{self.my_info.fqcn}: max_msg_size={self.max_msg_size}")
if not root_url and not parent_url:
raise ValueError(f"{self.my_info.fqcn}: neither root_url nor parent_url is provided")
if self.my_info.is_root and self.my_info.is_on_server:
if not root_url:
raise ValueError(f"{self.my_info.fqcn}: root_url is required for server-side cells but not provided")
if isinstance(root_url, list):
for url in root_url:
if not _validate_url(url):
raise ValueError(f"{self.my_info.fqcn}: invalid Root URL '{url}'")
else:
if not _validate_url(root_url):
raise ValueError(f"{self.my_info.fqcn}: invalid Root URL '{root_url}'")
root_url = [root_url]
elif root_url:
if isinstance(root_url, list):
# multiple urls are available - randomly pick one
root_url = random.choice(root_url)
self.logger.info(f"{self.my_info.fqcn}: use Root URL {root_url}")
if not _validate_url(root_url):
raise ValueError(f"{self.my_info.fqcn}: invalid Root URL '{root_url}'")
self.root_url = root_url
self.create_internal_listener = create_internal_listener
self.parent_url = parent_url
self.bulk_check_interval = bulk_check_interval
self.max_bulk_size = max_bulk_size
self.bulk_checker = None
self.bulk_senders = {}
self.bulk_process_interval = bulk_process_interval
self.bulk_messages = []
self.bulk_processor = None
self.bulk_lock = threading.Lock()
self.bulk_msg_lock = threading.Lock()
self.agents = {} # cell_fqcn => CellAgent
self.agent_lock = threading.Lock()
self.logger.debug(f"Creating Cell: {self.my_info.fqcn}")
if credentials:
enhance_credential_info(credentials)
ep = Endpoint(
name=fqcn,
conn_props=credentials,
properties={
CellPropertyKey.FQCN: self.my_info.fqcn,
},
)
self.communicator = Communicator(local_endpoint=ep)
self.endpoint = ep
self.connector_manager = ConnectorManager(
communicator=self.communicator, secure=secure, comm_configurator=comm_configurator
)
self.communicator.register_message_receiver(app_id=self.APP_ID, receiver=self)
self.communicator.register_monitor(monitor=self)
self.req_reg = Registry()
self.in_req_filter_reg = Registry() # for request received
self.out_reply_filter_reg = Registry() # for reply going out
self.out_req_filter_reg = Registry() # for request sent
self.in_reply_filter_reg = Registry() # for reply received
self.error_handler_reg = Registry()
self.cell_connected_cb = None
self.cell_connected_cb_args = None
self.cell_connected_cb_kwargs = None
self.cell_disconnected_cb = None
self.cell_disconnected_cb_args = None
self.cell_disconnected_cb_kwargs = None
self.message_interceptor = None
self.message_interceptor_args = None
self.message_interceptor_kwargs = None
self.waiters = {} # req_id => req
self.stats_lock = threading.Lock()
self.req_hw = 0
self.num_sar_reqs = 0 # send-and-receive
self.num_faf_reqs = 0
self.num_timeout_reqs = 0
# req_expiry specifies how long we keep requests in "reqs" table if they are
# not answered or picked up
if not max_timeout or max_timeout <= 0:
max_timeout = 3600 # one hour
self.max_timeout = max_timeout
self.asked_to_stop = False
self.running = False
self.stopping = False
# add appropriate drivers based on roles of the cell
# a cell can have at most two listeners: one for external, one for internal
self.ext_listeners = {} # external listeners: url => connector object
self.ext_listener_lock = threading.Lock()
self.ext_listener_impossible = False
self.int_listener = None # backbone internal listener - only for cells with child cells
# a cell could have any number of connectors: some for backbone, some for ad-hoc
self.bb_ext_connector = None # backbone external connector - only for Client cells
self.bb_int_connector = None # backbone internal connector - only for non-root cells
# ad-hoc connectors: currently only support ad-hoc external connectors
self.adhoc_connectors = {} # target cell fqcn => connector
self.adhoc_connector_lock = threading.Lock()
self.root_change_lock = threading.Lock()
self.register_request_cb(channel=_CHANNEL, topic=_TOPIC_BULK, cb=self._receive_bulk_message)
self.register_request_cb(channel=_CHANNEL, topic=_TOPIC_BYE, cb=self._peer_goodbye)
self.cleanup_waiter = None
self.msg_stats_pool = StatsPoolManager.add_time_hist_pool(
"Request_Response", "Request/response time in secs (sender)", scope=self.my_info.fqcn
)
self.req_cb_stats_pool = StatsPoolManager.add_time_hist_pool(
"Request_Processing",
"Time spent (secs) by request processing callbacks (receiver)",
scope=self.my_info.fqcn,
)
self.msg_travel_stats_pool = StatsPoolManager.add_time_hist_pool(
"Msg_Travel", "Time taken (secs) to get here (receiver)", scope=self.my_info.fqcn
)
self.sent_msg_size_pool = StatsPoolManager.add_msg_size_pool(
"Sent_Msg_Sizes", "Sizes of messages sent (MBs)", scope=self.my_info.fqcn
)
self.received_msg_size_pool = StatsPoolManager.add_msg_size_pool(
"Received_Msg_Sizes", "Sizes of messages received (MBs)", scope=self.my_info.fqcn
)
counter_names = [_CounterName.SENT]
self.sent_msg_counter_pool = StatsPoolManager.add_counter_pool(
name="Sent_Msg_Counters",
description="Result counters of sent messages",
counter_names=counter_names,
scope=self.my_info.fqcn,
)
counter_names = [_CounterName.RECEIVED]
self.received_msg_counter_pool = StatsPoolManager.add_counter_pool(
name="Received_Msg_Counters",
description="Result counters of received messages",
counter_names=counter_names,
scope=self.my_info.fqcn,
)
self.ALL_CELLS[fqcn] = self
self.credential_manager = CredentialManager(self.endpoint)
self.cert_ex = CertificateExchanger(self, self.credential_manager)
[docs] def log_error(self, log_text: str, msg: Union[None, Message], log_except=False):
log_messaging_error(
logger=self.logger, log_text=log_text, cell=self, msg=msg, log_except=log_except, log_level=logging.ERROR
)
[docs] def log_warning(self, log_text: str, msg: Union[None, Message], log_except=False):
log_messaging_error(
logger=self.logger, log_text=log_text, cell=self, msg=msg, log_except=log_except, log_level=logging.WARNING
)
[docs] def get_root_url_for_child(self):
if isinstance(self.root_url, list):
return self.root_url[0]
else:
return self.root_url
[docs] def get_fqcn(self) -> str:
return self.my_info.fqcn
[docs] def is_cell_reachable(self, target_fqcn: str, for_msg=None) -> bool:
if target_fqcn in self.ALL_CELLS:
return True
_, ep = self._find_endpoint(target_fqcn, for_msg)
return ep is not None
[docs] def is_cell_connected(self, target_fqcn: str) -> bool:
if target_fqcn in self.ALL_CELLS:
return True
agent = self.agents.get(target_fqcn)
return agent is not None
[docs] def is_backbone_ready(self):
"""Check if backbone is ready.
Backbone is the preconfigured network connections, like all the connections from clients to server.
Adhoc connections are not part of the backbone.
"""
if not self.running:
return False
if self.my_info.is_root:
if self.my_info.is_on_server:
# server root - make sure listener is created
return len(self.ext_listeners) > 0
else:
# client root - must be connected to server root
if FQCN.ROOT_SERVER in self.ALL_CELLS:
return True
else:
return self.agents.get(FQCN.ROOT_SERVER) is not None
else:
# child cell - must be connected to parent
parent_fqcn = FQCN.get_parent(self.my_info.fqcn)
if parent_fqcn in self.ALL_CELLS:
return True
else:
return self.agents.get(parent_fqcn) is not None
def _set_bb_for_client_root(self):
self._create_bb_external_connector()
if self.create_internal_listener:
self._create_internal_listener()
def _set_bb_for_client_child(self, parent_url: str, create_internal_listener: bool):
if parent_url:
self._create_internal_connector(parent_url)
if create_internal_listener:
self._create_internal_listener()
if self.connector_manager.should_connect_to_server(self.my_info):
self._create_bb_external_connector()
def _set_bb_for_server_root(self):
if isinstance(self.root_url, list):
for url in self.root_url:
self.logger.info(f"{self.my_info.fqcn}: creating listener on {url}")
self._create_external_listener(url)
else:
self.logger.info(f"{self.my_info.fqcn}: creating listener on {self.root_url}")
if self.root_url:
self._create_external_listener(self.root_url)
if self.create_internal_listener:
self._create_internal_listener()
def _set_bb_for_server_child(self, parent_url: str, create_internal_listener: bool):
if FQCN.ROOT_SERVER in self.ALL_CELLS:
return
if parent_url:
self._create_internal_connector(parent_url)
if create_internal_listener:
self._create_internal_listener()
[docs] def change_server_root(self, to_url: str):
"""Change to a different server url
Args:
to_url: the new url of the server root
Returns:
"""
self.logger.debug(f"{self.my_info.fqcn}: changing server root to {to_url}")
with self.root_change_lock:
if self.my_info.is_on_server:
# only affect clients
self.logger.debug(f"{self.my_info.fqcn}: no change - on server side")
return
if to_url == self.root_url:
# already changed
self.logger.debug(f"{self.my_info.fqcn}: no change - same url")
return
self.root_url = to_url
self.drop_connectors()
self.drop_agents()
# recreate backbone connector to the root
if self.my_info.gen <= 2:
self.logger.debug(f"{self.my_info.fqcn}: recreating bb_external_connector ...")
self._create_bb_external_connector()
[docs] def drop_connectors(self):
# drop connections to all cells on server and their agents
# drop the backbone connector
if self.bb_ext_connector:
self.logger.debug(f"{self.my_info.fqcn}: removing bb_ext_connector ...")
try:
self.communicator.remove_connector(self.bb_ext_connector.handle)
self.communicator.remove_endpoint(FQCN.ROOT_SERVER)
except Exception as ex:
self.log_error(
msg=None,
log_text=f"{self.my_info.fqcn}: error removing bb_ext_connector {secure_format_exception(ex)}",
)
self.bb_ext_connector = None
# drop ad-hoc connectors to cells on server
with self.adhoc_connector_lock:
cells_to_delete = []
for to_cell in self.adhoc_connectors.keys():
to_cell_info = FqcnInfo(to_cell)
if to_cell_info.is_on_server:
cells_to_delete.append(to_cell)
for cell_name in cells_to_delete:
self.logger.debug(f"{self.my_info.fqcn}: removing adhoc connector to {cell_name}")
connector = self.adhoc_connectors.pop(cell_name, None)
if connector:
try:
self.communicator.remove_connector(connector.handle)
self.communicator.remove_endpoint(cell_name)
except:
self.log_error(
msg=None, log_text=f"error removing adhoc connector to {cell_name}", log_except=True
)
[docs] def drop_agents(self):
# drop agents
with self.agent_lock:
agents_to_delete = []
for fqcn, agent in self.agents.items():
assert isinstance(agent, CellAgent)
if agent.info.is_on_server:
agents_to_delete.append(fqcn)
for a in agents_to_delete:
self.logger.debug(f"{self.my_info.fqcn}: removing agent {a}")
self.agents.pop(a, None)
[docs] def make_internal_listener(self):
"""
Create the internal listener for child cells of this cell to connect to.
Returns:
"""
self._create_internal_listener()
[docs] def get_internal_listener_url(self) -> Union[None, str]:
"""Get the cell's internal listener url.
This method should only be used for cells that need to have child cells.
The url returned is to be passed to child of this cell to create connection
Returns: url for child cells to connect
"""
if not self.int_listener:
return None
return self.int_listener.get_connection_url()
def _add_adhoc_connector(self, to_cell: str, url: str):
if self.bb_ext_connector:
# it is possible that the server root offers connect url after the bb_ext_connector is created
# but the actual connection has not been established.
# Do not create another adhoc connection to the server!
if isinstance(self.root_url, str) and url == self.root_url:
return None
if isinstance(self.root_url, list) and url in self.root_url:
return None
with self.adhoc_connector_lock:
if to_cell in self.adhoc_connectors:
return self.adhoc_connectors[to_cell]
connector = self.connector_manager.get_external_connector(url, adhoc=True)
self.adhoc_connectors[to_cell] = connector
if connector:
self.logger.info(
f"{self.my_info.fqcn}: created adhoc connector {connector.handle} to {url} on {to_cell}"
)
else:
self.logger.info(f"{self.my_info.fqcn}: cannot create adhoc connector to {url} on {to_cell}")
return connector
def _create_internal_listener(self):
# internal listener is always backbone
if not self.int_listener:
self.int_listener = self.connector_manager.get_internal_listener()
if self.int_listener:
self.logger.info(
f"{self.my_info.fqcn}: created backbone internal listener "
f"for {self.int_listener.get_connection_url()}"
)
else:
raise RuntimeError(f"{self.my_info.fqcn}: cannot create backbone internal listener")
return self.int_listener
def _create_external_listener(self, url: str):
adhoc = len(url) == 0
if adhoc and not self.connector_manager.adhoc_allowed:
return None
with self.ext_listener_lock:
if url:
listener = self.ext_listeners.get(url)
if listener:
return listener
elif len(self.ext_listeners) > 0:
# no url specified - just pick one if any
k = random.choice(list(self.ext_listeners))
return self.ext_listeners[k]
listener = None
if not self.ext_listener_impossible:
self.logger.debug(f"{self.my_info.fqcn}: trying create ext listener: url={url}")
listener = self.connector_manager.get_external_listener(url, adhoc)
if listener:
if not adhoc:
self.logger.info(f"{self.my_info.fqcn}: created backbone external listener for {url}")
else:
self.logger.info(
f"{self.my_info.fqcn}: created adhoc external listener {listener.handle} "
f"for {listener.get_connection_url()}"
)
self.ext_listeners[listener.get_connection_url()] = listener
else:
if not adhoc:
raise RuntimeError(
f"{os.getpid()}: {self.my_info.fqcn}: "
f"cannot create backbone external listener for {url}"
)
else:
self.logger.warning(f"{self.my_info.fqcn}: cannot create adhoc external listener")
self.ext_listener_impossible = True
return listener
def _create_bb_external_connector(self):
if not self.root_url:
return
# if the root server a local cell?
if self.ALL_CELLS.get(FQCN.ROOT_SERVER):
# no need to connect
return
self.logger.debug(f"{self.my_info.fqcn}: creating connector to {self.root_url}")
self.bb_ext_connector = self.connector_manager.get_external_connector(self.root_url, False)
if self.bb_ext_connector:
self.logger.info(f"{self.my_info.fqcn}: created backbone external connector to {self.root_url}")
else:
raise RuntimeError(f"{self.my_info.fqcn}: cannot create backbone external connector to {self.root_url}")
def _create_internal_connector(self, url: str):
self.bb_int_connector = self.connector_manager.get_internal_connector(url)
if self.bb_int_connector:
self.logger.info(f"{self.my_info.fqcn}: created backbone internal connector to {url} on parent")
else:
raise RuntimeError(f"{self.my_info.fqcn}: cannot create backbone internal connector to {url} on parent")
[docs] def set_cell_connected_cb(self, cb, *args, **kwargs):
"""
Set a callback that is called when an external cell is connected.
Args:
cb: the callback function. It must follow the signature of cell_connected_cb_signature.
*args: args to be passed to the cb.
**kwargs: kwargs to be passed to the cb
Returns: None
"""
if not callable(cb):
raise ValueError(f"specified cell_connected_cb {type(cb)} is not callable")
self.cell_connected_cb = cb
self.cell_connected_cb_args = args
self.cell_connected_cb_kwargs = kwargs
[docs] def set_cell_disconnected_cb(self, cb, *args, **kwargs):
"""
Set a callback that is called when an external cell is disconnected.
Args:
cb: the callback function. It must follow the signature of cell_disconnected_cb_signature.
*args: args to be passed to the cb.
**kwargs: kwargs to be passed to the cb
Returns: None
"""
if not callable(cb):
raise ValueError(f"specified cell_disconnected_cb {type(cb)} is not callable")
self.cell_disconnected_cb = cb
self.cell_disconnected_cb_args = args
self.cell_disconnected_cb_kwargs = kwargs
[docs] def set_message_interceptor(self, cb, *args, **kwargs):
"""
Set a callback that is called when a message is received or forwarded.
Args:
cb: the callback function. It must follow the signature of message_interceptor_signature.
*args: args to be passed to the cb.
**kwargs: kwargs to be passed to the cb
Returns: None
"""
if not callable(cb):
raise ValueError(f"specified message_interceptor {type(cb)} is not callable")
self.message_interceptor = cb
self.message_interceptor_args = args
self.message_interceptor_kwargs = kwargs
[docs] def start(self):
"""
Start the cell after it is fully set up (connectors and listeners are added, CBs are set up)
Returns:
"""
if self.my_info.is_on_server:
if self.my_info.is_root:
self._set_bb_for_server_root()
else:
self._set_bb_for_server_child(self.parent_url, self.create_internal_listener)
else:
# client side
if self.my_info.is_root:
self._set_bb_for_client_root()
else:
self._set_bb_for_client_child(self.parent_url, self.create_internal_listener)
self.communicator.start()
self.running = True
[docs] def stop(self):
"""
Cleanup the cell. Once the cell is stopped, it won't be able to send/receive messages.
Returns:
"""
if not self.running:
return
if self.stopping:
return
self.stopping = True
self.logger.debug(f"{self.my_info.fqcn}: Stopping Cell")
# notify peers that I am gone
with self.agent_lock:
if self.agents:
targets = [peer_name for peer_name in self.agents.keys()]
self.logger.debug(f"broadcasting goodbye to {targets}")
self.broadcast_request(
channel=_CHANNEL,
topic=_TOPIC_BYE,
targets=targets,
request=Message(),
timeout=0.5,
optional=True,
)
self.running = False
self.asked_to_stop = True
if self.bulk_checker is not None and self.bulk_checker.is_alive():
self.bulk_checker.join()
if self.bulk_processor is not None and self.bulk_processor.is_alive():
self.bulk_processor.join()
try:
# we can now stop the communicator
self.communicator.stop()
except Exception as ex:
self.log_error(
msg=None, log_text=f"error stopping Communicator: {secure_format_exception(ex)}", log_except=True
)
self.logger.debug(f"{self.my_info.fqcn}: cell stopped!")
[docs] def register_request_cb(self, channel: str, topic: str, cb, *args, **kwargs):
"""
Register a callback for handling request. The CB must follow request_cb_signature.
Args:
channel: the channel of the request
topic: topic of the request
cb:
*args:
**kwargs:
Returns:
"""
if not callable(cb):
raise ValueError(f"specified request_cb {type(cb)} is not callable")
self.req_reg.set(channel, topic, Callback(cb, args, kwargs))
[docs] def encrypt_payload(self, message: Message):
if not message.get_header(MessageHeaderKey.SECURE, False):
return
encrypted = message.get_header(MessageHeaderKey.ENCRYPTED, False)
if encrypted:
# Prevent double encryption
return
target = message.get_header(MessageHeaderKey.DESTINATION)
if not target:
raise RuntimeError("Message destination missing")
if message.payload is None:
message.payload = bytes(0)
payload_len = len(message.payload)
message.add_headers(
{
MessageHeaderKey.CLEAR_PAYLOAD_LEN: payload_len,
MessageHeaderKey.ENCRYPTED: True,
}
)
target_cert = self.cert_ex.get_certificate(target)
message.payload = self.credential_manager.encrypt(target_cert, message.payload)
self.logger.debug(f"Payload ({payload_len} bytes) is encrypted ({len(message.payload)} bytes)")
[docs] def decrypt_payload(self, message: Message):
if not message.get_header(MessageHeaderKey.SECURE, False):
return
encrypted = message.get_header(MessageHeaderKey.ENCRYPTED, False)
if not encrypted:
# Message is already decrypted
return
message.remove_header(MessageHeaderKey.ENCRYPTED)
origin = message.get_header(MessageHeaderKey.ORIGIN)
if not origin:
raise RuntimeError("Message origin missing")
payload_len = message.get_header(MessageHeaderKey.CLEAR_PAYLOAD_LEN)
origin_cert = self.cert_ex.get_certificate(origin)
message.payload = self.credential_manager.decrypt(origin_cert, message.payload)
if len(message.payload) != payload_len:
raise RuntimeError(f"Payload size changed after decryption {len(message.payload)} <> {payload_len}")
[docs] def add_incoming_request_filter(self, channel: str, topic: str, cb, *args, **kwargs):
if not callable(cb):
raise ValueError(f"specified incoming_request_filter {type(cb)} is not callable")
self.in_req_filter_reg.append(channel, topic, Callback(cb, args, kwargs))
[docs] def add_outgoing_reply_filter(self, channel: str, topic: str, cb, *args, **kwargs):
if not callable(cb):
raise ValueError(f"specified outgoing_reply_filter {type(cb)} is not callable")
self.out_reply_filter_reg.append(channel, topic, Callback(cb, args, kwargs))
[docs] def add_outgoing_request_filter(self, channel: str, topic: str, cb, *args, **kwargs):
if not callable(cb):
raise ValueError(f"specified outgoing_request_filter {type(cb)} is not callable")
self.out_req_filter_reg.append(channel, topic, Callback(cb, args, kwargs))
[docs] def add_incoming_reply_filter(self, channel: str, topic: str, cb, *args, **kwargs):
if not callable(cb):
raise ValueError(f"specified incoming_reply_filter {type(cb)} is not callable")
self.in_reply_filter_reg.append(channel, topic, Callback(cb, args, kwargs))
[docs] def add_error_handler(self, channel: str, topic: str, cb, *args, **kwargs):
if not callable(cb):
raise ValueError(f"specified error_handler {type(cb)} is not callable")
self.error_handler_reg.set(channel, topic, Callback(cb, args, kwargs))
def _filter_outgoing_request(self, channel: str, topic: str, request: Message) -> Union[None, Message]:
cbs = self.out_req_filter_reg.find(channel, topic)
if not cbs:
return None
for _cb in cbs:
assert isinstance(_cb, Callback)
reply = self._try_cb(request, _cb.cb, *_cb.args, **_cb.kwargs)
if reply:
return reply
def _try_path(self, fqcn_path: List[str]) -> Union[None, Endpoint]:
self.logger.debug(f"{self.my_info.fqcn}: trying path {fqcn_path} ...")
target = FQCN.join(fqcn_path)
agent = self.agents.get(target, None)
if agent:
# there is a direct path to the target call
self.logger.debug(f"{self.my_info.fqcn}: got cell agent for {target}")
return agent.endpoint
else:
self.logger.debug(f"{self.my_info.fqcn}: no CellAgent for {target}")
if len(fqcn_path) == 1:
return None
return self._try_path(fqcn_path[:-1])
def _find_endpoint(self, target_fqcn: str, for_msg: Message) -> Tuple[str, Union[None, Endpoint]]:
err = FQCN.validate(target_fqcn)
if err:
self.log_error(msg=None, log_text=f"invalid target FQCN '{target_fqcn}': {err}")
return ReturnCode.INVALID_TARGET, None
try:
ep = self._try_find_ep(target_fqcn, for_msg)
if not ep:
return ReturnCode.TARGET_UNREACHABLE, None
return "", ep
except:
self.log_error(msg=for_msg, log_text=f"Error when finding {target_fqcn}", log_except=True)
return ReturnCode.TARGET_UNREACHABLE, None
def _try_find_ep(self, target_fqcn: str, for_msg: Message) -> Union[None, Endpoint]:
self.logger.debug(f"{self.my_info.fqcn}: finding path to {target_fqcn}")
if target_fqcn == self.my_info.fqcn:
return self.endpoint
target_info = FqcnInfo(target_fqcn)
# is there a direct path to the target?
if target_fqcn in self.ALL_CELLS:
return Endpoint(target_fqcn)
agent = self.agents.get(target_fqcn)
if agent:
return agent.endpoint
if same_family(self.my_info, target_info):
if FQCN.is_parent(self.my_info.fqcn, target_fqcn):
self.log_warning(msg=for_msg, log_text=f"no connection to child {target_fqcn}")
return None
elif FQCN.is_parent(target_fqcn, self.my_info.fqcn):
self.log_warning(f"no connection to parent {target_fqcn}", for_msg)
self.logger.debug(f"{self.my_info.fqcn}: find path in the same family")
if FQCN.is_ancestor(self.my_info.fqcn, target_fqcn):
# I am the ancestor of the target
self.logger.debug(f"{self.my_info.fqcn}: I'm ancestor of the target {target_fqcn}")
return self._try_path(target_info.path)
else:
# target is my ancestor, or we share the same ancestor - go to my parent!
self.logger.debug(f"{self.my_info.fqcn}: target {target_fqcn} is or share my ancestor")
parent_fqcn = FQCN.get_parent(self.my_info.fqcn)
agent = self.agents.get(parent_fqcn)
if not agent:
self.log_warning(f"no connection to parent {parent_fqcn}", for_msg)
return None
return agent.endpoint
# not the same family
ep = self._try_path(target_info.path)
if ep:
return ep
# cannot find path to the target
# try the server root
# we assume that all client roots connect to the server root.
# Do so only if I'm not the server root
if not self.my_info.is_root or not self.my_info.is_on_server:
if FQCN.ROOT_SERVER in self.ALL_CELLS:
return Endpoint(FQCN.ROOT_SERVER)
root_agent = self.agents.get(FQCN.ROOT_SERVER)
if root_agent:
return root_agent.endpoint
# no direct path to the server root
# let my parent handle it if I have a parent
if self.my_info.gen > 1:
parent_fqcn = FQCN.get_parent(self.my_info.fqcn)
agent = self.agents.get(parent_fqcn)
if not agent:
self.log_warning(f"no connection to parent {parent_fqcn}", for_msg)
return None
return agent.endpoint
self.log_warning(f"no connection to {target_fqcn}", for_msg)
return None
def _send_to_endpoint(self, to_endpoint: Endpoint, message: Message) -> str:
err = ""
try:
encode_payload(message)
self.encrypt_payload(message)
message.set_header(MessageHeaderKey.SEND_TIME, time.time())
if not message.payload:
msg_size = 0
else:
msg_size = len(message.payload)
if msg_size > self.max_msg_size:
err_text = f"message is too big ({msg_size} > {self.max_msg_size}"
self.log_error(err_text, message)
err = ReturnCode.MSG_TOO_BIG
else:
direct_cell = self.ALL_CELLS.get(to_endpoint.name)
msg_size_mbs = self._msg_size_mbs(message)
if direct_cell:
# create a thread and fire the cell's process_message!
# self.DIRECT_MSG_EXECUTOR.submit(self._send_direct_message, direct_cell, message)
self._send_direct_message(direct_cell, message)
else:
self.communicator.send(to_endpoint, CoreCell.APP_ID, message)
self.sent_msg_size_pool.record_value(category=self._stats_category(message), value=msg_size_mbs)
except Exception as ex:
err_text = f"Failed to send message to {to_endpoint.name}: {secure_format_exception(ex)}"
self.log_error(err_text, message)
self.logger.debug(secure_format_traceback())
err = ReturnCode.COMM_ERROR
return err
def _send_direct_message(self, target_cell, message):
target_cell.process_message(
endpoint=Endpoint(self.my_info.fqcn), connection=None, app_id=self.APP_ID, message=message
)
def _send_target_messages(
self,
target_msgs: Dict[str, TargetMessage],
) -> Dict[str, str]:
if not self.running:
raise RuntimeError("Messenger is not running")
send_errs = {}
reachable_targets = {} # target fqcn => endpoint
for t, tm in target_msgs.items():
err, ep = self._find_endpoint(t, tm.message)
if ep:
reachable_targets[t] = ep
else:
msg = Message(headers=copy.copy(tm.message.headers), payload=tm.message.payload)
msg.add_headers(
{
MessageHeaderKey.CHANNEL: tm.channel,
MessageHeaderKey.TOPIC: tm.topic,
MessageHeaderKey.FROM_CELL: self.my_info.fqcn,
MessageHeaderKey.TO_CELL: t,
MessageHeaderKey.ORIGIN: self.my_info.fqcn,
MessageHeaderKey.DESTINATION: t,
}
)
self.log_error(f"cannot send to '{t}': {err}", msg)
send_errs[t] = err
for t, ep in reachable_targets.items():
tm = target_msgs[t]
req = Message(headers=copy.copy(tm.message.headers), payload=tm.message.payload)
req.add_headers(
{
MessageHeaderKey.CHANNEL: tm.channel,
MessageHeaderKey.TOPIC: tm.topic,
MessageHeaderKey.FROM_CELL: self.my_info.fqcn,
MessageHeaderKey.TO_CELL: ep.name,
MessageHeaderKey.ORIGIN: self.my_info.fqcn,
MessageHeaderKey.DESTINATION: t,
MessageHeaderKey.MSG_TYPE: MessageType.REQ,
MessageHeaderKey.ROUTE: [(self.my_info.fqcn, time.time())],
}
)
# invoke outgoing req filters
req_filters = self.out_req_filter_reg.find(tm.channel, tm.topic)
if req_filters:
self.logger.debug(f"{self.my_info.fqcn}: invoking outgoing request filters")
assert isinstance(req_filters, list)
for f in req_filters:
assert isinstance(f, Callback)
r = self._try_cb(req, f.cb, *f.args, **f.kwargs)
if r:
send_errs[t] = ReturnCode.FILTER_ERROR
break
if send_errs.get(t):
# process next target
continue
# is this a direct path?
ti = FqcnInfo(t)
allow_adhoc = self.connector_manager.is_adhoc_allowed(ti, self.my_info)
if allow_adhoc and t != ep.name:
# Not a direct path since the destination and the next leg are not the same
if not ti.is_on_server and (self.my_info.is_on_server or self.my_info.fqcn > t):
# try to get or create a listener and let the peer know the endpoint
listener = self._create_external_listener("")
if listener:
conn_url = listener.get_connection_url()
req.set_header(MessageHeaderKey.CONN_URL, conn_url)
err = self._send_to_endpoint(ep, req)
if err:
self.log_error(f"failed to send to endpoint {ep.name}: {err}", req)
else:
self.sent_msg_counter_pool.increment(category=self._stats_category(req), counter_name=_CounterName.SENT)
send_errs[t] = err
return send_errs
def _send_to_targets(
self,
channel: str,
topic: str,
targets: Union[str, List[str]],
message: Message,
) -> Dict[str, str]:
if isinstance(targets, str):
targets = [targets]
target_msgs = {}
for t in targets:
target_msgs[t] = TargetMessage(t, channel, topic, message)
return self._send_target_messages(target_msgs)
[docs] def send_request(
self, channel: str, topic: str, target: str, request: Message, timeout=None, secure=False, optional=False
) -> Message:
self.logger.debug(f"{self.my_info.fqcn}: sending request {channel}:{topic} to {target}")
result = self.broadcast_request(channel, topic, [target], request, timeout, secure, optional)
assert isinstance(result, dict)
return result.get(target)
[docs] def broadcast_multi_requests(
self, target_msgs: Dict[str, TargetMessage], timeout=None, secure=False, optional=False
) -> Dict[str, Message]:
"""
This is the core of the request/response handling. Be extremely careful when making any changes!
To maximize the communication efficiency, we avoid the use of locks.
We use a waiter implemented as a Python threading.Event object.
We create the waiter, send out messages, set up default responses, and set it up to wait for response.
Once the waiter is triggered from a reply-receiving thread, we process received results.
HOWEVER, if the network is extremely fast, the response may already be received even before we finish setting
up the waiter in this thread!
We had a very mysterious bug that caused a request to be treated as timeout even though the reply is received.
It was both threads try to set values to "waiter.replies". In case of extremely fast network, the reply
processing thread set the reply to "waiter.replies", and then overwritten by this thread with a default timeout
reply.
To avoid this kind of problems, we now use two sets of values in the waiter object.
One set is for this thread: targets
Another set is for the reply processing thread: received_replies, reply_time
Args:
target_msgs: messages to be sent
timeout: timeout value
secure: End-end encryption
optional: whether the message is optional
Returns: a dict of: target name => reply message
"""
targets = [t for t in target_msgs]
self.logger.debug(f"{self.my_info.fqcn}: broadcasting to {targets} ...")
waiter = _Waiter(targets)
if waiter.id in self.waiters:
raise RuntimeError("waiter not unique!")
self.waiters[waiter.id] = waiter
now = time.time()
if not timeout:
timeout = self.max_timeout
result = {}
try:
for _, tm in target_msgs.items():
request = tm.message
request.add_headers(
{
MessageHeaderKey.REQ_ID: waiter.id,
MessageHeaderKey.REPLY_EXPECTED: True,
MessageHeaderKey.SECURE: secure,
MessageHeaderKey.OPTIONAL: optional,
}
)
send_errs = self._send_target_messages(target_msgs)
send_count = 0
timeout_reply = make_reply(ReturnCode.TIMEOUT)
# NOTE: it is possible that reply is already received and the waiter is triggered by now!
# if waiter.received_replies:
# self.logger.info(f"{self.my_info.fqcn}: the network is extremely fast - response already received!")
topics = []
for_msg = None
for t, err in send_errs.items():
if not err:
send_count += 1
result[t] = timeout_reply
tm = target_msgs[t]
topic = tm.message.get_header(MessageHeaderKey.TOPIC, "?")
if topic not in topics:
topics.append(topic)
if not for_msg:
for_msg = tm.message
else:
result[t] = make_reply(rc=err)
waiter.reply_time[t] = now
if send_count > 0:
self.num_sar_reqs += 1
num_reqs = len(self.waiters)
if self.req_hw < num_reqs:
self.req_hw = num_reqs
# wait for reply
self.logger.debug(f"{self.my_info.fqcn}: set up waiter {waiter.id} to wait for {timeout} secs")
if not waiter.wait(timeout=timeout):
# timeout
self.log_error(f"timeout on Request {waiter.id} for {topics} after {timeout} secs", for_msg)
with self.stats_lock:
self.num_timeout_reqs += 1
except Exception as ex:
raise ex
finally:
self.waiters.pop(waiter.id, None)
self.logger.debug(f"released waiter on REQ {waiter.id}")
if waiter.received_replies:
result.update(waiter.received_replies)
for t, reply in result.items():
rc = reply.get_header(MessageHeaderKey.RETURN_CODE, ReturnCode.OK)
self.sent_msg_counter_pool.increment(category=self._stats_category(reply), counter_name=rc)
return result
[docs] def broadcast_request(
self,
channel: str,
topic: str,
targets: Union[str, List[str]],
request: Message,
timeout=None,
secure=False,
optional=False,
) -> Dict[str, Message]:
"""
Send a message over a channel to specified destination cell(s), and wait for reply
Args:
channel: channel for the message
topic: topic of the message
targets: FQCN of the destination cell(s)
request: message to be sent
timeout: how long to wait for replies
secure: End-end encryption
optional: whether the message is optional
Returns: a dict of: cell_id => reply message
"""
if isinstance(targets, str):
targets = [targets]
target_msgs = {}
for t in targets:
target_msgs[t] = TargetMessage(t, channel, topic, request)
return self.broadcast_multi_requests(target_msgs, timeout, secure=secure, optional=optional)
[docs] def fire_and_forget(
self, channel: str, topic: str, targets: Union[str, List[str]], message: Message, secure=False, optional=False
) -> Dict[str, str]:
"""
Send a message over a channel to specified destination cell(s), and do not wait for replies.
Args:
channel: channel for the message
topic: topic of the message
targets: one or more destination cell IDs. None means all.
message: message to be sent
secure: End-end encryption of the message
optional: whether the message is optional
Returns: None
"""
message.add_headers(
{
MessageHeaderKey.REPLY_EXPECTED: False,
MessageHeaderKey.OPTIONAL: optional,
MessageHeaderKey.SECURE: secure,
}
)
return self._send_to_targets(channel, topic, targets, message)
[docs] def queue_message(self, channel: str, topic: str, targets: Union[str, List[str]], message: Message, optional=False):
if self.max_bulk_size <= 0:
raise RuntimeError(f"{self.get_fqcn()}: bulk message is not enabled!")
if isinstance(targets, str):
targets = [targets]
message.set_header(MessageHeaderKey.OPTIONAL, optional)
with self.bulk_lock:
if self.bulk_checker is None:
self.logger.info(f"{self.my_info.fqcn}: starting bulk_checker")
self.bulk_checker = threading.Thread(target=self._check_bulk, name="check_bulk_msg")
self.bulk_checker.start()
self.logger.info(f"{self.my_info.fqcn}: started bulk_checker")
for t in targets:
sender = self.bulk_senders.get(t)
if not sender:
sender = _BulkSender(cell=self, target=t, max_queue_size=self.max_bulk_size)
self.bulk_senders[t] = sender
sender.queue_message(channel=channel, topic=topic, message=message)
self.logger.info(f"{self.get_fqcn()}: queued msg for {t}")
def _peer_goodbye(self, request: Message):
peer_ep = request.get_prop(MessagePropKey.ENDPOINT)
if not peer_ep:
self.log_error("no endpoint prop in message", request)
return
assert isinstance(peer_ep, Endpoint)
with self.agent_lock:
self.logger.debug(f"{self.my_info.fqcn}: got goodbye from cell {peer_ep.name}")
ep = self.agents.pop(peer_ep.name, None)
if ep:
self.logger.debug(f"{self.my_info.fqcn}: removed agent for {peer_ep.name}")
else:
self.logger.debug(f"{self.my_info.fqcn}: agent for {peer_ep.name} is already gone")
# ack back
return Message()
def _receive_bulk_message(self, request: Message):
target_msgs = request.payload
assert isinstance(target_msgs, list)
with self.bulk_msg_lock:
if self.bulk_processor is None:
self.logger.debug(f"{self.my_info.fqcn}: starting bulk message processor")
self.bulk_processor = threading.Thread(target=self._process_bulk_messages, name="process_bulk_msg")
self.bulk_processor.start()
self.logger.debug(f"{self.my_info.fqcn}: started bulk message processor")
self.bulk_messages.append(request)
self.logger.debug(f"{self.get_fqcn()}: received bulk msg. Pending size {len(self.bulk_messages)}")
def _process_bulk_messages(self):
self.logger.debug(f"{self.get_fqcn()}: processing bulks ...")
while not self.asked_to_stop:
self._process_pending_bulks()
time.sleep(self.bulk_process_interval)
# process remaining messages if any
self._process_pending_bulks()
def _process_pending_bulks(self):
while True:
with self.bulk_msg_lock:
if not self.bulk_messages:
return
bulk = self.bulk_messages.pop(0)
self._process_one_bulk(bulk)
def _process_one_bulk(self, bulk_request: Message):
target_msgs = bulk_request.payload
assert isinstance(target_msgs, list)
self.logger.debug(f"{self.get_fqcn()}: processing one bulk size {len(target_msgs)}")
for tmd in target_msgs:
assert isinstance(tmd, dict)
tm = TargetMessage.from_dict(tmd)
assert isinstance(tm, TargetMessage)
req = tm.message
req.add_headers(bulk_request.headers)
req.add_headers({MessageHeaderKey.TOPIC: tm.topic, MessageHeaderKey.CHANNEL: tm.channel})
origin = bulk_request.get_header(MessageHeaderKey.ORIGIN, "")
self.logger.debug(f"{self.get_fqcn()}: bulk item: {req.headers}")
self._process_request(origin=origin, message=req)
[docs] def fire_multi_requests_and_forget(self, target_msgs: Dict[str, TargetMessage], optional=False) -> Dict[str, str]:
for _, tm in target_msgs.items():
request = tm.message
request.add_headers({MessageHeaderKey.REPLY_EXPECTED: False, MessageHeaderKey.OPTIONAL: optional})
return self._send_target_messages(target_msgs)
[docs] def send_reply(self, reply: Message, to_cell: str, for_req_ids: List[str], secure=False, optional=False) -> str:
"""Send a reply to respond to one or more requests.
This is useful if the request receiver needs to delay its reply as follows:
- When a request is received, if it's not ready to reply (e.g. waiting for additional requests from
other cells), simply remember the REQ_ID and returns None;
- The receiver may queue up multiple such requests
- When ready, call this method to send the reply for all the queued requests
Args:
reply: the reply message
to_cell: the target cell
for_req_ids: the list of req IDs that the reply is for
secure: End-end encryption
optional: whether the message is optional
Returns: an error message if any
"""
reply.add_headers(
{
MessageHeaderKey.FROM_CELL: self.my_info.fqcn,
MessageHeaderKey.TO_CELL: to_cell,
MessageHeaderKey.ORIGIN: self.my_info.fqcn,
MessageHeaderKey.DESTINATION: to_cell,
MessageHeaderKey.REQ_ID: for_req_ids,
MessageHeaderKey.MSG_TYPE: MessageType.REPLY,
MessageHeaderKey.ROUTE: [(self.my_info.fqcn, time.time())],
MessageHeaderKey.SECURE: secure,
MessageHeaderKey.OPTIONAL: optional,
}
)
err, ep = self._find_endpoint(to_cell, reply)
if err:
return err
reply.set_header(MessageHeaderKey.TO_CELL, ep.name)
return self._send_to_endpoint(ep, reply)
def _try_cb(self, message, cb, *args, **kwargs):
try:
self.logger.debug(f"{self.my_info.fqcn}: calling CB {cb.__name__}")
return cb(message, *args, **kwargs)
except ServiceUnavailable:
return make_reply(ReturnCode.SERVICE_UNAVAILABLE)
except InvalidSession:
return make_reply(ReturnCode.INVALID_SESSION)
except InvalidRequest:
return make_reply(ReturnCode.INVALID_REQUEST)
except AuthenticationError:
return make_reply(ReturnCode.AUTHENTICATION_ERROR)
except AbortRun:
return make_reply(ReturnCode.ABORT_RUN)
except Exception as ex:
self.log_error(
f"exception from CB {cb.__name__}: {secure_format_exception(ex)}", msg=message, log_except=True
)
return make_reply(ReturnCode.PROCESS_EXCEPTION)
[docs] def process_message(self, endpoint: Endpoint, connection: Connection, app_id: int, message: Message):
# this is the receiver callback
try:
self._process_received_msg(endpoint, connection, message)
except Exception as ex:
self.log_error(
f"Error processing received message: {secure_format_exception(ex)}", msg=message, log_except=True
)
def _process_request(self, origin: str, message: Message) -> Union[None, Message]:
self.logger.debug(f"{self.my_info.fqcn}: processing incoming request")
self.decrypt_payload(message)
decode_payload(message)
# this is a request for me - dispatch to the right CB
channel = message.get_header(MessageHeaderKey.CHANNEL, "")
topic = message.get_header(MessageHeaderKey.TOPIC, "")
_cb = self.req_reg.find(channel, topic)
if not _cb:
self.log_error(f"no callback for request ({topic}@{channel}) from cell '{origin}'", message)
return make_reply(ReturnCode.PROCESS_EXCEPTION, error="no callback")
# invoke incoming request filters
req_filters = self.in_req_filter_reg.find(channel, topic)
if req_filters:
self.logger.debug(f"{self.my_info.fqcn}: invoking incoming request filters")
assert isinstance(req_filters, list)
for f in req_filters:
assert isinstance(f, Callback)
reply = self._try_cb(message, f.cb, *f.args, **f.kwargs)
if reply:
return reply
assert isinstance(_cb, Callback)
self.logger.debug(f"{self.my_info.fqcn}: calling registered request CB")
cb_start = time.perf_counter()
reply = self._try_cb(message, _cb.cb, *_cb.args, **_cb.kwargs)
cb_end = time.perf_counter()
self.req_cb_stats_pool.record_value(category=self._stats_category(message), value=cb_end - cb_start)
if not reply:
# the CB doesn't have anything to reply
return None
if not isinstance(reply, Message):
self.log_error(
f"bad result from request CB for topic {topic} on channel {channel}: "
f"expect Message but got {type(reply)}",
msg=message,
)
reply = make_reply(ReturnCode.PROCESS_EXCEPTION, error="bad cb result")
# Reply must be secure if request is
reply.add_headers({MessageHeaderKey.SECURE: message.get_header(MessageHeaderKey.SECURE, False)})
return reply
def _add_to_route(self, message: Message):
route = message.get_header(MessageHeaderKey.ROUTE, None)
if not route:
route = []
message.set_header(MessageHeaderKey.ROUTE, route)
if not isinstance(route, list):
self.log_error(f"bad route header: expect list but got {type(route)}", msg=message)
else:
route.append((self.my_info.fqcn, time.time()))
def _forward(self, endpoint: Endpoint, origin: str, destination: str, msg_type: str, message: Message):
# not for me - need to forward it
self.logger.debug(f"{self.my_info.fqcn}: forwarding for {origin} to {destination}")
err, ep = self._find_endpoint(destination, message)
if ep:
self.logger.debug(f"{self.my_info.fqcn}: found next leg {ep.name}")
message.add_headers({MessageHeaderKey.FROM_CELL: self.my_info.fqcn, MessageHeaderKey.TO_CELL: ep.name})
self._add_to_route(message)
err = self._send_to_endpoint(to_endpoint=ep, message=message)
if not err:
self.logger.debug(f"{self.my_info.fqcn}: forwarded successfully!")
return
else:
self.log_error(f"failed to forward {msg_type}: {err}", msg=message)
else:
# cannot find next leg endpoint
self.log_error(f"cannot forward {msg_type}: no path", message)
if msg_type == MessageType.REQ:
reply_expected = message.get_header(MessageHeaderKey.REPLY_EXPECTED, False)
if not reply_expected:
self.logger.debug(f"{self.my_info.fqcn}: can't forward: drop the message since reply is not expected")
return
# tell the requester that message couldn't be delivered
req_id = message.get_header(MessageHeaderKey.REQ_ID, "")
reply = make_reply(ReturnCode.COMM_ERROR, error="cannot forward")
reply.add_headers(
{
MessageHeaderKey.ORIGINAL_HEADERS: message.headers,
MessageHeaderKey.FROM_CELL: self.my_info.fqcn,
MessageHeaderKey.TO_CELL: endpoint.name,
MessageHeaderKey.ORIGIN: self.my_info.fqcn,
MessageHeaderKey.DESTINATION: origin,
MessageHeaderKey.REQ_ID: [req_id],
MessageHeaderKey.MSG_TYPE: MessageType.RETURN,
MessageHeaderKey.ROUTE: [(self.my_info.fqcn, time.time())],
MessageHeaderKey.RETURN_REASON: ReturnReason.CANT_FORWARD,
}
)
self._send_to_endpoint(endpoint, reply)
self.logger.debug(f"{self.my_info.fqcn}: sent RETURN message back to {endpoint.name}")
else:
# msg_type is either RETURN or REPLY - drop it.
self.logger.debug(format_log_message(self.my_info.fqcn, message, "dropped forwarded message"))
def _stats_category(self, message: Message):
channel = message.get_header(MessageHeaderKey.CHANNEL, "?")
topic = message.get_header(MessageHeaderKey.TOPIC, "?")
msg_type = message.get_header(MessageHeaderKey.MSG_TYPE, "?")
dest = message.get_header(MessageHeaderKey.DESTINATION, "")
origin = message.get_header(MessageHeaderKey.ORIGIN, "")
to_cell = message.get_header(MessageHeaderKey.TO_CELL, "")
type_tag = msg_type
if dest and origin:
if dest != self.my_info.fqcn and origin != self.my_info.fqcn:
# this is the case of forwarding
type_tag = "fwd." + msg_type
if msg_type == MessageType.RETURN:
orig_headers = message.get_header(MessageHeaderKey.ORIGINAL_HEADERS, None)
if orig_headers:
channel = orig_headers.get(MessageHeaderKey.CHANNEL, "??")
topic = orig_headers.get(MessageHeaderKey.TOPIC, "??")
else:
channel = "???"
topic = "???"
return f"{type_tag}:{channel}:{topic}"
def _process_reply(self, origin: str, message: Message, msg_type: str):
channel = message.get_header(MessageHeaderKey.CHANNEL, "")
topic = message.get_header(MessageHeaderKey.TOPIC, "")
now = time.time()
self.logger.debug(f"{self.my_info.fqcn}: processing reply from {origin} for type {msg_type}")
self.decrypt_payload(message)
decode_payload(message)
req_ids = message.get_header(MessageHeaderKey.REQ_ID)
if not req_ids:
raise RuntimeError(format_log_message(self.my_info.fqcn, message, "reply does not have REQ_ID header"))
if isinstance(req_ids, str):
req_ids = [req_ids]
if not isinstance(req_ids, list):
raise RuntimeError(
format_log_message(self.my_info.fqcn, message, f"REQ_ID must be list of ids but got {type(req_ids)}")
)
req_destination = origin
if msg_type == MessageType.RETURN:
self.logger.debug(format_log_message(self.my_info.fqcn, message, "message is returned"))
self.sent_msg_counter_pool.increment(
category=self._stats_category(message), counter_name=_CounterName.RETURN
)
original_headers = message.get_header(MessageHeaderKey.ORIGINAL_HEADERS, None)
if not original_headers:
raise RuntimeError(
format_log_message(self.my_info.fqcn, message, "missing ORIGINAL_HEADERS in returned message!")
)
req_destination = original_headers.get(MessageHeaderKey.DESTINATION, None)
if not req_destination:
raise RuntimeError(
format_log_message(self.my_info.fqcn, message, "missing DESTINATION header in original headers")
)
else:
# invoking incoming reply filter
reply_filters = self.in_reply_filter_reg.find(channel, topic)
if reply_filters:
self.logger.debug(f"{self.my_info.fqcn}: invoking incoming reply filters")
assert isinstance(reply_filters, list)
for f in reply_filters:
assert isinstance(f, Callback)
self._try_cb(message, f.cb, *f.args, **f.kwargs)
for rid in req_ids:
waiter = self.waiters.get(rid, None)
if waiter:
assert isinstance(waiter, _Waiter)
if req_destination not in waiter.targets:
self.log_error(
f"unexpected reply for {rid} from {req_destination}"
f"req_destination='{req_destination}', expecting={waiter.targets}",
message,
)
return
waiter.received_replies[req_destination] = message
waiter.reply_time[req_destination] = now
time_taken = now - waiter.send_time
self.msg_stats_pool.record_value(category=self._stats_category(message), value=time_taken)
# all targets replied?
all_targets_replied = True
for t in waiter.targets:
if not waiter.reply_time.get(t):
all_targets_replied = False
break
if all_targets_replied:
self.logger.debug(
format_log_message(
self.my_info.fqcn,
message,
f"trigger waiter - replies received from {len(waiter.targets)} targets for {rid}",
)
)
waiter.set() # trigger the waiting requests!
else:
self.logger.debug(
format_log_message(
self.my_info.fqcn,
message,
f"waiting - replies not received from {len(waiter.targets)} targets for req {rid}",
)
)
else:
self.log_warning(f"no waiter for req {rid} - the reply is too late", None)
self.sent_msg_counter_pool.increment(
category=self._stats_category(message), counter_name=_CounterName.LATE
)
@staticmethod
def _msg_size_mbs(message: Message):
if message.payload:
msg_size = len(message.payload)
else:
msg_size = 0
return msg_size / _ONE_MB
def _process_received_msg(self, endpoint: Endpoint, connection: Connection, message: Message):
route = message.get_header(MessageHeaderKey.ROUTE)
if route:
origin_name = route[0][0]
t0 = route[0][1]
time_taken = time.time() - t0
self.msg_travel_stats_pool.record_value(
category=f"{origin_name}#{self._stats_category(message)}", value=time_taken
)
self.logger.debug(f"{self.my_info.fqcn}: received message: {message.headers}")
message.set_prop(MessagePropKey.ENDPOINT, endpoint)
if connection:
conn_props = connection.get_conn_properties()
cn = conn_props.get(DriverParams.PEER_CN.value)
if cn:
message.set_prop(MessagePropKey.COMMON_NAME, cn)
msg_type = message.get_header(MessageHeaderKey.MSG_TYPE)
if not msg_type:
raise RuntimeError(format_log_message(self.my_info.fqcn, message, "missing MSG_TYPE in received message"))
origin = message.get_header(MessageHeaderKey.ORIGIN)
if not origin:
raise RuntimeError(
format_log_message(self.my_info.fqcn, message, "missing ORIGIN header in received message")
)
# is this msg for me?
destination = message.get_header(MessageHeaderKey.DESTINATION)
if not destination:
raise RuntimeError(
format_log_message(self.my_info.fqcn, message, "missing DESTINATION header in received message")
)
self.received_msg_counter_pool.increment(
category=self._stats_category(message), counter_name=_CounterName.RECEIVED
)
if msg_type == MessageType.REQ and self.message_interceptor is not None:
reply = self._try_cb(
message, self.message_interceptor, *self.message_interceptor_args, **self.message_interceptor_kwargs
)
if reply:
self.logger.debug(f"{self.my_info.fqcn}: interceptor stopped message!")
reply_expected = message.get_header(MessageHeaderKey.REPLY_EXPECTED)
if not reply_expected:
return
req_id = message.get_header(MessageHeaderKey.REQ_ID, "")
reply.add_headers(
{
MessageHeaderKey.ORIGINAL_HEADERS: message.headers,
MessageHeaderKey.FROM_CELL: self.my_info.fqcn,
MessageHeaderKey.TO_CELL: endpoint.name,
MessageHeaderKey.ORIGIN: self.my_info.fqcn,
MessageHeaderKey.DESTINATION: origin,
MessageHeaderKey.REQ_ID: [req_id],
MessageHeaderKey.MSG_TYPE: MessageType.RETURN,
MessageHeaderKey.ROUTE: [(self.my_info.fqcn, time.time())],
MessageHeaderKey.RETURN_REASON: ReturnReason.INTERCEPT,
}
)
self._send_reply(reply, endpoint)
self.logger.debug(f"{self.my_info.fqcn}: returned intercepted message")
return
if destination != self.my_info.fqcn:
# not for me - need to forward it
self.sent_msg_counter_pool.increment(
category=self._stats_category(message), counter_name=_CounterName.FORWARD
)
self.received_msg_counter_pool.increment(
category=self._stats_category(message), counter_name=_CounterName.FORWARD
)
self._forward(endpoint, origin, destination, msg_type, message)
return
self.received_msg_size_pool.record_value(
category=self._stats_category(message), value=self._msg_size_mbs(message)
)
# this message is for me
self._add_to_route(message)
# handle ad-hoc
my_conn_url = None
if msg_type in [MessageType.REQ, MessageType.REPLY]:
from_cell = message.get_header(MessageHeaderKey.FROM_CELL)
oi = FqcnInfo(origin)
if from_cell != origin and not same_family(oi, self.my_info):
# this is a forwarded message, so no direct path from the origin to me
conn_url = message.get_header(MessageHeaderKey.CONN_URL)
if conn_url:
# the origin already has a listener
# create an ad-hoc connector to connect to the origin cell
self.logger.debug(f"{self.my_info.fqcn}: creating adhoc connector to {origin} at {conn_url}")
self._add_adhoc_connector(origin, conn_url)
elif msg_type == MessageType.REQ:
# see whether we can offer a listener
allow_adhoc = self.connector_manager.is_adhoc_allowed(oi, self.my_info)
if (
allow_adhoc
and (not oi.is_on_server)
and (self.my_info.fqcn > origin or self.my_info.is_on_server)
):
self.logger.debug(f"{self.my_info.fqcn}: trying to offer ad-hoc listener to {origin}")
listener = self._create_external_listener("")
if listener:
my_conn_url = listener.get_connection_url()
if msg_type == MessageType.REQ:
# this is a request for me - dispatch to the right CB
channel = message.get_header(MessageHeaderKey.CHANNEL, "")
topic = message.get_header(MessageHeaderKey.TOPIC, "")
reply = self._process_request(origin, message)
if not reply:
self.received_msg_counter_pool.increment(
category=self._stats_category(message), counter_name=_CounterName.REPLY_NONE
)
return
is_optional = message.get_header(MessageHeaderKey.OPTIONAL, False)
reply.set_header(MessageHeaderKey.OPTIONAL, is_optional)
reply_expected = message.get_header(MessageHeaderKey.REPLY_EXPECTED, False)
if not reply_expected:
# this is fire and forget
self.logger.debug(f"{self.my_info.fqcn}: don't send response - request expects no reply")
self.received_msg_counter_pool.increment(
category=self._stats_category(message), counter_name=_CounterName.REPLY_NOT_EXPECTED
)
return
# send the reply back
if not reply.headers.get(MessageHeaderKey.RETURN_CODE):
self.logger.debug(f"{self.my_info.fqcn}: added return code OK")
reply.set_header(MessageHeaderKey.RETURN_CODE, ReturnCode.OK)
req_id = message.get_header(MessageHeaderKey.REQ_ID, "")
reply.add_headers(
{
MessageHeaderKey.CHANNEL: channel,
MessageHeaderKey.TOPIC: topic,
MessageHeaderKey.FROM_CELL: self.my_info.fqcn,
MessageHeaderKey.TO_CELL: endpoint.name,
MessageHeaderKey.ORIGIN: self.my_info.fqcn,
MessageHeaderKey.DESTINATION: origin,
MessageHeaderKey.REQ_ID: req_id,
MessageHeaderKey.MSG_TYPE: MessageType.REPLY,
MessageHeaderKey.ROUTE: [(self.my_info.fqcn, time.time())],
}
)
if my_conn_url:
reply.set_header(MessageHeaderKey.CONN_URL, my_conn_url)
# invoke outgoing reply filters
reply_filters = self.out_reply_filter_reg.find(channel, topic)
if reply_filters:
self.logger.debug(f"{self.my_info.fqcn}: invoking outgoing reply filters")
assert isinstance(reply_filters, list)
for f in reply_filters:
assert isinstance(f, Callback)
r = self._try_cb(reply, f.cb, *f.args, **f.kwargs)
if r:
reply = r
break
self._send_reply(reply, endpoint)
else:
# the message is either a reply or a return for a previous request: handle replies
self._process_reply(origin, message, msg_type)
def _send_reply(self, reply: Message, endpoint: Endpoint):
self.logger.debug(f"{self.my_info.fqcn}: sending reply back to {endpoint.name}")
self.logger.debug(f"Reply message: {reply.headers}")
err = self._send_to_endpoint(endpoint, reply)
if err:
self.log_error(f"error sending reply back to {endpoint.name}: {err}", reply)
self.received_msg_counter_pool.increment(category=self._stats_category(reply), counter_name=err)
else:
self.received_msg_counter_pool.increment(
category=self._stats_category(reply), counter_name=_CounterName.REPLIED
)
rc = reply.get_header(MessageHeaderKey.RETURN_CODE)
self.received_msg_counter_pool.increment(category=self._stats_category(reply), counter_name=rc)
def _check_bulk(self):
while not self.asked_to_stop:
with self.bulk_lock:
for _, sender in self.bulk_senders.items():
sender.send()
time.sleep(self.bulk_check_interval)
# force everything to be flushed
with self.bulk_lock:
for _, sender in self.bulk_senders.items():
sender.send()
[docs] def state_change(self, endpoint: Endpoint):
self.logger.debug(f"========= {self.my_info.fqcn}: EP {endpoint.name} state changed to {endpoint.state}")
fqcn = endpoint.name
if endpoint.state == EndpointState.READY:
# create the CellAgent for this endpoint
agent = self.agents.get(fqcn)
if not agent:
agent = CellAgent(fqcn, endpoint)
with self.agent_lock:
self.agents[fqcn] = agent
self.logger.debug(f"{self.my_info.fqcn}: created CellAgent for {fqcn}")
else:
self.logger.debug(f"{self.my_info.fqcn}: found existing CellAgent for {fqcn} - shouldn't happen")
agent.endpoint = endpoint
if self.cell_connected_cb is not None:
try:
self.logger.debug(f"{self.my_info.fqcn}: calling cell_connected_cb")
self.cell_connected_cb(agent, *self.cell_connected_cb_args, **self.cell_connected_cb_kwargs)
except Exception as ex:
self.log_error(
f"exception in cell_connected_cb: {secure_format_exception(ex)}", None, log_except=True
)
elif endpoint.state in [EndpointState.CLOSING, EndpointState.DISCONNECTED, EndpointState.IDLE]:
# remove this agent
with self.agent_lock:
agent = self.agents.pop(fqcn, None)
self.logger.debug(f"{self.my_info.fqcn}: removed CellAgent {fqcn}")
if agent and self.cell_disconnected_cb is not None:
try:
self.logger.debug(f"{self.my_info.fqcn}: calling cell_disconnected_cb")
self.cell_disconnected_cb(
agent, *self.cell_disconnected_cb_args, **self.cell_disconnected_cb_kwargs
)
except Exception as ex:
self.log_error(
f"exception in cell_disconnected_cb: {secure_format_exception(ex)}", None, log_except=True
)
[docs] def get_sub_cell_names(self) -> Tuple[List[str], List[str]]:
"""
Get cell FQCNs of all subs, which are children or top-level client cells (if my cell is server).
Returns: fqcns of child cells, fqcns of top-level client cells
"""
children_dict = {}
clients_dict = {}
with self.agent_lock:
for fqcn, agent in self.agents.items():
sub_type = self._is_my_sub(agent.info)
if sub_type == self.SUB_TYPE_CHILD:
children_dict[fqcn] = True
elif sub_type == self.SUB_TYPE_CLIENT:
clients_dict[fqcn] = True
# check local cells
for fqcn in self.ALL_CELLS.keys():
sub_type = self._is_my_sub(FqcnInfo(fqcn))
if sub_type == self.SUB_TYPE_CHILD:
children_dict[fqcn] = True
elif sub_type == self.SUB_TYPE_CLIENT:
clients_dict[fqcn] = True
return list(children_dict.keys()), list(clients_dict.keys())
def _is_my_sub(self, candidate_info: FqcnInfo) -> int:
if FQCN.is_parent(self.my_info.fqcn, candidate_info.fqcn):
return self.SUB_TYPE_CHILD
if self.my_info.is_root and self.my_info.is_on_server:
# see whether the agent is a client root cell
if candidate_info.is_root and not candidate_info.is_on_server:
return self.SUB_TYPE_CLIENT
return self.SUB_TYPE_NONE
[docs] def is_secure(self):
return self.secure