# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
import json
import logging.config
import os
import shlex
import shutil
import subprocess
import sys
import tempfile
import threading
import time
from argparse import Namespace
from concurrent.futures import ThreadPoolExecutor
from multiprocessing import Manager, Process
from multiprocessing.connection import Client
from urllib.parse import urlparse
from nvflare.apis.fl_component import FLComponent
from nvflare.apis.fl_constant import (
FLMetaKey,
JobConstants,
MachineStatus,
RunnerTask,
RunProcessKey,
WorkspaceConstants,
)
from nvflare.apis.job_def import ALL_SITES, JobMetaKey
from nvflare.apis.utils.job_utils import convert_legacy_zipped_app_to_job
from nvflare.apis.workspace import Workspace
from nvflare.fuel.common.exit_codes import ProcessExitCode
from nvflare.fuel.common.multi_process_executor_constants import CommunicationMetaData
from nvflare.fuel.f3.mpm import MainProcessMonitor as mpm
from nvflare.fuel.f3.stats_pool import StatsPoolManager
from nvflare.fuel.hci.server.authz import AuthorizationService
from nvflare.fuel.sec.audit import AuditService
from nvflare.fuel.utils.argument_utils import parse_vars
from nvflare.fuel.utils.gpu_utils import get_host_gpu_ids
from nvflare.fuel.utils.network_utils import get_open_ports
from nvflare.fuel.utils.zip_utils import split_path, unzip_all_from_bytes, zip_directory_to_bytes
from nvflare.private.defs import AppFolderConstants
from nvflare.private.fed.app.deployer.simulator_deployer import SimulatorDeployer
from nvflare.private.fed.app.utils import init_security_content_service, kill_child_processes
from nvflare.private.fed.client.client_status import ClientStatus
from nvflare.private.fed.server.job_meta_validator import JobMetaValidator
from nvflare.private.fed.simulator.simulator_app_runner import SimulatorServerAppRunner
from nvflare.private.fed.simulator.simulator_audit import SimulatorAuditor
from nvflare.private.fed.simulator.simulator_const import SimulatorConstants
from nvflare.private.fed.utils.fed_utils import add_logfile_handler, fobs_initialize, get_simulator_app_root, split_gpus
from nvflare.security.logging import secure_format_exception, secure_log_traceback
from nvflare.security.security import EmptyAuthorizer
CLIENT_CREATE_POOL_SIZE = 200
POOL_STATS_DIR = "pool_stats"
SIMULATOR_POOL_STATS = "simulator_cell_stats.json"
[docs]class SimulatorRunner(FLComponent):
def __init__(
self,
job_folder: str,
workspace: str,
clients=None,
n_clients=None,
threads=None,
gpu=None,
max_clients=100,
end_run_for_all=False,
):
super().__init__()
self.job_folder = job_folder
self.workspace = workspace
self.clients = clients
self.n_clients = n_clients
self.threads = threads
self.gpu = gpu
self.max_clients = max_clients
self.end_run_for_all = end_run_for_all
self.ask_to_stop = False
self.simulator_root = None
self.server = None
self.deployer = SimulatorDeployer()
self.client_names = []
self.federated_clients = []
self.client_config = None
self.deploy_args = None
self.build_ctx = None
self.clients_created = 0
running_dir = os.getcwd()
if self.workspace is None:
self.workspace = "simulator_workspace"
self.logger.warn(
f"Simulator workspace is not provided. Set it to the default location:"
f" {os.path.join(running_dir, self.workspace)}"
)
self.workspace = os.path.join(running_dir, self.workspace)
def _generate_args(
self, job_folder: str, workspace: str, clients=None, n_clients=None, threads=None, gpu=None, max_clients=100
):
args = Namespace(
job_folder=job_folder,
workspace=workspace,
clients=clients,
n_clients=n_clients,
threads=threads,
gpu=gpu,
max_clients=max_clients,
)
args.set = []
return args
[docs] def setup(self):
self.args = self._generate_args(
self.job_folder, self.workspace, self.clients, self.n_clients, self.threads, self.gpu, self.max_clients
)
if self.args.clients:
self.client_names = self.args.clients.strip().split(",")
else:
if self.args.n_clients:
for i in range(self.args.n_clients):
self.client_names.append("site-" + str(i + 1))
log_config_file_path = os.path.join(self.args.workspace, "startup", WorkspaceConstants.LOGGING_CONFIG)
if not os.path.isfile(log_config_file_path):
log_config_file_path = os.path.join(os.path.dirname(__file__), WorkspaceConstants.LOGGING_CONFIG)
logging.config.fileConfig(fname=log_config_file_path, disable_existing_loggers=False)
self.args.log_config = None
self.args.config_folder = "config"
self.args.job_id = SimulatorConstants.JOB_NAME
self.args.client_config = os.path.join(self.args.config_folder, JobConstants.CLIENT_JOB_CONFIG)
self.args.env = os.path.join("config", AppFolderConstants.CONFIG_ENV)
cwd = os.getcwd()
self.args.job_folder = os.path.join(cwd, self.args.job_folder)
self.args.end_run_for_all = self.end_run_for_all
if not os.path.exists(self.args.workspace):
os.makedirs(self.args.workspace)
os.chdir(self.args.workspace)
fobs_initialize()
AuthorizationService.initialize(EmptyAuthorizer())
AuditService.the_auditor = SimulatorAuditor()
self.simulator_root = self.args.workspace
self._cleanup_workspace()
init_security_content_service(self.args.workspace)
try:
data_bytes, job_name, meta = self.validate_job_data()
if not self.client_names:
self.client_names = self._extract_client_names_from_meta(meta)
if not self.client_names:
self.args.n_clients = 2
self.logger.warn("The number of simulator clients is not provided. Setting it to default: 2")
for i in range(self.args.n_clients):
self.client_names.append("site-" + str(i + 1))
if self.args.gpu is None and self.args.threads is None:
self.args.threads = 1
self.logger.warn("The number of threads is not provided. Set it to default: 1")
if self.max_clients < len(self.client_names):
self.logger.error(
f"The number of clients ({len(self.client_names)}) can not be more than the "
f"max_number of clients ({self.max_clients})"
)
return False
if self.args.gpu:
try:
gpu_groups = split_gpus(self.args.gpu)
except ValueError as e:
self.logger.error(f"GPUs group list option in wrong format. Error: {e}")
return False
host_gpus = [str(x) for x in (get_host_gpu_ids())]
gpu_ids = [x.split(",") for x in gpu_groups]
if host_gpus and not set().union(*gpu_ids).issubset(host_gpus):
wrong_gpus = [x for x in gpu_groups if x not in host_gpus]
self.logger.error(f"These GPUs are not available: {wrong_gpus}")
return False
if len(gpu_groups) > len(self.client_names):
self.logger.error(
f"The number of clients ({len(self.client_names)}) must be larger than or equal to "
f"the number of GPU groups: ({len(gpu_groups)})"
)
return False
if len(gpu_groups) > 1:
if self.args.threads and self.args.threads > 1:
self.logger.info(
"When running with multi GPU, each GPU group will run with only 1 thread. "
"Set the Threads to 1."
)
self.args.threads = 1
elif len(gpu_groups) == 1:
if self.args.threads is None:
self.args.threads = 1
self.logger.warn("The number of threads is not provided. Set it to default: 1")
if self.args.threads and self.args.threads > len(self.client_names):
self.logger.error("The number of threads to run can not be larger than the number of clients.")
return False
if not (self.args.gpu or self.args.threads):
self.logger.error("Please provide the number of threads or provide gpu options to run the simulator.")
return False
self._validate_client_names(meta, self.client_names)
# Deploy the FL server
self.logger.info("Create the Simulator Server.")
simulator_server, self.server = self.deployer.create_fl_server(self.args)
# self.services.deploy(self.args, grpc_args=simulator_server)
url = self.server.get_cell().get_root_url_for_child()
parsed_url = urlparse(url)
self.args.sp_target = parsed_url.netloc
self.args.sp_scheme = parsed_url.scheme
self.logger.info("Deploy the Apps.")
self._deploy_apps(job_name, data_bytes, meta, log_config_file_path)
return True
except Exception as e:
self.logger.error(f"Simulator setup error: {secure_format_exception(e)}")
secure_log_traceback()
return False
def _cleanup_workspace(self):
os.makedirs(self.simulator_root, exist_ok=True)
with tempfile.TemporaryDirectory() as temp_dir:
startup_dir = os.path.join(self.args.workspace, "startup")
temp_start_up = os.path.join(temp_dir, "startup")
if os.path.exists(startup_dir):
shutil.move(startup_dir, temp_start_up)
if os.path.exists(self.simulator_root):
shutil.rmtree(self.simulator_root)
if os.path.exists(temp_start_up):
shutil.move(temp_start_up, startup_dir)
def _setup_local_startup(self, log_config_file_path, workspace):
local_dir = os.path.join(workspace, "local")
startup = os.path.join(workspace, "startup")
os.makedirs(local_dir, exist_ok=True)
shutil.copyfile(log_config_file_path, os.path.join(local_dir, WorkspaceConstants.LOGGING_CONFIG))
shutil.copytree(os.path.join(self.simulator_root, "startup"), startup)
[docs] def validate_job_data(self):
# Validate the simulate job
job_name = split_path(self.args.job_folder)[1]
data = zip_directory_to_bytes("", self.args.job_folder)
data_bytes = convert_legacy_zipped_app_to_job(data)
job_validator = JobMetaValidator()
valid, error, meta = job_validator.validate(job_name, data_bytes)
if not valid:
raise RuntimeError(error)
return data_bytes, job_name, meta
def _extract_client_names_from_meta(self, meta):
client_names = []
for _, participants in meta.get(JobMetaKey.DEPLOY_MAP, {}).items():
for p in participants:
if p.upper() != ALL_SITES and p != "server":
client_names.append(p)
return client_names
def _validate_client_names(self, meta, client_names):
no_app_clients = []
for name in client_names:
name_matched = False
for _, participants in meta.get(JobMetaKey.DEPLOY_MAP, {}).items():
if len(participants) == 1 and participants[0].upper() == ALL_SITES:
name_matched = True
break
if name in participants:
name_matched = True
break
if not name_matched:
no_app_clients.append(name)
if no_app_clients:
raise RuntimeError(f"The job does not have App to run for clients: {no_app_clients}")
def _deploy_apps(self, job_name, data_bytes, meta, log_config_file_path):
with tempfile.TemporaryDirectory() as temp_dir:
if os.path.exists(temp_dir):
shutil.rmtree(temp_dir)
os.mkdir(temp_dir)
unzip_all_from_bytes(data_bytes, temp_dir)
temp_job_folder = os.path.join(temp_dir, job_name)
for app_name, participants in meta.get(JobMetaKey.DEPLOY_MAP).items():
if len(participants) == 1 and participants[0].upper() == ALL_SITES:
participants = ["server"]
participants.extend([client for client in self.client_names])
for p in participants:
if p == "server" or p in self.client_names:
app_root = get_simulator_app_root(self.simulator_root, p)
self._setup_local_startup(log_config_file_path, os.path.join(self.simulator_root, p))
app = os.path.join(temp_job_folder, app_name)
shutil.copytree(app, app_root)
job_meta_file = os.path.join(self.simulator_root, "server", WorkspaceConstants.JOB_META_FILE)
with open(job_meta_file, "w") as f:
json.dump(meta, f, indent=4)
[docs] def split_clients(self, clients: [], gpus: []):
split_clients = []
for _ in gpus:
split_clients.append([])
index = 0
for client in clients:
split_clients[index % len(gpus)].append(client)
index += 1
return split_clients
[docs] def create_clients(self):
# Deploy the FL clients
self.logger.info("Create the simulate clients.")
clients_created_waiter = threading.Event()
for client_name in self.client_names:
self.create_client(client_name)
self.logger.info("Set the client status ready.")
self._set_client_status()
[docs] def create_client(self, client_name):
client, self.client_config, self.deploy_args, self.build_ctx = self.deployer.create_fl_client(
client_name, self.args
)
self.federated_clients.append(client)
def _set_client_status(self):
for client in self.federated_clients:
app_client_root = get_simulator_app_root(self.simulator_root, client.client_name)
client.app_client_root = app_client_root
client.args = self.args
# self.create_client_runner(client)
client.simulate_running = False
client.status = ClientStatus.STARTED
[docs] def run(self):
try:
manager = Manager()
return_dict = manager.dict()
process = Process(target=self.run_process, args=(return_dict,))
process.start()
process.join()
run_status = self._get_return_code(return_dict, process, self.workspace)
return run_status
except KeyboardInterrupt:
self.logger.info("KeyboardInterrupt, terminate all the child processes.")
kill_child_processes(os.getpid())
return -9
def _get_return_code(self, return_dict, process, workspace):
return_code = return_dict.get("run_status")
if return_code:
self.logger.info(f"process run_status: {return_code}")
else:
rc_file = os.path.join(workspace, FLMetaKey.PROCESS_RC_FILE)
if os.path.exists(rc_file):
try:
with open(rc_file, "r") as f:
return_code = int(f.readline())
os.remove(rc_file)
self.logger.info(f"return_code from process_rc_file: {return_code}")
except Exception:
self.logger.warning(
f"Could not get the return_code from {rc_file}, Return the RC from the process:{process.pid}"
)
return_code = process.exitcode
else:
return_code = process.exitcode
self.logger.info(f"return_code from process.exitcode: {return_code}")
return return_code
[docs] def run_process(self, return_dict):
# run_status = self.simulator_run_main()
try:
run_status = mpm.run(
main_func=self.simulator_run_main, run_dir=self.workspace, shutdown_grace_time=3, cleanup_grace_time=6
)
except Exception as e:
self.logger.error(f"Simulator main run with exception: {secure_format_exception(e)}")
run_status = ProcessExitCode.EXCEPTION
return_dict["run_status"] = run_status
[docs] def simulator_run_main(self):
if self.setup():
try:
self.create_clients()
self.server.engine.run_processes[SimulatorConstants.JOB_NAME] = {
RunProcessKey.LISTEN_PORT: None,
RunProcessKey.CONNECTION: None,
RunProcessKey.CHILD_PROCESS: None,
RunProcessKey.JOB_ID: SimulatorConstants.JOB_NAME,
RunProcessKey.PARTICIPANTS: self.server.engine.client_manager.clients,
}
self.logger.info("Deploy and start the Server App.")
args = copy.deepcopy(self.args)
server_thread = threading.Thread(target=self.start_server_app, args=[args])
server_thread.start()
# wait for the server app is started
while self.server.engine.engine_info.status != MachineStatus.STARTED:
time.sleep(1.0)
if not server_thread.is_alive():
raise RuntimeError("Could not start the Server App.")
# # Start the client heartbeat calls.
# for client in self.federated_clients:
# client.start_heartbeat(interval=2)
if self.args.gpu:
gpus = split_gpus(self.args.gpu)
split_clients = self.split_clients(self.federated_clients, gpus)
else:
gpus = [None]
split_clients = [self.federated_clients]
executor = ThreadPoolExecutor(max_workers=len(gpus))
for index in range(len(gpus)):
clients = split_clients[index]
executor.submit(lambda p: self.client_run(*p), [clients, gpus[index]])
executor.shutdown()
# Abort the server after all clients finished run
self.server.abort_run()
server_thread.join()
run_status = 0
except Exception as e:
self.logger.error(f"Simulator run error: {secure_format_exception(e)}")
run_status = 2
finally:
# self.services.close()
self.deployer.close()
else:
run_status = 1
return run_status
[docs] def client_run(self, clients, gpu):
client_runner = SimulatorClientRunner(self.args, clients, self.client_config, self.deploy_args, self.build_ctx)
client_runner.run(gpu)
[docs] def start_server_app(self, args):
app_server_root = os.path.join(self.simulator_root, "server", SimulatorConstants.JOB_NAME, "app_server")
args.workspace = os.path.join(self.simulator_root, "server")
os.chdir(args.workspace)
log_file = os.path.join(self.simulator_root, "server", WorkspaceConstants.LOG_FILE_NAME)
add_logfile_handler(log_file)
args.server_config = os.path.join("config", JobConstants.SERVER_JOB_CONFIG)
app_custom_folder = os.path.join(app_server_root, "custom")
sys.path.append(app_custom_folder)
startup = os.path.join(args.workspace, WorkspaceConstants.STARTUP_FOLDER_NAME)
os.makedirs(startup, exist_ok=True)
local = os.path.join(args.workspace, WorkspaceConstants.SITE_FOLDER_NAME)
os.makedirs(local, exist_ok=True)
workspace = Workspace(root_dir=args.workspace, site_name="server")
self.server.job_cell = self.server.create_job_cell(
SimulatorConstants.JOB_NAME,
self.server.get_cell().get_root_url_for_child(),
self.server.get_cell().get_internal_listener_url(),
False,
None,
)
server_app_runner = SimulatorServerAppRunner(self.server)
snapshot = None
kv_list = [f"secure_train={self.server.secure_train}"]
server_app_runner.start_server_app(
workspace, args, app_server_root, args.job_id, snapshot, self.logger, kv_list=kv_list
)
# start = time.time()
# while self.services.engine.client_manager.clients:
# # Wait for the clients to shutdown and quite first.
# time.sleep(0.1)
# if time.time() - start > 30.:
# break
self.dump_stats(workspace)
self.server.admin_server.stop()
self.server.close()
[docs] def dump_stats(self, workspace: Workspace):
stats_dict = StatsPoolManager.to_dict()
json_object = json.dumps(stats_dict, indent=4)
os.makedirs(os.path.join(workspace.get_root_dir(), POOL_STATS_DIR))
file = os.path.join(workspace.get_root_dir(), POOL_STATS_DIR, SIMULATOR_POOL_STATS)
with open(file, "w") as outfile:
outfile.write(json_object)
[docs]class SimulatorClientRunner(FLComponent):
def __init__(self, args, clients: [], client_config, deploy_args, build_ctx):
super().__init__()
self.args = args
self.federated_clients = clients
self.run_client_index = -1
self.simulator_root = self.args.workspace
self.client_config = client_config
self.deploy_args = deploy_args
self.build_ctx = build_ctx
self.kv_list = parse_vars(args.set)
self.logging_config = os.path.join(self.args.workspace, "local", WorkspaceConstants.LOGGING_CONFIG)
self.clients_finished_end_run = []
[docs] def run(self, gpu):
try:
# self.create_clients()
self.logger.info("Start the clients run simulation.")
executor = ThreadPoolExecutor(max_workers=self.args.threads)
lock = threading.Lock()
timeout = self.kv_list.get("simulator_worker_timeout", 60.0)
for i in range(self.args.threads):
executor.submit(
lambda p: self.run_client_thread(*p),
[self.args.threads, gpu, lock, self.args.end_run_for_all, timeout],
)
# wait for the server and client running thread to finish.
executor.shutdown()
except Exception as e:
self.logger.error(f"SimulatorClientRunner run error: {secure_format_exception(e)}")
finally:
for client in self.federated_clients:
threading.Thread(target=self._shutdown_client, args=[client]).start()
def _shutdown_client(self, client):
try:
client.communicator.heartbeat_done = True
# time.sleep(3)
client.terminate()
# client.close()
client.status = ClientStatus.STOPPED
client.communicator.cell.stop()
except:
# Ignore the exception for the simulator client shutdown
self.logger.warn(f"Exception happened to client{client.name} during shutdown ")
[docs] def run_client_thread(self, num_of_threads, gpu, lock, end_run_for_all, timeout=60):
stop_run = False
interval = 1
client_to_run = None # indicates the next client to run
try:
while not stop_run:
time.sleep(interval)
with lock:
if not client_to_run:
client = self.get_next_run_client(gpu)
else:
client = client_to_run
client.simulate_running = True
stop_run, client_to_run, end_run_client = self.do_one_task(
client, num_of_threads, gpu, lock, timeout=timeout
)
if end_run_client:
with lock:
self.clients_finished_end_run.append(end_run_client)
client.simulate_running = False
if end_run_for_all:
self._end_run_clients(gpu, lock, num_of_threads, timeout)
except Exception as e:
self.logger.error(f"run_client_thread error: {secure_format_exception(e)}")
def _end_run_clients(self, gpu, lock, num_of_threads, timeout):
"""After the WF reaches the END_RUN, each running thread will try to pick up one of the remaining client
which has not run the END_RUN yet, then execute the END_RUN handler, until all the clients have done so.
These client END_RUN event handler only execute when "end_run_for_all" has been set.
Multiple client running threads will try to pick up the client from the same clients pool.
"""
# Each thread only stop picking up the NOT-DONE client until all clients have run the END_RUN event.
while len(self.clients_finished_end_run) != len(self.federated_clients):
with lock:
end_run_client = self._pick_next_client()
if end_run_client:
self.do_one_task(
end_run_client, num_of_threads, gpu, lock, timeout=timeout, task_name=RunnerTask.END_RUN
)
with lock:
end_run_client.simulate_running = False
def _pick_next_client(self):
for client in self.federated_clients:
# Ensure the client has not run the END_RUN event
if client.client_name not in self.clients_finished_end_run and not client.simulate_running:
client.simulate_running = True
self.clients_finished_end_run.append(client.client_name)
return client
return None
[docs] def do_one_task(self, client, num_of_threads, gpu, lock, timeout=60.0, task_name=RunnerTask.TASK_EXEC):
open_port = get_open_ports(1)[0]
client_workspace = os.path.join(self.args.workspace, client.client_name)
logging_config = os.path.join(
self.args.workspace, client.client_name, "local", WorkspaceConstants.LOGGING_CONFIG
)
command = (
sys.executable
+ " -m nvflare.private.fed.app.simulator.simulator_worker -o "
+ client_workspace
+ " --logging_config "
+ logging_config
+ " --client "
+ client.client_name
+ " --token "
+ client.token
+ " --port "
+ str(open_port)
+ " --parent_pid "
+ str(os.getpid())
+ " --simulator_root "
+ self.simulator_root
+ " --root_url "
+ str(client.cell.get_root_url_for_child())
+ " --parent_url "
+ str(client.cell.get_internal_listener_url())
+ " --task_name "
+ str(task_name)
)
if gpu:
command += " --gpu " + str(gpu)
new_env = os.environ.copy()
new_env["PYTHONPATH"] = os.pathsep.join(self._get_new_sys_path())
_ = subprocess.Popen(shlex.split(command, True), preexec_fn=os.setsid, env=new_env)
conn = self._create_connection(open_port, timeout=timeout)
self.build_ctx["client_name"] = client.client_name
deploy_args = copy.deepcopy(self.deploy_args)
deploy_args.workspace = os.path.join(deploy_args.workspace, client.client_name)
data = {
# SimulatorConstants.CLIENT: client,
SimulatorConstants.CLIENT_CONFIG: self.client_config,
SimulatorConstants.DEPLOY_ARGS: deploy_args,
SimulatorConstants.BUILD_CTX: self.build_ctx,
}
conn.send(data)
end_run_client = None
while True:
stop_run = conn.recv()
if stop_run:
end_run_client = conn.recv()
with lock:
if num_of_threads != len(self.federated_clients):
next_client = self.get_next_run_client(gpu)
else:
next_client = client
if not stop_run and next_client.client_name == client.client_name:
conn.send(True)
else:
conn.send(False)
break
return stop_run, next_client, end_run_client
def _get_new_sys_path(self):
new_sys_path = []
for i in range(0, len(sys.path) - 1):
if sys.path[i]:
new_sys_path.append(sys.path[i])
return new_sys_path
def _create_connection(self, open_port, timeout=60.0):
conn = None
start = time.time()
while not conn:
try:
address = ("localhost", open_port)
conn = Client(address, authkey=CommunicationMetaData.CHILD_PASSWORD.encode())
except Exception:
if time.time() - start > timeout:
raise RuntimeError(
f"Failed to create connection to the child process in {self.__class__.__name__},"
f" timeout: {timeout}"
)
time.sleep(1.0)
pass
return conn
[docs] def get_next_run_client(self, gpu):
# Find the next client which is not currently running
while True:
self.run_client_index = (self.run_client_index + 1) % len(self.federated_clients)
client = self.federated_clients[self.run_client_index]
if not client.simulate_running:
break
self.logger.info(f"Simulate Run client: {client.client_name} on GPU group: {gpu}")
return client