# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import base64
import copy
import hashlib
import logging
import os
import re
import time
from abc import abstractmethod
from datetime import datetime
from enum import Enum
import yaml
from nvflare.apis.event_type import EventType
from nvflare.apis.fl_constant import FLContextKey, JobConstants
from nvflare.apis.fl_context import FLContext
from nvflare.apis.job_def import JobMetaKey
from nvflare.apis.job_launcher_spec import JobHandleSpec, JobLauncherSpec, JobProcessArgs, JobReturnCode, add_launcher
from nvflare.app_opt.job_launcher.study_data import (
load_study_data_file,
resolve_study_dataset_mounts,
should_mount_study_data,
)
from nvflare.app_opt.job_launcher.workspace_cell_transfer import (
ENV_WORKSPACE_OWNER_FQCN,
ENV_WORKSPACE_TRANSFER_TOKEN,
WorkspaceTransferManager,
)
from nvflare.utils.job_launcher_utils import get_client_job_args, get_job_launcher_spec, get_server_job_args
[docs]
class JobState(Enum):
STARTING = "starting"
RUNNING = "running"
TERMINATED = "terminated"
SUCCEEDED = "succeeded"
UNKNOWN = "unknown"
[docs]
class PodPhase(Enum):
PENDING = "Pending"
RUNNING = "Running"
SUCCEEDED = "Succeeded"
FAILED = "Failed"
UNKNOWN = "Unknown"
[docs]
class PendingPodAction(Enum):
WAIT = "wait"
WAIT_FOR_RESOURCES = "wait_for_resources"
FAIL = "fail"
POD_STATE_MAPPING = {
PodPhase.PENDING.value: JobState.STARTING,
PodPhase.RUNNING.value: JobState.RUNNING,
PodPhase.SUCCEEDED.value: JobState.SUCCEEDED,
PodPhase.FAILED.value: JobState.TERMINATED,
PodPhase.UNKNOWN.value: JobState.UNKNOWN,
}
JOB_RETURN_CODE_MAPPING = {
JobState.SUCCEEDED: JobReturnCode.SUCCESS,
JobState.STARTING: JobReturnCode.UNKNOWN,
JobState.RUNNING: JobReturnCode.UNKNOWN,
JobState.TERMINATED: JobReturnCode.ABORTED,
JobState.UNKNOWN: JobReturnCode.UNKNOWN,
}
DEFAULT_CONTAINER_ARGS_MODULE_ARGS_DICT = {
"-m": None,
"-w": None,
"-t": None,
"-d": None,
"-n": None,
"-c": None,
"-p": None,
"-g": None,
"-scheme": None,
"-s": None,
}
DEFAULT_NAMESPACE = "default"
DEFAULT_PENDING_TIMEOUT = 120
DEFAULT_PYTHON_PATH = "/usr/local/bin/python"
POLL_INTERVAL = 1
SCHEDULED_EVENT_FAILURE_MAX_AGE = 60
WORKSPACE_MOUNT_PATH = "/var/tmp/nvflare/workspace"
DEFAULT_EPHEMERAL_STORAGE = "1Gi"
_PENDING_FAILURE_WAITING_REASONS = {
"CreateContainerConfigError",
"CreateContainerError",
"ErrImagePull",
"ErrImageNeverPull",
"ImagePullBackOff",
"InvalidImageName",
"RunContainerError",
"CrashLoopBackOff",
}
_PENDING_FAILURE_EVENT_REASONS = {
"BackOff",
"Failed",
"FailedAttachVolume",
"FailedCreatePodSandBox",
"FailedMount",
"FailedScheduling",
"FailedSync",
"InspectFailed",
"InvalidImageName",
"NetworkNotReady",
}
# Files actually read from startup/ by the job pod at runtime. Others in
# startup/ are dropped to shrink the Secret. local/ is bundled whole with each
# job workspace so job resource files and local custom code keep working.
_STARTUP_KEEP_SUFFIXES = (".crt", ".key", ".pem", ".json")
def _keep_startup_file(fname: str) -> bool:
return fname.endswith(_STARTUP_KEEP_SUFFIXES)
def _normalize_image_pull_secrets(image_pull_secrets) -> list[str]:
if image_pull_secrets is None:
return []
if not isinstance(image_pull_secrets, list):
raise ValueError("image_pull_secrets must be a list of Kubernetes Secret names")
for name in image_pull_secrets:
if not isinstance(name, str) or not name.strip():
raise ValueError("image_pull_secrets entries must be non-empty strings")
return list(image_pull_secrets)
def _normalize_pending_timeout(pending_timeout, field_name="pending_timeout"):
if pending_timeout is None:
return None
if isinstance(pending_timeout, bool) or not isinstance(pending_timeout, (int, float)):
raise ValueError(f"{field_name} must be a non-negative number of seconds or None")
if pending_timeout < 0:
raise ValueError(f"{field_name} must be a non-negative number of seconds or None")
return pending_timeout
def _obj_text(*values) -> str:
return " ".join(str(v) for v in values if v)
def _is_cpu_memory_gpu_shortage(message: str) -> bool:
if not message:
return False
for resource_name in re.findall(r"\binsufficient\s+([a-z0-9./_-]+)", message, flags=re.IGNORECASE):
resource_name = resource_name.lower().rstrip(".,;:")
if resource_name in {"cpu", "memory"} or "gpu" in resource_name:
return True
return False
def _timestamp_to_seconds(value):
if isinstance(value, datetime):
return value.timestamp()
if isinstance(value, (int, float)):
return value
if isinstance(value, str):
try:
return datetime.fromisoformat(value.strip().replace("Z", "+00:00")).timestamp()
except ValueError:
return None
return None
def _event_timestamp(event):
series = getattr(event, "series", None)
metadata = getattr(event, "metadata", None)
for value in (
getattr(event, "event_time", None),
getattr(series, "last_observed_time", None),
getattr(event, "last_timestamp", None),
getattr(event, "first_timestamp", None),
getattr(metadata, "creation_timestamp", None),
):
seconds = _timestamp_to_seconds(value)
if seconds is not None:
return seconds
return None
def _event_sort_key(event):
event_time = _event_timestamp(event)
return event_time if event_time is not None else 0
def _is_recent_event(event, now, max_age) -> bool:
event_time = _event_timestamp(event)
if event_time is None:
return False
return now - event_time <= max_age
[docs]
def uuid4_to_rfc1123(uuid_str: str) -> str:
name = uuid_str.lower()
# Strip any chars that aren't alphanumeric or hyphen
name = re.sub(r"[^a-z0-9-]", "", name)
# Prefix with a letter if it starts with a digit
if name and name[0].isdigit():
name = "j" + name
# Kubernetes label limit: 63 chars; strip trailing hyphens after truncation
# (truncation can expose a hyphen that was interior before slicing)
return name[:63].rstrip("-")
[docs]
def site_name_to_rfc1123(site_name: str, max_length: int = 47) -> str:
"""Convert a site name into a stable RFC1123-safe label with a hash suffix."""
digest = hashlib.sha256(site_name.encode("utf-8")).hexdigest()[:8]
name = site_name.lower()
name = re.sub(r"[^a-z0-9-]", "", name).strip("-")
if not name:
name = "site"
if name[0].isdigit():
name = "s" + name
head_max = max_length - len(digest) - 1
name = name[:head_max].rstrip("-") or "site"
return f"{name}-{digest}"
[docs]
def job_pod_name(job_id: str, site_name: str) -> str:
"""Build a site-scoped Kubernetes pod name for a FL job."""
site_suffix = site_name_to_rfc1123(site_name, max_length=20)
job_prefix_max = 63 - len(site_suffix) - 1
job_prefix = job_id[:job_prefix_max].rstrip("-")
return f"{job_prefix}-{site_suffix}"
[docs]
def study_dataset_volume_name(study: str, dataset: str) -> str:
return site_name_to_rfc1123(f"data-{study}-{dataset}", max_length=63)
def _load_yaml_file(file_path: str, label: str):
try:
with open(file_path, "rt") as f:
return yaml.safe_load(f)
except FileNotFoundError as e:
raise ValueError(f"{label} file '{file_path}' was not found") from e
except OSError as e:
raise ValueError(f"Could not read {label} file '{file_path}': {e}") from e
except yaml.YAMLError as e:
raise ValueError(f"Could not parse {label} file '{file_path}': {e}") from e
[docs]
def load_study_job_spec_file(file_path: str, logger: logging.Logger = None) -> dict:
study_job_spec = _load_yaml_file(file_path, "study job spec")
if study_job_spec is None:
study_job_spec = {}
if not isinstance(study_job_spec, dict):
raise ValueError(f"file at study_job_spec_file_path '{file_path}' does not contain a dictionary.")
if not study_job_spec and logger:
logger.warning("study job spec file '%s' has no study entries; built-in pod manifests will be used", file_path)
for study, pod_spec_file in study_job_spec.items():
if not isinstance(study, str) or not study:
raise ValueError(f"study name {study!r} in '{file_path}' must be a non-empty string.")
if not isinstance(pod_spec_file, str) or not pod_spec_file:
raise ValueError(
f"study job spec entry for study '{study}' in '{file_path}' must be a non-empty pod YAML file path."
)
return study_job_spec
[docs]
def resolve_study_job_spec_path(
study_job_spec: dict, study: str, file_path: str, logger: logging.Logger = None
) -> str | None:
if not study or not study_job_spec:
return None
pod_spec_file = study_job_spec.get(study)
if pod_spec_file is None:
if logger:
logger.warning(
"study job spec file '%s' has no entry for study '%s'; built-in pod manifest will be used",
file_path,
study,
)
return None
if os.path.isabs(pod_spec_file):
return pod_spec_file
return os.path.join(os.path.dirname(file_path), pod_spec_file)
[docs]
def load_pod_spec_file(file_path: str) -> dict:
pod_spec = _load_yaml_file(file_path, "pod spec")
if not isinstance(pod_spec, dict):
raise ValueError(f"pod spec file '{file_path}' must contain a Kubernetes Pod dictionary.")
kind = pod_spec.get("kind")
if kind and kind != "Pod":
raise ValueError(f"pod spec file '{file_path}' must define kind: Pod.")
return pod_spec
def _ensure_manifest_mapping(parent: dict, key: str, label: str) -> dict:
value = parent.get(key)
if value is None:
value = {}
parent[key] = value
if not isinstance(value, dict):
raise ValueError(f"{label} must be a dictionary.")
return value
def _ensure_manifest_containers(spec: dict) -> list[dict]:
containers = spec.get("containers")
if containers is None:
containers = [{}]
spec["containers"] = containers
if not isinstance(containers, list):
raise ValueError("pod spec containers must be a list.")
if not containers:
containers.append({})
for container in containers:
if not isinstance(container, dict):
raise ValueError("pod spec containers entries must be dictionaries.")
return containers
def _prepare_pod_manifest_template(pod_manifest_template: dict) -> dict:
if not isinstance(pod_manifest_template, dict):
raise ValueError("pod manifest template must be a dictionary.")
kind = pod_manifest_template.get("kind")
if kind and kind != "Pod":
raise ValueError("pod manifest template must define kind: Pod.")
pod_manifest = copy.deepcopy(pod_manifest_template)
pod_manifest.setdefault("apiVersion", "v1")
pod_manifest["kind"] = "Pod"
_ensure_manifest_mapping(pod_manifest, "metadata", "pod manifest metadata")
spec = _ensure_manifest_mapping(pod_manifest, "spec", "pod manifest spec")
_ensure_manifest_containers(spec)
return pod_manifest
def _merge_named_items(template_items, job_items, label: str) -> list:
if template_items is None:
template_items = []
if job_items is None:
job_items = []
if not isinstance(template_items, list) or not isinstance(job_items, list):
raise ValueError(f"{label} must be a list.")
job_items_by_name = {}
unnamed_job_items = []
for item in job_items:
if not isinstance(item, dict):
raise ValueError(f"{label} entries must be dictionaries.")
name = item.get("name")
if isinstance(name, str) and name:
job_items_by_name[name] = item
else:
unnamed_job_items.append(item)
result = []
used_job_item_names = set()
for item in template_items:
if not isinstance(item, dict):
raise ValueError(f"{label} entries must be dictionaries.")
name = item.get("name")
if name in job_items_by_name:
result.append(copy.deepcopy(job_items_by_name[name]))
used_job_item_names.add(name)
else:
result.append(copy.deepcopy(item))
for name, item in job_items_by_name.items():
if name not in used_job_item_names:
result.append(copy.deepcopy(item))
result.extend(copy.deepcopy(unnamed_job_items))
return result
def _select_job_container(containers: list[dict], container_name: str) -> dict:
target_names = {name for name in (container_name, "nvflare_job") if isinstance(name, str) and name}
for container in containers:
if container.get("name") in target_names:
return container
return containers[0]
[docs]
class K8sJobHandle(JobHandleSpec):
def __init__(
self,
job_id: str,
api_instance,
job_config: dict,
namespace=DEFAULT_NAMESPACE,
timeout=None,
pending_timeout=DEFAULT_PENDING_TIMEOUT,
python_path=DEFAULT_PYTHON_PATH,
workspace_transfer: WorkspaceTransferManager = None,
workspace_job_id: str = "",
pod_name: str = None,
pod_manifest_template: dict = None,
):
super().__init__()
self.job_id = job_id
self.pod_name = pod_name if pod_name is not None else job_id
self.timeout = timeout
self.terminal_state = None
self.terminal_return_code = None
self.workspace_transfer = workspace_transfer
self.workspace_job_id = workspace_job_id
self.api_instance = api_instance
self.namespace = namespace
self.pending_timeout = _normalize_pending_timeout(pending_timeout)
self.python_path = python_path
self.uses_pod_manifest_template = pod_manifest_template is not None
if self.uses_pod_manifest_template:
self.pod_manifest = _prepare_pod_manifest_template(pod_manifest_template)
else:
self.pod_manifest = {
"apiVersion": "v1",
"kind": "Pod",
"metadata": {"name": None}, # set by job_config['name']
"spec": {
"containers": None, # link to container_list
"volumes": None, # link to volume_list
"restartPolicy": "Never",
},
}
self.volume_list = []
if self.uses_pod_manifest_template:
spec = self.pod_manifest["spec"]
self.container_list = spec["containers"]
self.job_container = _select_job_container(self.container_list, job_config.get("container_name"))
else:
self.container_list = [
{
"image": None,
"name": None,
"command": [python_path],
"args": None, # args_list + args_dict + args_sets
"volumeMounts": None, # volume_mount_list
"imagePullPolicy": "Always",
}
]
self.job_container = self.container_list[0]
command = job_config.get("command")
if not command:
raise ValueError("job_config must contain a non-empty 'command' key")
self.container_args_python_args_list = ["-u", "-m", command]
self.container_volume_mount_list = []
self._make_manifest(job_config)
self._stuck_count = 0
self._pending_since = None
# Kept for diagnostics only; unit is seconds, not poll iterations like _stuck_count.
self._pending_timeout_secs = self.pending_timeout
self._last_event_query_failed = False
self._pending_timer_paused_at = None
self.logger = logging.getLogger(self.__class__.__name__)
def _make_manifest(self, job_config):
self.container_volume_mount_list.extend(job_config.get("volume_mount_list", []))
set_list = job_config.get("set_list")
if not set_list:
self.container_args_module_args_sets = list()
else:
self.container_args_module_args_sets = ["--set"] + set_list
if job_config.get("module_args") is None:
self.container_args_module_args_dict = DEFAULT_CONTAINER_ARGS_MODULE_ARGS_DICT.copy()
else:
self.container_args_module_args_dict = job_config.get("module_args")
self.container_args_module_args_dict_as_list = list()
for k, v in self.container_args_module_args_dict.items():
if v is None:
continue
self.container_args_module_args_dict_as_list.append(k)
self.container_args_module_args_dict_as_list.append(str(v))
job_volume_list = job_config.get("volume_list", [])
metadata = _ensure_manifest_mapping(self.pod_manifest, "metadata", "pod manifest metadata")
spec = _ensure_manifest_mapping(self.pod_manifest, "spec", "pod manifest spec")
metadata["name"] = job_config.get("name")
spec["restartPolicy"] = "Never"
if self.uses_pod_manifest_template:
spec["volumes"] = _merge_named_items(spec.get("volumes"), job_volume_list, "pod spec volumes")
else:
self.volume_list.extend(job_volume_list)
spec["containers"] = self.container_list
spec["volumes"] = self.volume_list
image_pull_secrets = _normalize_image_pull_secrets(job_config.get("image_pull_secrets"))
if image_pull_secrets:
image_pull_secret_refs = [{"name": name} for name in image_pull_secrets]
if self.uses_pod_manifest_template:
spec["imagePullSecrets"] = _merge_named_items(
spec.get("imagePullSecrets"), image_pull_secret_refs, "pod spec imagePullSecrets"
)
else:
spec["imagePullSecrets"] = image_pull_secret_refs
security_context = job_config.get("security_context")
if security_context:
spec["securityContext"] = security_context
image = job_config.get("image")
if not image:
raise ValueError("job_config must contain a non-empty 'image' key")
container = self.job_container
container["image"] = image
container["name"] = job_config.get("container_name", "nvflare_job")
container["command"] = [self.python_path]
container["args"] = (
self.container_args_python_args_list
+ self.container_args_module_args_dict_as_list
+ self.container_args_module_args_sets
)
if self.uses_pod_manifest_template:
container.setdefault("imagePullPolicy", "Always")
container["volumeMounts"] = _merge_named_items(
container.get("volumeMounts"), self.container_volume_mount_list, "container volumeMounts"
)
else:
container["volumeMounts"] = self.container_volume_mount_list
# resources now always includes ephemeral-storage; GPU limits are merged
# into the same dict only when requested for the job.
if job_config.get("resources"):
container["resources"] = job_config["resources"]
env_vars = {k: v for k, v in job_config.get("env", {}).items() if str(v)}
if env_vars:
env_items = [{"name": k, "value": str(v)} for k, v in env_vars.items()]
if self.uses_pod_manifest_template:
container["env"] = _merge_named_items(container.get("env"), env_items, "container env")
else:
container["env"] = env_items
[docs]
def get_manifest(self):
return copy.deepcopy(self.pod_manifest)
[docs]
def enter_states(self, job_states_to_enter: list):
starting_time = time.time()
if not isinstance(job_states_to_enter, (list, tuple)):
job_states_to_enter = [job_states_to_enter]
if not all([isinstance(js, JobState) for js in job_states_to_enter]):
raise ValueError(f"expect job_states_to_enter with valid values, but get {job_states_to_enter}")
while True:
if self.terminal_state is not None:
return False
pod = self._query_pod()
if self.terminal_state is not None:
return False
pod_phase = self._get_pod_phase(pod)
now = time.time()
if self._handle_starting_pod(pod, pod_phase, now=now):
return False
job_state = POD_STATE_MAPPING.get(pod_phase, JobState.UNKNOWN)
if job_state in job_states_to_enter:
return True
elif pod_phase in [PodPhase.FAILED.value, PodPhase.SUCCEEDED.value]: # terminal state
self.terminal_state = POD_STATE_MAPPING.get(pod_phase, JobState.UNKNOWN)
self._remove_workspace_job()
return False
elif self.timeout is not None and now - starting_time >= self.timeout:
self._terminate_for_timeout(f"timed out waiting for pod to enter {job_states_to_enter}")
return False
time.sleep(POLL_INTERVAL)
def _remove_workspace_job(self) -> None:
if self.workspace_transfer and self.workspace_job_id:
self.workspace_transfer.remove_job(self.workspace_job_id)
self.workspace_job_id = ""
[docs]
def terminate(self):
from kubernetes.client.rest import ApiException
try:
self.api_instance.delete_namespaced_pod(
name=self.pod_name, namespace=self.namespace, grace_period_seconds=0
)
self.terminal_state = JobState.TERMINATED
except ApiException as e:
if getattr(e, "status", None) == 404:
# Expected when terminate() runs as an idempotent cleanup after the
# pod already exited gracefully (e.g. server abort path where the SJ
# left on its own before the safety-net terminate fires). Not an
# event of interest for operators monitoring logs.
self.logger.debug(
f"job {self.job_id} pod {self.pod_name} not found during termination; assuming terminated"
)
else:
self.logger.error(f"failed to terminate job {self.job_id} pod {self.pod_name}: {e}")
self.terminal_state = JobState.TERMINATED
except Exception as e:
self.logger.error(f"unexpected error terminating job {self.job_id} pod {self.pod_name}: {e}")
self.terminal_state = JobState.TERMINATED
self._remove_workspace_job()
return None
def _terminate_for_timeout(self, reason: str):
self._terminate_for_exception(reason)
def _terminate_for_exception(self, reason: str):
self.logger.warning(f"job {self.job_id} pod {self.pod_name}: {reason}")
self.terminate()
self.terminal_return_code = JobReturnCode.EXCEPTION
def _get_return_code(self, job_state):
if self.terminal_return_code is not None:
return self.terminal_return_code
return JOB_RETURN_CODE_MAPPING.get(job_state)
[docs]
def poll(self):
if self.terminal_state is not None:
return self._get_return_code(self.terminal_state)
pod = self._query_pod()
if self.terminal_state is not None:
return self._get_return_code(self.terminal_state)
pod_phase = self._get_pod_phase(pod)
if self._handle_starting_pod(pod, pod_phase):
return self._get_return_code(self.terminal_state)
job_state = POD_STATE_MAPPING.get(pod_phase, JobState.UNKNOWN)
if job_state in (JobState.SUCCEEDED, JobState.TERMINATED):
self.terminal_state = job_state
self._remove_workspace_job()
return self._get_return_code(job_state)
def _query_pod(self):
from kubernetes.client.rest import ApiException
try:
return self.api_instance.read_namespaced_pod(name=self.pod_name, namespace=self.namespace)
except ApiException as e:
if getattr(e, "status", None) == 404:
self.logger.info(
f"job {self.job_id} pod {self.pod_name} not found during querying; assuming terminated"
)
self.terminal_state = JobState.TERMINATED
self._remove_workspace_job()
else:
self.logger.warning(f"failed to query pod for job {self.job_id} pod {self.pod_name}: {e}")
return None
except Exception as e:
self.logger.warning(f"unexpected error querying pod for job {self.job_id} pod {self.pod_name}: {e}")
return None
def _query_phase(self):
pod = self._query_pod()
if pod is None and self.terminal_state is not None:
return PodPhase.UNKNOWN.value
return self._get_pod_phase(pod)
def _get_pod_phase(self, pod):
if pod is None:
return None
phase = getattr(getattr(pod, "status", None), "phase", None)
if not phase:
self.logger.warning(f"pod phase is missing for job {self.job_id} pod {self.pod_name}")
return None
return phase
def _query_state(self):
pod_phase = self._query_phase()
return POD_STATE_MAPPING.get(pod_phase, JobState.UNKNOWN)
def _stuck_in_pending(self, current_phase, now=None):
if current_phase is None:
return False
if current_phase == PodPhase.PENDING.value:
self._stuck_count += 1
if self.pending_timeout is None:
return False
current_time = time.time() if now is None else now
if self._pending_since is None:
self._pending_since = current_time
self._pending_timer_paused_at = None
else:
self._resume_pending_timer(current_time)
if self.pending_timeout == 0:
return True
if current_time - self._pending_since >= self.pending_timeout:
return True
else:
self._reset_pending_timer()
return False
def _handle_starting_pod(self, pod, pod_phase, now=None) -> bool:
self._last_event_query_failed = False
action, detail = self._classify_starting_pod(pod, pod_phase, now=now)
if action == PendingPodAction.FAIL:
self._terminate_for_exception(f"pod startup failure: {detail}")
return True
if action == PendingPodAction.WAIT_FOR_RESOURCES:
if self._stuck_in_pending(pod_phase, now=now):
self._terminate_for_timeout(f"timed out waiting for CPU/memory/GPU resources: {detail}")
return True
return False
if pod_phase is None:
self._pause_pending_timer(now)
return False
if pod_phase == PodPhase.PENDING.value and self._pending_since is not None and self._last_event_query_failed:
if not self._pod_is_scheduled(getattr(pod, "status", None)):
self._pause_pending_timer(now)
return False
self._reset_pending_timer()
return False
def _pause_pending_timer(self, now=None):
if self._pending_since is None or self._pending_timer_paused_at is not None:
return
self._pending_timer_paused_at = time.time() if now is None else now
def _resume_pending_timer(self, now=None):
if self._pending_timer_paused_at is None:
return
current_time = time.time() if now is None else now
paused_duration = max(0, current_time - self._pending_timer_paused_at)
self._pending_since += paused_duration
self._pending_timer_paused_at = None
def _reset_pending_timer(self):
self._stuck_count = 0
self._pending_since = None
self._pending_timer_paused_at = None
def _classify_starting_pod(self, pod, pod_phase, now=None):
if pod_phase == PodPhase.UNKNOWN.value:
return PendingPodAction.FAIL, "pod phase is Unknown"
if pod_phase != PodPhase.PENDING.value:
return PendingPodAction.WAIT, ""
status = getattr(pod, "status", None)
if self._pod_is_scheduled(status):
failure = self._get_container_waiting_failure(status)
if failure:
return PendingPodAction.FAIL, failure
failure = self._get_event_failure(ignore_failed_scheduling=True, now=now)
if failure:
return PendingPodAction.FAIL, failure
return PendingPodAction.WAIT, "pod is scheduled and still starting"
action, detail = self._classify_unscheduled_pod(status)
if action != PendingPodAction.WAIT:
return action, detail
event_action, event_detail = self._classify_unscheduled_events()
if event_action != PendingPodAction.WAIT:
return event_action, event_detail
return PendingPodAction.WAIT, "pod is pending without a scheduler failure"
def _pod_is_scheduled(self, status) -> bool:
node_name = getattr(status, "node_name", None)
if isinstance(node_name, str) and node_name:
return True
for condition in self._get_pod_conditions(status):
if getattr(condition, "type", None) == "PodScheduled" and getattr(condition, "status", None) == "True":
return True
return False
def _classify_unscheduled_pod(self, status):
for condition in self._get_pod_conditions(status):
if getattr(condition, "type", None) != "PodScheduled":
continue
condition_status = getattr(condition, "status", None)
if condition_status != "False":
continue
reason = getattr(condition, "reason", None)
message = getattr(condition, "message", None)
detail = _obj_text(reason, message) or "pod is not scheduled"
if _is_cpu_memory_gpu_shortage(detail):
return PendingPodAction.WAIT_FOR_RESOURCES, detail
if reason == "Unschedulable":
return PendingPodAction.FAIL, detail
return PendingPodAction.WAIT, ""
def _classify_unscheduled_events(self):
for event in sorted(self._query_pod_events(), key=_event_sort_key, reverse=True):
reason = getattr(event, "reason", None)
message = getattr(event, "message", None)
event_type = getattr(event, "type", None)
if event_type != "Warning":
continue
detail = _obj_text(reason, message) or "pod event reported startup issue"
if _is_cpu_memory_gpu_shortage(detail):
return PendingPodAction.WAIT_FOR_RESOURCES, detail
if reason == "FailedScheduling":
return PendingPodAction.FAIL, detail
if reason in _PENDING_FAILURE_EVENT_REASONS:
return PendingPodAction.FAIL, detail
return PendingPodAction.WAIT, ""
def _get_container_waiting_failure(self, status):
for container_status in self._get_all_container_statuses(status):
waiting = getattr(getattr(container_status, "state", None), "waiting", None)
if not waiting:
continue
reason = getattr(waiting, "reason", None)
message = getattr(waiting, "message", None)
detail = _obj_text(reason, message) or "container is waiting"
if reason in _PENDING_FAILURE_WAITING_REASONS:
return detail
return ""
def _get_event_failure(self, ignore_failed_scheduling=False, now=None):
now = time.time() if now is None else now
for event in sorted(self._query_pod_events(), key=_event_sort_key, reverse=True):
reason = getattr(event, "reason", None)
if ignore_failed_scheduling and reason == "FailedScheduling":
continue
event_type = getattr(event, "type", None)
if event_type != "Warning" or reason not in _PENDING_FAILURE_EVENT_REASONS:
continue
if not _is_recent_event(event, now, SCHEDULED_EVENT_FAILURE_MAX_AGE):
continue
message = getattr(event, "message", None)
return _obj_text(reason, message) or "pod event reported startup issue"
return ""
def _query_pod_events(self):
from kubernetes.client.rest import ApiException
self._last_event_query_failed = False
try:
resp = self.api_instance.list_namespaced_event(
namespace=self.namespace,
field_selector=f"involvedObject.name={self.pod_name}",
)
except ApiException as e:
self._last_event_query_failed = True
self.logger.warning(f"failed to query events for job {self.job_id} pod {self.pod_name}: {e}")
return []
except Exception as e:
self._last_event_query_failed = True
self.logger.warning(f"unexpected error querying events for job {self.job_id} pod {self.pod_name}: {e}")
return []
items = getattr(resp, "items", None)
return items if isinstance(items, (list, tuple)) else []
def _get_pod_conditions(self, status):
conditions = getattr(status, "conditions", None)
return conditions if isinstance(conditions, (list, tuple)) else []
def _get_all_container_statuses(self, status):
result = []
for attr_name in ("init_container_statuses", "container_statuses"):
statuses = getattr(status, attr_name, None)
if isinstance(statuses, (list, tuple)):
result.extend(statuses)
return result
[docs]
def wait(self):
while True:
if self.terminal_state is not None:
return
pod = self._query_pod()
if self.terminal_state is not None:
return
pod_phase = self._get_pod_phase(pod)
if self._handle_starting_pod(pod, pod_phase):
return
job_state = POD_STATE_MAPPING.get(pod_phase, JobState.UNKNOWN)
if job_state in (JobState.SUCCEEDED, JobState.TERMINATED):
self.terminal_state = job_state # persist so poll() stays accurate
self._remove_workspace_job()
return
time.sleep(POLL_INTERVAL)
[docs]
class K8sJobLauncher(JobLauncherSpec):
def __init__(
self,
config_file_path: str,
study_data_pvc_file_path: str = None,
timeout=None,
namespace=DEFAULT_NAMESPACE,
pending_timeout=DEFAULT_PENDING_TIMEOUT,
python_path=None,
security_context: dict = None,
ephemeral_storage: str = DEFAULT_EPHEMERAL_STORAGE,
default_python_path: str = None,
workspace_mount_path: str = WORKSPACE_MOUNT_PATH,
image_pull_secrets: list[str] = None,
study_job_spec_file_path: str = None,
):
super().__init__()
self.logger = logging.getLogger(self.__class__.__name__)
self.config_file_path = config_file_path
if study_data_pvc_file_path is not None and (
not isinstance(study_data_pvc_file_path, str) or not study_data_pvc_file_path
):
raise ValueError("study_data_pvc_file_path must be a non-empty string or None")
self.study_data_pvc_file_path = study_data_pvc_file_path
if study_job_spec_file_path is not None and (
not isinstance(study_job_spec_file_path, str) or not study_job_spec_file_path
):
raise ValueError("study_job_spec_file_path must be a non-empty string or None")
self.study_job_spec_file_path = study_job_spec_file_path
self.timeout = timeout
self.namespace = namespace
self.pending_timeout = _normalize_pending_timeout(pending_timeout)
self.default_python_path = default_python_path if default_python_path is not None else python_path
if self.default_python_path is None:
self.default_python_path = DEFAULT_PYTHON_PATH
if not isinstance(self.default_python_path, str) or not self.default_python_path:
raise ValueError("default_python_path must be a non-empty string")
self.security_context = security_context
if not isinstance(ephemeral_storage, str) or not ephemeral_storage:
raise ValueError("ephemeral_storage must be a non-empty string")
self.ephemeral_storage = ephemeral_storage
if not isinstance(workspace_mount_path, str) or not workspace_mount_path:
raise ValueError("workspace_mount_path must be a non-empty string")
self.workspace_mount_path = workspace_mount_path
self.image_pull_secrets = _normalize_image_pull_secrets(image_pull_secrets)
self.study_data_pvc_dict = None
self.study_job_spec_dict = None
self.pod_manifest_template_dict = {}
self.core_v1 = None
def _get_pod_manifest_template(self, study: str):
if not self.study_job_spec_file_path or not study:
return None
if self.study_job_spec_dict is None:
self.study_job_spec_dict = load_study_job_spec_file(self.study_job_spec_file_path, logger=self.logger)
pod_spec_file_path = resolve_study_job_spec_path(
self.study_job_spec_dict, study, self.study_job_spec_file_path, logger=self.logger
)
if not pod_spec_file_path:
return None
if pod_spec_file_path not in self.pod_manifest_template_dict:
self.pod_manifest_template_dict[pod_spec_file_path] = load_pod_spec_file(pod_spec_file_path)
return self.pod_manifest_template_dict[pod_spec_file_path]
def _ensure_startup_secret(self, site_name: str, startup_dir: str) -> str:
"""Create or update a k8s Secret containing the site startup kit.
Returns the Secret name.
"""
from kubernetes.client.rest import ApiException
secret_name = f"nvflare-startup-{site_name_to_rfc1123(site_name)}"
data = {}
if os.path.isdir(startup_dir):
for fname in os.listdir(startup_dir):
if not _keep_startup_file(fname):
continue
fpath = os.path.join(startup_dir, fname)
if os.path.isfile(fpath):
with open(fpath, "rb") as f:
data[fname] = base64.b64encode(f.read()).decode()
secret_body = {
"apiVersion": "v1",
"kind": "Secret",
"metadata": {"name": secret_name, "namespace": self.namespace},
"type": "Opaque",
"data": data,
}
try:
self.core_v1.create_namespaced_secret(namespace=self.namespace, body=secret_body)
self.logger.debug("Created startup Secret %s", secret_name)
except ApiException as e:
if getattr(e, "status", None) == 409:
self.core_v1.replace_namespaced_secret(name=secret_name, namespace=self.namespace, body=secret_body)
self.logger.debug("Updated startup Secret %s", secret_name)
else:
raise
return secret_name
def _replace_pod_manifest_template_namespace(self, pod_manifest: dict) -> None:
metadata = _ensure_manifest_mapping(pod_manifest, "metadata", "pod manifest metadata")
if "namespace" not in metadata:
return
template_namespace = metadata["namespace"]
if template_namespace == self.namespace:
return
metadata["namespace"] = self.namespace
self.logger.warning(
"job pod is launched in namespace '%s' instead of metadata.namespace '%s'",
self.namespace,
template_namespace,
)
[docs]
def launch_job(self, job_meta: dict, fl_ctx: FLContext) -> JobHandleSpec:
if self.core_v1 is None:
from kubernetes import config
from kubernetes.client import Configuration
from kubernetes.client.api import core_v1_api
try:
if self.config_file_path:
config.load_kube_config(self.config_file_path)
else:
config.load_incluster_config()
c = Configuration().get_default_copy()
except AttributeError:
c = Configuration()
c.assert_hostname = False
Configuration.set_default(c)
self.core_v1 = core_v1_api.CoreV1Api()
site_name = fl_ctx.get_identity_name()
raw_job_id = job_meta.get(JobConstants.JOB_ID)
if not raw_job_id:
raise RuntimeError(f"missing {JobConstants.JOB_ID} in job_meta")
job_id = uuid4_to_rfc1123(raw_job_id)
pod_name = job_pod_name(job_id, site_name)
workspace_obj = fl_ctx.get_prop(FLContextKey.WORKSPACE_OBJECT)
if workspace_obj is None:
raise RuntimeError(f"missing {FLContextKey.WORKSPACE_OBJECT} in FLContext")
app_custom_folder = workspace_obj.get_app_custom_dir(raw_job_id)
args = fl_ctx.get_prop(FLContextKey.ARGS)
if args is None:
raise RuntimeError(f"missing {FLContextKey.ARGS} in FLContext")
k8s_spec = get_job_launcher_spec(job_meta, site_name, "k8s")
job_pending_timeout = k8s_spec["pending_timeout"] if "pending_timeout" in k8s_spec else self.pending_timeout
try:
job_pending_timeout = _normalize_pending_timeout(
job_pending_timeout, f"launcher_spec['{site_name}']['k8s']['pending_timeout']"
)
except ValueError as e:
raise RuntimeError(str(e)) from e
job_image = k8s_spec.get("image")
job_ephemeral_storage = k8s_spec.get("ephemeral_storage")
if job_ephemeral_storage is None:
job_ephemeral_storage = self.ephemeral_storage
if not isinstance(job_ephemeral_storage, str) or not job_ephemeral_storage:
raise RuntimeError(f"launcher_spec['{site_name}']['k8s']['ephemeral_storage'] must be a non-empty string")
if not job_image:
raise RuntimeError(
f"K8sJobLauncher is configured for site '{site_name}' but no job image "
f"was specified in meta.json for this site. "
f"Set launcher_spec['{site_name}']['k8s']['image'] (preferred), "
f"launcher_spec['default']['k8s']['image'] (shared default), "
f"or resource_spec['{site_name}']['k8s']['image'] (legacy)."
)
study = job_meta.get(JobMetaKey.STUDY.value)
pod_manifest_template = self._get_pod_manifest_template(study)
data_mounts = []
if should_mount_study_data(study) and self.study_data_pvc_file_path:
if self.study_data_pvc_dict is None:
self.study_data_pvc_dict = load_study_data_file(self.study_data_pvc_file_path, logger=self.logger)
data_mounts = resolve_study_dataset_mounts(
self.study_data_pvc_dict, study, self.study_data_pvc_file_path, logger=self.logger
)
if pod_manifest_template is not None and data_mounts:
self.logger.warning(
"study_job_spec_file_path '%s' is used for study '%s'; matching entries from "
"study_data_pvc_file_path '%s' will be added as extra volume mounts",
self.study_job_spec_file_path,
study,
self.study_data_pvc_file_path,
)
site_resources = (job_meta.get(JobMetaKey.RESOURCE_SPEC.value) or {}).get(site_name) or {}
flat_gpu_count = (
0
if any(k in site_resources for k in ("process", "docker", "k8s"))
else site_resources.get("num_of_gpus", 0)
)
job_resource = k8s_spec["num_of_gpus"] if "num_of_gpus" in k8s_spec else flat_gpu_count
job_args = fl_ctx.get_prop(FLContextKey.JOB_PROCESS_ARGS)
if not job_args:
raise RuntimeError(f"missing {FLContextKey.JOB_PROCESS_ARGS} in FLContext")
exe_module_entry = job_args.get(JobProcessArgs.EXE_MODULE)
if not exe_module_entry:
raise RuntimeError(f"missing {JobProcessArgs.EXE_MODULE} in {FLContextKey.JOB_PROCESS_ARGS}")
_, job_cmd = exe_module_entry
workspace_root = args.workspace
env = {}
if app_custom_folder:
workspace_root_abs = os.path.abspath(workspace_root)
custom_folder_abs = os.path.abspath(app_custom_folder)
if os.path.commonpath([workspace_root_abs, custom_folder_abs]) != workspace_root_abs:
raise RuntimeError(f"custom folder {app_custom_folder} is not under workspace {workspace_root}")
env["PYTHONPATH"] = os.path.join(
self.workspace_mount_path, os.path.relpath(custom_folder_abs, workspace_root_abs)
)
startup_dir = workspace_obj.get_startup_kit_dir()
engine = fl_ctx.get_engine()
owner_cell = getattr(engine, "cell", None) if engine else None
if owner_cell is None:
raise RuntimeError("missing parent CellNet cell for workspace transfer")
workspace_transfer = WorkspaceTransferManager.get_or_create(owner_cell)
workspace_transfer_token = workspace_transfer.add_job(raw_job_id, workspace_root)
try:
startup_secret_name = self._ensure_startup_secret(site_name, startup_dir)
env[ENV_WORKSPACE_OWNER_FQCN] = workspace_transfer.owner_fqcn
env[ENV_WORKSPACE_TRANSFER_TOKEN] = workspace_transfer_token
volume_list = [
{"name": "workspace-job", "emptyDir": {"sizeLimit": job_ephemeral_storage}},
{"name": "startup-kit", "secret": {"secretName": startup_secret_name}},
]
volume_mount_list = [
{"name": "workspace-job", "mountPath": self.workspace_mount_path},
{
"name": "startup-kit",
"mountPath": os.path.join(self.workspace_mount_path, "startup"),
"readOnly": True,
},
]
for dataset_mount in data_mounts:
volume_name = study_dataset_volume_name(dataset_mount.study, dataset_mount.dataset)
volume_list.append({"name": volume_name, "persistentVolumeClaim": {"claimName": dataset_mount.source}})
volume_mount_list.append(
{
"name": volume_name,
"mountPath": dataset_mount.mount_path,
"readOnly": dataset_mount.read_only,
}
)
job_config = {
"name": pod_name,
"image": job_image,
"container_name": f"container-{job_id}",
"command": job_cmd,
"volume_mount_list": volume_mount_list,
"volume_list": volume_list,
"module_args": self.get_module_args(job_id, fl_ctx),
"env": env,
}
if self.image_pull_secrets:
job_config["image_pull_secrets"] = self.image_pull_secrets
if args is not None and getattr(args, "set", None) is not None:
job_config.update({"set_list": args.set})
resources = {
"requests": {"ephemeral-storage": job_ephemeral_storage},
"limits": {"ephemeral-storage": job_ephemeral_storage},
}
for key in ("cpu", "memory"):
limit_val = k8s_spec.get(key)
# cpu_request / memory_request allow request < limit; when absent,
# request mirrors the limit so admission webhooks that require
# explicit cpu/memory requests (e.g. AKS deployment safeguards) pass.
request_val = k8s_spec.get(f"{key}_request", limit_val)
if limit_val:
resources["limits"][key] = limit_val
if request_val:
resources["requests"][key] = request_val
if job_resource:
resources["limits"]["nvidia.com/gpu"] = job_resource
resources["requests"]["nvidia.com/gpu"] = job_resource
job_config["resources"] = resources
if self.security_context:
job_config["security_context"] = self.security_context
python_path = k8s_spec.get("python_path", self.default_python_path)
if not isinstance(python_path, str) or not python_path:
raise RuntimeError(f"launcher_spec['{site_name}']['k8s']['python_path'] must be a non-empty string")
job_handle = K8sJobHandle(
job_id,
self.core_v1,
job_config,
namespace=self.namespace,
timeout=self.timeout,
pending_timeout=job_pending_timeout,
python_path=python_path,
workspace_transfer=workspace_transfer,
workspace_job_id=raw_job_id,
pod_name=pod_name,
pod_manifest_template=pod_manifest_template,
)
pod_manifest = job_handle.get_manifest()
if pod_manifest_template is not None:
self._replace_pod_manifest_template_namespace(pod_manifest)
self.logger.debug(
"launch job with k8s_launcher: pod_name=%s namespace=%s image=%s",
pod_manifest["metadata"]["name"],
self.namespace,
job_image,
)
self.core_v1.create_namespaced_pod(body=pod_manifest, namespace=self.namespace)
except Exception as e:
workspace_transfer.remove_job(raw_job_id)
if "job_handle" in locals():
self.logger.error(f"failed to launch job {job_id}: {e}")
job_handle.terminal_state = JobState.TERMINATED
job_handle.terminal_return_code = JobReturnCode.EXCEPTION
return job_handle
raise
try:
entered_running = job_handle.enter_states([JobState.RUNNING])
except BaseException:
job_handle.terminate()
raise
if not entered_running:
self.logger.warning(f"unable to enter running phase {job_id}")
return job_handle
[docs]
def handle_event(self, event_type: str, fl_ctx: FLContext):
if event_type == EventType.BEFORE_JOB_LAUNCH:
add_launcher(self, fl_ctx)
[docs]
@abstractmethod
def get_module_args(self, job_id, fl_ctx: FLContext):
"""To get the args to run the launcher
Args:
job_id: run job_id
fl_ctx: FLContext
Returns:
"""
pass
def _job_args_dict(job_args: dict, arg_names: list) -> dict:
result = {}
for name in arg_names:
e = job_args.get(name)
if e is None:
continue
n, v = e
result[n] = v
return result
[docs]
class ClientK8sJobLauncher(K8sJobLauncher):
[docs]
def get_module_args(self, _job_id, fl_ctx: FLContext):
job_args = fl_ctx.get_prop(FLContextKey.JOB_PROCESS_ARGS)
if not job_args:
raise RuntimeError(f"missing {FLContextKey.JOB_PROCESS_ARGS} in FLContext")
return _job_args_dict(job_args, get_client_job_args(False, False))
[docs]
class ServerK8sJobLauncher(K8sJobLauncher):
[docs]
def get_module_args(self, _job_id, fl_ctx: FLContext):
job_args = fl_ctx.get_prop(FLContextKey.JOB_PROCESS_ARGS)
if not job_args:
raise RuntimeError(f"missing {FLContextKey.JOB_PROCESS_ARGS} in FLContext")
return _job_args_dict(job_args, get_server_job_args(False, False))