Source code for nvflare.app_common.job_schedulers.job_scheduler

# Copyright (c) 2021-2022, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import threading
from typing import Dict, List, Optional, Tuple

from nvflare.apis.event_type import EventType
from nvflare.apis.fl_component import FLComponent
from nvflare.apis.fl_constant import FLContextKey
from nvflare.apis.fl_context import FLContext
from nvflare.apis.job_def import ALL_SITES, Job, JobMetaKey
from nvflare.apis.job_scheduler_spec import DispatchInfo, JobSchedulerSpec
from nvflare.apis.server_engine_spec import ServerEngineSpec

SERVER_SITE_NAME = "server"


[docs]class DefaultJobScheduler(JobSchedulerSpec, FLComponent): def __init__( self, max_jobs: int = 1, ): super().__init__() self.max_jobs = max_jobs self.scheduled_jobs = [] self.lock = threading.Lock() def _check_client_resources(self, resource_reqs: Dict[str, dict], fl_ctx: FLContext) -> Dict[str, Tuple[bool, str]]: """Checks resources on each site. Args: resource_reqs (dict): {client_name: resource_requirements} Returns: A dict of {client_name: client_check_result} where client_check_result is a tuple of {client check OK, resource reserve token if any} """ engine = fl_ctx.get_engine() if not isinstance(engine, ServerEngineSpec): raise RuntimeError(f"engine inside fl_ctx should be of type ServerEngineSpec, but got {type(engine)}.") result = engine.check_client_resources(resource_reqs) self.log_debug(fl_ctx, f"check client resources result: {result}") return result def _cancel_resources( self, resource_reqs: Dict[str, dict], resource_check_results: Dict[str, Tuple[bool, str]], fl_ctx: FLContext ): """Cancels any reserved resources based on resource check results. Args: resource_reqs (dict): {client_name: resource_requirements} resource_check_results: A dict of {client_name: client_check_result} where client_check_result is a tuple of {client check OK, resource reserve token if any} fl_ctx: FL context """ engine = fl_ctx.get_engine() if not isinstance(engine, ServerEngineSpec): raise RuntimeError(f"engine inside fl_ctx should be of type ServerEngineSpec, but got {type(engine)}.") engine.cancel_client_resources(resource_check_results, resource_reqs) self.log_debug(fl_ctx, f"cancel client resources using check results: {resource_check_results}") return False, None def _try_job(self, job: Job, fl_ctx) -> (bool, Optional[Dict[str, DispatchInfo]]): engine = fl_ctx.get_engine() online_clients = engine.get_clients() online_site_names = [x.name for x in online_clients] if not job.deploy_map: raise RuntimeError(f"Job ({job.job_id}) does not have deploy_map, can't be scheduled.") applicable_sites = [] sites_to_app = {} for app_name in job.deploy_map: for site_name in job.deploy_map[app_name]: if site_name.upper() == ALL_SITES: # deploy_map: {"app_name": ["ALL_SITES"]} will be treated as deploying to all online clients applicable_sites = online_site_names sites_to_app = {x: app_name for x in online_site_names} sites_to_app[SERVER_SITE_NAME] = app_name elif site_name in online_site_names: applicable_sites.append(site_name) sites_to_app[site_name] = app_name elif site_name == SERVER_SITE_NAME: sites_to_app[SERVER_SITE_NAME] = app_name self.log_debug(fl_ctx, f"Job {job.job_id} is checking against applicable sites: {applicable_sites}") required_sites = job.required_sites if job.required_sites else [] if required_sites: for s in required_sites: if s not in applicable_sites: self.log_debug(fl_ctx, f"Job {job.job_id} can't be scheduled: required site {s} is not connected.") return False, None if job.min_sites and len(applicable_sites) < job.min_sites: self.log_debug( fl_ctx, f"Job {job.job_id} can't be scheduled: connected sites ({len(applicable_sites)}) " f"are less than min_sites ({job.min_sites}).", ) return False, None # we are assuming server resource is sufficient resource_reqs = {} for site_name in applicable_sites: if site_name in job.resource_spec: resource_reqs[site_name] = job.resource_spec[site_name] else: resource_reqs[site_name] = {} resource_check_results = self._check_client_resources(resource_reqs=resource_reqs, fl_ctx=fl_ctx) if not resource_check_results: self.log_debug(fl_ctx, f"Job {job.job_id} can't be scheduled: resource check results is None or empty.") return False, None required_sites_not_enough_resource = list(required_sites) num_sites_ok = 0 sites_dispatch_info = {} for site_name, check_result in resource_check_results.items(): if check_result[0]: sites_dispatch_info[site_name] = DispatchInfo( app_name=sites_to_app[site_name], resource_requirements=resource_reqs[site_name], token=check_result[1], ) num_sites_ok += 1 if site_name in required_sites: required_sites_not_enough_resource.remove(site_name) if num_sites_ok < job.min_sites: self.log_debug(fl_ctx, f"Job {job.job_id} can't be scheduled: not enough sites have enough resources.") return self._cancel_resources( resource_reqs=job.resource_spec, resource_check_results=resource_check_results, fl_ctx=fl_ctx ) if required_sites_not_enough_resource: self.log_debug( fl_ctx, f"Job {job.job_id} can't be scheduled: required sites: {required_sites_not_enough_resource}" f" don't have enough resources.", ) return self._cancel_resources( resource_reqs=job.resource_spec, resource_check_results=resource_check_results, fl_ctx=fl_ctx ) # add server dispatch info sites_dispatch_info[SERVER_SITE_NAME] = DispatchInfo( app_name=sites_to_app[SERVER_SITE_NAME], resource_requirements={}, token=None ) return True, sites_dispatch_info def _exceed_max_jobs(self, fl_ctx: FLContext) -> bool: exceed_limit = False with self.lock: if len(self.scheduled_jobs) >= self.max_jobs: self.log_debug( fl_ctx, f"Skipping schedule job because scheduled_jobs ({len(self.scheduled_jobs)}) " f"is greater than max_jobs ({self.max_jobs})", ) exceed_limit = True return exceed_limit
[docs] def handle_event(self, event_type: str, fl_ctx: FLContext): if event_type == EventType.JOB_STARTED: with self.lock: job_id = fl_ctx.get_prop(FLContextKey.CURRENT_JOB_ID) if job_id not in self.scheduled_jobs: self.scheduled_jobs.append(job_id) elif event_type == EventType.JOB_COMPLETED or event_type == EventType.JOB_ABORTED: with self.lock: job_id = fl_ctx.get_prop(FLContextKey.CURRENT_JOB_ID) if job_id in self.scheduled_jobs: self.scheduled_jobs.remove(job_id)
[docs] def schedule_job( self, job_candidates: List[Job], fl_ctx: FLContext ) -> (Optional[Job], Optional[Dict[str, DispatchInfo]]): self.log_debug(fl_ctx, f"Current scheduled_jobs is {self.scheduled_jobs}") if self._exceed_max_jobs(fl_ctx=fl_ctx): return None, None # sort by submitted time job_candidates.sort(key=lambda j: j.meta.get(JobMetaKey.SUBMIT_TIME, 0.0)) for job in job_candidates: ok, sites_dispatch_info = self._try_job(job, fl_ctx) self.log_debug(fl_ctx, f"Try to schedule job {job.job_id}, get result: {ok}, {sites_dispatch_info}.") if ok: return job, sites_dispatch_info self.log_debug(fl_ctx, "No job is scheduled.") return None, None