nvflare.app_opt.xgboost.histogram_based_v2.runners.client_runner module¶
- class XGBClientRunner(data_loader_id: str, early_stopping_rounds: int, xgb_params: dict, verbose_eval: bool, use_gpus: bool, model_file_name: str, metrics_writer_id: str | None = None)[source]¶
Bases:
XGBRunner
,FLComponent
Constructor.
- Parameters:
early_stopping_rounds – early stopping rounds
xgb_params – This dict is passed to xgboost.train() as the first argument params. It contains all the Booster parameters. Please refer to XGBoost documentation for details: https://xgboost.readthedocs.io/en/stable/parameter.html
data_loader_id – the ID points to XGBDataLoader.
verbose_eval – verbose_eval in xgboost.train
use_gpus (bool) – A convenient flag to enable gpu training, if gpu device is specified in the xgb_params then this flag can be ignored.
metrics_writer_id – the ID points to a LogWriter, if provided, a MetricsCallback will be added. Users can then use the receivers from nvflare.app_opt.tracking.
- initialize(fl_ctx: FLContext)[source]¶
Initializes the runner. This happens when the job is about to start.
- Parameters:
fl_ctx – FL context
- Returns:
None
- is_stopped() Tuple[bool, int] [source]¶
Checks whether the runner is already stopped.
- Returns:
A tuple of (whether the runner is stopped, exit code)
- run(ctx: dict)[source]¶
Runs XGB processing logic.
- Parameters:
ctx – the contextual info to help the runner execution
- Returns:
None
- xgb_train(params: XGBoostParams, train_data: DMatrix, val_data: DMatrix) Booster [source]¶
XGBoost training logic.
- Parameters:
params (XGBoostParams) – xgboost parameters.
train_data (xgb.core.DMatrix) – training data.
val_data (xgb.core.DMatrix) – validation data.
- Returns:
A xgboost booster.
Note
Users can customize this method for their own training logic.
- class XGBoostParams(xgb_params: dict, num_rounds: int = 10, early_stopping_rounds: int = 2, verbose_eval: bool = False)[source]¶
Bases:
object
Container for all XGBoost parameters.
- Parameters:
xgb_params – This dict is passed to xgboost.train() as the first argument params. It contains all the Booster parameters. Please refer to XGBoost documentation for details: https://xgboost.readthedocs.io/en/stable/parameter.html