Source code for nvflare.tool.package_checker.utils

# Copyright (c) 2022, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import json
import os
import shlex
import shutil
import socket
import ssl
import subprocess
import tempfile
import time
from typing import Any, Dict, Optional, Tuple

import grpc
from requests import Request, RequestException, Response, Session, codes
from requests.adapters import HTTPAdapter

from nvflare.fuel.hci.conn import ALL_END


[docs]class NVFlareConfig: OVERSEER = "gunicorn.conf.py" SERVER = "fed_server.json" CLIENT = "fed_client.json" ADMIN = "fed_admin.json"
[docs]class NVFlareRole: SERVER = "server" CLIENT = "client" ADMIN = "admin"
[docs]def try_write_dir(path: str): try: created = False if not os.path.exists(path): created = True os.makedirs(path, exist_ok=False) fd, name = tempfile.mkstemp(dir=path) with os.fdopen(fd, "w") as fp: fp.write("dummy") os.remove(name) if created: shutil.rmtree(path) except OSError as e: return e
[docs]def try_bind_address(host: str, port: int): """Tries to bind to address.""" sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) try: sock.bind((host, port)) except OSError as e: return e finally: sock.close() return None
def _create_http_session(ca_path=None, cert_path=None, prv_key_path=None): session = Session() adapter = HTTPAdapter(max_retries=1) session.mount("https://", adapter) if ca_path: session.verify = ca_path session.cert = (cert_path, prv_key_path) return session def _send_request( session, api_point, headers: Optional[Dict[str, Any]] = None, payload: Optional[Dict[str, Any]] = None ) -> Response: req = Request("POST", api_point, json=payload, headers=headers) prepared = session.prepare_request(req) resp = session.send(prepared) return resp
[docs]def parse_overseer_agent_args(overseer_agent_conf: dict, required_args: list) -> dict: result = {} for k in required_args: value = overseer_agent_conf.get("args", {}).get(k) if value is None: raise Exception(f"overseer agent missing arg '{k}'.") result[k] = value return result
[docs]def construct_dummy_response(overseer_agent_args: dict) -> Response: psp = {"sp_end_point": overseer_agent_args["sp_end_point"], "primary": True} response_content = {"primary_sp": psp, "sp_list": [psp]} resp = Response() resp.status_code = 200 resp._content = str.encode(json.dumps(response_content)) return resp
[docs]def is_dummy_overseer_agent(overseer_agent_class: str) -> bool: if overseer_agent_class == "nvflare.ha.dummy_overseer_agent.DummyOverseerAgent": return True return False
[docs]def get_required_args_for_overseer_agent(overseer_agent_class: str, role: str) -> list: """Gets required argument list for a specific overseer agent class.""" if overseer_agent_class == "nvflare.ha.overseer_agent.HttpOverseerAgent": required_args = ["overseer_end_point", "role", "project", "name"] if role == NVFlareRole.SERVER: required_args.extend(["fl_port", "admin_port"]) return required_args elif overseer_agent_class == "nvflare.ha.dummy_overseer_agent.DummyOverseerAgent": required_args = ["sp_end_point"] return required_args else: raise Exception(f"overseer agent {overseer_agent_class} is not supported.")
def _prepare_data(args: dict): data = dict(role=args["role"], project=args["project"]) if args["role"] == NVFlareRole.SERVER: data["sp_end_point"] = ":".join([args["name"], args["fl_port"], args["admin_port"]]) return data def _get_ca_cert_file_name(): return "rootCA.pem" def _get_cert_file_name(role: str): if role == NVFlareRole.SERVER: return "server.crt" return "client.crt" def _get_prv_key_file_name(role: str): if role == NVFlareRole.SERVER: return "server.key" return "client.key"
[docs]def split_by_len(item, max_len): return [item[ind : ind + max_len] for ind in range(0, len(item), max_len)]
[docs]def check_overseer_running( startup: str, overseer_agent_args: dict, role: str, retry: int = 3 ) -> Tuple[Optional[Response], Optional[str]]: """Checks if overseer is running.""" session = _create_http_session( ca_path=os.path.join(startup, _get_ca_cert_file_name()), cert_path=os.path.join(startup, _get_cert_file_name(role)), prv_key_path=os.path.join(startup, _get_prv_key_file_name(role)), ) data = _prepare_data(overseer_agent_args) try_count = 0 retry_delay = 1 resp = None err = None while try_count < retry: try: resp = _send_request( session, api_point=overseer_agent_args["overseer_end_point"] + "/heartbeat", payload=data, ) if resp: break except RequestException as e: try_count += 1 time.sleep(retry_delay) err = str(e) return resp, err
[docs]def check_response(resp: Optional[Response]) -> bool: if not resp: return False if resp.status_code != codes.ok: return False return True
[docs]def check_socket_server_running(startup: str, host: str, port: int) -> bool: try: # SSL communication ctx = ssl.create_default_context() ctx.minimum_version = ssl.TLSVersion.TLSv1_2 ctx.verify_mode = ssl.CERT_REQUIRED ctx.check_hostname = False ctx.load_verify_locations(os.path.join(startup, _get_ca_cert_file_name())) ctx.load_cert_chain( certfile=os.path.join(startup, _get_cert_file_name(NVFlareRole.CLIENT)), keyfile=os.path.join(startup, _get_prv_key_file_name(NVFlareRole.CLIENT)), ) with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock: with ctx.wrap_socket(sock) as secure_sock: secure_sock.connect((host, port)) secure_sock.sendall(bytes(f"hello{ALL_END}", "utf-8")) secure_sock.recv() except Exception as e: print(e) return False return True
[docs]def check_grpc_server_running(startup: str, host: str, port: int, token=None) -> bool: with open(os.path.join(startup, _get_ca_cert_file_name()), "rb") as f: trusted_certs = f.read() with open(os.path.join(startup, _get_prv_key_file_name(NVFlareRole.CLIENT)), "rb") as f: private_key = f.read() with open(os.path.join(startup, _get_cert_file_name(NVFlareRole.CLIENT)), "rb") as f: certificate_chain = f.read() call_credentials = grpc.metadata_call_credentials( lambda context, callback: callback((("x-custom-token", token),), None) ) credentials = grpc.ssl_channel_credentials( certificate_chain=certificate_chain, private_key=private_key, root_certificates=trusted_certs ) composite_credentials = grpc.composite_channel_credentials(credentials, call_credentials) channel = grpc.secure_channel(target=f"{host}:{port}", credentials=composite_credentials) try: grpc.channel_ready_future(channel).result(timeout=10) except grpc.FutureTimeoutError: return False return True
[docs]def run_command_in_subprocess(command): new_env = os.environ.copy() process = subprocess.Popen( shlex.split(command), preexec_fn=os.setsid, env=new_env, stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, universal_newlines=True, ) return process