# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import logging
import time
from typing import List
from nvflare.apis.client import Client
from nvflare.apis.fl_constant import AdminCommandNames
from nvflare.fuel.hci.conn import Connection
from nvflare.fuel.hci.proto import ConfirmMethod, MetaKey, MetaStatusValue, make_meta
from nvflare.fuel.hci.reg import CommandModule, CommandModuleSpec, CommandSpec
from nvflare.private.admin_defs import MsgHeader, ReturnCode
from nvflare.private.defs import ClientStatusKey, ScopeInfoKey, TrainingTopic
from nvflare.private.fed.server.admin import new_message
from nvflare.private.fed.server.server_engine_internal_spec import ServerEngineInternalSpec
from nvflare.private.fed.utils.fed_utils import get_scope_info
from nvflare.security.logging import secure_format_exception
from .cmd_utils import CommandUtil
from .server_engine import ServerEngine
[docs]class TrainingCommandModule(CommandModule, CommandUtil):
def __init__(self):
"""A class for training commands."""
super().__init__()
self.logger = logging.getLogger(self.__class__.__name__)
[docs] def get_spec(self):
return CommandModuleSpec(
name="training",
cmd_specs=[
CommandSpec(
name=AdminCommandNames.CHECK_STATUS,
description="check status of the FL server/client",
usage="check_status server|client",
handler_func=self.check_status,
authz_func=self.authorize_server_operation,
visible=True,
),
CommandSpec(
name=AdminCommandNames.REMOVE_CLIENT,
description="remove a FL client",
usage="remove_client <client-name>",
handler_func=self.remove_client,
authz_func=self.authorize_client_operation,
visible=True,
confirm=ConfirmMethod.AUTH,
),
CommandSpec(
name=AdminCommandNames.ADMIN_CHECK_STATUS,
description="check status for project admin",
usage="admin_check_status server|client",
handler_func=self.check_status,
authz_func=self.must_be_project_admin,
visible=False,
),
CommandSpec(
name=AdminCommandNames.SHUTDOWN,
description="shutdown the FL server/client",
usage="shutdown server|client|all",
handler_func=self.shutdown,
authz_func=self.authorize_server_operation,
visible=True,
confirm=ConfirmMethod.AUTH,
),
CommandSpec(
name=AdminCommandNames.RESTART,
description="restart FL server and/or clients",
usage="restart server|client|all [clients]",
handler_func=self.restart,
authz_func=self.authorize_server_operation,
visible=True,
confirm=ConfirmMethod.AUTH,
),
CommandSpec(
name=AdminCommandNames.SHOW_SCOPES,
description="show configured scope names on server/client",
usage="show_scopes server|client|all ...",
handler_func=self.show_scopes,
authz_func=self.authorize_server_operation,
visible=True,
),
],
)
# Shutdown
def _shutdown_app_on_server(self, conn: Connection) -> str:
engine = conn.app_ctx
err = engine.shutdown_server()
if err:
conn.append_error(err)
return err
else:
conn.append_string("FL app has been shutdown.")
conn.append_shutdown("Goodbye!")
return ""
def _shutdown_app_on_clients(self, conn: Connection) -> bool:
message = new_message(conn, topic=TrainingTopic.SHUTDOWN, body="", require_authz=True)
clients = conn.get_prop(self.TARGET_CLIENT_TOKENS, None)
if not clients:
# no clients to shut down - this is okay
return True
replies = self.send_request_to_clients(conn, message)
self.process_replies_to_table(conn, replies)
clients_to_be_removed = set(clients)
for r in replies:
if r.reply and r.reply.get_header(MsgHeader.RETURN_CODE) == ReturnCode.ERROR:
clients_to_be_removed.remove(r.client_token)
result = True
if clients_to_be_removed != set(clients):
# means some clients can not be shutdown
result = False
return result
[docs] def shutdown(self, conn: Connection, args: List[str]):
target_type = args[1]
engine = conn.app_ctx
if not isinstance(engine, ServerEngine):
raise TypeError("engine must be ServerEngine but got {}".format(type(engine)))
for _, job in engine.job_runner.running_jobs.items():
if not job.run_aborted:
conn.append_error(
"There are still jobs running. Please let them finish or abort_job before shutdown.",
meta=make_meta(MetaStatusValue.JOB_RUNNING, info=job.job_id),
)
return
if target_type == self.TARGET_TYPE_SERVER:
if engine.get_clients():
conn.append_error(
"There are still active clients. Shutdown all clients first.",
meta=make_meta(MetaStatusValue.CLIENTS_RUNNING),
)
return
if target_type in [self.TARGET_TYPE_CLIENT, self.TARGET_TYPE_ALL]:
# must shut down clients first
success = self._shutdown_app_on_clients(conn)
if not success:
conn.update_meta(make_meta(MetaStatusValue.ERROR, "failed to shut down all clients"))
return
if target_type in [self.TARGET_TYPE_SERVER, self.TARGET_TYPE_ALL]:
# shut down the server
err = self._shutdown_app_on_server(conn)
if err:
conn.update_meta(make_meta(MetaStatusValue.ERROR, info=err))
return
conn.append_success("")
# Remove Clients
[docs] def remove_client(self, conn: Connection, args: List[str]):
engine = conn.app_ctx
if not isinstance(engine, ServerEngineInternalSpec):
raise TypeError("engine must be ServerEngineInternalSpec but got {}".format(type(engine)))
clients = conn.get_prop(self.TARGET_CLIENT_TOKENS)
err = engine.remove_clients(clients)
if err:
conn.append_error(err)
return
conn.append_success("")
# Restart
def _restart_clients(self, conn) -> str:
engine = conn.app_ctx
if not isinstance(engine, ServerEngineInternalSpec):
raise TypeError("engine must be ServerEngineInternalSpec but got {}".format(type(engine)))
message = new_message(conn, topic=TrainingTopic.RESTART, body="", require_authz=True)
replies = self.send_request_to_clients(conn, message)
# engine.remove_clients(clients)
return self._process_replies_to_string(conn, replies)
[docs] def restart(self, conn: Connection, args: List[str]):
engine = conn.app_ctx
if not isinstance(engine, ServerEngine):
raise TypeError("engine must be ServerEngine but got {}".format(type(engine)))
if engine.job_runner.running_jobs:
msg = "There are still jobs running. Please let them finish or abort_job before restart."
conn.append_error(msg, meta=make_meta(MetaStatusValue.JOB_RUNNING, msg))
return
target_type = args[1]
if target_type in [self.TARGET_TYPE_SERVER, self.TARGET_TYPE_ALL]:
clients = engine.get_clients()
if clients:
conn.append_string("Trying to restart all clients before restarting server...")
tokens = [c.token for c in clients]
conn.set_prop(
self.TARGET_CLIENT_TOKENS, tokens
) # need this because not set in validate_command_targets when target_type == self.TARGET_TYPE_SERVER
response = self._restart_clients(conn)
conn.append_string(response)
# check with Isaac - no need to wait!
# time.sleep(5)
err = engine.restart_server()
if err:
conn.append_error(err, meta={MetaKey.SERVER_STATUS: MetaStatusValue.ERROR, MetaKey.INFO: err})
else:
conn.append_string("Server scheduled for restart", meta={MetaKey.SERVER_STATUS: MetaStatusValue.OK})
# ask the admin client to shut down since its current session will become invalid after
# the server is restarted.
# conn.append_shutdown("Goodbye!")
elif target_type == self.TARGET_TYPE_CLIENT:
clients = conn.get_prop(self.TARGET_CLIENT_TOKENS)
if not clients:
conn.append_error("no clients available", meta=make_meta(MetaStatusValue.NO_CLIENTS, "no clients"))
return
else:
response = self._restart_clients(conn)
conn.append_string(response)
conn.append_success("")
# Check status
[docs] def check_status(self, conn: Connection, args: List[str]):
# TODO:: Need more discussion on what status to be shown
engine = conn.app_ctx
if not isinstance(engine, ServerEngineInternalSpec):
raise TypeError("engine must be ServerEngineInternalSpec but got {}".format(type(engine)))
dst = args[1]
if dst in [self.TARGET_TYPE_SERVER, self.TARGET_TYPE_ALL]:
engine_info = engine.get_engine_info()
conn.append_string(
f"Engine status: {engine_info.status.value}",
meta=make_meta(
MetaStatusValue.OK,
extra={
MetaKey.SERVER_STATUS: engine_info.status.value,
MetaKey.SERVER_START_TIME: engine_info.start_time,
},
),
)
table = conn.append_table(["job_id", "app name"], name=MetaKey.JOBS)
for job_id, app_name in engine_info.app_names.items():
table.add_row([job_id, app_name], meta={MetaKey.APP_NAME: app_name, MetaKey.JOB_ID: job_id})
clients = engine.get_clients()
conn.append_string("Registered clients: {} ".format(len(clients)))
if clients:
table = conn.append_table(["client", "token", "last connect time"], name=MetaKey.CLIENTS)
for c in clients:
if not isinstance(c, Client):
raise TypeError("c must be Client but got {}".format(type(c)))
table.add_row(
[c.name, str(c.token), time.asctime(time.localtime(c.last_connect_time))],
meta={MetaKey.CLIENT_NAME: c.name, MetaKey.CLIENT_LAST_CONNECT_TIME: c.last_connect_time},
)
if dst in [self.TARGET_TYPE_CLIENT, self.TARGET_TYPE_ALL]:
message = new_message(conn, topic=TrainingTopic.CHECK_STATUS, body="", require_authz=True)
replies = self.send_request_to_clients(conn, message)
self._process_client_status_replies(conn, replies)
if dst not in [self.TARGET_TYPE_ALL, self.TARGET_TYPE_CLIENT, self.TARGET_TYPE_SERVER]:
conn.append_error(
f"invalid target type {dst}. Usage: check_status server|client ...",
meta=make_meta(MetaStatusValue.SYNTAX_ERROR, f"invalid target type {dst}"),
)
def _process_client_status_replies(self, conn, replies):
if not replies:
conn.append_error("no responses from clients")
return
table = conn.append_table(["client", "app_name", "job_id", "status"], name=MetaKey.CLIENT_STATUS)
for r in replies:
job_id = "?"
app_name = "?"
client_name = r.client_name
if r.reply:
if r.reply.get_header(MsgHeader.RETURN_CODE) == ReturnCode.ERROR:
table.add_row(
[client_name, app_name, job_id, r.reply.body],
meta={MetaKey.CLIENT_NAME: client_name, MetaKey.STATUS: MetaStatusValue.ERROR},
)
else:
try:
body = json.loads(r.reply.body)
if isinstance(body, dict):
running_jobs = body.get(ClientStatusKey.RUNNING_JOBS)
if running_jobs:
for job in running_jobs:
app_name = job.get(ClientStatusKey.APP_NAME, "?")
job_id = job.get(ClientStatusKey.JOB_ID, "?")
status = job.get(ClientStatusKey.STATUS, "?")
table.add_row(
[client_name, app_name, job_id, status],
meta={
MetaKey.CLIENT_NAME: client_name,
MetaKey.APP_NAME: app_name,
MetaKey.JOB_ID: job_id,
MetaKey.STATUS: status,
},
)
else:
table.add_row(
[client_name, app_name, job_id, "No Jobs"],
meta={MetaKey.CLIENT_NAME: client_name, MetaKey.STATUS: MetaStatusValue.NO_JOBS},
)
except Exception as e:
self.logger.error(f"Bad reply from client: {secure_format_exception(e)}")
else:
table.add_row(
[client_name, app_name, job_id, "No Reply"],
meta={MetaKey.CLIENT_NAME: client_name, MetaKey.STATUS: MetaStatusValue.NO_REPLY},
)
def _add_scope_info(self, table, site_name, scope_names: List[str], default_scope: str):
if not scope_names:
names = ""
else:
names = ", ".join(scope_names)
table.add_row([site_name, names, default_scope])
def _process_scope_replies(self, table, conn, replies):
if not replies:
conn.append_error("no responses from clients")
return
for r in replies:
client_name = r.client_name
if r.reply:
if r.reply.get_header(MsgHeader.RETURN_CODE) == ReturnCode.ERROR:
self._add_scope_info(table, client_name, r.reply.body, "")
else:
try:
body = json.loads(r.reply.body)
if isinstance(body, dict):
scope_names = body.get(ScopeInfoKey.SCOPE_NAMES)
default_scope = body.get(ScopeInfoKey.DEFAULT_SCOPE)
self._add_scope_info(table, client_name, scope_names, default_scope)
else:
conn.append_error(
f"bad response from client {client_name}: expect dict but got {type(body)}"
)
except Exception as e:
self.logger.error(f"Bad reply from client: {secure_format_exception(e)}")
conn.append_error(f"bad response from client {client_name}: {secure_format_exception(e)}")
else:
self._add_scope_info(table, client_name, [], "no reply")
[docs] def show_scopes(self, conn: Connection, args: List[str]):
engine = conn.app_ctx
if not isinstance(engine, ServerEngineInternalSpec):
raise TypeError("engine must be ServerEngineInternalSpec but got {}".format(type(engine)))
dst = args[1]
table = conn.append_table(["site", "scopes", "default"])
if dst in [self.TARGET_TYPE_SERVER, self.TARGET_TYPE_ALL]:
# get the server's scope info
scope_names, default_scope_name = get_scope_info()
self._add_scope_info(table, "server", scope_names, default_scope_name)
if dst in [self.TARGET_TYPE_CLIENT, self.TARGET_TYPE_ALL]:
message = new_message(conn, topic=TrainingTopic.GET_SCOPES, body="", require_authz=True)
replies = self.send_request_to_clients(conn, message)
self._process_scope_replies(table, conn, replies)