From bc2aaaf4f16faa518539f44d3ade6b6ab648b1b4 Mon Sep 17 00:00:00 2001 From: Jerry Zhang Date: Fri, 18 Oct 2024 18:09:02 -0700 Subject: [PATCH] Add support for groupwise quantization for int8 weight only quantization (#1121) Summary: This is to support deprecating torchchat int8 weight only quantization: https://github.com/pytorch/torchchat/blob/ecc628da7c32c486742d92a751ed045b2a2194be/torchchat/utils/quantize.py#L582 Test Plan: python test/integration/test_integration.py -k test_weight_only_groupwise_quant Reviewers: Subscribers: Tasks: Tags: --- test/integration/test_integration.py | 14 ++++++++++++++ torchao/quantization/quant_api.py | 9 ++++++--- 2 files changed, 20 insertions(+), 3 deletions(-) diff --git a/test/integration/test_integration.py b/test/integration/test_integration.py index a451605c1d..837b1de7b2 100644 --- a/test/integration/test_integration.py +++ b/test/integration/test_integration.py @@ -109,6 +109,10 @@ def _int8wo_api(mod): else: change_linear_weights_to_int8_woqtensors(mod) +def _int8wo_groupwise_api(mod): + group_size = 32 + quantize_(mod, int8_weight_only(group_size=group_size), set_inductor_config=False) + def _int8da_int8w_api(mod): if TORCH_VERSION_AT_LEAST_2_4: quantize_(mod, int8_dynamic_activation_int8_weight(), set_inductor_config=False) @@ -927,6 +931,16 @@ def test_weight_only_quant(self): sqnr = compute_error(y_ref, y_wo) self.assertGreater(sqnr, 43.0) + def test_weight_only_groupwise_quant(self): + for x_shape in [[128, 512]]: + x = torch.randn(*x_shape) + m = nn.Sequential(nn.Linear(512, 32)) + y_ref = m(x) + _int8wo_groupwise_api(m) + y_wo = m(x) + sqnr = compute_error(y_ref, y_wo) + self.assertGreater(sqnr, 45.0) + @parameterized.expand(COMMON_DEVICE_DTYPE) @torch.no_grad() @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py index 91803fe3f7..484baa2865 100644 --- a/torchao/quantization/quant_api.py +++ b/torchao/quantization/quant_api.py @@ -563,19 +563,22 @@ def apply_int4_weight_only_quant(weight): return _get_linear_subclass_inserter(apply_int4_weight_only_quant) -def int8_weight_only(): +def int8_weight_only(group_size=None): """ Applies int8 weight-only symmetric per-channel quantization to linear layers. """ - def apply_int8wo_quant(weight): + def apply_int8wo_quant(weight, group_size=None): mapping_type = MappingType.SYMMETRIC target_dtype = torch.int8 eps = torch.finfo(torch.float32).eps zero_point_dtype = torch.int64 + if group_size is None: + group_size = weight.shape[1] + block_size = (1, weight.shape[1]) return to_affine_quantized_intx(weight, mapping_type, block_size, target_dtype, eps=eps, zero_point_dtype=zero_point_dtype) - return _get_linear_subclass_inserter(apply_int8wo_quant) + return _get_linear_subclass_inserter(apply_int8wo_quant, group_size=group_size) def _int8_symm_per_token_reduced_range_quant(x: torch.Tensor) -> torch.Tensor: mapping_type = MappingType.SYMMETRIC