# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import os
import shlex
import subprocess
import sys
import threading
import time
from abc import ABC, abstractmethod
from nvflare.apis.fl_constant import AdminCommandNames, RunProcessKey, SystemConfigs
from nvflare.apis.resource_manager_spec import ResourceManagerSpec
from nvflare.fuel.common.exit_codes import PROCESS_EXIT_REASON, ProcessExitCode
from nvflare.fuel.f3.cellnet.core_cell import FQCN
from nvflare.fuel.f3.cellnet.defs import MessageHeaderKey, ReturnCode
from nvflare.fuel.utils.config_service import ConfigService
from nvflare.private.defs import CellChannel, CellChannelTopic, JobFailureMsgKey, new_cell_message
from nvflare.private.fed.utils.fed_utils import get_return_code
from nvflare.security.logging import secure_format_exception, secure_log_traceback
from .client_status import ClientStatus, get_status_message
[docs]class ClientExecutor(ABC):
[docs] @abstractmethod
def start_app(
self,
client,
job_id,
args,
app_custom_folder,
listen_port,
allocated_resource,
token,
resource_manager,
target: str,
scheme: str,
):
"""Starts the client app.
Args:
client: the FL client object
job_id: the job_id
args: admin command arguments for starting the FL client training
app_custom_folder: FL application custom folder
listen_port: port to listen the command.
allocated_resource: allocated resources
token: token from resource manager
resource_manager: resource manager
target: SP target location
scheme: SP target connection scheme
"""
pass
[docs] @abstractmethod
def check_status(self, job_id) -> str:
"""Checks the status of the running client.
Args:
job_id: the job_id
Returns:
A client status message
"""
pass
[docs] @abstractmethod
def abort_app(self, job_id):
"""Aborts the running app.
Args:
job_id: the job_id
"""
pass
[docs] @abstractmethod
def abort_task(self, job_id):
"""Aborts the client executing task.
Args:
job_id: the job_id
"""
pass
[docs] @abstractmethod
def get_run_info(self, job_id):
"""Gets the run information.
Args:
job_id: the job_id
Returns:
A dict of run information.
"""
[docs] @abstractmethod
def get_errors(self, job_id):
"""Get the error information.
Returns:
A dict of error information.
"""
[docs] @abstractmethod
def reset_errors(self, job_id):
"""Resets the error information.
Args:
job_id: the job_id
"""
[docs]class ProcessExecutor(ClientExecutor):
"""Run the Client executor in a child process."""
def __init__(self, client, startup):
"""To init the ProcessExecutor.
Args:
startup: startup folder
"""
self.client = client
self.logger = logging.getLogger(self.__class__.__name__)
self.startup = startup
self.run_processes = {}
self.lock = threading.Lock()
self.job_query_timeout = ConfigService.get_float_var(
name="job_query_timeout", conf=SystemConfigs.APPLICATION_CONF, default=5.0
)
[docs] def start_app(
self,
client,
job_id,
args,
app_custom_folder,
listen_port,
allocated_resource,
token,
resource_manager: ResourceManagerSpec,
target: str,
scheme: str,
):
"""Starts the app.
Args:
client: the FL client object
job_id: the job_id
args: admin command arguments for starting the worker process
app_custom_folder: FL application custom folder
listen_port: port to listen the command.
allocated_resource: allocated resources
token: token from resource manager
resource_manager: resource manager
target: SP target location
scheme: SP connection scheme
"""
new_env = os.environ.copy()
if app_custom_folder != "":
new_env["PYTHONPATH"] = new_env.get("PYTHONPATH", "") + os.pathsep + app_custom_folder
command_options = ""
for t in args.set:
command_options += " " + t
command = (
f"{sys.executable} -m nvflare.private.fed.app.client.worker_process -m "
+ args.workspace
+ " -w "
+ self.startup
+ " -t "
+ client.token
+ " -d "
+ client.ssid
+ " -n "
+ job_id
+ " -c "
+ client.client_name
+ " -p "
+ str(client.cell.get_internal_listener_url())
+ " -g "
+ target
+ " -scheme "
+ scheme
+ " -s fed_client.json "
" --set" + command_options + " print_conf=True"
)
# use os.setsid to create new process group ID
process = subprocess.Popen(shlex.split(command, True), preexec_fn=os.setsid, env=new_env)
self.logger.info("Worker child process ID: {}".format(process.pid))
client.multi_gpu = False
with self.lock:
self.run_processes[job_id] = {
RunProcessKey.LISTEN_PORT: listen_port,
RunProcessKey.CONNECTION: None,
RunProcessKey.CHILD_PROCESS: process,
RunProcessKey.STATUS: ClientStatus.STARTING,
}
thread = threading.Thread(
target=self._wait_child_process_finish,
args=(client, job_id, allocated_resource, token, resource_manager, args.workspace),
)
thread.start()
[docs] def notify_job_status(self, job_id, job_status):
run_process = self.run_processes.get(job_id)
if run_process:
run_process[RunProcessKey.STATUS] = job_status
[docs] def check_status(self, job_id):
"""Checks the status of the running client.
Args:
job_id: the job_id
Returns:
A client status message
"""
try:
process_status = self.run_processes.get(job_id, {}).get(RunProcessKey.STATUS, ClientStatus.NOT_STARTED)
return get_status_message(process_status)
except Exception as e:
self.logger.error(f"check_status execution exception: {secure_format_exception(e)}.")
secure_log_traceback()
return "execution exception. Please try again."
[docs] def get_run_info(self, job_id):
"""Gets the run information.
Args:
job_id: the job_id
Returns:
A dict of run information.
"""
try:
data = {}
fqcn = FQCN.join([self.client.client_name, job_id])
request = new_cell_message({}, data)
return_data = self.client.cell.send_request(
target=fqcn,
channel=CellChannel.CLIENT_COMMAND,
topic=AdminCommandNames.SHOW_STATS,
request=request,
optional=True,
timeout=self.job_query_timeout,
)
return_code = return_data.get_header(MessageHeaderKey.RETURN_CODE)
if return_code == ReturnCode.OK:
run_info = return_data.payload
return run_info
else:
return {}
except Exception as e:
self.logger.error(f"get_run_info execution exception: {secure_format_exception(e)}.")
secure_log_traceback()
return {"error": "no info collector. Please try again."}
[docs] def get_errors(self, job_id):
"""Get the error information.
Args:
job_id: the job_id
Returns:
A dict of error information.
"""
try:
data = {"command": AdminCommandNames.SHOW_ERRORS, "data": {}}
fqcn = FQCN.join([self.client.client_name, job_id])
request = new_cell_message({}, data)
return_data = self.client.cell.send_request(
target=fqcn,
channel=CellChannel.CLIENT_COMMAND,
topic=AdminCommandNames.SHOW_ERRORS,
request=request,
optional=True,
timeout=self.job_query_timeout,
)
return_code = return_data.get_header(MessageHeaderKey.RETURN_CODE)
if return_code == ReturnCode.OK:
errors_info = return_data.payload
return errors_info
else:
return None
except Exception as e:
self.logger.error(f"get_errors execution exception: {secure_format_exception(e)}.")
secure_log_traceback()
return None
[docs] def reset_errors(self, job_id):
"""Resets the error information.
Args:
job_id: the job_id
"""
try:
data = {"command": AdminCommandNames.RESET_ERRORS, "data": {}}
fqcn = FQCN.join([self.client.client_name, job_id])
request = new_cell_message({}, data)
self.client.cell.fire_and_forget(
targets=fqcn,
channel=CellChannel.CLIENT_COMMAND,
topic=AdminCommandNames.RESET_ERRORS,
message=request,
optional=True,
)
except Exception as e:
self.logger.error(f"reset_errors execution exception: {secure_format_exception(e)}.")
secure_log_traceback()
[docs] def abort_app(self, job_id):
"""Aborts the running app.
Args:
job_id: the job_id
"""
# When the HeartBeat cleanup process try to abort the worker process, the job maybe already terminated,
# Use retry to avoid print out the error stack trace.
retry = 1
while retry >= 0:
process_status = self.run_processes.get(job_id, {}).get(RunProcessKey.STATUS, ClientStatus.NOT_STARTED)
if process_status == ClientStatus.STARTED:
try:
with self.lock:
child_process = self.run_processes[job_id][RunProcessKey.CHILD_PROCESS]
data = {}
fqcn = FQCN.join([self.client.client_name, job_id])
request = new_cell_message({}, data)
self.client.cell.fire_and_forget(
targets=fqcn,
channel=CellChannel.CLIENT_COMMAND,
topic=AdminCommandNames.ABORT,
message=request,
optional=True,
)
self.logger.debug("abort sent to worker")
t = threading.Thread(target=self._terminate_process, args=[child_process, job_id])
t.start()
t.join()
break
except Exception as e:
if retry == 0:
self.logger.error(
f"abort_worker_process execution exception: {secure_format_exception(e)} for run: {job_id}."
)
secure_log_traceback()
retry -= 1
time.sleep(5.0)
else:
self.logger.info(f"Client worker process for run: {job_id} was already terminated.")
break
self.logger.info("Client worker process is terminated.")
def _terminate_process(self, child_process, job_id):
max_wait = 10.0
done = False
start = time.time()
while True:
process = self.run_processes.get(job_id)
if not process:
# already finished gracefully
done = True
break
if time.time() - start > max_wait:
# waited enough
break
time.sleep(0.05) # we want to quickly check
# kill the sub-process group directly
if not done:
self.logger.debug(f"still not done after {max_wait} secs")
try:
os.killpg(os.getpgid(child_process.pid), 9)
self.logger.debug("kill signal sent")
except:
pass
child_process.terminate()
self.logger.info(f"run ({job_id}): child worker process terminated")
[docs] def abort_task(self, job_id):
"""Aborts the client executing task.
Args:
job_id: the job_id
"""
process_status = self.run_processes.get(job_id, {}).get(RunProcessKey.STATUS, ClientStatus.NOT_STARTED)
if process_status == ClientStatus.STARTED:
data = {"command": AdminCommandNames.ABORT_TASK, "data": {}}
fqcn = FQCN.join([self.client.client_name, job_id])
request = new_cell_message({}, data)
self.client.cell.fire_and_forget(
targets=fqcn,
channel=CellChannel.CLIENT_COMMAND,
topic=AdminCommandNames.ABORT_TASK,
message=request,
optional=True,
)
self.logger.debug("abort_task sent")
def _wait_child_process_finish(self, client, job_id, allocated_resource, token, resource_manager, workspace):
self.logger.info(f"run ({job_id}): waiting for child worker process to finish.")
child_process = self.run_processes.get(job_id, {}).get(RunProcessKey.CHILD_PROCESS)
if child_process:
child_process.wait()
return_code = get_return_code(child_process, job_id, workspace, self.logger)
self.logger.info(f"run ({job_id}): child worker process finished with RC {return_code}")
if return_code in [ProcessExitCode.UNSAFE_COMPONENT, ProcessExitCode.CONFIG_ERROR]:
request = new_cell_message(
headers={},
payload={
JobFailureMsgKey.JOB_ID: job_id,
JobFailureMsgKey.CODE: return_code,
JobFailureMsgKey.REASON: PROCESS_EXIT_REASON[return_code],
},
)
self.client.cell.fire_and_forget(
targets=[FQCN.ROOT_SERVER],
channel=CellChannel.SERVER_MAIN,
topic=CellChannelTopic.REPORT_JOB_FAILURE,
message=request,
optional=True,
)
self.logger.info(f"reported failure of job {job_id} to server!")
if allocated_resource:
resource_manager.free_resources(
resources=allocated_resource, token=token, fl_ctx=client.engine.new_context()
)
with self.lock:
self.run_processes.pop(job_id, None)
self.logger.debug(f"run ({job_id}): child worker resources freed.")
[docs] def get_status(self, job_id):
process_status = self.run_processes.get(job_id, {}).get(RunProcessKey.STATUS, ClientStatus.STOPPED)
return process_status
[docs] def get_run_processes_keys(self):
with self.lock:
return [x for x in self.run_processes.keys()]