# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import datetime
import os
import pathlib
import shutil
import tempfile
import threading
import time
from abc import ABC, abstractmethod
from typing import Any, Dict, List, Optional, Union
from nvflare.apis.client_engine_spec import ClientEngineSpec
from nvflare.apis.fl_context import FLContext
from nvflare.apis.job_def import (
Job,
JobDataKey,
JobMetaKey,
SubmitRecordKey,
SubmitRecordState,
job_from_meta,
new_job_id,
)
from nvflare.apis.job_def_manager_spec import JobDefManagerSpec, RunStatus
from nvflare.apis.server_engine_spec import ServerEngineSpec
from nvflare.apis.storage import WORKSPACE, StorageException, StorageSpec
from nvflare.apis.utils.format_check import check_job_app_name, check_job_id
from nvflare.apis.utils.job_submit_token import (
canonical_job_content_hash,
canonical_json_hash,
submit_record_scope_hashes,
submitter_to_dict,
)
from nvflare.fuel.utils import fobs
from nvflare.fuel.utils.zip_utils import unzip_all_from_bytes, zip_directory_to_bytes
_OBJ_TAG_SCHEDULED = "scheduled"
_SUBMIT_RECORD_URI_ROOT = "job_submit_records"
_SUBMIT_RECORD_JOB_INDEX_URI_ROOT = "job_submit_record_index"
_SUBMIT_RECORD_URIS_KEY = "submit_record_uris"
[docs]
class JobInfo:
def __init__(self, meta: dict, job_id: str, uri: str):
self.meta = meta
self.job_id = job_id
self.uri = uri
class _JobFilter(ABC):
@abstractmethod
def filter_job(self, info: JobInfo) -> bool:
pass
class _StatusFilter(_JobFilter):
def __init__(self, status_to_check):
self.result = []
if not isinstance(status_to_check, list):
# turning to list
status_to_check = [status_to_check]
self.status_to_check = status_to_check
def filter_job(self, info: JobInfo):
status = info.meta.get(JobMetaKey.STATUS.value)
if status in self.status_to_check:
self.result.append(job_from_meta(info.meta))
return True
class _AllJobsFilter(_JobFilter):
def __init__(self):
self.result = []
def filter_job(self, info: JobInfo):
self.result.append(job_from_meta(info.meta))
return True
class _ReviewerFilter(_JobFilter):
def __init__(self, reviewer_name):
"""Not used yet, for use in future implementations."""
self.result = []
self.reviewer_name = reviewer_name
def filter_job(self, info: JobInfo):
approvals = info.meta.get(JobMetaKey.APPROVALS)
if not approvals or self.reviewer_name not in approvals:
self.result.append(job_from_meta(info.meta))
return True
class _ScheduleJobFilter(_JobFilter):
"""
This filter is optimized for selecting jobs to schedule since it is used so frequently (every 1 sec).
"""
def __init__(self, store):
self.store = store
self.result = []
def filter_job(self, info: JobInfo):
status = info.meta.get(JobMetaKey.STATUS.value)
if status == RunStatus.SUBMITTED.value:
self.result.append(job_from_meta(info.meta))
elif status:
# skip this job in all future calls (so the meta file of this job won't be read)
self.store.tag_object(uri=info.uri, tag=_OBJ_TAG_SCHEDULED)
return True
[docs]
class SimpleJobDefManager(JobDefManagerSpec):
def __init__(self, uri_root: str = "jobs", job_store_id: str = "job_store"):
super().__init__()
self.uri_root = uri_root
# if env var is defined, use it to override uri_root!
job_store_root = os.environ.get("NVFL_JOB_STORE_ROOT")
if job_store_root:
self.uri_root = job_store_root
os.makedirs(uri_root, exist_ok=True)
self.job_store_id = job_store_id
# Submit-token records are a sidecar namespace beside the job store so they are not
# enumerated as jobs by stores that scan uri_root directly.
uri_root = self.uri_root.rstrip(os.sep) or self.uri_root
self.submit_record_uri_root = os.path.join(os.path.dirname(uri_root), _SUBMIT_RECORD_URI_ROOT)
self.submit_record_job_index_uri_root = os.path.join(
os.path.dirname(uri_root), _SUBMIT_RECORD_JOB_INDEX_URI_ROOT
)
self._submit_record_lock = threading.Lock()
def _get_job_store(self, fl_ctx):
engine = fl_ctx.get_engine()
if not (isinstance(engine, ServerEngineSpec) or isinstance(engine, ClientEngineSpec)):
raise TypeError(f"engine should be of type ServerEngineSpec or ClientEngineSpec, but got {type(engine)}")
store = engine.get_component(self.job_store_id)
if not isinstance(store, StorageSpec):
raise TypeError(f"engine should have a job store component of type StorageSpec, but got {type(store)}")
return store
[docs]
def job_uri(self, jid: str):
check_job_id(jid)
return os.path.join(self.uri_root, jid)
[docs]
def submit_record_uri(self, study: str, submitter, submit_token: str):
study_hash, submitter_hash, submit_token_hash = submit_record_scope_hashes(study, submitter, submit_token)
return os.path.join(self.submit_record_uri_root, study_hash, submitter_hash, submit_token_hash)
def _submit_record_uri_from_record(self, record: dict):
submitter = {
"name": record.get(SubmitRecordKey.SUBMITTER_NAME.value, ""),
"org": record.get(SubmitRecordKey.SUBMITTER_ORG.value, ""),
"role": record.get(SubmitRecordKey.SUBMITTER_ROLE.value, ""),
}
return self.submit_record_uri(
record.get(SubmitRecordKey.STUDY.value, ""),
submitter,
record.get(SubmitRecordKey.SUBMIT_TOKEN.value),
)
def _submit_record_job_index_uri(self, job_id: str) -> str:
return os.path.join(self.submit_record_job_index_uri_root, canonical_json_hash(job_id or ""))
def _upsert_submit_record_job_index(self, store: StorageSpec, record: dict):
job_id = record.get(SubmitRecordKey.JOB_ID.value)
if not job_id:
return
index_uri = self._submit_record_job_index_uri(job_id)
record_uri = self._submit_record_uri_from_record(record)
try:
index_meta = store.get_meta(index_uri) or {}
except StorageException:
index_meta = {}
submit_record_uris = list(index_meta.get(_SUBMIT_RECORD_URIS_KEY, []))
if record_uri not in submit_record_uris:
submit_record_uris.append(record_uri)
updated_meta = {SubmitRecordKey.JOB_ID.value: job_id, _SUBMIT_RECORD_URIS_KEY: submit_record_uris}
if index_meta:
store.update_meta(index_uri, updated_meta, replace=True)
return
try:
store.create_object(index_uri, b"", updated_meta, overwrite_existing=False)
except StorageException:
existing_meta = store.get_meta(index_uri) or {}
existing_uris = list(existing_meta.get(_SUBMIT_RECORD_URIS_KEY, []))
if record_uri not in existing_uris:
existing_uris.append(record_uri)
store.update_meta(
index_uri,
{SubmitRecordKey.JOB_ID.value: job_id, _SUBMIT_RECORD_URIS_KEY: existing_uris},
replace=True,
)
[docs]
def get_job_content_hash(self, uploaded_content: Union[str, bytes]) -> str:
return canonical_job_content_hash(uploaded_content)
[docs]
def get_submit_record(self, study: str, submitter, submit_token: str, fl_ctx: FLContext) -> Optional[dict]:
store = self._get_job_store(fl_ctx)
try:
return store.get_meta(self.submit_record_uri(study, submitter, submit_token))
except StorageException:
return None
[docs]
def create_submit_record(self, record: dict, fl_ctx: FLContext) -> bool:
store = self._get_job_store(fl_ctx)
uri = self._submit_record_uri_from_record(record)
with self._submit_record_lock:
try:
store.create_object(uri, b"", record, overwrite_existing=False)
except StorageException:
try:
if store.get_meta(uri):
return False
except StorageException:
pass
raise
self._upsert_submit_record_job_index(store, record)
return True
[docs]
def update_submit_record(self, record: dict, fl_ctx: FLContext) -> dict:
store = self._get_job_store(fl_ctx)
uri = self._submit_record_uri_from_record(record)
with self._submit_record_lock:
store.update_meta(uri, record, replace=True)
self._upsert_submit_record_job_index(store, record)
return record
[docs]
def mark_submit_records_job_deleted(self, job_id: str, deleted_by, fl_ctx: FLContext) -> List[dict]:
store = self._get_job_store(fl_ctx)
index_uri = self._submit_record_job_index_uri(job_id)
deleted_by_info = submitter_to_dict(deleted_by)
deleted_time = datetime.datetime.now(datetime.timezone.utc).isoformat()
updated_records = []
with self._submit_record_lock:
try:
index_meta = store.get_meta(index_uri) or {}
except StorageException:
return []
for record_uri in index_meta.get(_SUBMIT_RECORD_URIS_KEY, []):
try:
record = store.get_meta(record_uri)
except StorageException:
continue
if not record or record.get(SubmitRecordKey.JOB_ID.value) != job_id:
continue
if record.get(SubmitRecordKey.STATE.value) == SubmitRecordState.JOB_DELETED.value:
continue
record[SubmitRecordKey.STATE.value] = SubmitRecordState.JOB_DELETED.value
record[SubmitRecordKey.DELETED_TIME.value] = deleted_time
record[SubmitRecordKey.DELETED_BY.value] = deleted_by_info
store.update_meta(record_uri, record, replace=True)
updated_records.append(record)
return updated_records
[docs]
def get_job_by_submit_token(self, study: str, submitter, submit_token: str, fl_ctx: FLContext) -> Optional[Job]:
record = self.get_submit_record(study, submitter, submit_token, fl_ctx)
if not record:
return None
jid = record.get(SubmitRecordKey.JOB_ID.value)
if not jid:
return None
return self.get_job(jid, fl_ctx)
[docs]
@staticmethod
def new_submit_record(
study: str,
submitter,
submit_token: str,
job_content_hash: str,
job_name: str = "",
job_folder_name: str = "",
job_id: str = None,
state: str = SubmitRecordState.CREATING.value,
) -> dict:
submitter_info = submitter_to_dict(submitter)
return {
SubmitRecordKey.SCHEMA_VERSION.value: 1,
SubmitRecordKey.STATE.value: state,
SubmitRecordKey.SUBMIT_TOKEN.value: submit_token,
SubmitRecordKey.JOB_ID.value: job_id or new_job_id(),
SubmitRecordKey.STUDY.value: study,
SubmitRecordKey.SUBMITTER_NAME.value: submitter_info["name"],
SubmitRecordKey.SUBMITTER_ORG.value: submitter_info["org"],
SubmitRecordKey.SUBMITTER_ROLE.value: submitter_info["role"],
SubmitRecordKey.JOB_NAME.value: job_name,
SubmitRecordKey.JOB_FOLDER_NAME.value: job_folder_name,
SubmitRecordKey.JOB_CONTENT_HASH.value: job_content_hash,
SubmitRecordKey.SUBMIT_TIME.value: datetime.datetime.now().astimezone().isoformat(),
}
[docs]
def create(self, meta: dict, uploaded_content: Union[str, bytes], fl_ctx: FLContext) -> Dict[str, Any]:
meta.pop(SubmitRecordKey.SUBMIT_TOKEN.value, None)
# validate meta to make sure it has:
jid = meta.get(JobMetaKey.JOB_ID.value, None)
if not jid:
jid = new_job_id()
meta[JobMetaKey.JOB_ID.value] = jid
else:
check_job_id(jid)
now = time.time()
meta[JobMetaKey.SUBMIT_TIME.value] = now
meta[JobMetaKey.SUBMIT_TIME_ISO.value] = datetime.datetime.fromtimestamp(now).astimezone().isoformat()
meta[JobMetaKey.START_TIME.value] = ""
meta[JobMetaKey.DURATION.value] = "N/A"
meta[JobMetaKey.DATA_STORAGE_FORMAT.value] = 2
meta[JobMetaKey.STATUS.value] = RunStatus.SUBMITTED.value
# write it to the store
store = self._get_job_store(fl_ctx)
store.create_object(self.job_uri(jid), uploaded_content, meta, overwrite_existing=False)
return meta
[docs]
def clone(self, from_jid: str, meta: dict, fl_ctx: FLContext) -> Dict[str, Any]:
check_job_id(from_jid)
jid = meta.get(JobMetaKey.JOB_ID.value, None)
if not jid:
jid = new_job_id()
meta[JobMetaKey.JOB_ID.value] = jid
else:
check_job_id(jid)
now = time.time()
meta[JobMetaKey.SUBMIT_TIME.value] = now
meta[JobMetaKey.SUBMIT_TIME_ISO.value] = datetime.datetime.fromtimestamp(now).astimezone().isoformat()
meta[JobMetaKey.START_TIME.value] = ""
meta[JobMetaKey.DURATION.value] = "N/A"
meta[JobMetaKey.STATUS.value] = RunStatus.SUBMITTED.value
# write it to the store
store = self._get_job_store(fl_ctx)
store.clone_object(
from_uri=self.job_uri(from_jid), to_uri=self.job_uri(jid), meta=meta, overwrite_existing=False
)
return meta
[docs]
def delete(self, jid: str, fl_ctx: FLContext):
store = self._get_job_store(fl_ctx)
store.delete_object(self.job_uri(jid))
def _validate_meta(self, meta):
"""Validate meta
Args:
meta: meta to validate
Returns:
"""
pass
def _validate_uploaded_content(self, uploaded_content) -> bool:
"""Validate uploaded content for creating a run config. (THIS NEEDS TO HAPPEN BEFORE CONTENT IS PROVIDED NOW)
Internally used by create and update.
1. check all sites in deployment are in resources
2. each site in deployment need to have resources (each site in resource need to be in deployment ???)
"""
pass
[docs]
def get_job(self, jid: str, fl_ctx: FLContext) -> Optional[Job]:
store = self._get_job_store(fl_ctx)
try:
job_meta = store.get_meta(self.job_uri(jid))
return job_from_meta(job_meta)
except StorageException:
return None
[docs]
def set_results_uri(self, jid: str, result_uri: str, fl_ctx: FLContext):
store = self._get_job_store(fl_ctx)
updated_meta = {JobMetaKey.RESULT_LOCATION.value: result_uri}
store.update_meta(self.job_uri(jid), updated_meta, replace=False)
return self.get_job(jid, fl_ctx)
[docs]
def get_app(self, job: Job, app_name: str, fl_ctx: FLContext) -> bytes:
check_job_id(job.job_id)
check_job_app_name(app_name)
with tempfile.TemporaryDirectory() as temp_dir:
job_id_dir = self._load_job_data_from_store(job, temp_dir, fl_ctx)
job_folder = os.path.join(job_id_dir, job.meta[JobMetaKey.JOB_FOLDER_NAME.value])
fullpath_src = os.path.join(job_folder, app_name)
job_id_dir_real = os.path.realpath(job_id_dir)
job_folder_real = os.path.realpath(job_folder)
fullpath_src_real = os.path.realpath(fullpath_src)
if os.path.commonpath([job_id_dir_real, job_folder_real]) != job_id_dir_real:
raise ValueError(f"job folder for app '{app_name}' escapes job data folder")
if os.path.commonpath([job_folder_real, fullpath_src_real]) != job_folder_real:
raise ValueError(f"app '{app_name}' escapes job folder")
result = zip_directory_to_bytes(fullpath_src_real, "")
return result
def _load_job_data_from_store(self, job: Job, temp_dir: str, fl_ctx: FLContext):
check_job_id(job.job_id)
data_bytes = self.get_content(job.meta, fl_ctx)
job_id_dir = os.path.join(temp_dir, job.job_id)
if os.path.exists(job_id_dir):
shutil.rmtree(job_id_dir)
os.mkdir(job_id_dir)
unzip_all_from_bytes(data_bytes, job_id_dir)
return job_id_dir
[docs]
def get_content(self, meta: dict, fl_ctx: FLContext) -> Optional[bytes]:
store = self._get_job_store(fl_ctx)
jid = meta.get(JobMetaKey.JOB_ID.value)
if not jid:
raise RuntimeError("no Job ID in meta")
try:
stored_data = store.get_data(self.job_uri(jid))
storage_format = meta.get(JobMetaKey.DATA_STORAGE_FORMAT.value)
if storage_format:
# new format
return stored_data
else:
# old format
return fobs.loads(stored_data).get(JobDataKey.JOB_DATA.value)
except StorageException:
return None
[docs]
def set_client_data(self, jid: str, data: Union[bytes, str], client_name: str, data_type: str, fl_ctx: FLContext):
store = self._get_job_store(fl_ctx)
data_object_type = f"{data_type}_{client_name}"
store.update_object(self.job_uri(jid), data, data_object_type)
[docs]
def get_client_data(self, jid: str, client_name: str, data_type: str, fl_ctx: FLContext) -> Optional[bytes]:
store = self._get_job_store(fl_ctx)
data_object_type = f"{data_type}_{client_name}"
try:
data_data = store.get_data(self.job_uri(jid), data_object_type)
return data_data
except StorageException:
return None
[docs]
def list_components(self, jid: str, fl_ctx: FLContext) -> List[str]:
store = self._get_job_store(fl_ctx)
self.log_debug(
fl_ctx, f"list_components called for {jid}: {store.list_components_of_object(self.job_uri(jid))}"
)
return store.list_components_of_object(self.job_uri(jid))
[docs]
def set_status(self, jid: str, status: RunStatus, fl_ctx: FLContext):
meta = {JobMetaKey.STATUS.value: status.value}
store = self._get_job_store(fl_ctx)
if status == RunStatus.RUNNING.value:
meta[JobMetaKey.START_TIME.value] = str(datetime.datetime.now())
elif status in [
RunStatus.FINISHED_ABORTED.value,
RunStatus.FINISHED_COMPLETED.value,
RunStatus.FINISHED_EXECUTION_EXCEPTION.value,
RunStatus.FINISHED_CANT_SCHEDULE.value,
]:
job_meta = store.get_meta(self.job_uri(jid))
if job_meta[JobMetaKey.START_TIME.value]:
start_time = datetime.datetime.strptime(
job_meta.get(JobMetaKey.START_TIME.value), "%Y-%m-%d %H:%M:%S.%f"
)
meta[JobMetaKey.DURATION.value] = str(datetime.datetime.now() - start_time)
store.update_meta(uri=self.job_uri(jid), meta=meta, replace=False)
[docs]
def get_all_jobs(self, fl_ctx: FLContext) -> List[Job]:
job_filter = _AllJobsFilter()
self._scan(job_filter, fl_ctx)
return job_filter.result
[docs]
def get_jobs_to_schedule(self, fl_ctx: FLContext) -> List[Job]:
job_filter = _ScheduleJobFilter(self._get_job_store(fl_ctx))
self._scan(job_filter, fl_ctx, skip_tag=_OBJ_TAG_SCHEDULED)
return job_filter.result
def _scan(self, job_filter: _JobFilter, fl_ctx: FLContext, skip_tag=None):
store = self._get_job_store(fl_ctx)
obj_uris = store.list_objects(self.uri_root, without_tag=skip_tag)
self.log_debug(fl_ctx, f"objects to scan: {len(obj_uris)}")
if not obj_uris:
return
for uri in obj_uris:
jid = pathlib.PurePath(uri).name
job_uri = self.job_uri(jid)
meta = store.get_meta(job_uri)
if meta:
ok = job_filter.filter_job(JobInfo(meta, jid, job_uri))
if not ok:
break
[docs]
def get_jobs_by_status(self, status: Union[RunStatus, List[RunStatus]], fl_ctx: FLContext) -> List[Job]:
"""Get jobs that are in the specified status
Args:
status: a single status value or a list of status values
fl_ctx: the FL context
Returns: list of jobs that are in specified status
"""
job_filter = _StatusFilter(status)
self._scan(job_filter, fl_ctx)
return job_filter.result
[docs]
def get_jobs_waiting_for_review(self, reviewer_name: str, fl_ctx: FLContext) -> List[Job]:
job_filter = _ReviewerFilter(reviewer_name)
self._scan(job_filter, fl_ctx)
return job_filter.result
[docs]
def set_approval(
self, jid: str, reviewer_name: str, approved: bool, note: str, fl_ctx: FLContext
) -> Dict[str, Any]:
meta = self.get_job(jid, fl_ctx).meta
if meta:
approvals = meta.get(JobMetaKey.APPROVALS)
if not approvals:
approvals = {}
meta[JobMetaKey.APPROVALS.value] = approvals
approvals[reviewer_name] = (approved, note)
updated_meta = {JobMetaKey.APPROVALS.value: approvals}
store = self._get_job_store(fl_ctx)
store.update_meta(self.job_uri(jid), updated_meta, replace=False)
return meta
[docs]
def save_workspace(self, jid: str, data: Union[bytes, str, List[str]], fl_ctx: FLContext):
store = self._get_job_store(fl_ctx)
return store.update_object(self.job_uri(jid), data, WORKSPACE)
[docs]
def get_storage_component(self, jid: str, component: str, fl_ctx: FLContext):
store = self._get_job_store(fl_ctx)
return store.get_data(self.job_uri(jid), component)
[docs]
def get_storage_for_download(
self, jid: str, download_dir: str, component: str, download_file: str, fl_ctx: FLContext
):
"""Prepares the specified component of the job for download at the specified directory
The component is prepared for download at download_dir/jid/download_file.
Args:
jid: job ID
download_dir: directory to download the component to
component: component name
download_file: file name to save the downloaded component
fl_ctx: FLContext
"""
store = self._get_job_store(fl_ctx)
job_uri = self.job_uri(jid)
os.makedirs(os.path.join(download_dir, jid), exist_ok=True)
destination_file = os.path.join(download_dir, jid, download_file)
store.get_data_for_download(job_uri, component, destination_file)