-
Notifications
You must be signed in to change notification settings - Fork 197
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[RFC] Long Term QAT Flow #987
Comments
@andrewor14 by |
Summary: Following #987, this commit makes module swap the main QAT flow today. We remove all tensor subclass fake quantize injection logic since this is not needed in both the long term and the short term plans for QAT. In the short term, we will continue to use a full module swap flow, and only migrate to the long term flow once there is general distributed support for tensor subclasses and when tensor subclass composability provides meaningful benefits. Test Plan: python test/quantization/test_qat.py
Yeah, this is referring to how the data is represented during fake quantization in the training phase (after prepare but before convert), not the final quantized data (after convert). I'm open to renaming suggestions if you have any |
Summary: Following #987, this commit makes module swap the main QAT flow today. We remove all tensor subclass fake quantize injection logic since this is not needed in both the long term and the short term plans for QAT. In the short term, we will continue to use a full module swap flow, and only migrate to the long term flow once there is general distributed support for tensor subclasses and when tensor subclass composability provides meaningful benefits. Test Plan: python test/quantization/test_qat.py
maybe just "Fake Quantization Implementation"? |
Just curious. What are the problems that you observe with tensor subclass + DDP? I have used this combination in my other projects and it seems to work as expected (i.e. no errors, correct results) |
To clarify, not sure that "module swap for data representation makes sense". Should this say "use plain torch.Tensor for data representation"? For FSDP1 composability, from what I understand as long as you don't have model parameter wrappers, you can still use tensor subclass for data representation, and thus get the benefits of integrating with other distributed paradigms (TP/SP) and benefits of easily using low precision gemms. |
That's great to know. Can you share the links?
Sounds good. For FSDP1, the issue I ran into was moving the model to a different device moves only the outer tensor but not the inner tensor, and this is fundamental to how FSDP1 assigns |
@andrewor14 Train script https://github.com/gau-nernst/quantized-training/blob/main/llm_pretrain.py. DDP stuff is pretty standard, no changes. The subclasses are swapped by |
Summary: Following #987, this commit makes module swap the main QAT flow today. We remove all tensor subclass fake quantize injection logic since this is not needed in both the long term and the short term plans for QAT. In the short term, we will continue to use a full module swap flow, and only migrate to the long term flow once there is general distributed support for tensor subclasses and when tensor subclass composability provides meaningful benefits. Test Plan: python test/quantization/test_qat.py [ghstack-poisoned]
Summary: Following #987, this commit makes module swap the main QAT flow today. We remove all tensor subclass fake quantize injection logic since this is not needed in both the long term and the short term plans for QAT. In the short term, we will continue to use a full module swap flow, and only migrate to the long term flow once there is general distributed support for tensor subclasses and when tensor subclass composability provides meaningful benefits. Test Plan: python test/quantization/test_qat.py [ghstack-poisoned]
**Summary:** Following #987, this commit makes module swap the main QAT flow today. We remove all tensor subclass fake quantize injection logic since this is not needed in both the long term and the short term plans for QAT. In the short term, we will continue to use a full module swap flow, and only migrate to the long term flow once there is general distributed support for tensor subclasses and when tensor subclass composability provides meaningful benefits. **Test Plan:** python test/quantization/test_qat.py [ghstack-poisoned]
**Summary:** Following #987, this commit makes module swap the main QAT flow today. We remove all tensor subclass fake quantize injection logic since this is not needed in both the long term and the short term plans for QAT. In the short term, we will continue to use a full module swap flow, and only migrate to the long term flow once there is general distributed support for tensor subclasses and when tensor subclass composability provides meaningful benefits. **Test Plan:** python test/quantization/test_qat.py [ghstack-poisoned]
Summary: Following #987, this commit makes module swap the main QAT flow today. We remove all tensor subclass fake quantize injection logic since this is not needed in both the long term and the short term plans for QAT. In the short term, we will continue to use a full module swap flow, and only migrate to the long term flow once there is general distributed support for tensor subclasses and when tensor subclass composability provides meaningful benefits. Test Plan: python test/quantization/test_qat.py [ghstack-poisoned]
Summary: Following #987, this commit makes module swap the main QAT flow today. We remove all tensor subclass fake quantize injection logic since this is not needed in both the long term and the short term plans for QAT. In the short term, we will continue to use a full module swap flow, and only migrate to the long term flow once there is general distributed support for tensor subclasses and when tensor subclass composability provides meaningful benefits. Test Plan: python test/quantization/test_qat.py [ghstack-poisoned]
Summary: Following #987, this commit makes module swap the main QAT flow today. We remove all tensor subclass fake quantize injection logic since this is not needed in both the long term and the short term plans for QAT. In the short term, we will continue to use a full module swap flow, and only migrate to the long term flow once there is general distributed support for tensor subclasses and when tensor subclass composability provides meaningful benefits. Test Plan: python test/quantization/test_qat.py [ghstack-poisoned]
* Make module swap the main QAT flow again Summary: Following #987, this commit makes module swap the main QAT flow today. We remove all tensor subclass fake quantize injection logic since this is not needed in both the long term and the short term plans for QAT. In the short term, we will continue to use a full module swap flow, and only migrate to the long term flow once there is general distributed support for tensor subclasses and when tensor subclass composability provides meaningful benefits. Test Plan: python test/quantization/test_qat.py [ghstack-poisoned] * Move and rename GranularityType -> Granularity Summary: Move GranularityType to quant_primitives.py to be consistent with other similar fields like MappingType and ZeroPointDomain. Test Plan: CI [ghstack-poisoned] * Update on "Move and rename GranularityType -> Granularity" Summary: Move GranularityType to quant_primitives.py to be consistent with other similar fields like MappingType and ZeroPointDomain. Test Plan: CI [ghstack-poisoned] * Update on "Move and rename GranularityType -> Granularity" Summary: Move GranularityType to quant_primitives.py to be consistent with other similar fields like MappingType and ZeroPointDomain. Test Plan: CI [ghstack-poisoned] * Update on "Move and rename GranularityType -> Granularity" Summary: Move GranularityType to quant_primitives.py to be consistent with other similar fields like MappingType and ZeroPointDomain. Test Plan: CI [ghstack-poisoned] * Update base for Update on "Move and rename GranularityType -> Granularity" Summary: Move GranularityType to quant_primitives.py to be consistent with other similar fields like MappingType and ZeroPointDomain. Test Plan: CI [ghstack-poisoned]
Summary: Following #987, this commit makes module swap the main QAT flow today. We remove all tensor subclass fake quantize injection logic since this is not needed in both the long term and the short term plans for QAT. In the short term, we will continue to use a full module swap flow, and only migrate to the long term flow once there is general distributed support for tensor subclasses and when tensor subclass composability provides meaningful benefits. Test Plan: python test/quantization/test_qat.py [ghstack-poisoned]
* Make module swap the main QAT flow again Summary: Following #987, this commit makes module swap the main QAT flow today. We remove all tensor subclass fake quantize injection logic since this is not needed in both the long term and the short term plans for QAT. In the short term, we will continue to use a full module swap flow, and only migrate to the long term flow once there is general distributed support for tensor subclasses and when tensor subclass composability provides meaningful benefits. Test Plan: python test/quantization/test_qat.py [ghstack-poisoned] * Move and rename GranularityType -> Granularity Summary: Move GranularityType to quant_primitives.py to be consistent with other similar fields like MappingType and ZeroPointDomain. Test Plan: CI [ghstack-poisoned] * Update on "Move and rename GranularityType -> Granularity" Summary: Move GranularityType to quant_primitives.py to be consistent with other similar fields like MappingType and ZeroPointDomain. Test Plan: CI [ghstack-poisoned] * Update on "Move and rename GranularityType -> Granularity" Summary: Move GranularityType to quant_primitives.py to be consistent with other similar fields like MappingType and ZeroPointDomain. Test Plan: CI [ghstack-poisoned] * Update on "Move and rename GranularityType -> Granularity" Summary: Move GranularityType to quant_primitives.py to be consistent with other similar fields like MappingType and ZeroPointDomain. Test Plan: CI [ghstack-poisoned] * Update base for Update on "Move and rename GranularityType -> Granularity" Summary: Move GranularityType to quant_primitives.py to be consistent with other similar fields like MappingType and ZeroPointDomain. Test Plan: CI [ghstack-poisoned]
* Make module swap the main QAT flow again Summary: Following #987, this commit makes module swap the main QAT flow today. We remove all tensor subclass fake quantize injection logic since this is not needed in both the long term and the short term plans for QAT. In the short term, we will continue to use a full module swap flow, and only migrate to the long term flow once there is general distributed support for tensor subclasses and when tensor subclass composability provides meaningful benefits. Test Plan: python test/quantization/test_qat.py [ghstack-poisoned] * Add generic fake quantized linear for QAT Summary: This commit adds a generic fake quantized linear module to replace the uses of the existing more specific QAT linears. For example, `Int8DynActInt4WeightQATLinear` can be expressed as follows: ``` from torchao.quantization.prototype.qat.api import FakeQuantizeConfig from torchao.quantization.prototype.qat.linear import FakeQuantizedLinear activation_config = FakeQuantizeConfig( bit_width=8, granularity="per_token", symmetric=False, dynamic=True, ) weight_config = FakeQuantizeConfig( bit_width=4, group_size=8, symmetric=True, dynamic=True, ) fq_linear = FakeQuantizedLinear( 16, 32, False, activation_config, weight_config, ) ``` The main motivation is to provide a more flexible way to perform QAT on models with linear layers. Previously, we would have to create a new linear class every time we wish to experiment with different fake quantization settings, e.g. different group size or different bit width. Now we can express this easily using a single linear module. Test Plan: python test/quantization/test_qat.py -k test_fake_quantize_config python test/quantization/test_qat.py -k test_fake_quantized_linear_8da4w python test/quantization/test_qat.py -k test_fake_quantized_linear_4w [ghstack-poisoned] * Update base for Update on "Add generic fake quantized linear for QAT" **Summary:** This commit adds a generic fake quantized linear module to replace the uses of the existing more specific QAT linears. For example, `Int8DynActInt4WeightQATLinear` can be expressed as follows: ``` from torchao.quantization.prototype.qat.api import FakeQuantizeConfig from torchao.quantization.prototype.qat.linear import FakeQuantizedLinear activation_config = FakeQuantizeConfig( bit_width=8, granularity="per_token", symmetric=False, dynamic=True, ) weight_config = FakeQuantizeConfig( bit_width=4, group_size=8, symmetric=True, dynamic=True, ) fq_linear = FakeQuantizedLinear( 16, 32, False, activation_config, weight_config, ) ``` The main motivation is to provide a more flexible way to perform QAT on models with linear layers. Previously, we would have to create a new linear class every time we wish to experiment with different fake quantization settings, e.g. different group size or different bit width. Now we can express this easily using a single linear module. **Test Plan:** python test/quantization/test_qat.py -k test_fake_quantize_config python test/quantization/test_qat.py -k test_fake_quantized_linear_8da4w python test/quantization/test_qat.py -k test_fake_quantized_linear_4w [ghstack-poisoned] * Update base for Update on "Add generic fake quantized linear for QAT" **Summary:** This commit adds a generic fake quantized linear module to replace the uses of the existing more specific QAT linears. For example, `Int8DynActInt4WeightQATLinear` can be expressed as follows: ``` from torchao.quantization.prototype.qat.api import FakeQuantizeConfig from torchao.quantization.prototype.qat.linear import FakeQuantizedLinear activation_config = FakeQuantizeConfig( bit_width=8, granularity="per_token", symmetric=False, dynamic=True, ) weight_config = FakeQuantizeConfig( bit_width=4, group_size=8, symmetric=True, dynamic=True, ) fq_linear = FakeQuantizedLinear( 16, 32, False, activation_config, weight_config, ) ``` The main motivation is to provide a more flexible way to perform QAT on models with linear layers. Previously, we would have to create a new linear class every time we wish to experiment with different fake quantization settings, e.g. different group size or different bit width. Now we can express this easily using a single linear module. **Test Plan:** python test/quantization/test_qat.py -k test_fake_quantize_config python test/quantization/test_qat.py -k test_fake_quantized_linear_8da4w python test/quantization/test_qat.py -k test_fake_quantized_linear_4w [ghstack-poisoned] * Update base for Update on "Add generic fake quantized linear for QAT" **Summary:** This commit adds a generic fake quantized linear module to replace the uses of the existing more specific QAT linears. For example, `Int8DynActInt4WeightQATLinear` can be expressed as follows: ``` from torchao.quantization.prototype.qat.api import FakeQuantizeConfig from torchao.quantization.prototype.qat.linear import FakeQuantizedLinear activation_config = FakeQuantizeConfig( bit_width=8, granularity="per_token", symmetric=False, dynamic=True, ) weight_config = FakeQuantizeConfig( bit_width=4, group_size=8, symmetric=True, dynamic=True, ) fq_linear = FakeQuantizedLinear( 16, 32, False, activation_config, weight_config, ) ``` The main motivation is to provide a more flexible way to perform QAT on models with linear layers. Previously, we would have to create a new linear class every time we wish to experiment with different fake quantization settings, e.g. different group size or different bit width. Now we can express this easily using a single linear module. **Test Plan:** python test/quantization/test_qat.py -k test_fake_quantize_config python test/quantization/test_qat.py -k test_fake_quantized_linear_8da4w python test/quantization/test_qat.py -k test_fake_quantized_linear_4w [ghstack-poisoned] * Update base for Update on "Add generic fake quantized linear for QAT" **Summary:** This commit adds a generic fake quantized linear module to replace the uses of the existing more specific QAT linears. For example, `Int8DynActInt4WeightQATLinear` can be expressed as follows: ``` from torchao.quantization.prototype.qat.api import FakeQuantizeConfig from torchao.quantization.prototype.qat.linear import FakeQuantizedLinear activation_config = FakeQuantizeConfig( bit_width=8, granularity="per_token", symmetric=False, dynamic=True, ) weight_config = FakeQuantizeConfig( bit_width=4, group_size=8, symmetric=True, dynamic=True, ) fq_linear = FakeQuantizedLinear( 16, 32, False, activation_config, weight_config, ) ``` The main motivation is to provide a more flexible way to perform QAT on models with linear layers. Previously, we would have to create a new linear class every time we wish to experiment with different fake quantization settings, e.g. different group size or different bit width. Now we can express this easily using a single linear module. **Test Plan:** python test/quantization/test_qat.py -k test_fake_quantize_config python test/quantization/test_qat.py -k test_fake_quantized_linear_8da4w python test/quantization/test_qat.py -k test_fake_quantized_linear_4w [ghstack-poisoned] * Update base for Update on "Add generic fake quantized linear for QAT" **Summary:** This commit adds a generic fake quantized linear module to replace the uses of the existing more specific QAT linears. For example, `Int8DynActInt4WeightQATLinear` can be expressed as follows: ``` from torchao.quantization.prototype.qat.api import FakeQuantizeConfig from torchao.quantization.prototype.qat.linear import FakeQuantizedLinear activation_config = FakeQuantizeConfig( bit_width=8, granularity="per_token", symmetric=False, dynamic=True, ) weight_config = FakeQuantizeConfig( bit_width=4, group_size=8, symmetric=True, dynamic=True, ) fq_linear = FakeQuantizedLinear( 16, 32, False, activation_config, weight_config, ) ``` The main motivation is to provide a more flexible way to perform QAT on models with linear layers. Previously, we would have to create a new linear class every time we wish to experiment with different fake quantization settings, e.g. different group size or different bit width. Now we can express this easily using a single linear module. **Test Plan:** python test/quantization/test_qat.py -k test_fake_quantize_config python test/quantization/test_qat.py -k test_fake_quantized_linear_8da4w python test/quantization/test_qat.py -k test_fake_quantized_linear_4w [ghstack-poisoned] * Update base for Update on "Add generic fake quantized linear for QAT" **Summary:** This commit adds a generic fake quantized linear module to replace the uses of the existing more specific QAT linears. For example, `Int8DynActInt4WeightQATLinear` can be expressed as follows: ``` from torchao.quantization.prototype.qat.api import FakeQuantizeConfig from torchao.quantization.prototype.qat.linear import FakeQuantizedLinear activation_config = FakeQuantizeConfig(torch.int8, "per_token", is_symmetric=False) weight_config = FakeQuantizeConfig(torch.int4, group_size=8) fq_linear = FakeQuantizedLinear(16, 32, False, activation_config, weight_config) ``` The main motivation is to provide a more flexible way to perform QAT on models with linear layers. Previously, we would have to create a new linear class every time we wish to experiment with different fake quantization settings, e.g. different group size or different bit width. Now we can express this easily using a single linear module. **Test Plan:** python test/quantization/test_qat.py -k test_fake_quantize_config_granularity python test/quantization/test_qat.py -k test_fake_quantize_config_granularity_error_cases python test/quantization/test_qat.py -k test_fake_quantize_config_mapping_type python test/quantization/test_qat.py -k test_fake_quantized_linear_8da4w python test/quantization/test_qat.py -k test_fake_quantized_linear_4w [ghstack-poisoned] * Update base for Update on "Add generic fake quantized linear for QAT" **Summary:** This commit adds a generic fake quantized linear module to replace the uses of the existing more specific QAT linears. For example, `Int8DynActInt4WeightQATLinear` can be expressed as follows: ``` from torchao.quantization.prototype.qat.api import FakeQuantizeConfig from torchao.quantization.prototype.qat.linear import FakeQuantizedLinear activation_config = FakeQuantizeConfig(torch.int8, "per_token", is_symmetric=False) weight_config = FakeQuantizeConfig(torch.int4, group_size=8) fq_linear = FakeQuantizedLinear(16, 32, False, activation_config, weight_config) ``` The main motivation is to provide a more flexible way to perform QAT on models with linear layers. Previously, we would have to create a new linear class every time we wish to experiment with different fake quantization settings, e.g. different group size or different bit width. Now we can express this easily using a single linear module. **Test Plan:** python test/quantization/test_qat.py -k test_fake_quantize_config_granularity python test/quantization/test_qat.py -k test_fake_quantize_config_granularity_error_cases python test/quantization/test_qat.py -k test_fake_quantize_config_mapping_type python test/quantization/test_qat.py -k test_fake_quantized_linear_8da4w python test/quantization/test_qat.py -k test_fake_quantized_linear_4w [ghstack-poisoned] * Update base for Update on "Add generic fake quantized linear for QAT" **Summary:** This commit adds a generic fake quantized linear module to replace the uses of the existing more specific QAT linears. For example, `Int8DynActInt4WeightQATLinear` can be expressed as follows: ``` from torchao.quantization.prototype.qat.api import FakeQuantizeConfig from torchao.quantization.prototype.qat.linear import FakeQuantizedLinear activation_config = FakeQuantizeConfig(torch.int8, "per_token", is_symmetric=False) weight_config = FakeQuantizeConfig(torch.int4, group_size=8) fq_linear = FakeQuantizedLinear(16, 32, False, activation_config, weight_config) ``` The main motivation is to provide a more flexible way to perform QAT on models with linear layers. Previously, we would have to create a new linear class every time we wish to experiment with different fake quantization settings, e.g. different group size or different bit width. Now we can express this easily using a single linear module. **Test Plan:** python test/quantization/test_qat.py -k test_fake_quantize_config_granularity python test/quantization/test_qat.py -k test_fake_quantize_config_granularity_error_cases python test/quantization/test_qat.py -k test_fake_quantize_config_mapping_type python test/quantization/test_qat.py -k test_fake_quantized_linear_8da4w python test/quantization/test_qat.py -k test_fake_quantized_linear_4w [ghstack-poisoned] * Update base for Update on "Add generic fake quantized linear for QAT" **Summary:** This commit adds a generic fake quantized linear module to replace the uses of the existing more specific QAT linears. For example, `Int8DynActInt4WeightQATLinear` can be expressed as follows: ``` from torchao.quantization.prototype.qat.api import FakeQuantizeConfig from torchao.quantization.prototype.qat.linear import FakeQuantizedLinear activation_config = FakeQuantizeConfig(torch.int8, "per_token", is_symmetric=False) weight_config = FakeQuantizeConfig(torch.int4, group_size=8) fq_linear = FakeQuantizedLinear(16, 32, False, activation_config, weight_config) ``` The main motivation is to provide a more flexible way to perform QAT on models with linear layers. Previously, we would have to create a new linear class every time we wish to experiment with different fake quantization settings, e.g. different group size or different bit width. Now we can express this easily using a single linear module. **Test Plan:** python test/quantization/test_qat.py -k test_fake_quantize_config_granularity python test/quantization/test_qat.py -k test_fake_quantize_config_granularity_error_cases python test/quantization/test_qat.py -k test_fake_quantize_config_mapping_type python test/quantization/test_qat.py -k test_fake_quantized_linear_8da4w python test/quantization/test_qat.py -k test_fake_quantized_linear_4w [ghstack-poisoned] * Update base for Update on "Add generic fake quantized linear for QAT" **Summary:** This commit adds a generic fake quantized linear module to replace the uses of the existing more specific QAT linears. For example, `Int8DynActInt4WeightQATLinear` can be expressed as follows: ``` from torchao.quantization.prototype.qat.api import FakeQuantizeConfig from torchao.quantization.prototype.qat.linear import FakeQuantizedLinear activation_config = FakeQuantizeConfig(torch.int8, "per_token", is_symmetric=False) weight_config = FakeQuantizeConfig(torch.int4, group_size=8) fq_linear = FakeQuantizedLinear(16, 32, False, activation_config, weight_config) ``` The main motivation is to provide a more flexible way to perform QAT on models with linear layers. Previously, we would have to create a new linear class every time we wish to experiment with different fake quantization settings, e.g. different group size or different bit width. Now we can express this easily using a single linear module. **Test Plan:** python test/quantization/test_qat.py -k test_fake_quantize_config_granularity python test/quantization/test_qat.py -k test_fake_quantize_config_granularity_error_cases python test/quantization/test_qat.py -k test_fake_quantize_config_mapping_type python test/quantization/test_qat.py -k test_fake_quantized_linear_8da4w python test/quantization/test_qat.py -k test_fake_quantized_linear_4w [ghstack-poisoned] * Update base for Update on "Add generic fake quantized linear for QAT" **Summary:** This commit adds a generic fake quantized linear module to replace the uses of the existing more specific QAT linears. For example, `Int8DynActInt4WeightQATLinear` can be expressed as follows: ``` from torchao.quantization.prototype.qat.api import FakeQuantizeConfig from torchao.quantization.prototype.qat.linear import FakeQuantizedLinear activation_config = FakeQuantizeConfig(torch.int8, "per_token", is_symmetric=False) weight_config = FakeQuantizeConfig(torch.int4, group_size=8) fq_linear = FakeQuantizedLinear(16, 32, False, activation_config, weight_config) ``` The main motivation is to provide a more flexible way to perform QAT on models with linear layers. Previously, we would have to create a new linear class every time we wish to experiment with different fake quantization settings, e.g. different group size or different bit width. Now we can express this easily using a single linear module. **Test Plan:** python test/quantization/test_qat.py -k test_fake_quantize_config_granularity python test/quantization/test_qat.py -k test_fake_quantize_config_granularity_error_cases python test/quantization/test_qat.py -k test_fake_quantize_config_mapping_type python test/quantization/test_qat.py -k test_fake_quantized_linear_8da4w python test/quantization/test_qat.py -k test_fake_quantized_linear_4w [ghstack-poisoned] * Update base for Update on "Add generic fake quantized linear for QAT" **Summary:** This commit adds a generic fake quantized linear module to replace the uses of the existing more specific QAT linears. For example, `Int8DynActInt4WeightQATLinear` can be expressed as follows: ``` from torchao.quantization.prototype.qat.api import FakeQuantizeConfig from torchao.quantization.prototype.qat.linear import FakeQuantizedLinear activation_config = FakeQuantizeConfig(torch.int8, "per_token", is_symmetric=False) weight_config = FakeQuantizeConfig(torch.int4, group_size=8) fq_linear = FakeQuantizedLinear(16, 32, False, activation_config, weight_config) ``` The main motivation is to provide a more flexible way to perform QAT on models with linear layers. Previously, we would have to create a new linear class every time we wish to experiment with different fake quantization settings, e.g. different group size or different bit width. Now we can express this easily using a single linear module. **Test Plan:** python test/quantization/test_qat.py -k test_fake_quantize_config_granularity python test/quantization/test_qat.py -k test_fake_quantize_config_granularity_error_cases python test/quantization/test_qat.py -k test_fake_quantize_config_mapping_type python test/quantization/test_qat.py -k test_fake_quantized_linear_8da4w python test/quantization/test_qat.py -k test_fake_quantized_linear_4w [ghstack-poisoned] * Update base for Update on "Add generic fake quantized linear for QAT" **Summary:** This commit adds a generic fake quantized linear module to replace the uses of the existing more specific QAT linears. For example, `Int8DynActInt4WeightQATLinear` can be expressed as follows: ``` from torchao.quantization.prototype.qat.api import FakeQuantizeConfig from torchao.quantization.prototype.qat.linear import FakeQuantizedLinear activation_config = FakeQuantizeConfig(torch.int8, "per_token", is_symmetric=False) weight_config = FakeQuantizeConfig(torch.int4, group_size=8) fq_linear = FakeQuantizedLinear(16, 32, False, activation_config, weight_config) ``` The main motivation is to provide a more flexible way to perform QAT on models with linear layers. Previously, we would have to create a new linear class every time we wish to experiment with different fake quantization settings, e.g. different group size or different bit width. Now we can express this easily using a single linear module. **Test Plan:** python test/quantization/test_qat.py -k test_fake_quantize_config_granularity python test/quantization/test_qat.py -k test_fake_quantize_config_granularity_error_cases python test/quantization/test_qat.py -k test_fake_quantize_config_mapping_type python test/quantization/test_qat.py -k test_fake_quantized_linear_8da4w python test/quantization/test_qat.py -k test_fake_quantized_linear_4w [ghstack-poisoned] * Update base for Update on "Add generic fake quantized linear for QAT" **Summary:** This commit adds a generic fake quantized linear module to replace the uses of the existing more specific QAT linears. For example, `Int8DynActInt4WeightQATLinear` can be expressed as follows: ``` from torchao.quantization.prototype.qat.api import FakeQuantizeConfig from torchao.quantization.prototype.qat.linear import FakeQuantizedLinear activation_config = FakeQuantizeConfig(torch.int8, "per_token", is_symmetric=False) weight_config = FakeQuantizeConfig(torch.int4, group_size=8) fq_linear = FakeQuantizedLinear(16, 32, False, activation_config, weight_config) ``` The main motivation is to provide a more flexible way to perform QAT on models with linear layers. Previously, we would have to create a new linear class every time we wish to experiment with different fake quantization settings, e.g. different group size or different bit width. Now we can express this easily using a single linear module. **Test Plan:** python test/quantization/test_qat.py -k test_fake_quantize_config_granularity python test/quantization/test_qat.py -k test_fake_quantize_config_granularity_error_cases python test/quantization/test_qat.py -k test_fake_quantize_config_mapping_type python test/quantization/test_qat.py -k test_fake_quantized_linear_8da4w python test/quantization/test_qat.py -k test_fake_quantized_linear_4w [ghstack-poisoned]
Summary: pytorch/ao#1091 moved QAT out of prototype in torchao. This is a BC-breaking change so torchtune also needs to update its QAT imports. Additionally, after pytorch/ao#987 we decided that QAT in torchao will use module swaps to insert fake quantizes, so there is no need to have a separate module swap quantizer, so this commit removes the `*ModuleSwapQuantizer` option. Test Plan: pytest -m integration_test tests/recipes/test_qat_distributed.py should work
Summary: pytorch/ao#1091 moved QAT out of prototype in torchao. This is a BC-breaking change so torchtune also needs to update its QAT imports. Additionally, after pytorch/ao#987 we decided that QAT in torchao will use module swaps to insert fake quantizes, so there is no need to have a separate module swap quantizer, so this commit removes the `*ModuleSwapQuantizer` option. Test Plan: pytest -m integration_test tests/recipes/test_qat_distributed.py should work
Summary: pytorch/ao#1091 moved QAT out of prototype in torchao. This is a BC-breaking change so torchtune also needs to update its QAT imports. Additionally, after pytorch/ao#987 we decided that QAT in torchao will use module swaps to insert fake quantizes, so there is no need to have a separate module swap quantizer, so this commit removes the `*ModuleSwapQuantizer` option. Test Plan: pytest -m integration_test tests/recipes/test_qat_distributed.py should work
Summary: pytorch/ao#1091 moved QAT out of prototype in torchao. This is a BC-breaking change so torchtune also needs to update its QAT imports. Additionally, after pytorch/ao#987 we decided that QAT in torchao will use module swaps to insert fake quantizes, so there is no need to have a separate module swap quantizer, so this commit removes the `*ModuleSwapQuantizer` option. Test Plan: pytest -m integration_test tests/recipes/test_qat_distributed.py should work
* CLI: Remove unsafe access of unused args * Annotate the args conditional on subcommands in functions * Typo in generate.py
Currently torchao QAT has two APIs, tensor subclasses and module swap. The original plan was to deprecate and eventually remove the old module swap API in favor of the tensor subclass API. However, users are starting to rely on the module API for production uses due to gaps in the tensor subclass API. In this RFC, we discuss the few long term plans for these two APIs in torchao.
API Today
We use a quantizer API today to abstract the implementation details from the user. Currently we support both tensor subclass and module swap APIs using different quantizers:
Module Swap vs Tensor Subclass
Although tensor subclasses are generally adopted in torchao, the main gap today are (1) the lack of general distributed support, and (2) steep learning curve. For these two reasons, some users prefer the module swap flow, and have begun implementing new features in this flow, such as embedding quantization and static QAT.
To summarize the pros and cons of both approaches:
We can separate tensor subclass usage into two categories:
nn.Linear
modules and swap out the weight tensor, while module swap injection means we look fornn.Linear
modules and swap out the whole module with our customQATLinear
. Today, the tensor subclass flow in torchao uses the former, while the module swap flow uses the latter.AffineFakeQuantizedTensor
to encode the desired fake quantization configurations, or we can use plaintorch.Tensor
.Long Term Flow
We propose to use module swap for injection and tensor subclass for implementing fake quantization in the long term. This has the following pros and cons compared to the alternatives:
Note: In the short term, we will continue to use plain
torch.Tensor
s for fake quantization due to the lack of general distributed support for tensor subclasses. The distributed strategies we should support before migrating to the long term flow include DDP and FSDP1. Additionally, we should migrate only if tensor subclass composability provides meaningful performance benefits, such as faster fake quantization through efficient int8 kernels.The text was updated successfully, but these errors were encountered: