Source code for nvflare.fuel.f3.drivers.net_utils

# Copyright (c) 2023, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import os
import random
import socket
import ssl
from ssl import SSLContext
from typing import Any, Optional
from urllib.parse import parse_qsl, urlencode, urlparse

from nvflare.fuel.f3.comm_error import CommError
from nvflare.fuel.f3.drivers.driver_params import DriverParams
from nvflare.fuel.utils.argument_utils import str2bool
from nvflare.security.logging import secure_format_exception

log = logging.getLogger(__name__)

LO_PORT = 1025
HI_PORT = 65535
MAX_ITER_SIZE = 10
RANDOM_TRIES = 20
BIND_TIME_OUT = 5
SECURE_SCHEMES = {"https", "wss", "grpcs", "agrpcs", "ngrpcs", "stcp", "satcp"}

# GRPC can't handle frame size over 2G. So the limit is set to (2G-2M)
MAX_FRAME_SIZE = 2 * 1024 * 1024 * 1024 - (2 * 1024 * 1024)
MAX_HEADER_SIZE = 1024 * 1024
MAX_PAYLOAD_SIZE = MAX_FRAME_SIZE - 16 - MAX_HEADER_SIZE

SSL_SERVER_PRIVATE_KEY = "server.key"
SSL_SERVER_CERT = "server.crt"
SSL_CLIENT_PRIVATE_KEY = "client.key"
SSL_CLIENT_CERT = "client.crt"
SSL_ROOT_CERT = "rootCA.pem"


[docs]def ssl_required(params: dict) -> bool: """Check if SSL is required""" scheme = params.get(DriverParams.SCHEME.value, None) return scheme in SECURE_SCHEMES or str2bool(params.get(DriverParams.SECURE.value))
[docs]def get_ssl_context(params: dict, ssl_server: bool) -> Optional[SSLContext]: if not ssl_required(params): return None ca_path = params.get(DriverParams.CA_CERT.value) if ssl_server: cert_path = params.get(DriverParams.SERVER_CERT.value) key_path = params.get(DriverParams.SERVER_KEY.value) else: cert_path = params.get(DriverParams.CLIENT_CERT.value) key_path = params.get(DriverParams.CLIENT_KEY.value) if not all([ca_path, cert_path, key_path]): scheme = params.get(DriverParams.SCHEME.value, "Unknown") role = "Server" if ssl_server else "Client" raise CommError(CommError.BAD_CONFIG, f"{role} certificate parameters are missing for scheme {scheme}") if ssl_server: ctx = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH) else: ctx = ssl.create_default_context() ctx.minimum_version = ssl.TLSVersion.TLSv1_2 ctx.verify_mode = ssl.CERT_REQUIRED ctx.check_hostname = False ctx.load_verify_locations(ca_path) ctx.load_cert_chain(certfile=cert_path, keyfile=key_path) return ctx
[docs]def get_address(params: dict) -> str: host = params.get(DriverParams.HOST.value, "0.0.0.0") port = params.get(DriverParams.PORT.value, 0) if not host: host = "0.0.0.0" return f"{host}:{port}"
[docs]def parse_port_range(entry: Any): if isinstance(entry, int): return range(entry, entry + 1) parts = entry.split("-") if len(parts) == 1: num = int(parts[0]) return range(num, num + 1) lo = int(parts[0]) if parts[0] else LO_PORT hi = int(parts[1]) if parts[1] else HI_PORT return range(lo, hi + 1)
[docs]def parse_port_list(ranges: Any) -> list: all_ranges = [] if isinstance(ranges, list): for r in ranges: all_ranges.append(parse_port_range(r)) else: all_ranges.append(parse_port_range(ranges)) return all_ranges
[docs]def check_tcp_port(port) -> bool: result = False s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) s.settimeout(BIND_TIME_OUT) try: s.bind(("", port)) result = True except Exception as e: log.debug(f"Port {port} binding error: {secure_format_exception(e)}") finally: s.close() return result
[docs]def get_open_tcp_port(resources: dict) -> Optional[int]: port = resources.get(DriverParams.PORT) if port: return port ports = resources.get(DriverParams.PORTS) if ports: all_ports = parse_port_list(ports) else: all_ports = [range(LO_PORT, HI_PORT + 1)] for port_range in all_ports: if len(port_range) <= MAX_ITER_SIZE: for port in port_range: if check_tcp_port(port): return port else: for i in range(RANDOM_TRIES): port = random.randint(port_range.start, port_range.stop - 1) if check_tcp_port(port): return port return None
[docs]def parse_url(url: str) -> dict: """Parse URL into a dictionary, saving original URL also""" if not url: return {} params = {DriverParams.URL.value: url} parsed_url = urlparse(url) params[DriverParams.SCHEME.value] = parsed_url.scheme parts = parsed_url.netloc.split(":") if len(parts) >= 1: host = parts[0] # Host is required in URL. 0 is used as the placeholder for empty host if host == "0": host = "" params[DriverParams.HOST.value] = host if len(parts) >= 2: params[DriverParams.PORT.value] = parts[1] params[DriverParams.PATH.value] = parsed_url.path params[DriverParams.PARAMS.value] = parsed_url.params params[DriverParams.QUERY.value] = parsed_url.query params[DriverParams.FRAG.value] = parsed_url.fragment if parsed_url.query: for k, v in parse_qsl(parsed_url.query): # Only last one is saved if duplicate keys params[k] = v return params
[docs]def encode_url(params: dict) -> str: temp = params.copy() # Original URL is not needed temp.pop(DriverParams.URL.value, None) scheme = temp.pop(DriverParams.SCHEME.value, None) host = temp.pop(DriverParams.HOST.value, None) if not host: host = "0" port = temp.pop(DriverParams.PORT.value, None) path = temp.pop(DriverParams.PATH.value, None) parameters = temp.pop(DriverParams.PARAMS.value, None) # Encoded query is not needed temp.pop(DriverParams.QUERY.value, None) frag = temp.pop(DriverParams.FRAG.value, None) url = f"{scheme}://{host}" if port: url += ":" + str(port) if path: url += path if parameters: url += ";" + parameters if temp: url += "?" + urlencode(temp) if frag: url += "#" + frag return url
[docs]def short_url(params: dict) -> str: """Get a short url to be used in logs""" url = params.get(DriverParams.URL.value) if url: return url subset = { k: params[k] for k in {DriverParams.SCHEME.value, DriverParams.HOST.value, DriverParams.PORT.value, DriverParams.PATH.value} } return encode_url(subset)
[docs]def get_tcp_urls(scheme: str, resources: dict) -> (str, str): """Generate URL pairs for connecting and listening for TCP-based protocols Args: scheme: The transport scheme resources: The resource restrictions like port ranges Returns: a tuple with connecting and listening URL Raises: CommError: If any error happens while sending the request """ host = resources.get("host") if resources else None if not host: host = "localhost" port = get_open_tcp_port(resources) if not port: raise CommError(CommError.BAD_CONFIG, "Can't find an open port in the specified range") # Always listen on all interfaces listening_url = f"{scheme}://0:{port}" connect_url = f"{scheme}://{host}:{port}" return connect_url, listening_url
[docs]def enhance_credential_info(params: dict): # must have CA ca_path = params.get(DriverParams.CA_CERT.value) if not ca_path: return params # assume all SSL credential files are in the same folder with CA cert cred_folder = os.path.dirname(ca_path) client_cert_path = params.get(DriverParams.CLIENT_CERT.value) if not client_cert_path: # see whether the file client cert file exists client_cert_path = os.path.join(cred_folder, SSL_CLIENT_CERT) if os.path.exists(client_cert_path): params[DriverParams.CLIENT_CERT.value] = client_cert_path client_key_path = params.get(DriverParams.CLIENT_KEY.value) if not client_key_path: # see whether the file client key file exists client_key_path = os.path.join(cred_folder, SSL_CLIENT_PRIVATE_KEY) if os.path.exists(client_key_path): params[DriverParams.CLIENT_KEY.value] = client_key_path server_cert_path = params.get(DriverParams.SERVER_CERT.value) if not server_cert_path: # see whether the file client cert file exists server_cert_path = os.path.join(cred_folder, SSL_SERVER_CERT) if os.path.exists(server_cert_path): params[DriverParams.SERVER_CERT.value] = server_cert_path server_key_path = params.get(DriverParams.SERVER_KEY.value) if not server_key_path: # see whether the file client key file exists server_key_path = os.path.join(cred_folder, SSL_SERVER_PRIVATE_KEY) if os.path.exists(server_key_path): params[DriverParams.SERVER_KEY.value] = server_key_path