# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import os
import shlex
import shutil
import socket
import ssl
import subprocess
import tempfile
import grpc
[docs]
class NVFlareConfig:
SERVER = "fed_server.json"
CLIENT = "fed_client.json"
ADMIN = "fed_admin.json"
[docs]
class NVFlareRole:
SERVER = "server"
CLIENT = "client"
ADMIN = "admin"
[docs]
def try_write_dir(path: str):
try:
created = False
if not os.path.exists(path):
created = True
os.makedirs(path, exist_ok=False)
fd, name = tempfile.mkstemp(dir=path)
with os.fdopen(fd, "w") as fp:
fp.write("dummy")
os.remove(name)
if created:
shutil.rmtree(path)
except OSError as e:
return e
[docs]
def try_bind_address(host: str, port: int):
"""Tries to bind to address."""
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
try:
sock.bind((host, port))
except OSError as e:
return e
finally:
sock.close()
return None
def _get_ca_cert_file_name():
return "rootCA.pem"
def _get_cert_file_name(role: str):
if role == NVFlareRole.SERVER:
return "server.crt"
return "client.crt"
def _get_prv_key_file_name(role: str):
if role == NVFlareRole.SERVER:
return "server.key"
return "client.key"
[docs]
def split_by_len(item, max_len):
return [item[ind : ind + max_len] for ind in range(0, len(item), max_len)]
def _get_conn_sec(startup: str):
# get connection security
# first try to see whether this is a client config.
client_config = os.path.join(startup, "fed_client.json")
if os.path.exists(client_config):
with open(client_config, "r") as f:
config = json.load(f)
return config["client"].get("connection_security", "mtls")
# try admin config
admin_config = os.path.join(startup, "fed_admin.json")
if os.path.exists(admin_config):
with open(admin_config, "r") as f:
config = json.load(f)
return config["admin"].get("connection_security", "mtls")
return "mtls"
[docs]
def check_grpc_server_running(startup: str, host: str, port: int, token=None) -> bool:
conn_sec = _get_conn_sec(startup)
secure = True
if conn_sec == "clear":
secure = False
if secure:
with open(os.path.join(startup, _get_ca_cert_file_name()), "rb") as f:
trusted_certs = f.read()
with open(os.path.join(startup, _get_prv_key_file_name(NVFlareRole.CLIENT)), "rb") as f:
private_key = f.read()
with open(os.path.join(startup, _get_cert_file_name(NVFlareRole.CLIENT)), "rb") as f:
certificate_chain = f.read()
call_credentials = grpc.metadata_call_credentials(
lambda context, callback: callback((("x-custom-token", token),), None)
)
credentials = grpc.ssl_channel_credentials(
certificate_chain=certificate_chain, private_key=private_key, root_certificates=trusted_certs
)
composite_credentials = grpc.composite_channel_credentials(credentials, call_credentials)
channel = grpc.secure_channel(target=f"{host}:{port}", credentials=composite_credentials)
else:
channel = grpc.insecure_channel(target=f"{host}:{port}")
try:
grpc.channel_ready_future(channel).result(timeout=10)
except grpc.FutureTimeoutError:
return False
return True
[docs]
def check_socket_server_running(startup: str, host: str, port: int, scheme: str = "https") -> bool:
"""Check if socket-based server (HTTP/HTTPS/TCP/STCP) is running and accessible.
This function performs a socket connection test with optional SSL/TLS.
It's used for HTTP/WebSocket and TCP-based FL servers.
Args:
startup: Path to startup directory containing certificates
host: Server hostname or IP address
port: Server port number
scheme: URL scheme ("http", "https", "tcp", "stcp")
Returns:
True if server is accessible, False otherwise
"""
conn_sec = _get_conn_sec(startup)
secure = True
if conn_sec == "clear":
secure = False
# Determine if we need SSL based on scheme
use_ssl = secure and scheme in ["https", "stcp"]
# Try a socket connection to check if port is reachable
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
sock.settimeout(10)
try:
if use_ssl:
# For secure connection, wrap socket with SSL and use client certificates
context = ssl.create_default_context(ssl.Purpose.SERVER_AUTH)
context.minimum_version = ssl.TLSVersion.TLSv1_2
ca_path = os.path.join(startup, _get_ca_cert_file_name())
cert_path = os.path.join(startup, _get_cert_file_name(NVFlareRole.CLIENT))
prv_key_path = os.path.join(startup, _get_prv_key_file_name(NVFlareRole.CLIENT))
context.load_verify_locations(ca_path)
context.load_cert_chain(cert_path, prv_key_path)
# Check hostname may fail for localhost, so disable it for preflight check
context.check_hostname = False
ssl_sock = context.wrap_socket(sock, server_hostname=host)
ssl_sock.connect((host, port))
ssl_sock.close()
else:
# For insecure connection, just check if we can connect
sock.connect((host, port))
sock.close()
return True
except (socket.timeout, socket.error, ssl.SSLError, OSError, ConnectionRefusedError):
# Connection failed - server is not accessible
return False
finally:
try:
sock.close()
except Exception:
pass
[docs]
def run_command_in_subprocess(command):
new_env = os.environ.copy()
process = subprocess.Popen(
shlex.split(command),
shell=False,
preexec_fn=os.setsid,
env=new_env,
stdin=subprocess.PIPE,
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
universal_newlines=True,
)
return process
[docs]
def get_communication_scheme(package_path: str, config_name: str, default_scheme: str = "http") -> str:
"""Read the communication scheme from package configuration files.
This function checks multiple sources to determine the communication scheme:
1. For servers: fed_server.json (service.scheme)
2. For all packages: comm_config.json in local/ or startup/ directories
Args:
package_path: Path to the package directory
config_name: Name of the configuration file (fed_server.json, fed_client.json, fed_admin.json)
default_scheme: Default scheme to return if no scheme is found
Returns:
The communication scheme (e.g., "grpc", "http")
"""
# First try to read from fed_xxx.json
startup = os.path.join(package_path, "startup")
fed_config_file = os.path.join(startup, config_name)
if os.path.exists(fed_config_file):
try:
with open(fed_config_file, "r") as f:
fed_config = json.load(f)
server_conf = fed_config.get("servers", [{}])[0]
service_config = server_conf.get("service", {})
scheme = service_config.get("scheme")
if scheme:
return scheme.lower()
except Exception:
pass
return default_scheme