# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import shlex
import subprocess
import sys
import threading
from nvflare.apis.fl_constant import FLContextKey
from nvflare.apis.fl_context import FLContext
from nvflare.apis.workspace import Workspace
from nvflare.fuel.utils.log_utils import get_obj_logger
from nvflare.fuel.utils.validation_utils import check_object_type, check_str
[docs]
class StopMethod:
KILL = "kill"
TERMINATE = "terminate"
[docs]
class CommandDescriptor:
def __init__(
self,
cmd: str,
cwd=None,
env=None,
log_file_name: str = "",
log_stdout: bool = True,
stdout_msg_prefix: str = None,
stop_method=StopMethod.KILL,
):
"""Constructor of CommandDescriptor.
A CommandDescriptor describes the requirements of the new process to be started.
Args:
cmd: the command to be executed to start the new process
cwd: current work dir for the new process
env: system env for the new process
log_file_name: base name of the log file.
log_stdout: whether to output log messages to stdout.
stdout_msg_prefix: prefix to be prepended to log message when writing to stdout.
Since multiple processes could be running within the same terminal window, the prefix can help
differentiate log messages from these processes.
stop_method: how to stop the command (kill or terminate)
"""
check_str("cmd", cmd)
if cwd:
check_str("cwd", cwd)
if env:
check_object_type("env", env, dict)
if log_file_name:
check_str("log_file_name", log_file_name)
if stdout_msg_prefix:
check_str("stdout_msg_prefix", stdout_msg_prefix)
valid_stop_methods = [StopMethod.KILL, StopMethod.TERMINATE]
if stop_method not in valid_stop_methods:
raise ValueError(f"invalid stop_method '{stop_method}': must be one of {valid_stop_methods}")
self.cmd = cmd
self.cwd = cwd
self.env = env
self.log_file_name = log_file_name
self.log_stdout = log_stdout
self.stdout_msg_prefix = stdout_msg_prefix
self.stop_method = stop_method
[docs]
class ProcessManager:
def __init__(self, cmd_desc: CommandDescriptor, stop_method="kill"):
"""Constructor of ProcessManager.
ProcessManager provides methods for managing the lifecycle of a subprocess (start, stop, poll), as well
as the handling of log file to be used by the subprocess.
Args:
cmd_desc: the CommandDescriptor that describes the command of the new process to be started
NOTE: the methods of ProcessManager are not thread safe.
"""
check_object_type("cmd_desc", cmd_desc, CommandDescriptor)
self.process = None
self.cmd_desc = cmd_desc
self.stop_method = stop_method
self.log_file = None
self.msg_prefix = None
self.file_lock = threading.Lock()
self.logger = get_obj_logger(self)
[docs]
def start(
self,
fl_ctx: FLContext,
):
"""Start the new process.
Args:
fl_ctx: FLContext object.
Returns: None
"""
job_id = fl_ctx.get_job_id()
if self.cmd_desc.stdout_msg_prefix:
site_name = fl_ctx.get_identity_name()
self.msg_prefix = f"[{self.cmd_desc.stdout_msg_prefix}@{site_name}]"
if self.cmd_desc.log_file_name:
ws = fl_ctx.get_prop(FLContextKey.WORKSPACE_OBJECT)
if not isinstance(ws, Workspace):
self.logger.error(
f"FL context prop {FLContextKey.WORKSPACE_OBJECT} should be Workspace but got {type(ws)}"
)
raise RuntimeError("bad FLContext object")
run_dir = ws.get_run_dir(job_id)
log_file_path = os.path.join(run_dir, self.cmd_desc.log_file_name)
self.log_file = open(log_file_path, "a")
env = os.environ.copy()
if self.cmd_desc.env:
env.update(self.cmd_desc.env)
command_seq = shlex.split(self.cmd_desc.cmd)
self.process = subprocess.Popen(
command_seq,
stderr=subprocess.STDOUT,
cwd=self.cmd_desc.cwd,
env=env,
stdout=subprocess.PIPE,
)
log_writer = threading.Thread(target=self._write_log, daemon=True)
log_writer.start()
def _write_log(self):
# write messages from the process's stdout pipe to log file and sys.stdout.
# note that depending on how the process flushes out its output, the messages may be buffered/delayed.
while True:
line = self.process.stdout.readline()
if not line:
break
assert isinstance(line, bytes)
line = line.decode("utf-8")
# use file_lock to ensure file integrity since the log file could be closed by the self.stop() method!
with self.file_lock:
if self.log_file:
self.log_file.write(line)
self.log_file.flush()
if self.cmd_desc.log_stdout:
assert isinstance(line, str)
if self.msg_prefix and not line.startswith("\r"):
line = f"{self.msg_prefix} {line}"
sys.stdout.write(line)
sys.stdout.flush()
[docs]
def poll(self):
"""Perform a poll request on the process.
Returns: None if the process is still running; an exit code (int) if process is not running.
"""
if not self.process:
raise RuntimeError("there is no process to poll")
return self.process.poll()
[docs]
def stop(self) -> int:
"""Stop the process.
If the process is still running, kill the process. If a log file is open, close the log file.
Returns: the exit code of the process. If killed, returns -9.
"""
self.logger.info(f"stopping process: {self.cmd_desc.cmd}")
rc = self.poll()
if rc is None:
# process is still alive
stop_method = self.cmd_desc.stop_method
self.logger.info(f"process still running - {stop_method} process: {self.cmd_desc.cmd}")
try:
if stop_method == StopMethod.KILL:
self.process.kill()
rc = -9
else:
self.process.terminate()
rc = -15
except Exception as ex:
# ignore kill error
self.logger.debug(f"ignored exception {ex} from {stop_method}")
pass
else:
self.logger.info(f"process already stopped: {rc=}")
# close the log file if any
with self.file_lock:
if self.log_file:
self.logger.info("closed subprocess log file!")
self.log_file.close()
self.log_file = None
return rc
[docs]
def start_process(cmd_desc: CommandDescriptor, fl_ctx: FLContext, stop_method="kill") -> ProcessManager:
"""Convenience function for starting a subprocess.
Args:
cmd_desc: the CommandDescriptor the describes the command to be executed
fl_ctx: FLContext object
stop_method: how to stop the process
Returns: a ProcessManager object.
"""
mgr = ProcessManager(cmd_desc, stop_method)
mgr.start(fl_ctx)
return mgr
[docs]
def run_command(cmd_desc: CommandDescriptor) -> str:
env = os.environ.copy()
if cmd_desc.env:
env.update(cmd_desc.env)
command_seq = shlex.split(cmd_desc.cmd)
p = subprocess.Popen(
command_seq,
stderr=subprocess.PIPE,
cwd=cmd_desc.cwd,
env=env,
stdout=subprocess.PIPE,
)
output = []
while True:
line = p.stdout.readline()
# prevent blocking
stderr_line = p.stderr.readline()
if not line and not stderr_line:
break
line = line.decode("utf-8")
output.append(line)
return "".join(output)