Source code for nvflare.tool.api_utils

# Copyright (c) 2023, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import time
from typing import List

from nvflare.fuel.flare_api.api_spec import JobNotFound, NoConnection
from nvflare.fuel.flare_api.flare_api import Session


[docs]def shutdown_system( prod_dir: str, username: str = "admin@nvidia.com", secure_mode: bool = True, timeout_in_sec: int = 30 ): admin_user_dir = os.path.join(prod_dir, username) print("connect to nvflare server") sess = None conn_timeout = 10 try: sess = Session(username=username, startup_path=admin_user_dir, secure_mode=secure_mode) sess.try_connect(conn_timeout) shutdown_system_by_session(sess=sess, timeout_in_sec=timeout_in_sec) except NoConnection: # system is already shutdown return finally: if sess: sess.close()
[docs]def shutdown_system_by_session(sess: Session, timeout_in_sec: int = 20): print("checking running jobs") jobs = sess.list_jobs() active_job_ids = get_running_job_ids(jobs) if len(active_job_ids) > 0: print("Warning: current running jobs will be aborted") abort_jobs(sess, active_job_ids) print("shutdown NVFLARE") sess.api.do_command("shutdown all") wait_for_system_shutdown(sess, timeout_in_sec=timeout_in_sec)
[docs]def get_running_job_ids(jobs: list) -> List[str]: if len(jobs) > 0: running_job_ids = [job for job in jobs if job["status"] == "RUNNING"] return running_job_ids else: return []
[docs]def abort_jobs(sess, job_ids): for job_id in job_ids: try: sess.abort_job(job_id) except JobNotFound: # ignore invalid job id pass
[docs]def wait_for_system_shutdown(sess: Session, timeout_in_sec: int = 30): start = time.time() duration = 0 cnt = 0 status = None while (status is None or status == "started") and duration < timeout_in_sec: try: sys_info = sess.get_system_info() status = sys_info.server_info.status curr = time.time() duration = curr - start if cnt % 25 == 0: print("waiting system to shutdown") cnt += 1 time.sleep(0.2) except Exception: # Server is already shutdown return
[docs]def wait_for_system_start( num_clients: int, prod_dir: str, username: str = "admin", secure_mode: bool = False, second_to_wait: int = 10, timeout_in_sec: int = 30, ): print(f"wait for {second_to_wait} seconds before FL system is up") time.sleep(second_to_wait) # just in case try to connect before server started flare_not_ready = True start = time.time() duration = 0 admin_user_dir = os.path.join(prod_dir, username) conn_timeout = 10.0 while flare_not_ready and duration < timeout_in_sec: print("trying to connect to server") sess = None try: sess = Session(username=username, startup_path=admin_user_dir, secure_mode=secure_mode) sess.try_connect(conn_timeout) sys_info = sess.get_system_info() print(f"Server info:\n{sys_info.server_info}") print("\nClient info") for client in sys_info.client_info: print(client) flare_not_ready = len(sys_info.client_info) < num_clients curr = time.time() duration = curr - start time.sleep(2) except NoConnection: # server is not up yet pass except Exception as e: print("failure", e) finally: if sess: sess.close() if flare_not_ready: raise RuntimeError("can't not connect to server within {timeout_in_sec} sec") else: print("ready to go")