Source code for nvflare.private.fed.app.relay.relay

# Copyright (c) 2024, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import argparse
import json
import logging
import os
import sys
import threading

from nvflare.apis.fl_constant import ConnectionSecurity, ConnPropKey, ReservedKey, WorkspaceConstants
from nvflare.apis.fl_context import FLContext
from nvflare.apis.signal import Signal
from nvflare.apis.utils.decomposers import flare_decomposers
from nvflare.apis.workspace import Workspace
from nvflare.fuel.f3.cellnet.cell import Cell
from nvflare.fuel.f3.cellnet.fqcn import FQCN
from nvflare.fuel.f3.cellnet.net_agent import NetAgent
from nvflare.fuel.f3.drivers.driver_params import DriverParams
from nvflare.fuel.f3.drivers.net_utils import SSL_ROOT_CERT, enhance_credential_info
from nvflare.fuel.f3.message import Message as CellMessage
from nvflare.fuel.f3.mpm import MainProcessMonitor as mpm
from nvflare.fuel.sec.authn import set_add_auth_headers_filters
from nvflare.fuel.utils.argument_utils import parse_vars
from nvflare.fuel.utils.config_service import ConfigService, search_file
from nvflare.fuel.utils.log_utils import configure_logging
from nvflare.fuel.utils.url_utils import make_url
from nvflare.private.defs import ClientType
from nvflare.private.fed.authenticator import Authenticator, validate_auth_headers
from nvflare.private.fed.utils.identity_utils import TokenVerifier


[docs] class CellnetMonitor: def __init__(self, stop_event: threading.Event, workspace: str): self.stop_event = stop_event self.workspace = workspace
[docs] def cellnet_stopped(self): touch_file = os.path.join(self.workspace, WorkspaceConstants.SHUTDOWN_FILE) with open(touch_file, "a"): os.utime(touch_file, None) self.stop_event.set()
class _ConfigKey: PROJECT_NAME = "project_name" SERVER_IDENTITY = "server_identity" IDENTITY = "identity" CONNECT_TO = "connect_to"
[docs] def parse_arguments(): parser = argparse.ArgumentParser() parser.add_argument("--workspace", "-m", type=str, help="WORKSPACE folder", required=True) parser.add_argument("--relay_config", "-s", type=str, help="relay config json file", required=True) parser.add_argument("--set", metavar="KEY=VALUE", nargs="*") args = parser.parse_args() return args
[docs] def main(args): workspace = Workspace(root_dir=args.workspace) for name in [WorkspaceConstants.RESTART_FILE, WorkspaceConstants.SHUTDOWN_FILE]: try: f = workspace.get_file_path_in_root(name) if os.path.exists(f): os.remove(f) except Exception as ex: print(f"Could not remove file '{name}': {ex}. Please check your system before starting FL.") sys.exit(-1) configure_logging(workspace) logger = logging.getLogger() relay_config_file = workspace.get_file_path_in_startup(args.relay_config) with open(relay_config_file, "rt") as f: relay_config = json.load(f) if not isinstance(relay_config, dict): raise RuntimeError(f"invalid relay config file {args.relay_config}") project_name = relay_config.get(_ConfigKey.PROJECT_NAME) if not project_name: raise RuntimeError(f"invalid relay config file {args.relay_config}: missing {_ConfigKey.PROJECT_NAME}") server_identity = relay_config.get(_ConfigKey.SERVER_IDENTITY) if not server_identity: raise RuntimeError(f"invalid relay config file {args.relay_config}: missing {_ConfigKey.SERVER_IDENTITY}") my_identity = relay_config.get(_ConfigKey.IDENTITY) if not my_identity: raise RuntimeError(f"invalid relay config file {args.relay_config}: missing {_ConfigKey.IDENTITY}") parent = relay_config.get(_ConfigKey.CONNECT_TO) if not parent: raise RuntimeError(f"invalid relay config file {args.relay_config}: missing {_ConfigKey.CONNECT_TO}") parent_address = parent.get(ConnPropKey.ADDRESS) if not parent_address: raise RuntimeError(f"invalid relay config file {args.relay_config}: missing parent.address") parent_scheme = parent.get(ConnPropKey.SCHEME) if not parent_scheme: raise RuntimeError(f"invalid relay config file {args.relay_config}: missing parent.scheme") parent_fqcn = parent.get(ConnPropKey.FQCN) if not parent_fqcn: raise RuntimeError(f"invalid relay config file {args.relay_config}: missing parent.fqcn") cmd_vars = parse_vars(args.set) secure_train = cmd_vars.get("secure_train", False) logger.debug(f"{cmd_vars=} {secure_train=}") stop_event = threading.Event() monitor = CellnetMonitor(stop_event, args.workspace) ConfigService.initialize( section_files={}, config_path=[args.workspace], ) root_cert_path = search_file(SSL_ROOT_CERT, args.workspace) if not root_cert_path: raise ValueError(f"cannot find {SSL_ROOT_CERT} from config path {args.workspace}") credentials = { DriverParams.CA_CERT.value: root_cert_path, } enhance_credential_info(credentials) conn_security = parent.get(ConnPropKey.CONNECTION_SECURITY) secure_conn = True if conn_security: credentials[DriverParams.CONNECTION_SECURITY.value] = conn_security if conn_security == ConnectionSecurity.CLEAR: secure_conn = False parent_url = make_url(parent_scheme, parent_address, secure_conn) if parent_fqcn == FQCN.ROOT_SERVER: my_fqcn = my_identity root_url = parent_url parent_url = None else: my_fqcn = FQCN.join([parent_fqcn, my_identity]) root_url = None flare_decomposers.register() cell = Cell( fqcn=my_fqcn, root_url=root_url, secure=secure_conn, credentials=credentials, create_internal_listener=True, parent_url=parent_url, ) NetAgent(cell, agent_closed_cb=monitor.cellnet_stopped) cell.start() # authenticate authenticator = Authenticator( cell=cell, project_name=project_name, client_name=my_identity, client_type=ClientType.RELAY, expected_sp_identity=server_identity, secure_mode=secure_train, root_cert_file=credentials.get(DriverParams.CA_CERT.value), private_key_file=credentials.get(DriverParams.CLIENT_KEY.value), cert_file=credentials.get(DriverParams.CLIENT_CERT.value), msg_timeout=5.0, retry_interval=2.0, ) abort_signal = Signal() shared_fl_ctx = FLContext() shared_fl_ctx.set_public_props({ReservedKey.IDENTITY_NAME: my_identity}) token, token_signature, ssid, token_verifier = authenticator.authenticate( shared_fl_ctx=shared_fl_ctx, abort_signal=abort_signal, ) if secure_train: if not isinstance(token_verifier, TokenVerifier): raise RuntimeError(f"expect token_verifier to be TokenVerifier but got {type(token_verifier)}") set_add_auth_headers_filters(cell, my_identity, token, token_signature, ssid) cell.core_cell.add_incoming_filter( channel="*", topic="*", cb=_validate_auth_headers, token_verifier=token_verifier, logger=logger, ) logger.info(f"Successfully authenticated to {server_identity}: {token=} {ssid=}") # wait until stopped logger.info(f"Started relay {my_identity=} {my_fqcn=} {root_url=} {parent_url=} {parent_fqcn=}") stop_event.wait() cell.stop() logger.info(f"Relay {my_fqcn} stopped.")
def _validate_auth_headers(message: CellMessage, token_verifier: TokenVerifier, logger): """Validate auth headers from messages that go through the server. Args: message: the message to validate Returns: """ return validate_auth_headers(message, token_verifier, logger) if __name__ == "__main__": args = parse_arguments() rc = mpm.run(main_func=main, run_dir=args.workspace, args=args) sys.exit(rc)