Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable OpResolver with XNNPACK for quantizer #174

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion ai_edge_quantizer/model_validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,7 +224,7 @@ def _setup_validation_interpreter(
"""

interpreter = utils.create_tfl_interpreter(
tflite_model=model, use_reference_kernel=use_reference_kernel
tflite_model=model
)
utils.invoke_interpreter_signature(
interpreter, signature_input, signature_key
Expand Down
13 changes: 8 additions & 5 deletions ai_edge_quantizer/utils/tfl_interpreter_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,15 +30,16 @@
def create_tfl_interpreter(
tflite_model: Union[str, bytes],
allocate_tensors: bool = True,
use_reference_kernel: bool = False,
use_xnnpack: bool = True,
num_threads: int = 4,
) -> tfl.Interpreter:
"""Creates a TFLite interpreter from a model file.
Args:
tflite_model: Model file path or bytes.
allocate_tensors: Whether to allocate tensors.
use_reference_kernel: Whether to use the reference kernel for the
interpreter.
use_xnnpack: Whether to use the XNNPACK delegate for the interpreter.
num_threads: The number of threads to use for the interpreter.
Returns:
A TFLite interpreter.
Expand All @@ -47,12 +48,14 @@ def create_tfl_interpreter(
with gfile.GFile(tflite_model, "rb") as f:
tflite_model = f.read()

if use_reference_kernel:
op_resolver = tfl.OpResolverType.BUILTIN_REF
if use_xnnpack:
op_resolver = tfl.OpResolverType.BUILTIN_WITH_XNNPACK
else:
op_resolver = tfl.OpResolverType.BUILTIN_WITHOUT_DEFAULT_DELEGATES

tflite_interpreter = tfl.Interpreter(
model_content=bytes(tflite_model),
num_threads=num_threads,
experimental_op_resolver_type=op_resolver,
experimental_preserve_all_tensors=True,
)
Expand Down
Loading