# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import bz2
import math
from typing import Any, Optional, Union
import numpy as np
import torch
[docs]
class AdaQuantizer:
def __init__(self, weight: float = 0.01, compression: bool = True) -> None:
"""Implements the ADAQUANT quantization scheme,
for further details refer to the paper https://arxiv.org/abs/2208.05174
Args:
weight: a hyperparameter for the trade-off between quantization size and error
compression: whether to compress the resulting integer quantized tensor
"""
self.weight = weight
self.compression = compression
[docs]
def quantize(self, values_tensor: torch.Tensor) -> tuple[Union[torch.Tensor, np.ndarray], dict]:
old_values_tensor = values_tensor
values_tensor = values_tensor.to(dtype=torch.float64).view(-1)
offset = self.get_offset(values_tensor)
values_tensor = values_tensor + offset
res = self.get_number_of_quantization_levels(
element_size=old_values_tensor.element_size(), values_tensor=values_tensor
)
if res is None:
return old_values_tensor, {}
quant_state = {"offset": offset}
old_tensor_shape = list(old_values_tensor.shape)
norm, quantization_level, new_dtype = res
if norm == 0.0:
return torch.tensor([0], dtype=torch.bool), quant_state | {
"tensor_shape": old_tensor_shape,
}
normalized_tensor = values_tensor / norm
quantized_tensor = (
(normalized_tensor * quantization_level)
.round()
.clamp(0, quantization_level)
.numpy()
.astype(dtype=new_dtype)
)
if self.compression:
raw_bytes = quantized_tensor.tobytes()
compressed_bytes = bz2.compress(raw_bytes, compresslevel=1)
if len(compressed_bytes) < len(raw_bytes):
compressed_tensor = np.frombuffer(compressed_bytes, dtype=np.uint8)
return torch.tensor([0], dtype=torch.bool), quant_state | {
"compressed_tensor": compressed_tensor,
"quantization_level": quantization_level,
"new_dtype": str(new_dtype),
"tensor_shape": old_tensor_shape,
"norm": norm,
}
return quantized_tensor, quant_state | {
"quantization_level": quantization_level,
"tensor_shape": old_tensor_shape,
"norm": norm,
}
[docs]
def dequantized(self, quantized_tensor: torch.Tensor, quant_state: dict) -> torch.Tensor:
offset = quant_state["offset"]
if "norm" not in quant_state:
return torch.zeros(quant_state["tensor_shape"], dtype=torch.float64) - offset
norm = quant_state["norm"]
if "compressed_tensor" in quant_state:
decompressed_tensor = bz2.decompress(quant_state["compressed_tensor"].tobytes())
quantized_tensor = torch.from_numpy(np.frombuffer(decompressed_tensor, dtype=quant_state["new_dtype"]))
quantization_level = quant_state["quantization_level"]
quantized_tensor = quantized_tensor.to(dtype=torch.float64).reshape(quant_state["tensor_shape"])
return (quantized_tensor * norm / quantization_level) - offset
[docs]
def get_offset(self, tensor: torch.Tensor) -> float:
min_value = tensor.min().item()
return -min_value
[docs]
def get_number_of_quantization_levels(
self, element_size: int, values_tensor: torch.Tensor
) -> Optional[tuple[float, int, Any]]:
norm = values_tensor.max().item()
element_bits = element_size * 8
quantization_level = math.ceil(max(1, math.sqrt(norm * element_bits * math.log(4) / self.weight)))
new_element_bits = math.ceil(math.log2(quantization_level))
quantization_level = int(2**new_element_bits) - 1
new_dtype = None
if new_element_bits < element_bits:
if new_element_bits <= 8:
new_dtype = "u1"
elif new_element_bits <= 16:
new_dtype = "<u2"
else:
raise RuntimeError(f"Invalid element_bits {new_element_bits}")
if new_dtype is None:
return None
return norm, quantization_level, new_dtype