Source code for nvflare.app_opt.xgboost.histogram_based_v2.sec.client_handler

# Copyright (c) 2024, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import time

import xgboost
from packaging import version

from nvflare.apis.event_type import EventType
from nvflare.apis.fl_component import FLComponent
from nvflare.apis.fl_constant import FLContextKey
from nvflare.apis.fl_context import FLContext
from nvflare.apis.shareable import Shareable
from nvflare.app_opt.xgboost.histogram_based_v2.aggr import Aggregator
from nvflare.app_opt.xgboost.histogram_based_v2.defs import Constant
from nvflare.app_opt.xgboost.histogram_based_v2.sec.dam import DamDecoder
from nvflare.app_opt.xgboost.histogram_based_v2.sec.data_converter import FeatureAggregationResult
from nvflare.app_opt.xgboost.histogram_based_v2.sec.partial_he.adder import Adder
from nvflare.app_opt.xgboost.histogram_based_v2.sec.partial_he.decrypter import Decrypter
from nvflare.app_opt.xgboost.histogram_based_v2.sec.partial_he.encryptor import Encryptor
from nvflare.app_opt.xgboost.histogram_based_v2.sec.partial_he.util import (
    combine,
    decode_encrypted_data,
    decode_feature_aggregations,
    encode_encrypted_data,
    encode_feature_aggregations,
    generate_keys,
    ipcl_imported,
    split,
)
from nvflare.app_opt.xgboost.histogram_based_v2.sec.processor_data_converter import (
    DATA_SET_HISTOGRAMS,
    ProcessorDataConverter,
)
from nvflare.app_opt.xgboost.histogram_based_v2.sec.sec_handler import SecurityHandler

try:
    import tenseal as ts
    from tenseal.tensors.ckksvector import CKKSVector

    from nvflare.app_opt.he import decomposers
    from nvflare.app_opt.he.homomorphic_encrypt import load_tenseal_context_from_workspace

    tenseal_imported = True
    tenseal_error = None
except Exception as ex:
    tenseal_imported = False
    tenseal_error = f"Import error: {ex}"

XGBOOST_MIN_VERSION = "2.2.0-dev"


[docs] class ClientSecurityHandler(SecurityHandler): def __init__(self, key_length=1024, num_workers=10, tenseal_context_file="client_context.tenseal"): FLComponent.__init__(self) self.num_workers = num_workers self.key_length = key_length self.public_key = None self.private_key = None self.encryptor = None self.adder = None self.decrypter = None self.data_converter = ProcessorDataConverter() self.encrypted_ghs = None self.clear_ghs = None # for label client: list of tuples (g, h) self.original_gh_buffer = None self.feature_masks = None self.aggregator = Aggregator() self.aggr_result = None # for label client: computed aggr result based on clear-text clear_ghs self.tenseal_context_file = tenseal_context_file self.tenseal_context = None if tenseal_imported: decomposers.register() def _process_before_broadcast(self, fl_ctx: FLContext): root = fl_ctx.get_prop(Constant.PARAM_KEY_ROOT) rank = fl_ctx.get_prop(Constant.PARAM_KEY_RANK) self.info(fl_ctx, "start") if root != rank: # I am not the source of the broadcast self.info(fl_ctx, "not root - ignore") return buffer = fl_ctx.get_prop(Constant.PARAM_KEY_SEND_BUF) clear_ghs = self.data_converter.decode_gh_pairs(buffer, fl_ctx) if clear_ghs is None: # the buffer does not contain (g, h) pairs self.info(fl_ctx, "no clear gh pairs - ignore") return if self.encryptor is None: return self._abort("Encryptor is not created due to missing packages", fl_ctx) self.info(fl_ctx, f"got gh {len(clear_ghs)} pairs; original buf len: {len(buffer)}") self.original_gh_buffer = buffer # encrypt clear-text gh pairs and send to server self.clear_ghs = [combine(clear_ghs[i][0], clear_ghs[i][1]) for i in range(len(clear_ghs))] t = time.time() encrypted_values = self.encryptor.encrypt(self.clear_ghs) self.info(fl_ctx, f"encrypted gh pairs: {len(encrypted_values)}, took {time.time() - t} secs") t = time.time() encoded = encode_encrypted_data(self.public_key, encrypted_values) self.info(fl_ctx, f"encoded msg: size={len(encoded)}, type={type(encoded)} time={time.time() - t} secs") # Remember the original buffer size, so we could send a dummy buffer of this size to other clients # This is important since all XGB clients already prepared a buffer of this size and expect the data # to be the same size. headers = {Constant.HEADER_KEY_ENCRYPTED_DATA: True, Constant.HEADER_KEY_ORIGINAL_BUF_SIZE: len(buffer)} fl_ctx.set_prop(key=Constant.PARAM_KEY_SEND_BUF, value=encoded, private=True, sticky=False) fl_ctx.set_prop(key=Constant.PARAM_KEY_HEADERS, value=headers, private=True, sticky=False) def _process_after_broadcast(self, fl_ctx: FLContext): # this is called when the bcst result is received from the server self.info(fl_ctx, "start") reply = fl_ctx.get_prop(Constant.PARAM_KEY_REPLY) assert isinstance(reply, Shareable) has_encrypted_data = reply.get_header(Constant.HEADER_KEY_ENCRYPTED_DATA) if not has_encrypted_data: self.info(fl_ctx, f"{has_encrypted_data=}") return if self.clear_ghs: # this is the root rank # TBD: assume MPI requires the original buffer to be sent back to it. fl_ctx.set_prop(key=Constant.PARAM_KEY_RCV_BUF, value=self.original_gh_buffer, private=True, sticky=False) self.info(fl_ctx, "has_encrypted_data: label client - send original buffer back to XGB") return # this is a receiving non-label client # the rcv_buf contains encrypted gh values encoded_gh_str = fl_ctx.get_prop(Constant.PARAM_KEY_RCV_BUF) self.info(fl_ctx, f"{len(encoded_gh_str)=} {type(encoded_gh_str)=}") self.public_key, self.encrypted_ghs = decode_encrypted_data(encoded_gh_str) original_buf_size = reply.get_header(Constant.HEADER_KEY_ORIGINAL_BUF_SIZE) self.info(fl_ctx, f"{original_buf_size=}; encrypted gh pairs: {len(self.encrypted_ghs)}") # send a dummy buffer of original size to the XGB client since it is expecting data to be this size dummy_buf = os.urandom(original_buf_size) fl_ctx.set_prop(key=Constant.PARAM_KEY_RCV_BUF, value=dummy_buf, private=True, sticky=False) def _process_before_all_gather_v(self, fl_ctx: FLContext): self.info(fl_ctx, "start") buffer = fl_ctx.get_prop(Constant.PARAM_KEY_SEND_BUF) decoder = DamDecoder(buffer) if not decoder.is_valid(): self.info(fl_ctx, "Not secure content - ignore") return if decoder.get_data_set_id() == DATA_SET_HISTOGRAMS: self._process_before_all_gather_v_horizontal(fl_ctx) else: self._process_before_all_gather_v_vertical(fl_ctx) def _process_before_all_gather_v_vertical(self, fl_ctx: FLContext): rank = fl_ctx.get_prop(Constant.PARAM_KEY_RANK) buffer = fl_ctx.get_prop(Constant.PARAM_KEY_SEND_BUF) aggr_ctx = self.data_converter.decode_aggregation_context(buffer, fl_ctx) if not aggr_ctx: # this AllGatherV is irrelevant to secure processing self.info(fl_ctx, "no aggr ctx - ignore") return if not self.feature_masks: # the feature contexts only need to be set once if not aggr_ctx.features and not self.clear_ghs: return self._abort("missing features in aggregation context from non-label client", fl_ctx) m = [] if aggr_ctx.features: for f in aggr_ctx.features: m.append((f.feature_id, f.sample_bin_assignment, f.num_bins)) self.feature_masks = m self.info(fl_ctx, f"got feature ctx: {len(m)}") # compute aggregation groups = [] if aggr_ctx.sample_groups: for gid, sample_ids in aggr_ctx.sample_groups.items(): groups.append((gid, sample_ids)) if not self.encrypted_ghs: if not self.clear_ghs: # this is non-label client return self._abort(f"no encrypted (g, h) values for aggregation in rank {rank}", fl_ctx) else: # label client - send a dummy of 4 bytes self.info(fl_ctx, "label client: _do_aggregation in clear text") self._do_aggregation(groups, fl_ctx) headers = {Constant.HEADER_KEY_ENCRYPTED_DATA: True, Constant.HEADER_KEY_ORIGINAL_BUF_SIZE: len(buffer)} fl_ctx.set_prop(key=Constant.PARAM_KEY_HEADERS, value=headers, private=True, sticky=False) fl_ctx.set_prop( key=Constant.PARAM_KEY_SEND_BUF, value=None, private=True, sticky=False, ) return self.info( fl_ctx, f"_process_before_all_gather_v: non-label client - do encrypted aggr for {len(groups)} groups" ) start = time.time() aggr_result = self.adder.add(self.encrypted_ghs, self.feature_masks, groups, encode_sum=True) self.info(fl_ctx, f"got aggr result for {len(aggr_result)} features in {time.time() - start} secs") start = time.time() encoded_str = encode_feature_aggregations(aggr_result) self.info(fl_ctx, f"encoded aggr result len {len(encoded_str)} in {time.time() - start} secs") headers = {Constant.HEADER_KEY_ENCRYPTED_DATA: True, Constant.HEADER_KEY_ORIGINAL_BUF_SIZE: len(buffer)} fl_ctx.set_prop(key=Constant.PARAM_KEY_SEND_BUF, value=encoded_str, private=True, sticky=False) fl_ctx.set_prop(key=Constant.PARAM_KEY_HEADERS, value=headers, private=True, sticky=False) def _process_before_all_gather_v_horizontal(self, fl_ctx: FLContext): if not self.tenseal_context: return self._abort( "Horizontal secure XGBoost not supported due to missing context or missing module", fl_ctx ) buffer = fl_ctx.get_prop(Constant.PARAM_KEY_SEND_BUF) histograms = self.data_converter.decode_histograms(buffer, fl_ctx) start = time.time() vector = ts.ckks_vector(self.tenseal_context, histograms) self.info( fl_ctx, f"_process_before_all_gather_v: Histograms with {len(histograms)} entries " f"encrypted in {time.time() - start} secs", ) headers = { Constant.HEADER_KEY_ENCRYPTED_DATA: True, Constant.HEADER_KEY_HORIZONTAL: True, Constant.HEADER_KEY_ORIGINAL_BUF_SIZE: len(buffer), } fl_ctx.set_prop(key=Constant.PARAM_KEY_SEND_BUF, value=vector, private=True, sticky=False) fl_ctx.set_prop(key=Constant.PARAM_KEY_HEADERS, value=headers, private=True, sticky=False) def _do_aggregation(self, groups, fl_ctx: FLContext): # this is only for the label-client to compute aggregation in clear-text! if not self.feature_masks: return t = time.time() aggr_result = [] # list of (fid, gid, GH_list) for fm in self.feature_masks: fid, masks, num_bins = fm if not groups: gid = 0 gh_list = self.aggregator.aggregate(self.clear_ghs, masks, num_bins, None) aggr_result.append((fid, gid, gh_list)) else: for grp in groups: gid, sample_ids = grp gh_list = self.aggregator.aggregate(self.clear_ghs, masks, num_bins, sample_ids) aggr_result.append((fid, gid, gh_list)) self.info(fl_ctx, f"aggregated clear-text in {time.time() - t} secs") self.aggr_result = aggr_result def _decrypt_aggr_result(self, encoded, fl_ctx: FLContext): # decrypt aggr result from a client if not isinstance(encoded, str): # this is dummy result of the label-client return encoded encoded_str = encoded t = time.time() decoded_aggrs = decode_feature_aggregations(self.public_key, encoded_str) self.info(fl_ctx, f"decode_feature_aggregations took {time.time() - t} secs") t = time.time() aggrs_to_decrypt = [decoded_aggrs[i][2] for i in range(len(decoded_aggrs))] decrypted_aggrs = self.decrypter.decrypt(aggrs_to_decrypt) # this is a list of clear-text GH numbers self.info(fl_ctx, f"decrypted {len(aggrs_to_decrypt)} numbers in {time.time() - t} secs") aggr_result = [] for i in range(len(decoded_aggrs)): fid, gid, _ = decoded_aggrs[i] clear_aggr = decrypted_aggrs[i] # list of combined clear-text ints aggr_result.append((fid, gid, clear_aggr)) return aggr_result def _process_after_all_gather_v(self, fl_ctx: FLContext): # called after AllGatherV result is received from the server self.info(fl_ctx, "start") reply = fl_ctx.get_prop(Constant.PARAM_KEY_REPLY) assert isinstance(reply, Shareable) encrypted_data = reply.get_header(Constant.HEADER_KEY_ENCRYPTED_DATA) if not encrypted_data: self.info(fl_ctx, "no encrypted result - ignore") return horizontal = reply.get_header(Constant.HEADER_KEY_HORIZONTAL) if horizontal: self._process_after_all_gather_v_horizontal(fl_ctx) else: self._process_after_all_gather_v_vertical(fl_ctx) def _process_after_all_gather_v_vertical(self, fl_ctx: FLContext): reply = fl_ctx.get_prop(Constant.PARAM_KEY_REPLY) size_dict = reply.get_header(Constant.HEADER_KEY_SIZE_DICT) total_size = sum(size_dict.values()) self.info(fl_ctx, f"{total_size=} {size_dict=}") rcv_buf = fl_ctx.get_prop(Constant.PARAM_KEY_RCV_BUF) # this rcv_buf is a list of replies from ALL clients! rank = fl_ctx.get_prop(Constant.PARAM_KEY_RANK) if not isinstance(rcv_buf, dict): return self._abort(f"rank {rank}: expect a dict of aggr result but got {type(rcv_buf)}", fl_ctx) rank_replies = rcv_buf self.info(fl_ctx, f"received rank replies: {len(rank_replies)}") if not self.clear_ghs: # this is non-label client - don't care about the results dummy = os.urandom(total_size) fl_ctx.set_prop(key=Constant.PARAM_KEY_RCV_BUF, value=dummy, private=True, sticky=False) self.info(fl_ctx, "non-label client: return dummy buffer back to XGB") return # this is label client: rank_replies contain encrypted aggr result! for r, rr in rank_replies.items(): if r != rank: # this is aggr result of a non-label client rank_replies[r] = self._decrypt_aggr_result(rr, fl_ctx) # add label client's result rank_replies[rank] = self.aggr_result combined_result = {} # gid => dict[fid=>GH_list] for r, rr in rank_replies.items(): # rr is a list of tuples: fid, gid, GHList if not rr: # label client may not have any features. continue for a in rr: fid, gid, combined_numbers = a gh_list = [] for n in combined_numbers: gh_list.append(split(n)) grp_result = combined_result.get(gid) if not grp_result: grp_result = {} combined_result[gid] = grp_result grp_result[fid] = FeatureAggregationResult(fid, gh_list) self.info(fl_ctx, f"aggr from rank {r}: {fid=} {gid=} bins={len(gh_list)}") final_result = {} for gid, far in combined_result.items(): sorted_far = sorted(far.items()) # r is a tuple of (fid, FeatureAggregationResult) final_result[gid] = [r[1] for r in sorted_far] fid_list = [x.feature_id for x in final_result[gid]] self.info(fl_ctx, f"final aggr: {gid=} features={fid_list}") result = self.data_converter.encode_aggregation_result(final_result, fl_ctx) # XGBoost expects every work has a set of histograms. They are already combined here so # just add zeros zero_result = final_result for result_list in zero_result.values(): for item in result_list: size = len(item.aggregated_hist) item.aggregated_hist = [(0, 0)] * size zero_buf = self.data_converter.encode_aggregation_result(zero_result, fl_ctx) world_size = len(size_dict) for _ in range(world_size - 1): result += zero_buf # XGBoost checks that the size of allgatherv is not changed padding_size = total_size - len(result) if padding_size > 0: result += b"\x00" * padding_size elif padding_size < 0: self.error(fl_ctx, f"The original size {total_size} is not big enough for data size {len(result)}") fl_ctx.set_prop(key=Constant.PARAM_KEY_RCV_BUF, value=result, private=True, sticky=False) def _process_after_all_gather_v_horizontal(self, fl_ctx: FLContext): reply = fl_ctx.get_prop(Constant.PARAM_KEY_REPLY) world_size = reply.get_header(Constant.HEADER_KEY_WORLD_SIZE) encrypted_histograms = fl_ctx.get_prop(Constant.PARAM_KEY_RCV_BUF) rank = fl_ctx.get_prop(Constant.PARAM_KEY_RANK) if not isinstance(encrypted_histograms, CKKSVector): return self._abort(f"rank {rank}: expect a CKKSVector but got {type(encrypted_histograms)}", fl_ctx) histograms = encrypted_histograms.decrypt(secret_key=self.tenseal_context.secret_key()) result = self.data_converter.encode_histograms_result(histograms, fl_ctx) # XGBoost expect every worker returns a histogram, all zeros are returned for other workers zeros = [0.0] * len(histograms) zero_buf = self.data_converter.encode_histograms_result(zeros, fl_ctx) for _ in range(world_size - 1): result += zero_buf fl_ctx.set_prop(key=Constant.PARAM_KEY_RCV_BUF, value=result, private=True, sticky=False) def _check_xgboost_version(self, disable_version_check: bool) -> bool: """Check XGBoost version. Returns true if it supports secure training""" if disable_version_check: self.logger.info("XGBoost version check is disabled") return True try: min_version = version.parse(XGBOOST_MIN_VERSION) current_version = version.parse(xgboost.__version__) if current_version < min_version: self.logger.error(f"XGBoost version {xgboost.__version__} doesn't support secure training") return False else: return True except Exception as error: self.logger.error(f"Unknown XGBoost version {xgboost.__version__}. Error: {error}") return False
[docs] def handle_event(self, event_type: str, fl_ctx: FLContext): global tenseal_error if event_type == Constant.EVENT_XGB_JOB_CONFIGURED: task_data = fl_ctx.get_prop(FLContextKey.TASK_DATA) data_split_mode = task_data.get(Constant.CONF_KEY_DATA_SPLIT_MODE) secure_training = task_data.get(Constant.CONF_KEY_SECURE_TRAINING) disable_version_check = task_data.get(Constant.CONF_KEY_DISABLE_VERSION_CHECK) if secure_training and not self._check_xgboost_version(disable_version_check): fl_ctx.set_prop( Constant.PARAM_KEY_CONFIG_ERROR, f"XGBoost version {xgboost.__version__} doesn't support secure training", private=True, sticky=False, ) return if secure_training and data_split_mode == xgboost.core.DataSplitMode.COL and ipcl_imported: self.public_key, self.private_key = generate_keys(self.key_length) self.encryptor = Encryptor(self.public_key, self.num_workers) self.decrypter = Decrypter(self.private_key, self.num_workers) self.adder = Adder(self.num_workers) elif secure_training and data_split_mode == xgboost.core.DataSplitMode.ROW: if not tenseal_imported: fl_ctx.set_prop(Constant.PARAM_KEY_CONFIG_ERROR, tenseal_error, private=True, sticky=False) return try: self.tenseal_context = load_tenseal_context_from_workspace(self.tenseal_context_file, fl_ctx) except Exception as err: tenseal_error = f"Can't load tenseal context: {err}" self.tenseal_context = None fl_ctx.set_prop(Constant.PARAM_KEY_CONFIG_ERROR, tenseal_error, private=True, sticky=False) elif event_type == EventType.END_RUN: self.tenseal_context = None else: super().handle_event(event_type, fl_ctx)