Source code for nvflare.app_opt.xgboost.histogram_based_v2.sec.server_handler

# Copyright (c) 2024, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import threading

from nvflare.apis.fl_component import FLComponent
from nvflare.apis.fl_context import FLContext
from nvflare.apis.shareable import Shareable
from nvflare.app_opt.xgboost.histogram_based_v2.defs import Constant
from nvflare.app_opt.xgboost.histogram_based_v2.sec.sec_handler import SecurityHandler

try:
    from nvflare.app_opt.he import decomposers

    tenseal_imported = True
except Exception:
    tenseal_imported = False


[docs] class ServerSecurityHandler(SecurityHandler): def __init__(self): FLComponent.__init__(self) self.encrypted_gh = None self.gh_source_rank = 0 self.gh_seq = 0 self.gh_original_buf_size = 0 self.aggr_seq = 0 self.aggr_result_dict = None self.aggr_result_to_send = None self.aggr_result_lock = threading.Lock() self.world_size = 0 self.size_dict = None if tenseal_imported: decomposers.register() def _process_before_broadcast(self, fl_ctx: FLContext): self.info(fl_ctx, "start") rank = fl_ctx.get_prop(Constant.PARAM_KEY_RANK) seq = fl_ctx.get_prop(Constant.PARAM_KEY_SEQ) request = fl_ctx.get_prop(Constant.PARAM_KEY_REQUEST) assert isinstance(request, Shareable) has_encrypted_gh = request.get_header(Constant.HEADER_KEY_ENCRYPTED_DATA) self.info(fl_ctx, f"{has_encrypted_gh=}") if not has_encrypted_gh: self.info(fl_ctx, "not for gh broadcast - ignore") return self.encrypted_gh = fl_ctx.get_prop(Constant.PARAM_KEY_SEND_BUF) self.gh_source_rank = rank self.gh_seq = seq self.gh_original_buf_size = request.get_header(Constant.HEADER_KEY_ORIGINAL_BUF_SIZE) self.info(fl_ctx, f"got gh bcst: encrypted_gh={len(self.encrypted_gh)} orig_buf={self.gh_original_buf_size}") # only need to send a small dummy buffer to the server dummy_buf = os.urandom(Constant.DUMMY_BUFFER_SIZE) fl_ctx.set_prop(key=Constant.PARAM_KEY_SEND_BUF, value=dummy_buf, private=True, sticky=False) def _process_after_broadcast(self, fl_ctx: FLContext): # this is called after the Server already received broadcast calls from all clients of the same sequence self.info(fl_ctx, "start") rank = fl_ctx.get_prop(Constant.PARAM_KEY_RANK) seq = fl_ctx.get_prop(Constant.PARAM_KEY_SEQ) if seq != self.gh_seq: # this is not a gh broadcast self.info(fl_ctx, "not gh bcast - ignore") return reply = fl_ctx.get_prop(Constant.PARAM_KEY_REPLY) assert isinstance(reply, Shareable) reply.set_header(Constant.HEADER_KEY_ENCRYPTED_DATA, True) reply.set_header(Constant.HEADER_KEY_ORIGINAL_BUF_SIZE, self.gh_original_buf_size) if rank == self.gh_source_rank: # no need to send any data back to label client self.info(fl_ctx, f"return dummy to gh source {rank}") fl_ctx.set_prop(key=Constant.PARAM_KEY_RCV_BUF, value=None, private=True, sticky=False) return # send encrypted ghs self.info(fl_ctx, f"return {len(self.encrypted_gh)=} to non-label {rank}") fl_ctx.set_prop(key=Constant.PARAM_KEY_RCV_BUF, value=self.encrypted_gh, private=True, sticky=False) def _process_before_all_gather_v(self, fl_ctx: FLContext): request = fl_ctx.get_prop(Constant.PARAM_KEY_REQUEST) assert isinstance(request, Shareable) has_encrypted_data = request.get_header(Constant.HEADER_KEY_ENCRYPTED_DATA) self.info(fl_ctx, f"{has_encrypted_data=}") if not has_encrypted_data: self.info(fl_ctx, "start - non-secure data") return horizontal = request.get_header(Constant.HEADER_KEY_HORIZONTAL) training_mode = "horizontal" if horizontal else "vertical" self.info(fl_ctx, f"start - {training_mode}") fl_ctx.set_prop(key=Constant.HEADER_KEY_IN_AGGR, value=True, private=True, sticky=False) fl_ctx.set_prop(key=Constant.HEADER_KEY_HORIZONTAL, value=horizontal, private=True, sticky=False) rank = fl_ctx.get_prop(Constant.PARAM_KEY_RANK) send_buf = fl_ctx.get_prop(Constant.PARAM_KEY_SEND_BUF) if send_buf: if horizontal: length = send_buf.size() else: length = len(send_buf) # the send_buf contains encoded aggr result (str) or CKKS vector from this rank self.info(fl_ctx, f"got encrypted aggr data: {length} bytes") with self.aggr_result_lock: self.aggr_result_to_send = None if not self.aggr_result_dict: self.aggr_result_dict = {} self.aggr_result_dict[rank] = send_buf else: self.info(fl_ctx, f"no aggr data from {rank=}") if self.size_dict is None: self.size_dict = {} self.size_dict[rank] = request.get_header(Constant.HEADER_KEY_ORIGINAL_BUF_SIZE) # only send a dummy to the Server fl_ctx.set_prop( key=Constant.PARAM_KEY_SEND_BUF, value=os.urandom(Constant.DUMMY_BUFFER_SIZE), private=True, sticky=False ) self.info(fl_ctx, "send dummy buf to XGB server") def _process_after_all_gather_v(self, fl_ctx: FLContext): # this is called after the Server has finished gathering # Note: this fl_ctx is the same as the one in _process_before_all_gather_v! rank = fl_ctx.get_prop(Constant.PARAM_KEY_RANK) in_aggr = fl_ctx.get_prop(Constant.HEADER_KEY_IN_AGGR) self.info(fl_ctx, f"start {in_aggr=}") if not in_aggr: self.info(fl_ctx, "not in_aggr - ignore") return reply = fl_ctx.get_prop(Constant.PARAM_KEY_REPLY) assert isinstance(reply, Shareable) horizontal = fl_ctx.get_prop(Constant.HEADER_KEY_HORIZONTAL) reply.set_header(Constant.HEADER_KEY_ENCRYPTED_DATA, True) reply.set_header(Constant.HEADER_KEY_HORIZONTAL, horizontal) with self.aggr_result_lock: if not self.aggr_result_to_send: if not self.aggr_result_dict: return self._abort(f"Rank {rank}: no aggr result after AllGatherV!", fl_ctx) if horizontal: self.aggr_result_to_send = self._histogram_sum(fl_ctx) else: self.aggr_result_to_send = self.aggr_result_dict # reset aggr_result_dict for next gather self.aggr_result_dict = None self.world_size = len(self.size_dict) reply.set_header(Constant.HEADER_KEY_WORLD_SIZE, self.world_size) reply.set_header(Constant.HEADER_KEY_SIZE_DICT, self.size_dict) if horizontal: length = self.aggr_result_to_send.size() else: length = len(self.aggr_result_to_send) self.info(fl_ctx, f"aggr_result_to_send {length}") fl_ctx.set_prop(key=Constant.PARAM_KEY_RCV_BUF, value=self.aggr_result_to_send, private=True, sticky=False) def _histogram_sum(self, fl_ctx: FLContext): result = None for rank, vector in self.aggr_result_dict.items(): if not result: result = vector else: result = result + vector return result