diff --git a/test/quantization/test_quant_api.py b/test/quantization/test_quant_api.py index 93ac6fe739..10d36f0c1b 100644 --- a/test/quantization/test_quant_api.py +++ b/test/quantization/test_quant_api.py @@ -87,7 +87,7 @@ def quantize(self, model: torch.nn.Module) -> torch.nn.Module: apply_dynamic_quant(model) return model -class M(torch.nn.Module): +class ToyLinearModel(torch.nn.Module): def __init__(self): super().__init__() self.linear1 = torch.nn.Linear(64, 32, bias=False).to(torch.float) @@ -103,7 +103,7 @@ def forward(self, x): class TestQuantFlow(unittest.TestCase): def test_dynamic_quant_gpu_singleline(self): - m = M().eval() + m = ToyLinearModel().eval() m = _apply_dynamic_quant(m) quantized = m(*m.example_inputs()) # AssertionError: Expecting input to have dtype torch.float32, but got dtype: torch.float64 @@ -116,7 +116,7 @@ def test_dynamic_quant_gpu_singleline(self): @unittest.skip("skipping for now due to torch.compile error") def test_dynamic_quant_gpu_unified_api_unified_impl(self): quantizer = XNNPackDynamicQuantizer() - m = M().eval() + m = ToyLinearModel().eval() example_inputs = m.example_inputs() m = quantizer.prepare(m) m = quantizer.convert(m) @@ -131,7 +131,7 @@ def test_dynamic_quant_gpu_unified_api_unified_impl(self): @unittest.skip("FAILED test/quantization/test_quant_api.py::TestQuantFlow::test_dynamic_quant_gpu_unified_api_eager_mode_impl - AssertionError: Tensor-likes are not equal!") def test_dynamic_quant_gpu_unified_api_eager_mode_impl(self): quantizer = TorchCompileDynamicQuantizer() - m = M().eval() + m = ToyLinearModel().eval() example_inputs = m.example_inputs() m = quantizer.quantize(m) quantized = m(*example_inputs) @@ -141,7 +141,7 @@ def test_dynamic_quant_gpu_unified_api_eager_mode_impl(self): @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") def test_int8_wo_quant_save_load(self): - m = M().eval().cpu() + m = ToyLinearModel().eval().cpu() apply_weight_only_int8_quant(m) example_inputs = m.example_inputs() ref = m(*example_inputs) @@ -150,7 +150,7 @@ def test_int8_wo_quant_save_load(self): state_dict = torch.load(_TMP_FN) os.remove(_TMP_FN) - m2 = M().eval() + m2 = ToyLinearModel().eval() apply_weight_only_int8_quant(m2) m2.load_state_dict(state_dict) m2 = m2.to(device="cuda") @@ -165,7 +165,7 @@ def test_8da4w_quantizer(self): from torchao.quantization.GPTQ import Int8DynActInt4WeightLinear quantizer = Int8DynActInt4WeightQuantizer(groupsize=32) - m = M().eval() + m = ToyLinearModel().eval() example_inputs = m.example_inputs() m = quantizer.quantize(m) assert isinstance(m.linear1, Int8DynActInt4WeightLinear) @@ -392,5 +392,58 @@ def test_eval_wrapper(self): f"accuracy regressed from 7.76 to {result['results']['wikitext']['word_perplexity,none']}" ) + # TODO: move to a separate test file + @unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "Test only enabled for 2.4+") + def test_quantized_tensor_subclass_8da4w(self): + from torchao.quantization.subclass import AffineQuantizedTensor + from torchao.quantization.quant_primitives import MappingType + import copy + + # weight settings + groupsize = 32 + mapping_type = MappingType.SYMMETRIC + block_size = (1, groupsize) + target_dtype = torch.int8 + eps = torch.finfo(torch.float32).eps + quant_min = -8 + quant_max = 7 + + # TODO: make a general helper function? + def get_per_token_block_size(x): + block_size = [] + for i in range(len(x.shape)-1): + block_size.append(1) + block_size.append(x.shape[-1]) + return block_size + + # input settings + input_mapping_type = MappingType.ASYMMETRIC + input_target_dtype = torch.int8 + input_quant_func = lambda x: AffineQuantizedTensor.from_float(x, input_mapping_type, get_per_token_block_size(x), input_target_dtype) + + m = ToyLinearModel().eval() + m_copy = copy.deepcopy(m) + example_inputs = m.example_inputs() + m.linear1.weight = torch.nn.Parameter(AffineQuantizedTensor.from_float(m.linear1.weight, mapping_type, block_size, target_dtype, quant_min, quant_max, eps, input_quant_func=input_quant_func), requires_grad=False) + m.linear2.weight = torch.nn.Parameter(AffineQuantizedTensor.from_float(m.linear2.weight, mapping_type, block_size, target_dtype, quant_min, quant_max, eps, input_quant_func=input_quant_func), requires_grad=False) + assert isinstance(m.linear1.weight, AffineQuantizedTensor) + assert isinstance(m.linear2.weight, AffineQuantizedTensor) + + # reference + from torchao.quantization.quant_api import Int8DynActInt4WeightQuantizer + from torchao.quantization.GPTQ import Int8DynActInt4WeightLinear + + quantizer = Int8DynActInt4WeightQuantizer(groupsize=groupsize) + m_copy = quantizer.quantize(m_copy) + assert isinstance(m_copy.linear1, Int8DynActInt4WeightLinear) + assert isinstance(m_copy.linear2, Int8DynActInt4WeightLinear) + + res = m(*example_inputs) + ref = m_copy(*example_inputs) + self.assertTrue(torch.equal(res, ref)) + + + + if __name__ == "__main__": unittest.main() diff --git a/torchao/quantization/quant_primitives.py b/torchao/quantization/quant_primitives.py index 90316e1557..f59144becd 100644 --- a/torchao/quantization/quant_primitives.py +++ b/torchao/quantization/quant_primitives.py @@ -136,7 +136,7 @@ def _get_reduction_params(block_size, input_size): def quantize_affine( input: torch.Tensor, - block_size: List[int], + block_size: Tuple[int, ...], scale: torch.Tensor, zero_point: Optional[torch.Tensor], output_dtype: torch.dtype, @@ -146,7 +146,7 @@ def quantize_affine( """ Args: input (torch.Tensor): original float32 or bfloat16 Tensor - block_size: (List[int]): granularity of quantization, this means the size of the tensor elements that's sharing the same qparam + block_size: (Tuple[int, ...]): granularity of quantization, this means the size of the tensor elements that's sharing the same qparam e.g. when size is the same as the input tensor dimension, we are using per tensor quantization scale (float): quantization parameter for affine quantization zero_point (int): quantization parameter for affine quantization @@ -191,7 +191,7 @@ def quantize_affine( def dequantize_affine( input: torch.Tensor, - block_size: List[int], + block_size: Tuple[int, ...], scale: torch.Tensor, zero_point: Optional[torch.Tensor], input_dtype: torch.dtype, @@ -244,7 +244,7 @@ class MappingType(Enum): def choose_qparams_affine( input: torch.Tensor, mapping_type: MappingType, - block_size: List[int], + block_size: Tuple[int, ...], target_dtype: torch.dtype, quant_min: Optional[int] = None, quant_max: Optional[int] = None, @@ -256,12 +256,14 @@ def choose_qparams_affine( Args: input (torch.Tensor): fp32, bf16, fp16 input Tensor mapping_type (MappingType): determines how the qparams are calculated, symmetric or asymmetric + block_size: (Tuple[int, ...]): granularity of quantization, this means the size of the tensor elements that's sharing the same qparam + e.g. when size is the same as the input tensor dimension, we are using per tensor quantization target_dtype (torch.dtype): dtype for target quantized Tensor quant_min (Optional[int]): minimum quantized value for target quantized Tensor quant_max (Optioanl[int]): maximum quantized value for target quantized Tensor - eps (Optional[float]: minimum scale - scale_dtype (torch.dtype): dtype for scales - zero_point_dtype (torch.dtype): dtype for zero_points + eps (Optional[float]): minimum scale, if not provided, default to eps of input.dtype + scale_dtype (torch.dtype): dtype for scale Tensor + zero_point_dtype (torch.dtype): dtype for zero_point Tensor Output: Tuple of scales and zero_points Tensor with requested dtype diff --git a/torchao/quantization/subclass.py b/torchao/quantization/subclass.py index 7de4a6169f..148228a030 100644 --- a/torchao/quantization/subclass.py +++ b/torchao/quantization/subclass.py @@ -15,14 +15,19 @@ groupwise_affine_quantize_tensor, quant_int8_dynamic_per_token_linear, unpack_tinygemm_scales_and_zeros, + choose_qparams_affine, + quantize_affine, + dequantize_affine, ) from .utils import find_multiple +from typing import Tuple, Optional, Callable __all__ = [ "Int8DynamicallyQuantizedLinearWeight", "Int8WeightOnlyQuantizedLinearWeight", "Int4WeightOnlyQuantizedLinearWeight", + "AffineQuantizedTensor", ] @@ -134,14 +139,21 @@ def __torch_function__(cls, func, types, args=(), kwargs=None): @classmethod def __torch_dispatch__(cls, func, types, args, kwargs): + # Note: we only added cpu path here for 8da4w, this is for executorch, in the future + # 1. we'll add cpu/cuda version (int4mm etc.) + # 2. we'll need to hide the 8da4w executorch version under things like layouts (we also have multiple impl for cpu kernel as Michael mentioned), so it will be something like + # cpu device + et laytout --> gives current 8da4w executorch representation + # cpu device + avx layout --> gives optimized kernel for 8da4w in avx cpu etc. + # cuda device + some layout --> gives cuda kernel + # two scenarios where we currently fall back to vanilla mm: - # 1 - when tensor is on CPU: we are missing qmm for CPU, but we should have a CPU implementation - # for consistency and to allow people to test + # 1 - when tensor is on CUDA: we'll add this later, we'll also enable dispatching to optimized + # kernels in CPU as well, see the note above # 2 - we're given non-floats - quantizing long to int8 is crazy if ( func in [aten.mm.default, aten.addmm.default] and args[0].is_floating_point() - and args[0].is_cuda + and args[0].device == torch.device("cpu") ): if func == aten.addmm.default: assert args[1].shape[-1] == args[2].shape[0], ( @@ -592,3 +604,263 @@ def to_qtensor_components(cls, input_float, groupsize=128, inner_k_tiles=8): ) int_data = aten._convert_weight_to_int4pack(input_int4x8, inner_k_tiles) return int_data, scales_and_zeros, False, groupsize, inner_k_tiles + + +class AffineQuantizedTensor(torch.Tensor): + """ + Base affine quantized tensor subclass. When the from_float method is used, + to create an instance of any AffineQuantizedTensor + + The shape and dtype of the tensor subclass represent how the tensor subclass looks externally, + regardless of the internal representation's type or orientation. + + Affine quantization means we quantize the floating point tensor with an affine transformation: + quantized_tensor = float_tensor / scale + zero_point + + fields: + int_data (torch.Tensor): the quantized integer data Tensor + scale (torch.Tensor): the scale Tensor used to map between floating point tensor to quantized tensor + zero_point (torch.Tensor): the zero_point Tensor used to map between floating point tensor to quantized tensor + block_size (Tuple[int, ...]): granularity of quantization, this means the size of the tensor elements that's sharing the same qparam + e.g. when size is the same as the input tensor dimension, we are using per tensor quantization + shape (torch.Size): the shape for the Tensor + quant_min (Optional[int]): minimum quantized value for the Tensor, if not specified, it will be derived from dtype of `int_data` + quant_max (Optional[int]): maximum quantized value for the Tensor, if not specified, it will be derived from dtype of `int_data` + input_quant_func (Optional[Callable]): function for quantizing the input float Tensor to a quantized tensor subclass object, that takes input Tensor as input and outputs an AffineQuantizedTensor object + dtype: dtype for external representation of the tensor, e.g. torch.float32 + """ + + @staticmethod + def __new__( + cls, + int_data: torch.Tensor, + scale: torch.Tensor, + zero_point: torch.Tensor, + block_size: Tuple[int, ...], + shape: torch.Size, + quant_min: Optional[int] = None, + quant_max: Optional[int] = None, + input_quant_func: Optional[Callable] = None, + dtype=None, + *args, + **kwargs + ): + kwargs["device"] = int_data.device + kwargs["layout"] = ( + kwargs.get("layout") if kwargs.get("layout", False) else int_data.layout + ) + if dtype is None: + dtype = scale.dtype + kwargs["dtype"] = dtype + assert not kwargs.get("requires_grad", False) + kwargs["requires_grad"] = False + return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) # type: ignore[attr-defined] + + def __init__( + self, + int_data: torch.Tensor, + scale: torch.Tensor, + zero_point: torch.Tensor, + block_size: Tuple[int, ...], + shape: torch.Size, + quant_min: Optional[int] = None, + quant_max: Optional[int] = None, + input_quant_func: Optional[Callable] = None, + dtype=None, + *args, + **kwargs + ): + self.int_data = int_data + self.scale = scale + self.zero_point = zero_point + self.block_size = block_size + self.quant_min = quant_min + self.quant_max = quant_max + self.input_quant_func = input_quant_func + + def __repr__(self): + return ( + f"{self.__class__.__name__}(data={self.dequantize()}, shape={self.shape}, " + f"device={self.device}, dtype={self.dtype}, input_quant_func={self.input_quant_func}, requires_grad={self.requires_grad})" + ) + + def dequantize(self, output_dtype=torch.float32): + return dequantize_affine(self.int_data, self.block_size, self.scale, self.zero_point, self.int_data.dtype, self.quant_min, self.quant_max, output_dtype=output_dtype) + + def __tensor_flatten__(self): + return ["int_data", "scales", "zero_point"], [self.block_size, self.shape, self.quant_min, self.quant_max, self.input_quant_func, self.dtype] + + @classmethod + def __tensor_unflatten__( + cls, tensor_data_dict, tensor_attributes, outer_size, outer_stride + ): + int_data, scale, zero_point = tensor_data_dict["int_data"], tensor_data_dict["scale"], tensor_data_dict["zero_point"] + block_size, shape, quant_min, quant_max, input_quant_func, dtype = tensor_attributes + return cls( + int_data, + scale, + zero_point, + block_size, + shape if outer_size is None else outer_size, + quant_min, + quant_max, + input_quant_func=input_quant_func, + dtype=dtype, + strides=outer_stride, + ) + + @classmethod + def from_float( + cls, + input_float, + mapping_type, + block_size, + target_dtype, + quant_min = None, + quant_max = None, + eps = None, + scale_dtype = None, + zero_point_dtype = None, + input_quant_func = None, + ): + scale, zero_point = choose_qparams_affine(input_float, mapping_type, block_size, target_dtype, quant_min, quant_max, eps, scale_dtype, zero_point_dtype) + int_data = quantize_affine(input_float, block_size, scale, zero_point, target_dtype, quant_min, quant_max) + return cls( + int_data, + scale, + zero_point, + block_size, + input_float.shape, + quant_min, + quant_max, + input_quant_func=input_quant_func, + dtype=input_float.dtype + ) + + @classmethod + def __torch_function__(cls, func, types, args=(), kwargs=None): + kwargs = {} if kwargs is None else kwargs + + if func is torch.nn.functional.linear: + input_tensor, weight_qtensor, bias = ( + args[0], + args[1], + args[2] if len(args) > 2 else None, + ) + if weight_qtensor.input_quant_func is not None: + input_tensor = weight_qtensor.input_quant_func(input_tensor) + input_tensor = input_tensor.dequantize() + weight_tensor = weight_qtensor.dequantize() + return torch.nn.functional.linear(input_tensor, weight_tensor, bias) + + try: + with torch._C.DisableTorchFunctionSubclass(): + return func(*args, **kwargs) + except: + print(f"ERR: subclass doesn't implement {func}") + + + def _get_to_kwargs(self, *args, **kwargs): + device, dtype, _, memory_format = torch._C._nn._parse_to(*args, **kwargs) + device = self.device if device is None else device + dtype = self.dtype if dtype is None else dtype + memory_format = ( + memory_format if memory_format is not None else torch.preserve_format + ) + kwargs = { + "device": device, + "dtype": dtype, + "memory_format": memory_format, + } + return kwargs + + def to(self, *args, **kwargs): + kwargs = self._get_to_kwargs(*args, **kwargs) + return self.__class__( + self.int_data.to(kwargs["device"]), + self.scale.to(kwargs["device"]), + self.zero_point.to(kwargs["device"]), + self.block_size, + self.shape, + self.quant_min, + self.quant_max, + self.input_quant_func, + **kwargs, + ) + + def _apply_fn_to_data(self, fn): + return self.__class__( + fn(self.int_data), + fn(self.scale), + fn(self.zero_point), + self.block_size, + self.shape, + self.quant_min, + self.quant_max, + self.input_quant_func, + dtype=self.dtype, + ) + + @classmethod + def __torch_dispatch__(cls, func, types, args, kwargs): + # two scenarios where we currently fall back to vanilla mm: + # 1 - when tensor is on CPU: we are missing qmm for CPU, but we should have a CPU implementation + # for consistency and to allow people to test + # 2 - we're given non-floats - quantizing long to int8 is crazy + if ( + func in [aten.mm.default, aten.addmm.default] + and args[0].is_floating_point() + and args[0].is_cuda + ): + if func == aten.addmm.default: + assert args[1].shape[-1] == args[2].shape[0], ( + f"need mat1 shape: {args[1].shape} final" + f"dim to match mat2 shape: {args[2].shape} first dim " + ) + input_tensor, weight_qtensor, bias = ( + args[1], + args[2], + args[0], + ) + else: + assert args[0].shape[-1] == args[1].shape[0], ( + f"need mat1 shape: {args[0].shape} final dim" + f"to match mat2 shape: {args[1].shape} first dim" + ) + input_tensor, weight_qtensor, bias = ( + args[0], + args[1], + None if len(args) == 2 else args[2], + ) + if weight_qtensor.input_quant_func is not None: + input_tensor = weight_qtensor.input_quant_func(input_tensor) + input_tensor = input_tensor.dequantize() + weight_tensor = weight_qtensor.dequantize() + return func(input_tensor, weight_tensor, bias) + + if (func is aten.detach.default or + func is aten.clone.default or + func is aten._to_copy.default): + return return_and_correct_aliasing( + func, args, kwargs, args[0]._apply_fn_to_data(torch.detach) + ) + + if func is aten.clone.default: + return return_and_correct_aliasing( + func, args, kwargs, args[0]._apply_fn_to_data(torch.clone) + ) + + if func is aten.t.default: + # TODO: need to implement this + # args[0].transposed = not args[0].transposed + # new = args[0]._change_shape(args[0].shape[::-1]) + # return return_and_correct_aliasing(func, args, kwargs, new) + raise Exception("transpose not implemented yet") + + if func is aten._to_copy.default: + return return_and_correct_aliasing( + func, + args, + kwargs, + args[0].to(*args[1:], **kwargs)._apply_fn_to_data(torch.clone), + )