# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Dict, List, Tuple
from nvflare.apis.fl_context import FLContext
from nvflare.app_opt.xgboost.histogram_based_v2.sec.dam import DamDecoder, DamEncoder
from nvflare.app_opt.xgboost.histogram_based_v2.sec.data_converter import (
AggregationContext,
DataConverter,
FeatureAggregationResult,
FeatureContext,
)
DATA_SET_GH_PAIRS = 1
DATA_SET_AGGREGATION = 2
DATA_SET_AGGREGATION_WITH_FEATURES = 3
DATA_SET_AGGREGATION_RESULT = 4
DATA_SET_HISTOGRAMS = 5
DATA_SET_HISTOGRAMS_RESULT = 6
SCALE_FACTOR = 1000000.0 # Preserve 6 decimal places
[docs]
class ProcessorDataConverter(DataConverter):
def __init__(self):
super().__init__()
self.features = []
self.feature_list = None
self.num_samples = 0
[docs]
def decode_gh_pairs(self, buffer: bytes, fl_ctx: FLContext) -> List[Tuple[int, int]]:
decoder = DamDecoder(buffer)
if not decoder.is_valid():
return None
if decoder.get_data_set_id() != DATA_SET_GH_PAIRS:
raise RuntimeError(f"Data is not for GH Pairs: {decoder.get_data_set_id()}")
float_array = decoder.decode_float_array()
result = []
self.num_samples = int(len(float_array) / 2)
for i in range(self.num_samples):
result.append((self.float_to_int(float_array[2 * i]), self.float_to_int(float_array[2 * i + 1])))
return result
[docs]
def decode_aggregation_context(self, buffer: bytes, fl_ctx: FLContext) -> AggregationContext:
decoder = DamDecoder(buffer)
if not decoder.is_valid():
return None
data_set_id = decoder.get_data_set_id()
cuts = decoder.decode_int_array()
if data_set_id == DATA_SET_AGGREGATION_WITH_FEATURES:
self.feature_list = decoder.decode_int_array()
num = len(self.feature_list)
slots = decoder.decode_int_array()
num_samples = int(len(slots) / num)
for i in range(num):
bin_assignment = []
for row_id in range(num_samples):
_, bin_num = self.slot_to_bin(cuts, slots[row_id * num + i])
bin_assignment.append(bin_num)
bin_size = self.get_bin_size(cuts, self.feature_list[i])
feature_ctx = FeatureContext(self.feature_list[i], bin_assignment, bin_size)
self.features.append(feature_ctx)
elif data_set_id != DATA_SET_AGGREGATION:
raise RuntimeError(f"Invalid DataSet: {data_set_id}")
node_list = decoder.decode_int_array()
sample_groups = {}
for node in node_list:
row_ids = decoder.decode_int_array()
sample_groups[node] = row_ids
return AggregationContext(self.features, sample_groups)
[docs]
def encode_aggregation_result(
self, aggr_results: Dict[int, List[FeatureAggregationResult]], fl_ctx: FLContext
) -> bytes:
encoder = DamEncoder(DATA_SET_AGGREGATION_RESULT)
node_list = sorted(aggr_results.keys())
encoder.add_int_array(node_list)
for node in node_list:
result_list = aggr_results.get(node)
feature_list = [result.feature_id for result in result_list]
encoder.add_int_array(feature_list)
for result in result_list:
encoder.add_float_array(self.to_float_array(result))
return encoder.finish()
[docs]
def decode_histograms(self, buffer: bytes, fl_ctx: FLContext) -> List[float]:
decoder = DamDecoder(buffer)
if not decoder.is_valid():
return None
data_set_id = decoder.get_data_set_id()
if data_set_id != DATA_SET_HISTOGRAMS:
raise RuntimeError(f"Invalid DataSet: {data_set_id}")
return decoder.decode_float_array()
[docs]
def encode_histograms_result(self, histograms: List[float], fl_ctx: FLContext) -> bytes:
encoder = DamEncoder(DATA_SET_HISTOGRAMS_RESULT)
encoder.add_float_array(histograms)
return encoder.finish()
[docs]
@staticmethod
def get_bin_size(cuts: [int], feature_id: int) -> int:
return cuts[feature_id + 1] - cuts[feature_id]
[docs]
@staticmethod
def slot_to_bin(cuts: [int], slot: int) -> Tuple[int, int]:
if slot < 0 or slot >= cuts[-1]:
raise RuntimeError(f"Invalid slot {slot}, out of range [0-{cuts[-1] - 1}]")
for i in range(len(cuts) - 1):
if cuts[i] <= slot < cuts[i + 1]:
bin_num = slot - cuts[i]
return i, bin_num
raise RuntimeError(f"Logic error. Slot {slot}, out of range [0-{cuts[-1] - 1}]")
[docs]
@staticmethod
def float_to_int(value: float) -> int:
return int(value * SCALE_FACTOR)
[docs]
@staticmethod
def int_to_float(value: int) -> float:
return value / SCALE_FACTOR
[docs]
@staticmethod
def to_float_array(result: FeatureAggregationResult) -> List[float]:
float_array = []
for g, h in result.aggregated_hist:
float_array.append(ProcessorDataConverter.int_to_float(g))
float_array.append(ProcessorDataConverter.int_to_float(h))
return float_array