Skip to content

Commit

Permalink
Unified AffineQuantizedTensor subclass (#214)
Browse files Browse the repository at this point in the history
Summary:
Creatd a `AffineQuantizedTensor` subclass that works for both weight and input (for dynamic quantization), for all granularities (levering the recently added choose_qparams_affine, quantize_affine
and dequantize_affine ops)

only verified for 8da4w right now, we can make it work for other types of quantization (mostly the operator dispatching part) later

Test Plan:
python test/quantization/test_quant_api.py -k test_quantized_tensor_subclass_8da4w

Reviewers:

Subscribers:

Tasks:

Tags:

Co-authored-by: Mark Saroufim <[email protected]>
  • Loading branch information
jerryzh168 and msaroufim authored May 7, 2024
1 parent c2657e4 commit f0bdc8f
Show file tree
Hide file tree
Showing 3 changed files with 344 additions and 17 deletions.
67 changes: 60 additions & 7 deletions test/quantization/test_quant_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ def quantize(self, model: torch.nn.Module) -> torch.nn.Module:
apply_dynamic_quant(model)
return model

class M(torch.nn.Module):
class ToyLinearModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.linear1 = torch.nn.Linear(64, 32, bias=False).to(torch.float)
Expand All @@ -103,7 +103,7 @@ def forward(self, x):

class TestQuantFlow(unittest.TestCase):
def test_dynamic_quant_gpu_singleline(self):
m = M().eval()
m = ToyLinearModel().eval()
m = _apply_dynamic_quant(m)
quantized = m(*m.example_inputs())
# AssertionError: Expecting input to have dtype torch.float32, but got dtype: torch.float64
Expand All @@ -116,7 +116,7 @@ def test_dynamic_quant_gpu_singleline(self):
@unittest.skip("skipping for now due to torch.compile error")
def test_dynamic_quant_gpu_unified_api_unified_impl(self):
quantizer = XNNPackDynamicQuantizer()
m = M().eval()
m = ToyLinearModel().eval()
example_inputs = m.example_inputs()
m = quantizer.prepare(m)
m = quantizer.convert(m)
Expand All @@ -131,7 +131,7 @@ def test_dynamic_quant_gpu_unified_api_unified_impl(self):
@unittest.skip("FAILED test/quantization/test_quant_api.py::TestQuantFlow::test_dynamic_quant_gpu_unified_api_eager_mode_impl - AssertionError: Tensor-likes are not equal!")
def test_dynamic_quant_gpu_unified_api_eager_mode_impl(self):
quantizer = TorchCompileDynamicQuantizer()
m = M().eval()
m = ToyLinearModel().eval()
example_inputs = m.example_inputs()
m = quantizer.quantize(m)
quantized = m(*example_inputs)
Expand All @@ -141,7 +141,7 @@ def test_dynamic_quant_gpu_unified_api_eager_mode_impl(self):

@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
def test_int8_wo_quant_save_load(self):
m = M().eval().cpu()
m = ToyLinearModel().eval().cpu()
apply_weight_only_int8_quant(m)
example_inputs = m.example_inputs()
ref = m(*example_inputs)
Expand All @@ -150,7 +150,7 @@ def test_int8_wo_quant_save_load(self):

state_dict = torch.load(_TMP_FN)
os.remove(_TMP_FN)
m2 = M().eval()
m2 = ToyLinearModel().eval()
apply_weight_only_int8_quant(m2)
m2.load_state_dict(state_dict)
m2 = m2.to(device="cuda")
Expand All @@ -165,7 +165,7 @@ def test_8da4w_quantizer(self):
from torchao.quantization.GPTQ import Int8DynActInt4WeightLinear

quantizer = Int8DynActInt4WeightQuantizer(groupsize=32)
m = M().eval()
m = ToyLinearModel().eval()
example_inputs = m.example_inputs()
m = quantizer.quantize(m)
assert isinstance(m.linear1, Int8DynActInt4WeightLinear)
Expand Down Expand Up @@ -392,5 +392,58 @@ def test_eval_wrapper(self):
f"accuracy regressed from 7.76 to {result['results']['wikitext']['word_perplexity,none']}"
)

# TODO: move to a separate test file
@unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "Test only enabled for 2.4+")
def test_quantized_tensor_subclass_8da4w(self):
from torchao.quantization.subclass import AffineQuantizedTensor
from torchao.quantization.quant_primitives import MappingType
import copy

# weight settings
groupsize = 32
mapping_type = MappingType.SYMMETRIC
block_size = (1, groupsize)
target_dtype = torch.int8
eps = torch.finfo(torch.float32).eps
quant_min = -8
quant_max = 7

# TODO: make a general helper function?
def get_per_token_block_size(x):
block_size = []
for i in range(len(x.shape)-1):
block_size.append(1)
block_size.append(x.shape[-1])
return block_size

# input settings
input_mapping_type = MappingType.ASYMMETRIC
input_target_dtype = torch.int8
input_quant_func = lambda x: AffineQuantizedTensor.from_float(x, input_mapping_type, get_per_token_block_size(x), input_target_dtype)

m = ToyLinearModel().eval()
m_copy = copy.deepcopy(m)
example_inputs = m.example_inputs()
m.linear1.weight = torch.nn.Parameter(AffineQuantizedTensor.from_float(m.linear1.weight, mapping_type, block_size, target_dtype, quant_min, quant_max, eps, input_quant_func=input_quant_func), requires_grad=False)
m.linear2.weight = torch.nn.Parameter(AffineQuantizedTensor.from_float(m.linear2.weight, mapping_type, block_size, target_dtype, quant_min, quant_max, eps, input_quant_func=input_quant_func), requires_grad=False)
assert isinstance(m.linear1.weight, AffineQuantizedTensor)
assert isinstance(m.linear2.weight, AffineQuantizedTensor)

# reference
from torchao.quantization.quant_api import Int8DynActInt4WeightQuantizer
from torchao.quantization.GPTQ import Int8DynActInt4WeightLinear

quantizer = Int8DynActInt4WeightQuantizer(groupsize=groupsize)
m_copy = quantizer.quantize(m_copy)
assert isinstance(m_copy.linear1, Int8DynActInt4WeightLinear)
assert isinstance(m_copy.linear2, Int8DynActInt4WeightLinear)

res = m(*example_inputs)
ref = m_copy(*example_inputs)
self.assertTrue(torch.equal(res, ref))




if __name__ == "__main__":
unittest.main()
16 changes: 9 additions & 7 deletions torchao/quantization/quant_primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ def _get_reduction_params(block_size, input_size):

def quantize_affine(
input: torch.Tensor,
block_size: List[int],
block_size: Tuple[int, ...],
scale: torch.Tensor,
zero_point: Optional[torch.Tensor],
output_dtype: torch.dtype,
Expand All @@ -146,7 +146,7 @@ def quantize_affine(
"""
Args:
input (torch.Tensor): original float32 or bfloat16 Tensor
block_size: (List[int]): granularity of quantization, this means the size of the tensor elements that's sharing the same qparam
block_size: (Tuple[int, ...]): granularity of quantization, this means the size of the tensor elements that's sharing the same qparam
e.g. when size is the same as the input tensor dimension, we are using per tensor quantization
scale (float): quantization parameter for affine quantization
zero_point (int): quantization parameter for affine quantization
Expand Down Expand Up @@ -191,7 +191,7 @@ def quantize_affine(

def dequantize_affine(
input: torch.Tensor,
block_size: List[int],
block_size: Tuple[int, ...],
scale: torch.Tensor,
zero_point: Optional[torch.Tensor],
input_dtype: torch.dtype,
Expand Down Expand Up @@ -244,7 +244,7 @@ class MappingType(Enum):
def choose_qparams_affine(
input: torch.Tensor,
mapping_type: MappingType,
block_size: List[int],
block_size: Tuple[int, ...],
target_dtype: torch.dtype,
quant_min: Optional[int] = None,
quant_max: Optional[int] = None,
Expand All @@ -256,12 +256,14 @@ def choose_qparams_affine(
Args:
input (torch.Tensor): fp32, bf16, fp16 input Tensor
mapping_type (MappingType): determines how the qparams are calculated, symmetric or asymmetric
block_size: (Tuple[int, ...]): granularity of quantization, this means the size of the tensor elements that's sharing the same qparam
e.g. when size is the same as the input tensor dimension, we are using per tensor quantization
target_dtype (torch.dtype): dtype for target quantized Tensor
quant_min (Optional[int]): minimum quantized value for target quantized Tensor
quant_max (Optioanl[int]): maximum quantized value for target quantized Tensor
eps (Optional[float]: minimum scale
scale_dtype (torch.dtype): dtype for scales
zero_point_dtype (torch.dtype): dtype for zero_points
eps (Optional[float]): minimum scale, if not provided, default to eps of input.dtype
scale_dtype (torch.dtype): dtype for scale Tensor
zero_point_dtype (torch.dtype): dtype for zero_point Tensor
Output:
Tuple of scales and zero_points Tensor with requested dtype
Expand Down
Loading

0 comments on commit f0bdc8f

Please sign in to comment.