Source code for nvflare.app_opt.xgboost.histogram_based_v2.sec.dam

# Copyright (c) 2024, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import struct
from io import BytesIO
from typing import List

SIGNATURE = "NVDADAM1"  # DAM (Direct Accessible Marshalling) V1
PREFIX_LEN = 24

DATA_TYPE_INT = 1
DATA_TYPE_FLOAT = 2
DATA_TYPE_STRING = 3
DATA_TYPE_INT_ARRAY = 257
DATA_TYPE_FLOAT_ARRAY = 258


[docs] class DamEncoder: def __init__(self, data_set_id: int): self.data_set_id = data_set_id self.entries = [] self.buffer = BytesIO()
[docs] def add_int_array(self, value: List[int]): self.entries.append((DATA_TYPE_INT_ARRAY, value))
[docs] def add_float_array(self, value: List[float]): self.entries.append((DATA_TYPE_FLOAT_ARRAY, value))
[docs] def finish(self) -> bytes: size = PREFIX_LEN for entry in self.entries: size += 16 size += len(entry[1]) * 8 self.write_str(SIGNATURE) self.write_int64(size) self.write_int64(self.data_set_id) for entry in self.entries: data_type, value = entry self.write_int64(data_type) self.write_int64(len(value)) for x in value: if data_type == DATA_TYPE_INT_ARRAY: self.write_int64(x) else: self.write_float(x) return self.buffer.getvalue()
[docs] def write_int64(self, value: int): self.buffer.write(struct.pack("q", value))
[docs] def write_float(self, value: float): self.buffer.write(struct.pack("d", value))
[docs] def write_str(self, value: str): self.buffer.write(value.encode("utf-8"))
[docs] class DamDecoder: def __init__(self, buffer: bytes): self.buffer = buffer self.pos = 0 if len(buffer) >= PREFIX_LEN: self.signature = self.read_string(8) self.size = self.read_int64() self.data_set_id = self.read_int64() else: self.signature = None self.size = 0 self.data_set_id = 0
[docs] def is_valid(self): return self.signature == SIGNATURE
[docs] def get_data_set_id(self): return self.data_set_id
[docs] def decode_int_array(self) -> List[int]: data_type = self.read_int64() if data_type != DATA_TYPE_INT_ARRAY: raise RuntimeError("Invalid data type for int array") num = self.read_int64() result = [0] * num for i in range(num): result[i] = self.read_int64() return result
[docs] def decode_float_array(self): data_type = self.read_int64() if data_type != DATA_TYPE_FLOAT_ARRAY: raise RuntimeError("Invalid data type for float array") num = self.read_int64() result = [0.0] * num for i in range(num): result[i] = self.read_float() return result
[docs] def read_string(self, length: int) -> str: result = self.buffer[self.pos : self.pos + length].decode("latin1") self.pos += length return result
[docs] def read_int64(self) -> int: (result,) = struct.unpack_from("q", self.buffer, self.pos) self.pos += 8 return result
[docs] def read_float(self) -> float: (result,) = struct.unpack_from("d", self.buffer, self.pos) self.pos += 8 return result