Source code for nvflare.app_opt.xgboost.histogram_based_v2.executor

# Copyright (c) 2024, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from nvflare.apis.fl_context import FLContext
from nvflare.app_opt.xgboost.histogram_based_v2.adaptor_executor import XGBExecutor
from nvflare.app_opt.xgboost.histogram_based_v2.adaptors.grpc_client_adaptor import GrpcClientAdaptor
from nvflare.app_opt.xgboost.histogram_based_v2.runners.client_runner import XGBClientRunner
from nvflare.fuel.utils.validation_utils import (
    check_non_negative_int,
    check_object_type,
    check_positive_number,
    check_str,
)


[docs]class FedXGBHistogramExecutor(XGBExecutor): def __init__( self, early_stopping_rounds, xgb_params: dict, data_loader_id: str, verbose_eval=False, use_gpus=False, per_msg_timeout=10.0, tx_timeout=100.0, model_file_name="model.json", metrics_writer_id: str = None, in_process: bool = True, ): """ Args: early_stopping_rounds: early stopping rounds xgb_params: This dict is passed to `xgboost.train()` as the first argument `params`. It contains all the Booster parameters. Please refer to XGBoost documentation for details: https://xgboost.readthedocs.io/en/stable/parameter.html data_loader_id: the ID points to XGBDataLoader. verbose_eval: verbose_eval in xgboost.train use_gpus (bool): A convenient flag to enable gpu training, if gpu device is specified in the `xgb_params` then this flag can be ignored. metrics_writer_id: the ID points to a LogWriter, if provided, a MetricsCallback will be added. Users can then use the receivers from nvflare.app_opt.tracking. model_file_name (str): where to save the model. in_process (bool): Specifies whether to start the `XGBRunner` in the same process or not. per_msg_timeout: timeout for sending one message tx_timeout: transaction timeout """ XGBExecutor.__init__( self, adaptor_component_id="", ) if early_stopping_rounds is not None: check_non_negative_int("early_stopping_rounds", early_stopping_rounds) if xgb_params is not None: check_object_type("xgb_params", xgb_params, dict) check_str("data_loader_id", data_loader_id) check_positive_number("per_msg_timeout", per_msg_timeout) if tx_timeout: check_positive_number("tx_timeout", tx_timeout) check_str("model_file_name", model_file_name) if metrics_writer_id: check_str("metrics_writer_id", metrics_writer_id) self.early_stopping_rounds = early_stopping_rounds self.xgb_params = xgb_params self.data_loader_id = data_loader_id self.verbose_eval = verbose_eval self.use_gpus = use_gpus self.per_msg_timeout = per_msg_timeout self.tx_timeout = tx_timeout self.model_file_name = model_file_name self.in_process = in_process self.metrics_writer_id = metrics_writer_id # do not let user specify int_server_grpc_options in this version - always use default self.int_server_grpc_options = None
[docs] def get_adaptor(self, fl_ctx: FLContext): runner = XGBClientRunner( data_loader_id=self.data_loader_id, early_stopping_rounds=self.early_stopping_rounds, xgb_params=self.xgb_params, verbose_eval=self.verbose_eval, use_gpus=self.use_gpus, model_file_name=self.model_file_name, metrics_writer_id=self.metrics_writer_id, ) runner.initialize(fl_ctx) adaptor = GrpcClientAdaptor( int_server_grpc_options=self.int_server_grpc_options, in_process=self.in_process, per_msg_timeout=self.per_msg_timeout, tx_timeout=self.tx_timeout, ) adaptor.set_runner(runner) return adaptor