# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import os
import random
from typing import Dict, List, Tuple
from nvflare.apis.fl_context import FLContext
from nvflare.app_opt.xgboost.histogram_based_v2.aggr import Aggregator
from nvflare.app_opt.xgboost.histogram_based_v2.defs import Constant
from nvflare.app_opt.xgboost.histogram_based_v2.sec.data_converter import (
AggregationContext,
DataConverter,
FeatureAggregationResult,
FeatureContext,
)
SAMPLE_SIZE = 1000
NUM_FEATURES = 30
WORLD_SIZE = 3
RANK_FEATURES = [(0, 10), (10, 20), (20, 30)]
[docs]
def decode_msg(msg: bytes) -> dict:
return json.loads(str(msg, "utf-8"))
[docs]
class TupleAggregator(Aggregator):
def __init__(self):
Aggregator.__init__(self, initial_value=(0, 0))
[docs]
def add(self, a, b):
return a[0] + b[0], a[1] + b[1]
[docs]
class MockDataConverter(DataConverter):
def _gen_feature(self, num_bins, fid):
mask = [0] * SAMPLE_SIZE
for i in range(SAMPLE_SIZE):
mask[i] = (i + fid) % num_bins
return FeatureContext(fid, mask, num_bins)
def _setup(self):
self.features = {}
for fid in range(NUM_FEATURES):
self.features[fid] = self._gen_feature(256, fid)
for rank, fid_range in enumerate(RANK_FEATURES):
if fid_range is not None:
f, t = fid_range
self.rank_features[rank] = [self.features[fid] for fid in range(f, t)]
def __init__(self):
self._features_done = False
self.gh_pairs = None
# feature_id => feature
self.features = {}
self.rank_features = {}
self._setup()
# self.features = {
# 0: self._gen_feature(256, 0),
# 1: self._gen_feature(2, 1),
# 2: self._gen_feature(256, 2),
# 3: self._gen_feature(16, 3),
# 4: self._gen_feature(256, 4),
# 5: self._gen_feature(128, 5),
# }
#
# # rank => features
# self.rank_features = {
# # 0: [self.features[0], self.features[2]],
# 1: [self.features[0], self.features[1], self.features[3]],
# 2: [self.features[2], self.features[4], self.features[5]],
# }
self.groups = {}
[docs]
def decode_gh_pairs(self, buffer: bytes, fl_ctx: FLContext) -> List[Tuple[int, int]]:
"""Decode the buffer to extract (g, h) pairs.
Args:
buffer: the buffer to be decoded
fl_ctx: FLContext info
Returns: if the buffer contains (g, h) pairs, return a tuple of (g_numbers, h_numbers);
otherwise, return None
"""
rank = fl_ctx.get_prop(Constant.PARAM_KEY_RANK)
if rank != 0:
# non-label client
return None
msg = decode_msg(buffer)
op = msg["op"]
if op != "gh":
return None
min_value = -999999
max_value = 999999
result = []
for i in range(SAMPLE_SIZE):
result.append((random.randint(min_value, max_value), random.randint(min_value, max_value)))
self.gh_pairs = result
return result
[docs]
def decode_aggregation_context(self, buffer: bytes, fl_ctx: FLContext) -> AggregationContext:
"""Decode the buffer to extract aggregation context info
Args:
buffer: buffer to be decoded
fl_ctx: FLContext info
Returns: if the buffer contains aggregation context, return an AggregationContext object;
otherwise, return None
"""
rank = fl_ctx.get_prop(Constant.PARAM_KEY_RANK)
features = None
if not self._features_done:
self._features_done = True
features = self.rank_features.get(rank)
else:
self.groups = {1: [1, 3, 4, 101], 4: [2, 7, 9, 23, 50]}
return AggregationContext(features, self.groups)
def _aggregate_feature(self, ctx: FeatureContext, sample_ids):
aggr = TupleAggregator()
return aggr.aggregate(self.gh_pairs, ctx.sample_bin_assignment, ctx.num_bins, sample_ids)
[docs]
def encode_aggregation_result(
self, aggr_results: Dict[int, List[FeatureAggregationResult]], fl_ctx: FLContext
) -> bytes:
"""Encode an individual rank's aggr result to a buffer based on XGB data structure
Args:
aggr_results: aggregation result for all features and all groups from all clients
group_id => list of feature aggr results
fl_ctx: FLContext info
Returns: a buffer of bytes
"""
# verify result
for gid, fars in aggr_results.items():
for far in fars:
ctx = self.features[far.feature_id]
sample_ids = self.groups.get(gid)
expected = self._aggregate_feature(ctx, sample_ids)
if expected != far.aggregated_hist:
print(f"group {gid}: feature {far.feature_id}: expected aggr != received")
print(f"{expected=}")
print(f"{far.aggregated_hist=}")
else:
print(f"group {gid}: feature {far.feature_id}: Result OK!")
return os.urandom(4)