Source code for nvflare.app_opt.xgboost.tree_based.bagging_aggregator

# Copyright (c) 2022, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import json

from nvflare.apis.dxo import DXO, DataKind, from_shareable
from nvflare.apis.fl_constant import ReservedKey, ReturnCode
from nvflare.apis.fl_context import FLContext
from nvflare.apis.shareable import Shareable
from nvflare.app_common.abstract.aggregator import Aggregator
from nvflare.app_common.app_constant import AppConstants


[docs]class XGBBaggingAggregator(Aggregator): def __init__(self): """Perform bagging aggregation for XGBoost trees. The trees are pre-weighted during training. Bagging aggregation simply add the new trees to existing global model. """ super().__init__() self.logger.debug(f"expected data kind: {DataKind.WEIGHTS}") self.history = [] self.local_models = [] self.local_models_as_dict = [] self.global_model = None self.expected_data_kind = DataKind.WEIGHTS self.num_trees = 0
[docs] def accept(self, shareable: Shareable, fl_ctx: FLContext) -> bool: """Store shareable and update aggregator's internal state Args: shareable: information from contributor fl_ctx: context provided by workflow Returns: The first boolean indicates if this shareable is accepted. The second boolean indicates if aggregate can be called. """ try: dxo = from_shareable(shareable) except Exception: self.log_exception(fl_ctx, "shareable data is not a valid DXO") return False contributor_name = shareable.get_peer_prop(key=ReservedKey.IDENTITY_NAME, default="?") contribution_round = shareable.get_cookie(AppConstants.CONTRIBUTION_ROUND) rc = shareable.get_return_code() if rc and rc != ReturnCode.OK: self.log_warning(fl_ctx, f"Contributor {contributor_name} returned rc: {rc}. Disregarding contribution.") return False if dxo.data_kind != self.expected_data_kind: self.log_error(fl_ctx, "expected {} but got {}".format(self.expected_data_kind, dxo.data_kind)) return False current_round = fl_ctx.get_prop(AppConstants.CURRENT_ROUND) self.log_debug(fl_ctx, f"current_round: {current_round}") if contribution_round != current_round: self.log_warning( fl_ctx, f"discarding DXO from {contributor_name} at round: " f"{contribution_round}. Current round is: {current_round}", ) return False for item in self.history: if contributor_name == item["contributor_name"]: prev_round = item["round"] self.log_warning( fl_ctx, f"discarding DXO from {contributor_name} at round: " f"{contribution_round} as {prev_round} accepted already", ) return False data = dxo.data if data is None: self.log_error(fl_ctx, "no data to aggregate") return False else: self.local_models.append(data["model_data"]) self.local_models_as_dict.append(json.loads(data["model_data"])) self.history.append( { "contributor_name": contributor_name, "round": contribution_round, } ) return True
[docs] def aggregate(self, fl_ctx: FLContext) -> Shareable: """Called when workflow determines to generate shareable to send back to contributors Args: fl_ctx (FLContext): context provided by workflow Returns: Shareable: the weighted mean of accepted shareables from contributors """ self.log_debug(fl_ctx, "Start aggregation") current_round = fl_ctx.get_prop(AppConstants.CURRENT_ROUND) site_num = len(self.history) self.log_info(fl_ctx, f"aggregating {site_num} update(s) at round {current_round}") self.history = [] self.log_debug(fl_ctx, "End aggregation") local_updates = self.local_models local_updates_as_dict = self.local_models_as_dict self.local_models = [] self.local_models_as_dict = [] dxo = DXO( data_kind=self.expected_data_kind, data={"model_data": local_updates, "model_data_dict": local_updates_as_dict}, ) return dxo.to_shareable()