Source code for nvflare.private.fed.server.job_runner

# Copyright (c) 2022, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import shutil
import tempfile
import threading
import time
from typing import Dict, List, Tuple

from nvflare.apis.client import Client
from nvflare.apis.event_type import EventType
from nvflare.apis.fl_component import FLComponent
from nvflare.apis.fl_constant import AdminCommandNames, FLContextKey, RunProcessKey, SystemComponents
from nvflare.apis.fl_context import FLContext
from nvflare.apis.job_def import ALL_SITES, Job, JobMetaKey, RunStatus
from nvflare.apis.job_scheduler_spec import DispatchInfo
from nvflare.apis.workspace import Workspace
from nvflare.fuel.utils.argument_utils import parse_vars
from nvflare.fuel.utils.zip_utils import zip_directory_to_file
from nvflare.lighter.utils import verify_folder_signature
from nvflare.private.admin_defs import Message, MsgHeader, ReturnCode
from nvflare.private.defs import RequestHeader, TrainingTopic
from nvflare.private.fed.server.admin import check_client_replies
from nvflare.private.fed.server.server_state import HotState
from nvflare.private.fed.utils.app_deployer import AppDeployer
from nvflare.private.fed.utils.fed_utils import set_message_security_data
from nvflare.security.logging import secure_format_exception


def _send_to_clients(admin_server, client_sites: List[str], engine, message, timeout=None, optional=False):
    clients, invalid_inputs = engine.validate_targets(client_sites)
    if invalid_inputs:
        raise RuntimeError(f"unknown clients: {invalid_inputs}.")
    requests = {}
    for c in clients:
        requests.update({c.token: message})

    if timeout is None:
        timeout = admin_server.timeout
    with admin_server.sai.new_context() as fl_ctx:
        replies = admin_server.send_requests(requests, fl_ctx, timeout_secs=timeout, optional=optional)
    return replies


def _get_active_job_participants(connected_clients: Dict[str, Client], participants: Dict[str, Client]) -> List[str]:
    """Gets active job participants.

        Some clients might be dropped/dead during job execution.
        No need to abort those clients.

    Args:
        connected_clients: Clients that are currently connected.
        participants: Clients that were participating when the job started.

    Returns:
        A list of active job participants name.
    """
    client_sites_names = []
    for token, client in participants.items():
        if token in connected_clients:
            client_sites_names.append(client.name)

    return client_sites_names


[docs]class JobRunner(FLComponent): def __init__(self, workspace_root: str) -> None: super().__init__() self.workspace_root = workspace_root self.ask_to_stop = False self.scheduler = None self.running_jobs = {} self.lock = threading.Lock()
[docs] def handle_event(self, event_type: str, fl_ctx: FLContext): if event_type == EventType.SYSTEM_START: engine = fl_ctx.get_engine() self.scheduler = engine.get_component(SystemComponents.JOB_SCHEDULER) elif event_type in [EventType.JOB_COMPLETED, EventType.END_RUN]: self._save_workspace(fl_ctx) elif event_type == EventType.SYSTEM_END: self.stop()
@staticmethod def _make_deploy_message(job: Job, app_data, app_name, fl_ctx): message = Message(topic=TrainingTopic.DEPLOY, body=app_data) message.set_header(RequestHeader.REQUIRE_AUTHZ, "true") message.set_header(RequestHeader.ADMIN_COMMAND, AdminCommandNames.SUBMIT_JOB) message.set_header(RequestHeader.JOB_ID, job.job_id) message.set_header(RequestHeader.APP_NAME, app_name) set_message_security_data(message, job, fl_ctx) return message def _deploy_job(self, job: Job, sites: dict, fl_ctx: FLContext) -> Tuple[str, list]: """Deploy the application to the list of participants Args: job: job to be deployed sites: participating sites fl_ctx: FLContext Returns: job id, failed_clients """ fl_ctx.remove_prop(FLContextKey.JOB_RUN_NUMBER) fl_ctx.remove_prop(FLContextKey.JOB_DEPLOY_DETAIL) engine = fl_ctx.get_engine() run_number = job.job_id fl_ctx.set_prop(FLContextKey.JOB_RUN_NUMBER, run_number) workspace = Workspace(root_dir=self.workspace_root, site_name="server") client_deploy_requests = {} client_token_to_name = {} client_token_to_reply = {} deploy_detail = [] fl_ctx.set_prop(FLContextKey.JOB_DEPLOY_DETAIL, deploy_detail) for app_name, participants in job.get_deployment().items(): app_data = job.get_application(app_name, fl_ctx) if len(participants) == 1 and participants[0].upper() == ALL_SITES: participants = ["server"] participants.extend([client.name for client in engine.get_clients()]) client_sites = [] for p in participants: if p == "server": self.fire_event(EventType.DEPLOY_JOB_TO_SERVER, fl_ctx) app_deployer = AppDeployer() err = app_deployer.deploy( app_name=app_name, workspace=workspace, job_id=job.job_id, job_meta=job.meta, app_data=app_data, fl_ctx=fl_ctx, ) if err: deploy_detail.append(f"server: {err}") raise RuntimeError(f"Failed to deploy app '{app_name}': {err}") kv_list = parse_vars(engine.args.set) secure_train = kv_list.get("secure_train", True) from_hub_site = job.meta.get(JobMetaKey.FROM_HUB_SITE.value) if secure_train and not from_hub_site: app_path = workspace.get_app_dir(job.job_id) root_ca_path = os.path.join(workspace.get_startup_kit_dir(), "rootCA.pem") if not verify_folder_signature(app_path, root_ca_path): err = "job signature verification failed" deploy_detail.append(f"server: {err}") raise RuntimeError(f"Failed to verify app '{app_name}': {err}") self.log_info( fl_ctx, f"Application {app_name} deployed to the server for job: {run_number}", fire_event=False ) deploy_detail.append("server: OK") else: if p in sites: client_sites.append(p) if client_sites: self.fire_event(EventType.DEPLOY_JOB_TO_CLIENT, fl_ctx) message = self._make_deploy_message(job, app_data, app_name, fl_ctx) clients, invalid_inputs = engine.validate_targets(client_sites) if invalid_inputs: deploy_detail.append("invalid_clients: {}".format(",".join(invalid_inputs))) raise RuntimeError(f"unknown clients: {invalid_inputs}.") for c in clients: assert isinstance(c, Client) client_token_to_name[c.token] = c.name client_deploy_requests[c.token] = message client_token_to_reply[c.token] = None display_sites = ",".join(client_sites) self.log_info( fl_ctx, f"App {app_name} to be deployed to the clients: {display_sites} for run: {run_number}", fire_event=False, ) abort_job = False failed_clients = [] if client_deploy_requests: engine = fl_ctx.get_engine() admin_server = engine.server.admin_server client_token_to_reply = admin_server.send_requests_and_get_reply_dict( client_deploy_requests, timeout_secs=admin_server.timeout ) # check replies and see whether required clients are okay for client_token, reply in client_token_to_reply.items(): client_name = client_token_to_name[client_token] if reply: assert isinstance(reply, Message) rc = reply.get_header(MsgHeader.RETURN_CODE, ReturnCode.OK) if rc != ReturnCode.OK: failed_clients.append(client_name) deploy_detail.append(f"{client_name}: {reply.body}") else: deploy_detail.append(f"{client_name}: OK") else: deploy_detail.append(f"{client_name}: unknown") # see whether any of the failed clients are required if failed_clients: num_ok_sites = len(client_deploy_requests) - len(failed_clients) if job.min_sites and num_ok_sites < job.min_sites: abort_job = True deploy_detail.append(f"num_ok_sites {num_ok_sites} < required_min_sites {job.min_sites}") elif job.required_sites: for c in failed_clients: if c in job.required_sites: abort_job = True deploy_detail.append(f"failed to deploy to required client {c}") if abort_job: raise RuntimeError("deploy failure", deploy_detail) self.fire_event(EventType.JOB_DEPLOYED, fl_ctx) return run_number, failed_clients def _start_run(self, job_id: str, job: Job, client_sites: Dict[str, DispatchInfo], fl_ctx: FLContext): """Start the application Args: job_id: job_id client_sites: participating sites fl_ctx: FLContext """ engine = fl_ctx.get_engine() job_clients = engine.get_job_clients(client_sites) err = engine.start_app_on_server(job_id, job=job, job_clients=job_clients) if err: raise RuntimeError(f"Could not start the server App for job: {job_id}.") replies = engine.start_client_job(job_id, client_sites, fl_ctx) client_sites_names = list(client_sites.keys()) check_client_replies(replies=replies, client_sites=client_sites_names, command=f"start job ({job_id})") display_sites = ",".join(client_sites_names) self.log_info(fl_ctx, f"Started run: {job_id} for clients: {display_sites}") self.fire_event(EventType.JOB_STARTED, fl_ctx) def _stop_run(self, job_id, fl_ctx: FLContext): """Stop the application Args: job_id: job_id to be stopped fl_ctx: FLContext """ engine = fl_ctx.get_engine() run_process = engine.run_processes.get(job_id) if run_process: participants: Dict[str, Client] = run_process.get(RunProcessKey.PARTICIPANTS) active_client_sites_names = _get_active_job_participants( connected_clients=engine.client_manager.clients, participants=participants ) self.abort_client_run(job_id, active_client_sites_names, fl_ctx) err = engine.abort_app_on_server(job_id) if err: self.log_error(fl_ctx, f"Failed to abort the server for run: {job_id}: {err}")
[docs] def abort_client_run(self, job_id, client_sites: List[str], fl_ctx): """Send the abort run command to the clients Args: job_id: job_id client_sites: Clients to be aborted fl_ctx: FLContext """ engine = fl_ctx.get_engine() admin_server = engine.server.admin_server message = Message(topic=TrainingTopic.ABORT, body="") message.set_header(RequestHeader.JOB_ID, str(job_id)) self.log_debug(fl_ctx, f"Send abort command to the clients for run: {job_id}") try: _ = _send_to_clients(admin_server, client_sites, engine, message, timeout=2.0, optional=True) # There isn't much we can do here if a client didn't get the message or send a reply # check_client_replies(replies=replies, client_sites=client_sites, command="abort the run") except RuntimeError as e: self.log_error(fl_ctx, f"Failed to abort run ({job_id}) on the clients: {secure_format_exception(e)}")
def _delete_run(self, job_id, client_sites: List[str], fl_ctx: FLContext): """Deletes the run workspace Args: job_id: job_id client_sites: participating sites fl_ctx: FLContext """ engine = fl_ctx.get_engine() admin_server = engine.server.admin_server message = Message(topic=TrainingTopic.DELETE_RUN, body="") message.set_header(RequestHeader.JOB_ID, str(job_id)) self.log_debug(fl_ctx, f"Send delete_run command to the clients for run: {job_id}") try: replies = _send_to_clients(admin_server, client_sites, engine, message) check_client_replies(replies=replies, client_sites=client_sites, command="send delete_run command") except RuntimeError as e: self.log_error( fl_ctx, f"Failed to execute delete run ({job_id}) on the clients: {secure_format_exception(e)}" ) err = engine.delete_job_id(job_id) if err: self.log_error(fl_ctx, f"Failed to delete_run the server for run: {job_id}") def _job_complete_process(self, fl_ctx: FLContext): engine = fl_ctx.get_engine() job_manager = engine.get_component(SystemComponents.JOB_MANAGER) while not self.ask_to_stop: for job_id in list(self.running_jobs.keys()): if job_id not in engine.run_processes.keys(): job = self.running_jobs.get(job_id) if job: if not job.run_aborted: self._update_job_status(engine, job, job_manager, fl_ctx) with self.lock: del self.running_jobs[job_id] fl_ctx.set_prop(FLContextKey.CURRENT_JOB_ID, job.job_id) self.fire_event(EventType.JOB_COMPLETED, fl_ctx) self.log_debug(fl_ctx, f"Finished running job:{job.job_id}") engine.remove_exception_process(job_id) time.sleep(1.0) def _update_job_status(self, engine, job, job_manager, fl_ctx): exception_run_processes = engine.exception_run_processes if job.job_id in exception_run_processes: self.log_info(fl_ctx, f"Try to abort job ({job.job_id}) on clients ...") run_process = exception_run_processes[job.job_id] # stop client run participants: Dict[str, Client] = run_process.get(RunProcessKey.PARTICIPANTS) active_client_sites_names = _get_active_job_participants( connected_clients=engine.client_manager.clients, participants=participants ) self.abort_client_run(job.job_id, active_client_sites_names, fl_ctx) finished = run_process.get(RunProcessKey.PROCESS_FINISHED, False) if finished: # job status is already reported from the Job cell! exe_err = run_process.get(RunProcessKey.PROCESS_EXE_ERROR, False) if exe_err: status = RunStatus.FINISHED_EXECUTION_EXCEPTION else: status = RunStatus.FINISHED_COMPLETED else: # never got job status report from job cell process_return_code = run_process.get(RunProcessKey.PROCESS_RETURN_CODE) if process_return_code == -9: status = RunStatus.FINISHED_ABNORMAL else: status = RunStatus.FINISHED_EXECUTION_EXCEPTION else: status = RunStatus.FINISHED_COMPLETED job_manager.set_status(job.job_id, status, fl_ctx) def _save_workspace(self, fl_ctx: FLContext): job_id = fl_ctx.get_prop(FLContextKey.CURRENT_JOB_ID) workspace = Workspace(root_dir=self.workspace_root) run_dir = workspace.get_run_dir(job_id) engine = fl_ctx.get_engine() job_manager = engine.get_component(SystemComponents.JOB_MANAGER) with tempfile.TemporaryDirectory() as td: output_file = os.path.join(td, "workspace") zip_directory_to_file(run_dir, "", output_file) job_manager.save_workspace(job_id, output_file, fl_ctx) self.log_debug(fl_ctx, f"Workspace zipped to {output_file}") shutil.rmtree(run_dir)
[docs] def run(self, fl_ctx: FLContext): """Starts job runner.""" engine = fl_ctx.get_engine() job_manager = engine.get_component(SystemComponents.JOB_MANAGER) if job_manager: thread = threading.Thread(target=self._job_complete_process, args=[fl_ctx]) thread.start() while not self.ask_to_stop: time.sleep(1.0) if not isinstance(engine.server.server_state, HotState): continue if not engine.get_clients(): # no clients registered yet - don't try to schedule! continue approved_jobs = job_manager.get_jobs_to_schedule(fl_ctx) self.log_debug( fl_ctx, f"{fl_ctx.get_identity_name()} Got approved_jobs: {approved_jobs} from the job_manager" ) if self.scheduler: ready_job, sites = self.scheduler.schedule_job( job_manager=job_manager, job_candidates=approved_jobs, fl_ctx=fl_ctx ) if ready_job: if self._check_job_status(job_manager, ready_job.job_id, RunStatus.SUBMITTED, fl_ctx): self.log_info(fl_ctx, f"Job: {ready_job.job_id} is not in SUBMITTED. It won't be deployed.") continue client_sites = {k: v for k, v in sites.items() if k != "server"} job_id = None try: self.log_info(fl_ctx, f"Got the job: {ready_job.job_id} from the scheduler to run") fl_ctx.set_prop(FLContextKey.CURRENT_JOB_ID, ready_job.job_id) job_id, failed_clients = self._deploy_job(ready_job, sites, fl_ctx) job_manager.set_status(ready_job.job_id, RunStatus.DISPATCHED, fl_ctx) deploy_detail = fl_ctx.get_prop(FLContextKey.JOB_DEPLOY_DETAIL) if deploy_detail: job_manager.update_meta( ready_job.job_id, { JobMetaKey.JOB_DEPLOY_DETAIL.value: deploy_detail, JobMetaKey.SCHEDULE_COUNT.value: ready_job.meta[ JobMetaKey.SCHEDULE_COUNT.value ], JobMetaKey.LAST_SCHEDULE_TIME.value: ready_job.meta[ JobMetaKey.LAST_SCHEDULE_TIME.value ], JobMetaKey.SCHEDULE_HISTORY.value: ready_job.meta[ JobMetaKey.SCHEDULE_HISTORY.value ], }, fl_ctx, ) self.log_info(fl_ctx, f"Updated the schedule history of Job: {job_id}") if failed_clients: deployable_clients = {k: v for k, v in client_sites.items() if k not in failed_clients} else: deployable_clients = client_sites if self._check_job_status(job_manager, ready_job.job_id, RunStatus.DISPATCHED, fl_ctx): self.log_info( fl_ctx, f"Job: {ready_job.job_id} is not in DISPATCHED. It won't be start to run." ) continue self._start_run( job_id=job_id, job=ready_job, client_sites=deployable_clients, fl_ctx=fl_ctx, ) with self.lock: self.running_jobs[job_id] = ready_job job_manager.set_status(ready_job.job_id, RunStatus.RUNNING, fl_ctx) self.log_info(fl_ctx, f"Job: {job_id} started to run, status changed to RUNNING.") except Exception as e: if job_id: if job_id in self.running_jobs: with self.lock: del self.running_jobs[job_id] self._stop_run(job_id, fl_ctx) job_manager.set_status(ready_job.job_id, RunStatus.FAILED_TO_RUN, fl_ctx) deploy_detail = fl_ctx.get_prop(FLContextKey.JOB_DEPLOY_DETAIL) if deploy_detail: job_manager.update_meta( ready_job.job_id, {JobMetaKey.JOB_DEPLOY_DETAIL.value: deploy_detail}, fl_ctx ) self.fire_event(EventType.JOB_ABORTED, fl_ctx) self.log_error( fl_ctx, f"Failed to run the Job ({ready_job.job_id}): {secure_format_exception(e)}" ) thread.join() else: self.log_error(fl_ctx, "There's no Job Manager defined. Won't be able to run the jobs.")
@staticmethod def _check_job_status(job_manager, job_id, job_run_status, fl_ctx: FLContext): reload_job = job_manager.get_job(job_id, fl_ctx) return reload_job.meta.get(JobMetaKey.STATUS) != job_run_status
[docs] def stop(self): self.ask_to_stop = True
[docs] def restore_running_job(self, run_number: str, job_id: str, job_clients, snapshot, fl_ctx: FLContext): engine = fl_ctx.get_engine() try: job_manager = engine.get_component(SystemComponents.JOB_MANAGER) job = job_manager.get_job(jid=job_id, fl_ctx=fl_ctx) err = engine.start_app_on_server(run_number, job=job, job_clients=job_clients, snapshot=snapshot) if err: raise RuntimeError(f"Could not restore the server App for job: {job_id}.") with self.lock: self.running_jobs[job_id] = job self.scheduler.restore_scheduled_job(job_id) except Exception as e: self.log_error( fl_ctx, f"Failed to restore the job: {job_id} to the running job table: {secure_format_exception(e)}." )
[docs] def update_abnormal_finished_jobs(self, running_job_ids, fl_ctx: FLContext): engine = fl_ctx.get_engine() job_manager = engine.get_component(SystemComponents.JOB_MANAGER) all_jobs = self._get_all_running_jobs(job_manager, fl_ctx) for job in all_jobs: if job.job_id not in running_job_ids: try: job_manager.set_status(job.job_id, RunStatus.FINISHED_ABNORMAL, fl_ctx) self.logger.info(f"Update the previous running job: {job.job_id} to {RunStatus.FINISHED_ABNORMAL}.") except Exception as e: self.log_error( fl_ctx, f"Failed to update the job: {job.job_id} to {RunStatus.FINISHED_ABNORMAL}: " f"{secure_format_exception(e)}.", )
[docs] def update_unfinished_jobs(self, fl_ctx: FLContext): engine = fl_ctx.get_engine() job_manager = engine.get_component(SystemComponents.JOB_MANAGER) all_jobs = self._get_all_running_jobs(job_manager, fl_ctx) for job in all_jobs: try: job_manager.set_status(job.job_id, RunStatus.ABANDONED, fl_ctx) self.logger.info(f"Update the previous running job: {job.job_id} to {RunStatus.ABANDONED}.") except Exception as e: self.log_error( fl_ctx, f"Failed to update the job: {job.job_id} to {RunStatus.ABANDONED}: {secure_format_exception(e)}.", )
@staticmethod def _get_all_running_jobs(job_manager, fl_ctx): return job_manager.get_jobs_by_status([RunStatus.RUNNING, RunStatus.DISPATCHED], fl_ctx)
[docs] def stop_run(self, job_id: str, fl_ctx: FLContext): engine = fl_ctx.get_engine() job_manager = engine.get_component(SystemComponents.JOB_MANAGER) self._stop_run(job_id, fl_ctx) job = self.running_jobs.get(job_id) if job: self.log_info(fl_ctx, f"Stop the job run: {job_id}") fl_ctx.set_prop(FLContextKey.CURRENT_JOB_ID, job.job_id) job.run_aborted = True job_manager.set_status(job.job_id, RunStatus.FINISHED_ABORTED, fl_ctx) self.fire_event(EventType.JOB_ABORTED, fl_ctx) return "" else: self.log_error(fl_ctx, f"Job {job_id} is not running. It can not be stopped.") return f"Job {job_id} is not running."
[docs] def stop_all_runs(self, fl_ctx: FLContext): engine = fl_ctx.get_engine() for job_id in engine.run_processes.keys(): self.stop_run(job_id, fl_ctx) self.log_info(fl_ctx, "Stop all the running jobs.") # also stop the job runner self.ask_to_stop = True
[docs] def remove_running_job(self, job_id: str): with self.lock: if job_id in self.running_jobs: del self.running_jobs[job_id] self.scheduler.remove_scheduled_job(job_id)