diff --git a/ai_edge_quantizer/model_validator.py b/ai_edge_quantizer/model_validator.py index 7ec9bdf..b9fbe21 100644 --- a/ai_edge_quantizer/model_validator.py +++ b/ai_edge_quantizer/model_validator.py @@ -224,7 +224,7 @@ def _setup_validation_interpreter( """ interpreter = utils.create_tfl_interpreter( - tflite_model=model, use_reference_kernel=use_reference_kernel + tflite_model=model ) utils.invoke_interpreter_signature( interpreter, signature_input, signature_key diff --git a/ai_edge_quantizer/utils/tfl_interpreter_utils.py b/ai_edge_quantizer/utils/tfl_interpreter_utils.py index 4d46874..f064e1b 100644 --- a/ai_edge_quantizer/utils/tfl_interpreter_utils.py +++ b/ai_edge_quantizer/utils/tfl_interpreter_utils.py @@ -30,15 +30,16 @@ def create_tfl_interpreter( tflite_model: Union[str, bytes], allocate_tensors: bool = True, - use_reference_kernel: bool = False, + use_xnnpack: bool = True, + num_threads: int = 4, ) -> tfl.Interpreter: """Creates a TFLite interpreter from a model file. Args: tflite_model: Model file path or bytes. allocate_tensors: Whether to allocate tensors. - use_reference_kernel: Whether to use the reference kernel for the - interpreter. + use_xnnpack: Whether to use the XNNPACK delegate for the interpreter. + num_threads: The number of threads to use for the interpreter. Returns: A TFLite interpreter. @@ -47,12 +48,14 @@ def create_tfl_interpreter( with gfile.GFile(tflite_model, "rb") as f: tflite_model = f.read() - if use_reference_kernel: - op_resolver = tfl.OpResolverType.BUILTIN_REF + if use_xnnpack: + op_resolver = tfl.OpResolverType.BUILTIN_WITH_XNNPACK else: op_resolver = tfl.OpResolverType.BUILTIN_WITHOUT_DEFAULT_DELEGATES + tflite_interpreter = tfl.Interpreter( model_content=bytes(tflite_model), + num_threads=num_threads, experimental_op_resolver_type=op_resolver, experimental_preserve_all_tensors=True, )