Source code for nvflare.private.fed.server.training_cmds

# Copyright (c) 2021-2022, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import json
import logging
import time
from typing import List

from nvflare.apis.client import Client
from nvflare.apis.fl_constant import AdminCommandNames, WorkspaceConstants
from nvflare.fuel.hci.conn import Connection
from nvflare.fuel.hci.reg import CommandModule, CommandModuleSpec, CommandSpec
from nvflare.private.defs import ClientStatusKey, RequestHeader, TrainingTopic
from nvflare.private.fed.server.admin import new_message
from nvflare.private.fed.server.server_engine_internal_spec import ServerEngineInternalSpec
from nvflare.security.security import Action, FLAuthzContext

from .app_authz import AppAuthzService
from .cmd_utils import CommandUtil
from .server_engine import ServerEngine


[docs]class TrainingCommandModule(CommandModule, CommandUtil): APP_STAGING_PATH = "app_staging_path" def __init__(self): """A class for training commands.""" super().__init__() self.logger = logging.getLogger(self.__class__.__name__)
[docs] def get_spec(self): return CommandModuleSpec( name="training", cmd_specs=[ CommandSpec( name=AdminCommandNames.DELETE_WORKSPACE, description="delete the workspace of a job", usage="delete_workspace job_id", handler_func=self.delete_job_id, authz_func=self.authorize_set_job_id, visible=False, confirm="auth", ), CommandSpec( name=AdminCommandNames.DEPLOY_APP, description="deploy FL app to client/server", usage="deploy_app job_id app server|client <client-name>|all", handler_func=self.deploy_app, authz_func=self.authorize_deploy_app, visible=False, ), CommandSpec( name=AdminCommandNames.START_APP, description="start the FL app", usage="start_app job_id server|client|all", handler_func=self.start_app, authz_func=self.authorize_train, visible=True, ), CommandSpec( name=AdminCommandNames.CHECK_STATUS, description="check status of the FL server/client", usage="check_status server|client", handler_func=self.check_status, authz_func=self.authorize_view, visible=True, ), CommandSpec( name=AdminCommandNames.ABORT, description="abort the FL app", usage="abort job_id server|client|all", handler_func=self.abort_app, authz_func=self.authorize_train, visible=False, ), CommandSpec( name=AdminCommandNames.ABORT_TASK, description="abort the client current task execution", usage="abort_task job_id <client-name>", handler_func=self.abort_task, authz_func=self.authorize_abort_client, visible=True, ), CommandSpec( name=AdminCommandNames.REMOVE_CLIENT, description="remove a FL client", usage="remove_client <client-name>", handler_func=self.remove_client, authz_func=self.authorize_remove_client, visible=True, confirm="auth", ), CommandSpec( name=AdminCommandNames.SHUTDOWN, description="shutdown the FL server/client", usage="shutdown server|client|all", handler_func=self.shutdown, authz_func=self.authorize_operate, visible=True, confirm="auth", ), CommandSpec( name=AdminCommandNames.RESTART, description="restart the FL server/client", usage="restart server|client|all", handler_func=self.restart, authz_func=self.authorize_operate, visible=True, confirm="auth", ), CommandSpec( name=AdminCommandNames.SET_TIMEOUT, description="set the admin commands timeout", usage="set_timeout seconds ", handler_func=self.set_timeout, authz_func=self.authorize_set_timeout, visible=True, ), ], )
[docs] def authorize_set_job_id(self, conn: Connection, args: List[str]): if len(args) < 2: conn.append_error("syntax error: missing job id") return False, None return True, FLAuthzContext.new_authz_context(site_names=[self.SITE_SERVER], actions=[Action.TRAIN])
def _set_job_id_clients(self, conn: Connection, job_id) -> bool: engine = conn.app_ctx clients = engine.get_clients() if clients: valid_tokens = [] for c in clients: valid_tokens.append(c.token) conn.set_prop(self.TARGET_CLIENT_TOKENS, valid_tokens) message = new_message(conn, topic=TrainingTopic.SET_JOB_ID, body="") message.set_header(RequestHeader.JOB_ID, str(job_id)) replies = self.send_request_to_clients(conn, message) self.process_replies_to_table(conn, replies) return True
[docs] def delete_job_id(self, conn: Connection, args: List[str]): job_id = args[1] engine = conn.app_ctx if not isinstance(engine, ServerEngine): raise TypeError("engine must be ServerEngine but got {}".format(type(engine))) if job_id in engine.run_processes.keys(): conn.append_error(f"Current running run_{job_id} can not be deleted.") return err = engine.delete_job_id(job_id) if err: conn.append_error(err) return # ask clients to delete this RUN message = new_message(conn, topic=TrainingTopic.DELETE_RUN, body="") message.set_header(RequestHeader.JOB_ID, str(job_id)) clients = engine.get_clients() if clients: conn.set_prop(self.TARGET_CLIENT_TOKENS, [x.token for x in clients]) replies = self.send_request_to_clients(conn, message) self.process_replies_to_table(conn, replies) conn.append_success("")
# Deploy
[docs] def authorize_deploy_app(self, conn: Connection, args: List[str]): if len(args) < 4: conn.append_error("syntax error: missing job_id and target") return False, None engine = conn.app_ctx if not isinstance(engine, ServerEngineInternalSpec): raise TypeError("engine must be ServerEngineInternalSpec but got {}".format(type(engine))) err = self.validate_command_targets(conn, args[3:]) if err: conn.append_error(err) return False, None run_destination = args[1].lower() if not run_destination.startswith(WorkspaceConstants.WORKSPACE_PREFIX): conn.append_error("syntax error: run_destination must be run_XXX") return False, None destination = run_destination[len(WorkspaceConstants.WORKSPACE_PREFIX) :] conn.set_prop(self.JOB_ID, destination) app_name = args[2] app_staging_path = engine.get_staging_path_of_app(app_name) if not app_staging_path: conn.append_error("App {} does not exist. Please upload it first".format(app_name)) return False, None conn.set_prop(self.APP_STAGING_PATH, app_staging_path) target_type = args[3] if target_type == self.TARGET_TYPE_SERVER: sites = [self.SITE_SERVER] else: sites = [] client_names = conn.get_prop(self.TARGET_CLIENT_NAMES) if client_names: sites.extend(client_names) if target_type == self.TARGET_TYPE_ALL: sites.append(self.SITE_SERVER) err, authz_ctx = AppAuthzService.authorize_deploy(app_staging_path, sites) if err: conn.append_error(err) return False, None else: return True, authz_ctx
def _deploy_to_clients(self, conn: Connection, app_name, job_id) -> bool: # return True if successful engine = conn.app_ctx err, app_data = engine.get_app_data(app_name) if err: conn.append_error(err) return False message = new_message(conn, topic=TrainingTopic.DEPLOY, body=app_data) message.set_header(RequestHeader.JOB_ID, str(job_id)) message.set_header(RequestHeader.APP_NAME, app_name) replies = self.send_request_to_clients(conn, message) self.process_replies_to_table(conn, replies) return True def _deploy_to_server(self, conn, job_id, app_name, app_staging_path) -> bool: # return True if successful engine = conn.app_ctx err = engine.deploy_app_to_server(job_id, app_name, app_staging_path) if not err: conn.append_string('deployed app "{}" to Server'.format(app_name)) return True else: conn.append_error(err) return False
[docs] def deploy_app(self, conn: Connection, args: List[str]): app_name = args[2] job_id = conn.get_prop(self.JOB_ID) target_type = conn.get_prop(self.TARGET_TYPE) app_staging_path = conn.get_prop(self.APP_STAGING_PATH) if target_type == self.TARGET_TYPE_SERVER: if not self._deploy_to_server(conn, job_id, app_name, app_staging_path): return elif target_type == self.TARGET_TYPE_CLIENT: if not self._deploy_to_clients(conn, app_name, job_id): return else: # all success = self._deploy_to_server(conn, job_id, app_name, app_staging_path) if success: client_names = conn.get_prop(self.TARGET_CLIENT_NAMES, None) if client_names: if not self._deploy_to_clients(conn, app_name, job_id): return else: return conn.append_success("")
# Start App def _start_app_on_server(self, conn: Connection, job_id: str) -> bool: engine = conn.app_ctx err = engine.start_app_on_server(job_id) if err: conn.append_error(err) return False else: conn.append_string("Server app is starting....") return True def _start_app_on_clients(self, conn: Connection, job_id: str) -> bool: engine = conn.app_ctx err = engine.check_app_start_readiness(job_id) if err: conn.append_error(err) return False # run_info = engine.get_run_info() message = new_message(conn, topic=TrainingTopic.START, body="") # message.set_header(RequestHeader.JOB_ID, str(run_info.job_id)) message.set_header(RequestHeader.JOB_ID, job_id) replies = self.send_request_to_clients(conn, message) self.process_replies_to_table(conn, replies) return True
[docs] def start_app(self, conn: Connection, args: List[str]): engine = conn.app_ctx if not isinstance(engine, ServerEngineInternalSpec): raise TypeError("engine must be ServerEngineInternalSpec but got {}".format(type(engine))) job_id = conn.get_prop(self.JOB_ID) target_type = args[2] if target_type == self.TARGET_TYPE_SERVER: if not self._start_app_on_server(conn, job_id): return elif target_type == self.TARGET_TYPE_CLIENT: if not self._start_app_on_clients(conn, job_id): return else: # all success = self._start_app_on_server(conn, job_id) if success: client_names = conn.get_prop(self.TARGET_CLIENT_NAMES, None) if client_names: if not self._start_app_on_clients(conn, job_id): return conn.append_success("")
# Abort App def _abort_clients(self, conn, clients: List[str], job_id) -> bool: engine = conn.app_ctx if not isinstance(engine, ServerEngineInternalSpec): raise TypeError("engine must be ServerEngineInternalSpec but got {}".format(type(engine))) err = engine.abort_app_on_clients(clients) if err: conn.append_error(err) return False # run_info = engine.get_app_run_info(job_id) message = new_message(conn, topic=TrainingTopic.ABORT, body="") # if run_info: message.set_header(RequestHeader.JOB_ID, str(job_id)) # conn.set_prop(self.TARGET_CLIENT_NAMES, client_names) replies = self.send_request_to_clients(conn, message) self.process_replies_to_table(conn, replies) return True
[docs] def abort_app(self, conn: Connection, args: List[str]): engine = conn.app_ctx if not isinstance(engine, ServerEngineInternalSpec): raise TypeError("engine must be ServerEngineInternalSpec but got {}".format(type(engine))) job_id = conn.get_prop(self.JOB_ID) target_type = args[2] if target_type == self.TARGET_TYPE_SERVER or target_type == self.TARGET_TYPE_ALL: conn.append_string("Trying to abort all clients before abort server ...") clients = engine.get_clients() if clients: tokens = [c.token for c in clients] conn.set_prop( self.TARGET_CLIENT_TOKENS, tokens ) # need this because not set in validate_command_targets when target_type == self.TARGET_TYPE_SERVER if not self._abort_clients(conn, clients=[c.token for c in clients], job_id=job_id): return err = engine.abort_app_on_server(job_id) if err: conn.append_error(err) return conn.append_string("Abort signal has been sent to the server app.") elif target_type == self.TARGET_TYPE_CLIENT: clients = conn.get_prop(self.TARGET_CLIENT_TOKENS) if not clients: conn.append_string("No clients to abort") return if not self._abort_clients(conn, clients, job_id): return conn.append_success("")
[docs] def abort_task(self, conn, clients: List[str]) -> str: engine = conn.app_ctx if not isinstance(engine, ServerEngineInternalSpec): raise TypeError("engine must be ServerEngineInternalSpec but got {}".format(type(engine))) job_id = conn.get_prop(self.JOB_ID) # run_info = engine.get_app_run_info() message = new_message(conn, topic=TrainingTopic.ABORT_TASK, body="") # if run_info: message.set_header(RequestHeader.JOB_ID, str(job_id)) # conn.set_prop(self.TARGET_CLIENT_NAMES, client_names) replies = self.send_request_to_clients(conn, message) return self.process_replies_to_table(conn, replies)
# Shutdown def _shutdown_app_on_server(self, conn: Connection) -> bool: engine = conn.app_ctx err = engine.shutdown_server() if err: conn.append_error(err) return False else: conn.append_string("FL app has been shutdown.") conn.append_shutdown("Bye bye") return True def _shutdown_app_on_clients(self, conn: Connection) -> bool: engine = conn.app_ctx message = new_message(conn, topic=TrainingTopic.SHUTDOWN, body="") clients = conn.get_prop(self.TARGET_CLIENT_TOKENS, None) if not clients: conn.append_error("no clients to shutdown") return False replies = self.send_request_to_clients(conn, message) self.process_replies_to_table(conn, replies) err = engine.remove_clients(clients) if err: conn.append_error(err) return False return True
[docs] def shutdown(self, conn: Connection, args: List[str]): target_type = args[1] engine = conn.app_ctx if not isinstance(engine, ServerEngine): raise TypeError("engine must be ServerEngine but got {}".format(type(engine))) if engine.job_runner.running_jobs: conn.append_error("There are still jobs running. Please let them finish or abort_job before shutdown.") return if target_type == self.TARGET_TYPE_SERVER: if engine.get_clients(): conn.append_error("There are still active clients. Shutdown all clients first.") return if not self._shutdown_app_on_server(conn): return elif target_type == self.TARGET_TYPE_CLIENT: if not self._shutdown_app_on_clients(conn): return else: # all if engine.get_clients(): conn.append_string("Trying to shutdown clients before server...") success = self._shutdown_app_on_clients(conn) if success: if not self._shutdown_app_on_server(conn): return else: if not self._shutdown_app_on_server(conn): return conn.append_success("")
# Remove Clients
[docs] def authorize_remove_client(self, conn: Connection, args: List[str]): if len(args) < 2: conn.append_error("syntax error: missing site names") return False, None auth_args = [args[0], self.TARGET_TYPE_CLIENT] auth_args.extend(args[1:]) return self.authorize_operate(conn, auth_args)
[docs] def authorize_abort_client(self, conn: Connection, args: List[str]): if len(args) < 3: conn.append_error("syntax error: missing job_id and target") return False, None run_destination = args[1].lower() if not run_destination.startswith(WorkspaceConstants.WORKSPACE_PREFIX): conn.append_error("syntax error: run_destination must be run_XXX") return False, None job_id = run_destination[len(WorkspaceConstants.WORKSPACE_PREFIX) :] conn.set_prop(self.JOB_ID, job_id) auth_args = [args[0], self.TARGET_TYPE_CLIENT] auth_args.extend(args[2:]) return self.authorize_operate(conn, auth_args)
[docs] def remove_client(self, conn: Connection, args: List[str]): engine = conn.app_ctx if not isinstance(engine, ServerEngineInternalSpec): raise TypeError("engine must be ServerEngineInternalSpec but got {}".format(type(engine))) clients = conn.get_prop(self.TARGET_CLIENT_TOKENS) err = engine.remove_clients(clients) if err: conn.append_error(err) return conn.append_success("")
# Restart def _restart_clients(self, conn, clients) -> str: engine = conn.app_ctx if not isinstance(engine, ServerEngineInternalSpec): raise TypeError("engine must be ServerEngineInternalSpec but got {}".format(type(engine))) engine.remove_clients(clients) message = new_message(conn, topic=TrainingTopic.RESTART, body="") replies = self.send_request_to_clients(conn, message) return self._process_replies_to_string(conn, replies)
[docs] def restart(self, conn: Connection, args: List[str]): engine = conn.app_ctx if not isinstance(engine, ServerEngine): raise TypeError("engine must be ServerEngine but got {}".format(type(engine))) if engine.job_runner.running_jobs: conn.append_error("There are still jobs running. Please let them finish or abort_job before restart.") return target_type = args[1] if target_type == self.TARGET_TYPE_SERVER or target_type == self.TARGET_TYPE_ALL: clients = engine.get_clients() if clients: conn.append_string("Trying to restart all clients before restarting server...") tokens = [c.token for c in clients] conn.set_prop( self.TARGET_CLIENT_TOKENS, tokens ) # need this because not set in validate_command_targets when target_type == self.TARGET_TYPE_SERVER response = self._restart_clients(conn, tokens) conn.append_string(response) # check with Isaac - no need to wait! # time.sleep(5) err = engine.restart_server() if err: conn.append_error(err) else: conn.append_string("Server scheduled for restart") elif target_type == self.TARGET_TYPE_CLIENT: clients = conn.get_prop(self.TARGET_CLIENT_TOKENS) if not clients: conn.append_error("no clients available") return else: response = self._restart_clients(conn, clients) conn.append_string(response) conn.append_success("")
# Set Timeout
[docs] def authorize_set_timeout(self, conn: Connection, args: List[str]): if len(args) != 2: conn.append_error("syntax error: missing timeout") return False, None try: num = float(args[1]) except ValueError: conn.append_error("must provide the timeout value in seconds") return False, None if num <= 0: conn.append_error("timeout must be > 0") return False, None return True, FLAuthzContext.new_authz_context(site_names=[self.SITE_SERVER], actions=[Action.TRAIN])
[docs] def set_timeout(self, conn: Connection, args: List[str]): timeout = float(args[1]) server = conn.server server.timeout = timeout conn.append_string("admin command timeout has been set to: {}".format(timeout)) conn.append_success("")
# Check status
[docs] def check_status(self, conn: Connection, args: List[str]): # TODO:: Need more discussion on what status to be shown engine = conn.app_ctx if not isinstance(engine, ServerEngineInternalSpec): raise TypeError("engine must be ServerEngineInternalSpec but got {}".format(type(engine))) dst = args[1] if dst == self.TARGET_TYPE_SERVER: engine_info = engine.get_engine_info() conn.append_string(f"Engine status: {engine_info.status.value}") table = conn.append_table(["Job_id", "App Name"]) for job_id, app_name in engine_info.app_names.items(): table.add_row([job_id, app_name]) clients = engine.get_clients() conn.append_string("Registered clients: {} ".format(len(clients))) if clients: table = conn.append_table(["Client", "Token", "Last Connect Time"]) for c in clients: if not isinstance(c, Client): raise TypeError("c must be Client but got {}".format(type(c))) table.add_row([c.name, str(c.token), time.asctime(time.localtime(c.last_connect_time))]) elif dst == self.TARGET_TYPE_CLIENT: message = new_message(conn, topic=TrainingTopic.CHECK_STATUS, body="") replies = self.send_request_to_clients(conn, message) self._process_status_replies(conn, replies) else: conn.append_error("invalid target type {}. Usage: check_status server|client ...".format(dst))
def _process_status_replies(self, conn, replies): if not replies: conn.append_error("no responses from clients") return engine = conn.app_ctx table = conn.append_table(["client", "app_name", "job_id", "status"]) for r in replies: job_id = "?" app_name = "?" client_name = engine.get_client_name_from_token(r.client_token) if r.reply: try: body = json.loads(r.reply.body) if r.reply and isinstance(body, dict): running_jobs = body.get(ClientStatusKey.RUNNING_JOBS) if running_jobs: for job in running_jobs: app_name = job.get(ClientStatusKey.APP_NAME, "?") job_id = job.get(ClientStatusKey.JOB_ID, "?") status = job.get(ClientStatusKey.STATUS, "?") table.add_row([client_name, app_name, job_id, status]) else: table.add_row([client_name, app_name, job_id, "No Jobs"]) except BaseException as ex: self.logger.error(f"Bad reply from client: {ex}") else: table.add_row([client_name, app_name, job_id, "No Reply"])