Source code for nvflare.fuel.f3.drivers.aio_conn

# Copyright (c) 2023, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from asyncio import CancelledError, IncompleteReadError, StreamReader, StreamWriter

from nvflare.fuel.f3.comm_error import CommError
from nvflare.fuel.f3.connection import BytesAlike, Connection
from nvflare.fuel.f3.drivers.aio_context import AioContext
from nvflare.fuel.f3.drivers.connector_info import ConnectorInfo
from nvflare.fuel.f3.drivers.driver_params import DriverParams
from nvflare.fuel.f3.drivers.net_utils import MAX_FRAME_SIZE
from nvflare.fuel.f3.sfm.prefix import PREFIX_LEN, Prefix
from nvflare.fuel.hci.security import get_certificate_common_name
from nvflare.security.logging import secure_format_exception

log = logging.getLogger(__name__)


[docs]class AioConnection(Connection): def __init__( self, connector: ConnectorInfo, aio_ctx: AioContext, reader: StreamReader, writer: StreamWriter, secure: bool = False, ): super().__init__(connector) self.reader = reader self.writer = writer self.aio_ctx = aio_ctx self.closing = False self.secure = secure self.conn_props = self._get_aio_properties()
[docs] def get_conn_properties(self) -> dict: return self.conn_props
[docs] def close(self): self.closing = True if not self.writer: return self.writer.close() self.aio_ctx.run_coro(self.writer.wait_closed())
[docs] def send_frame(self, frame: BytesAlike): try: self.aio_ctx.run_coro(self._async_send_frame(frame)) except Exception as ex: log.error(f"Error calling send coroutine for connection {self}: {secure_format_exception(ex)}")
[docs] async def read_loop(self): try: while not self.closing: frame = await self._async_read_frame() self.process_frame(frame) except IncompleteReadError: if log.isEnabledFor(logging.DEBUG): closer = "locally" if self.closing else "by peer" log.debug(f"Connection {self} is closed {closer}") except CancelledError as error: log.debug(f"Connection {self} is closed by peer: {error}") except Exception as ex: log.error(f"Read error for connection {self}: {secure_format_exception(ex)}")
# Internal methods async def _async_send_frame(self, frame: BytesAlike): try: self.writer.write(frame) await self.writer.drain() except Exception as ex: if not self.closing: log.error(f"Error sending frame for connection {self}: {secure_format_exception(ex)}") async def _async_read_frame(self): prefix_buf = await self.reader.readexactly(PREFIX_LEN) prefix = Prefix.from_bytes(prefix_buf) # Prefix only message if prefix.length == PREFIX_LEN: return prefix_buf if prefix.length > MAX_FRAME_SIZE: raise CommError(CommError.BAD_DATA, f"Frame exceeds limit ({prefix.length} > {MAX_FRAME_SIZE}") remaining = await self.reader.readexactly(prefix.length - PREFIX_LEN) return prefix_buf + remaining def _get_aio_properties(self) -> dict: conn_props = {} if not self.writer: return conn_props fileno = 0 local_addr = self.writer.get_extra_info("sockname", "") if isinstance(local_addr, tuple): local_addr = f"{local_addr[0]}:{local_addr[1]}" else: sock = self.writer.get_extra_info("socket", None) if sock: fileno = sock.fileno() local_addr = f"{local_addr}:{fileno}" peer_addr = self.writer.get_extra_info("peername", "") if isinstance(peer_addr, tuple): peer_addr = f"{peer_addr[0]}:{peer_addr[1]}" else: peer_addr = f"{peer_addr}:{fileno}" conn_props[DriverParams.LOCAL_ADDR.value] = local_addr conn_props[DriverParams.PEER_ADDR.value] = peer_addr peer_cert = self.writer.get_extra_info("peercert") if peer_cert: cn = get_certificate_common_name(peer_cert) else: if self.secure: cn = "N/A" else: cn = None if cn: conn_props[DriverParams.PEER_CN.value] = cn return conn_props