From cc8bf8595dfbc6e5e2ca3f18bbd6e9384e794c04 Mon Sep 17 00:00:00 2001 From: Apurva Jain Date: Tue, 8 Oct 2024 13:49:27 -0700 Subject: [PATCH 1/6] Rename Layout -> TensorImpl (#1028) --- .github/workflows/regression_test.yml | 2 +- benchmarks/benchmark_fp6.py | 2 +- test/dtypes/test_affine_quantized_float.py | 14 +- test/dtypes/test_floatx.py | 14 +- test/hqq/test_hqq_affine.py | 4 +- test/integration/test_integration.py | 2 +- torchao/dtypes/__init__.py | 4 +- torchao/dtypes/affine_quantized_tensor.py | 210 +++++++++--------- torchao/dtypes/floatx/__init__.py | 2 +- torchao/dtypes/floatx/floatx.py | 18 +- torchao/dtypes/uintx/uintx.py | 6 +- torchao/dtypes/utils.py | 8 +- torchao/prototype/hqq/example.py | 4 +- torchao/quantization/autoquant.py | 8 +- torchao/quantization/quant_api.py | 2 +- torchao/sparsity/marlin/utils.py | 2 +- torchao/utils.py | 46 ++-- .../my_dtype_tensor_subclass.py | 64 +++--- .../my_trainable_tensor_subclass.py | 18 +- .../developer_api_guide/tensor_parallel.py | 16 +- 20 files changed, 223 insertions(+), 223 deletions(-) diff --git a/.github/workflows/regression_test.yml b/.github/workflows/regression_test.yml index 3aee8dbfb5..13cd4e2e76 100644 --- a/.github/workflows/regression_test.yml +++ b/.github/workflows/regression_test.yml @@ -35,7 +35,7 @@ jobs: gpu-arch-version: "12.1" - name: CUDA 2.4 runs-on: linux.g5.12xlarge.nvidia.gpu - torch-spec: 'torch==2.4.0' + torch-spec: 'torch==2.4.1' gpu-arch-type: "cuda" gpu-arch-version: "12.1" - name: CUDA Nightly (Oct 1) diff --git a/benchmarks/benchmark_fp6.py b/benchmarks/benchmark_fp6.py index 425507bd95..509ea6e86c 100644 --- a/benchmarks/benchmark_fp6.py +++ b/benchmarks/benchmark_fp6.py @@ -2,7 +2,7 @@ import pandas as pd import torch.nn.functional as F from torchao.dtypes import to_affine_quantized_fpx -from torchao.dtypes.floatx import FloatxTensorCoreAQTLayout, FloatxTensorCoreLayoutType +from torchao.dtypes.floatx import FloatxTensorCoreAQTTensorImpl, FloatxTensorCoreLayoutType from torchao.utils import benchmark_torch_function_in_microseconds from tqdm import tqdm diff --git a/test/dtypes/test_affine_quantized_float.py b/test/dtypes/test_affine_quantized_float.py index 621e3596e0..761b233fcd 100644 --- a/test/dtypes/test_affine_quantized_float.py +++ b/test/dtypes/test_affine_quantized_float.py @@ -210,18 +210,18 @@ def test_serialization(self, mode: str): # Compare weights if mode == "weight-only": - original_weight = original_layer.weight.layout_tensor.float8_data.to( - torch.float32 - ) - new_weight = new_layer.weight.layout_tensor.float8_data.to( + original_weight = original_layer.weight.tensor_impl.float8_data.to( torch.float32 ) + new_weight = new_layer.weight.tensor_impl.float8_data.to(torch.float32) else: - original_weight = original_layer.weight.original_weight_tensor.layout_tensor.float8_data.to( + original_weight = original_layer.weight.original_weight_tensor.tensor_impl.float8_data.to( torch.float32 ) - new_weight = new_layer.weight.original_weight_tensor.layout_tensor.float8_data.to( - torch.float32 + new_weight = ( + new_layer.weight.original_weight_tensor.tensor_impl.float8_data.to( + torch.float32 + ) ) assert torch.allclose( diff --git a/test/dtypes/test_floatx.py b/test/dtypes/test_floatx.py index b4776f95e1..f228c4c0c7 100644 --- a/test/dtypes/test_floatx.py +++ b/test/dtypes/test_floatx.py @@ -9,7 +9,7 @@ run_tests, ) from torchao.dtypes.floatx import ( - FloatxTensorCoreAQTLayout, + FloatxTensorCoreAQTTensorImpl, FloatxTensorCoreLayoutType, to_scaled_tc_floatx, from_scaled_tc_floatx, @@ -28,7 +28,7 @@ _Floatx_DTYPES = [(3, 2), (2, 2)] -class TestFloatxTensorCoreAQTLayout(TestCase): +class TestFloatxTensorCoreAQTTensorImpl(TestCase): @parametrize("device", _DEVICES) def test_pack_tc_fp6_correctness(self, device): x = torch.randint(256, size=(256, 64), dtype=torch.uint8, device=device) @@ -82,10 +82,10 @@ def test_to_copy_device(self, ebits, mbits): scale = choose_qparams_affine_floatx(x, ebits, mbits) x = quantize_affine_floatx(x, scale, ebits, mbits) layout_type = FloatxTensorCoreLayoutType(ebits, mbits) - floatx_layout_tensor = FloatxTensorCoreAQTLayout.from_plain(x, scale, None, layout_type).cuda() - assert floatx_layout_tensor.device.type == "cuda" - floatx_layout_tensor = floatx_layout_tensor.cpu() - assert floatx_layout_tensor.device.type == "cpu" + floatx_tensor_impl = FloatxTensorCoreAQTTensorImpl.from_plain(x, scale, None, layout_type).cuda() + assert floatx_tensor_impl.device.type == "cuda" + floatx_tensor_impl = floatx_tensor_impl.cpu() + assert floatx_tensor_impl.device.type == "cpu" @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") @pytest.mark.skipif(not TORCH_VERSION_AT_LEAST_2_5, reason="quantization only works with torch.compile for 2.5+") @@ -106,7 +106,7 @@ def test_fpx_weight_only(self, ebits, mbits, bias): torch.testing.assert_close(actual, expected) -instantiate_parametrized_tests(TestFloatxTensorCoreAQTLayout) +instantiate_parametrized_tests(TestFloatxTensorCoreAQTTensorImpl) if __name__ == "__main__": diff --git a/test/hqq/test_hqq_affine.py b/test/hqq/test_hqq_affine.py index f3fa41c643..c1177d2d4a 100644 --- a/test/hqq/test_hqq_affine.py +++ b/test/hqq/test_hqq_affine.py @@ -3,9 +3,9 @@ from torchao.dtypes.affine_quantized_tensor import ( to_affine_quantized_intx, ZeroPointDomain, - PlainAQTLayout, + PlainAQTTensorImpl, PlainLayoutType, - TensorCoreTiledAQTLayout, + TensorCoreTiledAQTTensorImpl, TensorCoreTiledLayoutType, MappingType, ) diff --git a/test/integration/test_integration.py b/test/integration/test_integration.py index be8f2f954e..46799b4916 100644 --- a/test/integration/test_integration.py +++ b/test/integration/test_integration.py @@ -1051,7 +1051,7 @@ def forward(self, x): self.assertTrue(torch.equal(ref_q, test)) @parameterized.expand(COMMON_DEVICE_DTYPE) - @unittest.skipIf(is_fbcode(), "'PlainAQTLayout' object has no attribute 'int_data'") + @unittest.skipIf(is_fbcode(), "'PlainAQTTensorImpl' object has no attribute 'int_data'") @torch.no_grad() def test_save_load_dqtensors(self, device, dtype): if device == "cpu": diff --git a/torchao/dtypes/__init__.py b/torchao/dtypes/__init__.py index e27bf6497a..8d4be52dc0 100644 --- a/torchao/dtypes/__init__.py +++ b/torchao/dtypes/__init__.py @@ -14,7 +14,7 @@ SemiSparseLayoutType, TensorCoreTiledLayoutType, Float8LayoutType, - Float8AQTLayout, + Float8AQTTensorImpl, MarlinSparseLayoutType, ) @@ -33,6 +33,6 @@ "SemiSparseLayoutType", "TensorCoreTiledLayoutType", "Float8LayoutType", - "Float8AQTLayout", + "Float8AQTTensorImpl", "MarlinSparseLayoutType", ] diff --git a/torchao/dtypes/affine_quantized_tensor.py b/torchao/dtypes/affine_quantized_tensor.py index c2c8e3c0b0..0fa864f8b0 100644 --- a/torchao/dtypes/affine_quantized_tensor.py +++ b/torchao/dtypes/affine_quantized_tensor.py @@ -51,17 +51,17 @@ aten = torch.ops.aten ############################### -# Base Layout Tensor Subclass # +# Base Tensor Impl Subclass # ############################### -class AQTLayout(TorchAOBaseTensor): +class AQTTensorImpl(TorchAOBaseTensor): """ - Base class for the layout tensor for `AffineQuantizedTensor` + Base class for the tensor impl for `AffineQuantizedTensor` """ def get_plain(self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - """Get the plain (unpacked) Tensor for the layout Tensor + """Get the plain (unpacked) Tensor for the tensor impl Returns data, scale and zero_point - Can be overwritten if other types of AQTLayout Tensor has different numbers of plain tensors + Can be overwritten if other types of AQTTensorImpl has different numbers of plain tensors """ pass @@ -76,7 +76,7 @@ def from_plain( zero_point: torch.Tensor, layout_type: LayoutType, ): - """ Construct a Layout from data, scale, zero_point and the layout_type""" + """ Construct a TensorImpl from data, scale, zero_point and the layout_type""" pass def __repr__(self): @@ -131,7 +131,7 @@ class AffineQuantizedTensor(TorchAOBaseTensor): regardless of the internal representation's type or orientation. fields: - layout_tensor (AQTLayout): tensor that serves as a general layout storage for the quantized data, + tensor_impl (AQTTensorImpl): tensor that serves as a general tensor impl storage for the quantized data, e.g. storing plain tensors (int_data, scale, zero_point) or packed formats depending on device and operator/kernel block_size (Tuple[int, ...]): granularity of quantization, this means the size of the tensor elements that's sharing the same qparam @@ -151,7 +151,7 @@ class AffineQuantizedTensor(TorchAOBaseTensor): @staticmethod def __new__( cls, - layout_tensor: AQTLayout, + tensor_impl: AQTTensorImpl, block_size: Tuple[int, ...], shape: torch.Size, quant_min: Optional[Union[int, float]] = None, @@ -161,9 +161,9 @@ def __new__( strides=None, ): kwargs = {} - kwargs["device"] = layout_tensor.device + kwargs["device"] = tensor_impl.device kwargs["layout"] = ( - kwargs.get("layout") if kwargs.get("layout", False) else layout_tensor.layout + kwargs.get("layout") if kwargs.get("layout", False) else tensor_impl.layout ) kwargs["dtype"] = dtype if strides is not None: @@ -173,7 +173,7 @@ def __new__( def __init__( self, - layout_tensor: AQTLayout, + tensor_impl: AQTTensorImpl, block_size: Tuple[int, ...], shape: torch.Size, quant_min: Optional[Union[int, float]] = None, @@ -182,7 +182,7 @@ def __init__( dtype=None, strides=None, ): - self.layout_tensor = layout_tensor + self.tensor_impl = tensor_impl self.block_size = block_size self.quant_min = quant_min self.quant_max = quant_max @@ -190,12 +190,12 @@ def __init__( def __repr__(self): return ( - f"{self.__class__.__name__}(layout_tensor={self.layout_tensor}, block_size={self.block_size}, " + f"{self.__class__.__name__}(tensor_impl={self.tensor_impl}, block_size={self.block_size}, " f"shape={self.shape}, device={self.device}, dtype={self.dtype}, requires_grad={self.requires_grad})" ) def _quantization_type(self): - return f"shape={self.shape}, block_size={self.block_size}, device={self.device}, layout_type={self.layout_type}, layout_tensor_dtype={self.layout_tensor.dtype}, quant_min={self.quant_min}, quant_max={self.quant_max}" + return f"shape={self.shape}, block_size={self.block_size}, device={self.device}, layout_type={self.layout_type}, tensor_impl_dtype={self.tensor_impl.dtype}, quant_min={self.quant_min}, quant_max={self.quant_max}" def dequantize(self, output_dtype: Optional[torch.dtype] = None) -> torch.Tensor: if output_dtype is None: @@ -203,10 +203,10 @@ def dequantize(self, output_dtype: Optional[torch.dtype] = None) -> torch.Tensor from torchao.dtypes.floatx import FloatxTensorCoreLayoutType if isinstance(self.layout_type, FloatxTensorCoreLayoutType): - int_data, scale = self.layout_tensor.get_plain() + int_data, scale = self.tensor_impl.get_plain() return dequantize_affine_floatx(int_data, scale, self.layout_type.ebits, self.layout_type.mbits, output_dtype=output_dtype) else: - data, scale, zero_point = self.layout_tensor.get_plain() + data, scale, zero_point = self.tensor_impl.get_plain() dq = dequantize_affine( data, self.block_size, @@ -232,16 +232,16 @@ def _quantized_linear_op(input_tensor, weight_tensor, bias): raise QuantizedLinearNotImplementedError("No specialized dispatch found for quantized linear op") def __tensor_flatten__(self): - return ["layout_tensor"], [self.block_size, self.shape, self.quant_min, self.quant_max, self.zero_point_domain, self.dtype] + return ["tensor_impl"], [self.block_size, self.shape, self.quant_min, self.quant_max, self.zero_point_domain, self.dtype] @classmethod def __tensor_unflatten__( cls, tensor_data_dict, tensor_attributes, outer_size, outer_stride ): - layout_tensor = tensor_data_dict["layout_tensor"] + tensor_impl = tensor_data_dict["tensor_impl"] block_size, shape, quant_min, quant_max, zero_point_domain, dtype = tensor_attributes return cls( - layout_tensor, + tensor_impl, block_size, shape if outer_size is None else outer_size, quant_min, @@ -289,10 +289,10 @@ def from_hp_to_intx( # Note: output will be uint8 tensor for sub byte tensors for now data = layout_type.post_process(data) - layout_tensor_ctr = get_layout_tensor_constructor(type(layout_type)) - layout_tensor = layout_tensor_ctr(data, scale, zero_point, layout_type) + tensor_impl_ctr = get_tensor_impl_constructor(type(layout_type)) + tensor_impl = tensor_impl_ctr(data, scale, zero_point, layout_type) return cls( - layout_tensor, + tensor_impl, block_size, original_shape, quant_min, @@ -324,10 +324,10 @@ def from_hp_to_intx_static( int_data = layout_type.post_process(int_data) - layout_tensor_ctr = get_layout_tensor_constructor(type(layout_type)) - layout_tensor = layout_tensor_ctr(int_data, scale, zero_point, layout_type) + tensor_impl_ctr = get_tensor_impl_constructor(type(layout_type)) + tensor_impl = tensor_impl_ctr(int_data, scale, zero_point, layout_type) return cls( - layout_tensor, + tensor_impl, block_size, original_shape, quant_min, @@ -410,10 +410,10 @@ def from_hp_to_fpx( floatx_unpacked = quantize_affine_floatx(input_float, scale, ebits, mbits) floatx_packed = layout_type.post_process(floatx_unpacked) - layout_tensor_ctr = get_layout_tensor_constructor(type(layout_type)) - layout_tensor = layout_tensor_ctr(floatx_packed, scale, None, layout_type) + tensor_impl_ctr = get_tensor_impl_constructor(type(layout_type)) + tensor_impl = tensor_impl_ctr(floatx_packed, scale, None, layout_type) return cls( - layout_tensor, + tensor_impl, block_size, original_shape, dtype=input_float.dtype @@ -421,13 +421,13 @@ def from_hp_to_fpx( @property def layout_type(self) -> LayoutType: - return self.layout_tensor.layout_type + return self.tensor_impl.layout_type def to(self, *args, **kwargs): kwargs = self._get_to_kwargs(*args, **kwargs) device = kwargs.pop("device") return self.__class__( - self.layout_tensor.to(device), + self.tensor_impl.to(device), self.block_size, self.shape, self.quant_min, @@ -438,7 +438,7 @@ def to(self, *args, **kwargs): def _apply_fn_to_data(self, fn): return self.__class__( - fn(self.layout_tensor), + fn(self.tensor_impl), self.block_size, self.shape, self.quant_min, @@ -464,10 +464,10 @@ def _apply_fn_to_data(self, fn): ###################################################### -# LayoutType and Layout Tensor Subclass Registration # +# LayoutType and TensorImpl Subclass Registration # ###################################################### -register_layout_cls = AffineQuantizedTensor.register_layout_cls -get_layout_tensor_constructor = AffineQuantizedTensor.get_layout_tensor_constructor +register_layout = AffineQuantizedTensor.register_layout +get_tensor_impl_constructor = AffineQuantizedTensor.get_tensor_impl_constructor @dataclass(frozen=True) class SemiSparseLayoutType(LayoutType): @@ -548,10 +548,10 @@ def pre_process(self, input: torch.Tensor) -> torch.Tensor: return w_24.t() -@register_layout_cls(PlainLayoutType) -class PlainAQTLayout(AQTLayout): +@register_layout(PlainLayoutType) +class PlainAQTTensorImpl(AQTTensorImpl): """ - Layout storage class for plain layout for affine quantized tensor, it stores int_data, scale, zero_point + TensorImpl storage class for plain layout for affine quantized tensor, it stores int_data, scale, zero_point tensors directly as plain tensors. fields: @@ -645,12 +645,12 @@ def __torch_dispatch__(cls, func, types, args, kwargs): ) elif dim == 1: assert len(self.scale.shape) == 1, f"slice dim==1 only works when len(scale.shape) == 1 currently, got: {self.scale.shape}" - return PlainAQTLayout(aten.slice.Tensor(self.int_data, dim, start, end, step), self.scale.view(-1), self.zero_point.view(-1), self.layout_type) + return PlainAQTTensorImpl(aten.slice.Tensor(self.int_data, dim, start, end, step), self.scale.view(-1), self.zero_point.view(-1), self.layout_type) else: - raise NotImplementedError(f"PlainAQTLayout dispatch: attempting to run {func}, with dim={dim}, that is not supported") + raise NotImplementedError(f"PlainAQTTensorImpl dispatch: attempting to run {func}, with dim={dim}, that is not supported") raise NotImplementedError( - f"PlainAQTLayout dispatch: attempting to run {func}, this is not supported" + f"PlainAQTTensorImpl dispatch: attempting to run {func}, this is not supported" ) __torch_function__ = torch._C._disabled_torch_function_impl @@ -672,10 +672,10 @@ def from_plain( assert isinstance(layout_type, PlainLayoutType) return cls(int_data, scale, zero_point, layout_type) -@register_layout_cls(SemiSparseLayoutType) -class SemiSparseAQTLayout(PlainAQTLayout): +@register_layout(SemiSparseLayoutType) +class SemiSparseAQTTensorImpl(PlainAQTTensorImpl): """ - Layout storage class for semi_sparse_cusparselt layout for affine quantized tensor + TensorImpl storage class for semi_sparse_cusparselt layout for affine quantized tensor """ @classmethod def __torch_dispatch__(cls, func, types, args, kwargs): @@ -687,7 +687,7 @@ def __torch_dispatch__(cls, func, types, args, kwargs): ) raise NotImplementedError( - f"SparseAQTLayout dispatch: attempting to run {func}, this is not supported" + f"SparseAQTTensorImpl dispatch: attempting to run {func}, this is not supported" ) def get_plain(self): @@ -712,8 +712,8 @@ def from_plain( int_data_compressed = torch._cslt_compress(int_data) return cls(int_data_compressed, scale, zero_point, layout_type) -@register_layout_cls(BlockSparseLayoutType) -class BlockSparseAQTLayout(PlainAQTLayout): +@register_layout(BlockSparseLayoutType) +class BlockSparseAQTTensorImpl(PlainAQTTensorImpl): bsr_crow_indices: Optional[torch.Tensor] bsr_col_indices: Optional[torch.Tensor] bsr_values: Optional[torch.Tensor] @@ -849,13 +849,13 @@ def __torch_dispatch__(cls, func, types, args, kwargs): return args[0].bsr_values.shape[0] raise NotImplementedError( - f"BlockSparseAQTLayout dispatch: attempting to run {func}, this is not supported" + f"BlockSparseAQTTensorImpl dispatch: attempting to run {func}, this is not supported" ) -@register_layout_cls(MarlinSparseLayoutType) -class MarlinSparseAQTLayout(AQTLayout): +@register_layout(MarlinSparseLayoutType) +class MarlinSparseAQTTensorImpl(AQTTensorImpl): """ - Layout storage class for sparse_marlin_24 layout for affine quantized tensor. + TensorImpl storage class for sparse_marlin_24 layout for affine quantized tensor. Can be used with 4 bits and 8 bits quantization. @@ -922,7 +922,7 @@ def __torch_dispatch__(cls, func, types, args, kwargs): ) raise NotImplementedError( - f"MarlinSparseAQTLayout dispatch: attempting to run {func}, this is not supported" + f"MarlinSparseAQTTensorImpl dispatch: attempting to run {func}, this is not supported" ) def __tensor_flatten__(self): @@ -1022,10 +1022,10 @@ def _apply_fn_to_data(self, fn): return self -@register_layout_cls(Float8LayoutType) -class Float8AQTLayout(AQTLayout): +@register_layout(Float8LayoutType) +class Float8AQTTensorImpl(AQTTensorImpl): """ - Layout storage class for float8 layout for affine quantized tensor + TensorImpl storage class for float8 tensor impl for affine quantized tensor """ float8_data: torch.Tensor scale: torch.Tensor @@ -1112,12 +1112,12 @@ def __torch_dispatch__(cls, func, types, args, kwargs): ) elif dim == 1: assert len(self.scale.shape) == 1, f"slice dim==1 only works when len(scale.shape) == 1 currently, got: {self.scale.shape}" - return Float8AQTLayout(aten.slice.Tensor(self.float8_data, dim, start, end, step), self.scale, None, self.layout_type) + return Float8AQTTensorImpl(aten.slice.Tensor(self.float8_data, dim, start, end, step), self.scale, None, self.layout_type) else: - raise NotImplementedError(f"Float8AQTLayout dispatch: attempting to run {func}, with dim={dim}, that is not supported") + raise NotImplementedError(f"Float8AQTTensorImpl dispatch: attempting to run {func}, with dim={dim}, that is not supported") else: raise NotImplementedError( - f"Float8AQTLayout dispatch: attempting to run {func}, this is not supported" + f"Float8AQTTensorImpl dispatch: attempting to run {func}, this is not supported" ) __torch_function__ = torch._C._disabled_torch_function_impl @@ -1136,9 +1136,9 @@ def from_plain( zero_point: Optional[torch.Tensor], layout_type: LayoutType, ): - """ Main entrypoint for constructing Float8Layout Tensor""" - assert _is_float8_type(data.dtype), f"Float8 Layout must be constructed from float8 dtype but got {data.dtype}" - assert isinstance(layout_type, Float8LayoutType), f"Float8 Layout must be constructed from Float8LayoutType but got {layout_type}" + """ Main entrypoint for constructing Float8TensorImpl""" + assert _is_float8_type(data.dtype), f"Float8 TensorImpl must be constructed from float8 dtype but got {data.dtype}" + assert isinstance(layout_type, Float8LayoutType), f"Float8 TensorImpl must be constructed from Float8LayoutType but got {layout_type}" return cls(data, scale, False, layout_type) def __repr__(self): @@ -1151,10 +1151,10 @@ def __repr__(self): f"layout_type={layout_type})") -@register_layout_cls(TensorCoreTiledLayoutType) -class TensorCoreTiledAQTLayout(AQTLayout): +@register_layout(TensorCoreTiledLayoutType) +class TensorCoreTiledAQTTensorImpl(AQTTensorImpl): """ - Layout storage class for tensor_core_tiled layout for affine quantized tensor, this is for int4 only, + TensorImpl storage class for tensor_core_tiled tensor impl for affine quantized tensor, this is for int4 only, it stores the original tensor of dimension [n][k] (int32 dtype) as packed weight of 4-d tensor of dimension: [n / 8][k / (inner_k_tiles * 16)][32][inner_k_tiles / 2] @@ -1230,7 +1230,7 @@ def to(self, *args, **kwargs): kwargs = self._get_to_kwargs(*args, **kwargs) device = kwargs["device"] if not is_device("cuda", device): - raise ValueError(f"TensorCoreTiledAQTLayout is only available for cuda device, can't convert to {device}") + raise ValueError(f"TensorCoreTiledAQTTensorImpl is only available for cuda device, can't convert to {device}") return self.__class__( self.packed_weight.to(device), self.scale_and_zero.to(device), @@ -1265,7 +1265,7 @@ def __torch_dispatch__(cls, func, types, args, kwargs): return return_and_correct_aliasing(func, args, kwargs, args[0]) raise NotImplementedError( - f"TensorCoreTiledAQTLayout dispatch: attempting to run {func}, this is not supported" + f"TensorCoreTiledAQTTensorImpl dispatch: attempting to run {func}, this is not supported" ) __torch_function__ = torch._C._disabled_torch_function_impl @@ -1311,14 +1311,14 @@ def get_layout_type(self) -> LayoutType: def _aqt_is_int8(aqt): """Check if an AffineQuantizedTensor is int8 quantized Tensor""" return ( - aqt.layout_tensor.dtype == torch.int8 and + aqt.tensor_impl.dtype == torch.int8 and (aqt.quant_min is None or aqt.quant_min == -128) and (aqt.quant_max is None or aqt.quant_max == 127) ) def _aqt_is_int8_reduced_range(aqt): return ( - aqt.layout_tensor.dtype == torch.int8 and + aqt.tensor_impl.dtype == torch.int8 and aqt.quant_min == -127 and (aqt.quant_max is None or aqt.quant_max == 127) ) @@ -1327,7 +1327,7 @@ def _aqt_is_tensor_core_tile_uint4(aqt): """Check if an AffineQuantizedTensor is uint4 quantized Tensor""" # TODO: use torch.uint4 return ( - aqt.layout_tensor.dtype == torch.int32 and + aqt.tensor_impl.dtype == torch.int32 and aqt.quant_min == 0 and aqt.quant_max == 15 ) @@ -1364,10 +1364,10 @@ def _linear_int8_act_int8_weight_impl(input_tensor, weight_tensor, bias): # value of a float 16, (which results in a value of inf even if multiplying # by the other scale would bring it within the expected range) - x_vals_int8 = input_tensor.layout_tensor.int_data - x_scales = input_tensor.layout_tensor.scale - w_vals_int8_t = weight_tensor.layout_tensor.int_data.contiguous().t() - w_scales = weight_tensor.layout_tensor.scale + x_vals_int8 = input_tensor.tensor_impl.int_data + x_scales = input_tensor.tensor_impl.scale + w_vals_int8_t = weight_tensor.tensor_impl.int_data.contiguous().t() + w_scales = weight_tensor.tensor_impl.scale tmp = x_vals_int8.reshape(-1, x_vals_int8.shape[-1]) y_dot_scaled = int_scaled_matmul(tmp, w_vals_int8_t, x_scales.reshape(-1, 1)) @@ -1395,10 +1395,10 @@ def _linear_int8_act_int8_weight_semi_structured_sparse_check(input_tensor, weig ) def _linear_int8_act_int8_weight_semi_structured_sparse_impl(input_tensor, weight_tensor, bias): - x_vals_int8 = input_tensor.layout_tensor.int_data - x_scales = input_tensor.layout_tensor.scale - w_vals_int8 = weight_tensor.layout_tensor.int_data - w_scales = weight_tensor.layout_tensor.scale + x_vals_int8 = input_tensor.tensor_impl.int_data + x_scales = input_tensor.tensor_impl.scale + w_vals_int8 = weight_tensor.tensor_impl.int_data + w_scales = weight_tensor.tensor_impl.scale tmp = x_vals_int8.reshape(-1, x_vals_int8.shape[-1]) # we fuse one of the scalar matrix multiplications (w_scales) into the sparse mm y_dot_bf16_w_scales_fused = torch._cslt_sparse_mm( @@ -1427,10 +1427,10 @@ def _linear_int8_act_int8_weight_block_sparse_check(input_tensor, weight_tensor, def _linear_int8_act_int8_weight_block_sparse_impl(input_tensor, weight_tensor, bias): - x_vals_int8 = input_tensor.layout_tensor.int_data - x_scales = input_tensor.layout_tensor.scale - w_vals = weight_tensor.layout_tensor - w_scales = weight_tensor.layout_tensor.scale + x_vals_int8 = input_tensor.tensor_impl.int_data + x_scales = input_tensor.tensor_impl.scale + w_vals = weight_tensor.tensor_impl + w_scales = weight_tensor.tensor_impl.scale tmp = x_vals_int8.reshape(-1, x_vals_int8.shape[-1]) tmp_t = tmp.t() @@ -1456,7 +1456,7 @@ def _linear_bf16_act_uint4_weight_check(input_tensor, weight_tensor, bias): # input is native bfloat16 tensor not is_traceable_wrapper_subclass(input_tensor) and input_tensor.dtype == torch.bfloat16 and - # weight is uint4, group quantized tensor_core_tiled layout affine quantized tensor + # weight is uint4, group quantized tensor_core_tiled tensor impl affine quantized tensor isinstance(weight_tensor, AffineQuantizedTensor) and _aqt_is_tensor_core_tile_uint4(weight_tensor) and weight_tensor.dtype == torch.bfloat16 and @@ -1478,8 +1478,8 @@ def _linear_bf16_act_uint4_weight_impl(input_tensor, weight_tensor, bias): act_mat = input_tensor # weight is packed from padded (out_features, in_features) weight tensor # (same dimension requirement as F.linear weight) - packed_weight = weight_tensor.layout_tensor.packed_weight - scale_and_zero = weight_tensor.layout_tensor.scale_and_zero + packed_weight = weight_tensor.tensor_impl.packed_weight + scale_and_zero = weight_tensor.tensor_impl.scale_and_zero orig_act_size = act_mat.size() orig_dtype = act_mat.dtype @@ -1522,11 +1522,11 @@ def _linear_fp_act_int8_weight_check(input_tensor, weight_tensor, bias): def _linear_fp_act_int8_weight_impl(input_tensor, weight_tensor, bias): # TODO: enable cpu and mps efficient path # is_cpu and is_mps only, some issue with is_contiguous() currently - # return torch.ops.aten._weight_int8pack_mm(input_tensor.contiguous(), w_vals_int8_t, weight_tensor.layout_tensor.scale) + # return torch.ops.aten._weight_int8pack_mm(input_tensor.contiguous(), w_vals_int8_t, weight_tensor.tensor_impl.scale) # per channel int8 weight only quantizated mm - w_vals_int8_t = weight_tensor.layout_tensor.int_data.t() - scale = weight_tensor.layout_tensor.scale + w_vals_int8_t = weight_tensor.tensor_impl.int_data.t() + scale = weight_tensor.tensor_impl.scale orig_dtype = input_tensor.dtype m = torch.mm( input_tensor.reshape(-1, input_tensor.shape[-1]), @@ -1580,8 +1580,8 @@ def _linear_f16_act_floatx_weight_impl(input_tensor, weight_tensor, bias): weight.layout_type.ebits, weight.layout_type.mbits, act_reshaped, - weight.layout_tensor.packed_floatx_data, - weight.layout_tensor.scale, + weight.tensor_impl.packed_floatx_data, + weight.tensor_impl.scale, splitK=splitK, ) @@ -1599,7 +1599,7 @@ def check_aqt(aqt: Union[torch.Tensor, AffineQuantizedTensor]) -> bool: return ( isinstance(aqt, AffineQuantizedTensor) and isinstance(aqt.layout_type, Float8LayoutType) - and aqt.layout_tensor.dtype in [torch.float8_e4m3fn, torch.float8_e5m2] + and aqt.tensor_impl.dtype in [torch.float8_e4m3fn, torch.float8_e5m2] and (aqt.shape == aqt.block_size or _is_rowwise_scaled(aqt)) ) return check_aqt(input_tensor) and check_aqt(weight_tensor) @@ -1624,14 +1624,14 @@ def _linear_fp8_act_fp8_weight_impl( out_shape = get_out_shape(input_tensor.shape, weight_tensor.shape) # Weight tensor preprocessing - w_layout = weight_tensor.layout_tensor - assert not w_layout.transposed, "Weight tensor must be contiguous" - w_data = w_layout.float8_data - w_scale = w_layout.scale + w_tensor_impl = weight_tensor.tensor_impl + assert not w_tensor_impl.transposed, "Weight tensor must be contiguous" + w_data = w_tensor_impl.float8_data + w_scale = w_tensor_impl.scale # Input tensor preprocessing - inpt_data = input_tensor.layout_tensor.float8_data - input_scale = input_tensor.layout_tensor.scale + inpt_data = input_tensor.tensor_impl.float8_data + input_scale = input_tensor.tensor_impl.scale # Handle case where input tensor is more than 2D inpt_data = inpt_data.reshape(-1, inpt_data.shape[-1]) @@ -1667,7 +1667,7 @@ def _linear_fp_act_fp8_weight_check( # weight is float8 quantized affine quantized tensor isinstance(weight_tensor, AffineQuantizedTensor) and isinstance(weight_tensor.layout_type, Float8LayoutType) - and weight_tensor.layout_tensor.dtype in [torch.float8_e4m3fn, torch.float8_e5m2] + and weight_tensor.tensor_impl.dtype in [torch.float8_e4m3fn, torch.float8_e5m2] and (weight_tensor.shape == weight_tensor.block_size or _is_rowwise_scaled(weight_tensor)) ) @@ -1694,11 +1694,11 @@ def _linear_fp_act_int4_weight_sparse_marlin_impl(input_tensor, weight_tensor, b assert isinstance(weight_tensor, AffineQuantizedTensor) - sparse_w_int4 = weight_tensor.layout_tensor.int_data - scale = weight_tensor.layout_tensor.scale - meta = weight_tensor.layout_tensor.meta - original_shape = weight_tensor.layout_tensor.original_shape - num_bits = weight_tensor.layout_tensor.num_bits + sparse_w_int4 = weight_tensor.tensor_impl.int_data + scale = weight_tensor.tensor_impl.scale + meta = weight_tensor.tensor_impl.meta + original_shape = weight_tensor.tensor_impl.original_shape + num_bits = weight_tensor.tensor_impl.num_bits # Folds batch dimension into the first dimension input_2d = input_tensor.view(-1, input_tensor.shape[-1]) @@ -1845,7 +1845,7 @@ def _(func, types, args, kwargs): tensor = args[0] shape = tensor.shape[::-1] new = tensor.__class__( - tensor.layout_tensor.t(), transposed_block_size, shape, tensor.quant_min, tensor.quant_max, tensor.zero_point_domain, dtype=tensor.dtype, strides=tensor.stride() + tensor.tensor_impl.t(), transposed_block_size, shape, tensor.quant_min, tensor.quant_max, tensor.zero_point_domain, dtype=tensor.dtype, strides=tensor.stride() ) return return_and_correct_aliasing(func, args, kwargs, new) @@ -1863,7 +1863,7 @@ def _(func, types, args, kwargs): # with slice, some shape dimension might be smaller than block_size dimension, so # we need to make sure there is no overflow block_size = (min(shape[0], block_size[0]), min(shape[1], block_size[1])) - new = self.__class__(aten.slice.Tensor(self.layout_tensor, dim, start, end, step), block_size, shape, self.quant_min, self.quant_max, self.zero_point_domain, dtype=self.dtype, strides=self.stride()) + new = self.__class__(aten.slice.Tensor(self.tensor_impl, dim, start, end, step), block_size, shape, self.quant_min, self.quant_max, self.zero_point_domain, dtype=self.dtype, strides=self.stride()) return return_and_correct_aliasing(func, args, kwargs, new) # this is needed for DTensor.from_local() and for flattening tensor @@ -1872,12 +1872,12 @@ def _(func, types, args, kwargs): self, shape = args if tuple(self.shape) == tuple(shape): - return self.__class__(self.layout_tensor, self.block_size, self.shape, self.quant_min, self.quant_max, self.zero_point_domain, dtype=self.dtype, strides=self.stride()) + return self.__class__(self.tensor_impl, self.block_size, self.shape, self.quant_min, self.quant_max, self.zero_point_domain, dtype=self.dtype, strides=self.stride()) if len(shape) == 1 and shape[0] == -1: assert len(self.block_size) == 2 and self.block_size[0] == 1 block_size = (self.block_size[1],) - return self.__class__(self.layout_tensor, block_size, (self.numel(),), self.quant_min, self.quant_max, self.zero_point_domain, dtype=self.dtype, strides=self.stride()) + return self.__class__(self.tensor_impl, block_size, (self.numel(),), self.quant_min, self.quant_max, self.zero_point_domain, dtype=self.dtype, strides=self.stride()) raise ValueError(f"{self.__class__.__name__} only supports .view() with same shape or shape=[-1]") diff --git a/torchao/dtypes/floatx/__init__.py b/torchao/dtypes/floatx/__init__.py index 0eb1e70529..39461d8869 100644 --- a/torchao/dtypes/floatx/__init__.py +++ b/torchao/dtypes/floatx/__init__.py @@ -1 +1 @@ -from .floatx import FloatxTensorCoreLayoutType, FloatxTensorCoreAQTLayout, to_scaled_tc_floatx, from_scaled_tc_floatx, _SPLIT_K_MAP +from .floatx import FloatxTensorCoreLayoutType, FloatxTensorCoreAQTTensorImpl, to_scaled_tc_floatx, from_scaled_tc_floatx, _SPLIT_K_MAP diff --git a/torchao/dtypes/floatx/floatx.py b/torchao/dtypes/floatx/floatx.py index dcbfd5f69c..5a9aab0357 100644 --- a/torchao/dtypes/floatx/floatx.py +++ b/torchao/dtypes/floatx/floatx.py @@ -10,7 +10,7 @@ ) from torchao.quantization.quant_api import _get_linear_subclass_inserter from dataclasses import dataclass -from torchao.dtypes.affine_quantized_tensor import AQTLayout, register_layout_cls +from torchao.dtypes.affine_quantized_tensor import AQTTensorImpl, register_layout aten = torch.ops.aten @@ -354,14 +354,14 @@ def from_scaled_tc_floatx(tensor: Tensor, ebits: int, mbits: int, scale=None) -> @dataclass(frozen=True) class FloatxTensorCoreLayoutType(LayoutType): - """Layout type for FloatxTensorCoreAQTLayout + """Layout type for FloatxTensorCoreAQTTensorImpl """ ebits: int mbits: int -@register_layout_cls(FloatxTensorCoreLayoutType) -class FloatxTensorCoreAQTLayout(AQTLayout): - """FloatxTensorCoreAQTLayout represents a Tensor with dtype floatx(ebits=a, mbits=b), +@register_layout(FloatxTensorCoreLayoutType) +class FloatxTensorCoreAQTTensorImpl(AQTTensorImpl): + """FloatxTensorCoreAQTTensorImpl represents a Tensor with dtype floatx(ebits=a, mbits=b), it has a internal tensor field of "packed_floatx_data", which is packed from the uint8 unpacked data (the output of `quantize_affine_floatx` operator) @@ -377,10 +377,10 @@ class FloatxTensorCoreAQTLayout(AQTLayout): If original Tensor shape is (M, N), and the data is in nbit, the shape of the packed data will be (M, N // 8 * nbit) - FloatxTensorCoreAQTLayout.from_plain takes an unpacked uint8 floatx Tensor of shape (M, N), with format of + FloatxTensorCoreAQTTensorImpl.from_plain takes an unpacked uint8 floatx Tensor of shape (M, N), with format of (zero padding bits + sign bit + exponent bits + mantissa bits), e.g. 00SEEEMM for fp6_e3_m2 - it will then pack the weight and instantiate the FloatxTensorCoreAQTLayout tensor - FloatxTensorCoreAQTLayout.__init__() takes a packed floatx Tensor of shape (M, N // 8 * nbit) + it will then pack the weight and instantiate the FloatxTensorCoreAQTTensorImpl tensor + FloatxTensorCoreAQTTensorImpl.__init__() takes a packed floatx Tensor of shape (M, N // 8 * nbit) """ def __new__( cls, @@ -483,7 +483,7 @@ def __torch_dispatch__(cls, func, types, args, kwargs): ) raise NotImplementedError( - f"FloatxTensorCoreAQTLayout dispatch: attempting to run {func}, this is not supported" + f"FloatxTensorCoreAQTTensorImpl dispatch: attempting to run {func}, this is not supported" ) __torch_function__ = torch._C._disabled_torch_function_impl diff --git a/torchao/dtypes/uintx/uintx.py b/torchao/dtypes/uintx/uintx.py index a0cd687f53..eb63fc6191 100644 --- a/torchao/dtypes/uintx/uintx.py +++ b/torchao/dtypes/uintx/uintx.py @@ -8,7 +8,7 @@ LayoutType, ) from torchao.utils import TorchAOBaseTensor -from torchao.dtypes.affine_quantized_tensor import PlainAQTLayout, register_layout_cls +from torchao.dtypes.affine_quantized_tensor import PlainAQTTensorImpl, register_layout from torchao.utils import TORCH_VERSION_AT_LEAST_2_3 aten = torch.ops.aten @@ -194,8 +194,8 @@ class UintxLayoutType(LayoutType): def post_process(self, input: torch.Tensor) -> torch.Tensor: return to_uintx(input, self.dtype, self.pack_dim) -@register_layout_cls(UintxLayoutType) -class UintxAQTLayout(PlainAQTLayout): +@register_layout(UintxLayoutType) +class UintxAQTTensorImpl(PlainAQTTensorImpl): def get_plain(self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: return self.int_data.get_plain(), self.scale, self.zero_point diff --git a/torchao/dtypes/utils.py b/torchao/dtypes/utils.py index 52a9c57191..4a6b3a0bb8 100644 --- a/torchao/dtypes/utils.py +++ b/torchao/dtypes/utils.py @@ -4,11 +4,11 @@ """ Base class for different LayoutType, should not be instantiated directly -used to allow users to pass around configurations for the layout tensor, e.g. inner_k_tiles -for int4 tensor core tiled layout +used to allow users to pass around configurations for the tensor impl, e.g. inner_k_tiles +for int4 tensor core tiled tensor impl -Note: layout is an abstraction not only for custom data representation, it is also used for how the -layout interacts with different operators, e.g. the same data representation can have different +Note: TensorImpl is an abstraction not only for custom data representation, it is also used for how the +tensorImpl interacts with different operators, e.g. the same data representation can have different behaviors when running the same operator, e.g. transpose, quantized_linear. """ @dataclass(frozen=True) diff --git a/torchao/prototype/hqq/example.py b/torchao/prototype/hqq/example.py index cd5a93b561..f410a11cd8 100644 --- a/torchao/prototype/hqq/example.py +++ b/torchao/prototype/hqq/example.py @@ -3,9 +3,9 @@ from torchao.dtypes.affine_quantized_tensor import ( to_affine_quantized_intx, ZeroPointDomain, - PlainAQTLayout, + PlainAQTTensorImpl, PlainLayoutType, - TensorCoreTiledAQTLayout, + TensorCoreTiledAQTTensorImpl, TensorCoreTiledLayoutType, MappingType, ) diff --git a/torchao/quantization/autoquant.py b/torchao/quantization/autoquant.py index a5568c4e17..7439c982b4 100644 --- a/torchao/quantization/autoquant.py +++ b/torchao/quantization/autoquant.py @@ -348,7 +348,7 @@ def _autoquant_test(cls, act_mat, weight, bias, best_time, mode=["relu", None]): ) q_c_matmul=torch.compile(quantized_matmul, mode="max-autotune-no-cudagraphs") with torch.no_grad(): - w_vals_int8 = w_qtensor.original_weight_tensor.layout_tensor.int_data.contiguous().t() + w_vals_int8 = w_qtensor.original_weight_tensor.tensor_impl.int_data.contiguous().t() res_matmul = do_autoquant_bench(q_c_matmul, x_vals_int8, x_scales.reshape(-1,1), w_vals_int8) print(f">>time: {res_matmul:0.3f}ms for {cls} matmul, to_beat: {best_time:0.3f}ms") @@ -399,8 +399,8 @@ def _quantized_linear_op(act_mat, w_qtensor, bias): orig_dtype = act_mat.dtype orig_shape = act_mat.shape act_mat = act_mat.reshape(-1, act_mat.shape[-1], 1) - y = (act_mat*w_qtensor.layout_tensor.int_data.t().unsqueeze(0)).sum(dim=-2) - y = y.reshape(*orig_shape[:-1], y.shape[-1]) * w_qtensor.layout_tensor.scale + y = (act_mat*w_qtensor.tensor_impl.int_data.t().unsqueeze(0)).sum(dim=-2) + y = y.reshape(*orig_shape[:-1], y.shape[-1]) * w_qtensor.tensor_impl.scale if bias is not None: y += bias return y.to(orig_dtype) @@ -420,7 +420,7 @@ class AQInt8WeightOnlyQuantizedLinearWeight3(AQInt8WeightOnlyQuantizedLinearWeig @staticmethod def _quantized_linear_op(act_mat, w_qtensor, bias): orig_shape = act_mat.shape - y = torch.mm(act_mat.reshape(-1, orig_shape[-1]), w_qtensor.layout_tensor.int_data.t()*w_qtensor.layout_tensor.scale) + y = torch.mm(act_mat.reshape(-1, orig_shape[-1]), w_qtensor.tensor_impl.int_data.t()*w_qtensor.tensor_impl.scale) y=y.reshape(*orig_shape[:-1], y.shape[-1]) if bias is not None: y += bias diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py index 6c41425062..aef873e2fd 100644 --- a/torchao/quantization/quant_api.py +++ b/torchao/quantization/quant_api.py @@ -886,7 +886,7 @@ def fpx_weight_only(ebits: int, mbits: int): e.g. fp6_e3_m2, fp6_e2_m3, ... The packing format and kernels are from the fp6-llm paper: https://arxiv.org/abs/2401.14112 github repo: https://github.com/usyd-fsalab/fp6_llm, now renamed to quant-llm - For more details for packing please see: :class:`~torchao.dtypes.fpx.FpxTensorCoreAQTLayout` + For more details for packing please see: :class:`~torchao.dtypes.fpx.FpxTensorCoreAQTTensorImpl` This is experimental, will be merged with `to_affine_quantized_floatx` in the future diff --git a/torchao/sparsity/marlin/utils.py b/torchao/sparsity/marlin/utils.py index 4ebdf432e3..4c55725539 100644 --- a/torchao/sparsity/marlin/utils.py +++ b/torchao/sparsity/marlin/utils.py @@ -9,7 +9,7 @@ class Marlin24Constants: MIN_THREAD_N: int = 128 MAX_PARALLEL: int = 64 - # NOTE: Cuda kernel supports fp8, but not implemented yet in SparseMarlinAQTLayout + # NOTE: Cuda kernel supports fp8, but not implemented yet in SparseMarlinAQTTensorImpl SUPPORTED_NUM_BITS: List[int] = field(default_factory=lambda: [4, 8]) SUPPORTED_GROUP_SIZES: List[int] = field(default_factory=lambda: [-1, 32, 64, 128]) const = Marlin24Constants() diff --git a/torchao/utils.py b/torchao/utils.py index a0302cabe6..36bc1be36b 100644 --- a/torchao/utils.py +++ b/torchao/utils.py @@ -392,34 +392,34 @@ class MyTensor(torch.Tensor): kwarg_types = {k: type(arg) for k, arg in kwargs} raise NotImplementedError(f"{cls.__name__} dispatch: attempting to run unimplemented operator/function: {func=}, {types=}, {arg_types=}, {kwarg_types=}") -def _register_layout_cls(cls: Callable, layout_type_class: Callable): +def _register_layout(cls: Callable, layout_type_class: Callable): """Helper function for layout registrations, this is used to implement - register_layout_cls decorator for each tensor subclass, see aqt.py for example usage + register_layout decorator for each tensor subclass, see aqt.py for example usage Args: cls: Tensor subclass type layout_type_class: the class type of subclass of `LayoutType`, e.g. `PlainLayoutType` Returns: - a decorator that registers the layout tensor constructor in the table + a decorator that registers the tensor impl constructor in the table """ # cls._LAYOUT_CONSTRUCTOR_TABLE is a map from layout_type_class like TensorCoreTiledLayout - # to layout class constructor like TensorCoreTiledAQTLayout.from_plain that can construct a layout_tensor + # to tensor_impl class constructor like TensorCoreTiledAQTTensorImpl.from_plain that can construct a tensor_impl # from plain data like (quantized, unpacked) `data`, `scale`, `zero_point` if not hasattr(cls, "_LAYOUT_CONSTRUCTOR_TABLE"): cls._LAYOUT_CONSTRUCTOR_TABLE = {} - def decorator(layout_class): - cls._LAYOUT_CONSTRUCTOR_TABLE[layout_type_class] = layout_class.from_plain + def decorator(tensor_impl_class): + cls._LAYOUT_CONSTRUCTOR_TABLE[layout_type_class] = tensor_impl_class.from_plain if TORCH_VERSION_AT_LEAST_2_5: - # Allow serialization to work for models uses this layout tensor subclass - torch.serialization.add_safe_globals([layout_type_class, layout_class]) - return layout_class + # Allow serialization to work for models uses this tensor impl subclass + torch.serialization.add_safe_globals([layout_type_class, tensor_impl_class]) + return tensor_impl_class return decorator -def _get_layout_tensor_constructor(cls: Callable, layout_type_class: Callable) -> Callable: - """Get Layout class constructor (LayoutClass.from_plain) for `cls` based on `layout_type_class` +def _get_tensor_impl_constructor(cls: Callable, layout_type_class: Callable) -> Callable: + """Get TensorImpl class constructor (TensorImplClass.from_plain) for `cls` based on `layout_type_class` `layout_type_class` means the class type of subclass of `LayoutType`, e.g. `PlainLayoutType` Args: @@ -427,10 +427,10 @@ def _get_layout_tensor_constructor(cls: Callable, layout_type_class: Callable) - layout_type_class: the class type of subclass of `LayoutType`, e.g. `PlainLayoutType` Returns: - layout tensor subclass constructor for the layout_type_class + tensor impl subclass constructor for the layout_type_class """ if not hasattr(cls, "_LAYOUT_CONSTRUCTOR_TABLE"): - raise ValueError(f"no registered layout class constructor for: {cls}") + raise ValueError(f"no registered tensor_impl class constructor for: {cls}") if layout_type_class not in cls._LAYOUT_CONSTRUCTOR_TABLE: raise ValueError(f"layout_name: {layout_type_class} is not supported yet for {cls}") @@ -457,25 +457,25 @@ def to(self, *args, **kwargs): def _(func, types, args, kwargs): ... - `register_layout_cls`: - register_layout_cls = MyTensor.register_layout_cls + `register_layout`: + register_layout = MyTensor.register_layout - @register_layout_cls(PlainLayoutType) - class PlainAQTLayout(...): + @register_layout(PlainLayoutType) + class PlainAQTTensorImpl(...): ... - `get_layout_tensor_constructor`: - get_layout_tensor_constructor = MyTensor.get_layout_tensor_constructor + `get_tensor_impl_constructor`: + get_tensor_impl_constructor = MyTensor.get_tensor_impl_constructor # in constructor of MyTensor: - layout_tensor_ctr = get_layout_tensor_constructor(type(layout_type)) - layout_tensor = layout_tensor_ctr(data, scale, zero_point, layout_type) + tensor_impl_ctr = get_tensor_impl_constructor(type(layout_type)) + tensor_impl = tensor_impl_ctr(data, scale, zero_point, layout_type) """ implements = classmethod(_implements) __torch_dispatch__ = classmethod(_dispatch__torch_dispatch__) __torch_function__ = classmethod(_dispatch__torch_function__) - register_layout_cls = classmethod(_register_layout_cls) - get_layout_tensor_constructor = classmethod(_get_layout_tensor_constructor) + register_layout = classmethod(_register_layout) + get_tensor_impl_constructor = classmethod(_get_tensor_impl_constructor) def _get_to_kwargs(self, *args, **kwargs): # `torch._C._nn._parse_to` can't handle `layout` argument diff --git a/tutorials/developer_api_guide/my_dtype_tensor_subclass.py b/tutorials/developer_api_guide/my_dtype_tensor_subclass.py index bc85d26f5d..c714df2a7b 100644 --- a/tutorials/developer_api_guide/my_dtype_tensor_subclass.py +++ b/tutorials/developer_api_guide/my_dtype_tensor_subclass.py @@ -33,11 +33,11 @@ aten = torch.ops.aten ############################### -# Base Layout Tensor Subclass # +# Base Tensor Impl Subclass # ############################### -class MyDTypeLayout(torch.Tensor): +class MyDTypeTensorImpl(torch.Tensor): """ - Base class for the layout tensor for `MyDTypeTensor` + Base class for the tensor impl for `MyDTypeTensor` """ # get the original unpacked Tensors def get_plain(self) -> Tuple[torch.Tensor, torch.Tensor]: @@ -53,7 +53,7 @@ def from_plain( scale: torch.Tensor, layout_type: LayoutType, ): - """Construct a layout tensor from plain tensors and a layout_type, which main contain + """Construct a tensor impl from plain tensors and a layout_type, which main contain extra metadata for packing etc. """ pass @@ -82,17 +82,17 @@ class MyDTypeTensor(TorchAOBaseTensor): @staticmethod def __new__( cls, - layout_tensor: MyDTypeLayout, + tensor_impl: MyDTypeTensorImpl, shape: torch.Size, dtype: Optional[torch.dtype] = None, requires_grad: bool = False, ): kwargs = {} - kwargs["device"] = layout_tensor.device + kwargs["device"] = tensor_impl.device kwargs["layout"] = ( kwargs.get("layout") if kwargs.get("layout", False) - else layout_tensor.layout + else tensor_impl.layout ) kwargs["dtype"] = dtype kwargs["requires_grad"] = requires_grad @@ -100,12 +100,12 @@ def __new__( def __init__( self, - layout_tensor: MyDTypeLayout, + tensor_impl: MyDTypeTensorImpl, shape: torch.Size, dtype: Optional[torch.dtype] = None, requires_grad: bool = False, ): - self.layout_tensor = layout_tensor + self.tensor_impl = tensor_impl """__tensor_flatten__ and __tensor_unflatten__ are used to desugar the tensor into native Tensors/attributes and reconstruct the tensor subclass instance from the desugared tensor and attributes, these are required to define @@ -118,7 +118,7 @@ def __tensor_flatten__(self): The first one contains any tensor fields such as int_data and scale as keys to a dictionary The second one contains all other non tensor type fields as values of a list """ - return ["layout_tensor"], [self.shape, self.dtype, self.requires_grad] + return ["tensor_impl"], [self.shape, self.dtype, self.requires_grad] @classmethod def __tensor_unflatten__( @@ -129,10 +129,10 @@ def __tensor_unflatten__( tensor_data_dict contains the tensor fields of the class as a dictionary tensor_attributes contains all other non tensor type fields """ - layout_tensor = tensor_data_dict["layout_tensor"] + tensor_impl = tensor_data_dict["tensor_impl"] shape, dtype, requires_grad = tensor_attributes return cls( - layout_tensor, + tensor_impl, shape if outer_size is None else outer_size, dtype=dtype, requires_grad=requires_grad, @@ -152,25 +152,25 @@ def from_float( dtype = torch.int16 scale, zero_point = choose_qparams_affine(input_float, mapping_type, block_size, dtype) int_data = quantize_affine(input_float, block_size, scale, zero_point, dtype) - layout_tensor_ctr = get_layout_tensor_constructor(type(layout_type)) - layout_tensor = layout_tensor_ctr(int_data, scale, layout_type) - return cls(layout_tensor, input_float.shape) + tensor_impl_ctr = get_tensor_impl_constructor(type(layout_type)) + tensor_impl = tensor_impl_ctr(int_data, scale, layout_type) + return cls(tensor_impl, input_float.shape) """[Optional] We can overwrite layout property of the Tensor to represent different packing formats """ @property def layout_type(self) -> LayoutType: - return self.layout_tensor.layout_type + return self.tensor_impl.layout_type def dequantize(self, output_dtype=None): """We can define a dequantize method to convert the quantized tensor to a floating point tensor""" if output_dtype is None: output_dtype = torch.get_default_dtype() - int_data, scale = self.layout_tensor.get_plain() + int_data, scale = self.tensor_impl.get_plain() transposed = False block_size = (1, int_data.shape[-1]) - if hasattr(self.layout_tensor, "transposed") and self.layout_tensor.transposed: + if hasattr(self.tensor_impl, "transposed") and self.tensor_impl.transposed: transposed = True res = dequantize_affine(int_data, block_size, scale, None, int_data.dtype, output_dtype=output_dtype) if transposed: @@ -186,10 +186,10 @@ def __repr__(self): def _apply_fn_to_data(self, fn): """ Used for implementing aten ops by applying them only to the relevant tensor atributes - In this case we only want to call things like to() or view() on the layout tensor + In this case we only want to call things like to() or view() on the tensor impl """ return self.__class__( - fn(self.layout_tensor), + fn(self.tensor_impl), self.shape, self.dtype, ) @@ -206,14 +206,14 @@ def _apply_fn_to_data(self, fn): """ ###################################################### -# LayoutType and Layout Tensor Subclass Registration # +# LayoutType and TensorImpl Subclass Registration # ###################################################### -register_layout_cls = MyDTypeTensor.register_layout_cls -get_layout_tensor_constructor = MyDTypeTensor.get_layout_tensor_constructor +register_layout = MyDTypeTensor.register_layout +get_tensor_impl_constructor = MyDTypeTensor.get_tensor_impl_constructor -@register_layout_cls(PlainLayoutType) -class PlainMyDTypeLayout(MyDTypeLayout): +@register_layout(PlainLayoutType) +class PlainMyDTypeTensorImpl(MyDTypeTensorImpl): def __new__( cls, int_data: torch.Tensor, @@ -261,7 +261,7 @@ def from_plain( scale: torch.Tensor, layout_type: LayoutType, ): - """Construct a layout tensor from plain tensors and a layout_type, which main contain + """Construct a tensor impl from plain tensors and a layout_type, which main contain extra metadata for packing etc. """ assert isinstance(layout_type, PlainLayoutType) @@ -292,11 +292,11 @@ def __torch_dispatch__(cls, func, types, args, kwargs): elif func is aten.split.Tensor: int_data_list = func(args[0].int_data, *args[1:], **kwargs) scale_list = func(args[0].scale, *args[1:], **kwargs) - out = [PlainMyDTypeLayout(int_data, scale, args[0].transposed, args[0].layout_type) for int_data, scale in zip(int_data_list, scale_list)] + out = [PlainMyDTypeTensorImpl(int_data, scale, args[0].transposed, args[0].layout_type) for int_data, scale in zip(int_data_list, scale_list)] return out elif func is aten.empty_like.default: int_data_empty_like = func(args[0].int_data, *args[1:], **kwargs) - return PlainMyDTypeLayout(int_data_empty_like, args[0].scale, args[0].transposed, args[0].layout_type) + return PlainMyDTypeTensorImpl(int_data_empty_like, args[0].scale, args[0].transposed, args[0].layout_type) elif func is aten.slice.Tensor: self, dim, start, end, step = fill_defaults(args, 5, [0, None, None, 1]) if dim == 0: @@ -304,16 +304,16 @@ def __torch_dispatch__(cls, func, types, args, kwargs): func, args, kwargs, args[0]._apply_fn_to_data(lambda x: aten.slice.Tensor(x, dim, start, end, step)) ) elif dim == 1: - return PlainMyDTypeLayout(aten.slice.Tensor(self.int_data, dim, start, end, step), self.scale.view(-1), self.transposed, self.layout_type) + return PlainMyDTypeTensorImpl(aten.slice.Tensor(self.int_data, dim, start, end, step), self.scale.view(-1), self.transposed, self.layout_type) else: - raise NotImplementedError(f"PlainMyDTypeLayout dispatch: attempting to run {func}, with dim={dim}, that is not supported") + raise NotImplementedError(f"PlainMyDTypeTensorImpl dispatch: attempting to run {func}, with dim={dim}, that is not supported") elif func is aten.t.default: - return return_and_correct_aliasing(func, args, kwargs, PlainMyDTypeLayout(args[0].int_data, args[0].scale, not args[0].transposed, args[0].layout_type)) + return return_and_correct_aliasing(func, args, kwargs, PlainMyDTypeTensorImpl(args[0].int_data, args[0].scale, not args[0].transposed, args[0].layout_type)) # Tensor parallel support END raise NotImplementedError( - f"PlainMyDTypeLayout dispatch: attempting to run {func}, this is not supported" + f"PlainMyDTypeTensorImpl dispatch: attempting to run {func}, this is not supported" ) ##################################################### diff --git a/tutorials/developer_api_guide/my_trainable_tensor_subclass.py b/tutorials/developer_api_guide/my_trainable_tensor_subclass.py index b702ac4f91..59e72efb6b 100644 --- a/tutorials/developer_api_guide/my_trainable_tensor_subclass.py +++ b/tutorials/developer_api_guide/my_trainable_tensor_subclass.py @@ -43,8 +43,8 @@ def _quantize( dtype = torch.int16 scale, _ = choose_qparams_affine(input_float, mapping_type, block_size, dtype) int_data = (input_float / scale).to(torch.int8) - layout_tensor_ctr = cls.get_layout_tensor_constructor(type(layout_type)) - return layout_tensor_ctr(int_data, scale, layout_type) + tensor_impl_ctr = cls.get_tensor_impl_constructor(type(layout_type)) + return tensor_impl_ctr(int_data, scale, layout_type) @classmethod def from_float( @@ -71,9 +71,9 @@ def forward( input_float: torch.Tensor, layout_type: LayoutType, ) -> "MyTrainableDTypeTensor": - layout_tensor = MyTrainableDTypeTensor._quantize(input_float, layout_type) + tensor_impl = MyTrainableDTypeTensor._quantize(input_float, layout_type) return MyTrainableDTypeTensor( - layout_tensor, + tensor_impl, input_float.shape, requires_grad=True, ) @@ -137,15 +137,15 @@ def _(func, types, args, kwargs): """ assert len(args) == 2 assert isinstance(args[0], MyTrainableDTypeTensor) - assert args[0].layout_tensor.int_data.dtype == torch.int8 + assert args[0].tensor_impl.int_data.dtype == torch.int8 float0 = args[0].dequantize() float1 = args[1].dequantize() if isinstance(args[1], MyTrainableDTypeTensor) else args[1] new_value = torch.add(float0, float1, **kwargs) - new_layout_tensor = MyTrainableDTypeTensor._quantize( + new_tensor_impl = MyTrainableDTypeTensor._quantize( new_value, - args[0].layout_tensor.get_layout_type(), + args[0].tensor_impl.get_layout_type(), ) - args[0].layout_tensor = new_layout_tensor + args[0].tensor_impl = new_tensor_impl return return_and_correct_aliasing(func, args, kwargs, args[0]) @implements(aten.add.Tensor) @@ -190,7 +190,7 @@ def main(): loss = loss_fn(output, target) loss.backward() if VERBOSE: - weight = m.linear.weight.layout_tensor.int_data.flatten()[:3] + weight = m.linear.weight.tensor_impl.int_data.flatten()[:3] weight_grad = m.linear.weight.grad.flatten()[:3] print(" * step %s: weight grad = %s, weight value = %s" % (i, weight_grad, weight)) optimizer.step() diff --git a/tutorials/developer_api_guide/tensor_parallel.py b/tutorials/developer_api_guide/tensor_parallel.py index 0ed3bc9a29..84de815a36 100644 --- a/tutorials/developer_api_guide/tensor_parallel.py +++ b/tutorials/developer_api_guide/tensor_parallel.py @@ -24,14 +24,14 @@ def _(func, types, args, kwargs): @implements([aten.split.Tensor]) def _(func, types, args, kwargs): - layout_tensor_list = func(args[0].layout_tensor, *args[1:], **kwargs) - out = [MyDTypeTensorTP(layout_tensor, layout_tensor.shape) for layout_tensor in layout_tensor_list] + tensor_impl_list = func(args[0].tensor_impl, *args[1:], **kwargs) + out = [MyDTypeTensorTP(tensor_impl, tensor_impl.shape) for tensor_impl in tensor_impl_list] return out @implements([aten.empty_like.default]) def _(func, types, args, kwargs): - empty_like_layout_tensor = func(args[0].layout_tensor, *args[1:], **kwargs) - return MyDTypeTensorTP(empty_like_layout_tensor, empty_like_layout_tensor.shape) + empty_like_tensor_impl = func(args[0].tensor_impl, *args[1:], **kwargs) + return MyDTypeTensorTP(empty_like_tensor_impl, empty_like_tensor_impl.shape) @implements(aten.slice.Tensor) def _(func, types, args, kwargs): @@ -41,7 +41,7 @@ def _(func, types, args, kwargs): end = self.shape[dim] shape = list(self.shape) shape[dim] = end - start - return self.__class__(aten.slice.Tensor(self.layout_tensor, dim, start, end, step), shape, self.dtype) + return self.__class__(aten.slice.Tensor(self.tensor_impl, dim, start, end, step), shape, self.dtype) # this is needed for DTensor.from_local() and for flattening tensor @implements(aten.view.default) @@ -49,10 +49,10 @@ def _(func, types, args, kwargs): x, shape = args if tuple(x.shape) == tuple(shape): - return x.__class__(x.layout_tensor, x.shape, x.dtype) + return x.__class__(x.tensor_impl, x.shape, x.dtype) if len(shape) == 1 and shape[0] == -1: - return x.__class__(x.layout_tensor, (x.numel(),), x.dtype) + return x.__class__(x.tensor_impl, (x.numel(),), x.dtype) raise ValueError(f"{x.__class__.__name__} only supports .view() with same shape or shape=[-1]") @@ -60,7 +60,7 @@ def _(func, types, args, kwargs): def _(func, types, args, kwargs): tensor = args[0] shape = tensor.shape[::-1] - new = tensor.__class__(tensor.layout_tensor.t(), shape, tensor.dtype) + new = tensor.__class__(tensor.tensor_impl.t(), shape, tensor.dtype) return return_and_correct_aliasing(func, args, kwargs, new) @implements(aten.addmm.default) From 745085fb77f8ad2100d85799e7d1f5a66cc69061 Mon Sep 17 00:00:00 2001 From: Mark Saroufim Date: Tue, 8 Oct 2024 15:34:13 -0700 Subject: [PATCH 2/6] Revert "Rename Layout -> TensorImpl" (#1040) Revert "Rename Layout -> TensorImpl (#1028)" This reverts commit cc8bf8595dfbc6e5e2ca3f18bbd6e9384e794c04. --- .github/workflows/regression_test.yml | 2 +- benchmarks/benchmark_fp6.py | 2 +- test/dtypes/test_affine_quantized_float.py | 14 +- test/dtypes/test_floatx.py | 14 +- test/hqq/test_hqq_affine.py | 4 +- test/integration/test_integration.py | 2 +- torchao/dtypes/__init__.py | 4 +- torchao/dtypes/affine_quantized_tensor.py | 210 +++++++++--------- torchao/dtypes/floatx/__init__.py | 2 +- torchao/dtypes/floatx/floatx.py | 18 +- torchao/dtypes/uintx/uintx.py | 6 +- torchao/dtypes/utils.py | 8 +- torchao/prototype/hqq/example.py | 4 +- torchao/quantization/autoquant.py | 8 +- torchao/quantization/quant_api.py | 2 +- torchao/sparsity/marlin/utils.py | 2 +- torchao/utils.py | 46 ++-- .../my_dtype_tensor_subclass.py | 64 +++--- .../my_trainable_tensor_subclass.py | 18 +- .../developer_api_guide/tensor_parallel.py | 16 +- 20 files changed, 223 insertions(+), 223 deletions(-) diff --git a/.github/workflows/regression_test.yml b/.github/workflows/regression_test.yml index 13cd4e2e76..3aee8dbfb5 100644 --- a/.github/workflows/regression_test.yml +++ b/.github/workflows/regression_test.yml @@ -35,7 +35,7 @@ jobs: gpu-arch-version: "12.1" - name: CUDA 2.4 runs-on: linux.g5.12xlarge.nvidia.gpu - torch-spec: 'torch==2.4.1' + torch-spec: 'torch==2.4.0' gpu-arch-type: "cuda" gpu-arch-version: "12.1" - name: CUDA Nightly (Oct 1) diff --git a/benchmarks/benchmark_fp6.py b/benchmarks/benchmark_fp6.py index 509ea6e86c..425507bd95 100644 --- a/benchmarks/benchmark_fp6.py +++ b/benchmarks/benchmark_fp6.py @@ -2,7 +2,7 @@ import pandas as pd import torch.nn.functional as F from torchao.dtypes import to_affine_quantized_fpx -from torchao.dtypes.floatx import FloatxTensorCoreAQTTensorImpl, FloatxTensorCoreLayoutType +from torchao.dtypes.floatx import FloatxTensorCoreAQTLayout, FloatxTensorCoreLayoutType from torchao.utils import benchmark_torch_function_in_microseconds from tqdm import tqdm diff --git a/test/dtypes/test_affine_quantized_float.py b/test/dtypes/test_affine_quantized_float.py index 761b233fcd..621e3596e0 100644 --- a/test/dtypes/test_affine_quantized_float.py +++ b/test/dtypes/test_affine_quantized_float.py @@ -210,18 +210,18 @@ def test_serialization(self, mode: str): # Compare weights if mode == "weight-only": - original_weight = original_layer.weight.tensor_impl.float8_data.to( + original_weight = original_layer.weight.layout_tensor.float8_data.to( + torch.float32 + ) + new_weight = new_layer.weight.layout_tensor.float8_data.to( torch.float32 ) - new_weight = new_layer.weight.tensor_impl.float8_data.to(torch.float32) else: - original_weight = original_layer.weight.original_weight_tensor.tensor_impl.float8_data.to( + original_weight = original_layer.weight.original_weight_tensor.layout_tensor.float8_data.to( torch.float32 ) - new_weight = ( - new_layer.weight.original_weight_tensor.tensor_impl.float8_data.to( - torch.float32 - ) + new_weight = new_layer.weight.original_weight_tensor.layout_tensor.float8_data.to( + torch.float32 ) assert torch.allclose( diff --git a/test/dtypes/test_floatx.py b/test/dtypes/test_floatx.py index f228c4c0c7..b4776f95e1 100644 --- a/test/dtypes/test_floatx.py +++ b/test/dtypes/test_floatx.py @@ -9,7 +9,7 @@ run_tests, ) from torchao.dtypes.floatx import ( - FloatxTensorCoreAQTTensorImpl, + FloatxTensorCoreAQTLayout, FloatxTensorCoreLayoutType, to_scaled_tc_floatx, from_scaled_tc_floatx, @@ -28,7 +28,7 @@ _Floatx_DTYPES = [(3, 2), (2, 2)] -class TestFloatxTensorCoreAQTTensorImpl(TestCase): +class TestFloatxTensorCoreAQTLayout(TestCase): @parametrize("device", _DEVICES) def test_pack_tc_fp6_correctness(self, device): x = torch.randint(256, size=(256, 64), dtype=torch.uint8, device=device) @@ -82,10 +82,10 @@ def test_to_copy_device(self, ebits, mbits): scale = choose_qparams_affine_floatx(x, ebits, mbits) x = quantize_affine_floatx(x, scale, ebits, mbits) layout_type = FloatxTensorCoreLayoutType(ebits, mbits) - floatx_tensor_impl = FloatxTensorCoreAQTTensorImpl.from_plain(x, scale, None, layout_type).cuda() - assert floatx_tensor_impl.device.type == "cuda" - floatx_tensor_impl = floatx_tensor_impl.cpu() - assert floatx_tensor_impl.device.type == "cpu" + floatx_layout_tensor = FloatxTensorCoreAQTLayout.from_plain(x, scale, None, layout_type).cuda() + assert floatx_layout_tensor.device.type == "cuda" + floatx_layout_tensor = floatx_layout_tensor.cpu() + assert floatx_layout_tensor.device.type == "cpu" @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") @pytest.mark.skipif(not TORCH_VERSION_AT_LEAST_2_5, reason="quantization only works with torch.compile for 2.5+") @@ -106,7 +106,7 @@ def test_fpx_weight_only(self, ebits, mbits, bias): torch.testing.assert_close(actual, expected) -instantiate_parametrized_tests(TestFloatxTensorCoreAQTTensorImpl) +instantiate_parametrized_tests(TestFloatxTensorCoreAQTLayout) if __name__ == "__main__": diff --git a/test/hqq/test_hqq_affine.py b/test/hqq/test_hqq_affine.py index c1177d2d4a..f3fa41c643 100644 --- a/test/hqq/test_hqq_affine.py +++ b/test/hqq/test_hqq_affine.py @@ -3,9 +3,9 @@ from torchao.dtypes.affine_quantized_tensor import ( to_affine_quantized_intx, ZeroPointDomain, - PlainAQTTensorImpl, + PlainAQTLayout, PlainLayoutType, - TensorCoreTiledAQTTensorImpl, + TensorCoreTiledAQTLayout, TensorCoreTiledLayoutType, MappingType, ) diff --git a/test/integration/test_integration.py b/test/integration/test_integration.py index 46799b4916..be8f2f954e 100644 --- a/test/integration/test_integration.py +++ b/test/integration/test_integration.py @@ -1051,7 +1051,7 @@ def forward(self, x): self.assertTrue(torch.equal(ref_q, test)) @parameterized.expand(COMMON_DEVICE_DTYPE) - @unittest.skipIf(is_fbcode(), "'PlainAQTTensorImpl' object has no attribute 'int_data'") + @unittest.skipIf(is_fbcode(), "'PlainAQTLayout' object has no attribute 'int_data'") @torch.no_grad() def test_save_load_dqtensors(self, device, dtype): if device == "cpu": diff --git a/torchao/dtypes/__init__.py b/torchao/dtypes/__init__.py index 8d4be52dc0..e27bf6497a 100644 --- a/torchao/dtypes/__init__.py +++ b/torchao/dtypes/__init__.py @@ -14,7 +14,7 @@ SemiSparseLayoutType, TensorCoreTiledLayoutType, Float8LayoutType, - Float8AQTTensorImpl, + Float8AQTLayout, MarlinSparseLayoutType, ) @@ -33,6 +33,6 @@ "SemiSparseLayoutType", "TensorCoreTiledLayoutType", "Float8LayoutType", - "Float8AQTTensorImpl", + "Float8AQTLayout", "MarlinSparseLayoutType", ] diff --git a/torchao/dtypes/affine_quantized_tensor.py b/torchao/dtypes/affine_quantized_tensor.py index 0fa864f8b0..c2c8e3c0b0 100644 --- a/torchao/dtypes/affine_quantized_tensor.py +++ b/torchao/dtypes/affine_quantized_tensor.py @@ -51,17 +51,17 @@ aten = torch.ops.aten ############################### -# Base Tensor Impl Subclass # +# Base Layout Tensor Subclass # ############################### -class AQTTensorImpl(TorchAOBaseTensor): +class AQTLayout(TorchAOBaseTensor): """ - Base class for the tensor impl for `AffineQuantizedTensor` + Base class for the layout tensor for `AffineQuantizedTensor` """ def get_plain(self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - """Get the plain (unpacked) Tensor for the tensor impl + """Get the plain (unpacked) Tensor for the layout Tensor Returns data, scale and zero_point - Can be overwritten if other types of AQTTensorImpl has different numbers of plain tensors + Can be overwritten if other types of AQTLayout Tensor has different numbers of plain tensors """ pass @@ -76,7 +76,7 @@ def from_plain( zero_point: torch.Tensor, layout_type: LayoutType, ): - """ Construct a TensorImpl from data, scale, zero_point and the layout_type""" + """ Construct a Layout from data, scale, zero_point and the layout_type""" pass def __repr__(self): @@ -131,7 +131,7 @@ class AffineQuantizedTensor(TorchAOBaseTensor): regardless of the internal representation's type or orientation. fields: - tensor_impl (AQTTensorImpl): tensor that serves as a general tensor impl storage for the quantized data, + layout_tensor (AQTLayout): tensor that serves as a general layout storage for the quantized data, e.g. storing plain tensors (int_data, scale, zero_point) or packed formats depending on device and operator/kernel block_size (Tuple[int, ...]): granularity of quantization, this means the size of the tensor elements that's sharing the same qparam @@ -151,7 +151,7 @@ class AffineQuantizedTensor(TorchAOBaseTensor): @staticmethod def __new__( cls, - tensor_impl: AQTTensorImpl, + layout_tensor: AQTLayout, block_size: Tuple[int, ...], shape: torch.Size, quant_min: Optional[Union[int, float]] = None, @@ -161,9 +161,9 @@ def __new__( strides=None, ): kwargs = {} - kwargs["device"] = tensor_impl.device + kwargs["device"] = layout_tensor.device kwargs["layout"] = ( - kwargs.get("layout") if kwargs.get("layout", False) else tensor_impl.layout + kwargs.get("layout") if kwargs.get("layout", False) else layout_tensor.layout ) kwargs["dtype"] = dtype if strides is not None: @@ -173,7 +173,7 @@ def __new__( def __init__( self, - tensor_impl: AQTTensorImpl, + layout_tensor: AQTLayout, block_size: Tuple[int, ...], shape: torch.Size, quant_min: Optional[Union[int, float]] = None, @@ -182,7 +182,7 @@ def __init__( dtype=None, strides=None, ): - self.tensor_impl = tensor_impl + self.layout_tensor = layout_tensor self.block_size = block_size self.quant_min = quant_min self.quant_max = quant_max @@ -190,12 +190,12 @@ def __init__( def __repr__(self): return ( - f"{self.__class__.__name__}(tensor_impl={self.tensor_impl}, block_size={self.block_size}, " + f"{self.__class__.__name__}(layout_tensor={self.layout_tensor}, block_size={self.block_size}, " f"shape={self.shape}, device={self.device}, dtype={self.dtype}, requires_grad={self.requires_grad})" ) def _quantization_type(self): - return f"shape={self.shape}, block_size={self.block_size}, device={self.device}, layout_type={self.layout_type}, tensor_impl_dtype={self.tensor_impl.dtype}, quant_min={self.quant_min}, quant_max={self.quant_max}" + return f"shape={self.shape}, block_size={self.block_size}, device={self.device}, layout_type={self.layout_type}, layout_tensor_dtype={self.layout_tensor.dtype}, quant_min={self.quant_min}, quant_max={self.quant_max}" def dequantize(self, output_dtype: Optional[torch.dtype] = None) -> torch.Tensor: if output_dtype is None: @@ -203,10 +203,10 @@ def dequantize(self, output_dtype: Optional[torch.dtype] = None) -> torch.Tensor from torchao.dtypes.floatx import FloatxTensorCoreLayoutType if isinstance(self.layout_type, FloatxTensorCoreLayoutType): - int_data, scale = self.tensor_impl.get_plain() + int_data, scale = self.layout_tensor.get_plain() return dequantize_affine_floatx(int_data, scale, self.layout_type.ebits, self.layout_type.mbits, output_dtype=output_dtype) else: - data, scale, zero_point = self.tensor_impl.get_plain() + data, scale, zero_point = self.layout_tensor.get_plain() dq = dequantize_affine( data, self.block_size, @@ -232,16 +232,16 @@ def _quantized_linear_op(input_tensor, weight_tensor, bias): raise QuantizedLinearNotImplementedError("No specialized dispatch found for quantized linear op") def __tensor_flatten__(self): - return ["tensor_impl"], [self.block_size, self.shape, self.quant_min, self.quant_max, self.zero_point_domain, self.dtype] + return ["layout_tensor"], [self.block_size, self.shape, self.quant_min, self.quant_max, self.zero_point_domain, self.dtype] @classmethod def __tensor_unflatten__( cls, tensor_data_dict, tensor_attributes, outer_size, outer_stride ): - tensor_impl = tensor_data_dict["tensor_impl"] + layout_tensor = tensor_data_dict["layout_tensor"] block_size, shape, quant_min, quant_max, zero_point_domain, dtype = tensor_attributes return cls( - tensor_impl, + layout_tensor, block_size, shape if outer_size is None else outer_size, quant_min, @@ -289,10 +289,10 @@ def from_hp_to_intx( # Note: output will be uint8 tensor for sub byte tensors for now data = layout_type.post_process(data) - tensor_impl_ctr = get_tensor_impl_constructor(type(layout_type)) - tensor_impl = tensor_impl_ctr(data, scale, zero_point, layout_type) + layout_tensor_ctr = get_layout_tensor_constructor(type(layout_type)) + layout_tensor = layout_tensor_ctr(data, scale, zero_point, layout_type) return cls( - tensor_impl, + layout_tensor, block_size, original_shape, quant_min, @@ -324,10 +324,10 @@ def from_hp_to_intx_static( int_data = layout_type.post_process(int_data) - tensor_impl_ctr = get_tensor_impl_constructor(type(layout_type)) - tensor_impl = tensor_impl_ctr(int_data, scale, zero_point, layout_type) + layout_tensor_ctr = get_layout_tensor_constructor(type(layout_type)) + layout_tensor = layout_tensor_ctr(int_data, scale, zero_point, layout_type) return cls( - tensor_impl, + layout_tensor, block_size, original_shape, quant_min, @@ -410,10 +410,10 @@ def from_hp_to_fpx( floatx_unpacked = quantize_affine_floatx(input_float, scale, ebits, mbits) floatx_packed = layout_type.post_process(floatx_unpacked) - tensor_impl_ctr = get_tensor_impl_constructor(type(layout_type)) - tensor_impl = tensor_impl_ctr(floatx_packed, scale, None, layout_type) + layout_tensor_ctr = get_layout_tensor_constructor(type(layout_type)) + layout_tensor = layout_tensor_ctr(floatx_packed, scale, None, layout_type) return cls( - tensor_impl, + layout_tensor, block_size, original_shape, dtype=input_float.dtype @@ -421,13 +421,13 @@ def from_hp_to_fpx( @property def layout_type(self) -> LayoutType: - return self.tensor_impl.layout_type + return self.layout_tensor.layout_type def to(self, *args, **kwargs): kwargs = self._get_to_kwargs(*args, **kwargs) device = kwargs.pop("device") return self.__class__( - self.tensor_impl.to(device), + self.layout_tensor.to(device), self.block_size, self.shape, self.quant_min, @@ -438,7 +438,7 @@ def to(self, *args, **kwargs): def _apply_fn_to_data(self, fn): return self.__class__( - fn(self.tensor_impl), + fn(self.layout_tensor), self.block_size, self.shape, self.quant_min, @@ -464,10 +464,10 @@ def _apply_fn_to_data(self, fn): ###################################################### -# LayoutType and TensorImpl Subclass Registration # +# LayoutType and Layout Tensor Subclass Registration # ###################################################### -register_layout = AffineQuantizedTensor.register_layout -get_tensor_impl_constructor = AffineQuantizedTensor.get_tensor_impl_constructor +register_layout_cls = AffineQuantizedTensor.register_layout_cls +get_layout_tensor_constructor = AffineQuantizedTensor.get_layout_tensor_constructor @dataclass(frozen=True) class SemiSparseLayoutType(LayoutType): @@ -548,10 +548,10 @@ def pre_process(self, input: torch.Tensor) -> torch.Tensor: return w_24.t() -@register_layout(PlainLayoutType) -class PlainAQTTensorImpl(AQTTensorImpl): +@register_layout_cls(PlainLayoutType) +class PlainAQTLayout(AQTLayout): """ - TensorImpl storage class for plain layout for affine quantized tensor, it stores int_data, scale, zero_point + Layout storage class for plain layout for affine quantized tensor, it stores int_data, scale, zero_point tensors directly as plain tensors. fields: @@ -645,12 +645,12 @@ def __torch_dispatch__(cls, func, types, args, kwargs): ) elif dim == 1: assert len(self.scale.shape) == 1, f"slice dim==1 only works when len(scale.shape) == 1 currently, got: {self.scale.shape}" - return PlainAQTTensorImpl(aten.slice.Tensor(self.int_data, dim, start, end, step), self.scale.view(-1), self.zero_point.view(-1), self.layout_type) + return PlainAQTLayout(aten.slice.Tensor(self.int_data, dim, start, end, step), self.scale.view(-1), self.zero_point.view(-1), self.layout_type) else: - raise NotImplementedError(f"PlainAQTTensorImpl dispatch: attempting to run {func}, with dim={dim}, that is not supported") + raise NotImplementedError(f"PlainAQTLayout dispatch: attempting to run {func}, with dim={dim}, that is not supported") raise NotImplementedError( - f"PlainAQTTensorImpl dispatch: attempting to run {func}, this is not supported" + f"PlainAQTLayout dispatch: attempting to run {func}, this is not supported" ) __torch_function__ = torch._C._disabled_torch_function_impl @@ -672,10 +672,10 @@ def from_plain( assert isinstance(layout_type, PlainLayoutType) return cls(int_data, scale, zero_point, layout_type) -@register_layout(SemiSparseLayoutType) -class SemiSparseAQTTensorImpl(PlainAQTTensorImpl): +@register_layout_cls(SemiSparseLayoutType) +class SemiSparseAQTLayout(PlainAQTLayout): """ - TensorImpl storage class for semi_sparse_cusparselt layout for affine quantized tensor + Layout storage class for semi_sparse_cusparselt layout for affine quantized tensor """ @classmethod def __torch_dispatch__(cls, func, types, args, kwargs): @@ -687,7 +687,7 @@ def __torch_dispatch__(cls, func, types, args, kwargs): ) raise NotImplementedError( - f"SparseAQTTensorImpl dispatch: attempting to run {func}, this is not supported" + f"SparseAQTLayout dispatch: attempting to run {func}, this is not supported" ) def get_plain(self): @@ -712,8 +712,8 @@ def from_plain( int_data_compressed = torch._cslt_compress(int_data) return cls(int_data_compressed, scale, zero_point, layout_type) -@register_layout(BlockSparseLayoutType) -class BlockSparseAQTTensorImpl(PlainAQTTensorImpl): +@register_layout_cls(BlockSparseLayoutType) +class BlockSparseAQTLayout(PlainAQTLayout): bsr_crow_indices: Optional[torch.Tensor] bsr_col_indices: Optional[torch.Tensor] bsr_values: Optional[torch.Tensor] @@ -849,13 +849,13 @@ def __torch_dispatch__(cls, func, types, args, kwargs): return args[0].bsr_values.shape[0] raise NotImplementedError( - f"BlockSparseAQTTensorImpl dispatch: attempting to run {func}, this is not supported" + f"BlockSparseAQTLayout dispatch: attempting to run {func}, this is not supported" ) -@register_layout(MarlinSparseLayoutType) -class MarlinSparseAQTTensorImpl(AQTTensorImpl): +@register_layout_cls(MarlinSparseLayoutType) +class MarlinSparseAQTLayout(AQTLayout): """ - TensorImpl storage class for sparse_marlin_24 layout for affine quantized tensor. + Layout storage class for sparse_marlin_24 layout for affine quantized tensor. Can be used with 4 bits and 8 bits quantization. @@ -922,7 +922,7 @@ def __torch_dispatch__(cls, func, types, args, kwargs): ) raise NotImplementedError( - f"MarlinSparseAQTTensorImpl dispatch: attempting to run {func}, this is not supported" + f"MarlinSparseAQTLayout dispatch: attempting to run {func}, this is not supported" ) def __tensor_flatten__(self): @@ -1022,10 +1022,10 @@ def _apply_fn_to_data(self, fn): return self -@register_layout(Float8LayoutType) -class Float8AQTTensorImpl(AQTTensorImpl): +@register_layout_cls(Float8LayoutType) +class Float8AQTLayout(AQTLayout): """ - TensorImpl storage class for float8 tensor impl for affine quantized tensor + Layout storage class for float8 layout for affine quantized tensor """ float8_data: torch.Tensor scale: torch.Tensor @@ -1112,12 +1112,12 @@ def __torch_dispatch__(cls, func, types, args, kwargs): ) elif dim == 1: assert len(self.scale.shape) == 1, f"slice dim==1 only works when len(scale.shape) == 1 currently, got: {self.scale.shape}" - return Float8AQTTensorImpl(aten.slice.Tensor(self.float8_data, dim, start, end, step), self.scale, None, self.layout_type) + return Float8AQTLayout(aten.slice.Tensor(self.float8_data, dim, start, end, step), self.scale, None, self.layout_type) else: - raise NotImplementedError(f"Float8AQTTensorImpl dispatch: attempting to run {func}, with dim={dim}, that is not supported") + raise NotImplementedError(f"Float8AQTLayout dispatch: attempting to run {func}, with dim={dim}, that is not supported") else: raise NotImplementedError( - f"Float8AQTTensorImpl dispatch: attempting to run {func}, this is not supported" + f"Float8AQTLayout dispatch: attempting to run {func}, this is not supported" ) __torch_function__ = torch._C._disabled_torch_function_impl @@ -1136,9 +1136,9 @@ def from_plain( zero_point: Optional[torch.Tensor], layout_type: LayoutType, ): - """ Main entrypoint for constructing Float8TensorImpl""" - assert _is_float8_type(data.dtype), f"Float8 TensorImpl must be constructed from float8 dtype but got {data.dtype}" - assert isinstance(layout_type, Float8LayoutType), f"Float8 TensorImpl must be constructed from Float8LayoutType but got {layout_type}" + """ Main entrypoint for constructing Float8Layout Tensor""" + assert _is_float8_type(data.dtype), f"Float8 Layout must be constructed from float8 dtype but got {data.dtype}" + assert isinstance(layout_type, Float8LayoutType), f"Float8 Layout must be constructed from Float8LayoutType but got {layout_type}" return cls(data, scale, False, layout_type) def __repr__(self): @@ -1151,10 +1151,10 @@ def __repr__(self): f"layout_type={layout_type})") -@register_layout(TensorCoreTiledLayoutType) -class TensorCoreTiledAQTTensorImpl(AQTTensorImpl): +@register_layout_cls(TensorCoreTiledLayoutType) +class TensorCoreTiledAQTLayout(AQTLayout): """ - TensorImpl storage class for tensor_core_tiled tensor impl for affine quantized tensor, this is for int4 only, + Layout storage class for tensor_core_tiled layout for affine quantized tensor, this is for int4 only, it stores the original tensor of dimension [n][k] (int32 dtype) as packed weight of 4-d tensor of dimension: [n / 8][k / (inner_k_tiles * 16)][32][inner_k_tiles / 2] @@ -1230,7 +1230,7 @@ def to(self, *args, **kwargs): kwargs = self._get_to_kwargs(*args, **kwargs) device = kwargs["device"] if not is_device("cuda", device): - raise ValueError(f"TensorCoreTiledAQTTensorImpl is only available for cuda device, can't convert to {device}") + raise ValueError(f"TensorCoreTiledAQTLayout is only available for cuda device, can't convert to {device}") return self.__class__( self.packed_weight.to(device), self.scale_and_zero.to(device), @@ -1265,7 +1265,7 @@ def __torch_dispatch__(cls, func, types, args, kwargs): return return_and_correct_aliasing(func, args, kwargs, args[0]) raise NotImplementedError( - f"TensorCoreTiledAQTTensorImpl dispatch: attempting to run {func}, this is not supported" + f"TensorCoreTiledAQTLayout dispatch: attempting to run {func}, this is not supported" ) __torch_function__ = torch._C._disabled_torch_function_impl @@ -1311,14 +1311,14 @@ def get_layout_type(self) -> LayoutType: def _aqt_is_int8(aqt): """Check if an AffineQuantizedTensor is int8 quantized Tensor""" return ( - aqt.tensor_impl.dtype == torch.int8 and + aqt.layout_tensor.dtype == torch.int8 and (aqt.quant_min is None or aqt.quant_min == -128) and (aqt.quant_max is None or aqt.quant_max == 127) ) def _aqt_is_int8_reduced_range(aqt): return ( - aqt.tensor_impl.dtype == torch.int8 and + aqt.layout_tensor.dtype == torch.int8 and aqt.quant_min == -127 and (aqt.quant_max is None or aqt.quant_max == 127) ) @@ -1327,7 +1327,7 @@ def _aqt_is_tensor_core_tile_uint4(aqt): """Check if an AffineQuantizedTensor is uint4 quantized Tensor""" # TODO: use torch.uint4 return ( - aqt.tensor_impl.dtype == torch.int32 and + aqt.layout_tensor.dtype == torch.int32 and aqt.quant_min == 0 and aqt.quant_max == 15 ) @@ -1364,10 +1364,10 @@ def _linear_int8_act_int8_weight_impl(input_tensor, weight_tensor, bias): # value of a float 16, (which results in a value of inf even if multiplying # by the other scale would bring it within the expected range) - x_vals_int8 = input_tensor.tensor_impl.int_data - x_scales = input_tensor.tensor_impl.scale - w_vals_int8_t = weight_tensor.tensor_impl.int_data.contiguous().t() - w_scales = weight_tensor.tensor_impl.scale + x_vals_int8 = input_tensor.layout_tensor.int_data + x_scales = input_tensor.layout_tensor.scale + w_vals_int8_t = weight_tensor.layout_tensor.int_data.contiguous().t() + w_scales = weight_tensor.layout_tensor.scale tmp = x_vals_int8.reshape(-1, x_vals_int8.shape[-1]) y_dot_scaled = int_scaled_matmul(tmp, w_vals_int8_t, x_scales.reshape(-1, 1)) @@ -1395,10 +1395,10 @@ def _linear_int8_act_int8_weight_semi_structured_sparse_check(input_tensor, weig ) def _linear_int8_act_int8_weight_semi_structured_sparse_impl(input_tensor, weight_tensor, bias): - x_vals_int8 = input_tensor.tensor_impl.int_data - x_scales = input_tensor.tensor_impl.scale - w_vals_int8 = weight_tensor.tensor_impl.int_data - w_scales = weight_tensor.tensor_impl.scale + x_vals_int8 = input_tensor.layout_tensor.int_data + x_scales = input_tensor.layout_tensor.scale + w_vals_int8 = weight_tensor.layout_tensor.int_data + w_scales = weight_tensor.layout_tensor.scale tmp = x_vals_int8.reshape(-1, x_vals_int8.shape[-1]) # we fuse one of the scalar matrix multiplications (w_scales) into the sparse mm y_dot_bf16_w_scales_fused = torch._cslt_sparse_mm( @@ -1427,10 +1427,10 @@ def _linear_int8_act_int8_weight_block_sparse_check(input_tensor, weight_tensor, def _linear_int8_act_int8_weight_block_sparse_impl(input_tensor, weight_tensor, bias): - x_vals_int8 = input_tensor.tensor_impl.int_data - x_scales = input_tensor.tensor_impl.scale - w_vals = weight_tensor.tensor_impl - w_scales = weight_tensor.tensor_impl.scale + x_vals_int8 = input_tensor.layout_tensor.int_data + x_scales = input_tensor.layout_tensor.scale + w_vals = weight_tensor.layout_tensor + w_scales = weight_tensor.layout_tensor.scale tmp = x_vals_int8.reshape(-1, x_vals_int8.shape[-1]) tmp_t = tmp.t() @@ -1456,7 +1456,7 @@ def _linear_bf16_act_uint4_weight_check(input_tensor, weight_tensor, bias): # input is native bfloat16 tensor not is_traceable_wrapper_subclass(input_tensor) and input_tensor.dtype == torch.bfloat16 and - # weight is uint4, group quantized tensor_core_tiled tensor impl affine quantized tensor + # weight is uint4, group quantized tensor_core_tiled layout affine quantized tensor isinstance(weight_tensor, AffineQuantizedTensor) and _aqt_is_tensor_core_tile_uint4(weight_tensor) and weight_tensor.dtype == torch.bfloat16 and @@ -1478,8 +1478,8 @@ def _linear_bf16_act_uint4_weight_impl(input_tensor, weight_tensor, bias): act_mat = input_tensor # weight is packed from padded (out_features, in_features) weight tensor # (same dimension requirement as F.linear weight) - packed_weight = weight_tensor.tensor_impl.packed_weight - scale_and_zero = weight_tensor.tensor_impl.scale_and_zero + packed_weight = weight_tensor.layout_tensor.packed_weight + scale_and_zero = weight_tensor.layout_tensor.scale_and_zero orig_act_size = act_mat.size() orig_dtype = act_mat.dtype @@ -1522,11 +1522,11 @@ def _linear_fp_act_int8_weight_check(input_tensor, weight_tensor, bias): def _linear_fp_act_int8_weight_impl(input_tensor, weight_tensor, bias): # TODO: enable cpu and mps efficient path # is_cpu and is_mps only, some issue with is_contiguous() currently - # return torch.ops.aten._weight_int8pack_mm(input_tensor.contiguous(), w_vals_int8_t, weight_tensor.tensor_impl.scale) + # return torch.ops.aten._weight_int8pack_mm(input_tensor.contiguous(), w_vals_int8_t, weight_tensor.layout_tensor.scale) # per channel int8 weight only quantizated mm - w_vals_int8_t = weight_tensor.tensor_impl.int_data.t() - scale = weight_tensor.tensor_impl.scale + w_vals_int8_t = weight_tensor.layout_tensor.int_data.t() + scale = weight_tensor.layout_tensor.scale orig_dtype = input_tensor.dtype m = torch.mm( input_tensor.reshape(-1, input_tensor.shape[-1]), @@ -1580,8 +1580,8 @@ def _linear_f16_act_floatx_weight_impl(input_tensor, weight_tensor, bias): weight.layout_type.ebits, weight.layout_type.mbits, act_reshaped, - weight.tensor_impl.packed_floatx_data, - weight.tensor_impl.scale, + weight.layout_tensor.packed_floatx_data, + weight.layout_tensor.scale, splitK=splitK, ) @@ -1599,7 +1599,7 @@ def check_aqt(aqt: Union[torch.Tensor, AffineQuantizedTensor]) -> bool: return ( isinstance(aqt, AffineQuantizedTensor) and isinstance(aqt.layout_type, Float8LayoutType) - and aqt.tensor_impl.dtype in [torch.float8_e4m3fn, torch.float8_e5m2] + and aqt.layout_tensor.dtype in [torch.float8_e4m3fn, torch.float8_e5m2] and (aqt.shape == aqt.block_size or _is_rowwise_scaled(aqt)) ) return check_aqt(input_tensor) and check_aqt(weight_tensor) @@ -1624,14 +1624,14 @@ def _linear_fp8_act_fp8_weight_impl( out_shape = get_out_shape(input_tensor.shape, weight_tensor.shape) # Weight tensor preprocessing - w_tensor_impl = weight_tensor.tensor_impl - assert not w_tensor_impl.transposed, "Weight tensor must be contiguous" - w_data = w_tensor_impl.float8_data - w_scale = w_tensor_impl.scale + w_layout = weight_tensor.layout_tensor + assert not w_layout.transposed, "Weight tensor must be contiguous" + w_data = w_layout.float8_data + w_scale = w_layout.scale # Input tensor preprocessing - inpt_data = input_tensor.tensor_impl.float8_data - input_scale = input_tensor.tensor_impl.scale + inpt_data = input_tensor.layout_tensor.float8_data + input_scale = input_tensor.layout_tensor.scale # Handle case where input tensor is more than 2D inpt_data = inpt_data.reshape(-1, inpt_data.shape[-1]) @@ -1667,7 +1667,7 @@ def _linear_fp_act_fp8_weight_check( # weight is float8 quantized affine quantized tensor isinstance(weight_tensor, AffineQuantizedTensor) and isinstance(weight_tensor.layout_type, Float8LayoutType) - and weight_tensor.tensor_impl.dtype in [torch.float8_e4m3fn, torch.float8_e5m2] + and weight_tensor.layout_tensor.dtype in [torch.float8_e4m3fn, torch.float8_e5m2] and (weight_tensor.shape == weight_tensor.block_size or _is_rowwise_scaled(weight_tensor)) ) @@ -1694,11 +1694,11 @@ def _linear_fp_act_int4_weight_sparse_marlin_impl(input_tensor, weight_tensor, b assert isinstance(weight_tensor, AffineQuantizedTensor) - sparse_w_int4 = weight_tensor.tensor_impl.int_data - scale = weight_tensor.tensor_impl.scale - meta = weight_tensor.tensor_impl.meta - original_shape = weight_tensor.tensor_impl.original_shape - num_bits = weight_tensor.tensor_impl.num_bits + sparse_w_int4 = weight_tensor.layout_tensor.int_data + scale = weight_tensor.layout_tensor.scale + meta = weight_tensor.layout_tensor.meta + original_shape = weight_tensor.layout_tensor.original_shape + num_bits = weight_tensor.layout_tensor.num_bits # Folds batch dimension into the first dimension input_2d = input_tensor.view(-1, input_tensor.shape[-1]) @@ -1845,7 +1845,7 @@ def _(func, types, args, kwargs): tensor = args[0] shape = tensor.shape[::-1] new = tensor.__class__( - tensor.tensor_impl.t(), transposed_block_size, shape, tensor.quant_min, tensor.quant_max, tensor.zero_point_domain, dtype=tensor.dtype, strides=tensor.stride() + tensor.layout_tensor.t(), transposed_block_size, shape, tensor.quant_min, tensor.quant_max, tensor.zero_point_domain, dtype=tensor.dtype, strides=tensor.stride() ) return return_and_correct_aliasing(func, args, kwargs, new) @@ -1863,7 +1863,7 @@ def _(func, types, args, kwargs): # with slice, some shape dimension might be smaller than block_size dimension, so # we need to make sure there is no overflow block_size = (min(shape[0], block_size[0]), min(shape[1], block_size[1])) - new = self.__class__(aten.slice.Tensor(self.tensor_impl, dim, start, end, step), block_size, shape, self.quant_min, self.quant_max, self.zero_point_domain, dtype=self.dtype, strides=self.stride()) + new = self.__class__(aten.slice.Tensor(self.layout_tensor, dim, start, end, step), block_size, shape, self.quant_min, self.quant_max, self.zero_point_domain, dtype=self.dtype, strides=self.stride()) return return_and_correct_aliasing(func, args, kwargs, new) # this is needed for DTensor.from_local() and for flattening tensor @@ -1872,12 +1872,12 @@ def _(func, types, args, kwargs): self, shape = args if tuple(self.shape) == tuple(shape): - return self.__class__(self.tensor_impl, self.block_size, self.shape, self.quant_min, self.quant_max, self.zero_point_domain, dtype=self.dtype, strides=self.stride()) + return self.__class__(self.layout_tensor, self.block_size, self.shape, self.quant_min, self.quant_max, self.zero_point_domain, dtype=self.dtype, strides=self.stride()) if len(shape) == 1 and shape[0] == -1: assert len(self.block_size) == 2 and self.block_size[0] == 1 block_size = (self.block_size[1],) - return self.__class__(self.tensor_impl, block_size, (self.numel(),), self.quant_min, self.quant_max, self.zero_point_domain, dtype=self.dtype, strides=self.stride()) + return self.__class__(self.layout_tensor, block_size, (self.numel(),), self.quant_min, self.quant_max, self.zero_point_domain, dtype=self.dtype, strides=self.stride()) raise ValueError(f"{self.__class__.__name__} only supports .view() with same shape or shape=[-1]") diff --git a/torchao/dtypes/floatx/__init__.py b/torchao/dtypes/floatx/__init__.py index 39461d8869..0eb1e70529 100644 --- a/torchao/dtypes/floatx/__init__.py +++ b/torchao/dtypes/floatx/__init__.py @@ -1 +1 @@ -from .floatx import FloatxTensorCoreLayoutType, FloatxTensorCoreAQTTensorImpl, to_scaled_tc_floatx, from_scaled_tc_floatx, _SPLIT_K_MAP +from .floatx import FloatxTensorCoreLayoutType, FloatxTensorCoreAQTLayout, to_scaled_tc_floatx, from_scaled_tc_floatx, _SPLIT_K_MAP diff --git a/torchao/dtypes/floatx/floatx.py b/torchao/dtypes/floatx/floatx.py index 5a9aab0357..dcbfd5f69c 100644 --- a/torchao/dtypes/floatx/floatx.py +++ b/torchao/dtypes/floatx/floatx.py @@ -10,7 +10,7 @@ ) from torchao.quantization.quant_api import _get_linear_subclass_inserter from dataclasses import dataclass -from torchao.dtypes.affine_quantized_tensor import AQTTensorImpl, register_layout +from torchao.dtypes.affine_quantized_tensor import AQTLayout, register_layout_cls aten = torch.ops.aten @@ -354,14 +354,14 @@ def from_scaled_tc_floatx(tensor: Tensor, ebits: int, mbits: int, scale=None) -> @dataclass(frozen=True) class FloatxTensorCoreLayoutType(LayoutType): - """Layout type for FloatxTensorCoreAQTTensorImpl + """Layout type for FloatxTensorCoreAQTLayout """ ebits: int mbits: int -@register_layout(FloatxTensorCoreLayoutType) -class FloatxTensorCoreAQTTensorImpl(AQTTensorImpl): - """FloatxTensorCoreAQTTensorImpl represents a Tensor with dtype floatx(ebits=a, mbits=b), +@register_layout_cls(FloatxTensorCoreLayoutType) +class FloatxTensorCoreAQTLayout(AQTLayout): + """FloatxTensorCoreAQTLayout represents a Tensor with dtype floatx(ebits=a, mbits=b), it has a internal tensor field of "packed_floatx_data", which is packed from the uint8 unpacked data (the output of `quantize_affine_floatx` operator) @@ -377,10 +377,10 @@ class FloatxTensorCoreAQTTensorImpl(AQTTensorImpl): If original Tensor shape is (M, N), and the data is in nbit, the shape of the packed data will be (M, N // 8 * nbit) - FloatxTensorCoreAQTTensorImpl.from_plain takes an unpacked uint8 floatx Tensor of shape (M, N), with format of + FloatxTensorCoreAQTLayout.from_plain takes an unpacked uint8 floatx Tensor of shape (M, N), with format of (zero padding bits + sign bit + exponent bits + mantissa bits), e.g. 00SEEEMM for fp6_e3_m2 - it will then pack the weight and instantiate the FloatxTensorCoreAQTTensorImpl tensor - FloatxTensorCoreAQTTensorImpl.__init__() takes a packed floatx Tensor of shape (M, N // 8 * nbit) + it will then pack the weight and instantiate the FloatxTensorCoreAQTLayout tensor + FloatxTensorCoreAQTLayout.__init__() takes a packed floatx Tensor of shape (M, N // 8 * nbit) """ def __new__( cls, @@ -483,7 +483,7 @@ def __torch_dispatch__(cls, func, types, args, kwargs): ) raise NotImplementedError( - f"FloatxTensorCoreAQTTensorImpl dispatch: attempting to run {func}, this is not supported" + f"FloatxTensorCoreAQTLayout dispatch: attempting to run {func}, this is not supported" ) __torch_function__ = torch._C._disabled_torch_function_impl diff --git a/torchao/dtypes/uintx/uintx.py b/torchao/dtypes/uintx/uintx.py index eb63fc6191..a0cd687f53 100644 --- a/torchao/dtypes/uintx/uintx.py +++ b/torchao/dtypes/uintx/uintx.py @@ -8,7 +8,7 @@ LayoutType, ) from torchao.utils import TorchAOBaseTensor -from torchao.dtypes.affine_quantized_tensor import PlainAQTTensorImpl, register_layout +from torchao.dtypes.affine_quantized_tensor import PlainAQTLayout, register_layout_cls from torchao.utils import TORCH_VERSION_AT_LEAST_2_3 aten = torch.ops.aten @@ -194,8 +194,8 @@ class UintxLayoutType(LayoutType): def post_process(self, input: torch.Tensor) -> torch.Tensor: return to_uintx(input, self.dtype, self.pack_dim) -@register_layout(UintxLayoutType) -class UintxAQTTensorImpl(PlainAQTTensorImpl): +@register_layout_cls(UintxLayoutType) +class UintxAQTLayout(PlainAQTLayout): def get_plain(self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: return self.int_data.get_plain(), self.scale, self.zero_point diff --git a/torchao/dtypes/utils.py b/torchao/dtypes/utils.py index 4a6b3a0bb8..52a9c57191 100644 --- a/torchao/dtypes/utils.py +++ b/torchao/dtypes/utils.py @@ -4,11 +4,11 @@ """ Base class for different LayoutType, should not be instantiated directly -used to allow users to pass around configurations for the tensor impl, e.g. inner_k_tiles -for int4 tensor core tiled tensor impl +used to allow users to pass around configurations for the layout tensor, e.g. inner_k_tiles +for int4 tensor core tiled layout -Note: TensorImpl is an abstraction not only for custom data representation, it is also used for how the -tensorImpl interacts with different operators, e.g. the same data representation can have different +Note: layout is an abstraction not only for custom data representation, it is also used for how the +layout interacts with different operators, e.g. the same data representation can have different behaviors when running the same operator, e.g. transpose, quantized_linear. """ @dataclass(frozen=True) diff --git a/torchao/prototype/hqq/example.py b/torchao/prototype/hqq/example.py index f410a11cd8..cd5a93b561 100644 --- a/torchao/prototype/hqq/example.py +++ b/torchao/prototype/hqq/example.py @@ -3,9 +3,9 @@ from torchao.dtypes.affine_quantized_tensor import ( to_affine_quantized_intx, ZeroPointDomain, - PlainAQTTensorImpl, + PlainAQTLayout, PlainLayoutType, - TensorCoreTiledAQTTensorImpl, + TensorCoreTiledAQTLayout, TensorCoreTiledLayoutType, MappingType, ) diff --git a/torchao/quantization/autoquant.py b/torchao/quantization/autoquant.py index 7439c982b4..a5568c4e17 100644 --- a/torchao/quantization/autoquant.py +++ b/torchao/quantization/autoquant.py @@ -348,7 +348,7 @@ def _autoquant_test(cls, act_mat, weight, bias, best_time, mode=["relu", None]): ) q_c_matmul=torch.compile(quantized_matmul, mode="max-autotune-no-cudagraphs") with torch.no_grad(): - w_vals_int8 = w_qtensor.original_weight_tensor.tensor_impl.int_data.contiguous().t() + w_vals_int8 = w_qtensor.original_weight_tensor.layout_tensor.int_data.contiguous().t() res_matmul = do_autoquant_bench(q_c_matmul, x_vals_int8, x_scales.reshape(-1,1), w_vals_int8) print(f">>time: {res_matmul:0.3f}ms for {cls} matmul, to_beat: {best_time:0.3f}ms") @@ -399,8 +399,8 @@ def _quantized_linear_op(act_mat, w_qtensor, bias): orig_dtype = act_mat.dtype orig_shape = act_mat.shape act_mat = act_mat.reshape(-1, act_mat.shape[-1], 1) - y = (act_mat*w_qtensor.tensor_impl.int_data.t().unsqueeze(0)).sum(dim=-2) - y = y.reshape(*orig_shape[:-1], y.shape[-1]) * w_qtensor.tensor_impl.scale + y = (act_mat*w_qtensor.layout_tensor.int_data.t().unsqueeze(0)).sum(dim=-2) + y = y.reshape(*orig_shape[:-1], y.shape[-1]) * w_qtensor.layout_tensor.scale if bias is not None: y += bias return y.to(orig_dtype) @@ -420,7 +420,7 @@ class AQInt8WeightOnlyQuantizedLinearWeight3(AQInt8WeightOnlyQuantizedLinearWeig @staticmethod def _quantized_linear_op(act_mat, w_qtensor, bias): orig_shape = act_mat.shape - y = torch.mm(act_mat.reshape(-1, orig_shape[-1]), w_qtensor.tensor_impl.int_data.t()*w_qtensor.tensor_impl.scale) + y = torch.mm(act_mat.reshape(-1, orig_shape[-1]), w_qtensor.layout_tensor.int_data.t()*w_qtensor.layout_tensor.scale) y=y.reshape(*orig_shape[:-1], y.shape[-1]) if bias is not None: y += bias diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py index aef873e2fd..6c41425062 100644 --- a/torchao/quantization/quant_api.py +++ b/torchao/quantization/quant_api.py @@ -886,7 +886,7 @@ def fpx_weight_only(ebits: int, mbits: int): e.g. fp6_e3_m2, fp6_e2_m3, ... The packing format and kernels are from the fp6-llm paper: https://arxiv.org/abs/2401.14112 github repo: https://github.com/usyd-fsalab/fp6_llm, now renamed to quant-llm - For more details for packing please see: :class:`~torchao.dtypes.fpx.FpxTensorCoreAQTTensorImpl` + For more details for packing please see: :class:`~torchao.dtypes.fpx.FpxTensorCoreAQTLayout` This is experimental, will be merged with `to_affine_quantized_floatx` in the future diff --git a/torchao/sparsity/marlin/utils.py b/torchao/sparsity/marlin/utils.py index 4c55725539..4ebdf432e3 100644 --- a/torchao/sparsity/marlin/utils.py +++ b/torchao/sparsity/marlin/utils.py @@ -9,7 +9,7 @@ class Marlin24Constants: MIN_THREAD_N: int = 128 MAX_PARALLEL: int = 64 - # NOTE: Cuda kernel supports fp8, but not implemented yet in SparseMarlinAQTTensorImpl + # NOTE: Cuda kernel supports fp8, but not implemented yet in SparseMarlinAQTLayout SUPPORTED_NUM_BITS: List[int] = field(default_factory=lambda: [4, 8]) SUPPORTED_GROUP_SIZES: List[int] = field(default_factory=lambda: [-1, 32, 64, 128]) const = Marlin24Constants() diff --git a/torchao/utils.py b/torchao/utils.py index 36bc1be36b..a0302cabe6 100644 --- a/torchao/utils.py +++ b/torchao/utils.py @@ -392,34 +392,34 @@ class MyTensor(torch.Tensor): kwarg_types = {k: type(arg) for k, arg in kwargs} raise NotImplementedError(f"{cls.__name__} dispatch: attempting to run unimplemented operator/function: {func=}, {types=}, {arg_types=}, {kwarg_types=}") -def _register_layout(cls: Callable, layout_type_class: Callable): +def _register_layout_cls(cls: Callable, layout_type_class: Callable): """Helper function for layout registrations, this is used to implement - register_layout decorator for each tensor subclass, see aqt.py for example usage + register_layout_cls decorator for each tensor subclass, see aqt.py for example usage Args: cls: Tensor subclass type layout_type_class: the class type of subclass of `LayoutType`, e.g. `PlainLayoutType` Returns: - a decorator that registers the tensor impl constructor in the table + a decorator that registers the layout tensor constructor in the table """ # cls._LAYOUT_CONSTRUCTOR_TABLE is a map from layout_type_class like TensorCoreTiledLayout - # to tensor_impl class constructor like TensorCoreTiledAQTTensorImpl.from_plain that can construct a tensor_impl + # to layout class constructor like TensorCoreTiledAQTLayout.from_plain that can construct a layout_tensor # from plain data like (quantized, unpacked) `data`, `scale`, `zero_point` if not hasattr(cls, "_LAYOUT_CONSTRUCTOR_TABLE"): cls._LAYOUT_CONSTRUCTOR_TABLE = {} - def decorator(tensor_impl_class): - cls._LAYOUT_CONSTRUCTOR_TABLE[layout_type_class] = tensor_impl_class.from_plain + def decorator(layout_class): + cls._LAYOUT_CONSTRUCTOR_TABLE[layout_type_class] = layout_class.from_plain if TORCH_VERSION_AT_LEAST_2_5: - # Allow serialization to work for models uses this tensor impl subclass - torch.serialization.add_safe_globals([layout_type_class, tensor_impl_class]) - return tensor_impl_class + # Allow serialization to work for models uses this layout tensor subclass + torch.serialization.add_safe_globals([layout_type_class, layout_class]) + return layout_class return decorator -def _get_tensor_impl_constructor(cls: Callable, layout_type_class: Callable) -> Callable: - """Get TensorImpl class constructor (TensorImplClass.from_plain) for `cls` based on `layout_type_class` +def _get_layout_tensor_constructor(cls: Callable, layout_type_class: Callable) -> Callable: + """Get Layout class constructor (LayoutClass.from_plain) for `cls` based on `layout_type_class` `layout_type_class` means the class type of subclass of `LayoutType`, e.g. `PlainLayoutType` Args: @@ -427,10 +427,10 @@ def _get_tensor_impl_constructor(cls: Callable, layout_type_class: Callable) -> layout_type_class: the class type of subclass of `LayoutType`, e.g. `PlainLayoutType` Returns: - tensor impl subclass constructor for the layout_type_class + layout tensor subclass constructor for the layout_type_class """ if not hasattr(cls, "_LAYOUT_CONSTRUCTOR_TABLE"): - raise ValueError(f"no registered tensor_impl class constructor for: {cls}") + raise ValueError(f"no registered layout class constructor for: {cls}") if layout_type_class not in cls._LAYOUT_CONSTRUCTOR_TABLE: raise ValueError(f"layout_name: {layout_type_class} is not supported yet for {cls}") @@ -457,25 +457,25 @@ def to(self, *args, **kwargs): def _(func, types, args, kwargs): ... - `register_layout`: - register_layout = MyTensor.register_layout + `register_layout_cls`: + register_layout_cls = MyTensor.register_layout_cls - @register_layout(PlainLayoutType) - class PlainAQTTensorImpl(...): + @register_layout_cls(PlainLayoutType) + class PlainAQTLayout(...): ... - `get_tensor_impl_constructor`: - get_tensor_impl_constructor = MyTensor.get_tensor_impl_constructor + `get_layout_tensor_constructor`: + get_layout_tensor_constructor = MyTensor.get_layout_tensor_constructor # in constructor of MyTensor: - tensor_impl_ctr = get_tensor_impl_constructor(type(layout_type)) - tensor_impl = tensor_impl_ctr(data, scale, zero_point, layout_type) + layout_tensor_ctr = get_layout_tensor_constructor(type(layout_type)) + layout_tensor = layout_tensor_ctr(data, scale, zero_point, layout_type) """ implements = classmethod(_implements) __torch_dispatch__ = classmethod(_dispatch__torch_dispatch__) __torch_function__ = classmethod(_dispatch__torch_function__) - register_layout = classmethod(_register_layout) - get_tensor_impl_constructor = classmethod(_get_tensor_impl_constructor) + register_layout_cls = classmethod(_register_layout_cls) + get_layout_tensor_constructor = classmethod(_get_layout_tensor_constructor) def _get_to_kwargs(self, *args, **kwargs): # `torch._C._nn._parse_to` can't handle `layout` argument diff --git a/tutorials/developer_api_guide/my_dtype_tensor_subclass.py b/tutorials/developer_api_guide/my_dtype_tensor_subclass.py index c714df2a7b..bc85d26f5d 100644 --- a/tutorials/developer_api_guide/my_dtype_tensor_subclass.py +++ b/tutorials/developer_api_guide/my_dtype_tensor_subclass.py @@ -33,11 +33,11 @@ aten = torch.ops.aten ############################### -# Base Tensor Impl Subclass # +# Base Layout Tensor Subclass # ############################### -class MyDTypeTensorImpl(torch.Tensor): +class MyDTypeLayout(torch.Tensor): """ - Base class for the tensor impl for `MyDTypeTensor` + Base class for the layout tensor for `MyDTypeTensor` """ # get the original unpacked Tensors def get_plain(self) -> Tuple[torch.Tensor, torch.Tensor]: @@ -53,7 +53,7 @@ def from_plain( scale: torch.Tensor, layout_type: LayoutType, ): - """Construct a tensor impl from plain tensors and a layout_type, which main contain + """Construct a layout tensor from plain tensors and a layout_type, which main contain extra metadata for packing etc. """ pass @@ -82,17 +82,17 @@ class MyDTypeTensor(TorchAOBaseTensor): @staticmethod def __new__( cls, - tensor_impl: MyDTypeTensorImpl, + layout_tensor: MyDTypeLayout, shape: torch.Size, dtype: Optional[torch.dtype] = None, requires_grad: bool = False, ): kwargs = {} - kwargs["device"] = tensor_impl.device + kwargs["device"] = layout_tensor.device kwargs["layout"] = ( kwargs.get("layout") if kwargs.get("layout", False) - else tensor_impl.layout + else layout_tensor.layout ) kwargs["dtype"] = dtype kwargs["requires_grad"] = requires_grad @@ -100,12 +100,12 @@ def __new__( def __init__( self, - tensor_impl: MyDTypeTensorImpl, + layout_tensor: MyDTypeLayout, shape: torch.Size, dtype: Optional[torch.dtype] = None, requires_grad: bool = False, ): - self.tensor_impl = tensor_impl + self.layout_tensor = layout_tensor """__tensor_flatten__ and __tensor_unflatten__ are used to desugar the tensor into native Tensors/attributes and reconstruct the tensor subclass instance from the desugared tensor and attributes, these are required to define @@ -118,7 +118,7 @@ def __tensor_flatten__(self): The first one contains any tensor fields such as int_data and scale as keys to a dictionary The second one contains all other non tensor type fields as values of a list """ - return ["tensor_impl"], [self.shape, self.dtype, self.requires_grad] + return ["layout_tensor"], [self.shape, self.dtype, self.requires_grad] @classmethod def __tensor_unflatten__( @@ -129,10 +129,10 @@ def __tensor_unflatten__( tensor_data_dict contains the tensor fields of the class as a dictionary tensor_attributes contains all other non tensor type fields """ - tensor_impl = tensor_data_dict["tensor_impl"] + layout_tensor = tensor_data_dict["layout_tensor"] shape, dtype, requires_grad = tensor_attributes return cls( - tensor_impl, + layout_tensor, shape if outer_size is None else outer_size, dtype=dtype, requires_grad=requires_grad, @@ -152,25 +152,25 @@ def from_float( dtype = torch.int16 scale, zero_point = choose_qparams_affine(input_float, mapping_type, block_size, dtype) int_data = quantize_affine(input_float, block_size, scale, zero_point, dtype) - tensor_impl_ctr = get_tensor_impl_constructor(type(layout_type)) - tensor_impl = tensor_impl_ctr(int_data, scale, layout_type) - return cls(tensor_impl, input_float.shape) + layout_tensor_ctr = get_layout_tensor_constructor(type(layout_type)) + layout_tensor = layout_tensor_ctr(int_data, scale, layout_type) + return cls(layout_tensor, input_float.shape) """[Optional] We can overwrite layout property of the Tensor to represent different packing formats """ @property def layout_type(self) -> LayoutType: - return self.tensor_impl.layout_type + return self.layout_tensor.layout_type def dequantize(self, output_dtype=None): """We can define a dequantize method to convert the quantized tensor to a floating point tensor""" if output_dtype is None: output_dtype = torch.get_default_dtype() - int_data, scale = self.tensor_impl.get_plain() + int_data, scale = self.layout_tensor.get_plain() transposed = False block_size = (1, int_data.shape[-1]) - if hasattr(self.tensor_impl, "transposed") and self.tensor_impl.transposed: + if hasattr(self.layout_tensor, "transposed") and self.layout_tensor.transposed: transposed = True res = dequantize_affine(int_data, block_size, scale, None, int_data.dtype, output_dtype=output_dtype) if transposed: @@ -186,10 +186,10 @@ def __repr__(self): def _apply_fn_to_data(self, fn): """ Used for implementing aten ops by applying them only to the relevant tensor atributes - In this case we only want to call things like to() or view() on the tensor impl + In this case we only want to call things like to() or view() on the layout tensor """ return self.__class__( - fn(self.tensor_impl), + fn(self.layout_tensor), self.shape, self.dtype, ) @@ -206,14 +206,14 @@ def _apply_fn_to_data(self, fn): """ ###################################################### -# LayoutType and TensorImpl Subclass Registration # +# LayoutType and Layout Tensor Subclass Registration # ###################################################### -register_layout = MyDTypeTensor.register_layout -get_tensor_impl_constructor = MyDTypeTensor.get_tensor_impl_constructor +register_layout_cls = MyDTypeTensor.register_layout_cls +get_layout_tensor_constructor = MyDTypeTensor.get_layout_tensor_constructor -@register_layout(PlainLayoutType) -class PlainMyDTypeTensorImpl(MyDTypeTensorImpl): +@register_layout_cls(PlainLayoutType) +class PlainMyDTypeLayout(MyDTypeLayout): def __new__( cls, int_data: torch.Tensor, @@ -261,7 +261,7 @@ def from_plain( scale: torch.Tensor, layout_type: LayoutType, ): - """Construct a tensor impl from plain tensors and a layout_type, which main contain + """Construct a layout tensor from plain tensors and a layout_type, which main contain extra metadata for packing etc. """ assert isinstance(layout_type, PlainLayoutType) @@ -292,11 +292,11 @@ def __torch_dispatch__(cls, func, types, args, kwargs): elif func is aten.split.Tensor: int_data_list = func(args[0].int_data, *args[1:], **kwargs) scale_list = func(args[0].scale, *args[1:], **kwargs) - out = [PlainMyDTypeTensorImpl(int_data, scale, args[0].transposed, args[0].layout_type) for int_data, scale in zip(int_data_list, scale_list)] + out = [PlainMyDTypeLayout(int_data, scale, args[0].transposed, args[0].layout_type) for int_data, scale in zip(int_data_list, scale_list)] return out elif func is aten.empty_like.default: int_data_empty_like = func(args[0].int_data, *args[1:], **kwargs) - return PlainMyDTypeTensorImpl(int_data_empty_like, args[0].scale, args[0].transposed, args[0].layout_type) + return PlainMyDTypeLayout(int_data_empty_like, args[0].scale, args[0].transposed, args[0].layout_type) elif func is aten.slice.Tensor: self, dim, start, end, step = fill_defaults(args, 5, [0, None, None, 1]) if dim == 0: @@ -304,16 +304,16 @@ def __torch_dispatch__(cls, func, types, args, kwargs): func, args, kwargs, args[0]._apply_fn_to_data(lambda x: aten.slice.Tensor(x, dim, start, end, step)) ) elif dim == 1: - return PlainMyDTypeTensorImpl(aten.slice.Tensor(self.int_data, dim, start, end, step), self.scale.view(-1), self.transposed, self.layout_type) + return PlainMyDTypeLayout(aten.slice.Tensor(self.int_data, dim, start, end, step), self.scale.view(-1), self.transposed, self.layout_type) else: - raise NotImplementedError(f"PlainMyDTypeTensorImpl dispatch: attempting to run {func}, with dim={dim}, that is not supported") + raise NotImplementedError(f"PlainMyDTypeLayout dispatch: attempting to run {func}, with dim={dim}, that is not supported") elif func is aten.t.default: - return return_and_correct_aliasing(func, args, kwargs, PlainMyDTypeTensorImpl(args[0].int_data, args[0].scale, not args[0].transposed, args[0].layout_type)) + return return_and_correct_aliasing(func, args, kwargs, PlainMyDTypeLayout(args[0].int_data, args[0].scale, not args[0].transposed, args[0].layout_type)) # Tensor parallel support END raise NotImplementedError( - f"PlainMyDTypeTensorImpl dispatch: attempting to run {func}, this is not supported" + f"PlainMyDTypeLayout dispatch: attempting to run {func}, this is not supported" ) ##################################################### diff --git a/tutorials/developer_api_guide/my_trainable_tensor_subclass.py b/tutorials/developer_api_guide/my_trainable_tensor_subclass.py index 59e72efb6b..b702ac4f91 100644 --- a/tutorials/developer_api_guide/my_trainable_tensor_subclass.py +++ b/tutorials/developer_api_guide/my_trainable_tensor_subclass.py @@ -43,8 +43,8 @@ def _quantize( dtype = torch.int16 scale, _ = choose_qparams_affine(input_float, mapping_type, block_size, dtype) int_data = (input_float / scale).to(torch.int8) - tensor_impl_ctr = cls.get_tensor_impl_constructor(type(layout_type)) - return tensor_impl_ctr(int_data, scale, layout_type) + layout_tensor_ctr = cls.get_layout_tensor_constructor(type(layout_type)) + return layout_tensor_ctr(int_data, scale, layout_type) @classmethod def from_float( @@ -71,9 +71,9 @@ def forward( input_float: torch.Tensor, layout_type: LayoutType, ) -> "MyTrainableDTypeTensor": - tensor_impl = MyTrainableDTypeTensor._quantize(input_float, layout_type) + layout_tensor = MyTrainableDTypeTensor._quantize(input_float, layout_type) return MyTrainableDTypeTensor( - tensor_impl, + layout_tensor, input_float.shape, requires_grad=True, ) @@ -137,15 +137,15 @@ def _(func, types, args, kwargs): """ assert len(args) == 2 assert isinstance(args[0], MyTrainableDTypeTensor) - assert args[0].tensor_impl.int_data.dtype == torch.int8 + assert args[0].layout_tensor.int_data.dtype == torch.int8 float0 = args[0].dequantize() float1 = args[1].dequantize() if isinstance(args[1], MyTrainableDTypeTensor) else args[1] new_value = torch.add(float0, float1, **kwargs) - new_tensor_impl = MyTrainableDTypeTensor._quantize( + new_layout_tensor = MyTrainableDTypeTensor._quantize( new_value, - args[0].tensor_impl.get_layout_type(), + args[0].layout_tensor.get_layout_type(), ) - args[0].tensor_impl = new_tensor_impl + args[0].layout_tensor = new_layout_tensor return return_and_correct_aliasing(func, args, kwargs, args[0]) @implements(aten.add.Tensor) @@ -190,7 +190,7 @@ def main(): loss = loss_fn(output, target) loss.backward() if VERBOSE: - weight = m.linear.weight.tensor_impl.int_data.flatten()[:3] + weight = m.linear.weight.layout_tensor.int_data.flatten()[:3] weight_grad = m.linear.weight.grad.flatten()[:3] print(" * step %s: weight grad = %s, weight value = %s" % (i, weight_grad, weight)) optimizer.step() diff --git a/tutorials/developer_api_guide/tensor_parallel.py b/tutorials/developer_api_guide/tensor_parallel.py index 84de815a36..0ed3bc9a29 100644 --- a/tutorials/developer_api_guide/tensor_parallel.py +++ b/tutorials/developer_api_guide/tensor_parallel.py @@ -24,14 +24,14 @@ def _(func, types, args, kwargs): @implements([aten.split.Tensor]) def _(func, types, args, kwargs): - tensor_impl_list = func(args[0].tensor_impl, *args[1:], **kwargs) - out = [MyDTypeTensorTP(tensor_impl, tensor_impl.shape) for tensor_impl in tensor_impl_list] + layout_tensor_list = func(args[0].layout_tensor, *args[1:], **kwargs) + out = [MyDTypeTensorTP(layout_tensor, layout_tensor.shape) for layout_tensor in layout_tensor_list] return out @implements([aten.empty_like.default]) def _(func, types, args, kwargs): - empty_like_tensor_impl = func(args[0].tensor_impl, *args[1:], **kwargs) - return MyDTypeTensorTP(empty_like_tensor_impl, empty_like_tensor_impl.shape) + empty_like_layout_tensor = func(args[0].layout_tensor, *args[1:], **kwargs) + return MyDTypeTensorTP(empty_like_layout_tensor, empty_like_layout_tensor.shape) @implements(aten.slice.Tensor) def _(func, types, args, kwargs): @@ -41,7 +41,7 @@ def _(func, types, args, kwargs): end = self.shape[dim] shape = list(self.shape) shape[dim] = end - start - return self.__class__(aten.slice.Tensor(self.tensor_impl, dim, start, end, step), shape, self.dtype) + return self.__class__(aten.slice.Tensor(self.layout_tensor, dim, start, end, step), shape, self.dtype) # this is needed for DTensor.from_local() and for flattening tensor @implements(aten.view.default) @@ -49,10 +49,10 @@ def _(func, types, args, kwargs): x, shape = args if tuple(x.shape) == tuple(shape): - return x.__class__(x.tensor_impl, x.shape, x.dtype) + return x.__class__(x.layout_tensor, x.shape, x.dtype) if len(shape) == 1 and shape[0] == -1: - return x.__class__(x.tensor_impl, (x.numel(),), x.dtype) + return x.__class__(x.layout_tensor, (x.numel(),), x.dtype) raise ValueError(f"{x.__class__.__name__} only supports .view() with same shape or shape=[-1]") @@ -60,7 +60,7 @@ def _(func, types, args, kwargs): def _(func, types, args, kwargs): tensor = args[0] shape = tensor.shape[::-1] - new = tensor.__class__(tensor.tensor_impl.t(), shape, tensor.dtype) + new = tensor.__class__(tensor.layout_tensor.t(), shape, tensor.dtype) return return_and_correct_aliasing(func, args, kwargs, new) @implements(aten.addmm.default) From b67407afea74ede621a0b6574d397ca56e26fa34 Mon Sep 17 00:00:00 2001 From: cpuhrsch Date: Tue, 8 Oct 2024 16:45:15 -0700 Subject: [PATCH 3/6] Basic SAM2 AutomaticMaskGeneration example server (#1039) --- examples/sam2_amg_server/README.md | 5 + examples/sam2_amg_server/amg_example.py | 133 +++++++++++++++++++ examples/sam2_amg_server/example.html | 57 +++++++++ examples/sam2_amg_server/sam2_hiera_l.yaml | 117 +++++++++++++++++ examples/sam2_amg_server/server.py | 142 +++++++++++++++++++++ 5 files changed, 454 insertions(+) create mode 100644 examples/sam2_amg_server/README.md create mode 100644 examples/sam2_amg_server/amg_example.py create mode 100644 examples/sam2_amg_server/example.html create mode 100644 examples/sam2_amg_server/sam2_hiera_l.yaml create mode 100644 examples/sam2_amg_server/server.py diff --git a/examples/sam2_amg_server/README.md b/examples/sam2_amg_server/README.md new file mode 100644 index 0000000000..776c85390a --- /dev/null +++ b/examples/sam2_amg_server/README.md @@ -0,0 +1,5 @@ +To run this example you need to download the vit_h checkpoint and put it into a local folder named checkpoints + +You can find the checkpoint for vit_h here: https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth + +To read the image you also need to install opencv-python: https://pypi.org/project/opencv-python/ diff --git a/examples/sam2_amg_server/amg_example.py b/examples/sam2_amg_server/amg_example.py new file mode 100644 index 0000000000..8305ea5340 --- /dev/null +++ b/examples/sam2_amg_server/amg_example.py @@ -0,0 +1,133 @@ +import numpy as np +import torch +import matplotlib.pyplot as plt +import cv2 +import torch.utils.benchmark as benchmark + +from torch._inductor import config as inductorconfig +inductorconfig.triton.unique_kernel_names = True +inductorconfig.coordinate_descent_tuning = True +inductorconfig.coordinate_descent_check_all_directions = True + +def profiler_runner(path, fn, *args, **kwargs): + with torch.profiler.profile( + activities=[torch.profiler.ProfilerActivity.CPU, + torch.profiler.ProfilerActivity.CUDA], + record_shapes=True) as prof: + result = fn(*args, **kwargs) + print(f"Saving trace under {path}") + prof.export_chrome_trace(path) + return result + +def show_anns(anns): + if len(anns) == 0: + return + sorted_anns = sorted(anns, key=(lambda x: x['area']), reverse=True) + ax = plt.gca() + ax.set_autoscale_on(False) + + img = np.ones((sorted_anns[0]['segmentation'].shape[0], sorted_anns[0]['segmentation'].shape[1], 4)) + img[:,:,3] = 0 + ms = [] + for ann in sorted_anns: + m = ann['segmentation'] + ms.append(torch.as_tensor(m)) + color_mask = np.concatenate([np.random.random(3), [0.35]]) + img[m] = color_mask + ax.imshow(img) + return torch.stack(ms) + +image = cv2.imread('dog.jpg') +image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) + + +# from segment_anything_fast import sam_model_registry, sam_model_fast_registry, SamAutomaticMaskGenerator +# +# sam_checkpoint = "checkpoints/sam_vit_h_4b8939.pth" +# model_type = "vit_h" +device = "cuda" +# +# sam = sam_model_fast_registry[model_type](checkpoint=sam_checkpoint) +# sam.to(device=device) + +from sam2.build_sam import build_sam2 +from sam2.automatic_mask_generator import SAM2AutomaticMaskGenerator + +sam2_checkpoint = "checkpoints/sam2_hiera_large.pt" +model_cfg = "sam2_hiera_l.yaml" + +sam2 = build_sam2(model_cfg, sam2_checkpoint, device=device, apply_postprocessing=False) +sam2.to(device=device) + +# mask_generator = SAM2AutomaticMaskGenerator(sam2, points_per_batch=256) +mask_generator = SAM2AutomaticMaskGenerator(sam2, points_per_batch=None) + +## NOTE: Causes numerical differences +## TODO: Implement mIoU to allow approximations. +# torch.set_float32_matmul_precision('high') +# torch.autocast("cuda", dtype=torch.bfloat16).__enter__() +## + +## TODO: Using CUDA graphs can cause numerical differences? +mask_generator.predictor.model.image_encoder = torch.compile( + mask_generator.predictor.model.image_encoder, + # mode="max-autotune-no-cudagraphs", + mode="max-autotune", + fullgraph=True, + dynamic=False, +) + +# mask_generator.predictor._predict = torch.compile( +# mask_generator.predictor._predict, +# mode="max-autotune-no-cudagraphs", +# fullgraph=True, +# dynamic=False, +# ) + +torch._dynamo.config.capture_dynamic_output_shape_ops = True +mask_generator._process_batch = torch.compile( + mask_generator._process_batch, + mode="max-autotune-no-cudagraphs", + fullgraph=True, + dynamic=True, +) + +# with torch.backends.cuda.sdp_kernel(enable_cudnn=False): #, enable_math=False, enable_mem_efficient=False): +with torch.backends.cuda.sdp_kernel(enable_cudnn=True): #, enable_math=False, enable_mem_efficient=False): + # Run thrice for warmup + masks = mask_generator.generate(image) + masks = mask_generator.generate(image) + masks = mask_generator.generate(image) + + # Save an example + plt.figure(figsize=(image.shape[1]/100., image.shape[0]/100.), dpi=100) + plt.imshow(image) + ms = show_anns(masks) + ms_ref = torch.load("dog_mask_fast.pt") + torch.testing.assert_allclose(ms, ms_ref) + print("Masks match reference") + # # torch.save(ms, "dog_mask_fast.pt") + plt.axis('off') + plt.tight_layout() + plt.savefig('dog_mask_fast.png', format='png') + + # Benchmark + torch.cuda.synchronize() + start_event = torch.cuda.Event(enable_timing=True) + end_event = torch.cuda.Event(enable_timing=True) + start_event.record() + for _ in range(10): + masks = mask_generator.generate(image) + end_event.record() + torch.cuda.synchronize() + print(start_event.elapsed_time(end_event) / 10.) + + # Save a GPU trace + profiler_runner(f"amg_example_trace.json.gz", mask_generator.generate, image) + + # Write out memory usage + max_memory_allocated_bytes = torch.cuda.max_memory_allocated() + _, total_memory = torch.cuda.mem_get_info() + max_memory_allocated_percentage = int(100 * (max_memory_allocated_bytes / total_memory)) + max_memory_allocated_bytes = max_memory_allocated_bytes >> 20 + print(f"memory(MiB): {max_memory_allocated_bytes} memory(%): {max_memory_allocated_percentage}") diff --git a/examples/sam2_amg_server/example.html b/examples/sam2_amg_server/example.html new file mode 100644 index 0000000000..0122c23de3 --- /dev/null +++ b/examples/sam2_amg_server/example.html @@ -0,0 +1,57 @@ +!DOCTYPE html> + + + + + Upload and Display Image from FastAPI Response + + + +

Upload an Image and Display the Response

+
+
+

+ +
+ +

Received Image Preview:

+ Received Image + + + + diff --git a/examples/sam2_amg_server/sam2_hiera_l.yaml b/examples/sam2_amg_server/sam2_hiera_l.yaml new file mode 100644 index 0000000000..918667f50c --- /dev/null +++ b/examples/sam2_amg_server/sam2_hiera_l.yaml @@ -0,0 +1,117 @@ +# @package _global_ + +# Model +model: + _target_: sam2.modeling.sam2_base.SAM2Base + image_encoder: + _target_: sam2.modeling.backbones.image_encoder.ImageEncoder + scalp: 1 + trunk: + _target_: sam2.modeling.backbones.hieradet.Hiera + embed_dim: 144 + num_heads: 2 + stages: [2, 6, 36, 4] + global_att_blocks: [23, 33, 43] + window_pos_embed_bkg_spatial_size: [7, 7] + window_spec: [8, 4, 16, 8] + neck: + _target_: sam2.modeling.backbones.image_encoder.FpnNeck + position_encoding: + _target_: sam2.modeling.position_encoding.PositionEmbeddingSine + num_pos_feats: 256 + normalize: true + scale: null + temperature: 10000 + d_model: 256 + backbone_channel_list: [1152, 576, 288, 144] + fpn_top_down_levels: [2, 3] # output level 0 and 1 directly use the backbone features + fpn_interp_model: nearest + + memory_attention: + _target_: sam2.modeling.memory_attention.MemoryAttention + d_model: 256 + pos_enc_at_input: true + layer: + _target_: sam2.modeling.memory_attention.MemoryAttentionLayer + activation: relu + dim_feedforward: 2048 + dropout: 0.1 + pos_enc_at_attn: false + self_attention: + _target_: sam2.modeling.sam.transformer.RoPEAttention + rope_theta: 10000.0 + feat_sizes: [32, 32] + embedding_dim: 256 + num_heads: 1 + downsample_rate: 1 + dropout: 0.1 + d_model: 256 + pos_enc_at_cross_attn_keys: true + pos_enc_at_cross_attn_queries: false + cross_attention: + _target_: sam2.modeling.sam.transformer.RoPEAttention + rope_theta: 10000.0 + feat_sizes: [32, 32] + rope_k_repeat: True + embedding_dim: 256 + num_heads: 1 + downsample_rate: 1 + dropout: 0.1 + kv_in_dim: 64 + num_layers: 4 + + memory_encoder: + _target_: sam2.modeling.memory_encoder.MemoryEncoder + out_dim: 64 + position_encoding: + _target_: sam2.modeling.position_encoding.PositionEmbeddingSine + num_pos_feats: 64 + normalize: true + scale: null + temperature: 10000 + mask_downsampler: + _target_: sam2.modeling.memory_encoder.MaskDownSampler + kernel_size: 3 + stride: 2 + padding: 1 + fuser: + _target_: sam2.modeling.memory_encoder.Fuser + layer: + _target_: sam2.modeling.memory_encoder.CXBlock + dim: 256 + kernel_size: 7 + padding: 3 + layer_scale_init_value: 1e-6 + use_dwconv: True # depth-wise convs + num_layers: 2 + + num_maskmem: 7 + image_size: 1024 + # apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask + sigmoid_scale_for_mem_enc: 20.0 + sigmoid_bias_for_mem_enc: -10.0 + use_mask_input_as_output_without_sam: true + # Memory + directly_add_no_mem_embed: true + # use high-resolution feature map in the SAM mask decoder + use_high_res_features_in_sam: true + # output 3 masks on the first click on initial conditioning frames + multimask_output_in_sam: true + # SAM heads + iou_prediction_use_sigmoid: True + # cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder + use_obj_ptrs_in_encoder: true + add_tpos_enc_to_obj_ptrs: false + only_obj_ptrs_in_the_past_for_eval: true + # object occlusion prediction + pred_obj_scores: true + pred_obj_scores_mlp: true + fixed_no_obj_ptr: true + # multimask tracking settings + multimask_output_for_tracking: true + use_multimask_token_for_obj_ptr: true + multimask_min_pt_num: 0 + multimask_max_pt_num: 1 + use_mlp_for_obj_ptr_proj: true + # Compilation flag + compile_image_encoder: False diff --git a/examples/sam2_amg_server/server.py b/examples/sam2_amg_server/server.py new file mode 100644 index 0000000000..7e28f63238 --- /dev/null +++ b/examples/sam2_amg_server/server.py @@ -0,0 +1,142 @@ +import itertools +import uvicorn +import fire +import tempfile +import logging +import sys +import time +from pathlib import Path +from typing import List, Optional + +import torch +import torch._dynamo.config +import torch._inductor.config +from fastapi.responses import Response +from fastapi import FastAPI, File, UploadFile +from fastapi.responses import StreamingResponse +from fastapi.middleware.cors import CORSMiddleware +from io import BytesIO +import shutil +from pydantic import BaseModel +import cv2 + +import matplotlib.pyplot as plt +import numpy as np + +from torch._inductor import config as inductorconfig +inductorconfig.triton.unique_kernel_names = True +inductorconfig.coordinate_descent_tuning = True +inductorconfig.coordinate_descent_check_all_directions = True + +# torch.set_float32_matmul_precision('high') + +def show_anns(anns): + if len(anns) == 0: + return + sorted_anns = sorted(anns, key=(lambda x: x['area']), reverse=True) + ax = plt.gca() + ax.set_autoscale_on(False) + + img = np.ones((sorted_anns[0]['segmentation'].shape[0], sorted_anns[0]['segmentation'].shape[1], 4)) + img[:,:,3] = 0 + ms = [] + for ann in sorted_anns: + m = ann['segmentation'] + ms.append(torch.as_tensor(m)) + color_mask = np.concatenate([np.random.random(3), [0.35]]) + img[m] = color_mask + ax.imshow(img) + return torch.stack(ms) + +class GenerateRequest(BaseModel): + prompt: str + num_steps: Optional[int] = 30 + seed: Optional[int] = 42 + +def main(): + + from sam2.build_sam import build_sam2 + from sam2.automatic_mask_generator import SAM2AutomaticMaskGenerator + + device = "cuda" + sam2_checkpoint = "checkpoints/sam2_hiera_large.pt" + model_cfg = "sam2_hiera_l.yaml" + logging.basicConfig(level=logging.INFO) + + logging.info(f"Loading model: {sam2_checkpoint}") + sam2 = build_sam2(model_cfg, sam2_checkpoint, device=device, apply_postprocessing=False) + sam2.to(device=device) + + mask_generator = SAM2AutomaticMaskGenerator(sam2) #, points_per_batch=None) + + ## TODO: Using CUDA graphs can cause numerical differences? + # mask_generator.predictor.model.image_encoder = torch.compile( + # mask_generator.predictor.model.image_encoder, + # # mode="max-autotune-no-cudagraphs", + # mode="max-autotune", + # fullgraph=True, + # dynamic=False, + # ) + + # torch._dynamo.config.capture_dynamic_output_shape_ops = True + # mask_generator._process_batch = torch.compile( + # mask_generator._process_batch, + # mode="max-autotune-no-cudagraphs", + # fullgraph=True, + # dynamic=True, + # ) + + example_image = cv2.imread('dog.jpg') + example_image = cv2.cvtColor(example_image, cv2.COLOR_BGR2RGB) + t = time.time() + with torch.backends.cuda.sdp_kernel(enable_cudnn=True): + logging.info(f"Running warmup.") + masks = mask_generator.generate(example_image) + masks = mask_generator.generate(example_image) + masks = mask_generator.generate(example_image) + logging.info(f"Warmup took {time.time() - t}s.") + + app = FastAPI() + + # Allow all origins (you can restrict it in production) + app.add_middleware( + CORSMiddleware, + allow_origins=["*"], + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], + ) + + @app.post("/upload") + async def upload_image(image: UploadFile = File(...)): + # Save the uploaded image to a temporary location + temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=f"_{image.filename}") + with open(temp_file.name, "wb") as b: + shutil.copyfileobj(image.file, b) + + # Read the image back into memory to send as response + example_image = cv2.imread(temp_file.name) + t = time.time() + with torch.backends.cuda.sdp_kernel(enable_cudnn=True): + masks = mask_generator.generate(example_image) + print(f"Took {time.time() - t} to generate a mask for input image.") + # Save an example + plt.figure(figsize=(example_image.shape[1]/100., example_image.shape[0]/100.), dpi=100) + plt.imshow(example_image) + show_anns(masks) + plt.axis('off') + plt.tight_layout() + plt.savefig(temp_file.name, format='png') + + # Read the image back into memory to send as response + with open(temp_file.name, "rb") as f: + image_data = f.read() + + # Return the image as a StreamingResponse + return StreamingResponse(BytesIO(image_data), media_type="image/png") + + + uvicorn.run(app, host="127.0.0.1", port=5000, log_level="info") + +if __name__ == "__main__": + fire.Fire(main) From 900f9acccb8bc3c102d3a7095d1b1f0355057bec Mon Sep 17 00:00:00 2001 From: Jerry Zhang Date: Tue, 8 Oct 2024 17:05:56 -0700 Subject: [PATCH 4/6] Update README.md (#1036) --- torchao/prototype/awq/README.md | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/torchao/prototype/awq/README.md b/torchao/prototype/awq/README.md index e7b7f782f7..1040610db5 100644 --- a/torchao/prototype/awq/README.md +++ b/torchao/prototype/awq/README.md @@ -15,7 +15,9 @@ Evaluation perplexity numbers were calculated using the script in awq/example.py The following tests were performed using LM eval and groupsize = 128 -| Model | Quantization | Perplexity | Truthful QA MC2 | WinoGrande | ARC challenge | + +| Model | Quantization | Perplexity | Truthful QA MC2 | WinoGrande | ARC challenge | +|--------------------|--------------|------------|-----------------|------------|---------------| | Llama-3-8B-Instruct| bfloat16 | 10.936 | 0.540 | 0.783 | 0.567 | | | awq-hqq-int4 | 11.383 | 0.522 | 0.772 | 0.543 | | | awq-uint4 | 11.409 | 0.519 | 0.756 | 0.577 | From 49b1fb61c8b8eceda755579a2fd92c756d822de2 Mon Sep 17 00:00:00 2001 From: Jack-Khuu Date: Wed, 9 Oct 2024 11:29:51 -0700 Subject: [PATCH 5/6] Allow deprecated declarations what using Parallel ExecuTorch Differential Revision: D64020498 Pull Request resolved: https://github.com/pytorch/ao/pull/1031 --- torchao/experimental/CMakeLists.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/torchao/experimental/CMakeLists.txt b/torchao/experimental/CMakeLists.txt index 64a0400015..2c75b9e1d0 100644 --- a/torchao/experimental/CMakeLists.txt +++ b/torchao/experimental/CMakeLists.txt @@ -50,6 +50,7 @@ if(TORCHAO_OP_TARGET STREQUAL "aten") add_library(torchao_ops_${TORCHAO_OP_TARGET} SHARED) elseif(TORCHAO_OP_TARGET STREQUAL "executorch") add_library(torchao_ops_${TORCHAO_OP_TARGET} STATIC) + add_compile_options("-Wno-error=deprecated") else() message(FATAL_ERROR "Unknown TORCHAO_OP_TARGET: ${TORCHAO_OP_TARGET}. Please choose one of: aten, executorch.") endif() From 9e0a59f5866a02e756c9478db5a5b28c4e7b342a Mon Sep 17 00:00:00 2001 From: andrewor14 Date: Wed, 9 Oct 2024 17:40:24 -0400 Subject: [PATCH 6/6] Make module swap the main QAT flow again (#1037) Summary: Following https://github.com/pytorch/ao/issues/987, this commit makes module swap the main QAT flow today. We remove all tensor subclass fake quantize injection logic since this is not needed in both the long term and the short term plans for QAT. In the short term, we will continue to use a full module swap flow, and only migrate to the long term flow once there is general distributed support for tensor subclasses and when tensor subclass composability provides meaningful benefits. Test Plan: python test/quantization/test_qat.py [ghstack-poisoned] --- test/quantization/test_qat.py | 141 ++----- .../quantization/prototype/qat/__init__.py | 13 +- .../prototype/qat/_module_swap_api.py | 364 +---------------- torchao/quantization/prototype/qat/api.py | 229 +---------- torchao/quantization/prototype/qat/linear.py | 377 ++++++++++++++++++ torchao/quantization/prototype/qat/utils.py | 79 ---- 6 files changed, 418 insertions(+), 785 deletions(-) create mode 100644 torchao/quantization/prototype/qat/linear.py diff --git a/test/quantization/test_qat.py b/test/quantization/test_qat.py index 72ffc23ab6..e1e670d5da 100644 --- a/test/quantization/test_qat.py +++ b/test/quantization/test_qat.py @@ -18,15 +18,11 @@ from torchao.quantization.prototype.qat.api import ( ComposableQATQuantizer, ) -from torchao.quantization.prototype.qat.affine_fake_quantized_tensor import ( - AffineFakeQuantizedTensor, -) from torchao.quantization.prototype.qat.utils import ( _choose_qparams_per_token_asymmetric, _fake_quantize_per_channel_group, _fake_quantize_per_token, _GenericFakeQuantize, - _QAT_LINEAR_SUBCLASS_INPUT_PREHOOK, ) from torchao.quantization.quant_api import ( int4_weight_only, @@ -164,7 +160,7 @@ def _set_ptq_weight( Int8DynActInt4WeightLinear, WeightOnlyInt4Linear, ) - from torchao.quantization.prototype.qat._module_swap_api import ( + from torchao.quantization.prototype.qat.linear import ( Int8DynActInt4WeightQATLinear, Int4WeightOnlyQATLinear, ) @@ -196,7 +192,7 @@ def _set_ptq_weight( @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower") def test_qat_8da4w_linear(self): - from torchao.quantization.prototype.qat._module_swap_api import Int8DynActInt4WeightQATLinear + from torchao.quantization.prototype.qat.linear import Int8DynActInt4WeightQATLinear from torchao.quantization.GPTQ import Int8DynActInt4WeightLinear group_size = 128 @@ -219,45 +215,17 @@ def test_qat_8da4w_linear(self): ptq_out = ptq_linear(x2) torch.testing.assert_close(ptq_out, qat_out, atol=0, rtol=0) - # TODO: compare against quantize_ API instead @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower") def test_qat_8da4w_quantizer(self): from torchao.quantization.prototype.qat import Int8DynActInt4WeightQATQuantizer - from torchao.quantization.GPTQ import Int8DynActInt4WeightQuantizer - - group_size = 16 - torch.manual_seed(self.SEED) - m = M() - m2 = copy.deepcopy(m) - qat_quantizer = Int8DynActInt4WeightQATQuantizer(groupsize=group_size) - ptq_quantizer = Int8DynActInt4WeightQuantizer(groupsize=group_size) - qat_model = qat_quantizer.prepare(m) - ptq_model = ptq_quantizer.quantize(m2) - - # Compare model values - torch.manual_seed(self.SEED) - x = m.example_inputs() - x2 = copy.deepcopy(x) - qat_out = qat_model(*x) - ptq_out = ptq_model(*x2) - torch.testing.assert_close(ptq_out, qat_out, atol=0, rtol=0) - - # Convert QAT model and compare model values - converted_model = qat_quantizer.convert(qat_model) - converted_out = converted_model(*x) - torch.testing.assert_close(ptq_out, converted_out, atol=0, rtol=0) - - @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower") - def test_qat_8da4w_quantizer_module_swap(self): - from torchao.quantization.prototype.qat import Int8DynActInt4WeightQATQuantizer - from torchao.quantization.prototype.qat._module_swap_api import Int8DynActInt4WeightQATQuantizerModuleSwap + from torchao.quantization.prototype.qat.linear import Int8DynActInt4WeightQATQuantizer group_size = 16 torch.manual_seed(self.SEED) m = M() m2 = copy.deepcopy(m) subclass_quantizer = Int8DynActInt4WeightQATQuantizer(groupsize=group_size) - module_swap_quantizer = Int8DynActInt4WeightQATQuantizerModuleSwap(groupsize=group_size) + module_swap_quantizer = Int8DynActInt4WeightQATQuantizer(groupsize=group_size) subclass_model = subclass_quantizer.prepare(m) module_swap_model = module_swap_quantizer.prepare(m2) @@ -288,20 +256,6 @@ def test_qat_8da4w_quantizer_meta_weights(self): qat_model = qat_quantizer.prepare(m) self.assertTrue(all(v.is_meta for v in qat_model.state_dict().values())) - def _copy_subclass_weights( - self, - nn_linear: torch.nn.Linear, - subclass_linear: AffineFakeQuantizedTensor, - ): - nn_linear.weight = torch.nn.Parameter(subclass_linear.weight.original_tensor) - - def _assert_matches_subclass_weights( - self, - nn_linear: torch.nn.Linear, - subclass_linear: AffineFakeQuantizedTensor, - ): - torch.testing.assert_close(nn_linear.weight, subclass_linear.weight.original_tensor, atol=0, rtol=0) - @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower") def test_qat_8da4w_quantizer_disable_fake_quant(self): """ @@ -313,16 +267,6 @@ def test_qat_8da4w_quantizer_disable_fake_quant(self): enable_8da4w_fake_quant, ) - def assert_fake_quant_enabled(m: torch.nn.Linear, enabled: bool): - self.assertTrue(isinstance(m.weight, AffineFakeQuantizedTensor)) - self.assertEqual(m.weight.fake_quant_enabled, enabled) - self.assertTrue(hasattr(m, _QAT_LINEAR_SUBCLASS_INPUT_PREHOOK)) - (_, handle) = getattr(m, _QAT_LINEAR_SUBCLASS_INPUT_PREHOOK) - if enabled: - self.assertIsNotNone(handle) - else: - self.assertIsNone(handle) - group_size = 16 torch.manual_seed(self.SEED) m = M() @@ -331,14 +275,14 @@ def assert_fake_quant_enabled(m: torch.nn.Linear, enabled: bool): quantizer = Int8DynActInt4WeightQATQuantizer(groupsize=group_size) qat_model = quantizer.prepare(m) qat_model.apply(disable_8da4w_fake_quant) - assert_fake_quant_enabled(qat_model.linear1, enabled=False) - assert_fake_quant_enabled(qat_model.linear2, enabled=False) - assert_fake_quant_enabled(qat_model.sub.linear, enabled=False) + self.assertFalse(qat_model.linear1._fake_quant_enabled) + self.assertFalse(qat_model.linear2._fake_quant_enabled) + self.assertFalse(qat_model.sub.linear._fake_quant_enabled) # Disabled fake quant is just a normal linear - self._copy_subclass_weights(m2.linear1, qat_model.linear1) - self._copy_subclass_weights(m2.linear2, qat_model.linear2) - self._copy_subclass_weights(m2.sub.linear, qat_model.sub.linear) + m2.linear1.weight = torch.nn.Parameter(qat_model.linear1.weight) + m2.linear2.weight = torch.nn.Parameter(qat_model.linear2.weight) + m2.sub.linear.weight = torch.nn.Parameter(qat_model.sub.linear.weight) torch.manual_seed(self.SEED) x = m.example_inputs() x2 = copy.deepcopy(x) @@ -348,16 +292,16 @@ def assert_fake_quant_enabled(m: torch.nn.Linear, enabled: bool): # Renable fake quant qat_model.apply(enable_8da4w_fake_quant) - assert_fake_quant_enabled(qat_model.linear1, enabled=True) - assert_fake_quant_enabled(qat_model.linear2, enabled=True) - assert_fake_quant_enabled(qat_model.sub.linear, enabled=True) + self.assertTrue(qat_model.linear1._fake_quant_enabled) + self.assertTrue(qat_model.linear2._fake_quant_enabled) + self.assertTrue(qat_model.sub.linear._fake_quant_enabled) # Fake quant should be applied as normal quantizer2 = Int8DynActInt4WeightQATQuantizer(groupsize=group_size) qat_model2 = quantizer2.prepare(m3) - qat_model2.linear1.weight.original_tensor = qat_model.linear1.weight.original_tensor - qat_model2.linear2.weight.original_tensor = qat_model.linear2.weight.original_tensor - qat_model2.sub.linear.weight.original_tensor = qat_model.sub.linear.weight.original_tensor + qat_model2.linear1.weight = qat_model.linear1.weight + qat_model2.linear2.weight = qat_model.linear2.weight + qat_model2.sub.linear.weight = qat_model.sub.linear.weight torch.manual_seed(self.SEED) x = m.example_inputs() x2 = copy.deepcopy(x) @@ -382,9 +326,9 @@ def test_qat_8da4w_quantizer_disable_fake_quant_backward(self): quantizer = Int8DynActInt4WeightQATQuantizer(groupsize=group_size) qat_model = quantizer.prepare(m) qat_model.apply(disable_8da4w_fake_quant) - self._copy_subclass_weights(nn_model.linear1, qat_model.linear1) - self._copy_subclass_weights(nn_model.linear2, qat_model.linear2) - self._copy_subclass_weights(nn_model.sub.linear, qat_model.sub.linear) + nn_model.linear1.weight = torch.nn.Parameter(qat_model.linear1.weight) + nn_model.linear2.weight = torch.nn.Parameter(qat_model.linear2.weight) + nn_model.sub.linear.weight = torch.nn.Parameter(qat_model.sub.linear.weight) # Simulate training for both models optimizer1 = torch.optim.SGD(nn_model.parameters(), lr=0.001, momentum=0.9, weight_decay=1e-5) @@ -406,9 +350,9 @@ def test_qat_8da4w_quantizer_disable_fake_quant_backward(self): optimizer2.step() # After 1 training step, weights should match exactly - self._assert_matches_subclass_weights(nn_model.linear1, qat_model.linear1) - self._assert_matches_subclass_weights(nn_model.linear2, qat_model.linear2) - self._assert_matches_subclass_weights(nn_model.sub.linear, qat_model.sub.linear) + torch.testing.assert_close(nn_model.linear1.weight, qat_model.linear1.weight, atol=0, rtol=0) + torch.testing.assert_close(nn_model.linear2.weight, qat_model.linear2.weight, atol=0, rtol=0) + torch.testing.assert_close(nn_model.sub.linear.weight, qat_model.sub.linear.weight, atol=0, rtol=0) def _test_qat_quantized_gradients(self, quantizer): """ @@ -542,7 +486,7 @@ def test_qat_4w_primitives(self): @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower") @unittest.skipIf(not _CUDA_IS_AVAILABLE, "skipping when cuda is not available") def test_qat_4w_linear(self): - from torchao.quantization.prototype.qat._module_swap_api import Int4WeightOnlyQATLinear + from torchao.quantization.prototype.qat.linear import Int4WeightOnlyQATLinear from torchao.quantization.GPTQ import WeightOnlyInt4Linear group_size = 128 @@ -567,39 +511,6 @@ def test_qat_4w_linear(self): ptq_out = ptq_linear(x2) self._assert_close_4w(qat_out, ptq_out) - @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower") - @unittest.skipIf(not _CUDA_IS_AVAILABLE, "skipping when cuda is not available") - def test_qat_4w_quantizer(self): - from torchao.quantization.prototype.qat import Int4WeightOnlyQATQuantizer - from torchao.quantization.GPTQ import Int4WeightOnlyQuantizer - - group_size = 32 - inner_k_tiles = 8 - device = torch.device("cuda") - dtype = torch.bfloat16 - torch.manual_seed(self.SEED) - m = M().to(device).to(dtype) - m2 = copy.deepcopy(m) - qat_quantizer = Int4WeightOnlyQATQuantizer( - groupsize=group_size, inner_k_tiles=inner_k_tiles, - ) - qat_model = qat_quantizer.prepare(m) - ptq_model = m2 - quantize_(ptq_model, int4_weight_only(group_size, TensorCoreTiledLayoutType(inner_k_tiles))) - - # Compare model values - torch.manual_seed(self.SEED) - x = [i.to(device).to(dtype) for i in m.example_inputs()] - x2 = copy.deepcopy(x) - qat_out = qat_model(*x) - ptq_out = ptq_model(*x2) - self._assert_close_4w(qat_out, ptq_out) - - # Convert QAT model and compare model values - converted_model = qat_quantizer.convert(qat_model) - converted_out = converted_model(*x) - torch.testing.assert_close(converted_out, ptq_out, atol=0, rtol=0) - @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower") def test_qat_4w_quantizer_gradients(self): from torchao.quantization.prototype.qat import Int4WeightOnlyQATQuantizer @@ -608,9 +519,9 @@ def test_qat_4w_quantizer_gradients(self): @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower") @unittest.skipIf(not _CUDA_IS_AVAILABLE, "skipping when cuda is not available") - def test_qat_4w_quantizer_module_swap(self): + def test_qat_4w_quantizer(self): from torchao.quantization.prototype.qat import Int4WeightOnlyQATQuantizer - from torchao.quantization.prototype.qat._module_swap_api import Int4WeightOnlyQATQuantizerModuleSwap + from torchao.quantization.prototype.qat.linear import Int4WeightOnlyQATQuantizer group_size = 32 inner_k_tiles = 8 @@ -622,7 +533,7 @@ def test_qat_4w_quantizer_module_swap(self): subclass_quantizer = Int4WeightOnlyQATQuantizer( groupsize=group_size, inner_k_tiles=inner_k_tiles, ) - module_swap_quantizer = Int4WeightOnlyQATQuantizerModuleSwap( + module_swap_quantizer = Int4WeightOnlyQATQuantizer( groupsize=group_size, inner_k_tiles=inner_k_tiles, ) subclass_model = subclass_quantizer.prepare(m) diff --git a/torchao/quantization/prototype/qat/__init__.py b/torchao/quantization/prototype/qat/__init__.py index 62740839b7..09ea6e708d 100644 --- a/torchao/quantization/prototype/qat/__init__.py +++ b/torchao/quantization/prototype/qat/__init__.py @@ -1,17 +1,14 @@ from .api import ( + ComposableQATQuantizer, +) +from .linear import ( disable_4w_fake_quant, disable_8da4w_fake_quant, enable_4w_fake_quant, enable_8da4w_fake_quant, - int4_weight_only_fake_quantize, - int8_dynamic_activation_int4_weight_fake_quantize, - ComposableQATQuantizer, Int4WeightOnlyQATQuantizer, - Int8DynActInt4WeightQATQuantizer, -) - -from ._module_swap_api import ( Int8DynActInt4WeightQATLinear, + Int8DynActInt4WeightQATQuantizer, ) from .embedding import ( Int4WeightOnlyEmbeddingQATQuantizer, @@ -22,8 +19,6 @@ "disable_8da4w_fake_quant", "enable_4w_fake_quant", "enable_8da4w_fake_quant", - "int4_weight_only_fake_quantize", - "int8_dynamic_activation_int4_weight_fake_quantize", "ComposableQATQuantizer", "Int4WeightOnlyQATQuantizer", "Int4WeightOnlyEmbeddingQATQuantizer" diff --git a/torchao/quantization/prototype/qat/_module_swap_api.py b/torchao/quantization/prototype/qat/_module_swap_api.py index a9239a03d5..0b44974f21 100644 --- a/torchao/quantization/prototype/qat/_module_swap_api.py +++ b/torchao/quantization/prototype/qat/_module_swap_api.py @@ -1,355 +1,11 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. - -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. - -from typing import Any - -import torch -import torch.nn.functional as F - -from torchao.quantization.GPTQ import ( - _check_linear_int4_k, - _replace_linear_int4, - _replace_linear_8da4w, - get_groupwise_affine_qparams, - groupwise_affine_quantize_tensor, - Int8DynActInt4WeightLinear, - WeightOnlyInt4Linear, -) -from torchao.quantization.quant_primitives import ZeroPointDomain -from torchao.quantization.utils import get_group_qparams_symmetric -from .api import ( - Int8DynActInt4WeightQATQuantizer, - Int4WeightOnlyQATQuantizer, -) -from .utils import ( - _choose_qparams_per_token_asymmetric, - _fake_quantize_per_channel_group, - _fake_quantize_per_token, - _get_qmin_qmax, +# For backward compatibility only +# These will be removed in the future + +from .linear import ( + Int8DynActInt4WeightQATQuantizer as Int8DynActInt4WeightQATQuantizerModuleSwap, + Int4WeightOnlyQATQuantizer as Int4WeightOnlyQATQuantizerModuleSwap, + enable_8da4w_fake_quant as enable_8da4w_fake_quant_module_swap, + disable_8da4w_fake_quant as disable_8da4w_fake_quant_module_swap, + enable_4w_fake_quant as enable_4w_fake_quant_module_swap, + disable_4w_fake_quant as disable_4w_fake_quant_module_swap, ) - - -# TODO: make module swap the main flow again, and remove the quantize_ flow -# TODO: rename this file to linear.py - -# ========================================================= -# | Linear int8 dynamic activations + int4 weight QAT | -# ========================================================= - - -class Int8DynActInt4WeightQATQuantizerModuleSwap(Int8DynActInt4WeightQATQuantizer): - """ - Quantizer for performing QAT on a model, where linear layers have int8 - dynamic per token fake quantized activations and int4 fake quantized - grouped per channel weights. - """ - - def prepare( - self, - model: torch.nn.Module, - *args: Any, - **kwargs: Any - ) -> torch.nn.Module: - _replace_linear_8da4w( - model, - self.groupsize, - self.padding_allowed, - self.precision, - self.scales_precision, - Int8DynActInt4WeightQATLinear, - copy_weights=True, - ) - return model - - def convert( - self, - model: torch.nn.Module, - *args: Any, - **kwargs: Any - ) -> torch.nn.Module: - _convert_qat_linear_8da4w(model) - return model - - -def _convert_qat_linear_8da4w(module: torch.nn.Module): - """ - Replace all `Int8DynActInt4WeightQATLinear` with `Int8DynActInt4WeightLinear`. - """ - for name, child in module.named_children(): - if isinstance(child, Int8DynActInt4WeightQATLinear): - quantized_linear = Int8DynActInt4WeightLinear( - child.in_features, - child.out_features, - bias=False, - groupsize=child.groupsize, - precision=child.precision, - scales_precision=child.scales_precision, - ) - setattr(module, name, quantized_linear) - - # Load weights and qparams into quantized linear - n_bit = 4 - (qmin, qmax) = _get_qmin_qmax(n_bit) - (s, zp) = get_group_qparams_symmetric(child.weight, n_bit, child.groupsize) - from torchao._executorch_ops import _quantized_decomposed_quantize_per_channel_group_wrapper - q_weight = _quantized_decomposed_quantize_per_channel_group_wrapper( - child.weight, s, zp, qmin, qmax, torch.int8, child.groupsize, - ) - quantized_linear.weight = q_weight - quantized_linear.scales = s - quantized_linear.zeros = zp - else: - _convert_qat_linear_8da4w(child) - - -class Int8DynActInt4WeightQATLinear(torch.nn.Linear): - """ - This module implements a linear layer with int8 dynamic per token fake - quantized activations with int4 fake quantized grouped per channel weights. - - args: - groupsize: the number of elements in each quantized group for weights - precision: precision of weights - scales_precision: precision of per group scales and zero points - """ - - def __init__( - self, - in_features: int, - out_features: int, - bias: bool = False, - device: torch.device = None, - groupsize: int = 256, - precision: torch.dtype = torch.float32, - scales_precision: torch.dtype = torch.float32, - ) -> None: - super().__init__( - in_features, - out_features, - bias, - device=device, - dtype=precision, - ) - assert ( - in_features % groupsize == 0 - ), f"require in_features:{in_features} % groupsize:{groupsize} == 0" - assert not bias, "require bias=False" - self.groupsize = groupsize - self.precision = precision - self.scales_precision = scales_precision - # TODO: make this configurable? - self.zero_points_precision = torch.int32 - self._fake_quant_enabled = True - - def enable_fake_quant(self, enabled: bool = True): - self._fake_quant_enabled = enabled - - def disable_fake_quant(self): - self.enable_fake_quant(False) - - def forward(self, x: torch.Tensor) -> torch.Tensor: - # activations: int8 dynamic asymmetric quant - if self._fake_quant_enabled: - (act_scales, act_zp) = _choose_qparams_per_token_asymmetric( - x, self.scales_precision, self.zero_points_precision, - ) - (act_qmin, act_qmax) = _get_qmin_qmax(8) - x_fq = _fake_quantize_per_token( - x, act_scales, act_zp, act_qmin, act_qmax, - ) - else: - x_fq = x - - # weights: int4 grouped per channel symmetric quant - if self._fake_quant_enabled: - (weight_scales, weight_zp) = get_group_qparams_symmetric( - self.weight, 4, self.groupsize, self.scales_precision, - ) - # TODO: pass zp dtype to `get_group_qparams_symmetric` instead - weight_zp = weight_zp.to(self.zero_points_precision) - (weight_qmin, weight_qmax) = _get_qmin_qmax(4) - w_fq = _fake_quantize_per_channel_group( - self.weight, - weight_scales, - weight_zp, - weight_qmin, - weight_qmax, - self.groupsize, - ) - else: - w_fq = self.weight - return F.linear(x_fq, w_fq) - - -def enable_8da4w_fake_quant_module_swap(mod: torch.nn.Module): - """ - Enable fake quantization for `Int8DynActInt4WeightQATLinear`. - """ - if isinstance(mod, Int8DynActInt4WeightQATLinear): - mod.enable_fake_quant() - - -def disable_8da4w_fake_quant_module_swap(mod: torch.nn.Module): - """ - Disable fake quantization for `Int8DynActInt4WeightQATLinear`. - """ - if isinstance(mod, Int8DynActInt4WeightQATLinear): - mod.disable_fake_quant() - - -# =================================== -# | Linear int4 weight-only QAT | -# =================================== - - -class Int4WeightOnlyQATQuantizerModuleSwap(Int4WeightOnlyQATQuantizer): - """ - Quantizer for performing QAT on a model, where linear layers have - int4 fake quantized grouped per channel weights. - """ - - def prepare( - self, - model: torch.nn.Module, - *args: Any, - **kwargs: Any - ) -> torch.nn.Module: - _replace_linear_int4( - model, - self.groupsize, - self.inner_k_tiles, - padding_allowed=True, - precision=self.precision, - scales_precision=self.scales_precision, - linear_class=Int4WeightOnlyQATLinear, - copy_weights=True, - ) - return model - - def convert( - self, - model: torch.nn.Module, - *args: Any, - **kwargs: Any - ) -> torch.nn.Module: - _convert_qat_linear_4w(model) - return model - - -def _convert_qat_linear_4w(module: torch.nn.Module): - """ - Replace all `Int4WeightOnlyQATLinear` with `WeightOnlyInt4Linear`. - """ - for name, child in module.named_children(): - if isinstance(child, Int4WeightOnlyQATLinear): - in_features = child.in_features - out_features = child.out_features - groupsize = child.groupsize - inner_k_tiles = child.inner_k_tiles - quantized_linear = WeightOnlyInt4Linear( - in_features, - out_features, - bias=False, - groupsize=groupsize, - inner_k_tiles=inner_k_tiles, - precision=child.precision, - scales_precision=child.scales_precision, - ) - setattr(module, name, quantized_linear) - - # Load weights and qparams into quantized linear - n_bit = 4 - (q_weight, scales_and_zeros) = groupwise_affine_quantize_tensor( - child.weight, n_bit, child.groupsize, - ) - q_weight = torch.ops.aten._convert_weight_to_int4pack( - q_weight.to(child.weight.device), child.inner_k_tiles, - ) - quantized_linear.weight = q_weight - quantized_linear.scales_and_zeros = scales_and_zeros - else: - _convert_qat_linear_4w(child) - - -class Int4WeightOnlyQATLinear(torch.nn.Linear): - """ - This module implements a linear layer with int4 fake quantized grouped - per channel weights, with forward numerics matching `WeightOnlyInt4Linear`, - which uses the efficient int4 tinygemm kernel. - - args: - groupsize: the number of elements in each quantized group for weights - precision: precision of weights - scales_precision: precision of per group scales and zero points - """ - - def __init__( - self, - in_features: int, - out_features: int, - bias: bool = False, - device: torch.device = None, - groupsize: int = 256, - inner_k_tiles: int = 8, - precision: torch.dtype = torch.bfloat16, - scales_precision: torch.dtype = torch.bfloat16, - ) -> None: - super().__init__( - in_features, - out_features, - bias, - device=device, - dtype=precision, - ) - assert not bias, "require bias=False" - assert scales_precision == torch.bfloat16, "only bf16 is supported for scales" - if not _check_linear_int4_k(in_features, groupsize, inner_k_tiles): - raise ValueError("Padding for QAT 4w is not supported yet") - self.groupsize = groupsize - self.inner_k_tiles = inner_k_tiles - self.precision = precision - self.scales_precision = scales_precision - self._fake_quant_enabled = True - - def enable_fake_quant(self, enabled: bool = True): - self._fake_quant_enabled = enabled - - def disable_fake_quant(self): - self.enable_fake_quant(False) - - def forward(self, x: torch.Tensor) -> torch.Tensor: - n_bit = 4 - qmin = 0 - qmax = 2 ** n_bit - 1 - scales, zero_points = get_groupwise_affine_qparams( - self.weight, n_bit, self.groupsize, self.scales_precision, - ) - w_fq = _fake_quantize_per_channel_group( - self.weight, - scales, - zero_points, - qmin, - qmax, - self.groupsize, - ZeroPointDomain.FLOAT, - ) - return F.linear(x, w_fq) - - -def enable_4w_fake_quant_module_swap(mod: torch.nn.Module): - """ - Enable fake quantization for `Int4WeightOnlyQATLinear`. - """ - if isinstance(mod, Int4WeightOnlyQATLinear): - mod.enable_fake_quant() - - -def disable_4w_fake_quant_module_swap(mod: torch.nn.Module): - """ - Disable fake quantization for `Int4WeightOnlyQATLinear`. - """ - if isinstance(mod, Int4WeightOnlyQATLinear): - mod.disable_fake_quant() diff --git a/torchao/quantization/prototype/qat/api.py b/torchao/quantization/prototype/qat/api.py index e1c5221e1e..93717271bb 100644 --- a/torchao/quantization/prototype/qat/api.py +++ b/torchao/quantization/prototype/qat/api.py @@ -4,34 +4,11 @@ # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. -from typing import Any, List, Optional +from typing import Any, List import torch -import torch.nn.functional as F -from torchao.dtypes import ( - TensorCoreTiledLayoutType, -) -from torchao.quantization.quant_api import ( - _get_linear_subclass_inserter, - _replace_with_custom_fn_if_matches_filter, - int4_weight_only, - int8_dynamic_activation_int4_weight, - quantize_, -) -from torchao.quantization.quant_primitives import ( - MappingType, - ZeroPointDomain, -) from torchao.quantization.unified import TwoStepQuantizer -from torchao.quantization.utils import _get_per_token_block_size -from .affine_fake_quantized_tensor import to_affine_fake_quantized -from .utils import ( - _enable_fake_quant, - _get_qat_linear_subclass_inserter, - _is_linear_with_fq_weight, - _unwrap_affine_fake_quantized_tensor, -) class ComposableQATQuantizer(TwoStepQuantizer): @@ -70,207 +47,3 @@ def convert( for quantizer in self.quantizers: model = quantizer.convert(model) return model - - -# ================= -# | 8da4w QAT | -# ================= - -def int8_dynamic_activation_int4_weight_fake_quantize(group_size=32): - """ - Applies int8 dynamic per token asymmetric activation fake quantization and - int4 per group weight symmetric fake quantization to linear. Please see - :func:`~torchao.quantization.int8_dynamic_activation_int4_weight` for more details. - - Example usage:: - - from torchao.quantization import quantize_ - quantize_(model, int8_dynamic_activation_int4_weight_fake_quantize(group_size=32)) - """ - # avoid circular dep - from torchao.dtypes import to_affine_quantized_intx - - def _apply_weight_fake_quant(weight: torch.Tensor): - mapping_type = MappingType.SYMMETRIC - block_size = (1, group_size) - target_dtype = torch.int8 - eps = torch.finfo(torch.float32).eps - quant_min = -8 - quant_max = 7 - return to_affine_fake_quantized( - weight, - mapping_type, - block_size, - target_dtype, - quant_min, - quant_max, - eps, - ) - - def _apply_input_activation_fake_quant(x: torch.Tensor): - mapping_type = MappingType.ASYMMETRIC - target_dtype = torch.int8 - return to_affine_fake_quantized( - x, - mapping_type, - _get_per_token_block_size(x), - target_dtype, - ) - - return _get_qat_linear_subclass_inserter( - _apply_weight_fake_quant, - _apply_input_activation_fake_quant, - ) - -class Int8DynActInt4WeightQATQuantizer(TwoStepQuantizer): - """ - Quantizer for performing QAT on a model, where linear layers have int8 - dynamic per token fake quantized activations and int4 fake quantized - grouped per channel weights. - """ - - def __init__( - self, - groupsize: int = 256, - padding_allowed: bool = False, - precision: torch.dtype = torch.float32, - scales_precision: torch.dtype = torch.float32, - ) -> None: - super().__init__() - self.groupsize: int = groupsize - self.padding_allowed: bool = padding_allowed - self.precision: torch.dtype = precision - self.scales_precision: torch.dtype = scales_precision - - def prepare( - self, - model: torch.nn.Module, - *args: Any, - **kwargs: Any - ) -> torch.nn.Module: - quantize_( - model, - int8_dynamic_activation_int4_weight_fake_quantize(group_size=self.groupsize), - ) - return model - - def convert( - self, - model: torch.nn.Module, - *args: Any, - **kwargs: Any - ) -> torch.nn.Module: - unwrap_fn = _get_linear_subclass_inserter(_unwrap_affine_fake_quantized_tensor) - filter_fn = _is_linear_with_fq_weight - model = _replace_with_custom_fn_if_matches_filter(model, unwrap_fn, filter_fn) - quantize_fn = int8_dynamic_activation_int4_weight(self.groupsize) - quantize_(model, quantize_fn) - return model - - -def enable_8da4w_fake_quant(mod: torch.nn.Module): - """ - Enable fake quantization for int8 dynamic activations + int4 weight. - """ - _enable_fake_quant(mod, enable=True) - -def disable_8da4w_fake_quant(mod: torch.nn.Module): - """ - Disable fake quantization for int8 dynamic activations + int4 weight. - """ - _enable_fake_quant(mod, enable=False) - - -# ================== -# | int4wo QAT | -# ================== - -def int4_weight_only_fake_quantize(group_size=128): - """ - Applies uint4 weight-only asymmetric per-group fake quantization to linear layers. - Please see :func:`~torchao.quantization.int4_weight_only` for more details. - - Example usage:: - - from torchao.quantization import quantize_ - quantize_(model, int4_weight_only_fake_quantize(group_size=32)) - """ - def _apply_fake_quant(weight): - mapping_type = MappingType.ASYMMETRIC - block_size = (1, group_size) - target_dtype = torch.int32 - quant_min = 0 - quant_max = 15 - eps = 1e-6 - preserve_zero = False - zero_point_dtype = torch.bfloat16 - zero_point_domain = ZeroPointDomain.FLOAT - return to_affine_fake_quantized( - weight, - mapping_type, - block_size, - target_dtype, - quant_min, - quant_max, - eps, - zero_point_dtype=zero_point_dtype, - preserve_zero=preserve_zero, - zero_point_domain=zero_point_domain, - ) - return _get_qat_linear_subclass_inserter(_apply_fake_quant) - -class Int4WeightOnlyQATQuantizer(TwoStepQuantizer): - """ - Quantizer for performing QAT on a model, where linear layers have - int4 fake quantized grouped per channel weights. - """ - - def __init__( - self, - groupsize: int = 256, - inner_k_tiles: Optional[int] = 8, - precision: torch.dtype = torch.bfloat16, - scales_precision: torch.dtype = torch.bfloat16, - ) -> None: - super().__init__() - assert inner_k_tiles in [2, 4, 8] - assert groupsize in [32, 64, 128, 256] - self.inner_k_tiles = inner_k_tiles - self.groupsize = groupsize - self.precision = precision - self.scales_precision = scales_precision - - def prepare( - self, - model: torch.nn.Module, - *args: Any, - **kwargs: Any - ) -> torch.nn.Module: - quantize_(model, int4_weight_only_fake_quantize(group_size=self.groupsize)) - return model - - def convert( - self, - model: torch.nn.Module, - *args: Any, - **kwargs: Any - ) -> torch.nn.Module: - unwrap_fn = _get_linear_subclass_inserter(_unwrap_affine_fake_quantized_tensor) - filter_fn = _is_linear_with_fq_weight - model = _replace_with_custom_fn_if_matches_filter(model, unwrap_fn, filter_fn) - layout_type = TensorCoreTiledLayoutType(self.inner_k_tiles) - quantize_fn = int4_weight_only(self.groupsize, layout_type) - quantize_(model, quantize_fn) - return model - -def enable_4w_fake_quant(mod: torch.nn.Module): - """ - Enable fake quantization for int4 weight only. - """ - _enable_fake_quant(mod, enable=True) - -def disable_4w_fake_quant(mod: torch.nn.Module): - """ - Disable fake quantization for int4 weight only. - """ - _enable_fake_quant(mod, enable=False) diff --git a/torchao/quantization/prototype/qat/linear.py b/torchao/quantization/prototype/qat/linear.py new file mode 100644 index 0000000000..07276ba84c --- /dev/null +++ b/torchao/quantization/prototype/qat/linear.py @@ -0,0 +1,377 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Any, Optional + +import torch +import torch.nn.functional as F + +from torchao.quantization.GPTQ import ( + _check_linear_int4_k, + _replace_linear_int4, + _replace_linear_8da4w, + get_groupwise_affine_qparams, + groupwise_affine_quantize_tensor, + Int8DynActInt4WeightLinear, + WeightOnlyInt4Linear, +) +from torchao.quantization.quant_primitives import ZeroPointDomain +from torchao.quantization.unified import TwoStepQuantizer +from torchao.quantization.utils import get_group_qparams_symmetric +from .utils import ( + _choose_qparams_per_token_asymmetric, + _fake_quantize_per_channel_group, + _fake_quantize_per_token, + _get_qmin_qmax, +) + + +# ========================================================= +# | Linear int8 dynamic activations + int4 weight QAT | +# ========================================================= + + +class Int8DynActInt4WeightQATQuantizer(TwoStepQuantizer): + """ + Quantizer for performing QAT on a model, where linear layers have int8 + dynamic per token fake quantized activations and int4 fake quantized + grouped per channel weights. + """ + + def __init__( + self, + groupsize: int = 256, + padding_allowed: bool = False, + precision: torch.dtype = torch.float32, + scales_precision: torch.dtype = torch.float32, + ) -> None: + super().__init__() + self.groupsize: int = groupsize + self.padding_allowed: bool = padding_allowed + self.precision: torch.dtype = precision + self.scales_precision: torch.dtype = scales_precision + + def prepare( + self, + model: torch.nn.Module, + *args: Any, + **kwargs: Any + ) -> torch.nn.Module: + _replace_linear_8da4w( + model, + self.groupsize, + self.padding_allowed, + self.precision, + self.scales_precision, + Int8DynActInt4WeightQATLinear, + copy_weights=True, + ) + return model + + def convert( + self, + model: torch.nn.Module, + *args: Any, + **kwargs: Any + ) -> torch.nn.Module: + _convert_qat_linear_8da4w(model) + return model + + +def _convert_qat_linear_8da4w(module: torch.nn.Module): + """ + Replace all `Int8DynActInt4WeightQATLinear` with `Int8DynActInt4WeightLinear`. + """ + for name, child in module.named_children(): + if isinstance(child, Int8DynActInt4WeightQATLinear): + quantized_linear = Int8DynActInt4WeightLinear( + child.in_features, + child.out_features, + bias=False, + groupsize=child.groupsize, + precision=child.precision, + scales_precision=child.scales_precision, + ) + setattr(module, name, quantized_linear) + + # Load weights and qparams into quantized linear + n_bit = 4 + (qmin, qmax) = _get_qmin_qmax(n_bit) + (s, zp) = get_group_qparams_symmetric(child.weight, n_bit, child.groupsize) + from torchao._executorch_ops import _quantized_decomposed_quantize_per_channel_group_wrapper + q_weight = _quantized_decomposed_quantize_per_channel_group_wrapper( + child.weight, s, zp, qmin, qmax, torch.int8, child.groupsize, + ) + quantized_linear.weight = q_weight + quantized_linear.scales = s + quantized_linear.zeros = zp + else: + _convert_qat_linear_8da4w(child) + + +class Int8DynActInt4WeightQATLinear(torch.nn.Linear): + """ + This module implements a linear layer with int8 dynamic per token fake + quantized activations with int4 fake quantized grouped per channel weights. + + args: + groupsize: the number of elements in each quantized group for weights + precision: precision of weights + scales_precision: precision of per group scales and zero points + """ + + def __init__( + self, + in_features: int, + out_features: int, + bias: bool = False, + device: torch.device = None, + groupsize: int = 256, + precision: torch.dtype = torch.float32, + scales_precision: torch.dtype = torch.float32, + ) -> None: + super().__init__( + in_features, + out_features, + bias, + device=device, + dtype=precision, + ) + assert ( + in_features % groupsize == 0 + ), f"require in_features:{in_features} % groupsize:{groupsize} == 0" + assert not bias, "require bias=False" + self.groupsize = groupsize + self.precision = precision + self.scales_precision = scales_precision + # TODO: make this configurable? + self.zero_points_precision = torch.int32 + self._fake_quant_enabled = True + + def enable_fake_quant(self, enabled: bool = True): + self._fake_quant_enabled = enabled + + def disable_fake_quant(self): + self.enable_fake_quant(False) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + # activations: int8 dynamic asymmetric quant + if self._fake_quant_enabled: + (act_scales, act_zp) = _choose_qparams_per_token_asymmetric( + x, self.scales_precision, self.zero_points_precision, + ) + (act_qmin, act_qmax) = _get_qmin_qmax(8) + x_fq = _fake_quantize_per_token( + x, act_scales, act_zp, act_qmin, act_qmax, + ) + else: + x_fq = x + + # weights: int4 grouped per channel symmetric quant + if self._fake_quant_enabled: + (weight_scales, weight_zp) = get_group_qparams_symmetric( + self.weight, 4, self.groupsize, self.scales_precision, + ) + # TODO: pass zp dtype to `get_group_qparams_symmetric` instead + weight_zp = weight_zp.to(self.zero_points_precision) + (weight_qmin, weight_qmax) = _get_qmin_qmax(4) + w_fq = _fake_quantize_per_channel_group( + self.weight, + weight_scales, + weight_zp, + weight_qmin, + weight_qmax, + self.groupsize, + ) + else: + w_fq = self.weight + return F.linear(x_fq, w_fq) + + +def enable_8da4w_fake_quant(mod: torch.nn.Module): + """ + Enable fake quantization for `Int8DynActInt4WeightQATLinear`. + """ + if isinstance(mod, Int8DynActInt4WeightQATLinear): + mod.enable_fake_quant() + + +def disable_8da4w_fake_quant(mod: torch.nn.Module): + """ + Disable fake quantization for `Int8DynActInt4WeightQATLinear`. + """ + if isinstance(mod, Int8DynActInt4WeightQATLinear): + mod.disable_fake_quant() + + +# =================================== +# | Linear int4 weight-only QAT | +# =================================== + + +class Int4WeightOnlyQATQuantizer(TwoStepQuantizer): + """ + Quantizer for performing QAT on a model, where linear layers have + int4 fake quantized grouped per channel weights. + """ + + def __init__( + self, + groupsize: int = 256, + inner_k_tiles: Optional[int] = 8, + precision: torch.dtype = torch.bfloat16, + scales_precision: torch.dtype = torch.bfloat16, + ) -> None: + super().__init__() + assert inner_k_tiles in [2, 4, 8] + assert groupsize in [32, 64, 128, 256] + self.inner_k_tiles = inner_k_tiles + self.groupsize = groupsize + self.precision = precision + self.scales_precision = scales_precision + + def prepare( + self, + model: torch.nn.Module, + *args: Any, + **kwargs: Any + ) -> torch.nn.Module: + _replace_linear_int4( + model, + self.groupsize, + self.inner_k_tiles, + padding_allowed=True, + precision=self.precision, + scales_precision=self.scales_precision, + linear_class=Int4WeightOnlyQATLinear, + copy_weights=True, + ) + return model + + def convert( + self, + model: torch.nn.Module, + *args: Any, + **kwargs: Any + ) -> torch.nn.Module: + _convert_qat_linear_4w(model) + return model + + +def _convert_qat_linear_4w(module: torch.nn.Module): + """ + Replace all `Int4WeightOnlyQATLinear` with `WeightOnlyInt4Linear`. + """ + for name, child in module.named_children(): + if isinstance(child, Int4WeightOnlyQATLinear): + in_features = child.in_features + out_features = child.out_features + groupsize = child.groupsize + inner_k_tiles = child.inner_k_tiles + quantized_linear = WeightOnlyInt4Linear( + in_features, + out_features, + bias=False, + groupsize=groupsize, + inner_k_tiles=inner_k_tiles, + precision=child.precision, + scales_precision=child.scales_precision, + ) + setattr(module, name, quantized_linear) + + # Load weights and qparams into quantized linear + n_bit = 4 + (q_weight, scales_and_zeros) = groupwise_affine_quantize_tensor( + child.weight, n_bit, child.groupsize, + ) + q_weight = torch.ops.aten._convert_weight_to_int4pack( + q_weight.to(child.weight.device), child.inner_k_tiles, + ) + quantized_linear.weight = q_weight + quantized_linear.scales_and_zeros = scales_and_zeros + else: + _convert_qat_linear_4w(child) + + +class Int4WeightOnlyQATLinear(torch.nn.Linear): + """ + This module implements a linear layer with int4 fake quantized grouped + per channel weights, with forward numerics matching `WeightOnlyInt4Linear`, + which uses the efficient int4 tinygemm kernel. + + args: + groupsize: the number of elements in each quantized group for weights + precision: precision of weights + scales_precision: precision of per group scales and zero points + """ + + def __init__( + self, + in_features: int, + out_features: int, + bias: bool = False, + device: torch.device = None, + groupsize: int = 256, + inner_k_tiles: int = 8, + precision: torch.dtype = torch.bfloat16, + scales_precision: torch.dtype = torch.bfloat16, + ) -> None: + super().__init__( + in_features, + out_features, + bias, + device=device, + dtype=precision, + ) + assert not bias, "require bias=False" + assert scales_precision == torch.bfloat16, "only bf16 is supported for scales" + if not _check_linear_int4_k(in_features, groupsize, inner_k_tiles): + raise ValueError("Padding for QAT 4w is not supported yet") + self.groupsize = groupsize + self.inner_k_tiles = inner_k_tiles + self.precision = precision + self.scales_precision = scales_precision + self._fake_quant_enabled = True + + def enable_fake_quant(self, enabled: bool = True): + self._fake_quant_enabled = enabled + + def disable_fake_quant(self): + self.enable_fake_quant(False) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + n_bit = 4 + qmin = 0 + qmax = 2 ** n_bit - 1 + scales, zero_points = get_groupwise_affine_qparams( + self.weight, n_bit, self.groupsize, self.scales_precision, + ) + w_fq = _fake_quantize_per_channel_group( + self.weight, + scales, + zero_points, + qmin, + qmax, + self.groupsize, + ZeroPointDomain.FLOAT, + ) + return F.linear(x, w_fq) + + +def enable_4w_fake_quant(mod: torch.nn.Module): + """ + Enable fake quantization for `Int4WeightOnlyQATLinear`. + """ + if isinstance(mod, Int4WeightOnlyQATLinear): + mod.enable_fake_quant() + + +def disable_4w_fake_quant(mod: torch.nn.Module): + """ + Disable fake quantization for `Int4WeightOnlyQATLinear`. + """ + if isinstance(mod, Int4WeightOnlyQATLinear): + mod.disable_fake_quant() diff --git a/torchao/quantization/prototype/qat/utils.py b/torchao/quantization/prototype/qat/utils.py index 1e4b61b8ac..354475e655 100644 --- a/torchao/quantization/prototype/qat/utils.py +++ b/torchao/quantization/prototype/qat/utils.py @@ -181,85 +181,6 @@ def _choose_qparams_per_token_asymmetric( return scale.to(scales_precision), zero_point.to(zero_points_precision) -def _forward_pre_hook_handler( - mod: torch.nn.Linear, - prehook: Callable, - handler: torch.utils.hooks.RemovableHandle, -): - """ - Store a 2-tuple (prehook function, handler) as an attribute on the given linear module. - """ - setattr(mod, _QAT_LINEAR_SUBCLASS_INPUT_PREHOOK, (prehook, handler)) - -def _unwrap_affine_fake_quantized_tensor(t: torch.Tensor): - """ - Return the original, non-fake-quantized float tensor from a `AffineFakeQuantizedTensor`. - """ - # avoid circular dependencies - from torchao.quantization.prototype.qat.affine_fake_quantized_tensor import ( - AffineFakeQuantizedTensor, - ) - assert isinstance(t, AffineFakeQuantizedTensor) - return t.original_tensor - -def _is_linear_with_fq_weight(mod: torch.nn.Module, *args): - """ - Return whether this is a nn.Linear module with `AffineFakeQuantizeTensor` weights. - """ - # avoid circular dependencies - from torchao.quantization.prototype.qat.affine_fake_quantized_tensor import ( - AffineFakeQuantizedTensor, - ) - if not isinstance(mod, torch.nn.Linear) or not hasattr(mod, "weight"): - return False - weight = mod.weight - return isinstance(weight, AffineFakeQuantizedTensor) - -def _enable_fake_quant(mod: torch.nn.Module, enable: bool): - """ - Enable or disable fake quantization in the activations and weights of a `nn.Linear` module. - """ - from torchao.quantization.prototype.qat.affine_fake_quantized_tensor import ( - AffineFakeQuantizedTensor, - ) - if not _is_linear_with_fq_weight(mod): - return - weight = mod.weight - assert isinstance(weight, AffineFakeQuantizedTensor) - weight.fake_quant_enabled = enable - - # Enable/disable input fake quant - if hasattr(mod, _QAT_LINEAR_SUBCLASS_INPUT_PREHOOK): - (prehook, handle) = getattr(mod, _QAT_LINEAR_SUBCLASS_INPUT_PREHOOK) - if enable and handle is None: - handle = mod.register_forward_pre_hook(prehook) - elif not enable and handle is not None: - handle.remove() - handle = None - setattr(mod, _QAT_LINEAR_SUBCLASS_INPUT_PREHOOK, (prehook, handle)) - -def _get_qat_linear_subclass_inserter( - weight_constructor: Callable, - input_constructor: Optional[Callable] = None, -) -> Callable: - """ - Return a function that inserts wraps the weight and/or input activation of a - linear module in tensor subclasses. - - Args: - weight_constructor: constructor of the weight subclass, accepts a tensor - input_constructor: (optional) constructor of the input subclass, accepts a tensor - """ - def insert_subclass(lin): - lin.weight = torch.nn.Parameter(weight_constructor(lin.weight), requires_grad=True) - if input_constructor is not None: - prehook = lambda _, args: tuple([input_constructor(args[0])] + list(args[1:])) - handle = lin.register_forward_pre_hook(prehook) - setattr(lin, _QAT_LINEAR_SUBCLASS_INPUT_PREHOOK, (prehook, handle)) - return lin - - return insert_subclass - def _get_qmin_qmax(n_bit: int): qmin = -(2 ** (n_bit - 1)) qmax = 2 ** (n_bit - 1) - 1