Source code for nvflare.fuel.f3.sfm.conn_manager

# Copyright (c) 2023, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import os
import threading
import time
from concurrent.futures import ThreadPoolExecutor
from typing import Dict, List, Optional

import msgpack

from nvflare.fuel.f3.cellnet.fqcn import FQCN
from nvflare.fuel.f3.cellnet.identity import CellIdentityResolver, get_param, is_admin_listener, is_mtls_connection
from nvflare.fuel.f3.comm_error import CommError
from nvflare.fuel.f3.connection import BytesAlike, Connection, ConnState, FrameReceiver
from nvflare.fuel.f3.drivers.connector_info import ConnectorInfo, Mode
from nvflare.fuel.f3.drivers.driver import ConnMonitor, Driver
from nvflare.fuel.f3.drivers.driver_params import DriverCap, DriverParams
from nvflare.fuel.f3.drivers.net_utils import ssl_required
from nvflare.fuel.f3.endpoint import Endpoint, EndpointMonitor, EndpointState
from nvflare.fuel.f3.message import Message, MessageReceiver
from nvflare.fuel.f3.sfm.constants import HandshakeKeys, Types
from nvflare.fuel.f3.sfm.heartbeat_monitor import HeartbeatMonitor
from nvflare.fuel.f3.sfm.prefix import PREFIX_LEN, Prefix
from nvflare.fuel.f3.sfm.sfm_conn import SfmConnection
from nvflare.fuel.f3.sfm.sfm_endpoint import SfmEndpoint
from nvflare.fuel.f3.stats_pool import StatsPoolManager
from nvflare.fuel.utils.admin_name_utils import is_valid_admin_client_name
from nvflare.fuel.utils.buffer_list import BufferList
from nvflare.security.logging import secure_format_exception, secure_format_traceback

FRAME_THREAD_POOL_SIZE = 100
CONN_THREAD_POOL_SIZE = 16
INIT_WAIT = 1
MAX_WAIT = 10
SILENT_RECONNECT_TIME = 5
SELF_ADDR = "0.0.0.0:0"

log = logging.getLogger(__name__)

handle_lock = threading.Lock()
handle_count = 0


[docs] def get_handle(): global handle_lock, handle_count with handle_lock: handle_count += 1 return "CH%05d" % handle_count
[docs] class ConnManager(ConnMonitor): """SFM connection manager The class is responsible for maintaining state of SFM connections and pumping data through them """ def __init__(self, local_endpoint: Endpoint, identity_resolver: CellIdentityResolver = None): self.local_endpoint = local_endpoint self.identity_resolver = identity_resolver if identity_resolver else CellIdentityResolver(local_endpoint.name) # Active connectors self.connectors: Dict[str, ConnectorInfo] = {} # A dict of SFM connections, key is connection name self.sfm_conns: Dict[str, SfmConnection] = {} # A dict of SfmEndpoint for finding endpoint by name self.sfm_endpoints: Dict[str, SfmEndpoint] = {} # A list of Endpoint monitors self.monitors: List[EndpointMonitor] = [] # App/receiver mapping self.receivers: Dict[int, MessageReceiver] = {} self.started = False self.stopped = False self.conn_mgr_executor = ThreadPoolExecutor(CONN_THREAD_POOL_SIZE, "conn_mgr") self.frame_mgr_executor = ThreadPoolExecutor(FRAME_THREAD_POOL_SIZE, "frame_mgr") self.lock = threading.Lock() self.null_conn = NullConnection() stats = StatsPoolManager.get_pool("sfm_send_frame") if not stats: stats = StatsPoolManager.add_time_hist_pool( "sfm_send_frame", "SFM send_frame time in secs", scope=local_endpoint.name ) self.send_frame_stats = stats self.heartbeat_monitor = HeartbeatMonitor(self.sfm_conns)
[docs] def add_connector(self, driver: Driver, params: dict, mode: Mode) -> str: # Validate parameters capabilities = driver.capabilities() support_ssl = capabilities.get(DriverCap.SUPPORT_SSL, False) if ssl_required(params) and not support_ssl: scheme = params.get(DriverParams.SCHEME.value, "Unknown") raise CommError( CommError.BAD_CONFIG, f"Connector with scheme {scheme} requires SSL but " f"driver {driver.get_name()} doesn't support it", ) handle = get_handle() connector = ConnectorInfo(handle, driver, params, mode, 0, 0, False, threading.Event()) driver.register_conn_monitor(self) with self.lock: self.connectors[handle] = connector log.debug(f"Connector {connector} is created") if self.started: self.start_connector(connector) return handle
[docs] def remove_connector(self, handle: str): with self.lock: connector = self.connectors.pop(handle, None) if connector: connector.stopped.set() connector.driver.shutdown() log.debug(f"Connector {connector} is removed") else: log.error(f"Unknown connector handle: {handle}")
[docs] def start(self): with self.lock: for handle in sorted(self.connectors.keys()): connector = self.connectors[handle] if not connector.started: self.start_connector(connector) self.heartbeat_monitor.start() self.started = True
[docs] def stop(self): self.heartbeat_monitor.stop() with self.lock: for handle in sorted(self.connectors.keys()): connector = self.connectors[handle] connector.stopped.set() connector.driver.shutdown() self.stopped = True self.conn_mgr_executor.shutdown(True) self.frame_mgr_executor.shutdown(True)
[docs] def find_endpoint(self, name: str) -> Optional[Endpoint]: sfm_endpoint = self.sfm_endpoints.get(name) if not sfm_endpoint: log.debug(f"Endpoint {name} doesn't exist") return None return sfm_endpoint.endpoint
[docs] def remove_endpoint(self, name: str): sfm_endpoint = self.sfm_endpoints.get(name) if not sfm_endpoint: log.debug(f"Endpoint {name} doesn't exist or already removed") return for sfm_conn in sfm_endpoint.connections: sfm_conn.conn.close() self.sfm_endpoints.pop(name) log.debug(f"Endpoint {name} is removed")
[docs] def get_connections(self, name: str) -> Optional[List[SfmConnection]]: sfm_endpoint = self.sfm_endpoints.get(name) if not sfm_endpoint: log.debug("Endpoint {name} doesn't exist") return None return sfm_endpoint.connections
[docs] def send_message(self, endpoint: Endpoint, app_id: int, headers: Optional[dict], payload: BytesAlike): """Send a message to endpoint for app The message is asynchronous, no response is expected. Args: endpoint: An endpoint to send the message to app_id: Application ID headers: headers, optional payload: message payload, optional Raises: CommError: If any error happens while sending the data """ # Flatten buffer list so drivers don't have to deal with it if isinstance(payload, list): flat_payload = BufferList(payload).flatten() else: flat_payload = payload if endpoint.name == self.local_endpoint.name: self.send_loopback_message(endpoint, app_id, headers, flat_payload) return sfm_endpoint = self.sfm_endpoints.get(endpoint.name) if not sfm_endpoint: raise CommError(CommError.CLOSED, f"Endpoint {endpoint.name} not available, may be disconnected") state = sfm_endpoint.endpoint.state if state != EndpointState.READY: raise CommError(CommError.NOT_READY, f"Endpoint {endpoint.name} is not ready: {state}") stream_id = sfm_endpoint.next_stream_id() # When multiple connections, round-robin by stream ID sfm_conn = sfm_endpoint.get_connection(stream_id) if not sfm_conn: log.error("Logic error, ready endpoint has no connections") raise CommError(CommError.ERROR, f"Endpoint {endpoint.name} has no connection") # TODO: If multiple connections, should retry a diff connection on errors start = time.perf_counter() sfm_conn.send_data(app_id, stream_id, headers, flat_payload) self.send_frame_stats.record_value( category=sfm_conn.conn.connector.driver.get_name(), value=time.perf_counter() - start )
[docs] def register_message_receiver(self, app_id: int, receiver: MessageReceiver): if self.receivers.get(app_id): raise CommError(CommError.BAD_CONFIG, f"Receiver for app {app_id} is already registered") self.receivers[app_id] = receiver
[docs] def add_endpoint_monitor(self, monitor: EndpointMonitor): self.monitors.append(monitor)
# Internal methods
[docs] def start_connector(self, connector: ConnectorInfo): """Start connector in a new thread""" if connector.started: return log.info(f"Connector {connector} is starting") try: self.conn_mgr_executor.submit(self.start_connector_task, connector) except RuntimeError: log.debug("Connector start skipped — executor already shut down")
[docs] @staticmethod def start_connector_task(connector: ConnectorInfo): """Start connector in a new thread This function will loop as long as connector is not stopped """ connector.started = True if connector.mode == Mode.ACTIVE: starter = connector.driver.connect else: starter = connector.driver.listen wait = INIT_WAIT while not connector.stopped.is_set(): start_time = time.time() try: starter(connector) except Exception as ex: fail_msg = ( f"Connector {connector} failed with exception {type(ex).__name__}: {secure_format_exception(ex)}" ) if wait < SILENT_RECONNECT_TIME: log.debug(fail_msg) else: log.error(fail_msg) if connector.stopped.is_set(): log.debug(f"Connector {connector} has stopped") break # After a long run, resetting wait run_time = time.time() - start_time if run_time > MAX_WAIT: log.debug(f"Driver for {connector} had a long run ({run_time} sec), resetting wait") wait = INIT_WAIT reconnect_msg = f"Retrying {connector} in {wait} seconds" # First few retries may happen in normal shutdown, show it as debug if wait < SILENT_RECONNECT_TIME: log.debug(reconnect_msg) else: log.info(reconnect_msg) connector.stopped.wait(wait) # Exponential backoff wait *= 2 if wait > MAX_WAIT: wait = MAX_WAIT
[docs] def state_change(self, connection: Connection): try: state = connection.state connector = connection.connector if state == ConnState.CONNECTED: log.info(f"Connection {connection} is created: PID: {os.getpid()}") self.handle_new_connection(connection) with self.lock: connector.total_conns += 1 connector.curr_conns += 1 elif state == ConnState.CLOSED: log.info(f"Connection {connection} is closed PID: {os.getpid()}") self.close_connection(connection) with self.lock: connector.curr_conns -= 1 else: log.error(f"Unknown state: {state}") except Exception as ex: log.error(f"Error handling state change: {secure_format_exception(ex)}") log.debug(secure_format_traceback())
[docs] def process_frame_task(self, sfm_conn: SfmConnection, frame: BytesAlike): if self.stopped: return try: prefix = Prefix.from_bytes(frame) log.debug(f"Received frame: {prefix} on {sfm_conn.conn}") if prefix.header_len == 0: headers = None else: headers = msgpack.unpackb(frame[PREFIX_LEN : PREFIX_LEN + prefix.header_len]) if prefix.type in (Types.HELLO, Types.READY): if prefix.type == Types.HELLO: sfm_conn.send_handshake(Types.READY) data = self.get_dict_payload(prefix, frame) self.update_endpoint(sfm_conn, data) elif prefix.type == Types.PING: sfm_conn.send_heartbeat(Types.PONG) elif prefix.type == Types.PONG: log.debug(f"PONG received for {sfm_conn.conn}") # No action is needed for PONG. The last_activity is already updated elif prefix.type == Types.DATA: if prefix.length > PREFIX_LEN + prefix.header_len: payload = frame[PREFIX_LEN + prefix.header_len :] else: payload = None message = Message(headers, payload) receiver = self.receivers.get(prefix.app_id) if receiver: receiver.process_message(sfm_conn.sfm_endpoint.endpoint, sfm_conn.conn, prefix.app_id, message) else: log.debug(f"No receiver registered for App ID {prefix.app_id}, message ignored") else: log.error(f"Received unsupported frame type {prefix.type} on {sfm_conn.get_name()}") except RuntimeError as ex: if self.stopped: log.debug(f"Frame processing interrupted by shutdown: {secure_format_exception(ex)}") else: log.error(f"Error processing frame: {secure_format_exception(ex)}") log.debug(secure_format_traceback()) except Exception as ex: log.error(f"Error processing frame: {secure_format_exception(ex)}") log.debug(secure_format_traceback())
[docs] def process_frame(self, sfm_conn: SfmConnection, frame: BytesAlike): if self.stopped: log.debug(f"Frame received after shutdown for connection {sfm_conn.get_name()}") return try: self.frame_mgr_executor.submit(self.process_frame_task, sfm_conn, frame) except RuntimeError: log.debug(f"Frame received after shutdown for connection {sfm_conn.get_name()}")
[docs] def update_endpoint(self, sfm_conn: SfmConnection, data: dict): endpoint_name = data.pop(HandshakeKeys.ENDPOINT_NAME) if not endpoint_name: raise CommError(CommError.BAD_DATA, f"Handshake without endpoint name for connection {sfm_conn.get_name()}") endpoint_name = FQCN.normalize(endpoint_name) err = FQCN.validate(endpoint_name) if err: sfm_conn.conn.close() raise CommError( CommError.BAD_DATA, f"Invalid endpoint name '{endpoint_name}' for connection {sfm_conn.get_name()}: {err}", ) if endpoint_name == self.local_endpoint.name: raise CommError( CommError.BAD_DATA, f"Duplicate endpoint name {endpoint_name} for connection {sfm_conn.get_name()}" ) conn_props = sfm_conn.conn.get_conn_properties() # Passive mTLS connections must always present the connecting peer's CN. # Active mTLS connections enforce the same binding when the driver exposes # a real peer CN; gRPC active-side connections may report "N/A"/None. if is_mtls_connection(sfm_conn.conn.connector.params): peer_cn = get_param(conn_props, DriverParams.PEER_CN) if sfm_conn.conn.connector.mode == Mode.PASSIVE or (peer_cn and peer_cn != "N/A"): if ( is_valid_admin_client_name(endpoint_name) and sfm_conn.conn.connector.mode == Mode.PASSIVE and not is_admin_listener(sfm_conn.conn.connector.params) ): sfm_conn.conn.close() raise CommError( CommError.BAD_DATA, f"Admin endpoint '{endpoint_name}' can only connect through an admin listener", ) try: self.identity_resolver.require_match(endpoint_name, peer_cn, f"connection {sfm_conn.get_name()}") except ValueError as ex: sfm_conn.conn.close() raise CommError(CommError.BAD_DATA, str(ex)) endpoint = Endpoint(endpoint_name, data) endpoint.state = EndpointState.READY if conn_props: endpoint.conn_props.update(conn_props) sfm_endpoint = self.sfm_endpoints.get(endpoint_name) if sfm_endpoint: old_state = sfm_endpoint.endpoint.state sfm_endpoint.endpoint = endpoint else: old_state = EndpointState.IDLE sfm_endpoint = SfmEndpoint(endpoint) sfm_endpoint.add_connection(sfm_conn) sfm_conn.sfm_endpoint = sfm_endpoint self.sfm_endpoints[endpoint_name] = sfm_endpoint if endpoint.state != old_state: self.notify_monitors(endpoint)
[docs] def notify_monitors(self, endpoint: Endpoint): if not self.monitors: log.debug("No endpoint monitor registered") return for monitor in self.monitors: monitor.state_change(endpoint)
[docs] @staticmethod def get_dict_payload(prefix, frame): mv = frame if isinstance(frame, memoryview) else memoryview(frame) return msgpack.unpackb(mv[(PREFIX_LEN + prefix.header_len) :])
[docs] def handle_new_connection(self, connection: Connection): sfm_conn = SfmConnection(connection, self.local_endpoint) with self.lock: self.sfm_conns[sfm_conn.get_name()] = sfm_conn connection.register_frame_receiver(SfmFrameReceiver(self, sfm_conn)) if connection.connector.mode == Mode.ACTIVE: sfm_conn.send_handshake(Types.HELLO)
[docs] def close_connection(self, connection: Connection): with self.lock: name = connection.name if name not in self.sfm_conns: log.debug(f"Connection {name} has closed with no endpoint assigned") return sfm_conn = self.sfm_conns.pop(name) sfm_endpoint = sfm_conn.sfm_endpoint if sfm_endpoint is None: log.debug(f"Connection {name} is closed before SFM handshake") return old_state = sfm_endpoint.endpoint.state sfm_endpoint.remove_connection(sfm_conn) state = EndpointState.READY if sfm_endpoint.connections else EndpointState.DISCONNECTED sfm_endpoint.endpoint.state = state if old_state != state: self.notify_monitors(sfm_endpoint.endpoint)
[docs] def send_loopback_message(self, endpoint: Endpoint, app_id: int, headers: Optional[dict], payload: BytesAlike): """Send message to itself""" if self.stopped: return # Call receiver in a different thread to avoid deadlock try: self.frame_mgr_executor.submit(self.loopback_message_task, endpoint, app_id, headers, payload) except RuntimeError as e: log.debug(f"Loopback submit skipped: {e}")
[docs] def loopback_message_task(self, endpoint: Endpoint, app_id: int, headers: Optional[dict], payload: BytesAlike): receiver = self.receivers.get(app_id) if not receiver: log.debug(f"No receiver registered for App ID {app_id}, loopback message ignored") return try: receiver.process_message(endpoint, self.null_conn, app_id, Message(headers, payload)) except Exception as ex: log.error(f"Loopback message error: {secure_format_exception(ex)}")
[docs] class SfmFrameReceiver(FrameReceiver): def __init__(self, conn_manager: ConnManager, conn: SfmConnection): self.conn_manager = conn_manager self.conn = conn
[docs] def process_frame(self, frame: BytesAlike): self.conn.last_activity = time.time() try: self.conn_manager.process_frame(self.conn, frame) except Exception as ex: log.error(f"Error processing frame: {secure_format_exception(ex)}") log.debug(secure_format_traceback())
[docs] class NullConnection(Connection): """A mock connection used for loopback messages""" def __init__(self): connector = ConnectorInfo("Null", None, {}, Mode.ACTIVE, 0, 0, False, threading.Event()) super().__init__(connector)
[docs] def get_conn_properties(self) -> dict: return {DriverParams.LOCAL_ADDR.value: SELF_ADDR, DriverParams.PEER_ADDR.value: SELF_ADDR}
[docs] def close(self): pass
[docs] def send_frame(self, frame: BytesAlike): raise CommError(CommError.NOT_SUPPORTED, "Can't send data on Null connection")