Source code for nvflare.app_opt.xgboost.histogram_based_v2.runners.client_runner

# Copyright (c) 2024, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
from typing import Tuple

import xgboost as xgb
from xgboost import callback

from nvflare.apis.fl_component import FLComponent
from nvflare.apis.fl_context import FLContext
from nvflare.app_common.tracking.log_writer import LogWriter
from nvflare.app_opt.xgboost.data_loader import XGBDataLoader
from nvflare.app_opt.xgboost.histogram_based_v2.defs import Constant
from nvflare.app_opt.xgboost.histogram_based_v2.runner import XGBRunner
from nvflare.app_opt.xgboost.metrics_cb import MetricsCallback
from nvflare.fuel.utils.obj_utils import get_logger


[docs]class XGBoostParams: def __init__( self, xgb_params: dict, num_rounds: int = 10, early_stopping_rounds: int = 2, verbose_eval: bool = False ): """Container for all XGBoost parameters. Args: xgb_params: This dict is passed to `xgboost.train()` as the first argument `params`. It contains all the Booster parameters. Please refer to XGBoost documentation for details: https://xgboost.readthedocs.io/en/stable/parameter.html """ self.num_rounds = num_rounds self.early_stopping_rounds = early_stopping_rounds self.verbose_eval = verbose_eval self.xgb_params: dict = xgb_params if xgb_params else {}
[docs]class XGBClientRunner(XGBRunner, FLComponent): def __init__( self, data_loader_id: str, early_stopping_rounds: int, xgb_params: dict, verbose_eval: bool, use_gpus: bool, model_file_name: str, metrics_writer_id: str = None, ): """Constructor. Args: early_stopping_rounds: early stopping rounds xgb_params: This dict is passed to `xgboost.train()` as the first argument `params`. It contains all the Booster parameters. Please refer to XGBoost documentation for details: https://xgboost.readthedocs.io/en/stable/parameter.html data_loader_id: the ID points to XGBDataLoader. verbose_eval: verbose_eval in xgboost.train use_gpus (bool): A convenient flag to enable gpu training, if gpu device is specified in the `xgb_params` then this flag can be ignored. metrics_writer_id: the ID points to a LogWriter, if provided, a MetricsCallback will be added. Users can then use the receivers from nvflare.app_opt.tracking. """ FLComponent.__init__(self) self.early_stopping_rounds = early_stopping_rounds self.xgb_params = xgb_params self.verbose_eval = verbose_eval self.use_gpus = use_gpus self.model_file_name = model_file_name self.data_loader_id = data_loader_id self.logger = get_logger(self) self._client_name = None self._rank = None self._world_size = None self._num_rounds = None self._server_addr = None self._data_loader = None self._model_dir = None self._stopped = False self._metrics_writer_id = metrics_writer_id self._metrics_writer = None
[docs] def initialize(self, fl_ctx: FLContext): engine = fl_ctx.get_engine() self._data_loader = engine.get_component(self.data_loader_id) if not isinstance(self._data_loader, XGBDataLoader): self.system_panic(f"data_loader should be type XGBDataLoader but got {type(self._data_loader)}", fl_ctx) if self._metrics_writer_id: self._metrics_writer = engine.get_component(self._metrics_writer_id) if not isinstance(self._metrics_writer, LogWriter): self.system_panic("writer should be type LogWriter", fl_ctx)
[docs] def xgb_train( self, params: XGBoostParams, train_data: xgb.core.DMatrix, val_data: xgb.core.DMatrix ) -> xgb.core.Booster: """XGBoost training logic. Args: params (XGBoostParams): xgboost parameters. train_data (xgb.core.DMatrix): training data. val_data (xgb.core.DMatrix): validation data. Returns: A xgboost booster. Note: Users can customize this method for their own training logic. """ # Specify validations set to watch performance watchlist = [(val_data, "eval"), (train_data, "train")] callbacks = [callback.EvaluationMonitor(rank=self._rank)] if self._metrics_writer: callbacks.append(MetricsCallback(self._metrics_writer)) # Run training, all the features in training API is available. bst = xgb.train( params.xgb_params, train_data, params.num_rounds, evals=watchlist, early_stopping_rounds=params.early_stopping_rounds, verbose_eval=params.verbose_eval, callbacks=callbacks, ) return bst
[docs] def run(self, ctx: dict): self._client_name = ctx.get(Constant.RUNNER_CTX_CLIENT_NAME) self._rank = ctx.get(Constant.RUNNER_CTX_RANK) self._world_size = ctx.get(Constant.RUNNER_CTX_WORLD_SIZE) self._num_rounds = ctx.get(Constant.RUNNER_CTX_NUM_ROUNDS) self._server_addr = ctx.get(Constant.RUNNER_CTX_SERVER_ADDR) self._model_dir = ctx.get(Constant.RUNNER_CTX_MODEL_DIR) if self.use_gpus: # mapping each rank to a GPU (can set to cuda:0 if simulating with only one gpu) self.logger.info(f"Training with GPU {self._rank}") self.xgb_params["device"] = f"cuda:{self._rank}" self.logger.info(f"Using xgb params: {self.xgb_params}") params = XGBoostParams( xgb_params=self.xgb_params, num_rounds=self._num_rounds, early_stopping_rounds=self.early_stopping_rounds, verbose_eval=self.verbose_eval, ) self.logger.info(f"server address is {self._server_addr}") communicator_env = { "dmlc_communicator": "federated", "federated_server_address": f"{self._server_addr}", "federated_world_size": self._world_size, "federated_rank": self._rank, } self.logger.info(f"communicator_env is {communicator_env=}") with xgb.collective.CommunicatorContext(**communicator_env): train_data, val_data = self._data_loader.load_data(self._client_name) bst = self.xgb_train(params, train_data, val_data) # Save the model. bst.save_model(os.path.join(self._model_dir, self.model_file_name)) xgb.collective.communicator_print("Finished training\n") self._stopped = True
[docs] def stop(self): # currently no way to stop the runner pass
[docs] def is_stopped(self) -> Tuple[bool, int]: return self._stopped, 0