Source code for nvflare.private.fed.server.server_state

# Copyright (c) 2022, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

import logging
from abc import ABC, abstractmethod

from nvflare.apis.fl_context import FLContext
from nvflare.apis.overseer_spec import SP

ACTION = "_action"
MESSAGE = "_message"

NIS = "Not In Service"
ABORT_RUN = "Abort Run"
SERVICE = "In Service"


[docs]class ServiceSession: def __init__(self, host: str = "", port: str = "", ssid: str = "") -> None: self.host = host self.service_port = port self.ssid = ssid
[docs]class ServerState(ABC): NOT_IN_SERVICE = {ACTION: NIS, MESSAGE: "Server not in service"} ABORT_CURRENT_RUN = {ACTION: ABORT_RUN, MESSAGE: "Abort current run"} IN_SERVICE = {ACTION: SERVICE, MESSAGE: "Server in service"} logger = logging.getLogger("ServerState") def __init__(self, host: str = "", port: str = "", ssid: str = "") -> None: self.host = host self.service_port = port self.ssid = ssid self.primary = False
[docs] @abstractmethod def register(self, fl_ctx: FLContext) -> dict: pass
[docs] @abstractmethod def heartbeat(self, fl_ctx: FLContext) -> dict: pass
[docs] @abstractmethod def get_task(self, fl_ctx: FLContext) -> dict: pass
[docs] @abstractmethod def submit_result(self, fl_ctx: FLContext) -> dict: pass
[docs] @abstractmethod def aux_communicate(self, fl_ctx: FLContext) -> dict: pass
[docs] @abstractmethod def handle_sd_callback(self, sp: SP, fl_ctx: FLContext) -> ServerState: pass
[docs]class ColdState(ServerState):
[docs] def register(self, fl_ctx: FLContext) -> dict: return ServerState.NOT_IN_SERVICE
[docs] def heartbeat(self, fl_ctx: FLContext) -> dict: return ServerState.NOT_IN_SERVICE
[docs] def get_task(self, fl_ctx: FLContext) -> dict: return ServerState.NOT_IN_SERVICE
[docs] def submit_result(self, fl_ctx: FLContext) -> dict: return ServerState.NOT_IN_SERVICE
[docs] def aux_communicate(self, fl_ctx: FLContext) -> dict: return ServerState.NOT_IN_SERVICE
[docs] def handle_sd_callback(self, sp: SP, fl_ctx: FLContext) -> ServerState: if sp and sp.primary is True: if sp.name == self.host and sp.fl_port == self.service_port: self.primary = True self.ssid = sp.service_session_id self.logger.info( f"Got the primary sp: {sp.name} fl_port: {sp.fl_port} SSID: {sp.service_session_id}. " f"Turning to hot." ) return Cold2HotState(host=self.host, port=self.service_port, ssid=sp.service_session_id) else: self.primary = False return self return self
[docs]class Cold2HotState(ServerState):
[docs] def register(self, fl_ctx: FLContext) -> dict: return ServerState.IN_SERVICE
[docs] def heartbeat(self, fl_ctx: FLContext) -> dict: return ServerState.NOT_IN_SERVICE
[docs] def get_task(self, fl_ctx: FLContext) -> dict: return ServerState.ABORT_CURRENT_RUN
[docs] def submit_result(self, fl_ctx: FLContext) -> dict: return ServerState.ABORT_CURRENT_RUN
[docs] def aux_communicate(self, fl_ctx: FLContext) -> dict: return ServerState.ABORT_CURRENT_RUN
[docs] def handle_sd_callback(self, sp: SP, fl_ctx: FLContext) -> ServerState: return self
[docs]class HotState(ServerState):
[docs] def register(self, fl_ctx: FLContext) -> dict: return ServerState.IN_SERVICE
[docs] def heartbeat(self, fl_ctx: FLContext) -> dict: return ServerState.IN_SERVICE
[docs] def get_task(self, fl_ctx: FLContext) -> dict: return ServerState.IN_SERVICE
[docs] def submit_result(self, fl_ctx: FLContext) -> dict: return ServerState.IN_SERVICE
[docs] def aux_communicate(self, fl_ctx: FLContext) -> dict: return ServerState.IN_SERVICE
[docs] def handle_sd_callback(self, sp: SP, fl_ctx: FLContext) -> ServerState: if sp and sp.primary is True: if sp.name == self.host and sp.fl_port == self.service_port: self.primary = True if sp.service_session_id != self.ssid: self.ssid = sp.service_session_id self.logger.info( f"Primary sp changed to: {sp.name} fl_port: {sp.fl_port} SSID: {sp.service_session_id}. " f"Turning to Cold" ) return Hot2ColdState(host=self.host, port=self.service_port, ssid=sp.service_session_id) else: return self else: self.primary = False self.logger.info( f"Primary sp changed to: {sp.name} fl_port: {sp.fl_port} SSID: {sp.service_session_id}. " f"Turning to Cold" ) return Hot2ColdState(host=self.host, port=self.service_port) return self
[docs]class Hot2ColdState(ServerState):
[docs] def register(self, fl_ctx: FLContext) -> dict: return ServerState.NOT_IN_SERVICE
[docs] def heartbeat(self, fl_ctx: FLContext) -> dict: return ServerState.NOT_IN_SERVICE
[docs] def get_task(self, fl_ctx: FLContext) -> dict: return ServerState.NOT_IN_SERVICE
[docs] def submit_result(self, fl_ctx: FLContext) -> dict: return ServerState.NOT_IN_SERVICE
[docs] def aux_communicate(self, fl_ctx: FLContext) -> dict: return ServerState.NOT_IN_SERVICE
[docs] def handle_sd_callback(self, sp: SP, fl_ctx: FLContext) -> ServerState: return self
[docs]class ShutdownState(ServerState):
[docs] def register(self, fl_ctx: FLContext) -> dict: return ServerState.NOT_IN_SERVICE
[docs] def heartbeat(self, fl_ctx: FLContext) -> dict: return ServerState.NOT_IN_SERVICE
[docs] def get_task(self, fl_ctx: FLContext) -> dict: return ServerState.NOT_IN_SERVICE
[docs] def submit_result(self, fl_ctx: FLContext) -> dict: return ServerState.NOT_IN_SERVICE
[docs] def aux_communicate(self, fl_ctx: FLContext) -> dict: return ServerState.NOT_IN_SERVICE
[docs] def handle_sd_callback(self, sp: SP, fl_ctx: FLContext) -> ServerState: return self