From 230c1fc3711e741fe7d3df6270517a8bf8927543 Mon Sep 17 00:00:00 2001 From: Vitalii Dziuba Date: Mon, 18 Nov 2024 17:48:50 -0800 Subject: [PATCH] Add static quantization for SUM in Quantizer PiperOrigin-RevId: 697818956 --- ai_edge_quantizer/algorithm_manager.py | 2 + .../naive_min_max_quantize.py | 15 +++ .../sum_test.py | 105 +++++++++++++++ ai_edge_quantizer/calibrator_test.py | 2 +- ai_edge_quantizer/default_policy.py | 6 +- ai_edge_quantizer/qtyping.py | 1 + .../tests/end_to_end_tests/slice_test.py | 5 +- .../tests/end_to_end_tests/sum_test.py | 124 ++++++++++++++++++ .../tests/models/single_sum.tflite | Bin 0 -> 1016 bytes .../utils/tfl_flatbuffer_utils.py | 1 + 10 files changed, 255 insertions(+), 6 deletions(-) create mode 100644 ai_edge_quantizer/algorithms/uniform_quantize/naive_min_max_quantize_op_tests/sum_test.py create mode 100644 ai_edge_quantizer/tests/end_to_end_tests/sum_test.py create mode 100644 ai_edge_quantizer/tests/models/single_sum.tflite diff --git a/ai_edge_quantizer/algorithm_manager.py b/ai_edge_quantizer/algorithm_manager.py index d7e7ba1..5a72a70 100644 --- a/ai_edge_quantizer/algorithm_manager.py +++ b/ai_edge_quantizer/algorithm_manager.py @@ -90,6 +90,7 @@ class AlgorithmName(str, enum.Enum): _TFLOpName.SPLIT, _TFLOpName.LOGISTIC, # Sigmoid _TFLOpName.SLICE, + _TFLOpName.SUM, ), ( naive_min_max_quantize.materialize_input, @@ -116,6 +117,7 @@ class AlgorithmName(str, enum.Enum): naive_min_max_quantize.materialize_split, naive_min_max_quantize.materialize_softmax_and_logistic, naive_min_max_quantize.materialize_slice, + naive_min_max_quantize.materialize_sum, ), ): register_quantized_op( diff --git a/ai_edge_quantizer/algorithms/uniform_quantize/naive_min_max_quantize.py b/ai_edge_quantizer/algorithms/uniform_quantize/naive_min_max_quantize.py index e5e3e91..411f77d 100644 --- a/ai_edge_quantizer/algorithms/uniform_quantize/naive_min_max_quantize.py +++ b/ai_edge_quantizer/algorithms/uniform_quantize/naive_min_max_quantize.py @@ -325,6 +325,21 @@ def materialize_slice( ) +def materialize_sum( + op_info: qtyping.OpInfo, + graph_info: qtyping.GraphInfo, + tensor_name_to_qsv: dict[str, Any], +) -> list[qtyping.TensorTransformationParams]: + """Materialize tensors in tfl.sum.""" + return utils.materialize_standard_op( + op_info, + graph_info, + tensor_name_to_qsv, + constraint=_OpQuantConstraint.SAME_AS_INPUT_SCALE, + inputs_to_ignore=[1], # Axis index does not need to be quantized. + ) + + def materialize_fc_conv( op_info: qtyping.OpInfo, graph_info: qtyping.GraphInfo, diff --git a/ai_edge_quantizer/algorithms/uniform_quantize/naive_min_max_quantize_op_tests/sum_test.py b/ai_edge_quantizer/algorithms/uniform_quantize/naive_min_max_quantize_op_tests/sum_test.py new file mode 100644 index 0000000..7cff7df --- /dev/null +++ b/ai_edge_quantizer/algorithms/uniform_quantize/naive_min_max_quantize_op_tests/sum_test.py @@ -0,0 +1,105 @@ +# Copyright 2024 The AI Edge Quantizer Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +import os + +from absl.testing import parameterized +import numpy as np + +from tensorflow.python.platform import googletest +from ai_edge_quantizer import qtyping +from ai_edge_quantizer.algorithms.uniform_quantize import naive_min_max_quantize +from ai_edge_quantizer.algorithms.uniform_quantize.naive_min_max_quantize_op_tests import test_utils as naive_min_max_test_utils +from ai_edge_quantizer.utils import test_utils +from ai_edge_quantizer.utils import tfl_flatbuffer_utils + +_TFLOpName = qtyping.TFLOperationName +_ComputePrecision = qtyping.ComputePrecision +_TensorQuantConfig = qtyping.TensorQuantizationConfig +_QuantTransformation = qtyping.QuantTransformation +_OpTestInfo = naive_min_max_test_utils.OpTestInfo + +_TEST_DATA_PREFIX_PATH = test_utils.get_path_to_datafile( + "../../../tests/models" +) +_DEFAULT_WEIGHT_QUANT_SETTING = ( + naive_min_max_test_utils.DEFAULT_WEIGHT_QUANT_SETTING +) + + +class SumTest(naive_min_max_test_utils.NaiveMinMaxQuantizeTest): + + def setUp(self): + super().setUp() + np.random.seed(666) + self._test_model_path = os.path.join( + _TEST_DATA_PREFIX_PATH, "single_sum.tflite" + ) + self._op_test_info = _OpTestInfo( + test_model=tfl_flatbuffer_utils.read_model(self._test_model_path), + op_tensor_names={}, + input_range=(np.array([[-10]]), np.array([[8]])), + output_range=(np.array([[10]]), np.array([[88]])), + ) + # The test model has one subgraph for now. + self._graph_info = qtyping.GraphInfo( + subgraph_tensors=self._op_test_info.test_model.subgraphs[0].tensors, + buffers=self._op_test_info.test_model.buffers, + ) + + @parameterized.parameters( + 8, + 16, + ) + def test_materialize_sum_succeeds(self, num_bits): + activation_tensor_config = _TensorQuantConfig( + num_bits=num_bits, + symmetric=True, + granularity=qtyping.QuantGranularity.TENSORWISE, + ) + op_quant_config = qtyping.OpQuantizationConfig( + activation_tensor_config=activation_tensor_config, + weight_tensor_config=_DEFAULT_WEIGHT_QUANT_SETTING, + compute_precision=_ComputePrecision.INTEGER, # SRQ. + ) + # Read from Model Explorer. + subgraph0 = self._op_test_info.test_model.subgraphs[0] + subgraph_op_id = 0 + op = subgraph0.operators[subgraph_op_id] + op_info = qtyping.OpInfo( + op=op, + op_name=qtyping.TFLOperationName.SUM, + subgraph_op_index=subgraph_op_id, + op_quant_config=op_quant_config, + ) + + # Test settings. + op_tensor_names = {} + op_tensor_names["input"] = "serving_default_input_1:0" + op_tensor_names["input2"] = "model/tf.math.reduce_sum/Sum/reduction_indices" + op_tensor_names["output"] = "PartitionedCall:0" + self._op_test_info.op_tensor_names = op_tensor_names + self._test_no_weights_op( + op_info, + self._graph_info, + self._op_test_info, + naive_min_max_quantize.materialize_sum, + same_input_output_params=True, + inputs_to_ignore=[1], # Ignore axis tensor. + ) + + +if __name__ == "__main__": + googletest.main() diff --git a/ai_edge_quantizer/calibrator_test.py b/ai_edge_quantizer/calibrator_test.py index c6d677b..454889e 100644 --- a/ai_edge_quantizer/calibrator_test.py +++ b/ai_edge_quantizer/calibrator_test.py @@ -290,7 +290,7 @@ def test_toy_gemma2_calibration_success(self): self._toy_gemma2_calibration_dataset, model_recipe_manager=recipe_mngr, ) - self.assertLen(calib.get_model_qsvs(), 274) + self.assertLen(calib.get_model_qsvs(), 282) if __name__ == "__main__": diff --git a/ai_edge_quantizer/default_policy.py b/ai_edge_quantizer/default_policy.py index 2de0a3f..65458f3 100644 --- a/ai_edge_quantizer/default_policy.py +++ b/ai_edge_quantizer/default_policy.py @@ -165,7 +165,8 @@ "INPUT", "OUTPUT", "SLICE", - "EMBEDDING_LOOKUP" + "EMBEDDING_LOOKUP", + "SUM" ], "static_wi8_ai8": [ "ADD", @@ -191,7 +192,8 @@ "INPUT", "OUTPUT", "SLICE", - "EMBEDDING_LOOKUP" + "EMBEDDING_LOOKUP", + "SUM" ], "static_wi4_ai8": ["FULLY_CONNECTED", "CONV_2D", "INPUT", "OUTPUT"], "static_wi4_ai16": ["FULLY_CONNECTED", "CONV_2D", "INPUT", "OUTPUT"], diff --git a/ai_edge_quantizer/qtyping.py b/ai_edge_quantizer/qtyping.py index 35e51f6..9dd692c 100644 --- a/ai_edge_quantizer/qtyping.py +++ b/ai_edge_quantizer/qtyping.py @@ -58,6 +58,7 @@ class TFLOperationName(str, enum.Enum): SPLIT = 'SPLIT' LOGISTIC = 'LOGISTIC' SLICE = 'SLICE' + SUM = 'SUM' class QuantizeMode(enum.Enum): diff --git a/ai_edge_quantizer/tests/end_to_end_tests/slice_test.py b/ai_edge_quantizer/tests/end_to_end_tests/slice_test.py index d99bd11..42f9106 100644 --- a/ai_edge_quantizer/tests/end_to_end_tests/slice_test.py +++ b/ai_edge_quantizer/tests/end_to_end_tests/slice_test.py @@ -78,7 +78,6 @@ def test_slice_model_full_integer(self, recipe_path, tensor_type): quantized_model = tfl_flatbuffer_utils.read_model( quantization_result.quantized_model ) - quantization_result.export_model('/tmp/quantized_slice.tflite') self.assertLen(quantized_model.subgraphs, 1) subgraph = quantized_model.subgraphs[0] subgraph_tensors = subgraph.tensors @@ -88,10 +87,10 @@ def test_slice_model_full_integer(self, recipe_path, tensor_type): size_tensor = subgraph_tensors[subgraph.inputs[1]] output_tensor = subgraph_tensors[subgraph.outputs[0]] # See schema_py_generated.py for type code. - self.assertEqual(input_tensor.type, tensor_type) # float32. + self.assertEqual(input_tensor.type, tensor_type) self.assertEqual(begin_tensor.type, 2) # int32. self.assertEqual(size_tensor.type, 2) # int32. - self.assertEqual(output_tensor.type, tensor_type) # float32. + self.assertEqual(output_tensor.type, tensor_type) comparison_result = self._quantizer.validate( error_metrics='mse', test_data=_get_test_data(num_samples=1) diff --git a/ai_edge_quantizer/tests/end_to_end_tests/sum_test.py b/ai_edge_quantizer/tests/end_to_end_tests/sum_test.py new file mode 100644 index 0000000..7e8447d --- /dev/null +++ b/ai_edge_quantizer/tests/end_to_end_tests/sum_test.py @@ -0,0 +1,124 @@ +# Copyright 2024 The AI Edge Quantizer Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""E2E tests for the quantizer for model with transpose.""" + +from typing import Any + +from absl.testing import parameterized +import numpy as np + +from tensorflow.python.platform import googletest +from ai_edge_quantizer import qtyping +from ai_edge_quantizer import quantizer +from ai_edge_quantizer.utils import test_utils +from ai_edge_quantizer.utils import tfl_flatbuffer_utils +from ai_edge_quantizer.utils import tfl_interpreter_utils + +_OpExecutionMode = qtyping.OpExecutionMode +_OpName = qtyping.TFLOperationName +_TensorQuantConfig = qtyping.TensorQuantizationConfig +_OpQuantConfig = qtyping.OpQuantizationConfig + +_RNG = np.random.default_rng(66) + + +def _get_dummy_data( + num_samples: int, dtype: np.dtype = np.float32 +) -> list[dict[str, Any]]: + data = [] + for _ in range(num_samples): + data.append({'input_1': _RNG.uniform(size=(2, 3)).astype(dtype)}) + return data + + +def _get_calibration_data( + num_samples: int = 128, dtype: np.dtype = np.float32 +) -> list[dict[str, Any]]: + calibration_samples = _get_dummy_data(num_samples, dtype) + calibration_data = { + tfl_interpreter_utils.DEFAULT_SIGNATURE_KEY: calibration_samples, + } + return calibration_data + + +def _get_test_data( + num_samples: int = 8, dtype: np.dtype = np.float32 +) -> list[dict[str, Any]]: + return _get_calibration_data(num_samples, dtype) + + +class SumTest(parameterized.TestCase): + + def setUp(self): + super().setUp() + self.float_model_path = test_utils.get_path_to_datafile( + '../models/single_sum.tflite' + ) + self._quantizer = quantizer.Quantizer(self.float_model_path) + + @parameterized.named_parameters( + dict( + testcase_name='int8_quantized', + recipe_path='../../recipes/default_a8w8_recipe.json', + tensor_type=9, + tol=1e-4, + ), + dict( + testcase_name='int16_quantized', + recipe_path='../../recipes/default_a16w8_recipe.json', + tensor_type=7, + tol=2.5, # TODO(b/379757798): Update tolerance after bug is fixed. + ), + ) + def test_sum_model_full_integer(self, recipe_path, tensor_type, tol): + recipe_path = test_utils.get_path_to_datafile(recipe_path) + self._quantizer.load_quantization_recipe(recipe_path) + self.assertTrue(self._quantizer.need_calibration) + + data = _get_calibration_data() + calibration_result = self._quantizer.calibrate(data) + + quantization_result = self._quantizer.quantize(calibration_result) + + # Check input/output tensor type. + quantized_model = tfl_flatbuffer_utils.read_model( + quantization_result.quantized_model + ) + self.assertLen(quantized_model.subgraphs, 1) + subgraph = quantized_model.subgraphs[0] + subgraph_tensors = subgraph.tensors + self.assertLen(subgraph.inputs, 1) + input_tensor = subgraph_tensors[subgraph.inputs[0]] + output_tensor = subgraph_tensors[subgraph.outputs[0]] + # See schema_py_generated.py for type code. + self.assertEqual(input_tensor.type, tensor_type) + self.assertEqual(output_tensor.type, tensor_type) + + comparison_result = self._quantizer.validate( + error_metrics='mse', + test_data=_get_test_data(num_samples=1), + ) + self._check_comparison_result(comparison_result, output_tolerance=tol) + + def _check_comparison_result(self, comparison_result, output_tolerance): + # TODO: b/357959309 - Use comparison result directly for testing. + comparison_result = comparison_result.get_all_tensor_results() + output_mse = comparison_result['PartitionedCall:0'] + self.assertLess(output_mse, output_tolerance) + + +if __name__ == '__main__': + googletest.main() diff --git a/ai_edge_quantizer/tests/models/single_sum.tflite b/ai_edge_quantizer/tests/models/single_sum.tflite new file mode 100644 index 0000000000000000000000000000000000000000..2e4337cc77a9f91fa07f21c2acca58253e298727 GIT binary patch literal 1016 zcmZWou}&L75FHW|Bd~w~N`i3V!u8osQUIc{3`igh5m>1r9Vc#qm3(J(-a{!Pt{$47ps4uiCBY+ece z4_THyfTleIm_siL%(0iKFg;NIFpY8!KL34{Wz=Or$0w=rTf*Fz0R0!JHCouS(4C#_ z*K50b8#~*6cdgf6ZTH&fsbc#YE8>{P*ht5IlEy|y%0E@fP!D2x1phhk6*vPv0tR4C z0xs%(An%{^n!K(}UmXxs;Y?VI@s!H-j7xS-mp}e>KmYytp1rtPF#fsMTJoAAzX~qa zYX^}E=vP5vw88dN*a`amNC9$vLP#>rv(*nfg&ZuJd0|Rxqmra^b{{bkBjp6_R literal 0 HcmV?d00001 diff --git a/ai_edge_quantizer/utils/tfl_flatbuffer_utils.py b/ai_edge_quantizer/utils/tfl_flatbuffer_utils.py index 546189c..6058760 100644 --- a/ai_edge_quantizer/utils/tfl_flatbuffer_utils.py +++ b/ai_edge_quantizer/utils/tfl_flatbuffer_utils.py @@ -60,6 +60,7 @@ _TFLOpName.SPLIT: schema_py_generated.BuiltinOperator.SPLIT, _TFLOpName.LOGISTIC: schema_py_generated.BuiltinOperator.LOGISTIC, _TFLOpName.SLICE: schema_py_generated.BuiltinOperator.SLICE, + _TFLOpName.SUM: schema_py_generated.BuiltinOperator.SUM, }) TFL_OP_CODE_TO_NAME = immutabledict.immutabledict(