Source code for nvflare.app_opt.job_launcher.docker_launcher

# Copyright (c) 2024, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import os
import posixpath
import re
import time
from abc import abstractmethod

try:
    import docker.errors

    import docker

    _DOCKER_AVAILABLE = True
except ImportError:
    _DOCKER_AVAILABLE = False

from nvflare.apis.event_type import EventType
from nvflare.apis.fl_constant import FLContextKey, JobConstants, WorkspaceConstants
from nvflare.apis.fl_context import FLContext
from nvflare.apis.job_def import JobMetaKey
from nvflare.apis.job_launcher_spec import JobHandleSpec, JobLauncherSpec, JobProcessArgs, JobReturnCode, add_launcher
from nvflare.apis.workspace import Workspace
from nvflare.app_opt.job_launcher.study_data import (
    load_study_data_file,
    resolve_study_dataset_mounts,
    should_mount_study_data,
)
from nvflare.utils.job_launcher_utils import get_client_job_args, get_job_launcher_spec, get_server_job_args


# Docker container status strings
[docs] class DockerStatus: CREATED = "created" RESTARTING = "restarting" RUNNING = "running" PAUSED = "paused" EXITED = "exited" DEAD = "dead"
TERMINAL_STATUSES = {DockerStatus.EXITED, DockerStatus.DEAD} # Docker tmpfs mounts are commonly owned by root; use sticky world-writable mode so the # non-root job-container user can initialize ephemeral top-level workspace directories. _WORKSPACE_TMPFS_MODE = 0o1777 _RESERVED_WORKSPACE_CHILD_NAMES = { WorkspaceConstants.STARTUP_FOLDER_NAME, WorkspaceConstants.SITE_FOLDER_NAME, } def _sanitize_container_name(name: str) -> str: """Sanitize a string to a valid Docker container name. Docker container names allow alphanumeric, hyphens, underscores, and dots. """ name = name.lower() name = re.sub(r"[^a-z0-9\-_.]", "-", name) name = name.strip("-") return name or "nvflare-job" def _exit_code_to_return_code(exit_code: int) -> JobReturnCode: if exit_code == 0: return JobReturnCode.SUCCESS elif exit_code == JobReturnCode.ABORTED: return JobReturnCode.ABORTED else: return JobReturnCode.EXECUTION_ERROR def _safe_workspace_child_path(workspace: str, child_name: str, allow_reserved: bool = False) -> str: """Return a host workspace child path, rejecting paths that escape workspace.""" child_name = str(child_name) normalized_child_name = os.path.normpath(child_name) # normpath catches traversal spellings; the separator checks catch already-normalized nested paths. if ( not child_name or os.path.isabs(child_name) or normalized_child_name != child_name or normalized_child_name in ("", ".", "..") or os.sep in normalized_child_name or (os.altsep and os.altsep in normalized_child_name) ): raise RuntimeError(f"job workspace path must be a single workspace child: {child_name}") if not allow_reserved and normalized_child_name in _RESERVED_WORKSPACE_CHILD_NAMES: raise RuntimeError(f"job workspace path uses reserved workspace name: {child_name}") child_path = os.path.normpath(os.path.join(workspace, child_name)) workspace_real = os.path.realpath(workspace) child_real = os.path.realpath(child_path) if os.path.commonpath([workspace_real, child_real]) != workspace_real: raise RuntimeError(f"job workspace path escapes workspace: {child_name}") if os.path.islink(child_path): raise RuntimeError(f"workspace child path must not be a symlink: {child_name}") return child_path
[docs] class DockerJobHandle(JobHandleSpec): """Handle for a running Docker container job. Modeled on K8sJobHandle: once the container reaches a terminal state, terminal_state is set and all subsequent poll()/wait() calls return immediately without querying Docker. """ def __init__( self, container_id: str, container_name: str, docker_client, timeout: int = 30, ): super().__init__() self.container_id = container_id self.container_name = container_name self.docker_client = docker_client self.timeout = timeout self.terminal_state: JobReturnCode = None # set once, never cleared self.logger = logging.getLogger(self.__class__.__name__) def _get_container(self): """Query Docker for the current container object. Returns None if not found (sets terminal_state) or on API error. """ try: return self.docker_client.containers.get(self.container_id) except docker.errors.NotFound: self.logger.info(f"container {self.container_name} not found; assuming terminated") if self.terminal_state is None: self.terminal_state = JobReturnCode.ABORTED return None except docker.errors.APIError as e: self.logger.warning(f"error querying container {self.container_name}: {e}") return None except Exception as e: self.logger.warning(f"unexpected error querying container {self.container_name}: {e}") return None def _resolve_terminal_return_code(self, container) -> JobReturnCode: """Get the final JobReturnCode from a terminal container using exit code.""" if container.status == DockerStatus.DEAD: return JobReturnCode.ABORTED # EXITED: read actual exit code from container attrs exit_code = container.attrs.get("State", {}).get("ExitCode", 1) return _exit_code_to_return_code(exit_code) def _remove_container(self): """Remove the container after it has reached a terminal state.""" try: container = self.docker_client.containers.get(self.container_id) container.remove(force=True) self.logger.debug(f"removed container {self.container_name}") except docker.errors.NotFound: pass # already gone except docker.errors.APIError as e: self.logger.warning(f"error removing container {self.container_name}: {e}")
[docs] def poll(self) -> JobReturnCode: """Non-blocking status check. Returns UNKNOWN while still running.""" if self.terminal_state is not None: return self.terminal_state container = self._get_container() if container is None: return self.terminal_state if self.terminal_state is not None else JobReturnCode.UNKNOWN if container.status in TERMINAL_STATUSES: rc = self._resolve_terminal_return_code(container) self.terminal_state = rc self._remove_container() return rc return JobReturnCode.UNKNOWN
[docs] def wait(self): """Block until the container reaches a terminal state.""" while True: if self.terminal_state is not None: return container = self._get_container() if container is None: return if container.status in TERMINAL_STATUSES: self.terminal_state = self._resolve_terminal_return_code(container) self._remove_container() return time.sleep(1)
[docs] def terminate(self): """Stop and remove the container. Always sets terminal_state.""" try: container = self.docker_client.containers.get(self.container_id) container.stop(timeout=0) container.remove(force=True) except docker.errors.NotFound: self.logger.info(f"container {self.container_name} not found during termination; assuming terminated") except docker.errors.APIError as e: self.logger.error(f"error terminating container {self.container_name}: {e}") except Exception as e: self.logger.error(f"unexpected error terminating container {self.container_name}: {e}") finally: # Always set terminal_state so poll()/wait() return immediately if self.terminal_state is None: self.terminal_state = JobReturnCode.ABORTED
[docs] def enter_states(self, states_to_enter: list) -> bool: """Poll until the container enters one of the target states. Returns True if the target state was reached, False otherwise (timeout, stuck, or terminal state reached before target). """ starting_time = time.time() if not isinstance(states_to_enter, (list, tuple)): states_to_enter = [states_to_enter] while True: if self.terminal_state is not None: return False container = self._get_container() if container is None: return False status = container.status if status in states_to_enter: return True if status in TERMINAL_STATUSES: self.terminal_state = self._resolve_terminal_return_code(container) self._remove_container() return False if self.timeout is not None and time.time() - starting_time > self.timeout: self.logger.warning(f"container {self.container_name} timed out waiting for {states_to_enter}") self.terminate() return False time.sleep(1)
def _job_args_dict(job_args: dict, arg_names: list) -> dict: """Extract a {flag: value} dict from JOB_PROCESS_ARGS for the given arg names.""" result = {} for name in arg_names: e = job_args.get(name) if not e: continue flag, value = e result[flag] = value return result
[docs] class DockerJobLauncher(JobLauncherSpec): """Launches NVFlare job processes as Docker containers. SP/CP runs as a container started by start_docker.sh (site admin). SJ/CJ containers are started dynamically per job by this launcher. Assumptions: - Docker network already exists (created by start_docker.sh or site admin). - Job containers get an isolated workspace view at /var/tmp/nvflare/workspace: the root is writable ephemeral tmpfs, startup/local are read-only, and only the current job workspace is read-write and persistent on the host. - SP/CP container name is known and reachable via Docker DNS on the network. - parent_url is derived at runtime from the site name and the port in JOB_PROCESS_ARGS. """ WORKSPACE_MOUNT = "/var/tmp/nvflare/workspace" STUDY_DATA_PATH_FILE = "local/study_data.yaml" DEFAULT_PYTHON_PATH = "/usr/local/bin/python" def __init__( self, workspace: str = None, network: str = "nvflare-network", python_path: str = None, timeout: int = 30, default_job_container_kwargs: dict = None, default_job_env: dict = None, default_python_path: str = None, ): """ Args: workspace: host path to the NVFlare workspace directory. Job containers receive an isolated workspace view: startup/local are mounted read-only, and the current job workspace is mounted read-write at /var/tmp/nvflare/workspace/<job_id>. If not provided, reads from NVFL_DOCKER_WORKSPACE environment variable. Must be the HOST path because it is passed directly to the Docker daemon as a volume bind source. network: Docker network name. Must already exist. python_path: Deprecated alias for default_python_path. timeout: max seconds to wait for container to reach RUNNING state (default 30). default_job_container_kwargs: site-level default docker run kwargs applied to every job container launched by this site. Job-level resource_spec[site][docker] takes precedence on conflict. Keys use Docker SDK naming (underscores, not hyphens). Example: {"shm_size": "8g", "ipc_mode": "host"} Note: "volumes", "mounts", "network", "environment", "command", "name", "detach", "user", "working_dir" are controlled by the launcher and cannot be overridden here. default_job_env: site-level default environment variables injected into every job container launched by this site. Useful for site/runtime-specific settings such as NCCL workarounds. Launcher-controlled variables like USER, HOME, and PYTHONPATH still take precedence. default_python_path: Default Python executable path inside job containers. Jobs can override it with launcher_spec[site]["docker"]["python_path"]. """ super().__init__() self.logger = logging.getLogger(self.__class__.__name__) if not workspace: workspace = os.environ.get("NVFL_DOCKER_WORKSPACE") self.workspace = workspace self.network = network self.default_python_path = default_python_path if default_python_path is not None else python_path if self.default_python_path is None: self.default_python_path = self.DEFAULT_PYTHON_PATH if not isinstance(self.default_python_path, str) or not self.default_python_path: raise ValueError("default_python_path must be a non-empty string") self.timeout = timeout default_job_container_kwargs = default_job_container_kwargs or {} _RESERVED_KWARGS = { "volumes", "mounts", "network", "environment", "command", "name", "detach", "user", "working_dir", } reserved_used = _RESERVED_KWARGS & set(default_job_container_kwargs.keys()) if reserved_used: raise ValueError( f"default_job_container_kwargs must not contain reserved keys: {sorted(reserved_used)}. " f"These are controlled by the launcher." ) self.default_job_container_kwargs = default_job_container_kwargs self.default_job_env = default_job_env or {} self._docker_client = None def _get_docker_client(self): if self._docker_client is None: if not _DOCKER_AVAILABLE: raise RuntimeError("docker SDK not installed; install it with: pip install docker") try: client = docker.from_env() client.ping() except Exception as e: raise RuntimeError(f"cannot connect to Docker daemon: {e}") try: client.networks.get(self.network) except docker.errors.NotFound: raise RuntimeError( f"Docker network '{self.network}' does not exist. " f"Create it with: docker network create {self.network}" ) except docker.errors.APIError as e: raise RuntimeError(f"error checking Docker network '{self.network}': {e}") self._docker_client = client return self._docker_client
[docs] def launch_job(self, job_meta: dict, fl_ctx: FLContext) -> JobHandleSpec: job_id = job_meta.get(JobConstants.JOB_ID) if not job_id: raise RuntimeError("missing JOB_ID in job_meta") job_args = fl_ctx.get_prop(FLContextKey.JOB_PROCESS_ARGS) if not job_args: raise RuntimeError(f"missing {FLContextKey.JOB_PROCESS_ARGS} in FLContext") exe_module_entry = job_args.get(JobProcessArgs.EXE_MODULE) if not exe_module_entry: raise RuntimeError(f"missing {JobProcessArgs.EXE_MODULE} in JOB_PROCESS_ARGS") _, exe_module = exe_module_entry site_name = fl_ctx.get_identity_name() docker_spec = get_job_launcher_spec(job_meta, site_name, "docker") job_image = docker_spec.get("image") container_name = _sanitize_container_name(f"{site_name}-{job_id}") if job_image is not None and not isinstance(job_image, str): raise RuntimeError( f"launcher_spec docker image for site '{site_name}' must be a string, " f"got {type(job_image).__name__}: {job_image!r}" ) if not job_image: raise RuntimeError( f"DockerJobLauncher is configured for site '{site_name}' but no job image " f"was specified in meta.json for this site. " f"Set launcher_spec['{site_name}']['docker']['image'] (preferred), " f"launcher_spec['default']['docker']['image'] (shared default), " f"or resource_spec['{site_name}']['docker']['image'] (legacy)." ) workspace = self.workspace if not workspace: raise ValueError( "workspace must be set to the host path of the NVFlare workspace directory, " "or set the NVFL_DOCKER_WORKSPACE environment variable" ) # Derive parent_url at runtime: site name (= container name on Docker DNS) + port # from the original PARENT_URL in job_args. This avoids baking parent_url into # resources.json at provision time. if JobProcessArgs.PARENT_URL in job_args: flag, original_url = job_args[JobProcessArgs.PARENT_URL] port = original_url.rsplit(":", 1)[-1] parent_url = f"tcp://{site_name}:{port}" job_args = dict(job_args) job_args[JobProcessArgs.PARENT_URL] = (flag, parent_url) module_args = self.get_module_args(job_args) module_args_list = [] for flag, value in module_args.items(): if value is not None: module_args_list.extend([flag, str(value)]) # Append --set options (same as K8s launcher) args = fl_ctx.get_prop(FLContextKey.ARGS) set_list = args.set if args is not None and getattr(args, "set", None) is not None else None if set_list: module_args_list.extend(["--set"] + set_list) python_path = docker_spec.get("python_path", self.default_python_path) if not isinstance(python_path, str) or not python_path: raise RuntimeError(f"launcher_spec['{site_name}']['docker']['python_path'] must be a non-empty string") command = [python_path, "-u", "-m", exe_module] + module_args_list # PYTHONPATH: translate app_custom_folder host path to container-internal path # so custom Python code in the job app is importable inside the container. # USER: some libraries (e.g. torch._dynamo) call getpass.getuser() which falls back to # pwd.getpwuid(os.getuid()). When the container runs as a host UID not in /etc/passwd, # this raises KeyError. Setting USER satisfies the env-var fast path in getpass.getuser(). # Pass USER and HOME so libraries that call getpass.getuser() or os.path.expanduser("~") # don't fall back to pwd.getpwuid() — which fails when the host UID has no /etc/passwd entry. environment = { **self.default_job_env, "USER": os.environ.get("USER", "nvflare"), "HOME": os.environ.get("HOME", "/tmp"), } workspace_obj: Workspace = fl_ctx.get_prop(FLContextKey.WORKSPACE_OBJECT) if workspace_obj is not None: python_paths = [] app_custom_folder = workspace_obj.get_app_custom_dir(job_id) if app_custom_folder: python_paths.append(app_custom_folder.replace(workspace, self.WORKSPACE_MOUNT, 1)) site_custom_folder = workspace_obj.get_site_custom_dir() if site_custom_folder and os.path.isdir(site_custom_folder): python_paths.append(site_custom_folder.replace(workspace, self.WORKSPACE_MOUNT, 1)) if python_paths: environment["PYTHONPATH"] = os.pathsep.join(python_paths) # Docker launcher spec: per-job Docker settings (image, shm_size, ipc_mode, ...) live in # launcher_spec[site][docker]. Falls back to nested resource_spec[site][docker] for # backward compatibility. num_of_gpus falls back to flat resource_spec[site] (Option 4). # Site-level defaults (default_job_container_kwargs) are merged in; job-level takes precedence on conflict. _site_rs = (job_meta.get(JobMetaKey.RESOURCE_SPEC.value) or {}).get(site_name) or {} _flat_gpus = 0 if any(k in _site_rs for k in ("process", "docker", "k8s")) else _site_rs.get("num_of_gpus", 0) num_gpus = docker_spec["num_of_gpus"] if "num_of_gpus" in docker_spec else _flat_gpus _RESERVED_KWARGS = { "volumes", "mounts", "network", "environment", "command", "name", "detach", "user", "working_dir", } _NON_CONTAINER_KEYS = {"num_of_gpus", "image", "python_path"} | _RESERVED_KWARGS reserved_in_spec = _RESERVED_KWARGS & set(docker_spec.keys()) if reserved_in_spec: self.logger.warning( f"job {job_id}: launcher_spec['{site_name}']['docker'] contains reserved keys " f"{sorted(reserved_in_spec)} — ignored (controlled by the launcher)" ) job_container_kwargs = {k: v for k, v in docker_spec.items() if k not in _NON_CONTAINER_KEYS} merged_container_kwargs = {**self.default_job_container_kwargs, **job_container_kwargs} # GPU precedence: # 1. explicit job-level device_requests in docker_spec # 2. job-level num_of_gpus translated to device_requests # 3. site-level default device_requests from default_job_container_kwargs # # This preserves the documented rule that job-level resource_spec takes precedence # over site-level defaults, while still allowing fine-grained device_requests overrides. if num_gpus and "device_requests" not in job_container_kwargs: merged_container_kwargs["device_requests"] = [{"Count": num_gpus, "Capabilities": [["gpu"]]}] # Give the job an isolated workspace view. The root tmpfs must be writable by the non-root # container user because server job startup may create ephemeral storage dirs such as # snapshot-storage and jobs-storage. # startup/local are read-only, and only this job's workspace is read-write and persistent on the host. job_workspace_name = WorkspaceConstants.WORKSPACE_PREFIX + str(job_id) host_job_workspace = _safe_workspace_child_path(workspace, job_workspace_name) host_startup_dir = _safe_workspace_child_path( workspace, WorkspaceConstants.STARTUP_FOLDER_NAME, allow_reserved=True ) host_local_dir = _safe_workspace_child_path(workspace, WorkspaceConstants.SITE_FOLDER_NAME, allow_reserved=True) container_job_workspace = posixpath.join(self.WORKSPACE_MOUNT, job_workspace_name) container_startup_dir = posixpath.join(self.WORKSPACE_MOUNT, WorkspaceConstants.STARTUP_FOLDER_NAME) container_local_dir = posixpath.join(self.WORKSPACE_MOUNT, WorkspaceConstants.SITE_FOLDER_NAME) data_mounts = [] # Read study data map from workspace/local/study_data.yaml. # Must use WORKSPACE_MOUNT (container-internal path) for the file read because launch_job # runs inside the SP/CP container. The host path (workspace) does not exist in the container # filesystem. The Docker volume source must remain the host path for the daemon API. # Maps study -> dataset -> {source, mode}; source is a host path for Docker. # Each dataset is mounted at /data/<study>/<dataset>. study_data_file = os.path.join(self.WORKSPACE_MOUNT, self.STUDY_DATA_PATH_FILE) study = job_meta.get(JobMetaKey.STUDY.value) if should_mount_study_data(study): study_data_map = load_study_data_file(study_data_file, logger=self.logger) data_mounts = resolve_study_dataset_mounts(study_data_map, study, study_data_file, logger=self.logger) for dataset_mount in data_mounts: self.logger.info( "mounting study '%s' dataset '%s' from %s -> %s", study, dataset_mount.dataset, dataset_mount.source, dataset_mount.mount_path, ) self.logger.info(f"launching job {job_id} as container {container_name} using image {job_image}") docker_client = self._get_docker_client() try: mounts = [ docker.types.Mount( target=self.WORKSPACE_MOUNT, source=None, type="tmpfs", read_only=False, tmpfs_mode=_WORKSPACE_TMPFS_MODE, ), docker.types.Mount( target=container_startup_dir, source=host_startup_dir, type="bind", read_only=True, ), docker.types.Mount(target=container_local_dir, source=host_local_dir, type="bind", read_only=True), docker.types.Mount( target=container_job_workspace, source=host_job_workspace, type="bind", read_only=False, ), ] for dataset_mount in data_mounts: mounts.append( docker.types.Mount( target=dataset_mount.mount_path, source=dataset_mount.source, type="bind", read_only=dataset_mount.read_only, ) ) container = docker_client.containers.run( job_image, command=command, name=container_name, network=self.network, detach=True, environment=environment if environment else None, mounts=mounts, working_dir=container_job_workspace, # Run as the same user as SP/CP so job-written files are accessible to SP/CP # (e.g. cross_val_results.json written by SJ must be readable/deletable by SP). # Never pass Docker socket to job containers. user=f"{os.getuid()}:{os.getgid()}", **merged_container_kwargs, ) except docker.errors.ImageNotFound: raise RuntimeError(f"image '{job_image}' not found for job {job_id}") except docker.errors.APIError as e: raise RuntimeError(f"error creating container for job {job_id}: {e}") job_handle = DockerJobHandle( container_id=container.id, container_name=container_name, docker_client=docker_client, timeout=self.timeout, ) try: if not job_handle.enter_states([DockerStatus.RUNNING]): self.logger.warning(f"container {container_name} did not reach RUNNING state for job {job_id}") except BaseException: job_handle.terminate() raise # Always return a handle — caller detects failure via poll() return job_handle
[docs] def handle_event(self, event_type: str, fl_ctx: FLContext): if event_type == EventType.BEFORE_JOB_LAUNCH: add_launcher(self, fl_ctx)
[docs] @abstractmethod def get_module_args(self, job_args: dict) -> dict: """Return a {flag: value} dict of args to pass to the job module. Args: job_args: JOB_PROCESS_ARGS dict from FLContext (with PARENT_URL already overridden). Returns: dict of {flag: value} pairs to append after '-u -m <module>' in the container command. """ pass
[docs] class ClientDockerJobLauncher(DockerJobLauncher):
[docs] def get_module_args(self, job_args: dict) -> dict: return _job_args_dict(job_args, get_client_job_args(include_exe_module=False, include_set_options=False))
[docs] class ServerDockerJobLauncher(DockerJobLauncher):
[docs] def get_module_args(self, job_args: dict) -> dict: return _job_args_dict(job_args, get_server_job_args(include_exe_module=False, include_set_options=False))