# Copyright (c) 2021-2022, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""FL Admin commands."""
import copy
import time
from nvflare.apis.fl_constant import AdminCommandNames, FLContextKey, ReservedKey, ServerCommandKey, ServerCommandNames
from nvflare.apis.fl_context import FLContext
from nvflare.apis.shareable import Shareable
from nvflare.apis.utils.fl_context_utils import get_serializable_data
from nvflare.fuel.utils import fobs
from nvflare.widgets.widget import WidgetID
[docs]class CommandProcessor(object):
"""The CommandProcessor is responsible for processing a command from parent process."""
[docs] def get_command_name(self) -> str:
"""Get command name that this processor will handle.
Returns: name of the command
"""
pass
[docs] def process(self, data: Shareable, fl_ctx: FLContext):
"""Called to process the specified command.
Args:
data: process data
fl_ctx: FLContext
Return: reply message
"""
pass
[docs]class AbortCommand(CommandProcessor):
"""To implement the abort command."""
[docs] def get_command_name(self) -> str:
"""To get the command name.
Returns: AdminCommandNames.ABORT
"""
return AdminCommandNames.ABORT
[docs] def process(self, data: Shareable, fl_ctx: FLContext):
"""Called to process the abort command.
Args:
data: process data
fl_ctx: FLContext
Returns: abort command message
"""
server_runner = fl_ctx.get_prop(FLContextKey.RUNNER)
server_runner.abort(fl_ctx)
# wait for the runner process gracefully abort the run.
time.sleep(3.0)
return "Aborted the run"
[docs]class GetRunInfoCommand(CommandProcessor):
"""To implement the abort command."""
[docs] def get_command_name(self) -> str:
"""To get the command name.
Returns: ServerCommandNames.GET_RUN_INFO
"""
return ServerCommandNames.GET_RUN_INFO
[docs] def process(self, data: Shareable, fl_ctx: FLContext):
"""Called to process the abort command.
Args:
data: process data
fl_ctx: FLContext
Returns: Engine run_info
"""
engine = fl_ctx.get_engine()
return engine.get_run_info()
[docs]class GetTaskCommand(CommandProcessor):
"""To implement the server GetTask command."""
[docs] def get_command_name(self) -> str:
"""To get the command name.
Returns: ServerCommandNames.GET_TASK
"""
return ServerCommandNames.GET_TASK
[docs] def process(self, data: Shareable, fl_ctx: FLContext):
"""Called to process the abort command.
Args:
data: process data
fl_ctx: FLContext
Returns: task data
"""
shared_fl_ctx = data.get_header(ServerCommandKey.PEER_FL_CONTEXT)
client = data.get_header(ServerCommandKey.FL_CLIENT)
fl_ctx.set_peer_context(shared_fl_ctx)
server_runner = fl_ctx.get_prop(FLContextKey.RUNNER)
taskname, task_id, shareable = server_runner.process_task_request(client, fl_ctx)
data = {
ServerCommandKey.TASK_NAME: taskname,
ServerCommandKey.TASK_ID: task_id,
ServerCommandKey.SHAREABLE: shareable,
ServerCommandKey.FL_CONTEXT: copy.deepcopy(get_serializable_data(fl_ctx).props),
}
return fobs.dumps(data)
[docs]class SubmitUpdateCommand(CommandProcessor):
"""To implement the server GetTask command."""
[docs] def get_command_name(self) -> str:
"""To get the command name.
Returns: ServerCommandNames.SUBMIT_UPDATE
"""
return ServerCommandNames.SUBMIT_UPDATE
[docs] def process(self, data: dict, fl_ctx: FLContext):
"""Called to process the abort command.
Args:
data: process data
fl_ctx: FLContext
Returns:
"""
shareable = data.get(ReservedKey.SHAREABLE)
shared_fl_ctx = data.get(ReservedKey.SHARED_FL_CONTEXT)
client = shareable.get_header(ServerCommandKey.FL_CLIENT)
fl_ctx.set_peer_context(shared_fl_ctx)
contribution_task_name = shareable.get_header(ServerCommandKey.TASK_NAME)
task_id = shareable.get_cookie(FLContextKey.TASK_ID)
server_runner = fl_ctx.get_prop(FLContextKey.RUNNER)
server_runner.process_submission(client, contribution_task_name, task_id, shareable, fl_ctx)
return ""
[docs]class AuxCommunicateCommand(CommandProcessor):
"""To implement the server GetTask command."""
[docs] def get_command_name(self) -> str:
"""To get the command name.
Returns: ServerCommandNames.AUX_COMMUNICATE
"""
return ServerCommandNames.AUX_COMMUNICATE
[docs] def process(self, data: Shareable, fl_ctx: FLContext):
"""Called to process the abort command.
Args:
data: process data
fl_ctx: FLContext
Returns: task data
"""
shared_fl_ctx = data.get_header(ServerCommandKey.PEER_FL_CONTEXT)
topic = data.get_header(ServerCommandKey.TOPIC)
shareable = data.get_header(ServerCommandKey.SHAREABLE)
fl_ctx.set_peer_context(shared_fl_ctx)
engine = fl_ctx.get_engine()
reply = engine.dispatch(topic=topic, request=shareable, fl_ctx=fl_ctx)
data = {
ServerCommandKey.AUX_REPLY: reply,
ServerCommandKey.FL_CONTEXT: copy.deepcopy(get_serializable_data(fl_ctx).props),
}
return data
[docs]class ShowStatsCommand(CommandProcessor):
"""To implement the show_stats command."""
[docs] def get_command_name(self) -> str:
"""To get the command name.
Returns: ServerCommandNames.SHOW_STATS
"""
return ServerCommandNames.SHOW_STATS
[docs] def process(self, data: Shareable, fl_ctx: FLContext):
"""Called to process the abort command.
Args:
data: process data
fl_ctx: FLContext
Returns: Engine run_info
"""
engine = fl_ctx.get_engine()
collector = engine.get_widget(WidgetID.INFO_COLLECTOR)
return collector.get_run_stats()
[docs]class GetErrorsCommand(CommandProcessor):
"""To implement the show_errors command."""
[docs] def get_command_name(self) -> str:
"""To get the command name.
Returns: ServerCommandNames.GET_ERRORS
"""
return ServerCommandNames.GET_ERRORS
[docs] def process(self, data: Shareable, fl_ctx: FLContext):
"""Called to process the abort command.
Args:
data: process data
fl_ctx: FLContext
Returns: Engine run_info
"""
engine = fl_ctx.get_engine()
collector = engine.get_widget(WidgetID.INFO_COLLECTOR)
errors = collector.get_errors()
if not errors:
errors = "No Error"
return errors
[docs]class ByeCommand(CommandProcessor):
"""To implement the ShutdownCommand."""
[docs] def get_command_name(self) -> str:
"""To get the command name.
Returns: AdminCommandNames.SHUTDOWN
"""
return AdminCommandNames.SHUTDOWN
[docs] def process(self, data: Shareable, fl_ctx: FLContext):
"""Called to process the Shutdown command.
Args:
data: process data
fl_ctx: FLContext
Returns: Shutdown command message
"""
return None
[docs]class ServerCommands(object):
"""AdminCommands contains all the commands for processing the commands from the parent process."""
commands = [
AbortCommand(),
ByeCommand(),
GetRunInfoCommand(),
GetTaskCommand(),
SubmitUpdateCommand(),
AuxCommunicateCommand(),
ShowStatsCommand(),
GetErrorsCommand(),
]
[docs] @staticmethod
def get_command(command_name):
"""Call to return the AdminCommand object.
Args:
command_name: AdminCommand name
Returns: AdminCommand object
"""
for command in ServerCommands.commands:
if command_name == command.get_command_name():
return command
return None
[docs] @staticmethod
def register_command(command_processor: CommandProcessor):
"""Call to register the AdminCommand processor.
Args:
command_processor: AdminCommand processor
"""
if not isinstance(command_processor, CommandProcessor):
raise TypeError(
"command_processor must be an instance of CommandProcessor, but got {}".format(type(command_processor))
)
ServerCommands.commands.append(command_processor)