Source code for nvflare.private.fed.app.simulator.simulator_worker

# Copyright (c) 2022, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import argparse
import os
import sys
import threading
import time
from multiprocessing.connection import Listener

from nvflare.apis.event_type import EventType
from nvflare.apis.fl_component import FLComponent
from nvflare.apis.fl_constant import FLContextKey, RunnerTask, WorkspaceConstants
from nvflare.apis.workspace import Workspace
from nvflare.fuel.f3.cellnet.cell import Cell
from nvflare.fuel.f3.cellnet.fqcn import FQCN
from nvflare.fuel.f3.mpm import MainProcessMonitor as mpm
from nvflare.fuel.hci.server.authz import AuthorizationService
from nvflare.fuel.sec.audit import AuditService
from nvflare.fuel.utils.log_utils import dynamic_log_config
from nvflare.private.fed.app.deployer.base_client_deployer import BaseClientDeployer
from nvflare.private.fed.app.utils import check_parent_alive, init_security_content_service
from nvflare.private.fed.client.client_engine import ClientEngine
from nvflare.private.fed.client.client_status import ClientStatus
from nvflare.private.fed.client.fed_client import FederatedClient
from nvflare.private.fed.simulator.simulator_app_runner import SimulatorClientAppRunner
from nvflare.private.fed.simulator.simulator_audit import SimulatorAuditor
from nvflare.private.fed.simulator.simulator_const import SimulatorConstants
from nvflare.private.fed.utils.fed_utils import fobs_initialize, get_simulator_app_root, register_ext_decomposers
from nvflare.security.logging import secure_format_exception, secure_log_traceback
from nvflare.security.security import EmptyAuthorizer

CELL_CONNECT_CHECK_TIMEOUT = 10.0
FETCH_TASK_RUN_RETRY = 3


[docs] class ClientTaskWorker(FLComponent):
[docs] def create_client_engine(self, federated_client: FederatedClient, args, rank=0): client_engine = ClientEngine(federated_client, args, rank) federated_client.set_client_engine(client_engine) federated_client.run_manager = None client_engine.fire_event(EventType.SYSTEM_START, client_engine.new_context())
[docs] def create_client_runner(self, client): """Create the ClientRunner for the client to run the ClientApp. Args: client: the client to run """ app_client_root = client.app_client_root args = client.args args.client_name = client.client_name args.token = client.token client_app_runner = SimulatorClientAppRunner() client_app_runner.client_runner = client_app_runner.create_client_runner( app_client_root, args, args.config_folder, client, False ) client_runner = client_app_runner.client_runner with client_runner.engine.new_context() as fl_ctx: client_app_runner.start_command_agent(args, client, fl_ctx) client_app_runner.sync_up_parents_process(client) client_runner.engine.cell = client.cell client_runner.init_run(app_client_root, args)
[docs] def do_one_task(self, client, args): interval = 1.0 stop_run = False end_run_client = None # Create the ClientRunManager and ClientRunner for the new client to run try: if client.run_manager is None: self.create_client_runner(client) self.logger.info(f"Initialize ClientRunner for client: {client.client_name}") with client.run_manager.new_context() as fl_ctx: client_runner = fl_ctx.get_prop(FLContextKey.RUNNER) self.fire_event(EventType.SWAP_IN, fl_ctx) if args.task_name == RunnerTask.END_RUN: client_runner.end_run_events_sequence() self.logger.info("Simulator END_RUN sequence.") stop_run = True else: run_task_tries = 0 while True: interval, task_processed = client_runner.fetch_and_run_one_task(fl_ctx) if task_processed: self.logger.info( f"Finished one task run for client: {client.client_name} " f"interval: {interval} task_processed: {task_processed}" ) # if any client got the END_RUN event, stop the simulator run. if client_runner.run_abort_signal.triggered: client_runner.end_run_events_sequence() end_run_client = client.client_name stop_run = True self.logger.info("End the Simulator run.") break else: if task_processed: break else: run_task_tries += 1 if run_task_tries >= FETCH_TASK_RUN_RETRY: break time.sleep(0.5) except Exception as e: self.logger.error(f"do_one_task execute exception: {secure_format_exception(e)}") secure_log_traceback() stop_run = True return interval, stop_run, end_run_client
[docs] def release_resources(self, client): if client.run_manager: with client.run_manager.new_context() as fl_ctx: self.fire_event(EventType.SWAP_OUT, fl_ctx) fl_ctx.set_prop(FLContextKey.RUNNER, None, private=True) self.logger.info(f"Clean up ClientRunner for : {client.client_name} ")
[docs] def run(self, args, conn): self.logger.info("ClientTaskWorker started to run") admin_agent = None client = None try: data = conn.recv() client_config = data[SimulatorConstants.CLIENT_CONFIG] deploy_args = data[SimulatorConstants.DEPLOY_ARGS] build_ctx = data[SimulatorConstants.BUILD_CTX] client = self._create_client(args, build_ctx, deploy_args) app_root = get_simulator_app_root(args.simulator_root, client.client_name) app_custom_folder = os.path.join(app_root, "custom") if os.path.isdir(app_custom_folder) and app_custom_folder not in sys.path: sys.path.append(app_custom_folder) self.create_client_engine(client, deploy_args) while True: interval, stop_run, end_run_client = self.do_one_task(client, args) conn.send(stop_run) if stop_run: conn.send(end_run_client) continue_run = conn.recv() if not continue_run: self.release_resources(client) break time.sleep(interval) except Exception as e: self.logger.error(f"ClientTaskWorker run error: {secure_format_exception(e)}") finally: if client: client.cell.stop() if admin_agent: admin_agent.shutdown()
def _create_client(self, args, build_ctx, deploy_args): deployer = BaseClientDeployer() deployer.build(build_ctx) client = deployer.create_fed_client(deploy_args) client.token = args.token self._set_client_status(client, deploy_args, args.simulator_root) start = time.time() self._create_client_cell(client, args.root_url, args.parent_url) self.logger.debug(f"Complete _create_client_cell. Time to create client job cell: {time.time() - start}") return client def _set_client_status(self, client, deploy_args, simulator_root): app_client_root = get_simulator_app_root(simulator_root, client.client_name) client.app_client_root = app_client_root client.args = deploy_args # self.create_client_runner(client) client.simulate_running = False client.status = ClientStatus.STARTED def _create_client_cell(self, federated_client, root_url, parent_url): fqcn = FQCN.join([federated_client.client_name, SimulatorConstants.JOB_NAME]) credentials = {} parent_url = None cell = Cell( fqcn=fqcn, root_url=root_url, secure=False, credentials=credentials, create_internal_listener=False, parent_url=parent_url, ) cell.start() mpm.add_cleanup_cb(cell.stop) federated_client.cell = cell federated_client.communicator.set_cell(cell) federated_client.communicator.set_auth( client_name=federated_client.client_name, token=federated_client.token, token_signature="NA", ssid="NA", ) start = time.time() while not cell.is_cell_connected(FQCN.ROOT_SERVER): time.sleep(0.1) if time.time() - start > CELL_CONNECT_CHECK_TIMEOUT: raise RuntimeError("Could not connect to the server cell.")
def _create_connection(listen_port): address = ("localhost", int(listen_port)) listener = Listener(address) conn = listener.accept() return conn
[docs] def main(args): # start parent process checking thread parent_pid = args.parent_pid stop_event = threading.Event() thread = threading.Thread(target=check_parent_alive, args=(parent_pid, stop_event)) thread.start() os.chdir(args.workspace) startup = os.path.join(args.workspace, WorkspaceConstants.STARTUP_FOLDER_NAME) os.makedirs(startup, exist_ok=True) local = os.path.join(args.workspace, WorkspaceConstants.SITE_FOLDER_NAME) os.makedirs(local, exist_ok=True) workspace = Workspace(root_dir=args.workspace, site_name=args.client) dynamic_log_config( config=args.logging_config, dir_path=args.workspace, reload_path=workspace.get_log_config_file_path() ) fobs_initialize(workspace, job_id=SimulatorConstants.JOB_NAME) register_ext_decomposers(args.decomposer_module) AuthorizationService.initialize(EmptyAuthorizer()) # AuditService.initialize(audit_file_name=WorkspaceConstants.AUDIT_LOG) AuditService.the_auditor = SimulatorAuditor() init_security_content_service(args.workspace) if args.gpu: os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu conn = _create_connection(args.port) try: task_worker = ClientTaskWorker() task_worker.run(args, conn) finally: stop_event.set() conn.close() AuditService.close()
[docs] def parse_arguments(): parser = argparse.ArgumentParser() parser.add_argument("--workspace", "-o", type=str, help="WORKSPACE folder", required=True) parser.add_argument("--logging_config", type=str, help="logging config file", required=True) parser.add_argument("--client", type=str, help="Client name", required=True) parser.add_argument("--token", type=str, help="Client token", required=True) parser.add_argument("--port", type=str, help="Listen port", required=True) parser.add_argument("--gpu", "-g", type=str, help="gpu index number") parser.add_argument("--parent_pid", type=int, help="parent process pid", required=True) parser.add_argument("--simulator_root", "-root", type=str, help="Simulator root folder") parser.add_argument("--root_url", "-r", type=str, help="cellnet root_url") parser.add_argument("--parent_url", "-p", type=str, help="cellnet parent_url") parser.add_argument("--task_name", type=str, help="end_run") parser.add_argument("--decomposer_module", type=str, help="decomposer_module name", required=True) args = parser.parse_args() return args
if __name__ == "__main__": """ This is the main program of simulator worker process when running the NVFlare Simulator.. """ # main() args = parse_arguments() mpm.run(main_func=main, run_dir=args.workspace, args=args) time.sleep(2) # os._exit(0)