Source code for nvflare.tool.package_checker.utils

# Copyright (c) 2022, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import json
import os
import shlex
import shutil
import socket
import subprocess
import tempfile
from typing import Any, Dict, Optional

import grpc
from requests import Request, Response, Session
from requests.adapters import HTTPAdapter


[docs] class NVFlareConfig: SERVER = "fed_server.json" CLIENT = "fed_client.json" ADMIN = "fed_admin.json"
[docs] class NVFlareRole: SERVER = "server" CLIENT = "client" ADMIN = "admin"
[docs] def try_write_dir(path: str): try: created = False if not os.path.exists(path): created = True os.makedirs(path, exist_ok=False) fd, name = tempfile.mkstemp(dir=path) with os.fdopen(fd, "w") as fp: fp.write("dummy") os.remove(name) if created: shutil.rmtree(path) except OSError as e: return e
[docs] def try_bind_address(host: str, port: int): """Tries to bind to address.""" sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) try: sock.bind((host, port)) except OSError as e: return e finally: sock.close() return None
def _create_http_session(ca_path=None, cert_path=None, prv_key_path=None): session = Session() adapter = HTTPAdapter(max_retries=1) session.mount("https://", adapter) if ca_path: session.verify = ca_path session.cert = (cert_path, prv_key_path) return session def _send_request( session, api_point, headers: Optional[Dict[str, Any]] = None, payload: Optional[Dict[str, Any]] = None ) -> Response: req = Request("POST", api_point, json=payload, headers=headers) prepared = session.prepare_request(req) resp = session.send(prepared) return resp
[docs] def parse_overseer_agent_args(overseer_agent_conf: dict, required_args: list) -> dict: result = {} for k in required_args: value = overseer_agent_conf.get("args", {}).get(k) if value is None: raise Exception(f"overseer agent missing arg '{k}'.") result[k] = value return result
[docs] def construct_dummy_overseer_response(overseer_agent_conf: dict, role: str) -> Response: overseer_agent_class = overseer_agent_conf.get("path") required_args = get_required_args_for_overseer_agent(overseer_agent_class=overseer_agent_class, role=role) overseer_agent_args = parse_overseer_agent_args(overseer_agent_conf, required_args) psp = {"sp_end_point": overseer_agent_args["sp_end_point"], "primary": True} response_content = {"primary_sp": psp, "sp_list": [psp]} resp = Response() resp.status_code = 200 resp._content = str.encode(json.dumps(response_content)) return resp
[docs] def get_required_args_for_overseer_agent(overseer_agent_class: str, role: str) -> list: """Gets required argument list for a specific overseer agent class.""" if overseer_agent_class == "nvflare.ha.overseer_agent.HttpOverseerAgent": required_args = ["overseer_end_point", "role", "project", "name"] if role == NVFlareRole.SERVER: required_args.extend(["fl_port", "admin_port"]) return required_args elif overseer_agent_class == "nvflare.ha.dummy_overseer_agent.DummyOverseerAgent": required_args = ["sp_end_point"] return required_args else: raise Exception(f"overseer agent {overseer_agent_class} is not supported.")
def _prepare_data(args: dict): data = dict(role=args["role"], project=args["project"]) if args["role"] == NVFlareRole.SERVER: data["sp_end_point"] = ":".join([args["name"], args["fl_port"], args["admin_port"]]) return data def _get_ca_cert_file_name(): return "rootCA.pem" def _get_cert_file_name(role: str): if role == NVFlareRole.SERVER: return "server.crt" return "client.crt" def _get_prv_key_file_name(role: str): if role == NVFlareRole.SERVER: return "server.key" return "client.key"
[docs] def split_by_len(item, max_len): return [item[ind : ind + max_len] for ind in range(0, len(item), max_len)]
def _get_conn_sec(startup: str): # get connection security # first try to see whether this is a client config. client_config = os.path.join(startup, "fed_client.json") if os.path.exists(client_config): with open(client_config, "r") as f: config = json.load(f) return config["client"].get("connection_security", "mtls") # try admin config admin_config = os.path.join(startup, "fed_admin.json") if os.path.exists(admin_config): with open(admin_config, "r") as f: config = json.load(f) return config["admin"].get("connection_security", "mtls") return "mtls"
[docs] def check_grpc_server_running(startup: str, host: str, port: int, token=None) -> bool: with open(os.path.join(startup, _get_ca_cert_file_name()), "rb") as f: trusted_certs = f.read() with open(os.path.join(startup, _get_prv_key_file_name(NVFlareRole.CLIENT)), "rb") as f: private_key = f.read() with open(os.path.join(startup, _get_cert_file_name(NVFlareRole.CLIENT)), "rb") as f: certificate_chain = f.read() conn_sec = _get_conn_sec(startup) secure = True if conn_sec == "clear": secure = False call_credentials = grpc.metadata_call_credentials( lambda context, callback: callback((("x-custom-token", token),), None) ) credentials = grpc.ssl_channel_credentials( certificate_chain=certificate_chain, private_key=private_key, root_certificates=trusted_certs ) composite_credentials = grpc.composite_channel_credentials(credentials, call_credentials) if secure: channel = grpc.secure_channel(target=f"{host}:{port}", credentials=composite_credentials) else: channel = grpc.insecure_channel(target=f"{host}:{port}") try: grpc.channel_ready_future(channel).result(timeout=10) except grpc.FutureTimeoutError: return False return True
[docs] def run_command_in_subprocess(command): new_env = os.environ.copy() process = subprocess.Popen( shlex.split(command), preexec_fn=os.setsid, env=new_env, stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, universal_newlines=True, ) return process