Source code for nvflare.private.fed.client.admin_commands

# Copyright (c) 2021, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""FL Admin commands."""

from nvflare.apis.fl_constant import AdminCommandNames, FLContextKey
from nvflare.apis.fl_context import FLContext
from nvflare.apis.shareable import Shareable
from nvflare.private.fed.client.client_status import get_status_message
from nvflare.widgets.info_collector import InfoCollector
from nvflare.widgets.widget import WidgetID


[docs]class CommandProcessor(object): """The CommandProcessor is responsible for processing a command from parent process."""
[docs] def get_command_name(self) -> str: """Get command name that this processor will handle. Returns: name of the command """ pass
[docs] def process(self, data: Shareable, fl_ctx: FLContext): """Called to process the specified command. Args: data: process data fl_ctx: FLContext Return: reply message """ pass
[docs]class CheckStatusCommand(CommandProcessor): """To implement the check_status command."""
[docs] def get_command_name(self) -> str: """To get thee command name. Returns: AdminCommandNames.CHECK_STATUSv """ return AdminCommandNames.CHECK_STATUS
[docs] def process(self, data: Shareable, fl_ctx: FLContext): """Called to process the check_status command. Args: data: process data fl_ctx: FLContext Returns: status message """ engine = fl_ctx.get_engine() federated_client = engine.client return get_status_message(federated_client.status)
[docs]class AbortCommand(CommandProcessor): """To implement the abort command."""
[docs] def get_command_name(self) -> str: """To get the command name. Returns: AdminCommandNames.ABORT """ return AdminCommandNames.ABORT
[docs] def process(self, data: Shareable, fl_ctx: FLContext): """Called to process the abort command. Args: data: process data fl_ctx: FLContext Returns: abort command message """ client_runner = fl_ctx.get_prop(FLContextKey.RUNNER) return client_runner.abort(msg="Received command to abort job")
[docs]class AbortTaskCommand(CommandProcessor): """To implement the abort_task command."""
[docs] def get_command_name(self) -> str: """To get the command name. Returns: AdminCommandNames.ABORT_TASK """ return AdminCommandNames.ABORT_TASK
[docs] def process(self, data: Shareable, fl_ctx: FLContext): """Called to process the abort_task command. Args: data: process data fl_ctx: FLContext Returns: abort_task command message """ client_runner = fl_ctx.get_prop(FLContextKey.RUNNER) if client_runner: client_runner.abort_task() return None
[docs]class ShowStatsCommand(CommandProcessor): """To implement the show_stats command."""
[docs] def get_command_name(self) -> str: """To get the command name. Returns: AdminCommandNames.SHOW_STATS """ return AdminCommandNames.SHOW_STATS
[docs] def process(self, data: Shareable, fl_ctx: FLContext): """Called to process the abort_task command. Args: data: process data fl_ctx: FLContext Returns: show_stats command message """ engine = fl_ctx.get_engine() collector = engine.get_widget(WidgetID.INFO_COLLECTOR) if not collector: result = {"error": "no info collector"} else: if not isinstance(collector, InfoCollector): raise TypeError("collector must be an instance of InfoCollector, but got {}".format(type(collector))) result = collector.get_run_stats() if not result: result = "No stats info" return result
[docs]class ShowErrorsCommand(CommandProcessor): """To implement the show_errors command."""
[docs] def get_command_name(self) -> str: """To get the command name. Returns: AdminCommandNames.SHOW_ERRORS """ return AdminCommandNames.SHOW_ERRORS
[docs] def process(self, data: Shareable, fl_ctx: FLContext): """Called to process the show_errors command. Args: data: process data fl_ctx: FLContext Returns: show_errors command message """ engine = fl_ctx.get_engine() collector = engine.get_widget(WidgetID.INFO_COLLECTOR) if not collector: result = {"error": "no info collector"} else: if not isinstance(collector, InfoCollector): raise TypeError("collector must be an instance of InfoCollector, but got {}".format(type(collector))) result = collector.get_errors() # CommandAgent is expecting data, could not be None if result is None: result = "No Errors" return result
[docs]class ResetErrorsCommand(CommandProcessor): """To implement the reset_errors command."""
[docs] def get_command_name(self) -> str: """To get the command name. Returns: AdminCommandNames.RESET_ERRORS """ return AdminCommandNames.RESET_ERRORS
[docs] def process(self, data: Shareable, fl_ctx: FLContext): """Called to process the reset_errors command. Args: data: process data fl_ctx: FLContext Returns: reset_errors command message """ engine = fl_ctx.get_engine() engine.reset_errors()
[docs]class ByeCommand(CommandProcessor): """To implement the ShutdownCommand."""
[docs] def get_command_name(self) -> str: """To get the command name. Returns: AdminCommandNames.SHUTDOWN """ return AdminCommandNames.SHUTDOWN
[docs] def process(self, data: Shareable, fl_ctx: FLContext): """Called to process the Shutdown command. Args: data: process data fl_ctx: FLContext Returns: Shutdown command message """ return None
[docs]class AdminCommands(object): """AdminCommands contains all the commands for processing the commands from the parent process.""" commands = [ CheckStatusCommand(), AbortCommand(), AbortTaskCommand(), ByeCommand(), ShowStatsCommand(), ShowErrorsCommand(), ResetErrorsCommand(), ]
[docs] @staticmethod def get_command(command_name): """Call to return the AdminCommand object. Args: command_name: AdminCommand name Returns: AdminCommand object """ for command in AdminCommands.commands: if command_name == command.get_command_name(): return command return None
[docs] @staticmethod def register_command(command_processor: CommandProcessor): """Call to register the AdminCommand processor. Args: command_processor: AdminCommand processor """ if not isinstance(command_processor, CommandProcessor): raise TypeError( "command_processor must be an instance of CommandProcessor, but got {}".format(type(command_processor)) ) AdminCommands.commands.append(command_processor)