Source code for nvflare.fuel.f3.communicator

# Copyright (c) 2023, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import atexit
import logging
import os
import weakref
from typing import Optional

from nvflare.fuel.f3 import drivers
from nvflare.fuel.f3.comm_config import CommConfigurator
from nvflare.fuel.f3.comm_error import CommError
from nvflare.fuel.f3.drivers.driver import Driver
from nvflare.fuel.f3.drivers.driver_manager import DriverManager
from nvflare.fuel.f3.drivers.driver_params import DriverParams
from nvflare.fuel.f3.drivers.net_utils import parse_url
from nvflare.fuel.f3.endpoint import Endpoint, EndpointMonitor
from nvflare.fuel.f3.message import Message, MessageReceiver
from nvflare.fuel.f3.sfm.conn_manager import ConnManager, Mode
from nvflare.security.logging import secure_format_exception

log = logging.getLogger(__name__)
_running_instances = weakref.WeakSet()
driver_mgr = DriverManager()
driver_loaded = False


[docs]def load_comm_drivers(): global driver_loaded # Load all the drivers in the drivers module driver_mgr.search_folder(os.path.dirname(drivers.__file__), drivers.__package__) # Load custom drivers driver_path = CommConfigurator().get_comm_driver_path(None) if not driver_path: return for path in driver_path.split(os.pathsep): log.debug(f"Custom driver folder {path} is searched") driver_mgr.search_folder(path, None) driver_loaded = True
[docs]class Communicator: """FCI (Flare Communication Interface) main communication API""" def __init__(self, local_endpoint: Endpoint): self.local_endpoint = local_endpoint self.monitors = [] self.conn_manager = ConnManager(local_endpoint) self.stopped = False
[docs] def start(self): """Start the communicator and establishing all the connections Raises: CommError: If any error encountered while starting up """ self.conn_manager.start() log.debug(f"Communicator for local endpoint: {self.local_endpoint.name} is started") _running_instances.add(self)
[docs] def stop(self): """Stop the communicator and shutdown all the connections Raises: CommError: If any error encountered while shutting down """ if self.stopped: return self.conn_manager.stop() self.stopped = True try: _running_instances.remove(self) except KeyError as ex: log.error( f"Logical error, communicator {self.local_endpoint.name} is not started: {secure_format_exception(ex)}" ) log.debug(f"Communicator endpoint: {self.local_endpoint.name} has stopped")
[docs] def register_monitor(self, monitor: EndpointMonitor): """Register a monitor for endpoint lifecycle changes This monitor is notified for any state changes of all the endpoints. Multiple monitors can be registered. Args: monitor: The class that receives the endpoint state change notification Raises: CommError: If any error happens while sending the request """ self.conn_manager.add_endpoint_monitor(monitor)
[docs] def find_endpoint(self, name: str) -> Optional[Endpoint]: """Find endpoint by name Args: name: Endpoint name Returns: The endpoint if found. None if not found """ return self.conn_manager.find_endpoint(name)
[docs] def remove_endpoint(self, name: str): """Remove endpoint and close all the connections associated with it Args: name: Endpoint name """ return self.conn_manager.remove_endpoint(name)
[docs] def send(self, endpoint: Endpoint, app_id: int, message: Message): """Send a message to endpoint for app_id, no response is expected Args: endpoint: An endpoint to send the request to app_id: Application ID message: Message to send Raises: CommError: If any error happens while sending the data """ self.conn_manager.send_message(endpoint, app_id, message.headers, message.payload)
[docs] def register_message_receiver(self, app_id: int, receiver: MessageReceiver): """Register a receiver to process FCI message for the app Args: app_id: Application ID receiver: The receiver to process the message Raises: CommError: If duplicate endpoint/app or receiver is of wrong type """ self.conn_manager.register_message_receiver(app_id, receiver)
[docs] def add_connector(self, url: str, mode: Mode, secure: bool = False) -> str: """Load a connector. The driver is selected based on the URL Args: url: The url to listen on or connect to, like "https://0:443". Use 0 for empty host mode: Active for connecting, Passive for listening secure: True if SSL is required. Returns: A handle that can be used to delete connector Raises: CommError: If any errors """ if not driver_loaded: load_comm_drivers() driver_class = driver_mgr.find_driver_class(url) if not driver_class: raise CommError(CommError.NOT_SUPPORTED, f"No driver found for URL {url}") params = parse_url(url) return self.add_connector_advanced(driver_class(), mode, params, secure, False)
[docs] def start_listener(self, scheme: str, resources: dict) -> (str, str): """Add and start a connector in passive mode on an address selected by the driver. Args: scheme: Connection scheme, e.g. http, https resources: User specified resources like host and port ranges Returns: A tuple with connector handle and connect url Raises: CommError: If any errors like invalid host or port not available """ if not driver_loaded: load_comm_drivers() driver_class = driver_mgr.find_driver_class(scheme) if not driver_class: raise CommError(CommError.NOT_SUPPORTED, f"No driver found for scheme {scheme}") connect_url, listening_url = driver_class.get_urls(scheme, resources) params = parse_url(listening_url) handle = self.add_connector_advanced(driver_class(), Mode.PASSIVE, params, False, True) return handle, connect_url
[docs] def add_connector_advanced( self, driver: Driver, mode: Mode, params: dict, secure: bool, start: bool = False ) -> str: """Add a connector using a specific driver instance. Args: driver: A transport driver instance mode: Active or passive params: Driver parameters secure: SSL is required if true start: Start the connector if true Returns: A handle that can be used to delete the connector Raises: CommError: If any errors """ if self.local_endpoint.conn_props: params.update(self.local_endpoint.conn_props) if secure: params[DriverParams.SECURE] = secure handle = self.conn_manager.add_connector(driver, params, mode) if not start: return handle connector = self.conn_manager.connectors.get(handle, None) if not connector: log.error(f"Connector {driver.get_name()}:{handle} is not found") raise CommError(CommError.ERROR, f"Logic error. Connector {driver.get_name()}:{handle} not found") self.conn_manager.start_connector(connector) return handle
[docs] def remove_connector(self, handle: str): """Remove the connector Args: handle: The connector handle Raises: CommError: If any errors """ self.conn_manager.remove_connector(handle)
def _exit_func(): while _running_instances: c = next(iter(_running_instances)) # This call will remove the entry from the set c.stop() log.debug(f"Communicator {c.local_endpoint.name} was left running, stopped on exit") atexit.register(_exit_func)