Source code for nvflare.fuel.f3.sfm.conn_manager

# Copyright (c) 2023, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import os
import threading
import time
from concurrent.futures import ThreadPoolExecutor
from typing import Dict, List, Optional

import msgpack

from nvflare.fuel.f3.comm_error import CommError
from nvflare.fuel.f3.connection import BytesAlike, Connection, ConnState, FrameReceiver
from nvflare.fuel.f3.drivers.connector_info import ConnectorInfo, Mode
from nvflare.fuel.f3.drivers.driver import ConnMonitor, Driver
from nvflare.fuel.f3.drivers.driver_params import DriverCap, DriverParams
from nvflare.fuel.f3.drivers.net_utils import ssl_required
from nvflare.fuel.f3.endpoint import Endpoint, EndpointMonitor, EndpointState
from nvflare.fuel.f3.message import Message, MessageReceiver
from nvflare.fuel.f3.sfm.constants import HandshakeKeys, Types
from nvflare.fuel.f3.sfm.heartbeat_monitor import HeartbeatMonitor
from nvflare.fuel.f3.sfm.prefix import PREFIX_LEN, Prefix
from nvflare.fuel.f3.sfm.sfm_conn import SfmConnection
from nvflare.fuel.f3.sfm.sfm_endpoint import SfmEndpoint
from nvflare.fuel.f3.stats_pool import StatsPoolManager
from nvflare.fuel.utils.buffer_list import BufferList
from nvflare.security.logging import secure_format_exception, secure_format_traceback

FRAME_THREAD_POOL_SIZE = 100
CONN_THREAD_POOL_SIZE = 16
INIT_WAIT = 1
MAX_WAIT = 10
SILENT_RECONNECT_TIME = 5
SELF_ADDR = "0.0.0.0:0"

log = logging.getLogger(__name__)

handle_lock = threading.Lock()
handle_count = 0


[docs]def get_handle(): global handle_lock, handle_count with handle_lock: handle_count += 1 return "CH%05d" % handle_count
[docs]class ConnManager(ConnMonitor): """SFM connection manager The class is responsible for maintaining state of SFM connections and pumping data through them """ def __init__(self, local_endpoint: Endpoint): self.local_endpoint = local_endpoint # Active connectors self.connectors: Dict[str, ConnectorInfo] = {} # A dict of SFM connections, key is connection name self.sfm_conns: Dict[str, SfmConnection] = {} # A dict of SfmEndpoint for finding endpoint by name self.sfm_endpoints: Dict[str, SfmEndpoint] = {} # A list of Endpoint monitors self.monitors: List[EndpointMonitor] = [] # App/receiver mapping self.receivers: Dict[int, MessageReceiver] = {} self.started = False self.conn_mgr_executor = ThreadPoolExecutor(CONN_THREAD_POOL_SIZE, "conn_mgr") self.frame_mgr_executor = ThreadPoolExecutor(FRAME_THREAD_POOL_SIZE, "frame_mgr") self.lock = threading.Lock() self.null_conn = NullConnection() stats = StatsPoolManager.get_pool("sfm_send_frame") if not stats: stats = StatsPoolManager.add_time_hist_pool( "sfm_send_frame", "SFM send_frame time in secs", scope=local_endpoint.name ) self.send_frame_stats = stats self.heartbeat_monitor = HeartbeatMonitor(self.sfm_conns)
[docs] def add_connector(self, driver: Driver, params: dict, mode: Mode) -> str: # Validate parameters capabilities = driver.capabilities() support_ssl = capabilities.get(DriverCap.SUPPORT_SSL, False) if ssl_required(params) and not support_ssl: scheme = params.get(DriverParams.SCHEME.value, "Unknown") raise CommError( CommError.BAD_CONFIG, f"Connector with scheme {scheme} requires SSL but " f"driver {driver.get_name()} doesn't support it", ) handle = get_handle() connector = ConnectorInfo(handle, driver, params, mode, 0, 0, False, threading.Event()) driver.register_conn_monitor(self) with self.lock: self.connectors[handle] = connector log.debug(f"Connector {connector} is created") if self.started: self.start_connector(connector) return handle
[docs] def remove_connector(self, handle: str): with self.lock: connector = self.connectors.pop(handle, None) if connector: connector.stopped.set() connector.driver.shutdown() log.debug(f"Connector {connector} is removed") else: log.error(f"Unknown connector handle: {handle}")
[docs] def start(self): with self.lock: for handle in sorted(self.connectors.keys()): connector = self.connectors[handle] if not connector.started: self.start_connector(connector) self.heartbeat_monitor.start() self.started = True
[docs] def stop(self): self.heartbeat_monitor.stop() with self.lock: for handle in sorted(self.connectors.keys()): connector = self.connectors[handle] connector.stopped.set() connector.driver.shutdown() self.conn_mgr_executor.shutdown(True) self.frame_mgr_executor.shutdown(True)
[docs] def find_endpoint(self, name: str) -> Optional[Endpoint]: sfm_endpoint = self.sfm_endpoints.get(name) if not sfm_endpoint: log.debug(f"Endpoint {name} doesn't exist") return None return sfm_endpoint.endpoint
[docs] def remove_endpoint(self, name: str): sfm_endpoint = self.sfm_endpoints.get(name) if not sfm_endpoint: log.debug(f"Endpoint {name} doesn't exist or already removed") return for sfm_conn in sfm_endpoint.connections: sfm_conn.conn.close() self.sfm_endpoints.pop(name) log.debug(f"Endpoint {name} is removed")
[docs] def get_connections(self, name: str) -> Optional[List[SfmConnection]]: sfm_endpoint = self.sfm_endpoints.get(name) if not sfm_endpoint: log.debug("Endpoint {name} doesn't exist") return None return sfm_endpoint.connections
[docs] def send_message(self, endpoint: Endpoint, app_id: int, headers: Optional[dict], payload: BytesAlike): """Send a message to endpoint for app The message is asynchronous, no response is expected. Args: endpoint: An endpoint to send the message to app_id: Application ID headers: headers, optional payload: message payload, optional Raises: CommError: If any error happens while sending the data """ # Flatten buffer list so drivers don't have to deal with it if isinstance(payload, list): flat_payload = BufferList(payload).flatten() else: flat_payload = payload if endpoint.name == self.local_endpoint.name: self.send_loopback_message(endpoint, app_id, headers, flat_payload) return sfm_endpoint = self.sfm_endpoints.get(endpoint.name) if not sfm_endpoint: raise CommError(CommError.CLOSED, f"Endpoint {endpoint.name} not available, may be disconnected") state = sfm_endpoint.endpoint.state if state != EndpointState.READY: raise CommError(CommError.NOT_READY, f"Endpoint {endpoint.name} is not ready: {state}") stream_id = sfm_endpoint.next_stream_id() # When multiple connections, round-robin by stream ID sfm_conn = sfm_endpoint.get_connection(stream_id) if not sfm_conn: log.error("Logic error, ready endpoint has no connections") raise CommError(CommError.ERROR, f"Endpoint {endpoint.name} has no connection") # TODO: If multiple connections, should retry a diff connection on errors start = time.perf_counter() sfm_conn.send_data(app_id, stream_id, headers, flat_payload) self.send_frame_stats.record_value( category=sfm_conn.conn.connector.driver.get_name(), value=time.perf_counter() - start )
[docs] def register_message_receiver(self, app_id: int, receiver: MessageReceiver): if self.receivers.get(app_id): raise CommError(CommError.BAD_CONFIG, f"Receiver for app {app_id} is already registered") self.receivers[app_id] = receiver
[docs] def add_endpoint_monitor(self, monitor: EndpointMonitor): self.monitors.append(monitor)
# Internal methods
[docs] def start_connector(self, connector: ConnectorInfo): """Start connector in a new thread""" if connector.started: return log.info(f"Connector {connector} is starting") self.conn_mgr_executor.submit(self.start_connector_task, connector)
[docs] @staticmethod def start_connector_task(connector: ConnectorInfo): """Start connector in a new thread This function will loop as long as connector is not stopped """ connector.started = True if connector.mode == Mode.ACTIVE: starter = connector.driver.connect else: starter = connector.driver.listen wait = INIT_WAIT while not connector.stopped.is_set(): start_time = time.time() try: starter(connector) except Exception as ex: fail_msg = ( f"Connector {connector} failed with exception {type(ex).__name__}: {secure_format_exception(ex)}" ) if wait < SILENT_RECONNECT_TIME: log.debug(fail_msg) else: log.error(fail_msg) if connector.stopped.is_set(): log.debug(f"Connector {connector} has stopped") break # After a long run, resetting wait run_time = time.time() - start_time if run_time > MAX_WAIT: log.debug(f"Driver for {connector} had a long run ({run_time} sec), resetting wait") wait = INIT_WAIT reconnect_msg = f"Retrying {connector} in {wait} seconds" # First few retries may happen in normal shutdown, show it as debug if wait < SILENT_RECONNECT_TIME: log.debug(reconnect_msg) else: log.info(reconnect_msg) connector.stopped.wait(wait) # Exponential backoff wait *= 2 if wait > MAX_WAIT: wait = MAX_WAIT
[docs] def state_change(self, connection: Connection): try: state = connection.state connector = connection.connector if state == ConnState.CONNECTED: log.info(f"Connection {connection} is created: PID: {os.getpid()}") self.handle_new_connection(connection) with self.lock: connector.total_conns += 1 connector.curr_conns += 1 elif state == ConnState.CLOSED: log.info(f"Connection {connection} is closed PID: {os.getpid()}") self.close_connection(connection) with self.lock: connector.curr_conns -= 1 else: log.error(f"Unknown state: {state}") except Exception as ex: log.error(f"Error handling state change: {secure_format_exception(ex)}") log.debug(secure_format_traceback())
[docs] def process_frame_task(self, sfm_conn: SfmConnection, frame: BytesAlike): try: prefix = Prefix.from_bytes(frame) log.debug(f"Received frame: {prefix} on {sfm_conn.conn}") if prefix.header_len == 0: headers = None else: headers = msgpack.unpackb(frame[PREFIX_LEN : PREFIX_LEN + prefix.header_len]) if prefix.type in (Types.HELLO, Types.READY): if prefix.type == Types.HELLO: sfm_conn.send_handshake(Types.READY) data = self.get_dict_payload(prefix, frame) self.update_endpoint(sfm_conn, data) elif prefix.type == Types.PING: sfm_conn.send_heartbeat(Types.PONG) elif prefix.type == Types.PONG: log.debug(f"PONG received for {sfm_conn.conn}") # No action is needed for PONG. The last_activity is already updated elif prefix.type == Types.DATA: if prefix.length > PREFIX_LEN + prefix.header_len: payload = frame[PREFIX_LEN + prefix.header_len :] else: payload = None message = Message(headers, payload) receiver = self.receivers.get(prefix.app_id) if receiver: receiver.process_message(sfm_conn.sfm_endpoint.endpoint, sfm_conn.conn, prefix.app_id, message) else: log.debug(f"No receiver registered for App ID {prefix.app_id}, message ignored") else: log.error(f"Received unsupported frame type {prefix.type} on {sfm_conn.get_name()}") except Exception as ex: log.error(f"Error processing frame: {secure_format_exception(ex)}") log.debug(secure_format_traceback())
[docs] def process_frame(self, sfm_conn: SfmConnection, frame: BytesAlike): self.frame_mgr_executor.submit(self.process_frame_task, sfm_conn, frame)
[docs] def update_endpoint(self, sfm_conn: SfmConnection, data: dict): endpoint_name = data.pop(HandshakeKeys.ENDPOINT_NAME) if not endpoint_name: raise CommError(CommError.BAD_DATA, f"Handshake without endpoint name for connection {sfm_conn.get_name()}") if endpoint_name == self.local_endpoint.name: raise CommError( CommError.BAD_DATA, f"Duplicate endpoint name {endpoint_name} for connection {sfm_conn.get_name()}" ) endpoint = Endpoint(endpoint_name, data) endpoint.state = EndpointState.READY conn_props = sfm_conn.conn.get_conn_properties() if conn_props: endpoint.conn_props.update(conn_props) sfm_endpoint = self.sfm_endpoints.get(endpoint_name) if sfm_endpoint: old_state = sfm_endpoint.endpoint.state sfm_endpoint.endpoint = endpoint else: old_state = EndpointState.IDLE sfm_endpoint = SfmEndpoint(endpoint) sfm_endpoint.add_connection(sfm_conn) sfm_conn.sfm_endpoint = sfm_endpoint self.sfm_endpoints[endpoint_name] = sfm_endpoint if endpoint.state != old_state: self.notify_monitors(endpoint)
[docs] def notify_monitors(self, endpoint: Endpoint): if not self.monitors: log.debug("No endpoint monitor registered") return for monitor in self.monitors: monitor.state_change(endpoint)
[docs] @staticmethod def get_dict_payload(prefix, frame): mv = memoryview(frame) return msgpack.unpackb(mv[(PREFIX_LEN + prefix.header_len) :])
[docs] def handle_new_connection(self, connection: Connection): sfm_conn = SfmConnection(connection, self.local_endpoint) with self.lock: self.sfm_conns[sfm_conn.get_name()] = sfm_conn connection.register_frame_receiver(SfmFrameReceiver(self, sfm_conn)) if connection.connector.mode == Mode.ACTIVE: sfm_conn.send_handshake(Types.HELLO)
[docs] def close_connection(self, connection: Connection): with self.lock: name = connection.name if name not in self.sfm_conns: log.debug(f"Connection {name} has closed with no endpoint assigned") return sfm_conn = self.sfm_conns.pop(name) sfm_endpoint = sfm_conn.sfm_endpoint if sfm_endpoint is None: log.debug(f"Connection {name} is closed before SFM handshake") return old_state = sfm_endpoint.endpoint.state sfm_endpoint.remove_connection(sfm_conn) state = EndpointState.READY if sfm_endpoint.connections else EndpointState.DISCONNECTED sfm_endpoint.endpoint.state = state if old_state != state: self.notify_monitors(sfm_endpoint.endpoint)
[docs] def send_loopback_message(self, endpoint: Endpoint, app_id: int, headers: Optional[dict], payload: BytesAlike): """Send message to itself""" # Call receiver in a different thread to avoid deadlock self.frame_mgr_executor.submit(self.loopback_message_task, endpoint, app_id, headers, payload)
[docs] def loopback_message_task(self, endpoint: Endpoint, app_id: int, headers: Optional[dict], payload: BytesAlike): receiver = self.receivers.get(app_id) if not receiver: log.debug(f"No receiver registered for App ID {app_id}, loopback message ignored") return try: receiver.process_message(endpoint, self.null_conn, app_id, Message(headers, payload)) except Exception as ex: log.error(f"Loopback message error: {secure_format_exception(ex)}")
[docs]class SfmFrameReceiver(FrameReceiver): def __init__(self, conn_manager: ConnManager, conn: SfmConnection): self.conn_manager = conn_manager self.conn = conn
[docs] def process_frame(self, frame: BytesAlike): self.conn.last_activity = time.time() try: self.conn_manager.process_frame(self.conn, frame) except Exception as ex: log.error(f"Error processing frame: {secure_format_exception(ex)}") log.debug(secure_format_traceback())
[docs]class NullConnection(Connection): """A mock connection used for loopback messages""" def __init__(self): connector = ConnectorInfo("Null", None, {}, Mode.ACTIVE, 0, 0, False, threading.Event()) super().__init__(connector)
[docs] def get_conn_properties(self) -> dict: return {DriverParams.LOCAL_ADDR.value: SELF_ADDR, DriverParams.PEER_ADDR.value: SELF_ADDR}
[docs] def close(self): pass
[docs] def send_frame(self, frame: BytesAlike): raise CommError(CommError.NOT_SUPPORTED, "Can't send data on Null connection")