# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
import json
import threading
import time
from abc import ABC, abstractmethod
from nvflare.apis.event_type import EventType
from nvflare.apis.fl_constant import AdminCommandNames, ConnPropKey, FLContextKey, RunProcessKey, SystemConfigs
from nvflare.apis.fl_context import FLContext
from nvflare.apis.job_launcher_spec import JobLauncherSpec, JobProcessArgs
from nvflare.apis.resource_manager_spec import ResourceManagerSpec
from nvflare.apis.workspace import Workspace
from nvflare.fuel.common.exit_codes import PROCESS_EXIT_REASON, ProcessExitCode
from nvflare.fuel.f3.cellnet.core_cell import FQCN
from nvflare.fuel.f3.cellnet.defs import MessageHeaderKey, ReturnCode
from nvflare.fuel.f3.message import Message as CellMessage
from nvflare.fuel.utils.config_service import ConfigService
from nvflare.fuel.utils.log_utils import get_obj_logger
from nvflare.private.defs import CellChannel, CellChannelTopic, JobFailureMsgKey, new_cell_message
from nvflare.private.fed.utils.fed_utils import get_job_launcher, get_return_code
from nvflare.security.logging import secure_format_exception, secure_log_traceback
from .client_status import ClientStatus, get_status_message
[docs]
class ClientExecutor(ABC):
[docs]
@abstractmethod
def start_app(
self,
client,
job_id,
job_meta,
args,
allocated_resource,
token,
resource_manager,
fl_ctx: FLContext,
):
"""Starts the client app.
Args:
client: the FL client object
job_id: the job_id
args: admin command arguments for starting the FL client training
allocated_resource: allocated resources
token: token from resource manager
resource_manager: resource manager
fl_ctx: FLContext
"""
pass
[docs]
@abstractmethod
def check_status(self, job_id) -> str:
"""Checks the status of the running client.
Args:
job_id: the job_id
Returns:
A client status message
"""
pass
[docs]
@abstractmethod
def abort_app(self, job_id):
"""Aborts the running app.
Args:
job_id: the job_id
"""
pass
[docs]
@abstractmethod
def abort_task(self, job_id):
"""Aborts the client executing task.
Args:
job_id: the job_id
"""
pass
[docs]
@abstractmethod
def get_run_info(self, job_id):
"""Gets the run information.
Args:
job_id: the job_id
Returns:
A dict of run information.
"""
[docs]
@abstractmethod
def get_errors(self, job_id):
"""Get the error information.
Returns:
A dict of error information.
"""
[docs]
@abstractmethod
def reset_errors(self, job_id):
"""Resets the error information.
Args:
job_id: the job_id
"""
[docs]
class JobExecutor(ClientExecutor):
"""Run the Client executor in a child process."""
def __init__(self, client, startup):
"""To init the ProcessExecutor.
Args:
startup: startup folder
"""
self.client = client
self.logger = get_obj_logger(self)
self.startup = startup
self.run_processes = {}
self.lock = threading.Lock()
self.job_query_timeout = ConfigService.get_float_var(
name="job_query_timeout", conf=SystemConfigs.APPLICATION_CONF, default=5.0
)
[docs]
def start_app(
self,
client,
job_id,
job_meta,
args,
allocated_resource,
token,
resource_manager: ResourceManagerSpec,
fl_ctx: FLContext,
):
"""Starts the app.
Args:
client: the FL client object
job_id: the job_id
job_meta: job metadata
args: admin command arguments for starting the worker process
allocated_resource: allocated resources
token: token from resource manager
resource_manager: resource manager
fl_ctx: FLContext
"""
# use a deep copy of the args for operation since its content will be changed!
args = copy.deepcopy(args)
# update the job meta
workspace = Workspace(args.workspace, site_name=client.client_name)
meta_file = workspace.get_job_meta_path(job_id)
# rewrite the meta file with the received meta
with open(meta_file, "w") as f:
json.dump(job_meta, f, indent=4)
job_launcher: JobLauncherSpec = get_job_launcher(job_meta, fl_ctx)
# prepare command args for the job process
workspace_obj: Workspace = fl_ctx.get_prop(FLContextKey.WORKSPACE_OBJECT)
server_config = fl_ctx.get_prop(FLContextKey.SERVER_CONFIG)
if not server_config:
raise RuntimeError(f"missing {FLContextKey.SERVER_CONFIG} in FL context")
service = server_config[0].get("service", {})
if not isinstance(service, dict):
raise RuntimeError(f"expect server config data to be dict but got {type(service)}")
command_options = ""
for t in args.set:
command_options += " " + t
command_options += " print_conf=True"
args.set.append("print_conf=True")
# Job process args are the same for all job launchers! Letting each job launcher compute the job
# args would be error-prone and would require access to internal server components (e.g. cell).
# We prepare job process args here and save the prepared result in the fl_ctx.
# This way, the job launcher won't need to compute these args again.
# The job launcher will only need to use the args properly to launch the job process!
#
# Each arg is a tuple of (arg_option, arg_value).
# Note that the arg_option is fixed for each arg, and is not launcher specific!
job_args = {
JobProcessArgs.EXE_MODULE: ("-m", "nvflare.private.fed.app.client.worker_process"),
JobProcessArgs.JOB_ID: ("-n", job_id),
JobProcessArgs.CLIENT_NAME: ("-c", client.client_name),
JobProcessArgs.AUTH_TOKEN: ("-t", client.token),
JobProcessArgs.TOKEN_SIGNATURE: ("-ts", client.token_signature),
JobProcessArgs.SSID: ("-d", client.ssid),
JobProcessArgs.WORKSPACE: ("-m", args.workspace),
JobProcessArgs.STARTUP_DIR: ("-w", workspace_obj.get_startup_kit_dir()),
JobProcessArgs.PARENT_URL: ("-p", str(client.cell.get_internal_listener_url())),
JobProcessArgs.SCHEME: ("-scheme", service.get("scheme", "grpc")),
JobProcessArgs.TARGET: ("-g", service.get("target")),
JobProcessArgs.STARTUP_CONFIG_FILE: ("-s", "fed_client.json"),
JobProcessArgs.OPTIONS: ("--set", command_options),
}
params = client.cell.get_internal_listener_params()
if params:
parent_conn_sec = params.get(ConnPropKey.CONNECTION_SECURITY)
if parent_conn_sec:
job_args[JobProcessArgs.PARENT_CONN_SEC] = ("-pcs", parent_conn_sec)
fl_ctx.set_prop(key=FLContextKey.JOB_PROCESS_ARGS, value=job_args, private=True, sticky=False)
job_handle = job_launcher.launch_job(job_meta, fl_ctx)
self.logger.info(f"Launched job {job_id} with job launcher: {type(job_launcher)} ")
fl_ctx.set_prop(FLContextKey.JOB_META, job_meta, private=True, sticky=False)
engine = fl_ctx.get_engine()
engine.fire_event(EventType.AFTER_JOB_LAUNCH, fl_ctx)
client.multi_gpu = False
with self.lock:
self.run_processes[job_id] = {
RunProcessKey.JOB_HANDLE: job_handle,
RunProcessKey.STATUS: ClientStatus.STARTING,
}
thread = threading.Thread(
target=self._wait_child_process_finish,
args=(client, job_id, allocated_resource, token, resource_manager, args.workspace, fl_ctx),
)
thread.start()
def _get_job_launcher(self, job_meta: dict, fl_ctx: FLContext) -> JobLauncherSpec:
engine = fl_ctx.get_engine()
fl_ctx.set_prop(FLContextKey.JOB_META, job_meta, private=True, sticky=False)
engine.fire_event(EventType.BEFORE_JOB_LAUNCH, fl_ctx)
job_launcher = fl_ctx.get_prop(FLContextKey.JOB_LAUNCHER)
if not (job_launcher and isinstance(job_launcher, list)):
raise RuntimeError(f"There's no job launcher can handle this job: {job_meta}.")
return job_launcher[0]
[docs]
def notify_job_status(self, job_id, job_status):
run_process = self.run_processes.get(job_id)
if run_process:
run_process[RunProcessKey.STATUS] = job_status
def _job_fqcn(self, job_id: str):
return FQCN.join([self.client.cell.get_fqcn(), job_id])
[docs]
def check_status(self, job_id):
"""Checks the status of the running client.
Args:
job_id: the job_id
Returns:
A client status message
"""
try:
process_status = self.run_processes.get(job_id, {}).get(RunProcessKey.STATUS, ClientStatus.NOT_STARTED)
return get_status_message(process_status)
except Exception as e:
self.logger.error(f"check_status execution exception: {secure_format_exception(e)}.")
secure_log_traceback()
return "execution exception. Please try again."
[docs]
def get_run_info(self, job_id):
"""Gets the run information.
Args:
job_id: the job_id
Returns:
A dict of run information.
"""
try:
data = {}
request = new_cell_message({}, data)
return_data = self.client.cell.send_request(
target=self._job_fqcn(job_id),
channel=CellChannel.CLIENT_COMMAND,
topic=AdminCommandNames.SHOW_STATS,
request=request,
optional=True,
timeout=self.job_query_timeout,
)
return_code = return_data.get_header(MessageHeaderKey.RETURN_CODE)
if return_code == ReturnCode.OK:
run_info = return_data.payload
return run_info
else:
return {}
except Exception as e:
self.logger.error(f"get_run_info execution exception: {secure_format_exception(e)}.")
secure_log_traceback()
return {"error": "no info collector. Please try again."}
[docs]
def get_errors(self, job_id):
"""Get the error information.
Args:
job_id: the job_id
Returns:
A dict of error information.
"""
try:
data = {"command": AdminCommandNames.SHOW_ERRORS, "data": {}}
request = new_cell_message({}, data)
return_data = self.client.cell.send_request(
target=self._job_fqcn(job_id),
channel=CellChannel.CLIENT_COMMAND,
topic=AdminCommandNames.SHOW_ERRORS,
request=request,
optional=True,
timeout=self.job_query_timeout,
)
return_code = return_data.get_header(MessageHeaderKey.RETURN_CODE)
if return_code == ReturnCode.OK:
errors_info = return_data.payload
return errors_info
else:
return None
except Exception as e:
self.logger.error(f"get_errors execution exception: {secure_format_exception(e)}.")
secure_log_traceback()
return None
[docs]
def reset_errors(self, job_id):
"""Resets the error information.
Args:
job_id: the job_id
"""
try:
data = {"command": AdminCommandNames.RESET_ERRORS, "data": {}}
request = new_cell_message({}, data)
self.client.cell.fire_and_forget(
targets=self._job_fqcn(job_id),
channel=CellChannel.CLIENT_COMMAND,
topic=AdminCommandNames.RESET_ERRORS,
message=request,
optional=True,
)
except Exception as e:
self.logger.error(f"reset_errors execution exception: {secure_format_exception(e)}.")
secure_log_traceback()
[docs]
def abort_app(self, job_id):
"""Aborts the running app.
Args:
job_id: the job_id
"""
# When the HeartBeat cleanup process try to abort the worker process, the job maybe already terminated,
# Use retry to avoid print out the error stack trace.
retry = 1
while retry >= 0:
process_status = self.run_processes.get(job_id, {}).get(RunProcessKey.STATUS, ClientStatus.NOT_STARTED)
if process_status == ClientStatus.STARTED:
try:
with self.lock:
job_handle = self.run_processes[job_id][RunProcessKey.JOB_HANDLE]
data = {}
request = new_cell_message({}, data)
self.client.cell.fire_and_forget(
targets=self._job_fqcn(job_id),
channel=CellChannel.CLIENT_COMMAND,
topic=AdminCommandNames.ABORT,
message=request,
optional=True,
)
self.logger.debug("abort sent to worker")
t = threading.Thread(target=self._terminate_job, args=[job_handle, job_id])
t.start()
t.join()
break
except Exception as e:
if retry == 0:
self.logger.error(
f"abort_worker_process execution exception: {secure_format_exception(e)} for run: {job_id}."
)
secure_log_traceback()
retry -= 1
time.sleep(5.0)
else:
self.logger.info(f"Client worker process for run: {job_id} was already terminated.")
break
self.logger.info("Client worker process is terminated.")
[docs]
def send_to_job(
self,
job_id,
channel: str,
topic: str,
msg: CellMessage,
timeout: float,
optional=False,
) -> CellMessage:
"""Send a message to CJ
Args:
job_id: id of the job
channel: message channel
topic: message topic
msg: the message to be sent
timeout: how long to wait for reply
optional: whether the message is optional
Returns: reply from CJ
"""
# send any serializable data to the job cell
return self.client.cell.send_request(
target=self._job_fqcn(job_id),
channel=channel,
topic=topic,
request=msg,
timeout=timeout,
optional=optional,
)
def _terminate_job(self, job_handle, job_id):
max_wait = 10.0
done = False
start = time.time()
while True:
process = self.run_processes.get(job_id)
if not process:
# already finished gracefully
done = True
break
if time.time() - start > max_wait:
# waited enough
break
time.sleep(0.05) # we want to quickly check
job_handle.terminate()
self.logger.info(f"run ({job_id}): child worker process terminated")
[docs]
def abort_task(self, job_id):
"""Aborts the client executing task.
Args:
job_id: the job_id
"""
process_status = self.run_processes.get(job_id, {}).get(RunProcessKey.STATUS, ClientStatus.NOT_STARTED)
if process_status == ClientStatus.STARTED:
data = {"command": AdminCommandNames.ABORT_TASK, "data": {}}
request = new_cell_message({}, data)
self.client.cell.fire_and_forget(
targets=self._job_fqcn(job_id),
channel=CellChannel.CLIENT_COMMAND,
topic=AdminCommandNames.ABORT_TASK,
message=request,
optional=True,
)
self.logger.debug("abort_task sent")
def _wait_child_process_finish(
self, client, job_id, allocated_resource, token, resource_manager, workspace, fl_ctx
):
self.logger.info(f"run ({job_id}): waiting for child worker process to finish.")
job_handle = self.run_processes.get(job_id, {}).get(RunProcessKey.JOB_HANDLE)
if job_handle:
job_handle.wait()
return_code = get_return_code(job_handle, job_id, workspace, self.logger)
self.logger.info(f"run ({job_id}): child worker process finished with RC {return_code}")
if return_code in [ProcessExitCode.UNSAFE_COMPONENT, ProcessExitCode.CONFIG_ERROR]:
request = new_cell_message(
headers={},
payload={
JobFailureMsgKey.JOB_ID: job_id,
JobFailureMsgKey.CODE: return_code,
JobFailureMsgKey.REASON: PROCESS_EXIT_REASON[return_code],
},
)
self.client.cell.fire_and_forget(
targets=[FQCN.ROOT_SERVER],
channel=CellChannel.SERVER_MAIN,
topic=CellChannelTopic.REPORT_JOB_FAILURE,
message=request,
optional=True,
)
self.logger.info(f"reported failure of job {job_id} to server!")
if allocated_resource:
resource_manager.free_resources(
resources=allocated_resource, token=token, fl_ctx=client.engine.new_context()
)
with self.lock:
self.run_processes.pop(job_id, None)
self.logger.debug(f"run ({job_id}): child worker resources freed.")
engine = fl_ctx.get_engine()
fl_ctx.set_prop(FLContextKey.CURRENT_JOB_ID, job_id, private=True, sticky=False)
fl_ctx.set_prop(FLContextKey.CLIENT_NAME, client.client_name, private=True, sticky=False)
engine.fire_event(EventType.JOB_COMPLETED, fl_ctx)
self.logger.debug(f"Fired event JOB_COMPLETED {EventType.JOB_COMPLETED}")
[docs]
def get_status(self, job_id):
process_status = self.run_processes.get(job_id, {}).get(RunProcessKey.STATUS, ClientStatus.STOPPED)
return process_status
[docs]
def get_run_processes_keys(self):
with self.lock:
return [x for x in self.run_processes.keys()]