Source code for nvflare.app_opt.xgboost.tree_based.executor

# Copyright (c) 2022, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import json
import os

import xgboost as xgb

from nvflare.apis.dxo import DXO, DataKind, from_shareable
from nvflare.apis.event_type import EventType
from nvflare.apis.executor import Executor
from nvflare.apis.fl_constant import FLContextKey, ReturnCode
from nvflare.apis.fl_context import FLContext
from nvflare.apis.shareable import Shareable, make_reply
from nvflare.apis.signal import Signal
from nvflare.app_common.app_constant import AppConstants
from nvflare.app_opt.xgboost.data_loader import XGBDataLoader
from nvflare.app_opt.xgboost.tree_based.utils import update_model
from nvflare.fuel.utils.import_utils import optional_import
from nvflare.security.logging import secure_format_exception


[docs] class FedXGBTreeExecutor(Executor): def __init__( self, training_mode, lr_scale, data_loader_id: str, num_client_bagging: int = 1, lr_mode: str = "uniform", local_model_path: str = "model.json", global_model_path: str = "model_global.json", learning_rate: float = 0.1, objective: str = "binary:logistic", num_local_parallel_tree: int = 1, local_subsample: float = 1, max_depth: int = 8, eval_metric: str = "auc", nthread: int = 16, tree_method: str = "hist", train_task_name: str = AppConstants.TASK_TRAIN, use_gpus=False, ): super().__init__() self.client_id = None self.writer = None self.training_mode = training_mode self.num_client_bagging = num_client_bagging self.lr = None self.lr_scale = lr_scale self.base_lr = learning_rate self.lr_mode = lr_mode self.num_local_parallel_tree = num_local_parallel_tree self.local_subsample = local_subsample self.local_model_path = local_model_path self.global_model_path = global_model_path self.objective = objective self.max_depth = max_depth self.eval_metric = eval_metric self.nthread = nthread self.tree_method = tree_method self.train_task_name = train_task_name self.use_gpus = use_gpus self.num_local_round = 1 self.bst = None self.global_model_as_dict = None self.config = None self.local_model = None self.data_loader_id = data_loader_id self.train_data = None self.val_data = None # use dynamic shrinkage - adjusted by personalized scaling factor if lr_mode not in ["uniform", "scaled"]: raise ValueError(f"Only support [uniform] or [scaled] mode, but got {lr_mode}")
[docs] def initialize(self, fl_ctx: FLContext): # set the paths according to fl_ctx engine = fl_ctx.get_engine() ws = engine.get_workspace() app_dir = ws.get_app_dir(fl_ctx.get_job_id()) self.local_model_path = os.path.join(app_dir, self.local_model_path) self.global_model_path = os.path.join(app_dir, self.global_model_path) # get and print the args fl_args = fl_ctx.get_prop(FLContextKey.ARGS) self.client_id = fl_ctx.get_identity_name() self.log_info( fl_ctx, f"Client {self.client_id} initialized with args: \n {fl_args}", ) self.rank = int(self.client_id.split("-")[1]) - 1 # set local tensorboard writer for local training info of current model tensorboard, flag = optional_import(module="torch.utils.tensorboard") if flag: self.writer = tensorboard.SummaryWriter(app_dir) if self.training_mode not in ["cyclic", "bagging"]: self.system_panic(f"Only support [cyclic] or [bagging] mode, but got {self.training_mode}", fl_ctx) return # load data and lr_scale, this is task/site-specific data_loader = engine.get_component(self.data_loader_id) if not isinstance(data_loader, XGBDataLoader): self.system_panic("data_loader should be type XGBDataLoader", fl_ctx) data_loader.initialize( client_id=self.client_id, rank=self.rank, ) try: self.train_data, self.val_data = data_loader.load_data() except Exception as e: self.system_panic(f"load_data failed: {secure_format_exception(e)}", fl_ctx) self.lr = self._get_effective_learning_rate()
def _get_effective_learning_rate(self): if self.training_mode == "bagging": # Bagging mode if self.lr_mode == "uniform": # uniform lr, global learning_rate scaled by num_client_bagging for bagging lr = self.base_lr / self.num_client_bagging else: # scaled lr, global learning_rate scaled by data size percentage lr = self.base_lr * self.lr_scale else: # Cyclic mode, directly use the base learning_rate lr = self.base_lr return lr def _get_xgb_train_params(self): params = { "objective": self.objective, "eta": self.lr, "max_depth": self.max_depth, "eval_metric": self.eval_metric, "nthread": self.nthread, "num_parallel_tree": self.num_local_parallel_tree, "subsample": self.local_subsample, "tree_method": self.tree_method, } return params def _local_boost_bagging(self, fl_ctx: FLContext): eval_results = self.bst.eval_set( evals=[(self.train_data, "train"), (self.val_data, "valid")], iteration=self.bst.num_boosted_rounds() - 1 ) self.log_info(fl_ctx, eval_results) auc = float(eval_results.split("\t")[2].split(":")[1]) for i in range(self.num_local_round): self.bst.update(self.train_data, self.bst.num_boosted_rounds()) # extract newly added self.num_local_round using xgboost slicing api bst = self.bst[self.bst.num_boosted_rounds() - self.num_local_round : self.bst.num_boosted_rounds()] self.log_info( fl_ctx, f"Global AUC {auc}", ) if self.writer: # note: writing auc before current training step, for passed in global model self.writer.add_scalar( "AUC", auc, int((self.bst.num_boosted_rounds() - self.num_local_round - 1) / self.num_client_bagging) ) return bst def _local_boost_cyclic(self, fl_ctx: FLContext): # Cyclic mode # starting from global model # return the whole boosting tree series self.bst.update(self.train_data, self.bst.num_boosted_rounds()) eval_results = self.bst.eval_set( evals=[(self.train_data, "train"), (self.val_data, "valid")], iteration=self.bst.num_boosted_rounds() - 1 ) self.log_info(fl_ctx, eval_results) auc = float(eval_results.split("\t")[2].split(":")[1]) self.log_info( fl_ctx, f"Client {self.client_id} AUC after training: {auc}", ) if self.writer: self.writer.add_scalar("AUC", auc, self.bst.num_boosted_rounds() - 1) return self.bst
[docs] def train( self, shareable: Shareable, fl_ctx: FLContext, abort_signal: Signal, ) -> Shareable: if abort_signal.triggered: self.finalize(fl_ctx) return make_reply(ReturnCode.TASK_ABORTED) # retrieve current global model download from server's shareable dxo = from_shareable(shareable) model_update = dxo.data # xgboost parameters params = self._get_xgb_train_params() if self.use_gpus: # mapping each rank to a GPU (can set to cuda:0 if simulating with only one gpu) self.log_info(fl_ctx, f"Training with GPU {self.rank}") params["device"] = f"cuda:{self.rank}" if not self.bst: # First round self.log_info( fl_ctx, f"Client {self.client_id} initial training from scratch", ) if not model_update: bst = xgb.train( params, self.train_data, num_boost_round=self.num_local_round, evals=[(self.val_data, "validate"), (self.train_data, "train")], ) else: loadable_model = bytearray(model_update["model_data"]) bst = xgb.train( params, self.train_data, num_boost_round=self.num_local_round, xgb_model=loadable_model, evals=[(self.val_data, "validate"), (self.train_data, "train")], ) self.config = bst.save_config() self.bst = bst else: self.log_info( fl_ctx, f"Client {self.client_id} model updates received from server", ) if self.training_mode == "bagging": model_updates = model_update["model_data"] for update in model_updates: self.global_model_as_dict = update_model(self.global_model_as_dict, json.loads(update)) loadable_model = bytearray(json.dumps(self.global_model_as_dict), "utf-8") else: loadable_model = bytearray(model_update["model_data"]) self.log_info( fl_ctx, f"Client {self.client_id} converted global model to json ", ) self.bst.load_model(loadable_model) self.bst.load_config(self.config) self.log_info( fl_ctx, f"Client {self.client_id} loaded global model into booster ", ) # train local model starting with global model if self.training_mode == "bagging": bst = self._local_boost_bagging(fl_ctx) else: bst = self._local_boost_cyclic(fl_ctx) self.local_model = bst.save_raw("json") # report updated model in shareable dxo = DXO(data_kind=DataKind.WEIGHTS, data={"model_data": self.local_model}) self.log_info(fl_ctx, "Local epochs finished. Returning shareable") new_shareable = dxo.to_shareable() if self.writer: self.writer.flush() return new_shareable
[docs] def finalize(self, fl_ctx: FLContext): # freeing resources in finalize avoids seg fault during shutdown of gpu mode del self.bst del self.train_data del self.val_data self.log_info(fl_ctx, "Freed training resources")
[docs] def execute(self, task_name: str, shareable: Shareable, fl_ctx: FLContext, abort_signal: Signal) -> Shareable: self.log_info(fl_ctx, f"Client trainer got task: {task_name}") try: if task_name == "train": return self.train(shareable, fl_ctx, abort_signal) else: self.log_error(fl_ctx, f"Could not handle task: {task_name}") return make_reply(ReturnCode.TASK_UNKNOWN) except Exception as e: # Task execution error, return EXECUTION_EXCEPTION Shareable self.log_exception(fl_ctx, f"execute exception: {secure_format_exception(e)}") return make_reply(ReturnCode.EXECUTION_EXCEPTION)
[docs] def handle_event(self, event_type: str, fl_ctx: FLContext): if event_type == EventType.START_RUN: self.initialize(fl_ctx) elif event_type == EventType.END_RUN: self.finalize(fl_ctx)