From 3755bf69f4274d8afa90cb5be752a14cd2a159a6 Mon Sep 17 00:00:00 2001 From: Vitalii Dziuba Date: Thu, 5 Dec 2024 13:22:49 -0800 Subject: [PATCH] Enable OpResolver with XNNPACK for quantizer PiperOrigin-RevId: 703228017 --- ai_edge_quantizer/utils/tfl_interpreter_utils.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/ai_edge_quantizer/utils/tfl_interpreter_utils.py b/ai_edge_quantizer/utils/tfl_interpreter_utils.py index 4d46874..f064e1b 100644 --- a/ai_edge_quantizer/utils/tfl_interpreter_utils.py +++ b/ai_edge_quantizer/utils/tfl_interpreter_utils.py @@ -30,15 +30,16 @@ def create_tfl_interpreter( tflite_model: Union[str, bytes], allocate_tensors: bool = True, - use_reference_kernel: bool = False, + use_xnnpack: bool = True, + num_threads: int = 4, ) -> tfl.Interpreter: """Creates a TFLite interpreter from a model file. Args: tflite_model: Model file path or bytes. allocate_tensors: Whether to allocate tensors. - use_reference_kernel: Whether to use the reference kernel for the - interpreter. + use_xnnpack: Whether to use the XNNPACK delegate for the interpreter. + num_threads: The number of threads to use for the interpreter. Returns: A TFLite interpreter. @@ -47,12 +48,14 @@ def create_tfl_interpreter( with gfile.GFile(tflite_model, "rb") as f: tflite_model = f.read() - if use_reference_kernel: - op_resolver = tfl.OpResolverType.BUILTIN_REF + if use_xnnpack: + op_resolver = tfl.OpResolverType.BUILTIN_WITH_XNNPACK else: op_resolver = tfl.OpResolverType.BUILTIN_WITHOUT_DEFAULT_DELEGATES + tflite_interpreter = tfl.Interpreter( model_content=bytes(tflite_model), + num_threads=num_threads, experimental_op_resolver_type=op_resolver, experimental_preserve_all_tensors=True, )