# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import uuid
from enum import Enum
from typing import Dict, List, Optional
from nvflare.apis.fl_constant import SystemComponents
from nvflare.apis.fl_context import FLContext
# this is treated as all online sites in job deploy_map
ALL_SITES = "@ALL"
SERVER_SITE_NAME = "server"
[docs]class RunStatus(str, Enum):
SUBMITTED = "SUBMITTED"
APPROVED = "APPROVED"
DISPATCHED = "DISPATCHED"
RUNNING = "RUNNING"
FINISHED_COMPLETED = "FINISHED:COMPLETED"
FINISHED_ABORTED = "FINISHED:ABORTED"
FINISHED_EXECUTION_EXCEPTION = "FINISHED:EXECUTION_EXCEPTION"
FINISHED_ABNORMAL = "FINISHED:ABNORMAL"
FINISHED_CANT_SCHEDULE = "FINISHED:CAN_NOT_SCHEDULE"
FAILED_TO_RUN = "FINISHED:FAILED_TO_RUN"
ABANDONED = "FINISHED:ABANDONED"
[docs]class JobDataKey(str, Enum):
DATA = "data"
META = "meta"
JOB_DATA = "job_data_"
WORKSPACE_DATA = "workspace_data_"
[docs]class TopDir(object):
JOB = "job"
WORKSPACE = "workspace"
[docs]class Job:
def __init__(
self,
job_id: str,
resource_spec: Dict[str, Dict],
deploy_map: Dict[str, List[str]],
meta,
min_sites: int = 1,
required_sites: Optional[List[str]] = None,
):
"""Job object containing the job metadata.
Args:
job_id: Job ID
resource_spec: Resource specification with information on the resources of each client
deploy_map: Deploy map specifying each app and the sites that it should be deployed to
meta: full contents of the persisted metadata for the job for persistent storage
min_sites (int): minimum number of sites
required_sites: A list of required site names
"""
self.job_id = job_id
self.resource_spec = resource_spec # resource_requirements should be {site name: resource}
self.deploy_map = deploy_map # should be {app name: a list of sites}
self.meta = meta
self.min_sites = min_sites
self.required_sites = required_sites
if not self.required_sites:
self.required_sites = []
self.dispatcher_id = None
self.dispatch_time = None
self.submit_time = None
self.run_record = None # job id, dispatched time/UUID, finished time, completion code (normal, aborted)
self.run_aborted = False
[docs] def get_deployment(self) -> Dict[str, List[str]]:
"""Returns the deployment configuration.
::
"deploy_map": {
"hello-numpy-sag-server": [
"server"
],
"hello-numpy-sag-client": [
"client1",
"client2"
],
"hello-numpy-sag-client3": [
"client3"
]
},
Returns:
Contents of deploy_map as a dictionary of strings of app names with their corresponding sites
"""
return self.deploy_map
[docs] def get_application(self, app_name, fl_ctx: FLContext) -> bytes:
"""Get the application content in bytes for the specified participant."""
# application_name = self.get_application_name(participant)
engine = fl_ctx.get_engine()
job_def_manager = engine.get_component(SystemComponents.JOB_MANAGER)
# # if not isinstance(job_def_manager, JobDefManagerSpec):
# # raise TypeError(f"job_def_manager must be JobDefManagerSpec type. Got: {type(job_def_manager)}")
return job_def_manager.get_app(self, app_name, fl_ctx)
[docs] def get_application_name(self, participant):
"""Get the application name for the specified participant."""
for app in self.deploy_map:
for site in self.deploy_map[app]:
if site == participant:
return app
return None
[docs] def get_resource_requirements(self):
"""Returns app resource requirements.
Returns:
A dict of {site_name: resource}
"""
return self.resource_spec
def __eq__(self, other):
return self.job_id == other.job_id
[docs]def new_job_id() -> str:
return str(uuid.uuid4())
[docs]def is_valid_job_id(jid: str) -> bool:
if not isinstance(jid, str):
return False
try:
val = uuid.UUID(jid, version=4)
except ValueError:
return False
# If the jid string is a valid hex code, but an invalid uuid4,the UUID.__init__ will convert it to a
# valid uuid4. This is bad for validation purposes.
return val.hex == jid.replace("-", "")
[docs]def get_custom_prop(meta: dict, prop_key: str, default=None):
props = meta.get(JobMetaKey.CUSTOM_PROPS)
if not props:
return default
return props.get(prop_key, default)
[docs]def get_custom_props(meta: dict, default=None):
return meta.get(JobMetaKey.CUSTOM_PROPS, default)