Source code for nvflare.private.fed.client.client_engine

# Copyright (c) 2021, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import logging
import os
import re
import shutil
import sys
import threading

from nvflare.apis.event_type import EventType
from nvflare.apis.fl_component import FLComponent
from nvflare.apis.fl_constant import FLContextKey, MachineStatus, SystemComponents, WorkspaceConstants
from nvflare.apis.fl_context import FLContext, FLContextManager
from nvflare.apis.workspace import Workspace
from nvflare.fuel.utils.network_utils import get_open_ports
from nvflare.private.defs import ERROR_MSG_PREFIX, ClientStatusKey, EngineConstant
from nvflare.private.event import fire_event
from nvflare.private.fed.server.job_meta_validator import JobMetaValidator
from nvflare.private.fed.utils.app_deployer import AppDeployer
from nvflare.private.fed.utils.fed_utils import security_close
from nvflare.security.logging import secure_format_exception, secure_log_traceback

from .client_engine_internal_spec import ClientEngineInternalSpec
from .client_executor import ProcessExecutor
from .client_run_manager import ClientRunInfo
from .client_status import ClientStatus
from .fed_client import FederatedClient


def _remove_custom_path():
    regex = re.compile(".*/run_.*/custom")
    custom_paths = list(filter(regex.search, sys.path))
    for path in custom_paths:
        sys.path.remove(path)


[docs]class ClientEngine(ClientEngineInternalSpec): """ClientEngine runs in the client parent process.""" def __init__(self, client: FederatedClient, args, rank, workers=5): """To init the ClientEngine. Args: client: FL client object args: command args rank: local process rank workers: number of workers """ super().__init__() self.client = client self.client_name = client.client_name self.args = args self.rank = rank self.client_executor = ProcessExecutor(client, os.path.join(args.workspace, "startup")) self.admin_agent = None self.fl_ctx_mgr = FLContextManager( engine=self, identity_name=self.client_name, job_id="", public_stickers={}, private_stickers={ SystemComponents.DEFAULT_APP_DEPLOYER: AppDeployer(), SystemComponents.JOB_META_VALIDATOR: JobMetaValidator(), SystemComponents.FED_CLIENT: client, FLContextKey.SECURE_MODE: self.client.secure_train, FLContextKey.WORKSPACE_ROOT: args.workspace, }, ) self.status = MachineStatus.STOPPED if workers < 1: raise ValueError("workers must >= 1") self.logger = logging.getLogger(self.__class__.__name__) self.fl_components = [x for x in self.client.components.values() if isinstance(x, FLComponent)]
[docs] def fire_event(self, event_type: str, fl_ctx: FLContext): fire_event(event=event_type, handlers=self.fl_components, ctx=fl_ctx)
[docs] def set_agent(self, admin_agent): self.admin_agent = admin_agent
[docs] def new_context(self) -> FLContext: return self.fl_ctx_mgr.new_context()
[docs] def get_component(self, component_id: str) -> object: return self.client.components.get(component_id)
[docs] def get_engine_status(self): running_jobs = [] for job_id in self.get_all_job_ids(): run_folder = os.path.join(self.args.workspace, WorkspaceConstants.WORKSPACE_PREFIX + str(job_id)) app_name = "" app_file = os.path.join(run_folder, "fl_app.txt") if os.path.exists(app_file): with open(app_file, "r") as f: app_name = f.readline().strip() job = { ClientStatusKey.APP_NAME: app_name, ClientStatusKey.JOB_ID: job_id, ClientStatusKey.STATUS: self.client_executor.check_status(job_id), } running_jobs.append(job) result = { ClientStatusKey.CLIENT_NAME: self.client.client_name, ClientStatusKey.RUNNING_JOBS: running_jobs, } return result
[docs] def start_app( self, job_id: str, allocated_resource: dict = None, token: str = None, resource_manager=None, ) -> str: status = self.client_executor.get_status(job_id) if status == ClientStatus.STARTED: return "Client app already started." app_root = os.path.join( self.args.workspace, WorkspaceConstants.WORKSPACE_PREFIX + str(job_id), WorkspaceConstants.APP_PREFIX + self.client.client_name, ) if not os.path.exists(app_root): return f"{ERROR_MSG_PREFIX}: Client app does not exist. Please deploy it before starting client." app_custom_folder = os.path.join(app_root, "custom") if os.path.isdir(app_custom_folder): try: sys.path.index(app_custom_folder) except ValueError: _remove_custom_path() sys.path.append(app_custom_folder) self.logger.info("Starting client app. rank: {}".format(self.rank)) open_port = get_open_ports(1)[0] server_config = list(self.client.servers.values())[0] self.client_executor.start_app( self.client, job_id, self.args, app_custom_folder, open_port, allocated_resource, token, resource_manager, target=server_config["target"], scheme=server_config.get("scheme", "grpc"), ) return "Start the client app..."
[docs] def notify_job_status(self, job_id: str, job_status): self.client_executor.notify_job_status(job_id, job_status)
[docs] def get_client_name(self): return self.client.client_name
def _write_token_file(self, job_id, open_port): token_file = os.path.join(self.args.workspace, EngineConstant.CLIENT_TOKEN_FILE) if os.path.exists(token_file): os.remove(token_file) with open(token_file, "wt") as f: f.write( "%s\n%s\n%s\n%s\n%s\n%s\n" % ( self.client.token, self.client.ssid, job_id, self.client.client_name, open_port, list(self.client.servers.values())[0]["target"], ) )
[docs] def abort_app(self, job_id: str) -> str: status = self.client_executor.get_status(job_id) if status == ClientStatus.STOPPED: return "Client app already stopped." if status == ClientStatus.NOT_STARTED: return "Client app has not started." if status == ClientStatus.STARTING: return "Client app is starting, please wait for client to have started before abort." self.client_executor.abort_app(job_id) return "Abort signal has been sent to the client App."
[docs] def abort_task(self, job_id: str) -> str: status = self.client_executor.get_status(job_id) if status == ClientStatus.NOT_STARTED: return "Client app has not started." if status == ClientStatus.STARTING: return "Client app is starting, please wait for started before abort_task." self.client_executor.abort_task(job_id) return "Abort signal has been sent to the current task."
[docs] def shutdown(self) -> str: self.logger.info("Client shutdown...") touch_file = os.path.join(self.args.workspace, "shutdown.fl") self.fire_event(EventType.SYSTEM_END, self.new_context()) thread = threading.Thread(target=shutdown_client, args=(self.client, touch_file)) thread.start() return "Shutdown the client..."
[docs] def restart(self) -> str: self.logger.info("Client shutdown...") touch_file = os.path.join(self.args.workspace, "restart.fl") self.fire_event(EventType.SYSTEM_END, self.new_context()) thread = threading.Thread(target=shutdown_client, args=(self.client, touch_file)) thread.start() return "Restart the client..."
[docs] def deploy_app(self, app_name: str, job_id: str, job_meta: dict, client_name: str, app_data) -> str: workspace = Workspace(root_dir=self.args.workspace, site_name=client_name) app_deployer = self.get_component(SystemComponents.APP_DEPLOYER) if not app_deployer: # use default deployer app_deployer = AppDeployer() err = app_deployer.deploy( workspace=workspace, job_id=job_id, job_meta=job_meta, app_name=app_name, app_data=app_data, fl_ctx=self.new_context(), ) if err: return f"{ERROR_MSG_PREFIX}: {err}" return ""
[docs] def delete_run(self, job_id: str) -> str: job_id_folder = os.path.join(self.args.workspace, WorkspaceConstants.WORKSPACE_PREFIX + str(job_id)) if os.path.exists(job_id_folder): shutil.rmtree(job_id_folder) return f"Delete run folder: {job_id_folder}."
[docs] def get_current_run_info(self, job_id) -> ClientRunInfo: return self.client_executor.get_run_info(job_id)
[docs] def get_errors(self, job_id): return self.client_executor.get_errors(job_id)
[docs] def reset_errors(self, job_id): self.client_executor.reset_errors(job_id)
[docs] def get_all_job_ids(self): return self.client_executor.get_run_processes_keys()
[docs]def shutdown_client(federated_client, touch_file): with open(touch_file, "a"): os.utime(touch_file, None) try: print("About to shutdown the client...") federated_client.communicator.heartbeat_done = True # time.sleep(3) federated_client.close() federated_client.status = ClientStatus.STOPPED # federated_client.cell.stop() security_close() except Exception as e: secure_log_traceback() print(f"Failed to shutdown client: {secure_format_exception(e)}")