Skip to content

Commit

Permalink
Further reduce memory usage in Quantizer
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 679627364
  • Loading branch information
v-dziuba authored and copybara-github committed Oct 4, 2024
1 parent 1681b00 commit 8b11f46
Show file tree
Hide file tree
Showing 10 changed files with 42 additions and 38 deletions.
3 changes: 2 additions & 1 deletion ai_edge_quantizer/calibrator.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,10 @@ class Calibrator:

def __init__(
self,
float_tflite: Union[str, bytearray],
float_tflite: Union[str, bytes],
):
self._flatbuffer_model = tfl_flatbuffer_utils.read_model(float_tflite)

if not tfl_flatbuffer_utils.is_float_model(self._flatbuffer_model):
raise ValueError(
"The input model for calibration is not a float model. Please check"
Expand Down
15 changes: 3 additions & 12 deletions ai_edge_quantizer/model_modifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,35 +16,27 @@
"""Model Modifier class that produce the final quantized TFlite model."""

import copy
from typing import Union

import numpy as np

from ai_edge_quantizer import qtyping
from ai_edge_quantizer import transformation_instruction_generator
from ai_edge_quantizer import transformation_performer
from ai_edge_quantizer.utils import tfl_flatbuffer_utils
from ai_edge_litert import schema_py_generated # pylint: disable=g-direct-tensorflow-import
from tensorflow.lite.tools import flatbuffer_utils # pylint: disable=g-direct-tensorflow-import


class ModelModifier:
"""Model Modifier class that produce the final quantized TFlite model."""

# TODO: b/336599483 - support byte array as input
def __init__(self, float_tflite: Union[str, bytearray]):
def __init__(self, float_tflite: bytes):
"""Constructor.
Args:
float_tflite: the original TFlite model in bytearray or file path
"""

if isinstance(float_tflite, str):
self._model_bytearray = tfl_flatbuffer_utils.get_model_buffer(
float_tflite
)
else:
self._model_bytearray = float_tflite
self._model_content = float_tflite

self._constant_map = []
self._transformation_instruction_generator = (
Expand All @@ -66,9 +58,8 @@ def modify_model(
Returns:
a byte buffer that represents the serialized tflite model
"""

quantized_model = copy.deepcopy(
flatbuffer_utils.read_model_from_bytearray(self._model_bytearray)
flatbuffer_utils.read_model_from_bytearray(self._model_content)
)

instructions = self._transformation_instruction_generator.quant_params_to_transformation_insts(
Expand Down
21 changes: 10 additions & 11 deletions ai_edge_quantizer/model_modifier_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,10 +37,11 @@ def setUp(self):
self._model_path = os.path.join(
TEST_DATA_PREFIX_PATH, 'tests/models/conv_fc_mnist.tflite'
)
self._model_modifier = model_modifier.ModelModifier(self._model_path)
self._model_buffer: bytearray = tfl_flatbuffer_utils.get_model_buffer(

self._model_content: bytes = tfl_flatbuffer_utils.get_model_content(
self._model_path
)
self._model_modifier = model_modifier.ModelModifier(self._model_content)
self._global_recipe = [
{
'regex': '.*',
Expand All @@ -62,13 +63,11 @@ def setUp(self):
]

def test_process_constant_map_succeeds(self):
constant_size = self._model_modifier._process_constant_map(
flatbuffer_utils.read_model_from_bytearray(
self._model_modifier._model_bytearray
)
model_bytearray = flatbuffer_utils.read_model_from_bytearray(
self._model_content
)
constant_size = self._model_modifier._process_constant_map(model_bytearray)
self.assertEqual(constant_size, 202540)
pass

def test_modify_model_succeeds_with_recipe(self):
recipe_manager_instance = recipe_manager.RecipeManager()
Expand All @@ -86,7 +85,7 @@ def test_modify_model_succeeds_with_recipe(self):
tensor_quantization_params
)
flatbuffer_utils.convert_bytearray_to_object(new_model_binary)
self.assertLess(new_model_binary, self._model_buffer)
self.assertLess(new_model_binary, self._model_content)

def test_modify_model_preserves_original_model(self):
recipe_manager_instance = recipe_manager.RecipeManager()
Expand All @@ -100,9 +99,9 @@ def test_modify_model_preserves_original_model(self):
recipe_manager_instance
)
)
self.assertEqual(self._model_modifier._model_bytearray, self._model_buffer)
self.assertEqual(self._model_modifier._model_content, self._model_content)
self._model_modifier.modify_model(tensor_quantization_params)
self.assertEqual(self._model_modifier._model_bytearray, self._model_buffer)
self.assertEqual(self._model_modifier._model_content, self._model_content)

def test_modify_model_peak_memory_usage_in_acceptable_range(self):
"""Test ModifyModel peak memory usage."""
Expand All @@ -124,7 +123,7 @@ def test_modify_model_peak_memory_usage_in_acceptable_range(self):
_, mem_peak = tracemalloc.get_traced_memory()

loosen_mem_use_factor = 4.5
self.assertLess(mem_peak / len(self._model_buffer), loosen_mem_use_factor)
self.assertLess(mem_peak / len(self._model_content), loosen_mem_use_factor)


if __name__ == '__main__':
Expand Down
4 changes: 2 additions & 2 deletions ai_edge_quantizer/model_validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,8 +224,8 @@ def _setup_validation_interpreter(

# TODO: b/330797129 - Enable multi-threaded evaluation.
def compare_model(
reference_model: Union[str, bytearray],
target_model: Union[str, bytearray],
reference_model: Union[str, bytes],
target_model: Union[str, bytes],
test_data: dict[str, Iterable[dict[str, Any]]],
error_metric: str,
compare_fn: Callable[[Any, Any], float],
Expand Down
3 changes: 2 additions & 1 deletion ai_edge_quantizer/params_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,9 @@
class ParamsGenerator:
"""Generate model tensor level quantization parameters."""

def __init__(self, float_tflite: Union[str, bytearray]):
def __init__(self, float_tflite: Union[str, bytes]):
self.flatbuffer_model = tfl_flatbuffer_utils.read_model(float_tflite)

if not tfl_flatbuffer_utils.is_float_model(self.flatbuffer_model):
raise ValueError(
'The input model for quantization parameters generation is not a'
Expand Down
6 changes: 3 additions & 3 deletions ai_edge_quantizer/quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,9 +125,9 @@ def __init__(
quantization_recipe: Quantization recipe in .json filepath or loaded json
format.
"""
# Turn the `float model` into bytearray for memory efficiency.
self.float_model: bytearray = (
tfl_flatbuffer_utils.get_model_buffer(float_model)
# Use `float model` as bytes for memory efficiency.
self.float_model: bytes = (
tfl_flatbuffer_utils.get_model_content(float_model)
if isinstance(float_model, str)
else float_model
)
Expand Down
4 changes: 2 additions & 2 deletions ai_edge_quantizer/tests/end_to_end_tests/transpose_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ def test_transpose_model_full_integer(

# Check tensor dtypes.
quantized_model = tfl_flatbuffer_utils.read_model(
bytearray(quantization_result.quantized_model)
quantization_result.quantized_model
)
self.assertLen(quantized_model.subgraphs, 1)
subgraph = quantized_model.subgraphs[0]
Expand Down Expand Up @@ -133,7 +133,7 @@ def test_quantize_integer_transpose(self, recipe_path):
)
quantization_result = self._quantizer.quantize(calibration_result)
quantized_model = tfl_flatbuffer_utils.read_model(
bytearray(quantization_result.quantized_model)
quantization_result.quantized_model
)
self.assertLen(quantized_model.subgraphs, 1)
subgraph = quantized_model.subgraphs[0]
Expand Down
2 changes: 1 addition & 1 deletion ai_edge_quantizer/utils/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def create_random_normal_dataset(


def create_random_normal_input_data(
tflite_model: Union[str, bytearray],
tflite_model: Union[str, bytes],
num_samples: int = 4,
random_seed: int = 666,
) -> dict[str, list[dict[str, Any]]]:
Expand Down
15 changes: 14 additions & 1 deletion ai_edge_quantizer/utils/tfl_flatbuffer_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,14 +102,27 @@ def read_model(tflite_model: Union[str, bytearray]) -> Any:
"""
if isinstance(tflite_model, str):
return flatbuffer_utils.read_model(tflite_model)
elif isinstance(tflite_model, bytearray):
elif isinstance(tflite_model, bytes) or isinstance(tflite_model, bytearray):
return flatbuffer_utils.read_model_from_bytearray(tflite_model)
else:
raise ValueError(
"Unsupported tflite_model type: %s" % type(tflite_model).__name__
)


def get_model_content(tflite_path: str) -> bytes:
"""Get the model content (bytes) from the path.
Args:
tflite_path: Path to the .tflite.
Returns:
The model bytes.
"""
with gfile.Open(tflite_path, "rb") as tflite_file:
return tflite_file.read()


def get_model_buffer(tflite_path: str) -> bytearray:
"""Get the model buffer from the path.
Expand Down
7 changes: 3 additions & 4 deletions ai_edge_quantizer/utils/tfl_interpreter_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@


def create_tfl_interpreter(
tflite_model: Union[str, bytearray],
tflite_model: Union[str, bytes],
allocate_tensors: bool = True,
use_reference_kernel: bool = False,
) -> tfl.Interpreter:
Expand All @@ -46,14 +46,13 @@ def create_tfl_interpreter(
if isinstance(tflite_model, str):
with gfile.GFile(tflite_model, "rb") as f:
tflite_model = f.read()
else:
tflite_model = bytes(tflite_model)

if use_reference_kernel:
op_resolver = tfl.OpResolverType.BUILTIN_REF
else:
op_resolver = tfl.OpResolverType.BUILTIN_WITHOUT_DEFAULT_DELEGATES
tflite_interpreter = tfl.Interpreter(
model_content=tflite_model,
model_content=bytes(tflite_model),
experimental_op_resolver_type=op_resolver,
experimental_preserve_all_tensors=True,
)
Expand Down

0 comments on commit 8b11f46

Please sign in to comment.