# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import re
from typing import Union
import numpy as np
import torch
from bitsandbytes.functional import quantize_4bit, quantize_blockwise
from nvflare.apis.dxo import DXO, DataKind, MetaKey
from nvflare.apis.dxo_filter import DXOFilter
from nvflare.apis.fl_context import FLContext
from nvflare.apis.shareable import Shareable
from nvflare.app_opt.pt.quantization.constant import DATA_TYPE, QUANTIZATION_TYPE
[docs]
class ModelQuantizer(DXOFilter):
def __init__(
self,
quantization_type="float16",
):
"""Filter to quantize Shareable object to reduce communication burden.
Args:
quantization_type: method used for quantization
"""
# support weight and weight_diff data kinds
data_kinds = [DataKind.WEIGHTS, DataKind.WEIGHT_DIFF]
super().__init__(supported_data_kinds=data_kinds, data_kinds_to_filter=data_kinds)
# assign quantization type and check if it is valid
self.logger.info("Using model quantizator.")
quantization_type = quantization_type.lower()
if quantization_type.upper() not in QUANTIZATION_TYPE:
raise ValueError(f"Invalid quantization type: {quantization_type}, valid: {QUANTIZATION_TYPE}")
else:
self.quantization_type = quantization_type
# quantization constants
self.NP_FP16_MIN = np.finfo(np.float16).min
self.NP_FP16_MAX = np.finfo(np.float16).max
self.TS_FP16_MIN = torch.finfo(torch.float16).min
self.TS_FP16_MAX = torch.finfo(torch.float16).max
[docs]
def quantization(self, params: dict, fl_ctx: FLContext):
n_params = len(params.keys())
self.log_info(fl_ctx, f"Running quantization on {n_params} variables")
n_bytes_before = 0
n_bytes_after = 0
n_bytes_meta = 0
n_quant_params = 0
quant_state = {}
source_datatype = {}
for i, param_name in enumerate(params.keys()):
values = params[param_name]
quant_state[param_name] = {}
# check the data type, numpy or torch
# otherwise error
if isinstance(values, np.ndarray):
# if numpy, convert to torch
source_data_format = "numpy"
elif isinstance(values, torch.Tensor):
source_data_format = "torch"
else:
raise ValueError(f"Invalid source data type: {type(values)}, valid: numpy or torch")
# get the data type of the values
if source_data_format == "numpy":
source_data_type = values.dtype.name
elif source_data_format == "torch":
source_data_type = str(values.dtype).split(".")[1]
source_datatype[param_name] = source_data_type
# check if the data type is valid
if source_data_type.upper() not in DATA_TYPE:
raise ValueError(f"Invalid source data type: {source_data_type}, valid: {DATA_TYPE}")
# get the bits information
source_data_bits = int(re.findall(r"\d+", source_data_type)[0])
quantization_bits = int(re.findall(r"\d+", self.quantization_type)[0])
# add the number of bytes of the values
n_bytes_before += values.nbytes
# only quantize if the quantization type is lower than the source data type
if quantization_bits >= source_data_bits:
self.log_info(
fl_ctx,
f"Skipping quantization for {param_name}, quantization bit {self.quantization_type} >= source data bit {source_data_type}",
)
continue
else:
n_quant_params += 1
if self.quantization_type == "float16":
if source_data_format == "numpy":
# first clamp the values to the range of float16
values = np.clip(values, self.NP_FP16_MIN, self.NP_FP16_MAX)
# then convert to float16
values = values.astype(np.float16)
elif source_data_format == "torch":
# first clamp the values to the range of float16
values = torch.clamp(values, self.TS_FP16_MIN, self.TS_FP16_MAX)
# then convert to float16
values = values.to(torch.float16)
params[param_name] = values
elif self.quantization_type in ["blockwise8", "float4", "normfloat4"]:
# use bitsandbytes to quantize the values
# input is a tensor, output is a tuple of (quantized tensor, quantized_state)
# CPU has limited support for 8- and 4-bits quantization
# For general purpose, here we use GPU
if source_data_format == "numpy":
# if numpy, first convert numpy array to tensor, need to use GPU
values_tensor = torch.as_tensor(values).cuda()
elif source_data_format == "torch":
# if torch, directly use the tensor, need to use GPU
values_tensor = values.cuda()
if self.quantization_type == "blockwise8":
# quantize the tensor
quantized, quantized_state = quantize_blockwise(values_tensor)
# add the quantization state and values, keep source data format
if source_data_format == "numpy":
quant_state[param_name]["absmax"] = quantized_state.absmax.cpu().numpy()
quant_state[param_name]["code"] = quantized_state.code.cpu().numpy()
values = quantized.cpu().numpy()
elif source_data_format == "torch":
quant_state[param_name]["absmax"] = quantized_state.absmax.cpu()
quant_state[param_name]["code"] = quantized_state.code.cpu()
values = quantized.cpu()
n_bytes_meta += quant_state[param_name]["absmax"].nbytes
n_bytes_meta += quant_state[param_name]["code"].nbytes
else:
# then quantize the tensor
if self.quantization_type == "float4":
quantized, quantized_state = quantize_4bit(values_tensor, quant_type="fp4")
else:
quantized, quantized_state = quantize_4bit(values_tensor, quant_type="nf4")
# add the quantization state and values, keep source data format
quantized_state = quantized_state.as_dict()
# prepared the message
for state_name, state in quantized_state.items():
if isinstance(state, torch.Tensor):
if source_data_format == "numpy":
# if the state is a tensor, convert it to numpy array
quant_state[param_name][state_name] = state.cpu().numpy()
elif source_data_format == "torch":
# if the state is a tensor, keep it as tensor
quant_state[param_name][state_name] = state.cpu()
n_bytes_meta += state.nbytes
else:
quant_state[param_name][state_name] = state
# add values
if source_data_format == "numpy":
values = quantized.cpu().numpy()
elif source_data_format == "torch":
values = quantized.cpu()
params[param_name] = values
n_bytes_after += params[param_name].nbytes
self.log_info(
fl_ctx,
f"Quantized {n_quant_params}/{n_params} params."
f" Before quantization: {n_bytes_before / (1024 ** 2):.2f} MB."
f" After quantization: {n_bytes_after / (1024 ** 2):.2f} MB with meta: {n_bytes_meta / (1024 ** 2):.2f} MB.",
)
return params, quant_state, source_datatype
[docs]
def process_dxo(self, dxo: DXO, shareable: Shareable, fl_ctx: FLContext) -> Union[None, DXO]:
"""Filter process apply to the Shareable object.
Args:
dxo: data to be processed
shareable: that the dxo belongs to
fl_ctx: FLContext
Returns: DXO object with quantized weights
"""
self.log_info(fl_ctx, "Running quantization...")
# for already quantized message, skip quantization
# The reason in this current example:
# server job in this case is 1-N communication with identical quantization operation
# the first communication to client will apply quantization and change the data on the server
# thus the subsequent communications to the rest of clients will no longer need to apply quantization
# This will not apply to client job, since the client job will be 1-1 and quantization applies to each client
# Potentially:
# - If clients talks to each other, it will also be 1-N and same rule applies
# - If 1-N server-client filters can be different (Filter_1 applies to server-client_subset_1, etc.), then
# a deep copy of the server data should be made by filter before applying a different filter
# quantized_flag None if does not exist in meta
quantized_flag = dxo.get_meta_prop("quantized_flag")
if quantized_flag:
self.log_info(fl_ctx, "Already quantized, skip quantization")
new_dxo = dxo
else:
# apply quantization
quantized_params, quant_state, source_datatype = self.quantization(params=dxo.data, fl_ctx=fl_ctx)
# Compose new DXO with quantized data
# Add quant_state to the new DXO meta
new_dxo = DXO(data_kind=dxo.data_kind, data=quantized_params, meta=dxo.meta)
new_dxo.set_meta_prop(key=MetaKey.PROCESSED_ALGORITHM, value=self.quantization_type)
new_dxo.set_meta_prop(key="quant_state", value=quant_state)
new_dxo.set_meta_prop(key="source_datatype", value=source_datatype)
new_dxo.set_meta_prop(key="quantized_flag", value=True)
self.log_info(fl_ctx, f"Quantized from {source_datatype} to {self.quantization_type}")
return new_dxo