Source code for nvflare.private.fed.server.server_engine

# Copyright (c) 2021-2022, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import copy
import logging
import multiprocessing
import os
import re
import shlex
import shutil
import subprocess
import sys
import threading
import time
from concurrent.futures import ThreadPoolExecutor
from multiprocessing.connection import Client as CommandClient
from multiprocessing.connection import Listener
from threading import Lock
from typing import Dict, List, Tuple

from nvflare.apis.client import Client
from nvflare.apis.fl_component import FLComponent
from nvflare.apis.fl_constant import (
    AdminCommandNames,
    FLContextKey,
    MachineStatus,
    ReservedTopic,
    ReturnCode,
    RunProcessKey,
    ServerCommandKey,
    ServerCommandNames,
    SnapshotKey,
    WorkspaceConstants,
)
from nvflare.apis.fl_context import FLContext, FLContextManager
from nvflare.apis.fl_snapshot import RunSnapshot
from nvflare.apis.impl.job_def_manager import JobDefManagerSpec
from nvflare.apis.shareable import Shareable, make_reply
from nvflare.apis.utils.common_utils import get_open_ports
from nvflare.apis.utils.fl_context_utils import get_serializable_data
from nvflare.apis.workspace import Workspace
from nvflare.fuel.hci.zip_utils import zip_directory_to_bytes
from nvflare.fuel.utils import fobs
from nvflare.private.admin_defs import Message
from nvflare.private.defs import RequestHeader, TrainingTopic
from nvflare.private.fed.server.server_json_config import ServerJsonConfigurator
from nvflare.private.scheduler_constants import ShareableHeader
from nvflare.widgets.info_collector import InfoCollector
from nvflare.widgets.widget import Widget, WidgetID

from .admin import ClientReply
from .client_manager import ClientManager
from .job_runner import JobRunner
from .run_manager import RunManager
from .server_engine_internal_spec import EngineInfo, RunInfo, ServerEngineInternalSpec
from .server_status import ServerStatus


[docs]class ClientConnection: def __init__(self, client): self.client = client
[docs] def send(self, data): try: buf = fobs.dumps(data) except BaseException as e: raise e self.client.send(buf)
[docs] def recv(self): return self.client.recv()
[docs]class ServerEngine(ServerEngineInternalSpec): def __init__(self, server, args, client_manager: ClientManager, snapshot_persistor, workers=3): """Server engine. Args: server: server args: arguments client_manager (ClientManager): client manager. workers: number of worker threads. """ # TODO:: clean up the server function / requirement here should be BaseServer self.server = server self.args = args self.run_processes = {} self.execution_exception_run_processes = {} self.run_manager = None self.conf = None # TODO:: does this class need client manager? self.client_manager = client_manager self.widgets = { WidgetID.INFO_COLLECTOR: InfoCollector(), # WidgetID.FED_EVENT_RUNNER: ServerFedEventRunner() } self.engine_info = EngineInfo() if not workers >= 1: raise ValueError("workers must >= 1 but got {}".format(workers)) self.executor = ThreadPoolExecutor(max_workers=workers) self.lock = Lock() self.logger = logging.getLogger(self.__class__.__name__) self.asked_to_stop = False self.snapshot_persistor = snapshot_persistor self.parent_conn = None self.parent_conn_lock = Lock() self.job_runner = None self.job_def_manager = None self.snapshot_lock = multiprocessing.Lock() def _get_server_app_folder(self): return WorkspaceConstants.APP_PREFIX + "server" def _get_client_app_folder(self, client_name): return WorkspaceConstants.APP_PREFIX + client_name def _get_run_folder(self, job_id): return os.path.join(self.args.workspace, WorkspaceConstants.WORKSPACE_PREFIX + str(job_id))
[docs] def get_engine_info(self) -> EngineInfo: self.engine_info.app_names = {} if bool(self.run_processes): self.engine_info.status = MachineStatus.STARTED else: self.engine_info.status = MachineStatus.STOPPED for job_id, _ in self.run_processes.items(): run_folder = os.path.join(self.args.workspace, WorkspaceConstants.WORKSPACE_PREFIX + str(job_id)) app_file = os.path.join(run_folder, "fl_app.txt") if os.path.exists(app_file): with open(app_file, "r") as f: self.engine_info.app_names[job_id] = f.readline().strip() else: self.engine_info.app_names[job_id] = "?" return self.engine_info
[docs] def get_run_info(self) -> RunInfo: if self.run_manager: return self.run_manager.get_run_info() else: return None
[docs] def create_parent_connection(self, port): while not self.parent_conn: try: address = ("localhost", port) self.parent_conn = CommandClient(address, authkey="parent process secret password".encode()) except BaseException: time.sleep(1.0) pass threading.Thread(target=self.heartbeat_to_parent, args=[]).start()
[docs] def heartbeat_to_parent(self): while True: try: with self.parent_conn_lock: data = {ServerCommandKey.COMMAND: ServerCommandNames.HEARTBEAT, ServerCommandKey.DATA: {}} self.parent_conn.send(data) time.sleep(1.0) except BaseException: # The parent process can not be reached. Terminate the child process. break # delay some time for the wrap up process before the child process self terminate. time.sleep(30) os.killpg(os.getpgid(os.getpid()), 9)
[docs] def delete_job_id(self, num): job_id_folder = os.path.join(self.args.workspace, WorkspaceConstants.WORKSPACE_PREFIX + str(num)) if os.path.exists(job_id_folder): shutil.rmtree(job_id_folder) return ""
[docs] def get_clients(self) -> [Client]: return list(self.client_manager.get_clients().values())
[docs] def validate_clients(self, client_names: List[str]) -> Tuple[List[Client], List[str]]: return self._get_all_clients_from_inputs(client_names)
[docs] def start_app_on_server(self, run_number: str, job_id: str = None, job_clients=None, snapshot=None) -> str: if run_number in self.run_processes.keys(): return f"Server run_{run_number} already started." else: app_root = os.path.join(self._get_run_folder(run_number), self._get_server_app_folder()) if not os.path.exists(app_root): return "Server app does not exist. Please deploy the server app before starting." self.engine_info.status = MachineStatus.STARTING app_custom_folder = "" if self.server.enable_byoc: app_custom_folder = os.path.join(app_root, "custom") open_ports = get_open_ports(2) self._start_runner_process( self.args, app_root, run_number, app_custom_folder, open_ports, job_id, job_clients, snapshot ) threading.Thread(target=self._listen_command, args=(open_ports[0], run_number)).start() self.engine_info.status = MachineStatus.STARTED return ""
def _listen_command(self, listen_port, job_id): address = ("localhost", int(listen_port)) listener = Listener(address, authkey="parent process secret password".encode()) conn = listener.accept() while job_id in self.run_processes.keys(): clients = self.run_processes.get(job_id).get(RunProcessKey.PARTICIPANTS) job_id = self.run_processes.get(job_id).get(RunProcessKey.JOB_ID) try: if conn.poll(0.1): received_data = conn.recv() command = received_data.get(ServerCommandKey.COMMAND) data = received_data.get(ServerCommandKey.DATA) if command == ServerCommandNames.GET_CLIENTS: return_data = {ServerCommandKey.CLIENTS: clients, ServerCommandKey.JOB_ID: job_id} conn.send(return_data) elif command == ServerCommandNames.AUX_SEND: targets = data.get("targets") topic = data.get("topic") request = data.get("request") timeout = data.get("timeout") fl_ctx = data.get("fl_ctx") replies = self.aux_send( targets=targets, topic=topic, request=request, timeout=timeout, fl_ctx=fl_ctx ) conn.send(replies) except BaseException as e: self.logger.warning(f"Failed to process the child process command: {e}", exc_info=True)
[docs] def wait_for_complete(self, job_id): while True: try: with self.lock: command_conn = self.get_command_conn(job_id) if command_conn: data = {ServerCommandKey.COMMAND: ServerCommandNames.HEARTBEAT, ServerCommandKey.DATA: {}} command_conn.send(data) time.sleep(1.0) except BaseException: with self.lock: run_process_info = self.run_processes.pop(job_id) return_code = run_process_info[RunProcessKey.CHILD_PROCESS].poll() # if process exit but with Execution exception if return_code and return_code != 0: self.execution_exception_run_processes[job_id] = run_process_info self.engine_info.status = MachineStatus.STOPPED break
def _start_runner_process( self, args, app_root, run_number, app_custom_folder, open_ports, job_id, job_clients, snapshot ): new_env = os.environ.copy() if app_custom_folder != "": new_env["PYTHONPATH"] = new_env.get("PYTHONPATH", "") + os.pathsep + app_custom_folder listen_port = open_ports[1] if snapshot: restore_snapshot = True else: restore_snapshot = False command_options = "" for t in args.set: command_options += " " + t command = ( sys.executable + " -m nvflare.private.fed.app.server.runner_process -m " + args.workspace + " -s fed_server.json -r " + app_root + " -n " + str(run_number) + " -p " + str(listen_port) + " -c " + str(open_ports[0]) + " --set" + command_options + " print_conf=True restore_snapshot=" + str(restore_snapshot) ) # use os.setsid to create new process group ID process = subprocess.Popen(shlex.split(command, True), preexec_fn=os.setsid, env=new_env) if not job_id: job_id = "" if not job_clients: job_clients = self.client_manager.clients with self.lock: self.run_processes[run_number] = { RunProcessKey.LISTEN_PORT: listen_port, RunProcessKey.CONNECTION: None, RunProcessKey.CHILD_PROCESS: process, RunProcessKey.JOB_ID: job_id, RunProcessKey.PARTICIPANTS: job_clients, } threading.Thread(target=self.wait_for_complete, args=[run_number]).start() return process
[docs] def get_job_clients(self, client_sites): job_clients = {} if client_sites: for site, dispatch_info in client_sites.items(): client = self.get_client_from_name(site) if client: job_clients[client.token] = client return job_clients
[docs] def remove_custom_path(self): regex = re.compile(".*/run_.*/custom") custom_paths = list(filter(regex.search, sys.path)) for path in custom_paths: sys.path.remove(path)
[docs] def abort_app_on_clients(self, clients: List[str]) -> str: status = self.engine_info.status if status == MachineStatus.STOPPED: return "Server app has not started." if status == MachineStatus.STARTING: return "Server app is starting, please wait for started before abort." return ""
[docs] def abort_app_on_server(self, job_id: str) -> str: if job_id not in self.run_processes.keys(): return "Server app has not started." self.logger.info("Abort the server app run.") try: with self.lock: command_conn = self.get_command_conn(job_id) if command_conn: data = {ServerCommandKey.COMMAND: AdminCommandNames.ABORT, ServerCommandKey.DATA: {}} command_conn.send(data) status_message = command_conn.recv() self.logger.info(f"Abort server: {status_message}") except BaseException: with self.lock: child_process = self.run_processes.get(job_id, {}).get(RunProcessKey.CHILD_PROCESS, None) if child_process: child_process.terminate() finally: with self.lock: self.run_processes.pop(job_id) self.engine_info.status = MachineStatus.STOPPED return ""
[docs] def check_app_start_readiness(self, job_id: str) -> str: if job_id not in self.run_processes.keys(): return f"Server app run_{job_id} has not started." return ""
[docs] def shutdown_server(self) -> str: status = self.server.status if status == ServerStatus.STARTING: return "Server app is starting, please wait for started before shutdown." self.logger.info("FL server shutdown.") touch_file = os.path.join(self.args.workspace, "shutdown.fl") _ = self.executor.submit(lambda p: server_shutdown(*p), [self.server, touch_file]) while self.server.status != ServerStatus.SHUTDOWN: time.sleep(1.0) return ""
[docs] def restart_server(self) -> str: status = self.server.status if status == ServerStatus.STARTING: return "Server is starting, please wait for started before restart." self.logger.info("FL server restart.") touch_file = os.path.join(self.args.workspace, "restart.fl") _ = self.executor.submit(lambda p: server_shutdown(*p), [self.server, touch_file]) while self.server.status != ServerStatus.SHUTDOWN: time.sleep(1.0) return ""
[docs] def get_widget(self, widget_id: str) -> Widget: return self.widgets.get(widget_id)
[docs] def get_client_name_from_token(self, token: str) -> str: client = self.server.client_manager.clients.get(token) if client: return client.name else: return ""
[docs] def get_all_clients(self): return list(self.server.client_manager.clients.keys())
[docs] def get_client_from_name(self, client_name): for c in self.get_clients(): if client_name == c.name: return c return None
def _get_all_clients_from_inputs(self, inputs): clients = [] invalid_inputs = [] for item in inputs: client = self.client_manager.clients.get(item) # if item in self.get_all_clients(): if client: clients.append(client) else: client = self.get_client_from_name(item) if client: clients.append(client) else: invalid_inputs.append(item) return clients, invalid_inputs
[docs] def get_app_data(self, app_name: str) -> Tuple[str, object]: fullpath_src = os.path.join(self.server.admin_server.file_upload_dir, app_name) if not os.path.exists(fullpath_src): return f"App folder '{app_name}' does not exist in staging area.", None data = zip_directory_to_bytes(fullpath_src, "") return "", data
[docs] def get_app_run_info(self, job_id) -> RunInfo: run_info = None try: with self.lock: command_conn = self.get_command_conn(job_id) if command_conn: data = {ServerCommandKey.COMMAND: ServerCommandNames.GET_RUN_INFO, ServerCommandKey.DATA: {}} command_conn.send(data) run_info = command_conn.recv() except BaseException: self.logger.error(f"Failed to get_app_run_info from run_{job_id}") return run_info
[docs] def set_run_manager(self, run_manager: RunManager): self.run_manager = run_manager for _, widget in self.widgets.items(): self.run_manager.add_handler(widget)
[docs] def set_job_runner(self, job_runner: JobRunner, job_manager: JobDefManagerSpec): self.job_runner = job_runner self.job_def_manager = job_manager
[docs] def set_configurator(self, conf: ServerJsonConfigurator): if not isinstance(conf, ServerJsonConfigurator): raise TypeError("conf must be ServerJsonConfigurator but got {}".format(type(conf))) self.conf = conf
[docs] def build_component(self, config_dict): return self.conf.build_component(config_dict)
[docs] def new_context(self) -> FLContext: if self.run_manager: return self.run_manager.new_context() else: # return FLContext() return FLContextManager( engine=self, identity_name=self.server.project_name, job_id="", public_stickers={}, private_stickers={} ).new_context()
[docs] def get_component(self, component_id: str) -> object: return self.run_manager.get_component(component_id)
[docs] def fire_event(self, event_type: str, fl_ctx: FLContext): self.run_manager.fire_event(event_type, fl_ctx)
[docs] def get_staging_path_of_app(self, app_name: str) -> str: return os.path.join(self.server.admin_server.file_upload_dir, app_name)
[docs] def deploy_app_to_server(self, run_destination: str, app_name: str, app_staging_path: str) -> str: return self.deploy_app(run_destination, app_name, WorkspaceConstants.APP_PREFIX + "server")
[docs] def get_workspace(self) -> Workspace: return self.run_manager.get_workspace()
[docs] def ask_to_stop(self): self.asked_to_stop = True
[docs] def deploy_app(self, job_id, src, dst): fullpath_src = os.path.join(self.server.admin_server.file_upload_dir, src) fullpath_dst = os.path.join(self._get_run_folder(job_id), dst) if not os.path.exists(fullpath_src): return f"App folder '{src}' does not exist in staging area." if os.path.exists(fullpath_dst): shutil.rmtree(fullpath_dst) shutil.copytree(fullpath_src, fullpath_dst) app_file = os.path.join(self._get_run_folder(job_id), "fl_app.txt") if os.path.exists(app_file): os.remove(app_file) with open(app_file, "wt") as f: f.write(f"{src}") return ""
[docs] def remove_clients(self, clients: List[str]) -> str: for client in clients: self._remove_dead_client(client) return ""
def _remove_dead_client(self, token): _ = self.server.client_manager.remove_client(token) self.server.remove_client_data(token) if self.server.admin_server: self.server.admin_server.client_dead(token)
[docs] def register_aux_message_handler(self, topic: str, message_handle_func): self.run_manager.aux_runner.register_aux_message_handler(topic, message_handle_func)
[docs] def send_aux_request(self, targets: [], topic: str, request: Shareable, timeout: float, fl_ctx: FLContext) -> dict: try: if not targets: self.sync_clients_from_main_process() targets = [] for t in self.get_clients(): targets.append(t.name) if targets: return self.run_manager.aux_runner.send_aux_request( targets=targets, topic=topic, request=request, timeout=timeout, fl_ctx=fl_ctx ) else: return {} except Exception as e: self.logger.error(f"Failed to send the aux_message: {topic} with exception: {e}.")
[docs] def sync_clients_from_main_process(self): with self.parent_conn_lock: data = {ServerCommandKey.COMMAND: ServerCommandNames.GET_CLIENTS, ServerCommandKey.DATA: {}} self.parent_conn.send(data) return_data = self.parent_conn.recv() clients = return_data.get(ServerCommandKey.CLIENTS) self.client_manager.clients = clients
[docs] def parent_aux_send(self, targets: [], topic: str, request: Shareable, timeout: float, fl_ctx: FLContext) -> dict: with self.parent_conn_lock: data = { ServerCommandKey.COMMAND: ServerCommandNames.AUX_SEND, ServerCommandKey.DATA: { "targets": targets, "topic": topic, "request": request, "timeout": timeout, "fl_ctx": get_serializable_data(fl_ctx), }, } self.parent_conn.send(data) return_data = self.parent_conn.recv() return return_data
[docs] def aux_send(self, targets: [], topic: str, request: Shareable, timeout: float, fl_ctx: FLContext) -> dict: # Send the aux messages through admin_server request.set_peer_props(fl_ctx.get_all_public_props()) message = Message(topic=ReservedTopic.AUX_COMMAND, body=fobs.dumps(request)) message.set_header(RequestHeader.JOB_ID, str(fl_ctx.get_prop(FLContextKey.CURRENT_RUN))) requests = {} for n in targets: requests.update({n: message}) replies = self.server.admin_server.send_requests(requests, timeout_secs=timeout) results = {} for r in replies: client_name = self.get_client_name_from_token(r.client_token) if r.reply: try: results[client_name] = fobs.loads(r.reply.body) except BaseException: results[client_name] = make_reply(ReturnCode.COMMUNICATION_ERROR) self.logger.error( f"Received unexpected reply from client: {client_name}, " f"message body:{r.reply.body} processing topic:{topic}" ) else: results[client_name] = None return results
[docs] def get_command_conn(self, job_id): # this function need to be called with self.lock port = self.run_processes.get(job_id, {}).get(RunProcessKey.LISTEN_PORT) command_conn = self.run_processes.get(job_id, {}).get(RunProcessKey.CONNECTION, None) if not command_conn: try: address = ("localhost", port) command_conn = CommandClient(address, authkey="client process secret password".encode()) command_conn = ClientConnection(command_conn) self.run_processes[job_id][RunProcessKey.CONNECTION] = command_conn except Exception: pass return command_conn
[docs] def persist_components(self, fl_ctx: FLContext, completed: bool): # Call the State Persistor to persist all the component states # 1. call every component to generate the component states data # Make sure to include the current round number # 2. call persistence API to save the component states try: job_id = fl_ctx.get_job_id() snapshot = RunSnapshot(job_id) for component_id, component in self.run_manager.components.items(): if isinstance(component, FLComponent): snapshot.set_component_snapshot( component_id=component_id, component_state=component.get_persist_state(fl_ctx) ) snapshot.set_component_snapshot( component_id=SnapshotKey.FL_CONTEXT, component_state=copy.deepcopy(get_serializable_data(fl_ctx).props) ) workspace = fl_ctx.get_prop(FLContextKey.WORKSPACE_OBJECT) data = zip_directory_to_bytes(workspace.get_run_dir(fl_ctx.get_prop(FLContextKey.CURRENT_RUN)), "") snapshot.set_component_snapshot(component_id=SnapshotKey.WORKSPACE, component_state={"content": data}) job_info = fl_ctx.get_prop(FLContextKey.JOB_INFO) if not job_info: with self.parent_conn_lock: data = {ServerCommandKey.COMMAND: ServerCommandNames.GET_CLIENTS, ServerCommandKey.DATA: {}} self.parent_conn.send(data) return_data = self.parent_conn.recv() job_id = return_data.get(ServerCommandKey.JOB_ID) job_clients = return_data.get(ServerCommandKey.CLIENTS) fl_ctx.set_prop(FLContextKey.JOB_INFO, (job_id, job_clients)) else: (job_id, job_clients) = job_info snapshot.set_component_snapshot( component_id=SnapshotKey.JOB_INFO, component_state={SnapshotKey.JOB_CLIENTS: job_clients, SnapshotKey.JOB_ID: job_id}, ) snapshot.completed = completed self.server.snapshot_location = self.snapshot_persistor.save(snapshot=snapshot) if not completed: self.logger.info(f"persist the snapshot to: {self.server.snapshot_location}") else: self.logger.info(f"The snapshot: {self.server.snapshot_location} has been removed.") except BaseException as e: self.logger.error(f"Failed to persist the components. {str(e)}")
[docs] def restore_components(self, snapshot: RunSnapshot, fl_ctx: FLContext): for component_id, component in self.run_manager.components.items(): component.restore(snapshot.get_component_snapshot(component_id=component_id), fl_ctx) fl_ctx.props.update(snapshot.get_component_snapshot(component_id=SnapshotKey.FL_CONTEXT))
[docs] def dispatch(self, topic: str, request: Shareable, fl_ctx: FLContext) -> Shareable: return self.run_manager.aux_runner.dispatch(topic=topic, request=request, fl_ctx=fl_ctx)
[docs] def show_stats(self, job_id): stats = None try: with self.lock: command_conn = self.get_command_conn(job_id) if command_conn: data = {ServerCommandKey.COMMAND: ServerCommandNames.SHOW_STATS, ServerCommandKey.DATA: {}} command_conn.send(data) stats = command_conn.recv() except BaseException: self.logger.error(f"Failed to get_stats from run_{job_id}") return stats
[docs] def get_errors(self, job_id): stats = None try: with self.lock: command_conn = self.get_command_conn(job_id) if command_conn: data = {ServerCommandKey.COMMAND: ServerCommandNames.GET_ERRORS, ServerCommandKey.DATA: {}} command_conn.send(data) stats = command_conn.recv() except BaseException: self.logger.error(f"Failed to get_stats from run_{job_id}") return stats
def _send_admin_requests(self, requests, timeout_secs=10) -> List[ClientReply]: return self.server.admin_server.send_requests(requests, timeout_secs=timeout_secs)
[docs] def check_client_resources(self, resource_reqs) -> Dict[str, Tuple[bool, str]]: requests = {} for site_name, resource_requirements in resource_reqs.items(): # assume server resource is unlimited if site_name == "server": continue request = Message(topic=TrainingTopic.CHECK_RESOURCE, body=fobs.dumps(resource_requirements)) client = self.get_client_from_name(site_name) if client: requests.update({client.token: request}) replies = [] if requests: replies = self._send_admin_requests(requests, 15) result = {} for r in replies: site_name = self.get_client_name_from_token(r.client_token) if r.reply: resp = fobs.loads(r.reply.body) result[site_name] = ( resp.get_header(ShareableHeader.CHECK_RESOURCE_RESULT, False), resp.get_header(ShareableHeader.RESOURCE_RESERVE_TOKEN, ""), ) else: result[site_name] = (False, "") return result
[docs] def cancel_client_resources( self, resource_check_results: Dict[str, Tuple[bool, str]], resource_reqs: Dict[str, dict] ): requests = {} for site_name, result in resource_check_results.items(): check_result, token = result if check_result and token: resource_requirements = resource_reqs[site_name] request = Message(topic=TrainingTopic.CANCEL_RESOURCE, body=fobs.dumps(resource_requirements)) request.set_header(ShareableHeader.RESOURCE_RESERVE_TOKEN, token) client = self.get_client_from_name(site_name) if client: requests.update({client.token: request}) if requests: _ = self._send_admin_requests(requests)
[docs] def start_client_job(self, job_id, client_sites): requests = {} for site, dispatch_info in client_sites.items(): resource_requirement = dispatch_info.resource_requirements token = dispatch_info.token request = Message(topic=TrainingTopic.START_JOB, body=fobs.dumps(resource_requirement)) request.set_header(RequestHeader.JOB_ID, job_id) request.set_header(ShareableHeader.RESOURCE_RESERVE_TOKEN, token) client = self.get_client_from_name(site) if client: requests.update({client.token: request}) replies = [] if requests: replies = self._send_admin_requests(requests, timeout_secs=20) return replies
[docs] def stop_all_jobs(self): fl_ctx = self.new_context() self.job_runner.stop_all_runs(fl_ctx)
[docs] def close(self): self.executor.shutdown()
[docs]def server_shutdown(server, touch_file): with open(touch_file, "a"): os.utime(touch_file, None) try: server.fl_shutdown() server.admin_server.stop() time.sleep(3.0) finally: server.status = ServerStatus.SHUTDOWN sys.exit(2)
[docs]def copy_new_server_properties(server, new_server): # server.model_manager = new_server.model_manager # server.model_saver = new_server.model_saver server.builder = new_server.builder server.wait_after_min_clients = new_server.wait_after_min_clients server.outbound_filters = new_server.outbound_filters server.inbound_filters = new_server.inbound_filters server.cmd_modules = new_server.cmd_modules server.processors = new_server.processors # server.task_name = new_server.task_name server.min_num_clients = new_server.min_num_clients server.max_num_clients = new_server.max_num_clients server.current_round = new_server.current_round server.num_rounds = new_server.num_rounds server.start_round = new_server.start_round # server.heart_beat_timeout = new_server.heart_beat_timeout # server.handlers = new_server.handlers # clients = server.client_manager.clients # server.client_manager = new_server.client_manager # server.client_manager.clients = clients server.client_manager.min_num_clients = new_server.client_manager.min_num_clients server.client_manager.max_num_clients = new_server.client_manager.max_num_clients server.client_manager.logger = new_server.client_manager.logger server.client_manager.logger.disabled = False server.reset_tokens() server.contributed_clients.clear() # server.accumulator.clear() server.fl_ctx = new_server.fl_ctx server.controller = new_server.controller
# server.model_aggregator = new_server.model_aggregator # server.model_saver = new_server.model_saver # server.shareable_generator = new_server.shareable_generator
[docs]def set_up_run_config(server, conf): server.heart_beat_timeout = conf.heartbeat_timeout server.runner_config = conf.runner_config server.handlers = conf.handlers