Source code for nvflare.fuel.f3.drivers.socket_conn

# Copyright (c) 2023, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import socket
from socketserver import BaseRequestHandler
from typing import Any, Union

from nvflare.fuel.f3.comm_error import CommError
from nvflare.fuel.f3.connection import BytesAlike, Connection
from nvflare.fuel.f3.drivers.driver import ConnectorInfo
from nvflare.fuel.f3.drivers.driver_params import DriverParams
from nvflare.fuel.f3.drivers.net_utils import MAX_FRAME_SIZE
from nvflare.fuel.f3.sfm.prefix import PREFIX_LEN, Prefix
from nvflare.fuel.hci.security import get_certificate_common_name
from nvflare.security.logging import secure_format_exception

log = logging.getLogger(__name__)


[docs]class SocketConnection(Connection): def __init__(self, sock: Any, connector: ConnectorInfo, secure: bool = False): super().__init__(connector) self.sock = sock self.secure = secure self.closing = False self.conn_props = self._get_socket_properties()
[docs] def get_conn_properties(self) -> dict: return self.conn_props
[docs] def close(self): self.closing = True if self.sock: try: self.sock.shutdown(socket.SHUT_RDWR) except OSError as error: log.debug(f"Connection {self} is already closed: {error}") self.sock.close()
[docs] def send_frame(self, frame: BytesAlike): try: self.sock.sendall(frame) except Exception as ex: if not self.closing: raise CommError(CommError.ERROR, f"Error sending frame on conn {self}: {secure_format_exception(ex)}")
[docs] def read_loop(self): try: self.read_frame_loop() except CommError as error: if error.code == CommError.CLOSED: log.debug(f"Connection {self.name} is closed by peer") else: log.debug(f"Connection {self.name} is closed due to error: {error}") except Exception as ex: if self.closing: log.debug(f"Connection {self.name} is closed") else: log.debug(f"Connection {self.name} is closed due to error: {secure_format_exception(ex)}")
[docs] def read_frame_loop(self): # read_frame throws exception on stale/bad connection so this is not a dead loop while not self.closing: frame = self.read_frame() self.process_frame(frame)
[docs] def read_frame(self) -> BytesAlike: prefix_buf = bytearray(PREFIX_LEN) self.read_into(prefix_buf, 0, PREFIX_LEN) prefix = Prefix.from_bytes(prefix_buf) if prefix.length == PREFIX_LEN: return prefix_buf if prefix.length > MAX_FRAME_SIZE: raise CommError(CommError.BAD_DATA, f"Frame exceeds limit ({prefix.length} > {MAX_FRAME_SIZE}") frame = bytearray(prefix.length) frame[0:PREFIX_LEN] = prefix_buf self.read_into(frame, PREFIX_LEN, prefix.length - PREFIX_LEN) return frame
[docs] def read_into(self, buffer: BytesAlike, offset: int, length: int): if isinstance(buffer, memoryview): view = buffer else: view = memoryview(buffer) if offset: view = view[offset:] remaining = length while remaining: n = self.sock.recv_into(view, remaining) if n == 0: raise CommError(CommError.CLOSED, f"Connection {self.name} is closed by peer") view = view[n:] remaining -= n
@staticmethod def _format_address(addr: Union[str, tuple], fileno: int) -> str: if isinstance(addr, tuple): result = f"{addr[0]}:{addr[1]}" else: result = f"{addr}:{fileno}" return result def _get_socket_properties(self) -> dict: conn_props = {} try: peer = self.sock.getpeername() fileno = self.sock.fileno() except OSError as ex: peer = "N/A" fileno = 0 log.debug(f"getpeername() error: {secure_format_exception(ex)}") conn_props[DriverParams.PEER_ADDR.value] = self._format_address(peer, fileno) local = self.sock.getsockname() conn_props[DriverParams.LOCAL_ADDR.value] = self._format_address(local, fileno) if self.secure: cert = self.sock.getpeercert() if cert: cn = get_certificate_common_name(cert) else: cn = "N/A" conn_props[DriverParams.PEER_CN.value] = cn return conn_props
[docs]class ConnectionHandler(BaseRequestHandler):
[docs] def handle(self): # noinspection PyUnresolvedReferences connection = SocketConnection(self.request, self.server.connector, self.server.ssl_context) # noinspection PyUnresolvedReferences driver = self.server.driver driver.add_connection(connection) connection.read_loop() driver.close_connection(connection)