# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import os
import shutil
import threading
import time
from abc import ABC, abstractmethod
from threading import Lock
from typing import Dict, List, Optional
from nvflare.apis.client import Client
from nvflare.apis.event_type import EventType
from nvflare.apis.fl_component import FLComponent
from nvflare.apis.fl_constant import (
FLContextKey,
MachineStatus,
RunProcessKey,
SecureTrainConst,
ServerCommandKey,
ServerCommandNames,
SnapshotKey,
WorkspaceConstants,
)
from nvflare.apis.fl_context import FLContext
from nvflare.apis.fl_exception import NotAuthenticated
from nvflare.apis.workspace import Workspace
from nvflare.fuel.common.exit_codes import ProcessExitCode
from nvflare.fuel.f3.cellnet.cell import Cell
from nvflare.fuel.f3.cellnet.core_cell import Message
from nvflare.fuel.f3.cellnet.core_cell import make_reply as make_cellnet_reply
from nvflare.fuel.f3.cellnet.defs import MessageHeaderKey
from nvflare.fuel.f3.cellnet.defs import ReturnCode as F3ReturnCode
from nvflare.fuel.f3.cellnet.fqcn import FQCN
from nvflare.fuel.f3.cellnet.net_agent import NetAgent
from nvflare.fuel.f3.drivers.driver_params import DriverParams
from nvflare.fuel.f3.mpm import MainProcessMonitor as mpm
from nvflare.fuel.utils.argument_utils import parse_vars
from nvflare.fuel.utils.zip_utils import unzip_all_from_bytes
from nvflare.ha.overseer_agent import HttpOverseerAgent
from nvflare.private.defs import (
CellChannel,
CellChannelTopic,
CellMessageHeaderKeys,
JobFailureMsgKey,
new_cell_message,
)
from nvflare.private.fed.server.server_command_agent import ServerCommandAgent
from nvflare.private.fed.server.server_runner import ServerRunner
from nvflare.security.logging import secure_format_exception
from nvflare.widgets.fed_event import ServerFedEventRunner
from .client_manager import ClientManager
from .run_manager import RunManager
from .server_engine import ServerEngine
from .server_state import (
ABORT_RUN,
ACTION,
MESSAGE,
NIS,
Cold2HotState,
ColdState,
Hot2ColdState,
HotState,
ServerState,
)
from .server_status import ServerStatus
[docs]class BaseServer(ABC):
def __init__(
self,
project_name=None,
min_num_clients=2,
max_num_clients=10,
heart_beat_timeout=600,
handlers: Optional[List[FLComponent]] = None,
shutdown_period=30.0,
):
"""Base server that provides the clients management and server deployment."""
self.project_name = project_name
self.min_num_clients = max(min_num_clients, 1)
self.max_num_clients = max(max_num_clients, 1)
self.heart_beat_timeout = heart_beat_timeout
self.handlers = handlers
self.client_manager = ClientManager(
project_name=self.project_name, min_num_clients=self.min_num_clients, max_num_clients=self.max_num_clients
)
self.cell = None
self.admin_server = None
self.lock = Lock()
self.snapshot_lock = Lock()
self.fl_ctx = FLContext()
self.platform = None
self.shutdown_period = shutdown_period
self.shutdown = False
self.status = ServerStatus.NOT_STARTED
self.abort_signal = None
self.executor = None
self.logger = logging.getLogger(self.__class__.__name__)
[docs] def get_all_clients(self) -> Dict[str, Client]:
"""Get the list of registered clients.
Returns:
A dict of {client_token: client}
"""
return self.client_manager.get_clients()
[docs] def get_cell(self):
return self.cell
[docs] @abstractmethod
def remove_client_data(self, token):
pass
[docs] def close(self):
"""Shutdown the server."""
try:
if self.lock:
self.lock.release()
except RuntimeError:
self.logger.info("canceling sync locks")
self.logger.info("server off")
return 0
[docs] def deploy(self, args, grpc_args=None, secure_train=False):
"""Start a grpc server and listening the designated port."""
target = grpc_args["service"].get("target", "0.0.0.0:6007")
scheme = grpc_args["service"].get("scheme", "grpc")
if secure_train:
root_cert = grpc_args[SecureTrainConst.SSL_ROOT_CERT]
ssl_cert = grpc_args[SecureTrainConst.SSL_CERT]
private_key = grpc_args[SecureTrainConst.PRIVATE_KEY]
credentials = {
DriverParams.CA_CERT.value: root_cert,
DriverParams.SERVER_CERT.value: ssl_cert,
DriverParams.SERVER_KEY.value: private_key,
}
else:
credentials = {}
parent_url = None
parts = target.split(":")
if len(parts) > 1:
# "0" means all interfaces for all protocols (ipv4 and ipv6)
listen_target = "0:" + parts[1]
else:
listen_target = target
my_fqcn = FQCN.ROOT_SERVER
self.cell = Cell(
fqcn=my_fqcn,
root_url=scheme + "://" + listen_target,
secure=secure_train,
credentials=credentials,
create_internal_listener=True,
parent_url=parent_url,
)
self.cell.start()
mpm.add_cleanup_cb(self.cell.stop)
# return self.start()
cleanup_thread = threading.Thread(target=self.client_cleanup)
# heartbeat_thread.daemon = True
cleanup_thread.start()
[docs] def client_cleanup(self):
last_remove_time = 0.0
remove_interval = 5.0
check_interval = 0.2
while not self.shutdown:
now = time.time()
if now - last_remove_time > remove_interval:
self.remove_dead_clients()
last_remove_time = now
# We want to sleep very little to check the self.shutdown quickly
# so that when self.shutdown is set we can return quickly.
# Without this, when the server parent cell ends, this thread will not end until 15 secs later.
# This will cause MPM's cleanup to fail!
time.sleep(check_interval)
[docs] def set_admin_server(self, admin_server):
self.admin_server = admin_server
[docs] def remove_dead_clients(self):
# Clean and remove the dead client without heartbeat.
self.logger.debug("trying to remove dead clients .......")
delete = []
for token, client in self.client_manager.get_clients().items():
if client.last_connect_time < time.time() - self.heart_beat_timeout:
delete.append(token)
for token in delete:
client = self.logout_client(token)
self.logger.info(
"Remove the dead Client. Name: {}\t Token: {}. Total clients: {}".format(
client.name, token, len(self.client_manager.get_clients())
)
)
[docs] def logout_client(self, token):
client = self.client_manager.remove_client(token)
self.remove_client_data(token)
if self.admin_server:
self.admin_server.client_dead(token)
self.notify_dead_client(client)
return client
[docs] def notify_dead_client(self, client):
"""Called to do further processing of the dead client
Args:
client: the dead client
Returns:
"""
pass
[docs] def fl_shutdown(self):
self.shutdown = True
start = time.time()
while self.client_manager.clients:
# Wait for the clients to shutdown and quite first.
time.sleep(0.1)
if time.time() - start > self.shutdown_period:
self.logger.info("There are still clients connected. But shutdown the server after timeout.")
break
self.close()
if self.executor:
self.executor.shutdown()
[docs]class FederatedServer(BaseServer):
def __init__(
self,
project_name=None,
min_num_clients=2,
max_num_clients=10,
cmd_modules=None,
heart_beat_timeout=600,
handlers: Optional[List[FLComponent]] = None,
args=None,
secure_train=False,
snapshot_persistor=None,
overseer_agent=None,
shutdown_period=30.0,
check_engine_frequency=3.0,
):
"""Federated server services.
Args:
project_name: server project name.
min_num_clients: minimum number of contributors at each round.
max_num_clients: maximum number of contributors at each round.
cmd_modules: command modules.
heart_beat_timeout: heartbeat timeout
handlers: A list of handler
args: arguments
secure_train: whether to use secure communication
"""
BaseServer.__init__(
self,
project_name=project_name,
min_num_clients=min_num_clients,
max_num_clients=max_num_clients,
heart_beat_timeout=heart_beat_timeout,
handlers=handlers,
shutdown_period=shutdown_period,
)
self.contributed_clients = {}
self.tokens = None
self.round_started = time.time()
with self.lock:
self.reset_tokens()
self.cmd_modules = cmd_modules
self.builder = None
self.engine = self._create_server_engine(args, snapshot_persistor)
self.run_manager = None
self.server_runner = None
self.command_agent = None
self.check_engine_frequency = check_engine_frequency
self.processors = {}
self.runner_config = None
self.secure_train = secure_train
self.workspace = args.workspace
self.snapshot_location = None
self.overseer_agent = overseer_agent
self.server_state: ServerState = ColdState()
self.snapshot_persistor = snapshot_persistor
self.checking_server_state = False
self.ha_mode = False
def _register_cellnet_cbs(self):
self.cell.register_request_cb(
channel=CellChannel.SERVER_MAIN,
topic=CellChannelTopic.Register,
cb=self.register_client,
)
self.cell.register_request_cb(
channel=CellChannel.SERVER_MAIN,
topic=CellChannelTopic.Quit,
cb=self.quit_client,
)
self.cell.register_request_cb(
channel=CellChannel.SERVER_MAIN,
topic=CellChannelTopic.HEART_BEAT,
cb=self.client_heartbeat,
)
self.cell.register_request_cb(
channel=CellChannel.SERVER_MAIN,
topic=CellChannelTopic.REPORT_JOB_FAILURE,
cb=self.process_job_failure,
)
self.cell.register_request_cb(
channel=CellChannel.SERVER_PARENT_LISTENER,
topic="*",
cb=self._listen_command,
)
def _listen_command(self, request: Message) -> Message:
job_id = request.get_header(CellMessageHeaderKeys.JOB_ID)
command = request.get_header(MessageHeaderKey.TOPIC)
data = request.payload
if command == ServerCommandNames.GET_CLIENTS:
if job_id in self.engine.run_processes:
clients = self.engine.run_processes[job_id].get(RunProcessKey.PARTICIPANTS)
return_data = {ServerCommandKey.CLIENTS: clients, ServerCommandKey.JOB_ID: job_id}
else:
return_data = {ServerCommandKey.CLIENTS: None, ServerCommandKey.JOB_ID: job_id}
return make_cellnet_reply(F3ReturnCode.OK, "", return_data)
elif command == ServerCommandNames.UPDATE_RUN_STATUS:
execution_error = data.get("execution_error")
with self.lock:
run_process_info = self.engine.run_processes.get(job_id)
if run_process_info is not None:
if execution_error:
run_process_info[RunProcessKey.PROCESS_EXE_ERROR] = True
self.engine.exception_run_processes[job_id] = run_process_info
run_process_info[RunProcessKey.PROCESS_FINISHED] = True
reply = make_cellnet_reply(F3ReturnCode.OK, "", None)
return reply
elif command == ServerCommandNames.HEARTBEAT:
return make_cellnet_reply(F3ReturnCode.OK, "", None)
else:
return make_cellnet_reply(F3ReturnCode.INVALID_REQUEST, "", None)
def _create_server_engine(self, args, snapshot_persistor):
return ServerEngine(
server=self, args=args, client_manager=self.client_manager, snapshot_persistor=snapshot_persistor
)
[docs] def create_job_cell(self, job_id, root_url, parent_url, secure_train, server_config) -> Cell:
my_fqcn = FQCN.join([FQCN.ROOT_SERVER, job_id])
if secure_train:
root_cert = server_config[SecureTrainConst.SSL_ROOT_CERT]
ssl_cert = server_config[SecureTrainConst.SSL_CERT]
private_key = server_config[SecureTrainConst.PRIVATE_KEY]
credentials = {
DriverParams.CA_CERT.value: root_cert,
DriverParams.SERVER_CERT.value: ssl_cert,
DriverParams.SERVER_KEY.value: private_key,
}
else:
credentials = {}
cell = Cell(
fqcn=my_fqcn,
root_url=root_url,
secure=secure_train,
credentials=credentials,
create_internal_listener=False,
parent_url=parent_url,
)
cell.start()
net_agent = NetAgent(cell)
mpm.add_cleanup_cb(net_agent.close)
mpm.add_cleanup_cb(cell.stop)
self.command_agent = ServerCommandAgent(self.engine, cell)
self.command_agent.start()
return cell
# @property
[docs] def remove_client_data(self, token):
self.tokens.pop(token, None)
[docs] def reset_tokens(self):
"""Reset the token set.
After resetting, each client can take a token
and start fetching the current global model.
This function is not thread-safe.
"""
self.tokens = dict()
for token, client in self.get_all_clients().items():
self.tokens[token] = self.task_meta_info(client.name)
def _before_service(self, fl_ctx: FLContext):
# before the service processing
fl_ctx.remove_prop(FLContextKey.COMMUNICATION_ERROR)
fl_ctx.remove_prop(FLContextKey.UNAUTHENTICATED)
def _generate_reply(self, headers, payload, fl_ctx: FLContext):
# process after the service processing
unauthenticated = fl_ctx.get_prop(FLContextKey.UNAUTHENTICATED)
if unauthenticated:
return make_cellnet_reply(rc=F3ReturnCode.UNAUTHENTICATED, error=unauthenticated)
error = fl_ctx.get_prop(FLContextKey.COMMUNICATION_ERROR)
if error:
return make_cellnet_reply(rc=F3ReturnCode.COMM_ERROR, error=error)
else:
return_message = new_cell_message(headers, payload)
return_message.set_header(MessageHeaderKey.RETURN_CODE, F3ReturnCode.OK)
return return_message
[docs] def register_client(self, request: Message) -> Message:
"""Register new clients on the fly.
Each client must get registered before getting the global model.
The server will expect updates from the registered clients
for multiple federated rounds.
This function does not change min_num_clients and max_num_clients.
"""
with self.engine.new_context() as fl_ctx:
try:
self._before_service(fl_ctx)
state_check = self.server_state.register(fl_ctx)
error = self._handle_state_check(state_check, fl_ctx)
if error is not None:
return make_cellnet_reply(rc=F3ReturnCode.COMM_ERROR, error=error)
data = request.payload
shared_fl_ctx = data.get_header(ServerCommandKey.PEER_FL_CONTEXT)
fl_ctx.set_peer_context(shared_fl_ctx)
self.engine.fire_event(EventType.CLIENT_REGISTER_RECEIVED, fl_ctx=fl_ctx)
exceptions = fl_ctx.get_prop(FLContextKey.EXCEPTIONS)
if exceptions:
for _, exception in exceptions.items():
if isinstance(exception, NotAuthenticated):
raise exception
client = self.client_manager.authenticate(request, fl_ctx)
if client and client.token:
self.tokens[client.token] = self.task_meta_info(client.name)
if self.admin_server:
self.admin_server.client_heartbeat(client.token, client.name)
headers = {
CellMessageHeaderKeys.TOKEN: client.token,
CellMessageHeaderKeys.SSID: self.server_state.ssid,
}
else:
headers = {}
self.engine.fire_event(EventType.CLIENT_REGISTER_PROCESSED, fl_ctx=fl_ctx)
return self._generate_reply(headers=headers, payload=None, fl_ctx=fl_ctx)
except NotAuthenticated as e:
self.logger.error(f"Failed to authenticate the register_client: {secure_format_exception(e)}")
return make_cellnet_reply(rc=F3ReturnCode.UNAUTHENTICATED, error="register_client unauthenticated")
def _handle_state_check(self, state_check, fl_ctx: FLContext):
if state_check.get(ACTION) in [NIS, ABORT_RUN]:
fl_ctx.set_prop(FLContextKey.COMMUNICATION_ERROR, state_check.get(MESSAGE), sticky=False)
return state_check.get(MESSAGE)
return None
[docs] def quit_client(self, request: Message) -> Message:
"""Existing client quits the federated training process.
Server will stop sharing the global model with the client,
further contribution will be rejected.
This function does not change min_num_clients and max_num_clients.
"""
with self.engine.new_context() as fl_ctx:
client = self.client_manager.validate_client(request, fl_ctx)
if client:
token = client.get_token()
self.logout_client(token)
data = request.payload
shared_fl_ctx = data.get_header(ServerCommandKey.PEER_FL_CONTEXT)
fl_ctx.set_peer_context(shared_fl_ctx)
self.engine.fire_event(EventType.CLIENT_QUIT, fl_ctx=fl_ctx)
headers = {CellMessageHeaderKeys.MESSAGE: "Removed client"}
return self._generate_reply(headers=headers, payload=None, fl_ctx=fl_ctx)
[docs] def process_job_failure(self, request: Message):
payload = request.payload
client = request.get_header(key=MessageHeaderKey.ORIGIN)
if not isinstance(payload, dict):
self.logger.error(
f"dropped bad Job Failure report from {client}: expect payload to be dict but got {type(payload)}"
)
return
job_id = payload.get(JobFailureMsgKey.JOB_ID)
if not job_id:
self.logger.error(f"dropped bad Job Failure report from {client}: no job_id")
return
code = payload.get(JobFailureMsgKey.CODE)
reason = payload.get(JobFailureMsgKey.REASON, "?")
if code == ProcessExitCode.UNSAFE_COMPONENT:
with self.engine.new_context() as fl_ctx:
self.logger.info(f"Aborting job {job_id} due to reported failure from {client}: {reason}")
self.engine.job_runner.stop_run(job_id, fl_ctx)
[docs] def client_heartbeat(self, request: Message) -> Message:
with self.engine.new_context() as fl_ctx:
self._before_service(fl_ctx)
state_check = self.server_state.heartbeat(fl_ctx)
error = self._handle_state_check(state_check, fl_ctx)
if error is not None:
return make_cellnet_reply(rc=F3ReturnCode.COMM_ERROR, error=error)
data = request.payload
shared_fl_ctx = data.get_header(ServerCommandKey.PEER_FL_CONTEXT)
fl_ctx.set_peer_context(shared_fl_ctx)
self.engine.fire_event(EventType.CLIENT_HEARTBEAT_RECEIVED, fl_ctx=fl_ctx)
token = request.get_header(CellMessageHeaderKeys.TOKEN)
client_name = request.get_header(CellMessageHeaderKeys.CLIENT_NAME)
if self.client_manager.heartbeat(token, client_name, fl_ctx):
self.tokens[token] = self.task_meta_info(client_name)
if self.admin_server:
self.admin_server.client_heartbeat(token, client_name)
abort_runs = self._sync_client_jobs(request, token)
reply = self._generate_reply(
headers={CellMessageHeaderKeys.MESSAGE: "Heartbeat response"}, payload=None, fl_ctx=fl_ctx
)
if abort_runs:
reply.set_header(CellMessageHeaderKeys.ABORT_JOBS, abort_runs)
display_runs = ",".join(abort_runs)
self.logger.debug(
f"These jobs: {display_runs} are not running on the server. "
f"Ask client: {client_name} to abort these runs."
)
self.engine.fire_event(EventType.CLIENT_HEARTBEAT_PROCESSED, fl_ctx=fl_ctx)
return reply
def _sync_client_jobs(self, request, client_token):
# jobs that are running on client but not on server need to be aborted!
client_jobs = request.get_header(CellMessageHeaderKeys.JOB_IDS)
server_jobs = self.engine.run_processes.keys()
jobs_need_abort = list(set(client_jobs).difference(server_jobs))
# also check jobs that are running on server but not on the client
jobs_on_server_but_not_on_client = list(set(server_jobs).difference(client_jobs))
if jobs_on_server_but_not_on_client:
# notify all the participating clients these jobs are not running on server anymore
for job_id in jobs_on_server_but_not_on_client:
job_info = self.engine.run_processes[job_id]
participating_clients = job_info.get(RunProcessKey.PARTICIPANTS, None)
if participating_clients:
# this is a dict: token => nvflare.apis.client.Client
client = participating_clients.get(client_token, None)
if client:
self._notify_dead_job(client, job_id, "missing job on client")
return jobs_need_abort
def _notify_dead_job(self, client, job_id: str, reason: str):
try:
self.engine.notify_dead_job(job_id, client.name, reason)
except Exception as ex:
self.logger.info(
f"Failed to notify_dead_job to runner process of job {job_id}: {secure_format_exception(ex)}"
)
[docs] def notify_dead_client(self, client):
"""Called to do further processing of the dead client
Args:
client: the dead client
Returns:
"""
# find all RUNs that this client is participating
if not self.engine.run_processes:
return
for job_id, process_info in self.engine.run_processes.items():
assert isinstance(process_info, dict)
participating_clients = process_info.get(RunProcessKey.PARTICIPANTS, None)
if participating_clients and client.token in participating_clients:
self._notify_dead_job(client, job_id, "client dead")
[docs] def start_run(self, job_id, run_root, conf, args, snapshot):
# Create the FL Engine
workspace = Workspace(args.workspace, "server", args.config_folder)
self.run_manager = self.create_run_manager(workspace, job_id)
self.engine.set_run_manager(self.run_manager)
self.engine.set_configurator(conf)
self.engine.asked_to_stop = False
self.run_manager.cell = self.cell
fed_event_runner = ServerFedEventRunner()
self.run_manager.add_handler(fed_event_runner)
try:
self.server_runner = ServerRunner(config=self.runner_config, job_id=job_id, engine=self.engine)
self.run_manager.add_handler(self.server_runner)
self.run_manager.add_component("_Server_Runner", self.server_runner)
with self.engine.new_context() as fl_ctx:
if snapshot:
self.engine.restore_components(snapshot=snapshot, fl_ctx=FLContext())
fl_ctx.set_prop(FLContextKey.APP_ROOT, run_root, sticky=True)
fl_ctx.set_prop(FLContextKey.CURRENT_RUN, job_id, private=False, sticky=True)
fl_ctx.set_prop(FLContextKey.WORKSPACE_ROOT, args.workspace, private=True, sticky=True)
fl_ctx.set_prop(FLContextKey.ARGS, args, private=True, sticky=True)
fl_ctx.set_prop(FLContextKey.WORKSPACE_OBJECT, workspace, private=True)
fl_ctx.set_prop(FLContextKey.SECURE_MODE, self.secure_train, private=True, sticky=True)
fl_ctx.set_prop(FLContextKey.RUNNER, self.server_runner, private=True, sticky=True)
engine_thread = threading.Thread(target=self.run_engine)
engine_thread.start()
self.engine.engine_info.status = MachineStatus.STARTED
while self.engine.engine_info.status != MachineStatus.STOPPED:
if self.engine.asked_to_stop:
self.engine.engine_info.status = MachineStatus.STOPPED
time.sleep(self.check_engine_frequency)
finally:
self.engine.engine_info.status = MachineStatus.STOPPED
self.run_manager = None
[docs] def create_run_manager(self, workspace, job_id):
return RunManager(
server_name=self.project_name,
engine=self.engine,
job_id=job_id,
workspace=workspace,
components=self.runner_config.components,
client_manager=self.client_manager,
handlers=self.runner_config.handlers,
)
[docs] def authentication_check(self, request: Message, state_check):
error = None
# server_state = self.engine.server.server_state
if state_check.get(ACTION) in [NIS, ABORT_RUN]:
# return make_reply(ReturnCode.AUTHENTICATION_ERROR, state_check.get(MESSAGE), fobs.dumps(None))
error = state_check.get(MESSAGE)
client_ssid = request.get_header(CellMessageHeaderKeys.SSID, None)
if client_ssid != self.server_state.ssid:
# return make_reply(ReturnCode.AUTHENTICATION_ERROR, "Request from invalid client SSID",
# fobs.dumps(None))
error = "Request from unknown client SSID"
return error
[docs] def abort_run(self):
with self.engine.new_context() as fl_ctx:
if self.server_runner:
self.server_runner.abort(fl_ctx)
[docs] def run_engine(self):
self.engine.engine_info.status = MachineStatus.STARTED
try:
self.server_runner.run()
except Exception as e:
self.logger.error(f"FL server execution exception: {secure_format_exception(e)}")
finally:
# self.engine.update_job_run_status()
self.stop_run_engine_cell()
self.engine.engine_info.status = MachineStatus.STOPPED
[docs] def stop_run_engine_cell(self):
# self.cell.stop()
# mpm.stop()
pass
[docs] def deploy(self, args, grpc_args=None, secure_train=False):
super().deploy(args, grpc_args, secure_train)
target = grpc_args["service"].get("target", "0.0.0.0:6007")
with self.lock:
self.server_state.host = target.split(":")[0]
self.server_state.service_port = target.split(":")[1]
self.overseer_agent = self._init_agent(args)
if isinstance(self.overseer_agent, HttpOverseerAgent):
self.ha_mode = True
if secure_train:
if self.overseer_agent:
self.overseer_agent.set_secure_context(
ca_path=grpc_args["ssl_root_cert"],
cert_path=grpc_args["ssl_cert"],
prv_key_path=grpc_args["ssl_private_key"],
)
self.engine.cell = self.cell
self._register_cellnet_cbs()
self.overseer_agent.start(self.overseer_callback)
def _init_agent(self, args=None):
kv_list = parse_vars(args.set)
sp = kv_list.get("sp")
if sp:
with self.engine.new_context() as fl_ctx:
fl_ctx.set_prop(FLContextKey.SP_END_POINT, sp)
self.overseer_agent.initialize(fl_ctx)
return self.overseer_agent
def _check_server_state(self, overseer_agent):
if self.status != ServerStatus.STARTED:
return
if overseer_agent.is_shutdown():
self.engine.shutdown_server()
return
sp = overseer_agent.get_primary_sp()
old_state_name = self.server_state.__class__.__name__
with self.lock:
with self.engine.new_context() as fl_ctx:
self.server_state = self.server_state.handle_sd_callback(sp, fl_ctx)
if isinstance(self.server_state, Cold2HotState):
self._turn_to_hot()
elif isinstance(self.server_state, Hot2ColdState):
self._turn_to_cold(old_state_name)
def _notify_state_change(self, old_state_name):
new_state_name = self.server_state.__class__.__name__
if new_state_name != old_state_name:
self.logger.info(f"state changed from: {old_state_name} to: {new_state_name}")
keys = list(self.engine.run_processes.keys())
if keys:
target_fqcns = []
for job_id in keys:
target_fqcns.append(FQCN.join([FQCN.ROOT_SERVER, job_id]))
cell_msg = new_cell_message(headers={}, payload=self.server_state)
self.cell.broadcast_request(
channel=CellChannel.SERVER_COMMAND,
topic=ServerCommandNames.SERVER_STATE,
request=cell_msg,
targets=target_fqcns,
timeout=5.0,
optional=True,
)
[docs] def overseer_callback(self, overseer_agent):
if self.checking_server_state:
self.logger.debug("busy checking server state")
return
self.checking_server_state = True
try:
self._check_server_state(overseer_agent)
except Exception as ex:
self.logger.error(f"exception in checking server state: {secure_format_exception(ex)}")
finally:
self.checking_server_state = False
def _turn_to_hot(self):
# Restore Snapshot
if self.ha_mode:
restored_job_ids = []
with self.snapshot_lock:
fl_snapshot = self.snapshot_persistor.retrieve()
if fl_snapshot:
for run_number, snapshot in fl_snapshot.run_snapshots.items():
if snapshot and not snapshot.completed:
# Restore the workspace
workspace_data = snapshot.get_component_snapshot(SnapshotKey.WORKSPACE).get("content")
dst = os.path.join(self.workspace, WorkspaceConstants.WORKSPACE_PREFIX + str(run_number))
if os.path.exists(dst):
shutil.rmtree(dst, ignore_errors=True)
os.makedirs(dst, exist_ok=True)
unzip_all_from_bytes(workspace_data, dst)
job_id = snapshot.get_component_snapshot(SnapshotKey.JOB_INFO).get(SnapshotKey.JOB_ID)
job_clients = snapshot.get_component_snapshot(SnapshotKey.JOB_INFO).get(
SnapshotKey.JOB_CLIENTS
)
self.logger.info(f"Restore the previous snapshot. Run_number: {run_number}")
with self.engine.new_context() as fl_ctx:
self.engine.job_runner.restore_running_job(
run_number=run_number,
job_id=job_id,
job_clients=job_clients,
snapshot=snapshot,
fl_ctx=fl_ctx,
)
restored_job_ids.append(job_id)
with self.engine.new_context() as fl_ctx:
self.engine.job_runner.update_abnormal_finished_jobs(restored_job_ids, fl_ctx=fl_ctx)
else:
with self.engine.new_context() as fl_ctx:
self.snapshot_persistor.delete()
self.engine.job_runner.update_unfinished_jobs(fl_ctx=fl_ctx)
with self.lock:
self.server_state = HotState(
host=self.server_state.host, port=self.server_state.service_port, ssid=self.server_state.ssid
)
def _turn_to_cold(self, old_state_name):
with self.lock:
self.server_state = ColdState(host=self.server_state.host, port=self.server_state.service_port)
self._notify_state_change(old_state_name)
self.engine.pause_server_jobs()
[docs] def stop_training(self):
self.status = ServerStatus.STOPPED
self.logger.info("Server app stopped.\n\n")
[docs] def fl_shutdown(self):
self.engine.stop_all_jobs()
self.engine.fire_event(EventType.SYSTEM_END, self.engine.new_context())
super().fl_shutdown()
[docs] def close(self):
"""Shutdown the server."""
self.logger.info("shutting down server")
self.shutdown = True
if self.overseer_agent:
self.overseer_agent.end()
return super().close()