Source code for nvflare.private.fed.server.server_commands

# Copyright (c) 2022, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""FL Admin commands."""

import logging
import time
from abc import ABC, abstractmethod
from typing import List

from nvflare.apis.fl_constant import (
    AdminCommandNames,
    FLContextKey,
    MachineStatus,
    ServerCommandKey,
    ServerCommandNames,
)
from nvflare.apis.fl_context import FLContext
from nvflare.apis.shareable import Shareable
from nvflare.apis.utils.fl_context_utils import gen_new_peer_ctx
from nvflare.private.defs import SpecialTaskName, TaskConstant
from nvflare.widgets.widget import WidgetID

NO_OP_REPLY = "__no_op_reply"


[docs]class CommandProcessor(ABC): """The CommandProcessor is responsible for processing a command from parent process.""" def __init__(self) -> None: self.logger = logging.getLogger(self.__class__.__name__)
[docs] @abstractmethod def get_command_name(self) -> str: """Gets the command name that this processor will handle. Returns: name of the command """ pass
[docs] @abstractmethod def process(self, data: Shareable, fl_ctx: FLContext): """Processes the data. Args: data: process data fl_ctx: FLContext Return: A reply message """ pass
[docs]class ServerStateCheck(ABC): """Server command requires the server state check"""
[docs] @abstractmethod def get_state_check(self, fl_ctx: FLContext) -> dict: """Get the state check data for the server command. Args: fl_ctx: FLContext Returns: server state check dict data """ pass
[docs]class AbortCommand(CommandProcessor): """To implement the abort command."""
[docs] def get_command_name(self) -> str: """To get the command name. Returns: AdminCommandNames.ABORT """ return AdminCommandNames.ABORT
[docs] def process(self, data: Shareable, fl_ctx: FLContext): """Called to process the abort command. Args: data: process data fl_ctx: FLContext Returns: abort command message """ server_runner = fl_ctx.get_prop(FLContextKey.RUNNER) # for HA server switch over turn_to_cold = data.get_header(ServerCommandKey.TURN_TO_COLD, False) if server_runner: server_runner.abort(fl_ctx=fl_ctx, turn_to_cold=turn_to_cold) # wait for the runner process gracefully abort the run. engine = fl_ctx.get_engine() start_time = time.time() while engine.engine_info.status != MachineStatus.STOPPED: time.sleep(1.0) if time.time() - start_time > 30.0: break return "Aborted the run"
[docs]class GetRunInfoCommand(CommandProcessor): """Implements the GET_RUN_INFO command."""
[docs] def get_command_name(self) -> str: return ServerCommandNames.GET_RUN_INFO
[docs] def process(self, data: Shareable, fl_ctx: FLContext): engine = fl_ctx.get_engine() run_info = engine.get_run_info() if run_info: return run_info return NO_OP_REPLY
[docs]class GetTaskCommand(CommandProcessor, ServerStateCheck): """To implement the server GetTask command."""
[docs] def get_command_name(self) -> str: """To get the command name. Returns: ServerCommandNames.GET_TASK """ return ServerCommandNames.GET_TASK
[docs] def process(self, data: Shareable, fl_ctx: FLContext): """Called to process the abort command. Args: data: process data fl_ctx: FLContext Returns: task data """ start_time = time.time() shared_fl_ctx = data.get_header(ServerCommandKey.PEER_FL_CONTEXT) data.set_header(ServerCommandKey.PEER_FL_CONTEXT, FLContext()) client = data.get_header(ServerCommandKey.FL_CLIENT) self.logger.debug(f"Got the GET_TASK request from client: {client.name}") fl_ctx.set_peer_context(shared_fl_ctx) server_runner = fl_ctx.get_prop(FLContextKey.RUNNER) if not server_runner: # this is possible only when the client request is received before the # server_app_runner.start_server_app is called in runner_process.py # We ask the client to try again later. taskname = SpecialTaskName.TRY_AGAIN task_id = "" shareable = Shareable() shareable.set_header(TaskConstant.WAIT_TIME, 1.0) else: taskname, task_id, shareable = server_runner.process_task_request(client, fl_ctx) # we need TASK_ID back as a cookie if not shareable: shareable = Shareable() shareable.add_cookie(name=FLContextKey.TASK_ID, data=task_id) # we also need to make TASK_ID available to the client shareable.set_header(key=FLContextKey.TASK_ID, value=task_id) shareable.set_header(key=ServerCommandKey.TASK_NAME, value=taskname) shared_fl_ctx = gen_new_peer_ctx(fl_ctx) shareable.set_header(key=FLContextKey.PEER_CONTEXT, value=shared_fl_ctx) if taskname != SpecialTaskName.TRY_AGAIN: self.logger.info( f"return task to client. client_name: {client.name} task_name: {taskname} task_id: {task_id} " f"sharable_header_task_id: {shareable.get_header(key=FLContextKey.TASK_ID)}" ) self.logger.debug(f"Get_task processing time: {time.time()-start_time} for client: {client.name}") return shareable
[docs] def get_state_check(self, fl_ctx: FLContext) -> dict: engine = fl_ctx.get_engine() server_state = engine.server.server_state return server_state.get_task(fl_ctx)
[docs]class SubmitUpdateCommand(CommandProcessor, ServerStateCheck): """To implement the server GetTask command."""
[docs] def get_command_name(self) -> str: """To get the command name. Returns: ServerCommandNames.SUBMIT_UPDATE """ return ServerCommandNames.SUBMIT_UPDATE
[docs] def process(self, data: Shareable, fl_ctx: FLContext): """Called to process the abort command. Args: data: process data fl_ctx: FLContext Returns: """ start_time = time.time() shared_fl_ctx = data.get_header(ServerCommandKey.PEER_FL_CONTEXT) data.set_header(ServerCommandKey.PEER_FL_CONTEXT, FLContext()) shared_fl_ctx.set_prop(FLContextKey.SHAREABLE, data, private=True) client = data.get_header(ServerCommandKey.FL_CLIENT) fl_ctx.set_peer_context(shared_fl_ctx) contribution_task_name = data.get_header(FLContextKey.TASK_NAME) task_id = data.get_cookie(FLContextKey.TASK_ID) server_runner = fl_ctx.get_prop(FLContextKey.RUNNER) server_runner.process_submission(client, contribution_task_name, task_id, data, fl_ctx) self.logger.info(f"submit_update process. client_name:{client.name} task_id:{task_id}") self.logger.debug(f"Submit_result processing time: {time.time()-start_time} for client: {client.name}") return ""
[docs] def get_state_check(self, fl_ctx: FLContext) -> dict: engine = fl_ctx.get_engine() server_state = engine.server.server_state return server_state.submit_result(fl_ctx)
[docs]class HandleDeadJobCommand(CommandProcessor): """To implement the server HandleDeadJob command."""
[docs] def get_command_name(self) -> str: """To get the command name. Returns: ServerCommandNames.SUBMIT_UPDATE """ return ServerCommandNames.HANDLE_DEAD_JOB
[docs] def process(self, data: Shareable, fl_ctx: FLContext): """Called to process the HandleDeadJob command. Args: data: process data fl_ctx: FLContext Returns: """ client_name = data.get_header(ServerCommandKey.FL_CLIENT) reason = data.get_header(ServerCommandKey.REASON) self.logger.warning(f"received dead job notification: {reason=}") server_runner = fl_ctx.get_prop(FLContextKey.RUNNER) if server_runner: server_runner.handle_dead_job(client_name, fl_ctx) return ""
[docs]class ShowStatsCommand(CommandProcessor): """To implement the show_stats command."""
[docs] def get_command_name(self) -> str: """To get the command name. Returns: ServerCommandNames.SHOW_STATS """ return ServerCommandNames.SHOW_STATS
[docs] def process(self, data: Shareable, fl_ctx: FLContext): """Called to process the abort command. Args: data: process data fl_ctx: FLContext Returns: Engine run_info """ engine = fl_ctx.get_engine() collector = engine.get_widget(WidgetID.INFO_COLLECTOR) return collector.get_run_stats()
[docs]class GetErrorsCommand(CommandProcessor): """To implement the show_errors command."""
[docs] def get_command_name(self) -> str: """To get the command name. Returns: ServerCommandNames.GET_ERRORS """ return ServerCommandNames.GET_ERRORS
[docs] def process(self, data: Shareable, fl_ctx: FLContext): """Called to process the abort command. Args: data: process data fl_ctx: FLContext Returns: Engine run_info """ engine = fl_ctx.get_engine() collector = engine.get_widget(WidgetID.INFO_COLLECTOR) errors = collector.get_errors() if not errors: errors = "No Error" return errors
[docs]class ResetErrorsCommand(CommandProcessor): """To implement the show_errors command."""
[docs] def get_command_name(self) -> str: """To get the command name. Returns: ServerCommandNames.GET_ERRORS """ return ServerCommandNames.RESET_ERRORS
[docs] def process(self, data: Shareable, fl_ctx: FLContext): """Called to process the abort command. Args: data: process data fl_ctx: FLContext """ engine = fl_ctx.get_engine() collector = engine.get_widget(WidgetID.INFO_COLLECTOR) collector.reset_errors() return None
[docs]class ByeCommand(CommandProcessor): """To implement the ShutdownCommand."""
[docs] def get_command_name(self) -> str: """To get the command name. Returns: AdminCommandNames.SHUTDOWN """ return AdminCommandNames.SHUTDOWN
[docs] def process(self, data: Shareable, fl_ctx: FLContext): """Called to process the Shutdown command. Args: data: process data fl_ctx: FLContext Returns: Shutdown command message """ return None
[docs]class HeartbeatCommand(CommandProcessor): """To implement the HEARTBEATCommand."""
[docs] def get_command_name(self) -> str: """To get the command name. Returns: AdminCommandNames.HEARTBEAT """ return ServerCommandNames.HEARTBEAT
[docs] def process(self, data: Shareable, fl_ctx: FLContext): """Called to process the HEARTBEAT command. Args: data: process data fl_ctx: FLContext """ return None
[docs]class ServerStateCommand(CommandProcessor): """To implement the ServerStateCommand."""
[docs] def get_command_name(self) -> str: """To get the command name. Returns: AdminCommandNames.SERVER_STATE """ return ServerCommandNames.SERVER_STATE
[docs] def process(self, data: Shareable, fl_ctx: FLContext): """Called to process the SERVER_STATE command. Args: data: ServerState object fl_ctx: FLContext """ engine = fl_ctx.get_engine() engine.server.server_state = data return "Success"
[docs]class ServerCommands(object): """AdminCommands contains all the commands for processing the commands from the parent process.""" commands: List[CommandProcessor] = [ AbortCommand(), ByeCommand(), GetRunInfoCommand(), GetTaskCommand(), SubmitUpdateCommand(), HandleDeadJobCommand(), ShowStatsCommand(), GetErrorsCommand(), ResetErrorsCommand(), HeartbeatCommand(), ServerStateCommand(), ] client_request_commands_names = [ ServerCommandNames.GET_TASK, ServerCommandNames.SUBMIT_UPDATE, # ServerCommandNames.AUX_COMMUNICATE, ]
[docs] @staticmethod def get_command(command_name): """Call to return the AdminCommand object. Args: command_name: AdminCommand name Returns: AdminCommand object """ for command in ServerCommands.commands: if command_name == command.get_command_name(): return command return None
[docs] @staticmethod def register_command(command_processor: CommandProcessor): """Call to register the AdminCommand processor. Args: command_processor: AdminCommand processor """ if not isinstance(command_processor, CommandProcessor): raise TypeError( "command_processor must be an instance of CommandProcessor, but got {}".format(type(command_processor)) ) ServerCommands.commands.append(command_processor)