nvflare.app_opt.xgboost.histogram_based.executor module¶
- class FedXGBHistogramExecutor(num_rounds, early_stopping_rounds, xgb_params: dict, data_loader_id: str, verbose_eval=False, use_gpus=False, metrics_writer_id: str | None = None, model_file_name='test.model.json')[source]¶
Bases:
Executor
Federated XGBoost Executor Spec for histogram-base collaboration.
This class implements a basic xgb_train logic, feel free to overwrite the function for custom behavior.
Federated XGBoost Executor for histogram-base collaboration.
This class sets up the training environment for Federated XGBoost. This is the executor running on each NVFlare client, which starts XGBoost training.
- Parameters:
num_rounds – number of boosting rounds
early_stopping_rounds – early stopping rounds
xgb_params – This dict is passed to xgboost.train() as the first argument params. It contains all the Booster parameters. Please refer to XGBoost documentation for details: https://xgboost.readthedocs.io/en/stable/python/python_api.html#module-xgboost.training
data_loader_id – the ID points to XGBDataLoader.
verbose_eval – verbose_eval in xgboost.train
use_gpus – flag to enable gpu training
metrics_writer_id – the ID points to a LogWriter, if provided, a MetricsCallback will be added. Users can then use the receivers from nvflare.app_opt.tracking.
model_file_name (str) – where to save the model.
- execute(task_name: str, shareable: Shareable, fl_ctx: FLContext, abort_signal: Signal) Shareable [source]¶
Executes a task.
- Parameters:
task_name (str) – task name.
shareable (Shareable) – input shareable.
fl_ctx (FLContext) – fl context.
abort_signal (Signal) – signal to check during execution to determine whether this task is aborted.
- Returns:
An output shareable.
- handle_event(event_type: str, fl_ctx: FLContext)[source]¶
Handles events.
- Parameters:
event_type (str) – event type fired by workflow.
fl_ctx (FLContext) – FLContext information.
- train(shareable: Shareable, fl_ctx: FLContext, abort_signal: Signal) Shareable [source]¶
XGBoost training task pipeline which handles NVFlare specific tasks
- xgb_train(params: XGBoostParams) Booster [source]¶
XGBoost training logic.
- Parameters:
params (XGBoostParams) – xgboost parameters.
- Returns:
A xgboost booster.
- class XGBoostParams(xgb_params: dict, num_rounds: int = 10, early_stopping_rounds: int = 2, verbose_eval: bool = False)[source]¶
Bases:
object
Container for all XGBoost parameters.
- Parameters:
xgb_params – The Booster parameters. This dict is passed to xgboost.train() as the argument params. It contains all the Booster parameters. Please refer to XGBoost documentation for details: https://xgboost.readthedocs.io/en/stable/parameter.html