Skip to content

Commit

Permalink
Add support for groupwise quantization for int8 weight only quantizat…
Browse files Browse the repository at this point in the history
…ion (#1121)

Summary:
This is to support deprecating torchchat int8 weight only quantization: https://github.com/pytorch/torchchat/blob/ecc628da7c32c486742d92a751ed045b2a2194be/torchchat/utils/quantize.py#L582

Test Plan:
python test/integration/test_integration.py -k test_weight_only_groupwise_quant

Reviewers:

Subscribers:

Tasks:

Tags:
  • Loading branch information
jerryzh168 authored Oct 19, 2024
1 parent 3296749 commit bc2aaaf
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 3 deletions.
14 changes: 14 additions & 0 deletions test/integration/test_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,10 @@ def _int8wo_api(mod):
else:
change_linear_weights_to_int8_woqtensors(mod)

def _int8wo_groupwise_api(mod):
group_size = 32
quantize_(mod, int8_weight_only(group_size=group_size), set_inductor_config=False)

def _int8da_int8w_api(mod):
if TORCH_VERSION_AT_LEAST_2_4:
quantize_(mod, int8_dynamic_activation_int8_weight(), set_inductor_config=False)
Expand Down Expand Up @@ -927,6 +931,16 @@ def test_weight_only_quant(self):
sqnr = compute_error(y_ref, y_wo)
self.assertGreater(sqnr, 43.0)

def test_weight_only_groupwise_quant(self):
for x_shape in [[128, 512]]:
x = torch.randn(*x_shape)
m = nn.Sequential(nn.Linear(512, 32))
y_ref = m(x)
_int8wo_groupwise_api(m)
y_wo = m(x)
sqnr = compute_error(y_ref, y_wo)
self.assertGreater(sqnr, 45.0)

@parameterized.expand(COMMON_DEVICE_DTYPE)
@torch.no_grad()
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
Expand Down
9 changes: 6 additions & 3 deletions torchao/quantization/quant_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -563,19 +563,22 @@ def apply_int4_weight_only_quant(weight):
return _get_linear_subclass_inserter(apply_int4_weight_only_quant)


def int8_weight_only():
def int8_weight_only(group_size=None):
"""
Applies int8 weight-only symmetric per-channel quantization to linear layers.
"""
def apply_int8wo_quant(weight):
def apply_int8wo_quant(weight, group_size=None):
mapping_type = MappingType.SYMMETRIC
target_dtype = torch.int8
eps = torch.finfo(torch.float32).eps
zero_point_dtype = torch.int64
if group_size is None:
group_size = weight.shape[1]

block_size = (1, weight.shape[1])
return to_affine_quantized_intx(weight, mapping_type, block_size, target_dtype, eps=eps, zero_point_dtype=zero_point_dtype)

return _get_linear_subclass_inserter(apply_int8wo_quant)
return _get_linear_subclass_inserter(apply_int8wo_quant, group_size=group_size)

def _int8_symm_per_token_reduced_range_quant(x: torch.Tensor) -> torch.Tensor:
mapping_type = MappingType.SYMMETRIC
Expand Down

0 comments on commit bc2aaaf

Please sign in to comment.