# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""FL Admin commands."""
from nvflare.apis.fl_constant import AdminCommandNames, FLContextKey
from nvflare.apis.fl_context import FLContext
from nvflare.apis.shareable import Shareable
from nvflare.private.fed.client.client_status import get_status_message
from nvflare.widgets.info_collector import InfoCollector
from nvflare.widgets.widget import WidgetID
[docs]class CommandProcessor(object):
"""The CommandProcessor is responsible for processing a command from parent process."""
[docs] def get_command_name(self) -> str:
"""Get command name that this processor will handle.
Returns: name of the command
"""
pass
[docs] def process(self, data: Shareable, fl_ctx: FLContext):
"""Called to process the specified command.
Args:
data: process data
fl_ctx: FLContext
Return: reply message
"""
pass
[docs]class CheckStatusCommand(CommandProcessor):
"""To implement the check_status command."""
[docs] def get_command_name(self) -> str:
"""To get thee command name.
Returns: AdminCommandNames.CHECK_STATUSv
"""
return AdminCommandNames.CHECK_STATUS
[docs] def process(self, data: Shareable, fl_ctx: FLContext):
"""Called to process the check_status command.
Args:
data: process data
fl_ctx: FLContext
Returns: status message
"""
engine = fl_ctx.get_engine()
federated_client = engine.client
return get_status_message(federated_client.status)
[docs]class AbortCommand(CommandProcessor):
"""To implement the abort command."""
[docs] def get_command_name(self) -> str:
"""To get the command name.
Returns: AdminCommandNames.ABORT
"""
return AdminCommandNames.ABORT
[docs] def process(self, data: Shareable, fl_ctx: FLContext):
"""Called to process the abort command.
Args:
data: process data
fl_ctx: FLContext
Returns: abort command message
"""
client_runner = fl_ctx.get_prop(FLContextKey.RUNNER)
return client_runner.abort(msg="Received command to abort job")
[docs]class AbortTaskCommand(CommandProcessor):
"""To implement the abort_task command."""
[docs] def get_command_name(self) -> str:
"""To get the command name.
Returns: AdminCommandNames.ABORT_TASK
"""
return AdminCommandNames.ABORT_TASK
[docs] def process(self, data: Shareable, fl_ctx: FLContext):
"""Called to process the abort_task command.
Args:
data: process data
fl_ctx: FLContext
Returns: abort_task command message
"""
client_runner = fl_ctx.get_prop(FLContextKey.RUNNER)
if client_runner:
client_runner.abort_task()
return None
[docs]class ShowStatsCommand(CommandProcessor):
"""To implement the show_stats command."""
[docs] def get_command_name(self) -> str:
"""To get the command name.
Returns: AdminCommandNames.SHOW_STATS
"""
return AdminCommandNames.SHOW_STATS
[docs] def process(self, data: Shareable, fl_ctx: FLContext):
"""Called to process the abort_task command.
Args:
data: process data
fl_ctx: FLContext
Returns: show_stats command message
"""
engine = fl_ctx.get_engine()
collector = engine.get_widget(WidgetID.INFO_COLLECTOR)
if not collector:
result = {"error": "no info collector"}
else:
if not isinstance(collector, InfoCollector):
raise TypeError("collector must be an instance of InfoCollector, but got {}".format(type(collector)))
result = collector.get_run_stats()
if not result:
result = "No stats info"
return result
[docs]class ShowErrorsCommand(CommandProcessor):
"""To implement the show_errors command."""
[docs] def get_command_name(self) -> str:
"""To get the command name.
Returns: AdminCommandNames.SHOW_ERRORS
"""
return AdminCommandNames.SHOW_ERRORS
[docs] def process(self, data: Shareable, fl_ctx: FLContext):
"""Called to process the show_errors command.
Args:
data: process data
fl_ctx: FLContext
Returns: show_errors command message
"""
engine = fl_ctx.get_engine()
collector = engine.get_widget(WidgetID.INFO_COLLECTOR)
if not collector:
result = {"error": "no info collector"}
else:
if not isinstance(collector, InfoCollector):
raise TypeError("collector must be an instance of InfoCollector, but got {}".format(type(collector)))
result = collector.get_errors()
# CommandAgent is expecting data, could not be None
if result is None:
result = "No Errors"
return result
[docs]class ResetErrorsCommand(CommandProcessor):
"""To implement the reset_errors command."""
[docs] def get_command_name(self) -> str:
"""To get the command name.
Returns: AdminCommandNames.RESET_ERRORS
"""
return AdminCommandNames.RESET_ERRORS
[docs] def process(self, data: Shareable, fl_ctx: FLContext):
"""Called to process the reset_errors command.
Args:
data: process data
fl_ctx: FLContext
Returns: reset_errors command message
"""
engine = fl_ctx.get_engine()
engine.reset_errors()
[docs]class ByeCommand(CommandProcessor):
"""To implement the ShutdownCommand."""
[docs] def get_command_name(self) -> str:
"""To get the command name.
Returns: AdminCommandNames.SHUTDOWN
"""
return AdminCommandNames.SHUTDOWN
[docs] def process(self, data: Shareable, fl_ctx: FLContext):
"""Called to process the Shutdown command.
Args:
data: process data
fl_ctx: FLContext
Returns: Shutdown command message
"""
return None
[docs]class AdminCommands(object):
"""AdminCommands contains all the commands for processing the commands from the parent process."""
commands = [
CheckStatusCommand(),
AbortCommand(),
AbortTaskCommand(),
ByeCommand(),
ShowStatsCommand(),
ShowErrorsCommand(),
ResetErrorsCommand(),
]
[docs] @staticmethod
def get_command(command_name):
"""Call to return the AdminCommand object.
Args:
command_name: AdminCommand name
Returns: AdminCommand object
"""
for command in AdminCommands.commands:
if command_name == command.get_command_name():
return command
return None
[docs] @staticmethod
def register_command(command_processor: CommandProcessor):
"""Call to register the AdminCommand processor.
Args:
command_processor: AdminCommand processor
"""
if not isinstance(command_processor, CommandProcessor):
raise TypeError(
"command_processor must be an instance of CommandProcessor, but got {}".format(type(command_processor))
)
AdminCommands.commands.append(command_processor)