Source code for nvflare.tool.api_utils

# Copyright (c) 2023, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import time
from typing import List, Optional

from nvflare.fuel.flare_api.api_spec import JobNotFound, NoConnection, TargetType
from nvflare.fuel.flare_api.flare_api import Session


[docs] class SystemStartTimeout(RuntimeError): pass
def _client_names(client_info: list) -> List[str]: return [name for name in (getattr(client, "name", None) for client in client_info) if name] def _format_ready_clients(client_names: List[str], ready_count: int, expected_count: int) -> str: names = f" ({', '.join(client_names)})" if client_names else "" return f"Clients ready: {ready_count}/{expected_count}{names}"
[docs] def shutdown_system( prod_dir: str, username: str = "admin@nvidia.com", secure_mode: bool = True, timeout_in_sec: int = 30, wait: bool = True, verbose: bool = True, ) -> dict: from nvflare.tool.cli_output import print_human admin_user_dir = os.path.join(prod_dir, username) if verbose: print_human("connect to nvflare server") sess = None conn_timeout = 10 try: sess = Session(username=username, startup_path=admin_user_dir, secure_mode=secure_mode) sess.try_connect(conn_timeout) return shutdown_system_by_session(sess=sess, timeout_in_sec=timeout_in_sec, wait=wait, verbose=verbose) except NoConnection: # system is already shutdown return {"server_reachable": False, "already_stopped": True, "active_job_ids": [], "wait": wait} finally: if sess: sess.close()
[docs] def shutdown_system_by_session( sess: Session, timeout_in_sec: int = 20, wait: bool = True, verbose: bool = True ) -> dict: from nvflare.tool.cli_output import print_human if verbose: print_human("checking running jobs") jobs = sess.list_jobs() active_job_ids = get_running_job_ids(jobs) if len(active_job_ids) > 0: if verbose: print_human("Warning: current running jobs will be aborted") abort_jobs(sess, active_job_ids) if wait: if verbose: print_human("shutdown NVFLARE and wait for completion") else: if verbose: print_human("shutdown NVFLARE") sess.shutdown(TargetType.ALL, wait=wait, timeout=timeout_in_sec) return { "server_reachable": True, "already_stopped": False, "active_job_ids": active_job_ids, "active_jobs_aborted": bool(active_job_ids), "wait": wait, }
[docs] def get_running_job_ids(jobs: list) -> List[str]: running_job_ids = [] for job in jobs or []: if not isinstance(job, dict) or job.get("status") != "RUNNING": continue job_id = job.get("job_id") or job.get("id") if job_id: running_job_ids.append(job_id) return running_job_ids
[docs] def abort_jobs(sess, job_ids): for job_id in job_ids: try: sess.abort_job(job_id) except JobNotFound: # ignore invalid job id pass
[docs] def wait_for_system_start( num_clients: int, prod_dir: str, username: str = "admin", secure_mode: bool = False, second_to_wait: int = 10, timeout_in_sec: int = 30, poll_interval: float = 2.0, conn_timeout: float = 10.0, expected_clients: Optional[List[str]] = None, ): from nvflare.tool.cli_output import print_human if second_to_wait > 0: print_human(f"wait for {second_to_wait} seconds before FL system is up") time.sleep(second_to_wait) # just in case try to connect before server started flare_not_ready = True expected_client_set = set(expected_clients or []) start = time.time() deadline = start + timeout_in_sec admin_user_dir = os.path.join(prod_dir, username) last_error = None while flare_not_ready and time.time() < deadline: sess = None try: sess = Session(username=username, startup_path=admin_user_dir, secure_mode=secure_mode) remaining = max(deadline - time.time(), 0.1) sess.try_connect(min(conn_timeout, remaining)) sys_info = sess.get_system_info() client_names = _client_names(sys_info.client_info) ready_count = len(sys_info.client_info) expected_count = len(expected_client_set) if expected_client_set else num_clients if expected_client_set: registered_clients = set(client_names) missing_clients = sorted(expected_client_set - registered_clients) flare_not_ready = bool(missing_clients) else: missing_clients = [] flare_not_ready = ready_count < num_clients if flare_not_ready: if missing_clients: last_error = f"waiting for clients: {', '.join(missing_clients)}" print_human( f"Waiting for clients: {', '.join(missing_clients)} ({ready_count}/{expected_count} ready)" ) else: last_error = f"{ready_count} of {num_clients} clients registered" print_human(f"Waiting for clients: {ready_count}/{expected_count} ready") else: print_human(_format_ready_clients(client_names, ready_count, expected_count)) print_human("\nReady to go.") return sys_info except NoConnection: # server is not up yet last_error = "server is not reachable" except Exception as e: last_error = str(e) finally: if sess: try: sess.close() except Exception as e: last_error = str(e) remaining = deadline - time.time() if flare_not_ready and remaining > 0: time.sleep(min(poll_interval, remaining)) detail = f"; last error: {last_error}" if last_error else "" client_target = ( f"expected clients {', '.join(sorted(expected_client_set))}" if expected_client_set else f"{num_clients} clients" ) raise SystemStartTimeout(f"cannot connect to server with {client_target} within {timeout_in_sec} sec{detail}")