Source code for nvflare.fuel.f3.connection

# Copyright (c) 2023, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import threading
from abc import ABC, abstractmethod
from enum import Enum
from typing import Union

from nvflare.fuel.f3.drivers.connector_info import ConnectorInfo, Mode
from nvflare.fuel.f3.drivers.driver_params import DriverParams

log = logging.getLogger(__name__)
lock = threading.Lock()
conn_count = 0

BytesAlike = Union[bytes, bytearray, memoryview, list]


[docs]def create_connection_name(): global lock, conn_count with lock: conn_count += 1 return "CN%05d" % conn_count
[docs]class ConnState(Enum): IDLE = 1 # Initial state CONNECTED = 2 # New connection CLOSED = 3 # Connection is closed
[docs]class FrameReceiver(ABC):
[docs] @abstractmethod def process_frame(self, frame: BytesAlike): """Frame received callback Args: frame: The frame received Raises: CommError: If any error happens while processing the frame """ pass
[docs]class Connection(ABC): """FCI connection spec. A connection is used to transfer opaque frames""" def __init__(self, connector: ConnectorInfo): self.name = create_connection_name() self.state = ConnState.IDLE self.frame_receiver = None self.connector = connector
[docs] @abstractmethod def get_conn_properties(self) -> dict: """Get connection specific properties, like peer address, TLS certificate etc Raises: CommError: If any errors """ pass
[docs] @abstractmethod def close(self): """Close connection Raises: CommError: If any errors """ pass
[docs] @abstractmethod def send_frame(self, frame: BytesAlike): """Send a SFM frame through the connection to the remote endpoint. Args: frame: The frame to be sent Raises: CommError: If any error happens while sending the frame """ pass
[docs] def register_frame_receiver(self, receiver: FrameReceiver): """Register frame receiver Args: receiver: The frame receiver """ self.frame_receiver = receiver
[docs] def process_frame(self, frame: BytesAlike): """A convenience function to call frame receiver Args: frame: The frame to be processed Raises: CommError: If any error happens while processing the frame """ if self.frame_receiver: self.frame_receiver.process_frame(frame) else: log.error(f"Frame receiver not registered for {self}")
def __str__(self): if self.state != ConnState.CONNECTED: return f"[{self.name} Not Connected]" conn_props = self.get_conn_properties() local_addr = conn_props.get(DriverParams.LOCAL_ADDR, "N/A") peer_addr = conn_props.get(DriverParams.PEER_ADDR, "N/A") direction = "=>" if self.connector.mode == Mode.ACTIVE else "<=" peer_cn = conn_props.get(DriverParams.PEER_CN, None) cn = " SSL " + peer_cn if peer_cn else "" return f"[{self.name} {local_addr} {direction} {peer_addr}{cn}]"