Source code for nvflare.app_opt.xgboost.histogram_based_v2.controller

# Copyright (c) 2023, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import threading
import time
from typing import Optional

import xgboost

from nvflare.apis.client import Client
from nvflare.apis.controller_spec import ClientTask, Task
from nvflare.apis.fl_constant import FLContextKey
from nvflare.apis.fl_context import FLContext
from nvflare.apis.impl.controller import Controller
from nvflare.apis.shareable import ReturnCode, Shareable, make_reply
from nvflare.apis.signal import Signal
from nvflare.apis.utils.reliable_message import ReliableMessage
from nvflare.app_opt.xgboost.histogram_based_v2.adaptors.xgb_adaptor import XGBServerAdaptor
from nvflare.fuel.utils.validation_utils import check_number_range, check_object_type, check_positive_number, check_str
from nvflare.security.logging import secure_format_exception

from .defs import Constant


[docs] class ClientStatus: """ Objects of this class keep processing status of each FL client during job execution. """ def __init__(self): # Set when the client's config reply is received and the reply return code is OK. # If the client failed to reply or the return code is not OK, this value is not set. self.configured_time = None # Set when the client's start reply is received and the reply return code is OK. # If the client failed to reply or the return code is not OK, this value is not set. self.started_time = None # operation of the last XGB request from this client self.last_op = None # time of the last XGB op request from this client self.last_op_time = time.time() # whether the XGB process is done on this client self.xgb_done = False
[docs] class XGBController(Controller): def __init__( self, adaptor_component_id: str, num_rounds: int, data_split_mode: int, secure_training: bool, xgb_params: dict, xgb_options: Optional[dict] = None, disable_version_check=False, configure_task_name=Constant.CONFIG_TASK_NAME, configure_task_timeout=Constant.CONFIG_TASK_TIMEOUT, start_task_name=Constant.START_TASK_NAME, start_task_timeout=Constant.START_TASK_TIMEOUT, job_status_check_interval: float = Constant.JOB_STATUS_CHECK_INTERVAL, max_client_op_interval: float = Constant.MAX_CLIENT_OP_INTERVAL, progress_timeout: float = Constant.WORKFLOW_PROGRESS_TIMEOUT, client_ranks=None, ): """ Constructor For the meaning of XGBoost parameters, please refer to the documentation for train API, https://xgboost.readthedocs.io/en/stable/python/python_api.html#xgboost.train Args: adaptor_component_id - the component ID of server target adaptor num_rounds - number of rounds data_split_mode - 0 for horizontal/row-split, 1 for vertical/column-split secure_training - If true, secure training is enabled xgb_params - The params argument for train method xgb_options - All other arguments for train method are passed through this dictionary disable_version_check - If true, XGBoost version check for secure training is skipped configure_task_name - name of the config task configure_task_timeout - time to wait for clients’ responses to the config task before timeout. start_task_name - name of the start task start_task_timeout - time to wait for clients’ responses to the start task before timeout. job_status_check_interval - how often to check client statuses of the job max_client_op_interval - max amount of time allowed between XGB ops from a client progress_timeout- the maximum amount of time allowed for the workflow to not make any progress. In other words, at least one participating client must have made progress during this time. Otherwise, the workflow will be considered to be in trouble and the job will be aborted. client_ranks: client rank assignments. If specified, must be a dict of client_name => rank. If not specified, client ranks will be randomly assigned. No matter how assigned, ranks must be consecutive integers, starting from 0. """ Controller.__init__(self) self.adaptor_component_id = adaptor_component_id self.num_rounds = num_rounds self.data_split_mode = data_split_mode self.secure_training = secure_training self.xgb_params = xgb_params self.xgb_options = xgb_options self.disable_version_check = disable_version_check self.configure_task_name = configure_task_name self.start_task_name = start_task_name self.start_task_timeout = start_task_timeout self.configure_task_timeout = configure_task_timeout self.max_client_op_interval = max_client_op_interval self.progress_timeout = progress_timeout self.job_status_check_interval = job_status_check_interval self.client_ranks = client_ranks # client rank assignments self.adaptor = None self.participating_clients = None self.status_lock = threading.Lock() self.client_statuses = {} # client name => ClientStatus self.abort_signal = None if data_split_mode not in {0, 1}: raise ValueError(f"Invalid data_split_mode: {data_split_mode}. It must be either 0 or 1") if not self.xgb_params: raise ValueError("xgb_params can't be empty") if not self.xgb_options: self.xgb_options = {} check_str("adaptor_component_id", adaptor_component_id) check_number_range("configure_task_timeout", configure_task_timeout, min_value=1) check_number_range("start_task_timeout", start_task_timeout, min_value=1) check_positive_number("job_status_check_interval", job_status_check_interval) check_positive_number("num_rounds", num_rounds) check_number_range("max_client_op_interval", max_client_op_interval, min_value=10.0) check_number_range("progress_timeout", progress_timeout, min_value=5.0) if client_ranks: check_object_type("client_ranks", client_ranks, dict) # set up operation handlers self.op_table = { Constant.OP_ALL_GATHER: self._process_all_gather, Constant.OP_ALL_GATHER_V: self._process_all_gather_v, Constant.OP_ALL_REDUCE: self._process_all_reduce, Constant.OP_BROADCAST: self._process_broadcast, }
[docs] def get_adaptor(self, fl_ctx: FLContext): engine = fl_ctx.get_engine() return engine.get_component(self.adaptor_component_id)
[docs] def start_controller(self, fl_ctx: FLContext): all_clients = self._engine.get_clients() self.participating_clients = [t.name for t in all_clients] for c in self.participating_clients: self.client_statuses[c] = ClientStatus() adaptor = self.get_adaptor(fl_ctx) if not adaptor: self.system_panic(f"cannot get component for {self.adaptor_component_id}", fl_ctx) return None if not isinstance(adaptor, XGBServerAdaptor): self.system_panic( f"invalid component '{self.adaptor_component_id}': expect XGBServerBridge but got {type(adaptor)}", fl_ctx, ) return None adaptor.initialize(fl_ctx) self.adaptor = adaptor ReliableMessage.register_request_handler( topic=Constant.TOPIC_XGB_REQUEST, handler_f=self._process_xgb_request, fl_ctx=fl_ctx, ) ReliableMessage.register_request_handler( topic=Constant.TOPIC_CLIENT_DONE, handler_f=self._process_client_done, fl_ctx=fl_ctx, )
def _trigger_stop(self, fl_ctx: FLContext, error=None): # first trigger the abort_signal to tell all components (mainly the controller's control_flow and adaptor) # that check this signal to abort. if self.abort_signal: self.abort_signal.trigger(value=True) # if there is error, call system_panic to terminate the job with proper status. # if no error, the job will end normally. if error: self.system_panic(reason=error, fl_ctx=fl_ctx)
[docs] def handle_event(self, event_type: str, fl_ctx: FLContext): if event_type == Constant.EVENT_XGB_ABORTED: error = fl_ctx.get_prop(FLContextKey.FATAL_SYSTEM_ERROR) self.system_panic(f"XGB server stopped with error: {error}", fl_ctx) else: super().handle_event(event_type, fl_ctx)
def _is_stopped(self): # check whether the abort signal is triggered return self.abort_signal and self.abort_signal.triggered def _update_client_status(self, fl_ctx: FLContext, op=None, client_done=False): """Update the status of the requesting client. Args: fl_ctx: FL context op: the XGB operation requested client_done: whether the client is done Returns: None """ with self.status_lock: peer_ctx = fl_ctx.get_peer_context() if not peer_ctx: self.log_error(fl_ctx, "missing peer_ctx from fl_ctx") return if not isinstance(peer_ctx, FLContext): self.log_error(fl_ctx, f"expect peer_ctx to be FLContext but got {type(peer_ctx)}") return client_name = peer_ctx.get_identity_name() if not client_name: self.log_error(fl_ctx, "missing identity from peer_ctx") return status = self.client_statuses.get(client_name) if not status: self.log_error(fl_ctx, f"no status record for client {client_name}") assert isinstance(status, ClientStatus) if op: status.last_op = op if client_done: status.xgb_done = client_done status.last_op_time = time.time() def _process_client_done(self, topic: str, request: Shareable, fl_ctx: FLContext) -> Shareable: """Process the ClientDone report for a client Args: topic: topic of the message request: request to be processed fl_ctx: the FL context Returns: reply to the client """ exit_code = request.get(Constant.MSG_KEY_EXIT_CODE) if exit_code == 0: self.log_info(fl_ctx, f"XGB client is done with exit code {exit_code}") elif exit_code == Constant.EXIT_CODE_CANT_START: self.log_error(fl_ctx, f"XGB client failed to start (exit code {exit_code})") self.system_panic("XGB client failed to start", fl_ctx) elif exit_code == Constant.EXIT_CODE_JOB_ABORT: self.log_error(fl_ctx, f"XGB client aborted (exit code {exit_code})") self.system_panic("XGB client aborted", fl_ctx) else: # Should we stop here? # Problem is that even if the exit_code is not 0, we can't say the job failed. self.log_warning(fl_ctx, f"XGB client is done with exit code {exit_code}") self._update_client_status(fl_ctx, client_done=True) return make_reply(ReturnCode.OK) def _process_all_gather(self, request: Shareable, fl_ctx: FLContext) -> Shareable: """This is the op handler for Allgather. Args: request: the request containing op params fl_ctx: FL context Returns: a Shareable containing operation result """ rank = request.get(Constant.PARAM_KEY_RANK) seq = request.get(Constant.PARAM_KEY_SEQ) send_buf = request.get(Constant.PARAM_KEY_SEND_BUF) rcv_buf = self.adaptor.all_gather(rank, seq, send_buf, fl_ctx) reply = Shareable() reply[Constant.PARAM_KEY_RCV_BUF] = rcv_buf return reply def _process_all_gather_v(self, request: Shareable, fl_ctx: FLContext) -> Shareable: """This is the op handler for AllgatherV. Args: request: the request containing op params fl_ctx: FL context Returns: a Shareable containing operation result """ rank = request.get(Constant.PARAM_KEY_RANK) seq = request.get(Constant.PARAM_KEY_SEQ) send_buf = request.get(Constant.PARAM_KEY_SEND_BUF) fl_ctx.set_prop(key=Constant.PARAM_KEY_RANK, value=rank, private=True, sticky=False) fl_ctx.set_prop(key=Constant.PARAM_KEY_SEQ, value=seq, private=True, sticky=False) fl_ctx.set_prop(key=Constant.PARAM_KEY_SEND_BUF, value=send_buf, private=True, sticky=False) fl_ctx.set_prop(key=Constant.PARAM_KEY_REQUEST, value=request, private=True, sticky=False) self.fire_event(Constant.EVENT_BEFORE_ALL_GATHER_V, fl_ctx) send_buf = fl_ctx.get_prop(Constant.PARAM_KEY_SEND_BUF) rcv_buf = self.adaptor.all_gather_v(rank, seq, send_buf, fl_ctx) reply = Shareable() fl_ctx.set_prop(key=Constant.PARAM_KEY_REPLY, value=reply, private=True, sticky=False) fl_ctx.set_prop(key=Constant.PARAM_KEY_RCV_BUF, value=rcv_buf, private=True, sticky=False) self.fire_event(Constant.EVENT_AFTER_ALL_GATHER_V, fl_ctx) rcv_buf = fl_ctx.get_prop(Constant.PARAM_KEY_RCV_BUF) reply[Constant.PARAM_KEY_RCV_BUF] = rcv_buf return reply def _process_all_reduce(self, request: Shareable, fl_ctx: FLContext) -> Shareable: """This is the op handler for Allreduce. Args: request: the request containing op params fl_ctx: FL context Returns: a Shareable containing operation result """ rank = request.get(Constant.PARAM_KEY_RANK) seq = request.get(Constant.PARAM_KEY_SEQ) send_buf = request.get(Constant.PARAM_KEY_SEND_BUF) data_type = request.get(Constant.PARAM_KEY_DATA_TYPE) reduce_op = request.get(Constant.PARAM_KEY_REDUCE_OP) assert isinstance(self.adaptor, XGBServerAdaptor) rcv_buf = self.adaptor.all_reduce(rank, seq, data_type, reduce_op, send_buf, fl_ctx) reply = Shareable() reply[Constant.PARAM_KEY_RCV_BUF] = rcv_buf return reply def _process_broadcast(self, request: Shareable, fl_ctx: FLContext) -> Shareable: """This is the op handler for Broadcast. Args: request: the request containing op params fl_ctx: FL context Returns: a Shareable containing operation result """ rank = request.get(Constant.PARAM_KEY_RANK) seq = request.get(Constant.PARAM_KEY_SEQ) send_buf = request.get(Constant.PARAM_KEY_SEND_BUF) root = request.get(Constant.PARAM_KEY_ROOT) fl_ctx.set_prop(key=Constant.PARAM_KEY_RANK, value=rank, private=True, sticky=False) fl_ctx.set_prop(key=Constant.PARAM_KEY_SEQ, value=seq, private=True, sticky=False) fl_ctx.set_prop(key=Constant.PARAM_KEY_ROOT, value=root, private=True, sticky=False) fl_ctx.set_prop(key=Constant.PARAM_KEY_SEND_BUF, value=send_buf, private=True, sticky=False) fl_ctx.set_prop(key=Constant.PARAM_KEY_REQUEST, value=request, private=True, sticky=False) self.fire_event(Constant.EVENT_BEFORE_BROADCAST, fl_ctx) send_buf = fl_ctx.get_prop(Constant.PARAM_KEY_SEND_BUF) assert isinstance(self.adaptor, XGBServerAdaptor) rcv_buf = self.adaptor.broadcast(rank, seq, root, send_buf, fl_ctx) reply = Shareable() fl_ctx.set_prop(key=Constant.PARAM_KEY_REPLY, value=reply, private=True, sticky=False) fl_ctx.set_prop(key=Constant.PARAM_KEY_RCV_BUF, value=rcv_buf, private=True, sticky=False) self.fire_event(Constant.EVENT_AFTER_BROADCAST, fl_ctx) rcv_buf = fl_ctx.get_prop(Constant.PARAM_KEY_RCV_BUF) reply[Constant.PARAM_KEY_RCV_BUF] = rcv_buf return reply def _process_xgb_request(self, topic: str, request: Shareable, fl_ctx: FLContext) -> Shareable: op = request.get_header(Constant.MSG_KEY_XGB_OP) if self._is_stopped(): self.log_error(fl_ctx, f"dropped XGB request '{op}' since server is already stopped") return make_reply(ReturnCode.SERVICE_UNAVAILABLE) # since XGB protocol is very strict, we'll stop the control flow when any error occurs bad_req_error = "bad XGB request" process_error = "XGB request process error" if not op: self.log_error(fl_ctx, "missing op from XGB request") self._trigger_stop(fl_ctx, bad_req_error) return make_reply(ReturnCode.BAD_REQUEST_DATA) # find and call the op handlers process_f = self.op_table.get(op) if process_f is None: self.log_error(fl_ctx, f"invalid op '{op}' from XGB request") self._trigger_stop(fl_ctx, bad_req_error) return make_reply(ReturnCode.BAD_REQUEST_DATA) self._update_client_status(fl_ctx, op=op) if not callable(process_f): # impossible but we must declare process_f to be callable; otherwise PyCharm will complain about # process_f(request, fl_ctx). raise RuntimeError(f"op handler for {op} is not callable") try: reply = process_f(request, fl_ctx) except Exception as ex: self.log_exception(fl_ctx, f"exception processing {op}: {secure_format_exception(ex)}") self._trigger_stop(fl_ctx, process_error) return make_reply(ReturnCode.EXECUTION_EXCEPTION) self.log_info(fl_ctx, f"received reply for '{op}'") reply.set_header(Constant.MSG_KEY_XGB_OP, op) return reply def _configure_clients(self, abort_signal: Signal, fl_ctx: FLContext): self.log_info(fl_ctx, f"Configuring clients {self.participating_clients}") shareable = Shareable() # compute client ranks if not self.client_ranks: # dynamically assign ranks, starting from 0 # Assumption: all clients are used clients = self.participating_clients # Sort by client name so rank is consistent clients.sort() self.client_ranks = {clients[i]: i for i in range(0, len(clients))} else: # validate ranks - ranks must be unique consecutive integers, starting from 0. num_clients = len(self.participating_clients) assigned_ranks = {} # rank => client if len(self.client_ranks) != num_clients: # either missing client or duplicate client self.system_panic( f"expecting rank assignments for {self.participating_clients} but got {self.client_ranks}", fl_ctx ) return False # all clients must have ranks for c in self.participating_clients: if c not in self.client_ranks: self.system_panic(f"missing rank assignment for client '{c}'", fl_ctx) return False # check each client's rank for c, r in self.client_ranks.items(): if not isinstance(r, int): self.system_panic(f"bad rank assignment {r} for client '{c}': expect int but got {type(r)}", fl_ctx) return False if r < 0 or r >= num_clients: self.system_panic( f"bad rank assignment {r} for client '{c}': must be 0 to {num_clients - 1}", fl_ctx ) return False assigned_client = assigned_ranks.get(r) if assigned_client: self.system_panic(f"rank {r} is assigned to both client '{c}' and '{assigned_client}'", fl_ctx) return False assigned_ranks[r] = c shareable[Constant.CONF_KEY_CLIENT_RANKS] = self.client_ranks shareable[Constant.CONF_KEY_NUM_ROUNDS] = self.num_rounds shareable[Constant.CONF_KEY_DATA_SPLIT_MODE] = xgboost.core.DataSplitMode(self.data_split_mode) shareable[Constant.CONF_KEY_SECURE_TRAINING] = self.secure_training shareable[Constant.CONF_KEY_XGB_PARAMS] = self.xgb_params shareable[Constant.CONF_KEY_XGB_OPTIONS] = self.xgb_options shareable[Constant.CONF_KEY_DISABLE_VERSION_CHECK] = self.disable_version_check task = Task( name=self.configure_task_name, data=shareable, timeout=self.configure_task_timeout, result_received_cb=self._process_configure_reply, ) self.log_info(fl_ctx, f"sending task {self.configure_task_name} to clients {self.participating_clients}") start_time = time.time() self.broadcast_and_wait( task=task, targets=self.participating_clients, min_responses=len(self.participating_clients), fl_ctx=fl_ctx, abort_signal=abort_signal, ) time_taken = time.time() - start_time self.log_info(fl_ctx, f"client configuration took {time_taken} seconds") failed_clients = [] for c, cs in self.client_statuses.items(): assert isinstance(cs, ClientStatus) if not cs.configured_time: failed_clients.append(c) # if any client failed to configure, terminate the job if failed_clients: self.system_panic(f"failed to configure clients {failed_clients}", fl_ctx) return False self.log_info(fl_ctx, f"successfully configured clients {self.participating_clients}") return True def _start_clients(self, abort_signal: Signal, fl_ctx: FLContext): self.log_info(fl_ctx, f"Starting clients {self.participating_clients}") task = Task( name=self.start_task_name, data=Shareable(), timeout=self.start_task_timeout, result_received_cb=self._process_start_reply, ) self.log_info(fl_ctx, f"sending task {self.start_task_name} to clients {self.participating_clients}") start_time = time.time() self.broadcast_and_wait( task=task, targets=self.participating_clients, min_responses=len(self.participating_clients), fl_ctx=fl_ctx, abort_signal=abort_signal, ) time_taken = time.time() - start_time self.log_info(fl_ctx, f"client starting took {time_taken} seconds") failed_clients = [] for c, cs in self.client_statuses.items(): assert isinstance(cs, ClientStatus) if not cs.started_time: failed_clients.append(c) # if any client failed to start, terminate the job if failed_clients: self.system_panic(f"failed to start clients {failed_clients}", fl_ctx) return False self.log_info(fl_ctx, f"successfully started clients {self.participating_clients}") return True
[docs] def control_flow(self, abort_signal: Signal, fl_ctx: FLContext): """ This is the control flow of the XGB Controller. To ensure smooth XGB execution: - ensure that all clients are online and ready to go before starting server - ensure that server is started and ready to take requests before asking clients to start operation - monitor the health of the clients - if anything goes wrong, terminate the job Args: abort_signal: abort signal that is used to notify components to abort fl_ctx: FL context Returns: None """ self.abort_signal = abort_signal # the adaptor uses the same abort signal! self.adaptor.set_abort_signal(abort_signal) # wait for every client to become online and properly configured self.log_info(fl_ctx, f"Waiting for clients to be ready: {self.participating_clients}") # configure all clients if not self._configure_clients(abort_signal, fl_ctx): self.system_panic("failed to configure all clients", fl_ctx) return # start the server adaptor try: self.adaptor.configure({Constant.CONF_KEY_WORLD_SIZE: len(self.participating_clients)}, fl_ctx) self.adaptor.start(fl_ctx) except Exception as ex: error = f"failed to start bridge: {secure_format_exception(ex)}" self.log_error(fl_ctx, error) self.system_panic(error, fl_ctx) return self.adaptor.monitor_target(fl_ctx, self._xgb_server_stopped) # start all clients if not self._start_clients(abort_signal, fl_ctx): self.system_panic("failed to start all clients", fl_ctx) return # monitor client health # we periodically check job status until all clients are done or the system is stopped self.log_info(fl_ctx, "Waiting for clients to finish ...") while not self._is_stopped(): done = self._check_job_status(fl_ctx) if done: break time.sleep(self.job_status_check_interval)
def _xgb_server_stopped(self, rc, fl_ctx: FLContext): # This CB is called when XGB server target is stopped error = None if rc != 0: self.log_error(fl_ctx, f"XGB Server stopped abnormally with code {rc}") error = "XGB server abnormal stop" # the XGB server could stop at any moment, we trigger the abort_signal in case it is checked by any # other components self._trigger_stop(fl_ctx, error) def _process_configure_reply(self, client_task: ClientTask, fl_ctx: FLContext): result = client_task.result client_name = client_task.client.name rc = result.get_return_code() if rc == ReturnCode.OK: self.log_info(fl_ctx, f"successfully configured client {client_name}") cs = self.client_statuses.get(client_name) if cs: assert isinstance(cs, ClientStatus) cs.configured_time = time.time() else: self.log_error(fl_ctx, f"client {client_task.client.name} failed to configure: {rc}") def _process_start_reply(self, client_task: ClientTask, fl_ctx: FLContext): result = client_task.result client_name = client_task.client.name rc = result.get_return_code() if rc == ReturnCode.OK: self.log_info(fl_ctx, f"successfully started client {client_name}") cs = self.client_statuses.get(client_name) if cs: assert isinstance(cs, ClientStatus) cs.started_time = time.time() else: self.log_error(fl_ctx, f"client {client_name} failed to start") def _check_job_status(self, fl_ctx: FLContext) -> bool: """Check job status and determine whether the job is done. Args: fl_ctx: FL context Returns: whether the job is considered done. """ now = time.time() # overall_last_progress_time is the latest time that any client made progress. overall_last_progress_time = 0.0 clients_done = 0 for client_name, cs in self.client_statuses.items(): assert isinstance(cs, ClientStatus) if cs.xgb_done: self.log_info(fl_ctx, f"client {client_name} is Done") clients_done += 1 elif now - cs.last_op_time > self.max_client_op_interval: self.system_panic( f"client {client_name} didn't have any activity for {self.max_client_op_interval} seconds", fl_ctx, ) return True if overall_last_progress_time < cs.last_op_time: overall_last_progress_time = cs.last_op_time if clients_done == len(self.client_statuses): # all clients are done - the job is considered done return True elif time.time() - overall_last_progress_time > self.progress_timeout: # there has been no progress from any client for too long. # this could be because the clients got stuck. # consider the job done and abort the job. self.system_panic(f"the job has no progress for {self.progress_timeout} seconds", fl_ctx) return True return False
[docs] def process_result_of_unknown_task( self, client: Client, task_name: str, client_task_id: str, result: Shareable, fl_ctx: FLContext ): self.log_warning(fl_ctx, f"ignored unknown task {task_name} from client {client.name}")
[docs] def stop_controller(self, fl_ctx: FLContext): if self.adaptor: self.log_info(fl_ctx, "Stopping server bridge") self.adaptor.stop(fl_ctx)