Source code for nvflare.private.fed.server.job_cmds

# Copyright (c) 2021-2022, NVIDIA CORPORATION.  All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# See the License for the specific language governing permissions and
# limitations under the License.
import datetime
import io
import json
import logging
from typing import Dict, List

from nvflare.apis.job_def import Job, JobMetaKey
from nvflare.apis.job_def_manager_spec import JobDefManagerSpec, RunStatus
from nvflare.fuel.hci.conn import Connection
from nvflare.fuel.hci.reg import CommandModuleSpec, CommandSpec
from nvflare.fuel.hci.server.authz import AuthorizationService
from nvflare.fuel.hci.server.constants import ConnProps
from nvflare.fuel.hci.table import Table
from nvflare.fuel.utils.argument_utils import SafeArgumentParser
from nvflare.private.fed.server.server_engine import ServerEngine
from import Action

from .cmd_utils import CommandUtil
from .training_cmds import TrainingCommandModule

[docs]class JobCommandModule(TrainingCommandModule, CommandUtil): """Command module with commands for job management.""" def __init__(self): super().__init__() self.logger = logging.getLogger(self.__class__.__name__)
[docs] def get_spec(self): return CommandModuleSpec( name="job_mgmt", cmd_specs=[ CommandSpec( name="list_jobs", description="list submitted jobs", usage="list_jobs [-n name_prefix] [-d] [job_id_prefix]", handler_func=self.list_jobs, ), CommandSpec( name="delete_job", description="delete a job and persisted workspace", usage="delete_job job_id", handler_func=self.delete_job, authz_func=self.authorize_job, ), CommandSpec( name="abort_job", description="abort a job if it is running or dispatched", usage="abort_job job_id", handler_func=self.abort_job, # see if running, if running, send abort command authz_func=self.authorize_job, ), CommandSpec( name="clone_job", description="clone a job with a new job_id", usage="clone_job job_id", handler_func=self.clone_job, authz_func=self.authorize_job, ), ], )
[docs] def authorize_job(self, conn: Connection, args: List[str]): if len(args) != 2: conn.append_error("syntax error: missing job_id") return False, None job_id = args[1].lower() conn.set_prop(self.JOB_ID, job_id) engine = conn.app_ctx job_def_manager = engine.job_def_manager with engine.new_context() as fl_ctx: job = job_def_manager.get_job(job_id, fl_ctx) if not job: conn.append_error(f"Job with ID {job_id} doesn't exist") return False, None return self.authorize_job_meta(conn, job.meta, [Action.TRAIN])
[docs] def list_jobs(self, conn: Connection, args: List[str]): try: parser = SafeArgumentParser(prog="list_jobs") parser.add_argument("job_id", nargs="?", help="Job ID prefix") parser.add_argument("-d", action="store_true", help="Show detailed list") parser.add_argument("-n", help="Filter by job name prefix") parsed_args = parser.parse_args(args[1:]) engine = conn.app_ctx job_def_manager = engine.job_def_manager if not isinstance(job_def_manager, JobDefManagerSpec): raise TypeError( f"job_def_manager in engine is not of type JobDefManagerSpec, but got {type(job_def_manager)}" ) with engine.new_context() as fl_ctx: jobs = job_def_manager.get_all_jobs(fl_ctx) if jobs: id_prefix = parsed_args.job_id name_prefix = parsed_args.n filtered_jobs = [job for job in jobs if self._job_match(job.meta, id_prefix, name_prefix)] if not filtered_jobs: conn.append_error("No jobs matching the searching criteria") return # Can't use authz_func so do authorization one by one authorized_jobs = [job for job in filtered_jobs if self._job_authorized(conn, job)] authorized_jobs.sort(key=lambda job: job.meta.get(JobMetaKey.SUBMIT_TIME, 0.0)) if parsed_args.d: self._send_detail_list(conn, authorized_jobs) else: self._send_summary_list(conn, authorized_jobs) diff = set([job.job_id for job in filtered_jobs]) - set([job.job_id for job in authorized_jobs]) if diff: self.logger.debug(f"Following jobs are not authorized for listing: {diff}") conn.append_string("Some jobs are not listed due to permission restrictions") else: conn.append_string("No jobs.") except Exception as e: conn.append_error(str(e)) return conn.append_success("")
[docs] def delete_job(self, conn: Connection, args: List[str]): job_id = conn.get_prop(self.JOB_ID) engine = conn.app_ctx try: if not isinstance(engine, ServerEngine): raise TypeError(f"engine is not of type ServerEngine, but got {type(engine)}") job_def_manager = engine.job_def_manager if not isinstance(job_def_manager, JobDefManagerSpec): raise TypeError( f"job_def_manager in engine is not of type JobDefManagerSpec, but got {type(job_def_manager)}" ) with engine.new_context() as fl_ctx: job = job_def_manager.get_job(job_id, fl_ctx) if not job: conn.append_error(f"job: {job_id} does not exist") return if job.meta.get(JobMetaKey.STATUS, "") in [RunStatus.DISPATCHED.value, RunStatus.RUNNING.value]: conn.append_error(f"job: {job_id} is running, could not be deleted at this time.") return job_def_manager.delete(job_id, fl_ctx) conn.append_string("Job {} deleted.".format(job_id)) except Exception as e: conn.append_error("exception occurred: " + str(e)) return conn.append_success("")
[docs] def abort_job(self, conn: Connection, args: List[str]): engine = conn.app_ctx job_runner = engine.job_runner try: job_id = conn.get_prop(self.JOB_ID) job_runner.stop_run(job_id, engine.new_context()) conn.append_string("Abort signal has been sent to the server app.") conn.append_success("") except Exception as e: conn.append_error("Exception occurred trying to abort job: " + str(e)) return
[docs] def clone_job(self, conn: Connection, args: List[str]): job_id = conn.get_prop(self.JOB_ID) engine = conn.app_ctx try: if not isinstance(engine, ServerEngine): raise TypeError(f"engine is not of type ServerEngine, but got {type(engine)}") job_def_manager = engine.job_def_manager if not isinstance(job_def_manager, JobDefManagerSpec): raise TypeError( f"job_def_manager in engine is not of type JobDefManagerSpec, but got {type(job_def_manager)}" ) with engine.new_context() as fl_ctx: job = job_def_manager.get_job(job_id, fl_ctx) data_bytes = job_def_manager.get_content(job_id, fl_ctx) meta = job_def_manager.create(job.meta, data_bytes, fl_ctx) conn.append_string("Cloned job {} as: {}".format(job_id, meta.get(JobMetaKey.JOB_ID))) except Exception as e: conn.append_error("Exception occurred trying to clone job: " + str(e)) return conn.append_success("")
@staticmethod def _job_match(job_meta: Dict, id_prefix: str, name_prefix: str) -> bool: return ((not id_prefix) or job_meta.get("job_id").lower().startswith(id_prefix.lower())) and ( (not name_prefix) or job_meta.get("name").lower().startswith(name_prefix.lower()) ) @staticmethod def _send_detail_list(conn: Connection, jobs: List[Job]): for job in jobs: JobCommandModule._set_duration(job) conn.append_string(json.dumps(job.meta, indent=4)) @staticmethod def _send_summary_list(conn: Connection, jobs: List[Job]): table = Table(["Job ID", "Name", "Status", "Submit Time", "Run Duration"]) for job in jobs: JobCommandModule._set_duration(job) table.add_row( [ job.meta.get(JobMetaKey.JOB_ID, ""), CommandUtil.get_job_name(job.meta), job.meta.get(JobMetaKey.STATUS, ""), job.meta.get(JobMetaKey.SUBMIT_TIME_ISO, ""), str(job.meta.get(JobMetaKey.DURATION, "N/A")), ] ) writer = io.StringIO() table.write(writer) conn.append_string(writer.getvalue()) @staticmethod def _set_duration(job): if job.meta.get(JobMetaKey.STATUS) == RunStatus.RUNNING.value: start_time = datetime.datetime.strptime(job.meta.get(JobMetaKey.START_TIME), "%Y-%m-%d %H:%M:%S.%f") duration = - start_time job.meta[JobMetaKey.DURATION] = str(duration) def _job_authorized(self, conn: Connection, job: Job) -> bool: valid, authz_ctx = self.authorize_job_meta(conn, job.meta, [Action.VIEW]) if not valid: return False authz_ctx.user_name = conn.get_prop(ConnProps.USER_NAME, "") conn.set_prop(ConnProps.AUTHZ_CTX, authz_ctx) authorizer = AuthorizationService.get_authorizer() authorized, err = authorizer.authorize(ctx=authz_ctx) if err: self.logger.debug("Authorization Error to view job {}: {}".format(job.job_id, err)) return False if not authorized: self.logger.debug(f"View action for job {job.job_id} is not authorized") return False return True