# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio
import logging
from typing import Any, Dict, List
import websockets
from websockets.exceptions import ConnectionClosedOK
from nvflare.fuel.f3.comm_error import CommError
from nvflare.fuel.f3.connection import BytesAlike, Connection
from nvflare.fuel.f3.drivers import net_utils
from nvflare.fuel.f3.drivers.aio_context import AioContext
from nvflare.fuel.f3.drivers.base_driver import BaseDriver
from nvflare.fuel.f3.drivers.driver import ConnectorInfo
from nvflare.fuel.f3.drivers.driver_params import DriverCap, DriverParams
from nvflare.fuel.f3.drivers.net_utils import MAX_FRAME_SIZE, get_tcp_urls
from nvflare.fuel.f3.sfm.conn_manager import Mode
from nvflare.fuel.hci.security import get_certificate_common_name
from nvflare.security.logging import secure_format_exception
log = logging.getLogger(__name__)
THREAD_POOL_SIZE = 8
[docs]class WsConnection(Connection):
def __init__(self, websocket: Any, aio_context: AioContext, connector: ConnectorInfo, secure: bool):
super().__init__(connector)
self.websocket = websocket
self.aio_context = aio_context
self.closing = False
self.secure = secure
self.conn_props = self._get_socket_properties()
[docs] def get_conn_properties(self) -> dict:
return self.conn_props
[docs] def close(self):
self.closing = True
self.aio_context.run_coro(self.websocket.close())
[docs] def send_frame(self, frame: BytesAlike):
self.aio_context.run_coro(self._async_send_frame(frame))
def _get_socket_properties(self) -> dict:
conn_props = {}
addr = self.websocket.remote_address
if addr:
conn_props[DriverParams.PEER_ADDR.value] = f"{addr[0]}:{addr[1]}"
addr = self.websocket.local_address
if addr:
conn_props[DriverParams.LOCAL_ADDR.value] = f"{addr[0]}:{addr[1]}"
peer_cert = self.websocket.transport.get_extra_info("peercert")
if peer_cert:
cn = get_certificate_common_name(peer_cert)
else:
if self.secure:
cn = "N/A"
else:
cn = None
if cn:
conn_props[DriverParams.PEER_CN.value] = cn
return conn_props
async def _async_send_frame(self, frame: BytesAlike):
try:
await self.websocket.send(frame)
# This is to yield control. See bug: https://github.com/aaugustin/websockets/issues/865
await asyncio.sleep(0)
except Exception as ex:
log.error(f"Error sending frame for connection {self}, closing: {secure_format_exception(ex)}")
self.close()
[docs]class AioHttpDriver(BaseDriver):
"""Async HTTP driver using websocket extension"""
def __init__(self):
super().__init__()
self.aio_context = AioContext.get_global_context()
self.stop_event = None
self.ssl_context = None
[docs] @staticmethod
def supported_transports() -> List[str]:
return ["http", "https", "ws", "wss"]
[docs] @staticmethod
def capabilities() -> Dict[str, Any]:
return {DriverCap.SEND_HEARTBEAT.value: True, DriverCap.SUPPORT_SSL.value: True}
[docs] def listen(self, connector: ConnectorInfo):
self._event_loop(Mode.PASSIVE, connector)
[docs] def connect(self, connector: ConnectorInfo):
self._event_loop(Mode.ACTIVE, connector)
[docs] def shutdown(self):
self.close_all()
if self.stop_event:
self.stop_event.set_result(None)
[docs] @staticmethod
def get_urls(scheme: str, resources: dict) -> (str, str):
secure = resources.get(DriverParams.SECURE)
if secure:
scheme = "https"
return get_tcp_urls(scheme, resources)
# Internal methods
def _event_loop(self, mode: Mode, connector: ConnectorInfo):
self.connector = connector
if mode != connector.mode:
raise CommError(CommError.ERROR, f"Connector mode doesn't match driver mode for {self.connector}")
self.aio_context.run_coro(self._async_event_loop(mode)).result()
async def _async_event_loop(self, mode: Mode):
self.stop_event = self.aio_context.get_event_loop().create_future()
params = self.connector.params
host = params.get(DriverParams.HOST.value)
port = params.get(DriverParams.PORT.value)
if mode == Mode.ACTIVE:
coroutine = self._async_connect(host, port)
else:
coroutine = self._async_listen(host, port)
await coroutine
async def _async_connect(self, host, port):
self.ssl_context = net_utils.get_ssl_context(self.connector.params, False)
if self.ssl_context:
scheme = "wss"
else:
scheme = "ws"
async with websockets.connect(
f"{scheme}://{host}:{port}", ssl=self.ssl_context, ping_interval=None, max_size=MAX_FRAME_SIZE
) as ws:
await self._handler(ws)
async def _async_listen(self, host, port):
self.ssl_context = net_utils.get_ssl_context(self.connector.params, True)
async with websockets.serve(
self._handler, host, port, ssl=self.ssl_context, ping_interval=None, max_size=MAX_FRAME_SIZE
):
await self.stop_event
async def _handler(self, websocket):
conn = None
try:
conn = WsConnection(websocket, self.aio_context, self.connector, self.ssl_context)
self.add_connection(conn)
await self._read_loop(conn)
self.close_connection(conn)
except ConnectionClosedOK as ex:
conn_info = str(conn) if conn else "N/A"
log.debug(f"Connection {conn_info} is closed by peer: {secure_format_exception(ex)}")
@staticmethod
async def _read_loop(conn: WsConnection):
while not conn.closing:
# Reading from websocket and call receiver CB
try:
frame = await conn.websocket.recv()
conn.process_frame(frame)
except ConnectionClosedOK as ex:
raise ex
except Exception as ex:
log.error(f"Exception {type(ex)} on connection {conn}: {ex}")
raise ex