Skip to content

Commit

Permalink
Add static quantization for SUM in Quantizer
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 697818956
  • Loading branch information
v-dziuba authored and copybara-github committed Nov 19, 2024
1 parent 3c79e16 commit 230c1fc
Show file tree
Hide file tree
Showing 10 changed files with 255 additions and 6 deletions.
2 changes: 2 additions & 0 deletions ai_edge_quantizer/algorithm_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ class AlgorithmName(str, enum.Enum):
_TFLOpName.SPLIT,
_TFLOpName.LOGISTIC, # Sigmoid
_TFLOpName.SLICE,
_TFLOpName.SUM,
),
(
naive_min_max_quantize.materialize_input,
Expand All @@ -116,6 +117,7 @@ class AlgorithmName(str, enum.Enum):
naive_min_max_quantize.materialize_split,
naive_min_max_quantize.materialize_softmax_and_logistic,
naive_min_max_quantize.materialize_slice,
naive_min_max_quantize.materialize_sum,
),
):
register_quantized_op(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -325,6 +325,21 @@ def materialize_slice(
)


def materialize_sum(
op_info: qtyping.OpInfo,
graph_info: qtyping.GraphInfo,
tensor_name_to_qsv: dict[str, Any],
) -> list[qtyping.TensorTransformationParams]:
"""Materialize tensors in tfl.sum."""
return utils.materialize_standard_op(
op_info,
graph_info,
tensor_name_to_qsv,
constraint=_OpQuantConstraint.SAME_AS_INPUT_SCALE,
inputs_to_ignore=[1], # Axis index does not need to be quantized.
)


def materialize_fc_conv(
op_info: qtyping.OpInfo,
graph_info: qtyping.GraphInfo,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
# Copyright 2024 The AI Edge Quantizer Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

import os

from absl.testing import parameterized
import numpy as np

from tensorflow.python.platform import googletest
from ai_edge_quantizer import qtyping
from ai_edge_quantizer.algorithms.uniform_quantize import naive_min_max_quantize
from ai_edge_quantizer.algorithms.uniform_quantize.naive_min_max_quantize_op_tests import test_utils as naive_min_max_test_utils
from ai_edge_quantizer.utils import test_utils
from ai_edge_quantizer.utils import tfl_flatbuffer_utils

_TFLOpName = qtyping.TFLOperationName
_ComputePrecision = qtyping.ComputePrecision
_TensorQuantConfig = qtyping.TensorQuantizationConfig
_QuantTransformation = qtyping.QuantTransformation
_OpTestInfo = naive_min_max_test_utils.OpTestInfo

_TEST_DATA_PREFIX_PATH = test_utils.get_path_to_datafile(
"../../../tests/models"
)
_DEFAULT_WEIGHT_QUANT_SETTING = (
naive_min_max_test_utils.DEFAULT_WEIGHT_QUANT_SETTING
)


class SumTest(naive_min_max_test_utils.NaiveMinMaxQuantizeTest):

def setUp(self):
super().setUp()
np.random.seed(666)
self._test_model_path = os.path.join(
_TEST_DATA_PREFIX_PATH, "single_sum.tflite"
)
self._op_test_info = _OpTestInfo(
test_model=tfl_flatbuffer_utils.read_model(self._test_model_path),
op_tensor_names={},
input_range=(np.array([[-10]]), np.array([[8]])),
output_range=(np.array([[10]]), np.array([[88]])),
)
# The test model has one subgraph for now.
self._graph_info = qtyping.GraphInfo(
subgraph_tensors=self._op_test_info.test_model.subgraphs[0].tensors,
buffers=self._op_test_info.test_model.buffers,
)

@parameterized.parameters(
8,
16,
)
def test_materialize_sum_succeeds(self, num_bits):
activation_tensor_config = _TensorQuantConfig(
num_bits=num_bits,
symmetric=True,
granularity=qtyping.QuantGranularity.TENSORWISE,
)
op_quant_config = qtyping.OpQuantizationConfig(
activation_tensor_config=activation_tensor_config,
weight_tensor_config=_DEFAULT_WEIGHT_QUANT_SETTING,
compute_precision=_ComputePrecision.INTEGER, # SRQ.
)
# Read from Model Explorer.
subgraph0 = self._op_test_info.test_model.subgraphs[0]
subgraph_op_id = 0
op = subgraph0.operators[subgraph_op_id]
op_info = qtyping.OpInfo(
op=op,
op_name=qtyping.TFLOperationName.SUM,
subgraph_op_index=subgraph_op_id,
op_quant_config=op_quant_config,
)

# Test settings.
op_tensor_names = {}
op_tensor_names["input"] = "serving_default_input_1:0"
op_tensor_names["input2"] = "model/tf.math.reduce_sum/Sum/reduction_indices"
op_tensor_names["output"] = "PartitionedCall:0"
self._op_test_info.op_tensor_names = op_tensor_names
self._test_no_weights_op(
op_info,
self._graph_info,
self._op_test_info,
naive_min_max_quantize.materialize_sum,
same_input_output_params=True,
inputs_to_ignore=[1], # Ignore axis tensor.
)


if __name__ == "__main__":
googletest.main()
2 changes: 1 addition & 1 deletion ai_edge_quantizer/calibrator_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -290,7 +290,7 @@ def test_toy_gemma2_calibration_success(self):
self._toy_gemma2_calibration_dataset,
model_recipe_manager=recipe_mngr,
)
self.assertLen(calib.get_model_qsvs(), 274)
self.assertLen(calib.get_model_qsvs(), 282)


if __name__ == "__main__":
Expand Down
6 changes: 4 additions & 2 deletions ai_edge_quantizer/default_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,8 @@
"INPUT",
"OUTPUT",
"SLICE",
"EMBEDDING_LOOKUP"
"EMBEDDING_LOOKUP",
"SUM"
],
"static_wi8_ai8": [
"ADD",
Expand All @@ -191,7 +192,8 @@
"INPUT",
"OUTPUT",
"SLICE",
"EMBEDDING_LOOKUP"
"EMBEDDING_LOOKUP",
"SUM"
],
"static_wi4_ai8": ["FULLY_CONNECTED", "CONV_2D", "INPUT", "OUTPUT"],
"static_wi4_ai16": ["FULLY_CONNECTED", "CONV_2D", "INPUT", "OUTPUT"],
Expand Down
1 change: 1 addition & 0 deletions ai_edge_quantizer/qtyping.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ class TFLOperationName(str, enum.Enum):
SPLIT = 'SPLIT'
LOGISTIC = 'LOGISTIC'
SLICE = 'SLICE'
SUM = 'SUM'


class QuantizeMode(enum.Enum):
Expand Down
5 changes: 2 additions & 3 deletions ai_edge_quantizer/tests/end_to_end_tests/slice_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,6 @@ def test_slice_model_full_integer(self, recipe_path, tensor_type):
quantized_model = tfl_flatbuffer_utils.read_model(
quantization_result.quantized_model
)
quantization_result.export_model('/tmp/quantized_slice.tflite')
self.assertLen(quantized_model.subgraphs, 1)
subgraph = quantized_model.subgraphs[0]
subgraph_tensors = subgraph.tensors
Expand All @@ -88,10 +87,10 @@ def test_slice_model_full_integer(self, recipe_path, tensor_type):
size_tensor = subgraph_tensors[subgraph.inputs[1]]
output_tensor = subgraph_tensors[subgraph.outputs[0]]
# See schema_py_generated.py for type code.
self.assertEqual(input_tensor.type, tensor_type) # float32.
self.assertEqual(input_tensor.type, tensor_type)
self.assertEqual(begin_tensor.type, 2) # int32.
self.assertEqual(size_tensor.type, 2) # int32.
self.assertEqual(output_tensor.type, tensor_type) # float32.
self.assertEqual(output_tensor.type, tensor_type)

comparison_result = self._quantizer.validate(
error_metrics='mse', test_data=_get_test_data(num_samples=1)
Expand Down
124 changes: 124 additions & 0 deletions ai_edge_quantizer/tests/end_to_end_tests/sum_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
# Copyright 2024 The AI Edge Quantizer Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

"""E2E tests for the quantizer for model with transpose."""

from typing import Any

from absl.testing import parameterized
import numpy as np

from tensorflow.python.platform import googletest
from ai_edge_quantizer import qtyping
from ai_edge_quantizer import quantizer
from ai_edge_quantizer.utils import test_utils
from ai_edge_quantizer.utils import tfl_flatbuffer_utils
from ai_edge_quantizer.utils import tfl_interpreter_utils

_OpExecutionMode = qtyping.OpExecutionMode
_OpName = qtyping.TFLOperationName
_TensorQuantConfig = qtyping.TensorQuantizationConfig
_OpQuantConfig = qtyping.OpQuantizationConfig

_RNG = np.random.default_rng(66)


def _get_dummy_data(
num_samples: int, dtype: np.dtype = np.float32
) -> list[dict[str, Any]]:
data = []
for _ in range(num_samples):
data.append({'input_1': _RNG.uniform(size=(2, 3)).astype(dtype)})
return data


def _get_calibration_data(
num_samples: int = 128, dtype: np.dtype = np.float32
) -> list[dict[str, Any]]:
calibration_samples = _get_dummy_data(num_samples, dtype)
calibration_data = {
tfl_interpreter_utils.DEFAULT_SIGNATURE_KEY: calibration_samples,
}
return calibration_data


def _get_test_data(
num_samples: int = 8, dtype: np.dtype = np.float32
) -> list[dict[str, Any]]:
return _get_calibration_data(num_samples, dtype)


class SumTest(parameterized.TestCase):

def setUp(self):
super().setUp()
self.float_model_path = test_utils.get_path_to_datafile(
'../models/single_sum.tflite'
)
self._quantizer = quantizer.Quantizer(self.float_model_path)

@parameterized.named_parameters(
dict(
testcase_name='int8_quantized',
recipe_path='../../recipes/default_a8w8_recipe.json',
tensor_type=9,
tol=1e-4,
),
dict(
testcase_name='int16_quantized',
recipe_path='../../recipes/default_a16w8_recipe.json',
tensor_type=7,
tol=2.5, # TODO(b/379757798): Update tolerance after bug is fixed.
),
)
def test_sum_model_full_integer(self, recipe_path, tensor_type, tol):
recipe_path = test_utils.get_path_to_datafile(recipe_path)
self._quantizer.load_quantization_recipe(recipe_path)
self.assertTrue(self._quantizer.need_calibration)

data = _get_calibration_data()
calibration_result = self._quantizer.calibrate(data)

quantization_result = self._quantizer.quantize(calibration_result)

# Check input/output tensor type.
quantized_model = tfl_flatbuffer_utils.read_model(
quantization_result.quantized_model
)
self.assertLen(quantized_model.subgraphs, 1)
subgraph = quantized_model.subgraphs[0]
subgraph_tensors = subgraph.tensors
self.assertLen(subgraph.inputs, 1)
input_tensor = subgraph_tensors[subgraph.inputs[0]]
output_tensor = subgraph_tensors[subgraph.outputs[0]]
# See schema_py_generated.py for type code.
self.assertEqual(input_tensor.type, tensor_type)
self.assertEqual(output_tensor.type, tensor_type)

comparison_result = self._quantizer.validate(
error_metrics='mse',
test_data=_get_test_data(num_samples=1),
)
self._check_comparison_result(comparison_result, output_tolerance=tol)

def _check_comparison_result(self, comparison_result, output_tolerance):
# TODO: b/357959309 - Use comparison result directly for testing.
comparison_result = comparison_result.get_all_tensor_results()
output_mse = comparison_result['PartitionedCall:0']
self.assertLess(output_mse, output_tolerance)


if __name__ == '__main__':
googletest.main()
Binary file added ai_edge_quantizer/tests/models/single_sum.tflite
Binary file not shown.
1 change: 1 addition & 0 deletions ai_edge_quantizer/utils/tfl_flatbuffer_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@
_TFLOpName.SPLIT: schema_py_generated.BuiltinOperator.SPLIT,
_TFLOpName.LOGISTIC: schema_py_generated.BuiltinOperator.LOGISTIC,
_TFLOpName.SLICE: schema_py_generated.BuiltinOperator.SLICE,
_TFLOpName.SUM: schema_py_generated.BuiltinOperator.SUM,
})

TFL_OP_CODE_TO_NAME = immutabledict.immutabledict(
Expand Down

0 comments on commit 230c1fc

Please sign in to comment.