Source code for nvflare.app_common.tie.process_mgr

# Copyright (c) 2024, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import shlex
import subprocess
import sys
import threading

from nvflare.apis.fl_constant import FLContextKey
from nvflare.apis.fl_context import FLContext
from nvflare.apis.workspace import Workspace
from nvflare.fuel.utils.log_utils import get_obj_logger
from nvflare.fuel.utils.validation_utils import check_object_type, check_str


[docs] class CommandDescriptor: def __init__( self, cmd: str, cwd=None, env=None, log_file_name: str = "", log_stdout: bool = True, stdout_msg_prefix: str = None, ): """Constructor of CommandDescriptor. A CommandDescriptor describes the requirements of the new process to be started. Args: cmd: the command to be executed to start the new process cwd: current work dir for the new process env: system env for the new process log_file_name: base name of the log file. log_stdout: whether to output log messages to stdout. stdout_msg_prefix: prefix to be prepended to log message when writing to stdout. Since multiple processes could be running within the same terminal window, the prefix can help differentiate log messages from these processes. """ check_str("cmd", cmd) if cwd: check_str("cwd", cwd) if env: check_object_type("env", env, dict) if log_file_name: check_str("log_file_name", log_file_name) if stdout_msg_prefix: check_str("stdout_msg_prefix", stdout_msg_prefix) self.cmd = cmd self.cwd = cwd self.env = env self.log_file_name = log_file_name self.log_stdout = log_stdout self.stdout_msg_prefix = stdout_msg_prefix
[docs] class ProcessManager: def __init__(self, cmd_desc: CommandDescriptor): """Constructor of ProcessManager. ProcessManager provides methods for managing the lifecycle of a subprocess (start, stop, poll), as well as the handling of log file to be used by the subprocess. Args: cmd_desc: the CommandDescriptor that describes the command of the new process to be started NOTE: the methods of ProcessManager are not thread safe. """ check_object_type("cmd_desc", cmd_desc, CommandDescriptor) self.process = None self.cmd_desc = cmd_desc self.log_file = None self.msg_prefix = None self.file_lock = threading.Lock() self.logger = get_obj_logger(self)
[docs] def start( self, fl_ctx: FLContext, ): """Start the new process. Args: fl_ctx: FLContext object. Returns: None """ job_id = fl_ctx.get_job_id() if self.cmd_desc.stdout_msg_prefix: site_name = fl_ctx.get_identity_name() self.msg_prefix = f"[{self.cmd_desc.stdout_msg_prefix}@{site_name}]" if self.cmd_desc.log_file_name: ws = fl_ctx.get_prop(FLContextKey.WORKSPACE_OBJECT) if not isinstance(ws, Workspace): self.logger.error( f"FL context prop {FLContextKey.WORKSPACE_OBJECT} should be Workspace but got {type(ws)}" ) raise RuntimeError("bad FLContext object") run_dir = ws.get_run_dir(job_id) log_file_path = os.path.join(run_dir, self.cmd_desc.log_file_name) self.log_file = open(log_file_path, "a") env = os.environ.copy() if self.cmd_desc.env: env.update(self.cmd_desc.env) command_seq = shlex.split(self.cmd_desc.cmd) self.process = subprocess.Popen( command_seq, stderr=subprocess.STDOUT, cwd=self.cmd_desc.cwd, env=env, stdout=subprocess.PIPE, ) log_writer = threading.Thread(target=self._write_log, daemon=True) log_writer.start()
def _write_log(self): # write messages from the process's stdout pipe to log file and sys.stdout. # note that depending on how the process flushes out its output, the messages may be buffered/delayed. while True: line = self.process.stdout.readline() if not line: break assert isinstance(line, bytes) line = line.decode("utf-8") # use file_lock to ensure file integrity since the log file could be closed by the self.stop() method! with self.file_lock: if self.log_file: self.log_file.write(line) self.log_file.flush() if self.cmd_desc.log_stdout: assert isinstance(line, str) if self.msg_prefix and not line.startswith("\r"): line = f"{self.msg_prefix} {line}" sys.stdout.write(line) sys.stdout.flush()
[docs] def poll(self): """Perform a poll request on the process. Returns: None if the process is still running; an exit code (int) if process is not running. """ if not self.process: raise RuntimeError("there is no process to poll") return self.process.poll()
[docs] def stop(self) -> int: """Stop the process. If the process is still running, kill the process. If a log file is open, close the log file. Returns: the exit code of the process. If killed, returns -9. """ rc = self.poll() if rc is None: # process is still alive try: self.process.kill() rc = -9 except: # ignore kill error pass # close the log file if any with self.file_lock: if self.log_file: self.logger.debug("closed subprocess log file!") self.log_file.close() self.log_file = None return rc
[docs] def start_process(cmd_desc: CommandDescriptor, fl_ctx: FLContext) -> ProcessManager: """Convenience function for starting a subprocess. Args: cmd_desc: the CommandDescriptor the describes the command to be executed fl_ctx: FLContext object Returns: a ProcessManager object. """ mgr = ProcessManager(cmd_desc) mgr.start(fl_ctx) return mgr