# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import base64
import copy
import hashlib
import logging
import os
import re
import time
from abc import abstractmethod
from enum import Enum
from nvflare.apis.event_type import EventType
from nvflare.apis.fl_constant import FLContextKey, JobConstants
from nvflare.apis.fl_context import FLContext
from nvflare.apis.job_def import JobMetaKey
from nvflare.apis.job_launcher_spec import JobHandleSpec, JobLauncherSpec, JobProcessArgs, JobReturnCode, add_launcher
from nvflare.app_opt.job_launcher.study_data import (
load_study_data_file,
resolve_study_dataset_mounts,
should_mount_study_data,
)
from nvflare.app_opt.job_launcher.workspace_cell_transfer import (
ENV_WORKSPACE_OWNER_FQCN,
ENV_WORKSPACE_TRANSFER_TOKEN,
WorkspaceTransferManager,
)
from nvflare.utils.job_launcher_utils import get_client_job_args, get_job_launcher_spec, get_server_job_args
[docs]
class JobState(Enum):
STARTING = "starting"
RUNNING = "running"
TERMINATED = "terminated"
SUCCEEDED = "succeeded"
UNKNOWN = "unknown"
[docs]
class PodPhase(Enum):
PENDING = "Pending"
RUNNING = "Running"
SUCCEEDED = "Succeeded"
FAILED = "Failed"
UNKNOWN = "Unknown"
POD_STATE_MAPPING = {
PodPhase.PENDING.value: JobState.STARTING,
PodPhase.RUNNING.value: JobState.RUNNING,
PodPhase.SUCCEEDED.value: JobState.SUCCEEDED,
PodPhase.FAILED.value: JobState.TERMINATED,
PodPhase.UNKNOWN.value: JobState.UNKNOWN,
}
JOB_RETURN_CODE_MAPPING = {
JobState.SUCCEEDED: JobReturnCode.SUCCESS,
JobState.STARTING: JobReturnCode.UNKNOWN,
JobState.RUNNING: JobReturnCode.UNKNOWN,
JobState.TERMINATED: JobReturnCode.ABORTED,
JobState.UNKNOWN: JobReturnCode.UNKNOWN,
}
DEFAULT_CONTAINER_ARGS_MODULE_ARGS_DICT = {
"-m": None,
"-w": None,
"-t": None,
"-d": None,
"-n": None,
"-c": None,
"-p": None,
"-g": None,
"-scheme": None,
"-s": None,
}
DEFAULT_NAMESPACE = "default"
DEFAULT_PENDING_TIMEOUT = 120
DEFAULT_PYTHON_PATH = "/usr/local/bin/python"
WORKSPACE_MOUNT_PATH = "/var/tmp/nvflare/workspace"
DEFAULT_EPHEMERAL_STORAGE = "1Gi"
# Files actually read from startup/ by the job pod at runtime. Others in
# startup/ are dropped to shrink the Secret. local/ is bundled whole with each
# job workspace so job resource files and local custom code keep working.
_STARTUP_KEEP_SUFFIXES = (".crt", ".key", ".pem", ".json")
def _keep_startup_file(fname: str) -> bool:
return fname.endswith(_STARTUP_KEEP_SUFFIXES)
def _normalize_image_pull_secrets(image_pull_secrets) -> list[str]:
if image_pull_secrets is None:
return []
if not isinstance(image_pull_secrets, list):
raise ValueError("image_pull_secrets must be a list of Kubernetes Secret names")
for name in image_pull_secrets:
if not isinstance(name, str) or not name.strip():
raise ValueError("image_pull_secrets entries must be non-empty strings")
return list(image_pull_secrets)
[docs]
def uuid4_to_rfc1123(uuid_str: str) -> str:
name = uuid_str.lower()
# Strip any chars that aren't alphanumeric or hyphen
name = re.sub(r"[^a-z0-9-]", "", name)
# Prefix with a letter if it starts with a digit
if name and name[0].isdigit():
name = "j" + name
# Kubernetes label limit: 63 chars; strip trailing hyphens after truncation
# (truncation can expose a hyphen that was interior before slicing)
return name[:63].rstrip("-")
[docs]
def site_name_to_rfc1123(site_name: str, max_length: int = 47) -> str:
"""Convert a site name into a stable RFC1123-safe label with a hash suffix."""
digest = hashlib.sha256(site_name.encode("utf-8")).hexdigest()[:8]
name = site_name.lower()
name = re.sub(r"[^a-z0-9-]", "", name).strip("-")
if not name:
name = "site"
if name[0].isdigit():
name = "s" + name
head_max = max_length - len(digest) - 1
name = name[:head_max].rstrip("-") or "site"
return f"{name}-{digest}"
[docs]
def job_pod_name(job_id: str, site_name: str) -> str:
"""Build a site-scoped Kubernetes pod name for a FL job."""
site_suffix = site_name_to_rfc1123(site_name, max_length=20)
job_prefix_max = 63 - len(site_suffix) - 1
job_prefix = job_id[:job_prefix_max].rstrip("-")
return f"{job_prefix}-{site_suffix}"
[docs]
def study_dataset_volume_name(study: str, dataset: str) -> str:
return site_name_to_rfc1123(f"data-{study}-{dataset}", max_length=63)
[docs]
class K8sJobHandle(JobHandleSpec):
def __init__(
self,
job_id: str,
api_instance,
job_config: dict,
namespace=DEFAULT_NAMESPACE,
timeout=None,
pending_timeout=DEFAULT_PENDING_TIMEOUT,
python_path=DEFAULT_PYTHON_PATH,
workspace_transfer: WorkspaceTransferManager = None,
workspace_job_id: str = "",
pod_name: str = None,
):
super().__init__()
self.job_id = job_id
self.pod_name = pod_name if pod_name is not None else job_id
self.timeout = timeout
self.terminal_state = None
self.terminal_return_code = None
self.workspace_transfer = workspace_transfer
self.workspace_job_id = workspace_job_id
self.api_instance = api_instance
self.namespace = namespace
self.pod_manifest = {
"apiVersion": "v1",
"kind": "Pod",
"metadata": {"name": None}, # set by job_config['name']
"spec": {
"containers": None, # link to container_list
"volumes": None, # link to volume_list
"restartPolicy": "Never",
},
}
self.volume_list = []
self.container_list = [
{
"image": None,
"name": None,
"command": [python_path],
"args": None, # args_list + args_dict + args_sets
"volumeMounts": None, # volume_mount_list
"imagePullPolicy": "Always",
}
]
command = job_config.get("command")
if not command:
raise ValueError("job_config must contain a non-empty 'command' key")
self.container_args_python_args_list = ["-u", "-m", command]
self.container_volume_mount_list = []
self._make_manifest(job_config)
self._stuck_count = 0
self._max_stuck_count = self.timeout if self.timeout is not None else pending_timeout
self.logger = logging.getLogger(self.__class__.__name__)
def _make_manifest(self, job_config):
self.container_volume_mount_list.extend(job_config.get("volume_mount_list", []))
set_list = job_config.get("set_list")
if not set_list:
self.container_args_module_args_sets = list()
else:
self.container_args_module_args_sets = ["--set"] + set_list
if job_config.get("module_args") is None:
self.container_args_module_args_dict = DEFAULT_CONTAINER_ARGS_MODULE_ARGS_DICT.copy()
else:
self.container_args_module_args_dict = job_config.get("module_args")
self.container_args_module_args_dict_as_list = list()
for k, v in self.container_args_module_args_dict.items():
if v is None:
continue
self.container_args_module_args_dict_as_list.append(k)
self.container_args_module_args_dict_as_list.append(str(v))
self.volume_list.extend(job_config.get("volume_list", []))
self.pod_manifest["metadata"]["name"] = job_config.get("name")
self.pod_manifest["spec"]["containers"] = self.container_list
self.pod_manifest["spec"]["volumes"] = self.volume_list
image_pull_secrets = _normalize_image_pull_secrets(job_config.get("image_pull_secrets"))
if image_pull_secrets:
self.pod_manifest["spec"]["imagePullSecrets"] = [{"name": name} for name in image_pull_secrets]
security_context = job_config.get("security_context")
if security_context:
self.pod_manifest["spec"]["securityContext"] = security_context
image = job_config.get("image")
if not image:
raise ValueError("job_config must contain a non-empty 'image' key")
self.container_list[0]["image"] = image
self.container_list[0]["name"] = job_config.get("container_name", "nvflare_job")
self.container_list[0]["args"] = (
self.container_args_python_args_list
+ self.container_args_module_args_dict_as_list
+ self.container_args_module_args_sets
)
self.container_list[0]["volumeMounts"] = self.container_volume_mount_list
# resources now always includes ephemeral-storage; GPU limits are merged
# into the same dict only when requested for the job.
if job_config.get("resources"):
self.container_list[0]["resources"] = job_config["resources"]
env_vars = {k: v for k, v in job_config.get("env", {}).items() if str(v)}
if env_vars:
self.container_list[0]["env"] = [{"name": k, "value": str(v)} for k, v in env_vars.items()]
[docs]
def get_manifest(self):
return copy.deepcopy(self.pod_manifest)
[docs]
def enter_states(self, job_states_to_enter: list):
starting_time = time.time()
if not isinstance(job_states_to_enter, (list, tuple)):
job_states_to_enter = [job_states_to_enter]
if not all([isinstance(js, JobState) for js in job_states_to_enter]):
raise ValueError(f"expect job_states_to_enter with valid values, but get {job_states_to_enter}")
while True:
if self.terminal_state is not None:
return False
pod_phase = self._query_phase()
if self.terminal_state is not None:
return False
if self._stuck_in_pending(pod_phase):
self._terminate_for_timeout("timed out waiting for pod to leave Pending/Unknown phase")
return False
job_state = POD_STATE_MAPPING.get(pod_phase, JobState.UNKNOWN)
if job_state in job_states_to_enter:
return True
elif pod_phase in [PodPhase.FAILED.value, PodPhase.SUCCEEDED.value]: # terminal state
self.terminal_state = POD_STATE_MAPPING.get(pod_phase, JobState.UNKNOWN)
self._remove_workspace_job()
return False
elif self.timeout is not None and time.time() - starting_time > self.timeout:
self._terminate_for_timeout(f"timed out waiting for pod to enter {job_states_to_enter}")
return False
time.sleep(1)
def _remove_workspace_job(self) -> None:
if self.workspace_transfer and self.workspace_job_id:
self.workspace_transfer.remove_job(self.workspace_job_id)
self.workspace_job_id = ""
[docs]
def terminate(self):
from kubernetes.client.rest import ApiException
try:
self.api_instance.delete_namespaced_pod(
name=self.pod_name, namespace=self.namespace, grace_period_seconds=0
)
self.terminal_state = JobState.TERMINATED
except ApiException as e:
if getattr(e, "status", None) == 404:
# Expected when terminate() runs as an idempotent cleanup after the
# pod already exited gracefully (e.g. server abort path where the SJ
# left on its own before the safety-net terminate fires). Not an
# event of interest for operators monitoring logs.
self.logger.debug(
f"job {self.job_id} pod {self.pod_name} not found during termination; assuming terminated"
)
else:
self.logger.error(f"failed to terminate job {self.job_id} pod {self.pod_name}: {e}")
self.terminal_state = JobState.TERMINATED
except Exception as e:
self.logger.error(f"unexpected error terminating job {self.job_id} pod {self.pod_name}: {e}")
self.terminal_state = JobState.TERMINATED
self._remove_workspace_job()
return None
def _terminate_for_timeout(self, reason: str):
self.logger.warning(f"job {self.job_id} pod {self.pod_name}: {reason}")
self.terminate()
self.terminal_return_code = JobReturnCode.EXCEPTION
def _get_return_code(self, job_state):
if self.terminal_return_code is not None:
return self.terminal_return_code
return JOB_RETURN_CODE_MAPPING.get(job_state)
[docs]
def poll(self):
if self.terminal_state is not None:
return self._get_return_code(self.terminal_state)
job_state = self._query_state()
if self.terminal_state is not None:
return self._get_return_code(self.terminal_state)
if job_state in (JobState.SUCCEEDED, JobState.TERMINATED):
self.terminal_state = job_state
self._remove_workspace_job()
return self._get_return_code(job_state)
def _query_phase(self):
from kubernetes.client.rest import ApiException
try:
resp = self.api_instance.read_namespaced_pod(name=self.pod_name, namespace=self.namespace)
except ApiException as e:
if getattr(e, "status", None) == 404:
self.logger.info(
f"job {self.job_id} pod {self.pod_name} not found during querying; assuming terminated"
)
self.terminal_state = JobState.TERMINATED
self._remove_workspace_job()
return PodPhase.UNKNOWN.value
else:
self.logger.warning(f"failed to query pod phase for job {self.job_id} pod {self.pod_name}: {e}")
return None # no pod phase was observed
except Exception as e:
self.logger.warning(f"unexpected error querying pod phase for job {self.job_id} pod {self.pod_name}: {e}")
return None # no pod phase was observed
return resp.status.phase
def _query_state(self):
pod_phase = self._query_phase()
return POD_STATE_MAPPING.get(pod_phase, JobState.UNKNOWN)
def _stuck_in_pending(self, current_phase):
if current_phase is None:
return False
if current_phase in (PodPhase.PENDING.value, PodPhase.UNKNOWN.value):
self._stuck_count += 1
if self._max_stuck_count is not None and self._stuck_count >= self._max_stuck_count:
return True
else:
self._stuck_count = 0
return False
[docs]
def wait(self):
while True:
if self.terminal_state is not None:
return
job_state = self._query_state()
if self.terminal_state is not None:
return
if job_state in (JobState.SUCCEEDED, JobState.TERMINATED):
self.terminal_state = job_state # persist so poll() stays accurate
self._remove_workspace_job()
return
time.sleep(1)
[docs]
class K8sJobLauncher(JobLauncherSpec):
def __init__(
self,
config_file_path: str,
study_data_pvc_file_path: str,
timeout=None,
namespace=DEFAULT_NAMESPACE,
pending_timeout=DEFAULT_PENDING_TIMEOUT,
python_path=None,
security_context: dict = None,
ephemeral_storage: str = DEFAULT_EPHEMERAL_STORAGE,
default_python_path: str = None,
workspace_mount_path: str = WORKSPACE_MOUNT_PATH,
image_pull_secrets: list[str] = None,
):
super().__init__()
self.logger = logging.getLogger(self.__class__.__name__)
self.config_file_path = config_file_path
self.study_data_pvc_file_path = study_data_pvc_file_path
self.timeout = timeout
self.namespace = namespace
self.pending_timeout = pending_timeout
self.default_python_path = default_python_path if default_python_path is not None else python_path
if self.default_python_path is None:
self.default_python_path = DEFAULT_PYTHON_PATH
if not isinstance(self.default_python_path, str) or not self.default_python_path:
raise ValueError("default_python_path must be a non-empty string")
self.security_context = security_context
if not isinstance(ephemeral_storage, str) or not ephemeral_storage:
raise ValueError("ephemeral_storage must be a non-empty string")
self.ephemeral_storage = ephemeral_storage
if not isinstance(workspace_mount_path, str) or not workspace_mount_path:
raise ValueError("workspace_mount_path must be a non-empty string")
self.workspace_mount_path = workspace_mount_path
self.image_pull_secrets = _normalize_image_pull_secrets(image_pull_secrets)
self.study_data_pvc_dict = None
self.core_v1 = None
def _ensure_startup_secret(self, site_name: str, startup_dir: str) -> str:
"""Create or update a k8s Secret containing the site startup kit.
Returns the Secret name.
"""
from kubernetes.client.rest import ApiException
secret_name = f"nvflare-startup-{site_name_to_rfc1123(site_name)}"
data = {}
if os.path.isdir(startup_dir):
for fname in os.listdir(startup_dir):
if not _keep_startup_file(fname):
continue
fpath = os.path.join(startup_dir, fname)
if os.path.isfile(fpath):
with open(fpath, "rb") as f:
data[fname] = base64.b64encode(f.read()).decode()
secret_body = {
"apiVersion": "v1",
"kind": "Secret",
"metadata": {"name": secret_name, "namespace": self.namespace},
"type": "Opaque",
"data": data,
}
try:
self.core_v1.create_namespaced_secret(namespace=self.namespace, body=secret_body)
self.logger.debug("Created startup Secret %s", secret_name)
except ApiException as e:
if getattr(e, "status", None) == 409:
self.core_v1.replace_namespaced_secret(name=secret_name, namespace=self.namespace, body=secret_body)
self.logger.debug("Updated startup Secret %s", secret_name)
else:
raise
return secret_name
[docs]
def launch_job(self, job_meta: dict, fl_ctx: FLContext) -> JobHandleSpec:
if self.core_v1 is None:
from kubernetes import config
from kubernetes.client import Configuration
from kubernetes.client.api import core_v1_api
try:
if self.config_file_path:
config.load_kube_config(self.config_file_path)
else:
config.load_incluster_config()
c = Configuration().get_default_copy()
except AttributeError:
c = Configuration()
c.assert_hostname = False
Configuration.set_default(c)
self.core_v1 = core_v1_api.CoreV1Api()
site_name = fl_ctx.get_identity_name()
raw_job_id = job_meta.get(JobConstants.JOB_ID)
if not raw_job_id:
raise RuntimeError(f"missing {JobConstants.JOB_ID} in job_meta")
job_id = uuid4_to_rfc1123(raw_job_id)
pod_name = job_pod_name(job_id, site_name)
workspace_obj = fl_ctx.get_prop(FLContextKey.WORKSPACE_OBJECT)
if workspace_obj is None:
raise RuntimeError(f"missing {FLContextKey.WORKSPACE_OBJECT} in FLContext")
app_custom_folder = workspace_obj.get_app_custom_dir(raw_job_id)
args = fl_ctx.get_prop(FLContextKey.ARGS)
if args is None:
raise RuntimeError(f"missing {FLContextKey.ARGS} in FLContext")
k8s_spec = get_job_launcher_spec(job_meta, site_name, "k8s")
job_image = k8s_spec.get("image")
job_ephemeral_storage = k8s_spec.get("ephemeral_storage")
if job_ephemeral_storage is None:
job_ephemeral_storage = self.ephemeral_storage
if not isinstance(job_ephemeral_storage, str) or not job_ephemeral_storage:
raise RuntimeError(f"launcher_spec['{site_name}']['k8s']['ephemeral_storage'] must be a non-empty string")
if not job_image:
raise RuntimeError(
f"K8sJobLauncher is configured for site '{site_name}' but no job image "
f"was specified in meta.json for this site. "
f"Set launcher_spec['{site_name}']['k8s']['image'] (preferred), "
f"launcher_spec['default']['k8s']['image'] (shared default), "
f"or resource_spec['{site_name}']['k8s']['image'] (legacy)."
)
study = job_meta.get(JobMetaKey.STUDY.value)
data_mounts = []
if should_mount_study_data(study):
if self.study_data_pvc_dict is None:
self.study_data_pvc_dict = load_study_data_file(self.study_data_pvc_file_path, logger=self.logger)
data_mounts = resolve_study_dataset_mounts(
self.study_data_pvc_dict, study, self.study_data_pvc_file_path, logger=self.logger
)
site_resources = (job_meta.get(JobMetaKey.RESOURCE_SPEC.value) or {}).get(site_name) or {}
flat_gpu_count = (
0
if any(k in site_resources for k in ("process", "docker", "k8s"))
else site_resources.get("num_of_gpus", 0)
)
job_resource = k8s_spec["num_of_gpus"] if "num_of_gpus" in k8s_spec else flat_gpu_count
job_args = fl_ctx.get_prop(FLContextKey.JOB_PROCESS_ARGS)
if not job_args:
raise RuntimeError(f"missing {FLContextKey.JOB_PROCESS_ARGS} in FLContext")
exe_module_entry = job_args.get(JobProcessArgs.EXE_MODULE)
if not exe_module_entry:
raise RuntimeError(f"missing {JobProcessArgs.EXE_MODULE} in {FLContextKey.JOB_PROCESS_ARGS}")
_, job_cmd = exe_module_entry
workspace_root = args.workspace
env = {}
if app_custom_folder:
workspace_root_abs = os.path.abspath(workspace_root)
custom_folder_abs = os.path.abspath(app_custom_folder)
if os.path.commonpath([workspace_root_abs, custom_folder_abs]) != workspace_root_abs:
raise RuntimeError(f"custom folder {app_custom_folder} is not under workspace {workspace_root}")
env["PYTHONPATH"] = os.path.join(
self.workspace_mount_path, os.path.relpath(custom_folder_abs, workspace_root_abs)
)
startup_dir = workspace_obj.get_startup_kit_dir()
engine = fl_ctx.get_engine()
owner_cell = getattr(engine, "cell", None) if engine else None
if owner_cell is None:
raise RuntimeError("missing parent CellNet cell for workspace transfer")
workspace_transfer = WorkspaceTransferManager.get_or_create(owner_cell)
workspace_transfer_token = workspace_transfer.add_job(raw_job_id, workspace_root)
try:
startup_secret_name = self._ensure_startup_secret(site_name, startup_dir)
env[ENV_WORKSPACE_OWNER_FQCN] = workspace_transfer.owner_fqcn
env[ENV_WORKSPACE_TRANSFER_TOKEN] = workspace_transfer_token
volume_list = [
{"name": "workspace-job", "emptyDir": {"sizeLimit": job_ephemeral_storage}},
{"name": "startup-kit", "secret": {"secretName": startup_secret_name}},
]
volume_mount_list = [
{"name": "workspace-job", "mountPath": self.workspace_mount_path},
{
"name": "startup-kit",
"mountPath": os.path.join(self.workspace_mount_path, "startup"),
"readOnly": True,
},
]
for dataset_mount in data_mounts:
volume_name = study_dataset_volume_name(dataset_mount.study, dataset_mount.dataset)
volume_list.append({"name": volume_name, "persistentVolumeClaim": {"claimName": dataset_mount.source}})
volume_mount_list.append(
{
"name": volume_name,
"mountPath": dataset_mount.mount_path,
"readOnly": dataset_mount.read_only,
}
)
job_config = {
"name": pod_name,
"image": job_image,
"container_name": f"container-{job_id}",
"command": job_cmd,
"volume_mount_list": volume_mount_list,
"volume_list": volume_list,
"module_args": self.get_module_args(job_id, fl_ctx),
"env": env,
}
if self.image_pull_secrets:
job_config["image_pull_secrets"] = self.image_pull_secrets
if args is not None and getattr(args, "set", None) is not None:
job_config.update({"set_list": args.set})
resources = {
"requests": {"ephemeral-storage": job_ephemeral_storage},
"limits": {"ephemeral-storage": job_ephemeral_storage},
}
for key in ("cpu", "memory"):
limit_val = k8s_spec.get(key)
# cpu_request / memory_request allow request < limit; when absent,
# request mirrors the limit so admission webhooks that require
# explicit cpu/memory requests (e.g. AKS deployment safeguards) pass.
request_val = k8s_spec.get(f"{key}_request", limit_val)
if limit_val:
resources["limits"][key] = limit_val
if request_val:
resources["requests"][key] = request_val
if job_resource:
resources["limits"]["nvidia.com/gpu"] = job_resource
resources["requests"]["nvidia.com/gpu"] = job_resource
job_config["resources"] = resources
if self.security_context:
job_config["security_context"] = self.security_context
python_path = k8s_spec.get("python_path", self.default_python_path)
if not isinstance(python_path, str) or not python_path:
raise RuntimeError(f"launcher_spec['{site_name}']['k8s']['python_path'] must be a non-empty string")
job_handle = K8sJobHandle(
job_id,
self.core_v1,
job_config,
namespace=self.namespace,
timeout=self.timeout,
pending_timeout=self.pending_timeout,
python_path=python_path,
workspace_transfer=workspace_transfer,
workspace_job_id=raw_job_id,
pod_name=pod_name,
)
pod_manifest = job_handle.get_manifest()
self.logger.debug(
"launch job with k8s_launcher: pod_name=%s namespace=%s image=%s",
pod_manifest["metadata"]["name"],
self.namespace,
job_image,
)
self.core_v1.create_namespaced_pod(body=pod_manifest, namespace=self.namespace)
except Exception as e:
workspace_transfer.remove_job(raw_job_id)
if "job_handle" in locals():
self.logger.error(f"failed to launch job {job_id}: {e}")
job_handle.terminal_state = JobState.TERMINATED
return job_handle
raise
try:
entered_running = job_handle.enter_states([JobState.RUNNING])
except BaseException:
job_handle.terminate()
raise
if not entered_running:
self.logger.warning(f"unable to enter running phase {job_id}")
return job_handle
[docs]
def handle_event(self, event_type: str, fl_ctx: FLContext):
if event_type == EventType.BEFORE_JOB_LAUNCH:
add_launcher(self, fl_ctx)
[docs]
@abstractmethod
def get_module_args(self, job_id, fl_ctx: FLContext):
"""To get the args to run the launcher
Args:
job_id: run job_id
fl_ctx: FLContext
Returns:
"""
pass
def _job_args_dict(job_args: dict, arg_names: list) -> dict:
result = {}
for name in arg_names:
e = job_args.get(name)
if e is None:
continue
n, v = e
result[n] = v
return result
[docs]
class ClientK8sJobLauncher(K8sJobLauncher):
[docs]
def get_module_args(self, _job_id, fl_ctx: FLContext):
job_args = fl_ctx.get_prop(FLContextKey.JOB_PROCESS_ARGS)
if not job_args:
raise RuntimeError(f"missing {FLContextKey.JOB_PROCESS_ARGS} in FLContext")
return _job_args_dict(job_args, get_client_job_args(False, False))
[docs]
class ServerK8sJobLauncher(K8sJobLauncher):
[docs]
def get_module_args(self, _job_id, fl_ctx: FLContext):
job_args = fl_ctx.get_prop(FLContextKey.JOB_PROCESS_ARGS)
if not job_args:
raise RuntimeError(f"missing {FLContextKey.JOB_PROCESS_ARGS} in FLContext")
return _job_args_dict(job_args, get_server_job_args(False, False))