Source code for nvflare.fuel.f3.drivers.aio_tcp_driver

# Copyright (c) 2023, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio
import logging
from concurrent.futures import CancelledError
from typing import Any, Dict, List

from nvflare.fuel.f3.comm_error import CommError
from nvflare.fuel.f3.drivers.aio_conn import AioConnection
from nvflare.fuel.f3.drivers.aio_context import AioContext
from nvflare.fuel.f3.drivers.base_driver import BaseDriver
from nvflare.fuel.f3.drivers.connector_info import ConnectorInfo, Mode
from nvflare.fuel.f3.drivers.driver_params import DriverCap, DriverParams
from nvflare.fuel.f3.drivers.net_utils import get_ssl_context
from nvflare.fuel.f3.drivers.tcp_driver import TcpDriver

log = logging.getLogger(__name__)


[docs]class AioTcpDriver(BaseDriver): def __init__(self): super().__init__() self.aio_ctx = AioContext.get_global_context() self.server = None self.ssl_context = None
[docs] @staticmethod def supported_transports() -> List[str]: return ["atcp", "satcp"]
[docs] @staticmethod def capabilities() -> Dict[str, Any]: return {DriverCap.SEND_HEARTBEAT.value: True, DriverCap.SUPPORT_SSL.value: True}
[docs] def listen(self, connector: ConnectorInfo): self._run(connector, Mode.PASSIVE)
[docs] def connect(self, connector: ConnectorInfo): self._run(connector, Mode.ACTIVE)
[docs] def shutdown(self): self.close_all() if self.server: self.server.close() # This will wake up the event loop to end the server self.aio_ctx.run_coro(asyncio.sleep(0))
[docs] @staticmethod def get_urls(scheme: str, resources: dict) -> (str, str): return TcpDriver.get_urls(scheme, resources)
# Internal methods def _run(self, connector: ConnectorInfo, mode: Mode): self.connector = connector if mode != self.connector.mode: raise CommError(CommError.ERROR, f"Connector mode doesn't match driver mode for {self.connector}") try: self.aio_ctx.run_coro(self._async_run(mode)).result() except CancelledError: log.debug(f"Connector {self.connector} is cancelled") async def _async_run(self, mode: Mode): params = self.connector.params host = params.get(DriverParams.HOST.value) port = params.get(DriverParams.PORT.value) if mode == Mode.ACTIVE: coroutine = self._tcp_connect(host, port) else: coroutine = self._tcp_listen(host, port) await coroutine async def _tcp_connect(self, host, port): self.ssl_context = get_ssl_context(self.connector.params, ssl_server=False) reader, writer = await asyncio.open_connection(host, port, ssl=self.ssl_context) await self._create_connection(reader, writer) async def _tcp_listen(self, host, port): self.ssl_context = get_ssl_context(self.connector.params, ssl_server=True) self.server = await asyncio.start_server(self._create_connection, host, port, ssl=self.ssl_context) async with self.server: await self.server.serve_forever() async def _create_connection(self, reader, writer): conn = AioConnection(self.connector, self.aio_ctx, reader, writer, self.ssl_context is not None) self.add_connection(conn) await conn.read_loop() self.close_connection(conn)