From 93084c80f53932e2618eaf39020b3516a03eeb74 Mon Sep 17 00:00:00 2001 From: yinnengzhong Date: Fri, 10 Jan 2025 06:14:32 +0000 Subject: [PATCH] update torch 2.5.1 --- mqbench/convert_deploy.py | 101 +++++---- .../custom_quantizer/academic_quantizer.py | 10 +- mqbench/custom_quantizer/model_quantizer.py | 43 ++-- .../custom_quantizer/onnx_qnn_quantizer.py | 10 +- .../custom_quantizer/openvino_quantizer.py | 12 +- mqbench/custom_quantizer/vitis_quantizer.py | 12 +- mqbench/custom_symbolic_opset.py | 185 +++++++++++++++-- mqbench/deploy/common.py | 16 ++ mqbench/deploy/deploy_linear.py | 114 ++++++----- mqbench/deploy/deploy_nnie.py | 13 +- mqbench/deploy/deploy_onnx_qlinear.py | 42 ++-- mqbench/deploy/deploy_onnx_qnn.py | 191 +++++++++++------- mqbench/deploy/deploy_openvino.py | 57 +++--- mqbench/deploy/deploy_stpu.py | 48 +++-- mqbench/deploy/deploy_tengine.py | 51 +++-- mqbench/fake_quantize/dorefa.py | 2 +- mqbench/fake_quantize/dsq.py | 42 +++- mqbench/fake_quantize/fixed.py | 5 +- mqbench/fake_quantize/lsq.py | 53 +++-- mqbench/fake_quantize/nnie.py | 10 +- mqbench/fake_quantize/tqt.py | 22 +- mqbench/fuser_method_mappings.py | 186 ++++++++--------- mqbench/fusion_method.py | 12 +- mqbench/observer.py | 41 ++-- mqbench/prepare_by_platform.py | 39 +++- mqbench/quantization/qconfig_mapping_utils.py | 36 ++++ 26 files changed, 891 insertions(+), 462 deletions(-) create mode 100644 mqbench/quantization/qconfig_mapping_utils.py diff --git a/mqbench/convert_deploy.py b/mqbench/convert_deploy.py index 93592bff..5effb201 100644 --- a/mqbench/convert_deploy.py +++ b/mqbench/convert_deploy.py @@ -21,9 +21,13 @@ remove_fakequantize_and_collect_params_stpu, ONNXQLinearPass, ONNXQNNPass ) - +import onnx +from onnxsim import simplify +from mqbench.deploy.common import ( + parse_attrs +) __all__ = ['convert_deploy'] - +qmin_max_dict = {} @register_deploy_function(BackendType.STPU) @register_deploy_function(BackendType.Tengine_u8) @register_deploy_function(BackendType.PPLCUDA) @@ -34,6 +38,7 @@ @register_deploy_function(BackendType.NNIE) @register_deploy_function(BackendType.Vitis) @register_deploy_function(BackendType.OPENVINO) +@register_deploy_function(BackendType.QDQ) def convert_merge_bn(model: GraphModule, **kwargs): logger.info("Merge BN for deploy.") nodes = list(model.graph.nodes) @@ -57,6 +62,7 @@ def convert_merge_bn(model: GraphModule, **kwargs): @register_deploy_function(BackendType.NNIE) @register_deploy_function(BackendType.Vitis) @register_deploy_function(BackendType.OPENVINO) +@register_deploy_function(BackendType.QDQ) def convert_onnx(model: GraphModule, input_shape_dict, dummy_input, onnx_model_path, **kwargs): logger.info("Export to onnx.") output_names = kwargs.get('output_names', []) @@ -68,29 +74,54 @@ def convert_onnx(model: GraphModule, input_shape_dict, dummy_input, onnx_model_p input_names = list(dummy_input.keys()) dummy_input = tuple(dummy_input.values()) # Per-channel QuantizeLinear and DequantizeLinear is supported since opset 13 - opset_version = 13 if kwargs.get('deploy_to_qlinear', False) else 11 + opset_version = 13 if kwargs.get('deploy_to_qlinear', False) else 13 with torch.no_grad(): + # try: + # from torch.onnx.utils import ONNXCheckerError + # try: + torch.onnx.export(model, dummy_input, onnx_model_path, + input_names=input_names, + output_names=output_names, + opset_version=opset_version, + dynamic_axes=dynamic_axes, + do_constant_folding=True, + custom_opsets={'' : opset_version}) + + + # except ONNXCheckerError: + # pass + # except ImportError: + # torch.onnx.export(model, dummy_input, onnx_model_path, + # input_names=input_names, + # output_names=output_names, + # opset_version=opset_version, + # do_constant_folding=True, + # custom_opsets={'' : opset_version}, + # enable_onnx_checker=False) + onnx_model = onnx.load(onnx_model_path) + graph = onnx_model.graph + for node in graph.node: + if len(node.attribute) > 1: + qparams = parse_attrs(node.attribute) + if 'quant_max' in qparams: + qmin_max_dict[node.name] = (qparams['quant_min'], qparams['quant_max']) + new_attributes = [] + for attr in node.attribute: + if attr.name not in ["quant_min", "quant_max"]: + new_attributes.append(attr) + node.ClearField("attribute") + node.attribute.extend(new_attributes) + onnx.save(onnx_model, onnx_model_path) try: - from torch.onnx.utils import ONNXCheckerError - try: - torch.onnx.export(model, dummy_input, onnx_model_path, - input_names=input_names, - output_names=output_names, - opset_version=opset_version, - dynamic_axes=dynamic_axes, - do_constant_folding=True, - custom_opsets={'' : opset_version}) - except ONNXCheckerError: - pass - except ImportError: - torch.onnx.export(model, dummy_input, onnx_model_path, - input_names=input_names, - output_names=output_names, - opset_version=opset_version, - do_constant_folding=True, - custom_opsets={'' : opset_version}, - enable_onnx_checker=False) - + logger.info("simplify model.") + onnx_model = onnx.load(onnx_model_path) + onnx_model_simplified, check = simplify(onnx_model) + onnx.save(onnx_model_simplified, onnx_model_path) + except Exception as e: + logger.info("simplify model fail.") + # onnx.checker.check_model(onnx_model_simplified) + # import onnxruntime as ort + # session = ort.InferenceSession(onnx_model_path) @register_deploy_function(BackendType.Tensorrt) def convert_onnx_qlinear(model: GraphModule, onnx_model_path, model_name, **kwargs): @@ -108,58 +139,58 @@ def deploy_qparams_nnie(model: GraphModule, onnx_model_path, model_name, **kwarg @register_deploy_function(BackendType.OPENVINO) def deploy_qparams_openvino(model: GraphModule, onnx_model_path, model_name, **kwargs): logger.info("Extract qparams for OPENVINO.") - replace_fakequantize_and_collect_params_openvino(onnx_model_path, model_name) + replace_fakequantize_and_collect_params_openvino(onnx_model_path, model_name, qmin_max_dict = qmin_max_dict) @register_deploy_function(BackendType.Tensorrt) def deploy_qparams_tensorrt(model: GraphModule, onnx_model_path, model_name, **kwargs): logger.info("Extract qparams for TensorRT.") - remove_fakequantize_and_collect_params(onnx_model_path, model_name, backend='tensorrt') + remove_fakequantize_and_collect_params(onnx_model_path, model_name, backend='tensorrt', qmin_max_dict = qmin_max_dict) @register_deploy_function(BackendType.Vitis) def deploy_qparams_vitis(model: GraphModule, onnx_model_path, model_name, **kwargs): logger.info("Extract qparams for Vitis-DPU.") - remove_fakequantize_and_collect_params(onnx_model_path, model_name, backend='vitis') + remove_fakequantize_and_collect_params(onnx_model_path, model_name, backend='vitis', qmin_max_dict = qmin_max_dict) @register_deploy_function(BackendType.SNPE) def deploy_qparams_snpe(model: GraphModule, onnx_model_path, model_name, **kwargs): logger.info("Extract qparams for SNPE.") - remove_fakequantize_and_collect_params(onnx_model_path, model_name, backend='snpe') + remove_fakequantize_and_collect_params(onnx_model_path, model_name, backend='snpe', qmin_max_dict = qmin_max_dict) @register_deploy_function(BackendType.PPLW8A16) def deploy_qparams_pplw8a16(model: GraphModule, onnx_model_path, model_name, **kwargs): logger.info("Extract qparams for PPLW8A16.") - remove_fakequantize_and_collect_params(onnx_model_path, model_name, backend='ppl') + remove_fakequantize_and_collect_params(onnx_model_path, model_name, backend='ppl', qmin_max_dict = qmin_max_dict) @register_deploy_function(BackendType.ONNX_QNN) def deploy_qparams_tvm(model: GraphModule, onnx_model_path, model_name, **kwargs): logger.info("Convert to ONNX QNN.") - ONNXQNNPass(onnx_model_path).run(model_name) + ONNXQNNPass(onnx_model_path).run(model_name, qmin_max_dict = qmin_max_dict) @register_deploy_function(BackendType.PPLCUDA) def deploy_qparams_ppl_cuda(model: GraphModule, onnx_model_path, model_name, **kwargs): logger.info("Extract qparams for PPL-CUDA.") - remove_fakequantize_and_collect_params(onnx_model_path, model_name, backend='ppl-cuda') + remove_fakequantize_and_collect_params(onnx_model_path, model_name, backend='ppl-cuda', qmin_max_dict = qmin_max_dict) @register_deploy_function(BackendType.Tengine_u8) def deploy_qparams_tengine(model: GraphModule, onnx_model_path, model_name, **kwargs): logger.info("Extract qparams for Tengine.") - remove_fakequantize_and_collect_params_tengine(onnx_model_path, model_name) + remove_fakequantize_and_collect_params_tengine(onnx_model_path, model_name, qmin_max_dict = qmin_max_dict) @register_deploy_function(BackendType.STPU) def deploy_qparams_stpu(model: GraphModule, onnx_model_path, model_name, **kwargs): logger.info("Extract qparams for STPU.") - remove_fakequantize_and_collect_params_stpu(onnx_model_path, model_name) + remove_fakequantize_and_collect_params_stpu(onnx_model_path, model_name, qmin_max_dict = qmin_max_dict) -def convert_deploy(model: GraphModule, backend_type: BackendType, +def convert_deploy(model: GraphModule, backend_type: BackendType, input_shape_dict=None, dummy_input=None, output_path='./', model_name='mqbench_qmodel', deploy_to_qlinear=False, **extra_kwargs): r"""Convert model to onnx model and quantization params depends on backend. @@ -186,9 +217,9 @@ def forward(self, input_0, input_1): 'output_path': output_path, 'model_name': model_name, 'onnx_model_path': osp.join(output_path, '{}.onnx'.format(model_name)), - 'deploy_to_qlinear': deploy_to_qlinear + 'deploy_to_qlinear': deploy_to_qlinear, } - kwargs.update(extra_kwargs) + # kwargs.update(extra_kwargs) deploy_model = deepcopy_graphmodule(model) for convert_function in BACKEND_DEPLOY_FUNCTION[backend_type]: convert_function(deploy_model, **kwargs) diff --git a/mqbench/custom_quantizer/academic_quantizer.py b/mqbench/custom_quantizer/academic_quantizer.py index 923d7834..ff04f573 100644 --- a/mqbench/custom_quantizer/academic_quantizer.py +++ b/mqbench/custom_quantizer/academic_quantizer.py @@ -6,7 +6,7 @@ import torch from torch.fx import GraphModule from torch.quantization import propagate_qconfig_ -from torch.quantization.fx.qconfig_utils import get_flattened_qconfig_dict +from mqbench.quantization.qconfig_mapping_utils import get_flattened_qconfig_dict from mqbench.utils import is_symmetric_quant, getitem2node from mqbench.utils.logger import logger @@ -25,14 +25,14 @@ def __init__(self, extra_quantizer_dict, extra_fuse_dict): self.io_module = {} self.post_act_8bit_node_name = [] - def prepare(self, model: GraphModule, qconfig): + def prepare(self, model: GraphModule, qconfig, is_qat, backend_config, freeze_bn): self._get_io_module(model) self._get_post_act_8bit_node_name(model) - model = self._weight_quant(model, qconfig) + model = self._weight_quant(model, qconfig, backend_config, freeze_bn) model = self._insert_fake_quantize_for_act_quant(model, qconfig) return model - def _weight_quant(self, model: GraphModule, qconfig): + def _weight_quant(self, model: GraphModule, qconfig, backend_config, freeze_bn): logger.info("Replace module to qat module.") wqconfig_8bit = copy.deepcopy(qconfig) wq_symmetry = True if is_symmetric_quant(qconfig.weight.p.keywords['qscheme']) else False @@ -44,7 +44,7 @@ def _weight_quant(self, model: GraphModule, qconfig): module.qconfig = wqconfig_8bit flattened_qconfig_dict = get_flattened_qconfig_dict({'': qconfig}) propagate_qconfig_(model, flattened_qconfig_dict) - self._qat_swap_modules(model, self.additional_qat_module_mapping) + self._qat_swap_modules(model, self.additional_qat_module_mapping, backend_config, freeze_bn) return model @property diff --git a/mqbench/custom_quantizer/model_quantizer.py b/mqbench/custom_quantizer/model_quantizer.py index f0a7c707..1b9e9812 100644 --- a/mqbench/custom_quantizer/model_quantizer.py +++ b/mqbench/custom_quantizer/model_quantizer.py @@ -23,9 +23,7 @@ from torch.quantization.utils import ( get_combined_dict ) -from torch.quantization.fx.qconfig_utils import ( - get_flattened_qconfig_dict -) +from mqbench.quantization.qconfig_mapping_utils import get_flattened_qconfig_dict from torch.quantization.quantize_fx import ( _fuse_fx ) @@ -34,10 +32,15 @@ from mqbench.utils.logger import logger from mqbench.utils.registry import register_model_quantizer from mqbench.prepare_by_platform import BackendType - - +from torch.ao.quantization.backend_config import ( + BackendConfig, +) +from torch.ao.quantization.backend_config.utils import ( + get_module_to_qat_module, +) @register_model_quantizer(BackendType.Tensorrt) @register_model_quantizer(BackendType.NNIE) +@register_model_quantizer(BackendType.QDQ) class ModelQuantizer(object): """General model quantizer class. First, replace common float module to nn.qat.modules to make weight fake @@ -60,9 +63,9 @@ def __init__(self, extra_quantizer_dict, extra_fuse_dict): self.exclude_node_name = extra_quantizer_dict.get('exclude_node_name', []) self.extra_fuse_dict = extra_fuse_dict - def prepare(self, model: GraphModule, qconfig): - model = _fuse_fx(model, self.extra_fuse_dict) - model = self._weight_quant(model, qconfig) + def prepare(self, model: GraphModule, qconfig, is_qat, backend_config, freeze_bn): + model = _fuse_fx(model, is_qat, self.extra_fuse_dict, backend_config) + model = self._weight_quant(model, qconfig, backend_config, freeze_bn) model = self._insert_fake_quantize_for_act_quant(model, qconfig) return model @@ -119,11 +122,11 @@ def _fix_succ_recursivly(self, args, target_node, inserted_node): else: raise NotImplementedError('{} can not be handled now.'.format(type(args))) - def _weight_quant(self, model: GraphModule, qconfig): + def _weight_quant(self, model: GraphModule, qconfig, backend_config, freeze_bn): logger.info("Replace module to qat module.") flattened_qconfig_dict = get_flattened_qconfig_dict({'': qconfig}) propagate_qconfig_(model, flattened_qconfig_dict) - self._qat_swap_modules(model, self.additional_qat_module_mapping) + self._qat_swap_modules(model, self.additional_qat_module_mapping, backend_config, freeze_bn) return model @property @@ -245,15 +248,18 @@ def _find_act_quants(self, model: GraphModule) -> List: node_need_to_quantize_output.append(_node) return node_need_to_quantize_output - def _qat_swap_modules(self, root: GraphModule, additional_qat_module_mapping: Dict[Callable, Callable]): + def _qat_swap_modules(self, root: GraphModule, additional_qat_module_mapping: Dict[Callable, Callable], backend_config: BackendConfig, freeze_bn: bool): + # all_mappings = get_combined_dict( + # get_default_qat_module_mappings(), additional_qat_module_mapping) all_mappings = get_combined_dict( - get_default_qat_module_mappings(), additional_qat_module_mapping) - root = self._convert(root, all_mappings, inplace=True) + get_module_to_qat_module(backend_config), additional_qat_module_mapping) + root = self._convert(root, all_mappings, inplace=True, backend_config = backend_config, freeze_bn=freeze_bn) return root - def _convert(self, module, mapping=None, inplace=False, scope=''): + def _convert(self, module, mapping=None, inplace=False, backend_config=None, freeze_bn=True, scope=''): if mapping is None: - mapping = get_default_static_quant_module_mappings() + # mapping = get_default_static_quant_module_mappings() + mapping = get_module_to_qat_module(backend_config) if not inplace: module = copy.deepcopy(module) @@ -265,8 +271,11 @@ def _convert(self, module, mapping=None, inplace=False, scope=''): logger.info("Skip quant layer: " + new_scope) continue if not isinstance(mod, _FusedModule): - self._convert(mod, mapping, True, new_scope) - reassign[name] = swap_module(mod, mapping, {}) + self._convert(mod, mapping, True, new_scope, freeze_bn= freeze_bn) + reassign[name] = swap_module(mod, mapping, {}, False) + if freeze_bn: + if (hasattr(reassign[name], 'freeze_bn')): + reassign[name].freeze_bn = True for key, value in reassign.items(): module._modules[key] = value diff --git a/mqbench/custom_quantizer/onnx_qnn_quantizer.py b/mqbench/custom_quantizer/onnx_qnn_quantizer.py index 2a321a15..5dafde9b 100644 --- a/mqbench/custom_quantizer/onnx_qnn_quantizer.py +++ b/mqbench/custom_quantizer/onnx_qnn_quantizer.py @@ -12,7 +12,9 @@ from mqbench.utils.registry import register_model_quantizer from mqbench.prepare_by_platform import BackendType from mqbench.custom_quantizer import ModelQuantizer - +from torch.ao.quantization.backend_config import ( + BackendConfig, +) @register_model_quantizer(BackendType.ONNX_QNN) class ONNXQNNQuantizer(ModelQuantizer): @@ -52,14 +54,14 @@ def _find_act_quants(self, model: GraphModule) -> List: node_need_to_quantize_output.append(next_node) return node_need_to_quantize_output - def _qat_swap_modules(self, root: GraphModule, additional_qat_module_mapping: Dict[Callable, Callable]): + def _qat_swap_modules(self, root: GraphModule, additional_qat_module_mapping: Dict[Callable, Callable], backend_config: BackendConfig, freeze_bn: bool): all_mappings = get_combined_dict( get_default_qat_module_mappings(), additional_qat_module_mapping) # There is no QLinearFC in ONNX for now. del all_mappings[torch.nn.modules.linear.Linear] del all_mappings[torch.nn.intrinsic.modules.fused.LinearReLU] - del all_mappings[qnni.modules.fused.LinearBn1d] - root = self._convert(root, all_mappings, inplace=True) + # del all_mappings[qnni.modules.fused.LinearBn1d] + root = self._convert(root, all_mappings, inplace=True, backend_config = backend_config, freeze_bn=freeze_bn) return root @property diff --git a/mqbench/custom_quantizer/openvino_quantizer.py b/mqbench/custom_quantizer/openvino_quantizer.py index 1509b832..2c5f4ec6 100644 --- a/mqbench/custom_quantizer/openvino_quantizer.py +++ b/mqbench/custom_quantizer/openvino_quantizer.py @@ -6,7 +6,7 @@ import torch from torch.fx import GraphModule from torch.quantization import propagate_qconfig_ -from torch.quantization.fx.qconfig_utils import get_flattened_qconfig_dict +from mqbench.quantization.qconfig_mapping_utils import get_flattened_qconfig_dict from torch.quantization.quantize_fx import _fuse_fx from mqbench.utils import is_symmetric_quant @@ -137,10 +137,10 @@ def function_type_to_quant_unsigned(self) -> tuple: def module_type_maybe_unsigned(self) -> tuple: return (torch.nn.Upsample, torch.nn.modules.pooling.MaxPool2d, torch.nn.modules.pooling.AvgPool2d, torch.nn.modules.pooling.AdaptiveAvgPool2d) - def prepare(self, model: GraphModule, qconfig): + def prepare(self, model: GraphModule, qconfig, is_qat, backend_config, freeze_bn): if not self.academic_mode: - model = _fuse_fx(model, self.extra_fuse_dict) - model = self._weight_quant(model, qconfig) + model = _fuse_fx(model, is_qat, self.extra_fuse_dict, backend_config) + model = self._weight_quant(model, qconfig, backend_config, freeze_bn) model = self._insert_fake_quantize_for_act_quant(model, qconfig) return model @@ -199,7 +199,7 @@ def propagated_pattern(prev_node, cur_node): break return node_need_to_quantize_output - def _weight_quant(self, model: GraphModule, qconfig): + def _weight_quant(self, model: GraphModule, qconfig, backend_config, freeze_bn): logger.info("Replace module to qat module.") wqconfig_8bit = copy.deepcopy(qconfig) wq_symmetry = True if is_symmetric_quant(qconfig.weight.p.keywords['qscheme']) else False @@ -213,7 +213,7 @@ def _weight_quant(self, model: GraphModule, qconfig): wqconfig_8bit.weight.p.keywords['quant_max'] = 2 ** (numbits - 2) - 1 flattened_qconfig_dict = get_flattened_qconfig_dict({'': wqconfig_8bit}) propagate_qconfig_(model, flattened_qconfig_dict) - self._qat_swap_modules(model, self.additional_qat_module_mapping) + self._qat_swap_modules(model, self.additional_qat_module_mapping, backend_config, freeze_bn) return model diff --git a/mqbench/custom_quantizer/vitis_quantizer.py b/mqbench/custom_quantizer/vitis_quantizer.py index 02b1550e..3b6b62b6 100644 --- a/mqbench/custom_quantizer/vitis_quantizer.py +++ b/mqbench/custom_quantizer/vitis_quantizer.py @@ -29,9 +29,9 @@ def __init__(self, extra_quantizer_dict, extra_fuse_dict): super().__init__(extra_quantizer_dict, extra_fuse_dict) self.additional_qat_module_mapping = { # Intrinsic modules: - nni.ConvBn2d: qnniqat.ConvBn2d, - nni.ConvBnReLU2d: qnniqat.ConvBnReLU2d, - nni.ConvReLU2d: qnniqat.ConvReLU2d, + # nni.ConvBn2d: qnniqat.ConvBn2d, + # nni.ConvBnReLU2d: qnniqat.ConvBnReLU2d, + # nni.ConvReLU2d: qnniqat.ConvReLU2d, } @property @@ -83,9 +83,9 @@ def function_type_to_quant_output(self) -> List: torch.nn.functional.interpolate, ] - def prepare(self, model: GraphModule, qconfig): - model = _fuse_fx(model, self.extra_fuse_dict) - model = self._weight_quant(model, qconfig) + def prepare(self, model: GraphModule, qconfig, is_qat, backend_config, freeze_bn): + model = _fuse_fx(model, is_qat, self.extra_fuse_dict, backend_config) + model = self._weight_quant(model, qconfig, backend_config, freeze_bn) model = self._insert_fake_quantize_for_act_quant(model, qconfig) prepared = model self._set_quant_type(prepared) diff --git a/mqbench/custom_symbolic_opset.py b/mqbench/custom_symbolic_opset.py index 6fcb1f28..fabd5d46 100644 --- a/mqbench/custom_symbolic_opset.py +++ b/mqbench/custom_symbolic_opset.py @@ -1,23 +1,178 @@ from torch.onnx import register_custom_op_symbolic # Register symbolic op for torch.quantize_function op. +import functools +from torch.onnx._internal import jit_utils, registration +_onnx_symbolic = functools.partial(registration.onnx_symbolic, opset=13) +_custom_onnx_symbolic = functools.partial(registration.custom_onnx_symbolic, opset=13) +from torch.onnx import ( + _type_utils, + symbolic_helper, + symbolic_opset9 as opset9, +) +import torch._C._onnx as _C_onnx +import torch +@_custom_onnx_symbolic("aten::fake_quantize_per_tensor_affine") +@symbolic_helper.parse_args("v", "v", "v", "i", "i") +def fake_quantize_per_tensor_affine( + g: jit_utils.GraphContext, + inputs, + scale, + zero_point, + quant_min=-128, + quant_max=127, +): + # NOTE: (0, 127) is allowed as special case. PyTorch restricts activations to be in the range (0, 127). + # https://github.com/pytorch/pytorch/blob/b34b192d6b97325c9f78e5995c48c8498ede34bd/torch/ao/quantization/observer.py#L1422 + # if (quant_min, quant_max) not in [(0, 255), (-128, 127), (0, 127)]: + # raise errors.SymbolicValueError( + # "For (quant_min, quant_max), ONNX allows only (0, 127), (0, 255) and (-128, 127). " + # f"Got ({quant_min}, {quant_max})", + # inputs, + # ) + if quant_min == 0: + zero_point = g.op("Cast", zero_point, to_i=_C_onnx.TensorProtoDataType.UINT8) + else: + zero_point = g.op("Cast", zero_point, to_i=_C_onnx.TensorProtoDataType.INT8) + if ( + _type_utils.JitScalarType.from_value(scale, _type_utils.JitScalarType.UNDEFINED) + != _type_utils.JitScalarType.FLOAT + ): + scale = g.op("Cast", scale, to_i=_C_onnx.TensorProtoDataType.FLOAT) + quantized = g.op("QuantizeLinear", inputs, scale, zero_point, quant_min_i = quant_min, quant_max_i = quant_max) + if (quant_min, quant_max) == (0, 127): + quantized = g.op( + "Clip", + quantized, + opset9.unused(g), + g.op("Constant", value_t=torch.tensor(127, dtype=torch.uint8)), + ) + return g.op("DequantizeLinear", quantized, scale, zero_point, quant_min_i = quant_min, quant_max_i = quant_max) -def _fake_quantize_learnable_per_tensor_affine(g, x, scale, zero_point, quant_min, quant_max, grad_factor): - return g.op("::LearnablePerTensorAffine", x, scale, zero_point, quant_min, quant_max) +@_custom_onnx_symbolic("aten::fake_quantize_per_channel_affine") +@symbolic_helper.parse_args("v", "v", "v", "i", "i", "i") +def fake_quantize_per_channel_affine( + g: jit_utils.GraphContext, + inputs, + scale, + zero_point, + axis, + quant_min=-128, + quant_max=127, +): + # NOTE: (0, 127) is allowed as special case. PyTorch restricts activations to be in the range (0, 127). + # https://github.com/pytorch/pytorch/blob/b34b192d6b97325c9f78e5995c48c8498ede34bd/torch/ao/quantization/observer.py#L1422 + # if (quant_min, quant_max) not in [(0, 255), (-128, 127), (0, 127)]: + # raise errors.SymbolicValueError( + # "For (quant_min, quant_max), ONNX allows only (0, 127), (0, 255) and (-128, 127). " + # f"Got ({quant_min}, {quant_max})", + # inputs, + # ) + # ONNX defines zero_point to be int8 or uint8 + if quant_min == 0: + zero_point = g.op("Cast", zero_point, to_i=_C_onnx.TensorProtoDataType.UINT8) + else: + zero_point = g.op("Cast", zero_point, to_i=_C_onnx.TensorProtoDataType.INT8) + quantized = g.op("QuantizeLinear", inputs, scale, zero_point, axis_i=axis, quant_min_i = quant_min, quant_max_i = quant_max) + if (quant_min, quant_max) == (0, 127): + quantized = g.op( + "Clip", + quantized, + opset9.unused(g), + g.op("Constant", value_t=torch.tensor(127, dtype=torch.uint8)), + ) + return g.op("DequantizeLinear", quantized, scale, zero_point, axis_i=axis, quant_min_i = quant_min, quant_max_i = quant_max) +@_onnx_symbolic("aten::_fake_quantize_learnable_per_tensor_affine") +@symbolic_helper.parse_args("v", "v", "v", "i", "i", "f") +def _fake_quantize_learnable_per_tensor_affine( + g: jit_utils.GraphContext, + inputs, + scale, + zero_point, + quant_min=-128, + quant_max=127, + grad_factor=1.0, +): + # NOTE: (0, 127) is allowed as special case. PyTorch restricts activations to be in the range (0, 127). + # https://github.com/pytorch/pytorch/blob/b34b192d6b97325c9f78e5995c48c8498ede34bd/torch/ao/quantization/observer.py#L1422 + # if (quant_min, quant_max) not in [(0, 255), (-128, 127), (0, 127)]: + # raise errors.SymbolicValueError( + # "For (quant_min, quant_max), ONNX allows only (0, 127), (0, 255) and (-128, 127). " + # f"Got ({quant_min}, {quant_max})", + # inputs, + # ) + if quant_min == 0: + zero_point = g.op("Cast", zero_point, to_i=_C_onnx.TensorProtoDataType.UINT8) + else: + zero_point = g.op("Cast", zero_point, to_i=_C_onnx.TensorProtoDataType.INT8) + if ( + _type_utils.JitScalarType.from_value(scale, _type_utils.JitScalarType.UNDEFINED) + != _type_utils.JitScalarType.FLOAT + ): + scale = g.op("Cast", scale, to_i=_C_onnx.TensorProtoDataType.FLOAT) + quantized = g.op("QuantizeLinear", inputs, scale, zero_point, quant_min_i = quant_min, quant_max_i = quant_max) + if (quant_min, quant_max) == (0, 127): + quantized = g.op( + "Clip", + quantized, + opset9.unused(g), + g.op("Constant", value_t=torch.tensor(127, dtype=torch.uint8)), + ) + return g.op("DequantizeLinear", quantized, scale, zero_point, quant_min_i = quant_min, quant_max_i = quant_max) -register_custom_op_symbolic('::_fake_quantize_learnable_per_tensor_affine', _fake_quantize_learnable_per_tensor_affine, 11) +@_custom_onnx_symbolic("aten::_fake_quantize_learnable_per_channel_affine") +@symbolic_helper.parse_args("v", "v", "v", "i", "i", "i", "f") +def _fake_quantize_learnable_per_channel_affine( + g: jit_utils.GraphContext, + inputs, + scale, + zero_point, + axis, + quant_min=-128, + quant_max=127, + grad_factor=1.0, +): + # NOTE: (0, 127) is allowed as special case. PyTorch restricts activations to be in the range (0, 127). + # https://github.com/pytorch/pytorch/blob/b34b192d6b97325c9f78e5995c48c8498ede34bd/torch/ao/quantization/observer.py#L1422 + # if (quant_min, quant_max) not in [(0, 255), (-128, 127), (0, 127)]: + # raise errors.SymbolicValueError( + # "For (quant_min, quant_max), ONNX allows only (0, 127), (0, 255) and (-128, 127). " + # f"Got ({quant_min}, {quant_max})", + # inputs, + # ) + # ONNX defines zero_point to be int8 or uint8 + if quant_min == 0: + zero_point = g.op("Cast", zero_point, to_i=_C_onnx.TensorProtoDataType.UINT8) + else: + zero_point = g.op("Cast", zero_point, to_i=_C_onnx.TensorProtoDataType.INT8) + quantized = g.op("QuantizeLinear", inputs, scale, zero_point, axis_i=axis, quant_min_i = quant_min, quant_max_i = quant_max) + if (quant_min, quant_max) == (0, 127): + quantized = g.op( + "Clip", + quantized, + opset9.unused(g), + g.op("Constant", value_t=torch.tensor(127, dtype=torch.uint8)), + ) + return g.op("DequantizeLinear", quantized, scale, zero_point, axis_i=axis, quant_min_i = quant_min, quant_max_i = quant_max) +# def _fake_quantize_learnable_per_tensor_affine(g, x, scale, zero_point, quant_min, quant_max, grad_factor): +# return g.op(x, scale, zero_point, quant_min, quant_max) +# +# +# register_custom_op_symbolic('::_fake_quantize_learnable_per_tensor_affine', _fake_quantize_learnable_per_tensor_affine, 11) +# +# +# def fake_quantize_per_channel_affine(g, x, scale, zero_point, ch_axis, quant_min, quant_max): +# return g.op("::FixedPerChannelAffine", x, scale, zero_point, ch_axis, quant_min, quant_max) +# +# +# register_custom_op_symbolic('::fake_quantize_per_channel_affine', fake_quantize_per_channel_affine, 11) +# +# +# def fake_quantize_per_tensor_affine(g, x, scale, zero_point, quant_min, quant_max): +# return g.op("::FixedPerTensorAffine", x, scale, zero_point, quant_min, quant_max) +# +# +# register_custom_op_symbolic('::fake_quantize_per_tensor_affine', fake_quantize_per_tensor_affine, 11) -def fake_quantize_per_channel_affine(g, x, scale, zero_point, ch_axis, quant_min, quant_max): - return g.op("::FixedPerChannelAffine", x, scale, zero_point, ch_axis, quant_min, quant_max) - - -register_custom_op_symbolic('::fake_quantize_per_channel_affine', fake_quantize_per_channel_affine, 11) - - -def fake_quantize_per_tensor_affine(g, x, scale, zero_point, quant_min, quant_max): - return g.op("::FixedPerTensorAffine", x, scale, zero_point, quant_min, quant_max) - - -register_custom_op_symbolic('::fake_quantize_per_tensor_affine', fake_quantize_per_tensor_affine, 11) \ No newline at end of file diff --git a/mqbench/deploy/common.py b/mqbench/deploy/common.py index 6a3382e1..c0e30f33 100644 --- a/mqbench/deploy/common.py +++ b/mqbench/deploy/common.py @@ -39,6 +39,8 @@ def get_initializer(self, initializer_name): def set_initializer(self, initializer_name, value_tensor, raw=True): idx = None + if value_tensor.shape == () and value_tensor.size == 1: + value_tensor.shape = 1 if initializer_name in self.initializer: idx = self.initializer[initializer_name][1] if raw: @@ -217,6 +219,20 @@ def prepare_data(graph): params[node.output[0]] = numpy_helper.to_array(attr.t) return params +def prepare_data_nnie(graph): + params = {} + for init in graph.initializer: + params[init.name] = numpy_helper.to_array(init) + for node in graph.node: + if node.op_type == "Constant": + for attr in node.attribute: + if attr.name == "value": + params[node.output[0]] = numpy_helper.to_array(attr.t) + elif node.op_type == "QuantizeLinear": + for attr in node.attribute: + if attr.name == "data_max": + params[node.output[0]] = attr.f + return params def prepare_initializer(graph): named_initializer = {} diff --git a/mqbench/deploy/deploy_linear.py b/mqbench/deploy/deploy_linear.py index ef7e8063..5d99091b 100644 --- a/mqbench/deploy/deploy_linear.py +++ b/mqbench/deploy/deploy_linear.py @@ -17,15 +17,19 @@ ) -PERCHANNEL_FAKEQUANTIZER = ['FakeQuantizeLearnablePerchannelAffine', - 'FixedPerChannelAffine', - 'FakeQuantizeDSQPerchannel'] -PERTENSOR_FAKEQUANTIZER = ['LearnablePerTensorAffine', - 'FixedPerTensorAffine', - 'FakeQuantizeDSQPertensor', - 'FakeQuantizeTqtAffine'] -ALL_FAKEQUANTIZER = PERCHANNEL_FAKEQUANTIZER + PERTENSOR_FAKEQUANTIZER +_FAKEQUANTIZER = ['QuantizeLinear'] +PERCHANNEL_FAKEQUANTIZER = [] +PERTENSOR_FAKEQUANTIZER = ['DequantizeLinear', + 'QuantizeLinear'] +ALL_FAKEQUANTIZER = ['QuantizeLinear', 'DequantizeLinear'] + # PERCHANNEL_FAKEQUANTIZER + PERTENSOR_FAKEQUANTIZER +def get_dequant_node(node, inp2node): + dequant_node = inp2node[node.output[0]][0][0] + if dequant_node.op_type == 'Clip': + dequant_node = inp2node[dequant_node.output[0]][0][0] + assert dequant_node.op_type == 'DequantizeLinear', "This is not correct fakequant node!" + return dequant_node class LinearQuantizer_process(object): # some method like dorefa need pre-compute weights @@ -59,46 +63,46 @@ def find_redundant_nodes(tensor): find_redundant_nodes(weight) return weight, redundant_nodes - def deal_with_weight_fakequant(self, node, out2node, inp2node, named_initializer): - next_nodes = inp2node[node.output[0]] + def deal_with_weight_fakequant(self, quant_node, dequant_node, out2node, inp2node, named_initializer): + next_nodes = inp2node[dequant_node.output[0]] assert len(next_nodes) == 1 next_node, idx = next_nodes[0] assert next_node.op_type in ['Conv', 'Gemm', 'ConvTranspose'] redundant_nodes = [] - if node.input[0] not in named_initializer: - node.input[0], redundant_nodes = \ - self.weight_preprocess(node.input[0], out2node, inp2node, named_initializer) - next_node.input[idx] = node.input[0] + if quant_node.input[0] not in named_initializer: + quant_node.input[0], redundant_nodes = \ + self.weight_preprocess(quant_node.input[0], out2node, inp2node, named_initializer) + next_node.input[idx] = quant_node.input[0] return redundant_nodes - def deal_with_activation_fakequant(self, node, inp2node): - next_nodes = inp2node[node.output[0]] + def deal_with_activation_fakequant(self, quant_node, dequant_node, inp2node): + next_nodes = inp2node[dequant_node.output[0]] for next_node, idx in next_nodes: - next_node.input[idx] = node.input[0] + next_node.input[idx] = quant_node.input[0] def parse_qparams(self, node, name2data): tensor_name, scale, zero_point = node.input[:3] scale, zero_point = name2data[scale], name2data[zero_point] - if len(node.input) > 3: - qmin, qmax = node.input[-2:] - qmin, qmax = name2data[qmin], name2data[qmax] - elif len(node.attribute) > 0: - qparams = parse_attrs(node.attribute) - qmin = qparams['quant_min'] - qmax = qparams['quant_max'] - else: - logger.info(f'qmin and qmax are not found for <{node.name}>!') - return tensor_name, scale, zero_point, qmin, qmax + # if len(node.input) > 3: + # qmin, qmax = node.input[-2:] + # qmin, qmax = name2data[qmin], name2data[qmax] + # elif len(node.attribute) > 0: + # qparams = parse_attrs(node.attribute) + # qmin = qparams['quant_min'] + # qmax = qparams['quant_max'] + # else: + # logger.info(f'qmin and qmax are not found for <{node.name}>!') + return tensor_name, scale, zero_point - def clip_weight(self, node, name2data, inp2node, named_initializer): - tensor_name, scale, zero_point, qmin, qmax = self.parse_qparams(node, name2data) + def clip_weight(self, quant, dequant, name2data, inp2node, named_initializer, qmin, qmax): + tensor_name, scale, zero_point = self.parse_qparams(quant, name2data) data = name2data[tensor_name] clip_range_min = ((qmin - zero_point) * scale).astype(data.dtype) clip_range_max = ((qmax - zero_point) * scale).astype(data.dtype) if len(scale.shape) > 0 and scale.shape[0] > 1: new_data = [] transposed = False - next_node = inp2node[node.output[0]] + next_node = inp2node[dequant.output[0]] if len(next_node) == 1 and next_node[0][0].op_type == 'ConvTranspose': transposed = True data = data.transpose(1, 0, 2, 3) @@ -131,7 +135,9 @@ def find_the_closest_clip_range(node): logger.info(f'Pass <{tensor_name}> clip range to <{node.name}> input <{node.input[0]}>.') return clip_ranges - def remove_fakequantize_and_collect_params(self, onnx_path, model_name, backend): + def remove_fakequantize_and_collect_params(self, onnx_path, model_name, backend, qmin_max_dict): + # a_qmin, a_qmax = kwargs['extra_kwargs'][0].p.keywords['quant_min'], kwargs['extra_kwargs'][0].p.keywords['quant_max'] + # w_qmin, w_qmax = kwargs['extra_kwargs'][1].p.keywords['quant_min'], kwargs['extra_kwargs'][1].p.keywords['quant_max'] model = onnx.load(onnx_path) graph = model.graph out2node, inp2node = update_inp2node_out2node(graph) @@ -146,16 +152,21 @@ def remove_fakequantize_and_collect_params(self, onnx_path, model_name, backend) nodes_to_be_removed = [] for node in graph.node: if node.op_type in ALL_FAKEQUANTIZER: + next_node = inp2node[node.output[0]][0][0] + if next_node.op_type == 'Clip' and inp2node[next_node.output[0]][0][0].op_type == 'DequantizeLinear': + nodes_to_be_removed.append(next_node) nodes_to_be_removed.append(node) nodes_to_be_removed.extend(get_constant_inputs(node, out2node)) - if node.op_type in PERCHANNEL_FAKEQUANTIZER: + if node.op_type in _FAKEQUANTIZER and 'axis' in parse_attrs(node.attribute): # fake quantize for weights, suppose per-channel quantize only for weight - redundant_nodes = self.deal_with_weight_fakequant(node, out2node, inp2node, named_initializer) + qmin, qmax = qmin_max_dict[node.name] + dequant_node = get_dequant_node(node, inp2node) + redundant_nodes = self.deal_with_weight_fakequant(node, dequant_node, out2node, inp2node, named_initializer) nodes_to_be_removed.extend(redundant_nodes) - self.clip_weight(node, name2data, inp2node, named_initializer) + self.clip_weight(node, dequant_node, name2data, inp2node, named_initializer, qmin, qmax) if backend == 'ppl': - tensor_name, scale, zero_point, qmin, qmax = self.parse_qparams(node, name2data) + tensor_name, scale, zero_point= self.parse_qparams(node, name2data) clip_ranges[tensor_name] = {'step': [float(x) for x in scale], 'zero_point': [int(x) for x in zero_point], 'min': [float(x) for x in scale * (qmin - zero_point)], @@ -167,30 +178,32 @@ def remove_fakequantize_and_collect_params(self, onnx_path, model_name, backend) logger.info("Vitis-DPU does not support per-channel quatization.") raise NotImplementedError("Vitis-DPU does not support per-channel quatization.") - elif node.op_type in PERTENSOR_FAKEQUANTIZER: - if node.output[0] not in inp2node: - assert node.output[0] in [l.name for l in graph.output] - inp2node[node.output[0]] = [] - next_nodes = inp2node[node.output[0]] + elif node.op_type in _FAKEQUANTIZER and 'axis' not in parse_attrs(node.attribute): + qmin, qmax = qmin_max_dict[node.name] + dequant_node = get_dequant_node(node, inp2node) + if dequant_node.output[0] not in inp2node: + assert dequant_node.output[0] in [l.name for l in graph.output] + inp2node[dequant_node.output[0]] = [] + next_nodes = inp2node[dequant_node.output[0]] if len(next_nodes) == 1 and next_nodes[0][1] == 1 and next_nodes[0][0].op_type in ['Gemm', 'Conv']: # fake quantize for weights - redundant_nodes = self.deal_with_weight_fakequant(node, out2node, inp2node, named_initializer) - tensor_name, scale, zero_point, qmin, qmax = self.parse_qparams(node, name2data) + redundant_nodes = self.deal_with_weight_fakequant(node, dequant_node, out2node, inp2node, named_initializer) + tensor_name, scale, zero_point = self.parse_qparams(node, name2data) nodes_to_be_removed.extend(redundant_nodes) - self.clip_weight(node, name2data, inp2node, named_initializer) + self.clip_weight(node, dequant_node, name2data, inp2node, named_initializer, qmin, qmax) elif len(next_nodes) == 1 and next_nodes[0][1] == 2 and next_nodes[0][0].op_type in ['Gemm', 'Conv']: # fake quantize for bias assert backend == 'vitis' - redundant_nodes = self.deal_with_weight_fakequant(node, out2node, inp2node, named_initializer) - tensor_name, scale, zero_point, qmin, qmax = self.parse_qparams(node, name2data) + redundant_nodes = self.deal_with_weight_fakequant(node, dequant_node, out2node, inp2node, named_initializer) + tensor_name, scale, zero_point = self.parse_qparams(node, name2data) nodes_to_be_removed.extend(redundant_nodes) - self.clip_weight(node, name2data, inp2node, named_initializer) + self.clip_weight(node, dequant_node, name2data, inp2node, named_initializer, qmin, qmax) else: # fake quantize for activations - self.deal_with_activation_fakequant(node, inp2node) - tensor_name, scale, zero_point, qmin, qmax = self.parse_qparams(node, name2data) + self.deal_with_activation_fakequant(node, dequant_node, inp2node) + tensor_name, scale, zero_point = self.parse_qparams(node, name2data) for out in graph.output: - if out.name == node.output[0]: + if out.name == dequant_node.output[0]: out.name = tensor_name if backend == 'tensorrt': @@ -215,7 +228,8 @@ def remove_fakequantize_and_collect_params(self, onnx_path, model_name, backend) clip_ranges[tensor_name] = float(max(-scale * (qmin - zero_point), scale * (qmax - zero_point))) for node in nodes_to_be_removed: - graph.node.remove(node) + if node in graph.node: + graph.node.remove(node) # delete initializer out2node, inp2node = update_inp2node_out2node(graph) named_initializer = prepare_initializer(graph) diff --git a/mqbench/deploy/deploy_nnie.py b/mqbench/deploy/deploy_nnie.py index ea41f2ba..1bab17af 100644 --- a/mqbench/deploy/deploy_nnie.py +++ b/mqbench/deploy/deploy_nnie.py @@ -10,7 +10,7 @@ from mqbench.deploy.common import ( update_inp2node_out2node, prepare_initializer, - prepare_data, + prepare_data_nnie, OnnxPreprocess, get_constant_inputs ) @@ -51,7 +51,7 @@ def remove_fakequantize_and_collect_params(self, onnx_path, model_name): model = onnx.load(onnx_path) graph = model.graph out2node, inp2node = update_inp2node_out2node(graph) - name2data = prepare_data(graph) + name2data = prepare_data_nnie(graph) named_initializer = prepare_initializer(graph) preprocess = OnnxPreprocess() @@ -62,8 +62,9 @@ def remove_fakequantize_and_collect_params(self, onnx_path, model_name): nodes_to_be_removed = [] clip_ranges = {} for node in graph.node: - if node.op_type == 'NNIEQuantize': - next_nodes = inp2node[node.output[0]] + if node.op_type == 'QuantizeLinear': + dequant_node = inp2node[node.output[0]][0][0] + next_nodes = inp2node[dequant_node.output[0]] if len(next_nodes) == 1 and next_nodes[0][1] == 1 and next_nodes[0][0].op_type in ['Gemm', 'Conv']: # fake quantize for weights next_node, idx = next_nodes[0] @@ -71,14 +72,14 @@ def remove_fakequantize_and_collect_params(self, onnx_path, model_name): # clip weights tensor_name = node.input[0] data = name2data[tensor_name] - clip_range = name2data[node.input[1]] + clip_range = name2data[dequant_node.input[0]] new_data = np.clip(data, -clip_range, clip_range) new_data = numpy_helper.from_array(new_data) named_initializer[tensor_name].raw_data = new_data.raw_data logger.info(f'Clip weights {tensor_name} to range [{-clip_range}, {clip_range}].') else: # fake quantize for activations - clip_ranges[node.input[0]] = name2data[node.input[1]] + clip_ranges[node.input[0]] = name2data[dequant_node.input[0]] for next_node, idx in next_nodes: next_node.input[idx] = node.input[0] diff --git a/mqbench/deploy/deploy_onnx_qlinear.py b/mqbench/deploy/deploy_onnx_qlinear.py index 3c9d46f5..c115f550 100644 --- a/mqbench/deploy/deploy_onnx_qlinear.py +++ b/mqbench/deploy/deploy_onnx_qlinear.py @@ -17,25 +17,25 @@ def __init__(self, onnx_model_path): def parse_qparams(self, node, name2data): tensor_name, scale, zero_point = node.input[:3] scale, zero_point = name2data[scale], name2data[zero_point] - if len(node.input) > 3: - qmin, qmax = node.input[-2:] - qmin, qmax = name2data[qmin], name2data[qmax] - elif len(node.attribute) > 0: - qparams = parse_attrs(node.attribute) - qmin = qparams['quant_min'] - qmax = qparams['quant_max'] - else: - logger.info(f'qmin and qmax are not found for <{node.name}>!') - return tensor_name, scale, zero_point, qmin, qmax + # if len(node.input) > 3: + # qmin, qmax = node.input[-2:] + # qmin, qmax = name2data[qmin], name2data[qmax] + # elif len(node.attribute) > 0: + # qparams = parse_attrs(node.attribute) + # qmin = qparams['quant_min'] + # qmax = qparams['quant_max'] + # else: + # logger.info(f'qmin and qmax are not found for <{node.name}>!') + return tensor_name, scale, zero_point - def clip_weight(self, node, name2data, named_initializer): - tensor_name, scale, zero_point, qmin, qmax = self.parse_qparams(node, name2data) + def clip_weight(self, quant, dequant, name2data, named_initializer, qmin, qmax): + tensor_name, scale, zero_point = self.parse_qparams(quant, name2data) data = name2data[tensor_name] clip_range_min = (qmin - zero_point) * scale clip_range_max = (qmax - zero_point) * scale if scale.shape[0] > 1: new_data = [] - next_node = self.onnx_model.get_tensor_consumer(node.output[0])[0] + next_node = self.onnx_model.get_tensor_consumer(dequant.output[0])[0] if next_node.op_type == 'ConvTranspose': for c in range(data.shape[1]): new_data.append(np.clip(data[:, c], clip_range_min[c], clip_range_max[c])) @@ -64,20 +64,14 @@ def wrap_onnx_constant(self, data): else: return np.array(data) - def format_qlinear_dtype_pass(self): + def format_qlinear_dtype_pass(self, qmin_max_dict): name2data = prepare_data(self.onnx_model.graph) named_initializer = prepare_initializer(self.onnx_model.graph) for node in self.onnx_model.graph.node: - if node.op_type in FAKE_QUANTIZE_OP: - if node.op_type == 'FakeQuantizeLearnablePerchannelAffine': - scale, zero_point = node.input[1], node.input[2] - assert node.attribute[0].name == 'quant_max' and node.attribute[1].name == 'quant_min' - qmax = node.attribute[0].i - qmin = node.attribute[1].i - else: - scale, zero_point, qmin, qmax = node.input[-4:] - qmin = self.onnx_model.get_constant(qmin) - qmax = self.onnx_model.get_constant(qmax) + scale, zero_point= node.input[-2:] + qmin, qmax = qmin_max_dict[node.name] + qmin = self.onnx_model.get_constant(qmin) + qmax = self.onnx_model.get_constant(qmax) assert qmax - qmin in (2 ** 8 - 1, 2 ** 8 - 2), "Only 8 bit quantization support deployment to ONNX." # In onnx, quantize linear node value is within [-128, 127]. This step is to remove inconsistency for # fake quantize node which clips to [-127, 127] by clipping its value to [-127 * scale, 127 * scale] diff --git a/mqbench/deploy/deploy_onnx_qnn.py b/mqbench/deploy/deploy_onnx_qnn.py index d1ab85d9..0e6ee148 100644 --- a/mqbench/deploy/deploy_onnx_qnn.py +++ b/mqbench/deploy/deploy_onnx_qnn.py @@ -3,11 +3,15 @@ from mqbench.utils.logger import logger from .common import ONNXGraph +from mqbench.deploy.common import parse_attrs +FAKE_QUANTIZE_OP = ['QuantizeLinear', 'DequantizeLinear'] -FAKE_QUANTIZE_OP = ['FakeQuantizeLearnablePerchannelAffine', 'FixedPerChannelAffine', 'FakeQuantizeDSQPerchannel', - 'LearnablePerTensorAffine', 'FixedPerTensorAffine', 'FakeQuantizeDSQPertensor'] +def search_and_replace_input(next_node, name, new_name): + for idx, _input_name in enumerate(next_node.input): + if _input_name == name: + next_node.input[idx] = new_name class ONNXQNNPass(object): def __init__(self, onnx_model_path): @@ -15,7 +19,7 @@ def __init__(self, onnx_model_path): @property def qlinear_op_type(self): - return ['QuantizeLinear', 'QLinearConv', 'QLinearAdd', 'QLinearGemm', 'QLinearGlobalAveragePool', + return ['QLinearConv', 'QLinearAdd', 'QLinearGemm', 'QLinearGlobalAveragePool', 'QLinearAveragePool', 'QLinearConcat'] @staticmethod @@ -74,14 +78,18 @@ def node_without_qparams(self): def replace_conv_gemm(self, node, idx, is_conv): # Input scale qlinear_conv_inputs = [] - input_fake_quant_node = self.onnx_model.get_tensor_producer(node.input[0]) - assert input_fake_quant_node.op_type in FAKE_QUANTIZE_OP + input_fake_dequant_node = self.onnx_model.get_tensor_producer(node.input[0]) + input_fake_quant_node = self.onnx_model.get_tensor_producer(input_fake_dequant_node.input[0]) + assert input_fake_quant_node.op_type == 'QuantizeLinear' x_scale, x_zero_point = input_fake_quant_node.input[1], input_fake_quant_node.input[2] # Output scale + node_next_quant = self.onnx_model.get_tensor_consumer(node.output[0])[0] + # node_next_dequant = self.onnx_model.get_tensor_consumer(node_next_quant.output[0])[0] qlinear_conv_output = node.output y_scale, y_zero_point = self.get_node_output_qparams(node) # Weight scale - weight_fake_quant_node = self.onnx_model.get_tensor_producer(node.input[1]) + weight_fake_dequant_node = self.onnx_model.get_tensor_producer(node.input[1]) + weight_fake_quant_node = self.onnx_model.get_tensor_producer(weight_fake_dequant_node.input[0]) w_scale, w_zero_point = weight_fake_quant_node.input[1], weight_fake_quant_node.input[2] weight_name = weight_fake_quant_node.input[0] W = self.quantize_weight(weight_name, w_scale, w_zero_point) @@ -106,6 +114,12 @@ def replace_conv_gemm(self, node, idx, is_conv): **kwargs) self.onnx_model.remove_node_purely(node) self.onnx_model.remove_node_purely(weight_fake_quant_node) + self.onnx_model.remove_node_purely(weight_fake_dequant_node) + # self.onnx_model.remove_node_purely(node_next_quant) + # next_nodes = self.onnx_model.get_tensor_consumer(input_fake_dequant_node.output[0]) + # for next_node in next_nodes: + # search_and_replace_input(next_node, input_fake_dequant_node.output[0], input_fake_quant_node.output[0]) + # self.onnx_model.remove_node_purely(input_fake_dequant_node) self.onnx_model.insert_node_purely(qlinear_conv_node, idx) self.onnx_model.topologize_graph() @@ -113,17 +127,19 @@ def replace_add_to_qlinearadd(self, node, idx): # First input qlinear_add_input = [] qlinear_add_output = node.output - first_input_node = self.onnx_model.get_tensor_producer(node.input[0]) - assert first_input_node.op_type in FAKE_QUANTIZE_OP - first_input_quantized = first_input_node.output[0] - first_scale = first_input_node.input[1] - first_zero_point = first_input_node.input[2] + first_input_dequant_node = self.onnx_model.get_tensor_producer(node.input[0]) + first_input_quant_node = self.onnx_model.get_tensor_producer(first_input_dequant_node.input[0]) + assert first_input_quant_node.op_type == 'QuantizeLinear' + first_input_quantized = first_input_dequant_node.output[0] + first_scale = first_input_quant_node.input[1] + first_zero_point = first_input_quant_node.input[2] # Second input - second_input_node = self.onnx_model.get_tensor_producer(node.input[1]) - assert second_input_node.op_type in FAKE_QUANTIZE_OP - second_input_quantized = second_input_node.output[0] - second_scale = second_input_node.input[1] - second_zero_point = second_input_node.input[2] + second_input_dequant_node = self.onnx_model.get_tensor_producer(node.input[1]) + second_input_quant_node = self.onnx_model.get_tensor_producer(second_input_dequant_node.input[0]) + assert second_input_quant_node.op_type == 'QuantizeLinear' + second_input_quantized = second_input_dequant_node.output[0] + second_scale = second_input_quant_node.input[1] + second_zero_point = second_input_quant_node.input[2] # Output output_scale, output_zero_point = self.get_node_output_qparams(node) qlinear_add_input.extend([first_input_quantized, first_scale, first_zero_point, @@ -140,13 +156,22 @@ def replace_add_to_qlinearadd(self, node, idx): **kwargs) self.onnx_model.insert_node_purely(qlinear_add_node, idx) self.onnx_model.remove_node_purely(node) + # first_next_nodes = self.onnx_model.get_tensor_consumer(first_input_dequant_node.output[0]) + # for next_node in first_next_nodes: + # search_and_replace_input(next_node, first_input_dequant_node.output[0], first_input_quant_node.output[0]) + # second_next_nodes = self.onnx_model.get_tensor_consumer(second_input_dequant_node.output[0]) + # for next_node in second_next_nodes: + # search_and_replace_input(next_node, second_input_dequant_node.output[0], second_input_quant_node.output[0]) + # self.onnx_model.remove_node_purely(first_input_dequant_node) + # self.onnx_model.remove_node_purely(second_input_dequant_node) self.onnx_model.topologize_graph() def replace_pool_to_qlinearpool(self, node, idx, is_global): qlinear_pool_input = [] - prev_node = self.onnx_model.get_tensor_producer(node.input[0]) - assert prev_node.op_type in FAKE_QUANTIZE_OP - x_scale, x_zero_point = prev_node.input[1], prev_node.input[2] + prev_dequant_node = self.onnx_model.get_tensor_producer(node.input[0]) + prev_quant_node = self.onnx_model.get_tensor_producer(prev_dequant_node.input[0]) + assert prev_quant_node.op_type == 'QuantizeLinear' + x_scale, x_zero_point = prev_quant_node.input[1], prev_quant_node.input[2] y_scale, y_zero_point = self.get_node_output_qparams(node) qlinear_pool_input.extend([node.input[0], x_scale, x_zero_point, y_scale, y_zero_point]) @@ -161,20 +186,25 @@ def replace_pool_to_qlinearpool(self, node, idx, is_global): node.name + '_quantized', domain='com.microsoft', **kwargs) + # next_nodes = self.onnx_model.get_tensor_consumer(prev_dequant_node.output[0]) + # for next_node in next_nodes: + # search_and_replace_input(next_node, prev_dequant_node.output[0], prev_quant_node.output[0]) self.onnx_model.insert_node_purely(qlinear_pool_node, idx) self.onnx_model.remove_node_purely(node) self.onnx_model.topologize_graph() def get_node_output_qparams(self, node): fake_quantize_node = self.onnx_model.get_tensor_consumer(node.output[0])[0] - while fake_quantize_node.op_type not in FAKE_QUANTIZE_OP: + while fake_quantize_node.op_type != 'QuantizeLinear': assert fake_quantize_node.op_type in self.node_without_qparams fake_quantize_node = self.onnx_model.get_tensor_consumer(fake_quantize_node.output[0])[0] return fake_quantize_node.input[1], fake_quantize_node.input[2] def replace_op_pass(self): # Replace Conv / Gemm / Add / AvgPool / Concat / LeakyRelu. + op_types = set() for idx, node in enumerate(self.onnx_model.graph.node): + op_types.add(node.op_type) if node.op_type == 'Conv': self.replace_conv_gemm(node, idx, is_conv=True) if node.op_type == 'Gemm': @@ -193,56 +223,79 @@ def replace_op_pass(self): if node.op_type == 'LeakyRelu': pass + # def replace_qlinear_layer_pass(self): + # # Replace FakeQuantize + # remove_nodes = [] + # for node in self.onnx_model.graph.node: + # if node.op_type in FAKE_QUANTIZE_OP: + # prev_node = self.onnx_model.get_tensor_producer(node.input[0]) + # next_node_list = self.onnx_model.get_tensor_consumer(node.output[0]) + # quantize_node = None + # dequantize_node = None + # output_flag = False + # for next_node in next_node_list: + # if prev_node != 'INPUT_TOKEN' and prev_node.op_type in self.qlinear_op_type and \ + # next_node != 'OUTPUT_TOKEN' and next_node.op_type in self.qlinear_op_type: + # search_and_replace_input(next_node, node.output[0], node.input[0]) + # elif prev_node != 'INPUT_TOKEN' and prev_node.op_type in self.qlinear_op_type and \ + # next_node == 'OUTPUT_TOKEN': + # if dequantize_node is None: + # output_flag = True + # else: + # if quantize_node is None: + # output_value_info = [f'{node.output[0]}_QuantizeLinear'] + # quantize_node = onnx.helper.make_node("QuantizeLinear", + # node.input[0:3], + # output_value_info, + # ('input' if prev_node == 'INPUT_TOKEN' else prev_node.name) + '_quantized') + # self.onnx_model.insert_node_purely(quantize_node) + # search_and_replace_input(next_node, node.output[0], quantize_node.output[0]) + # if not output_flag: + # self.onnx_model.remove_node_purely(node) + # self.onnx_model.topologize_graph() def replace_qlinear_layer_pass(self): - # Replace FakeQuantize - def search_and_replace_input(next_node, name, new_name): - for idx, _input_name in enumerate(next_node.input): - if _input_name == name: - next_node.input[idx] = new_name + node_detect = True + while node_detect: + node_detect = False + # Replace FakeQuantize + for node in self.onnx_model.graph.node: + if node.op_type in self.qlinear_op_type: + next_node_list = self.onnx_model.get_tensor_consumer(node.output[0]) + for i, next_node in enumerate(next_node_list): + if hasattr(next_node, 'op_type'): + if next_node.op_type == 'QuantizeLinear': + node_detect = True + node.output[0] = next_node.output[0] + # next_dequant_node_list = self.onnx_model.get_tensor_consumer(next_node.output[0]) + # for next_dequant_node in next_dequant_node_list: + # search_and_replace_input(next_dequant_node, next_node.output[0], node.output[0]) + self.onnx_model.remove_node_purely(next_node) + self.onnx_model.topologize_graph() + for i in range(len(node.input)): + pre_node = self.onnx_model.get_tensor_producer(node.input[i]) + if hasattr(pre_node, 'op_type'): + if pre_node.op_type == 'DequantizeLinear': + node_detect = True + pre_quant_node = self.onnx_model.get_tensor_producer(pre_node.input[0]) + pre_node_next_list = self.onnx_model.get_tensor_consumer(pre_node.output[0]) + for pre_node_next_node in pre_node_next_list: + search_and_replace_input(pre_node_next_node, pre_node.output[0], pre_quant_node.output[0]) + self.onnx_model.remove_node_purely(pre_node) + self.onnx_model.topologize_graph() - for node in self.onnx_model.graph.node: - if node.op_type in FAKE_QUANTIZE_OP: - prev_node = self.onnx_model.get_tensor_producer(node.input[0]) - next_node_list = self.onnx_model.get_tensor_consumer(node.output[0]) - quantize_node = None - dequantize_node = None - for next_node in next_node_list: - if prev_node != 'INPUT_TOKEN' and prev_node.op_type in self.qlinear_op_type and \ - next_node != 'OUTPUT_TOKEN' and next_node.op_type in self.qlinear_op_type: - search_and_replace_input(next_node, node.output[0], node.input[0]) - elif prev_node != 'INPUT_TOKEN' and prev_node.op_type in self.qlinear_op_type and \ - next_node == 'OUTPUT_TOKEN': - if dequantize_node is None: - output_value_info = [f'{node.output[0]}_DequantizeLinear'] - dequantize_node = onnx.helper.make_node("DequantizeLinear", - node.input[0:3], - output_value_info, - ('input' if prev_node == 'INPUT_TOKEN' else prev_node.name) + '_dequantized') - self.onnx_model.insert_node_purely(dequantize_node) - else: - if quantize_node is None: - output_value_info = [f'{node.output[0]}_QuantizeLinear'] - quantize_node = onnx.helper.make_node("QuantizeLinear", - node.input[0:3], - output_value_info, - ('input' if prev_node == 'INPUT_TOKEN' else prev_node.name) + '_quantized') - self.onnx_model.insert_node_purely(quantize_node) - search_and_replace_input(next_node, node.output[0], quantize_node.output[0]) - self.onnx_model.remove_node_purely(node) - self.onnx_model.topologize_graph() def merge_relu_pass(self): for node in self.onnx_model.graph.node: if node.op_type == 'Relu': next_node = self.onnx_model.get_tensor_consumer(node.output[0])[0] - assert next_node.op_type in FAKE_QUANTIZE_OP + assert next_node.op_type == 'QuantizeLinear' # Input idx2 is zero point. self.onnx_model.set_initializer(next_node.input[2], np.array([0], dtype=np.uint8), raw=False) self.onnx_model.remove_node_purely(node) next_node.input[0] = node.input[0] if node.op_type == 'Clip': next_node = self.onnx_model.get_tensor_consumer(node.output[0])[0] - assert next_node.op_type in FAKE_QUANTIZE_OP + assert next_node.op_type == 'QuantizeLinear' # Input idx2 is zero point. scale = self.onnx_model.get_initializer(next_node.input[1]) scale = min(scale, 6.0 / 255) @@ -252,15 +305,14 @@ def merge_relu_pass(self): next_node.input[0] = node.input[0] self.onnx_model.topologize_graph() - def format_qlinear_dtype_pass(self): + def format_qlinear_dtype_pass(self, qmin_max_dict): for node in self.onnx_model.graph.node: if node.op_type in FAKE_QUANTIZE_OP: - scale, zero_point, qmin, qmax = node.input[1], node.input[2], node.input[3], node.input[4] - qmin = self.onnx_model.get_constant(qmin) - qmax = self.onnx_model.get_constant(qmax) + scale, zero_point = node.input[1], node.input[2] + qmin, qmax = qmin_max_dict[node.name] assert qmax - qmin == 2 ** 8 - 1, "Only 8 bit quantization support deploy to QNN." scale_proto = self.onnx_model.initializer[scale][0] - if scale_proto.raw_data != b'' and scale_proto.dims[0] == 1: + if scale_proto.raw_data != b'' and scale_proto.dims == []: scale_data = self.onnx_model.get_initializer(scale) self.onnx_model.set_initializer(scale, scale_data.astype(np.float32), raw=False) zero_point_proto = self.onnx_model.initializer[zero_point][0] @@ -269,16 +321,17 @@ def format_qlinear_dtype_pass(self): zero_point_data = (zero_point_data - qmin).reshape((1,)) self.onnx_model.set_initializer(zero_point, zero_point_data.astype(np.uint8), raw=False) - def run(self, model_name): - self.format_qlinear_dtype_pass() + + def run(self, model_name, qmin_max_dict): + self.format_qlinear_dtype_pass(qmin_max_dict) self.merge_relu_pass() self.replace_op_pass() self.replace_qlinear_layer_pass() - self.onnx_model.optimize_model() + # self.onnx_model.optimize_model() self.onnx_model.set_opset_version('com.microsoft', 1) - try: - onnx.checker.check_model(self.onnx_model.model) - except onnx.checker.ValidationError as e: - logger.critical('The model is invalid: %s' % e) + # try: + # onnx.checker.check_model(self.onnx_model.model) + # except onnx.checker.ValidationError as e: + # logger.critical('The model is invalid: %s' % e) self.onnx_model.save_onnx_model('{}.onnx'.format(model_name)) diff --git a/mqbench/deploy/deploy_openvino.py b/mqbench/deploy/deploy_openvino.py index bec664cf..ceb3e085 100644 --- a/mqbench/deploy/deploy_openvino.py +++ b/mqbench/deploy/deploy_openvino.py @@ -16,34 +16,33 @@ parse_attrs ) -PERCHANNEL_FAKEQUANTIZER = ['FakeQuantizeLearnablePerchannelAffine', - 'FixedPerChannelAffine', - 'FakeQuantizeDSQPerchannel'] -PERTENSOR_FAKEQUANTIZER = ['LearnablePerTensorAffine', - 'FixedPerTensorAffine', - 'FakeQuantizeDSQPertensor', - 'FakeQuantizeTqtAffine'] -ALL_FAKEQUANTIZER = PERCHANNEL_FAKEQUANTIZER + PERTENSOR_FAKEQUANTIZER +ALL_FAKEQUANTIZER = ['QuantizeLinear', 'DequantizeLinear'] +def get_dequant_node(node, inp2node): + dequant_node = inp2node[node.output[0]][0][0] + if dequant_node.op_type == 'Clip': + dequant_node = inp2node[dequant_node.output[0]][0][0] + assert dequant_node.op_type == 'DequantizeLinear', "This is not correct fakequant node!" + return dequant_node class OPENVINO_process(object): def parse_qparams(self, node, name2data): tensor_name, scale, zero_point = node.input[:3] scale, zero_point = name2data[scale], name2data[zero_point] - if len(node.input) > 3: - qmin, qmax = node.input[-2:] - qmin, qmax = name2data[qmin], name2data[qmax] - elif len(node.attribute) > 0: - qparams = parse_attrs(node.attribute) - qmin = qparams['quant_min'] - qmax = qparams['quant_max'] - else: - logger.info(f'qmin and qmax are not found for <{node.name}>!') - qmax = qmin = None - return tensor_name, scale, zero_point, qmin, qmax - - def replace_fakequantize_and_collect_params(self, onnx_path, model_name): + # if len(node.input) > 3: + # qmin, qmax = node.input[-2:] + # qmin, qmax = name2data[qmin], name2data[qmax] + # elif len(node.attribute) > 0: + # qparams = parse_attrs(node.attribute) + # qmin = qparams['quant_min'] + # qmax = qparams['quant_max'] + # else: + # logger.info(f'qmin and qmax are not found for <{node.name}>!') + # qmax = qmin = None + return tensor_name, scale, zero_point + + def replace_fakequantize_and_collect_params(self, onnx_path, model_name, qmin_max_dict): onnx_graph = ONNXGraph(onnx_path) model = onnx_graph.model graph = model.graph @@ -59,10 +58,15 @@ def replace_fakequantize_and_collect_params(self, onnx_path, model_name): insert_initializer_names = set() for node in graph.node: if node.op_type in ALL_FAKEQUANTIZER: + next_node = inp2node[node.output[0]][0][0] + if next_node.op_type == 'Clip' and inp2node[next_node.output[0]][0][0].op_type == 'DequantizeLinear': + nodes_to_be_removed.append(next_node) nodes_to_be_removed.append(node) nodes_to_be_removed.extend(get_constant_inputs(node, out2node)) - - tensor_name, scale, zero_point, qmin, qmax = self.parse_qparams(node, name2data) + if node.op_type == 'QuantizeLinear': + qmin, qmax = qmin_max_dict[node.name] + dequant_node = get_dequant_node(node, inp2node) + tensor_name, scale, zero_point = self.parse_qparams(node, name2data) qmax = int(qmax) qmin = int(qmin) levels = qmax - qmin + 1 @@ -71,7 +75,7 @@ def replace_fakequantize_and_collect_params(self, onnx_path, model_name): levels = 256 qmax = qmax * 2 + 1 qmin = qmin * 2 - output_name = node.output[0] + output_name = dequant_node.output[0] # Create a node (FakeQuantize) fakeq_inputnames = [item % tensor_name for item in ['input_min_%s', 'input_max_%s', 'output_min_%s', 'output_max_%s']] node_def = helper.make_node( @@ -93,7 +97,7 @@ def replace_fakequantize_and_collect_params(self, onnx_path, model_name): input_low_size = input_low.size try: - next_node = inp2node[node.output[0]][0][0] + next_node = inp2node[dequant_node.output[0]][0][0] # node for save weights fake_node = out2node[next_node.input[1]] tensor = name2data[fake_node.input[0]] @@ -116,7 +120,8 @@ def replace_fakequantize_and_collect_params(self, onnx_path, model_name): graph.initializer.append(initializer) for node in nodes_to_be_removed: - graph.node.remove(node) + if node in graph.node: + graph.node.remove(node) graph.node.extend(node_defs) onnx_graph.topologize_graph() onnx_graph.prepare_initializer() diff --git a/mqbench/deploy/deploy_stpu.py b/mqbench/deploy/deploy_stpu.py index e44d975a..a3dd0f01 100644 --- a/mqbench/deploy/deploy_stpu.py +++ b/mqbench/deploy/deploy_stpu.py @@ -7,15 +7,21 @@ from mqbench.deploy.common import (get_constant_inputs, prepare_data, prepare_initializer, insert_initializer, - update_inp2node_out2node) + update_inp2node_out2node, parse_attrs) from mqbench.deploy.deploy_linear import (PERTENSOR_FAKEQUANTIZER, LinearQuantizer_process) from mqbench.utils.logger import logger - +ALL_FAKEQUANTIZER = ['QuantizeLinear', 'DequantizeLinear'] +def get_dequant_ndoe(node, inp2node): + dequant_node = inp2node[node.output[0]][0][0] + if dequant_node.op_type == 'Clip': + dequant_node = inp2node[dequant_node.output[0]][0][0] + assert dequant_node.op_type == 'DequantizeLinear', "This is not correct fakequant node!" + return dequant_node class STPU_process(LinearQuantizer_process): - def remove_fakequantize_and_collect_params(self, onnx_path, model_name): + def remove_fakequantize_and_collect_params(self, onnx_path, model_name, qmin_max_dict): model = onnx.load(onnx_path) graph = model.graph name2data = prepare_data(graph) @@ -25,21 +31,28 @@ def remove_fakequantize_and_collect_params(self, onnx_path, model_name): quant_params = OrderedDict() nodes_to_be_removed = [] for node in graph.node: - if node.op_type in PERTENSOR_FAKEQUANTIZER: + if node.op_type in ALL_FAKEQUANTIZER: + next_node = inp2node[node.output[0]][0][0] + if next_node.op_type == 'Clip' and inp2node[next_node.output[0]][0][0].op_type == 'DequantizeLinear': + nodes_to_be_removed.append(next_node) nodes_to_be_removed.append(node) nodes_to_be_removed.extend(get_constant_inputs(node, out2node)) - - if node.output[0] not in inp2node: - assert node.output[0] in [x.name for x in graph.output] - inp2node[node.output[0]] = [] - - next_nodes = inp2node[node.output[0]] + if node.op_type == 'QuantizeLinear' and 'axis' not in parse_attrs(node.attribute): + qmin, qmax = qmin_max_dict[node.name] + dequant_node = get_dequant_ndoe(node, inp2node) + next_node = inp2node[node.output[0]][0][0] + if dequant_node.output[0] not in inp2node: + assert dequant_node.output[0] in [x.name for x in graph.output] + inp2node[dequant_node.output[0]] = [] + + next_nodes = inp2node[dequant_node.output[0]] if len(next_nodes) == 1 and next_nodes[0][1] == 1 and next_nodes[0][0].op_type in ['Gemm', 'Conv']: # fake quantize for weights - redundant_nodes = self.deal_with_weight_fakequant(node, out2node, inp2node, named_initializer) - tensor_name, scale, zero_point, qmin, qmax = self.parse_qparams(node, name2data) + redundant_nodes = self.deal_with_weight_fakequant(node, dequant_node, out2node, inp2node, + named_initializer) + tensor_name, scale, zero_point = self.parse_qparams(node, name2data) nodes_to_be_removed.extend(redundant_nodes) - self.clip_weight(node, name2data, inp2node, named_initializer) + self.clip_weight(node, dequant_node, name2data, inp2node, named_initializer, qmin, qmax) # [-127 * scale, 127 * scale] quant_params[next_nodes[0][0].name + '_weights'] = { "min": -127 * scale, @@ -47,10 +60,10 @@ def remove_fakequantize_and_collect_params(self, onnx_path, model_name): } else: # fake quantize for activations - self.deal_with_activation_fakequant(node, inp2node) - tensor_name, scale, zero_point, qmin, qmax = self.parse_qparams(node, name2data) + self.deal_with_activation_fakequant(node, dequant_node, inp2node) + tensor_name, scale, zero_point = self.parse_qparams(node, name2data) for out in graph.output: - if out.name == node.output[0]: + if out.name == dequant_node.output[0]: out.name = tensor_name quant_params[tensor_name] = { "min": -127 * scale, @@ -107,7 +120,8 @@ def remove_fakequantize_and_collect_params(self, onnx_path, model_name): self.update_emin(node, quant_params, named_initializer) # Delete node and init. for node in nodes_to_be_removed: - graph.node.remove(node) + if node in graph.node: + graph.node.remove(node) named_initializer = prepare_initializer(graph) for name, initial_data in named_initializer.items(): if name in (out2node.keys() | inp2node.keys()): diff --git a/mqbench/deploy/deploy_tengine.py b/mqbench/deploy/deploy_tengine.py index 1f63dda5..768a4de5 100644 --- a/mqbench/deploy/deploy_tengine.py +++ b/mqbench/deploy/deploy_tengine.py @@ -13,7 +13,8 @@ prepare_initializer, prepare_data, OnnxPreprocess, - get_constant_inputs + get_constant_inputs, + parse_attrs ) import onnx @@ -24,6 +25,12 @@ logger.warn('onnxsim not found, if you want to use deploy_tengine, please install it.') +def get_dequant_node(node, inp2node): + dequant_node = inp2node[node.output[0]][0][0] + if dequant_node.op_type == 'Clip': + dequant_node = inp2node[dequant_node.output[0]][0][0] + assert dequant_node.op_type == 'DequantizeLinear', "This is not correct fakequant node!" + return dequant_node class Tengine_process(LinearQuantizer_process): @@ -31,7 +38,7 @@ class Tengine_process(LinearQuantizer_process): def get_constant(node: onnx.NodeProto): return numpy_helper.to_array(node.attribute[0].t).tolist() - def remove_fakequantize_and_collect_params(self, onnx_path, model_name): + def remove_fakequantize_and_collect_params(self, onnx_path, model_name, qmin_max_dict): model = onnx.load(onnx_path) graph = model.graph out2node, inp2node = update_inp2node_out2node(graph) @@ -46,32 +53,39 @@ def remove_fakequantize_and_collect_params(self, onnx_path, model_name): nodes_to_be_removed = [] for node in graph.node: if node.op_type in ALL_FAKEQUANTIZER: + next_node = inp2node[node.output[0]][0][0] + if next_node.op_type == 'Clip' and inp2node[next_node.output[0]][0][0].op_type == 'DequantizeLinear': + nodes_to_be_removed.append(next_node) nodes_to_be_removed.append(node) nodes_to_be_removed.extend(get_constant_inputs(node, out2node)) - if node.op_type in PERCHANNEL_FAKEQUANTIZER: + if node.op_type == 'QuantizeLinear' and 'axis' in parse_attrs(node.attribute): # fake quantize for weights, suppose per-channel quantize only for weight - redundant_nodes = self.deal_with_weight_fakequant(node, out2node, inp2node, named_initializer) + qmin, qmax = qmin_max_dict[node.name] + dequant_node = get_dequant_node(node, inp2node) + redundant_nodes = self.deal_with_weight_fakequant(node, dequant_node, out2node, inp2node, named_initializer) nodes_to_be_removed.extend(redundant_nodes) - self.clip_weight(node, name2data, inp2node, named_initializer) - elif node.op_type in PERTENSOR_FAKEQUANTIZER: - if node.output[0] not in inp2node: - assert node.output[0] in [x.name for x in graph.output] - inp2node[node.output[0]] = [] - - next_nodes = inp2node[node.output[0]] + self.clip_weight(node, dequant_node, name2data, inp2node, named_initializer, qmin, qmax) + elif node.op_type == 'QuantizeLinear' and 'axis' not in parse_attrs(node.attribute): + qmin, qmax = qmin_max_dict[node.name] + dequant_node = get_dequant_node(node, inp2node) + if dequant_node.output[0] not in inp2node: + assert dequant_node.output[0] in [x.name for x in graph.output] + inp2node[dequant_node.output[0]] = [] + + next_nodes = inp2node[dequant_node.output[0]] if len(next_nodes) == 1 and next_nodes[0][1] == 1 and next_nodes[0][0].op_type in ['Gemm', 'Conv']: # fake quantize for weights - redundant_nodes = self.deal_with_weight_fakequant(node, out2node, inp2node, named_initializer) - tensor_name, scale, zero_point, qmin, qmax = self.parse_qparams(node, name2data) + redundant_nodes = self.deal_with_weight_fakequant(node, dequant_node, out2node, inp2node, named_initializer) + tensor_name, scale, zero_point = self.parse_qparams(node, name2data) nodes_to_be_removed.extend(redundant_nodes) - self.clip_weight(node, name2data, inp2node, named_initializer) + self.clip_weight(node, dequant_node, name2data, inp2node, named_initializer, qmin, qmax) else: # fake quantize for activations - self.deal_with_activation_fakequant(node, inp2node) - tensor_name, scale, zero_point, qmin, qmax = self.parse_qparams(node, name2data) + self.deal_with_activation_fakequant(node, dequant_node, inp2node) + tensor_name, scale, zero_point = self.parse_qparams(node, name2data) for out in graph.output: - if out.name == node.output[0]: + if out.name == dequant_node.output[0]: out.name = tensor_name quant_params[tensor_name] = [ @@ -98,7 +112,8 @@ def remove_fakequantize_and_collect_params(self, onnx_path, model_name): quant_params[conv_tensor_name] = quant_params[tensor_name] for node in nodes_to_be_removed: - graph.node.remove(node) + if node in graph.node: + graph.node.remove(node) named_initializer = prepare_initializer(graph) for name, initial_data in named_initializer.items(): if name in (out2node.keys() | inp2node.keys()): diff --git a/mqbench/fake_quantize/dorefa.py b/mqbench/fake_quantize/dorefa.py index d570edd5..61c4b4af 100644 --- a/mqbench/fake_quantize/dorefa.py +++ b/mqbench/fake_quantize/dorefa.py @@ -3,7 +3,7 @@ from mqbench.fake_quantize.quantize_base import QuantizeBase -_version_under_1100 = int(torch.__version__.split('.')[1]) < 10 +_version_under_1100 = int(torch.__version__.split('.')[0]) == 1 and int(torch.__version__.split('.')[1]) < 10 class DoReFaFakeQuantize(QuantizeBase): def __init__(self, observer, **observer_kwargs): diff --git a/mqbench/fake_quantize/dsq.py b/mqbench/fake_quantize/dsq.py index 316f1be6..991e5c22 100644 --- a/mqbench/fake_quantize/dsq.py +++ b/mqbench/fake_quantize/dsq.py @@ -5,8 +5,13 @@ from mqbench.fake_quantize.quantize_base import QuantizeBase from mqbench.utils import is_tracing_state from mqbench.utils.hook import PerChannelLoadHook - - +import torch._C._onnx as _C_onnx +from torch.onnx import _type_utils +from torch.onnx import ( + _type_utils, + symbolic_helper, + symbolic_opset9 as opset9, +) def dsq_function_per_tensor(x, scale, zero_point, quant_min, quant_max, alpha): tanh_scale = 1 / (1 - alpha) tanh_k = math.log((tanh_scale + 1) / (tanh_scale - 1)) @@ -84,7 +89,19 @@ def forward(ctx, x, scale, zero_point, quant_min, quant_max, ch_axis, alpha): @staticmethod def symbolic(g, x, scale, zero_point, quant_min, quant_max, ch_axis, alpha): - return g.op("::FakeQuantizeDSQPerchannel", x, scale, zero_point, quant_min_i=quant_min, quant_max_i=quant_max, alpha_f=alpha) + if quant_min == 0: + zero_point = g.op("Cast", zero_point, to_i=_C_onnx.TensorProtoDataType.UINT8) + else: + zero_point = g.op("Cast", zero_point, to_i=_C_onnx.TensorProtoDataType.INT8) + quantized = g.op("QuantizeLinear", x, scale, zero_point, axis_i=ch_axis) + if (quant_min, quant_max) == (0, 127): + quantized = g.op( + "Clip", + quantized, + opset9.unused(g), + g.op("Constant", value_t=torch.tensor(127, dtype=torch.uint8)), + ) + return g.op("DequantizeLinear", quantized, scale, zero_point, axis_i=ch_axis) class FakeQuantizeDSQPertensor(torch.autograd.Function): @@ -94,4 +111,21 @@ def forward(ctx, x, scale, zero_point, quant_min, quant_max, alpha): @staticmethod def symbolic(g, x, scale, zero_point, quant_min, quant_max, alpha): - return g.op("::FakeQuantizeDSQPertensor", x, scale, zero_point, quant_min_i=quant_min, quant_max_i=quant_max, alpha_f=alpha) + if quant_min == 0: + zero_point = g.op("Cast", zero_point, to_i=_C_onnx.TensorProtoDataType.UINT8) + else: + zero_point = g.op("Cast", zero_point, to_i=_C_onnx.TensorProtoDataType.INT8) + if ( + _type_utils.JitScalarType.from_value(scale, _type_utils.JitScalarType.UNDEFINED) + != _type_utils.JitScalarType.FLOAT + ): + scale = g.op("Cast", scale, to_i=_C_onnx.TensorProtoDataType.FLOAT) + quantized = g.op("QuantizeLinear", x, scale, zero_point) + if (quant_min, quant_max) == (0, 127): + quantized = g.op( + "Clip", + quantized, + opset9.unused(g), + g.op("Constant", value_t=torch.tensor(127, dtype=torch.uint8)), + ) + return g.op("DequantizeLinear", quantized, scale, zero_point) \ No newline at end of file diff --git a/mqbench/fake_quantize/fixed.py b/mqbench/fake_quantize/fixed.py index 1fd2ae2b..1ceb2414 100644 --- a/mqbench/fake_quantize/fixed.py +++ b/mqbench/fake_quantize/fixed.py @@ -2,10 +2,7 @@ from mqbench.fake_quantize.quantize_base import QuantizeBase from mqbench.utils.hook import PerChannelLoadHook - - -_version_under_1100 = int(torch.__version__.split('.')[1]) < 10 - +_version_under_1100 = int(torch.__version__.split('.')[0]) == 1 and int(torch.__version__.split('.')[1]) < 10 class FixedFakeQuantize(QuantizeBase): """This is actually torch.quantization.FakeQuantize. """ diff --git a/mqbench/fake_quantize/lsq.py b/mqbench/fake_quantize/lsq.py index c133b0d6..198585bc 100644 --- a/mqbench/fake_quantize/lsq.py +++ b/mqbench/fake_quantize/lsq.py @@ -4,7 +4,12 @@ from mqbench.fake_quantize.quantize_base import QuantizeBase from mqbench.utils import is_symmetric_quant, is_tracing_state from mqbench.utils.hook import PerChannelLoadHook - +from torch.onnx import ( + _type_utils, + symbolic_helper, + symbolic_opset9 as opset9, +) +import torch._C._onnx as _C_onnx class LearnableFakeQuantize(QuantizeBase): r""" This is an extension of the FakeQuantize module in fake_quantize.py, which @@ -64,14 +69,16 @@ def forward(self, X): grad_factor = 1.0 / (X.numel() / X.shape[self.ch_axis] * self.quant_max) ** 0.5 else: grad_factor = 1.0 - if is_tracing_state(): - X = FakeQuantizeLearnablePerchannelAffine.apply( - X, self.scale, self.zero_point, self.ch_axis, - self.quant_min, self.quant_max, grad_factor) - else: - X = _fake_quantize_learnable_per_channel_affine_training( - X, self.scale, self.zero_point, self.ch_axis, + X = torch._fake_quantize_learnable_per_channel_affine(X, self.scale, self.zero_point, self.ch_axis, self.quant_min, self.quant_max, grad_factor) + # if is_tracing_state(): + # X = FakeQuantizeLearnablePerchannelAffine.apply( + # X, self.scale, self.zero_point, self.ch_axis, + # self.quant_min, self.quant_max, grad_factor) + # else: + # X = _fake_quantize_learnable_per_channel_affine_training( + # X, self.scale, self.zero_point, self.ch_axis, + # self.quant_min, self.quant_max, grad_factor) else: if self.use_grad_scaling: grad_factor = 1.0 / (X.numel() * self.quant_max) ** 0.5 @@ -99,12 +106,24 @@ def grad_scale(t, scale): return (t - (t * scale)).detach() + (t * scale) -class FakeQuantizeLearnablePerchannelAffine(torch.autograd.Function): - @staticmethod - def forward(ctx, x, scale, zero_point, ch_axis, quant_min, quant_max, grad_factor): - return _fake_quantize_learnable_per_channel_affine_training(x, scale, zero_point, ch_axis, - quant_min, quant_max, grad_factor) - - @staticmethod - def symbolic(g, x, scale, zero_point, ch_axis, quant_min, quant_max, grad_factor): - return g.op("::FakeQuantizeLearnablePerchannelAffine", x, scale, zero_point, quant_min_i=quant_min, quant_max_i=quant_max) +# class FakeQuantizeLearnablePerchannelAffine(torch.autograd.Function): +# @staticmethod +# def forward(ctx, x, scale, zero_point, ch_axis, quant_min, quant_max, grad_factor): +# return _fake_quantize_learnable_per_channel_affine_training(x, scale, zero_point, ch_axis, +# quant_min, quant_max, grad_factor) +# +# @staticmethod +# def symbolic(g, inputs, scale, zero_point, ch_axis, quant_min, quant_max, grad_factor): +# if quant_min == 0: +# zero_point = g.op("Cast", zero_point, to_i=_C_onnx.TensorProtoDataType.UINT8) +# else: +# zero_point = g.op("Cast", zero_point, to_i=_C_onnx.TensorProtoDataType.INT8) +# quantized = g.op("QuantizeLinear", inputs, scale, zero_point, axis_i=ch_axis) +# if (quant_min, quant_max) == (0, 127): +# quantized = g.op( +# "Clip", +# quantized, +# opset9.unused(g), +# g.op("Constant", value_t=torch.tensor(127, dtype=torch.uint8)), +# ) +# return g.op("DequantizeLinear", quantized, scale, zero_point, axis_i=ch_axis) \ No newline at end of file diff --git a/mqbench/fake_quantize/nnie.py b/mqbench/fake_quantize/nnie.py index d05c18d3..60d4f17e 100644 --- a/mqbench/fake_quantize/nnie.py +++ b/mqbench/fake_quantize/nnie.py @@ -2,8 +2,10 @@ from mqbench.fake_quantize.quantize_base import QuantizeBase from mqbench.utils import no_jit_trace - - +from torch.onnx import ( + symbolic_helper, +) +from torch.onnx import register_custom_op_symbolic class NNIEFakeQuantize(QuantizeBase): def __init__(self, observer, **observer_kwargs): super(NNIEFakeQuantize, self).__init__(observer, **observer_kwargs) @@ -40,5 +42,7 @@ def backward(ctx, grad_output): return grad_input, None @staticmethod + @symbolic_helper.parse_args("v", "f") def symbolic(g, x, data_max): - return g.op("::NNIEQuantize", x, data_max) \ No newline at end of file + quantized = g.op("QuantizeLinear", x, data_max_f=data_max) + return g.op("DequantizeLinear", quantized, data_max_f=data_max) \ No newline at end of file diff --git a/mqbench/fake_quantize/tqt.py b/mqbench/fake_quantize/tqt.py index 260563ad..45c27224 100644 --- a/mqbench/fake_quantize/tqt.py +++ b/mqbench/fake_quantize/tqt.py @@ -2,7 +2,8 @@ from mqbench.fake_quantize.quantize_base import QuantizeBase from mqbench.utils import is_symmetric_quant - +import torch._C._onnx as _C_onnx +from torch.onnx import _type_utils class TqtFakeQuantize(QuantizeBase): def __init__(self, observer, scale=1., zero_point=0., **observer_kwargs): @@ -114,4 +115,21 @@ def backward(ctx, grad_outputs): @staticmethod def symbolic(g, x, scale, zero_point, quant_min, quant_max, mth): - return g.op("::FakeQuantizeTqtAffine", x, scale, zero_point, quant_min_i=quant_min, quant_max_i=quant_max) + if quant_min == 0: + zero_point = g.op("Cast", zero_point, to_i=_C_onnx.TensorProtoDataType.UINT8) + else: + zero_point = g.op("Cast", zero_point, to_i=_C_onnx.TensorProtoDataType.INT8) + if ( + _type_utils.JitScalarType.from_value(scale, _type_utils.JitScalarType.UNDEFINED) + != _type_utils.JitScalarType.FLOAT + ): + scale = g.op("Cast", scale, to_i=_C_onnx.TensorProtoDataType.FLOAT) + quantized = g.op("QuantizeLinear", inputs, scale, zero_point) + if (quant_min, quant_max) == (0, 127): + quantized = g.op( + "Clip", + quantized, + opset9.unused(g), + g.op("Constant", value_t=torch.tensor(127, dtype=torch.uint8)), + ) + return g.op("DequantizeLinear", quantized, scale, zero_point) diff --git a/mqbench/fuser_method_mappings.py b/mqbench/fuser_method_mappings.py index d4baf662..57387cf1 100644 --- a/mqbench/fuser_method_mappings.py +++ b/mqbench/fuser_method_mappings.py @@ -1,35 +1,94 @@ import torch import torch.nn as nn -from torch.quantization.fx.fusion_patterns import ConvBNReLUFusion, ModuleReLUFusion +# from torch.quantization.fx.fusion_patterns import ConvBNReLUFusion, ModuleReLUFusion from torch.quantization.fx.quantization_types import QuantizerCls from torch.fx.graph import Node +from collections import namedtuple import mqbench.nn as qnn import mqbench.nn.intrinsic as qnni import mqbench.nn.intrinsic.qat as qnniqat from mqbench.utils.fusion import fuse_deconv_bn_eval from mqbench.nn.modules import FrozenBatchNorm2d +from torch.ao.quantization.fx.fuse_handler import DefaultFuseHandler +from torch.ao.quantization.backend_config import ( + BackendPatternConfig, + DTypeConfig, + DTypeWithConstraints, + ObservationType, +) +import torch.ao.nn.intrinsic as nni +import torch.ao.nn.intrinsic.qat as nniqat +import torch.ao.nn.qat as nnqat +import torch.ao.nn.quantized.reference as nnqr +import torch.nn.functional as F -class ConvExtendBnReLUFusion(ConvBNReLUFusion): - def __init__(self, quantizer: QuantizerCls, node: Node): - super(ConvBNReLUFusion, self).__init__(quantizer, node) - self.relu_node = None - self.bn_node = None - if (node.op == 'call_function' and node.target is torch.nn.functional.relu) or \ - (node.op == 'call_module' and type(quantizer.modules[node.target]) == torch.nn.ReLU): - self.relu_node = node - assert isinstance(node.args[0], Node) - node = node.args[0] - assert node.op == 'call_module' - if type(quantizer.modules[node.target]) in [torch.nn.BatchNorm1d, torch.nn.BatchNorm2d, torch.nn.BatchNorm3d, FrozenBatchNorm2d]: - self.bn_node = node - self.bn = quantizer.modules[self.bn_node.target] - assert isinstance(node.args[0], Node) - node = node.args[0] - assert node.op == 'call_module' - self.conv_node = node - self.conv = quantizer.modules[self.conv_node.target] + + +def _get_custom_conv_configs(dtype_configs): + """ + Return all configs related to conv modules and ops. + """ + conv_configs = [] + observation_type = ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT + # 1 conv transpose + bn + relu + conv_configs.append( + BackendPatternConfig((nn.ConvTranspose2d, nn.BatchNorm2d, nn.ReLU)) + .set_dtype_configs(dtype_configs) # noqa: E131 + .set_fuser_method(fuse_deconv_bn_relu) + .set_fused_module(qnni.ConvTransposeBnReLU2d) + ) + + conv_configs.append( + BackendPatternConfig(qnni.ConvTransposeBnReLU2d) + .set_dtype_configs(dtype_configs) # noqa: E131 + .set_observation_type(observation_type) + .set_root_module(nn.ConvTranspose2d) + .set_reference_quantized_module(nnqr.ConvTranspose2d) + .set_qat_module(qnniqat.ConvTransposeBnReLU2d) + ) + # 2 conv transpose + bn + conv_configs.append( + BackendPatternConfig((nn.ConvTranspose2d, nn.BatchNorm2d)) + .set_dtype_configs(dtype_configs) # noqa: E131 + .set_fuser_method(fuse_deconv_bn) + .set_fused_module(qnni.ConvTransposeBn2d) + ) + conv_configs.append( + BackendPatternConfig(qnni.ConvTransposeBn2d) + .set_dtype_configs(dtype_configs) # noqa: E131 + .set_observation_type(observation_type) + .set_root_module(nn.ConvTranspose2d) + .set_reference_quantized_module(nnqr.ConvTranspose2d) + .set_qat_module(qnniqat.ConvTransposeBn2d) + ) + # 3 conv transpose + conv_configs.append( + BackendPatternConfig(nn.ConvTranspose2d) + .set_dtype_configs(dtype_configs) # noqa: E131 + .set_observation_type(observation_type) + .set_root_module(nn.ConvTranspose2d) + .set_reference_quantized_module(nnqr.ConvTranspose2d) + .set_qat_module(qnn.qat.ConvTranspose2d) + ) + # 4 linear bn + conv_configs.append( + BackendPatternConfig((nn.Linear, nn.BatchNorm1d)) + .set_dtype_configs(dtype_configs) # noqa: E131 + .set_fuser_method(fuse_linear_bn) + .set_fused_module(qnni.LinearBn1d) + ) + conv_configs.append( + BackendPatternConfig(qnni.LinearBn1d) + .set_dtype_configs(dtype_configs) # noqa: E131 + .set_observation_type(observation_type) + .set_root_module(nn.Linear) + .set_reference_quantized_module(nnqr.Linear) + .set_qat_module(qnniqat.LinearBn1d) + ) + + return conv_configs def fuse_linear_bn(linear, bn): @@ -56,7 +115,7 @@ def fuse_linear_bn(linear, bn): return nn.utils.fusion.fuse_linear_bn_eval(linear, bn) -def fuse_deconv_bn(deconv, bn): +def fuse_deconv_bn(is_qat, deconv, bn): assert deconv.training == bn.training, \ 'DeConv and BN must be in the same mode (train or eval)' @@ -69,7 +128,7 @@ def fuse_deconv_bn(deconv, bn): return fuse_deconv_bn_eval(deconv, bn) -def fuse_deconv_bn_relu(deconv, bn, relu): +def fuse_deconv_bn_relu(is_qat, deconv, bn, relu): assert deconv.training == bn.training == relu.training, \ "DeConv and BN both must be in the same mode (train or eval)." @@ -82,7 +141,7 @@ def fuse_deconv_bn_relu(deconv, bn, relu): return qnni.ConvTransposeReLU2d(fuse_deconv_bn_eval(deconv, bn), relu) -def fuse_conv_freezebn(conv, bn): +def fuse_conv_freezebn(is_qat, conv, bn): assert bn.training is False, "Freezebn must be eval." if conv.training: @@ -94,7 +153,7 @@ def fuse_conv_freezebn(conv, bn): return nn.utils.fuse_conv_bn_eval(conv, bn) -def fuse_conv_freezebn_relu(conv, bn, relu): +def fuse_conv_freezebn_relu(is_qat, conv, bn, relu): assert conv.training == relu.training and bn.training is False, \ "Conv and relu both must be in the same mode (train or eval) and bn must be eval." @@ -108,7 +167,7 @@ def fuse_conv_freezebn_relu(conv, bn, relu): return nn.intrinsic.ConvReLU2d(fused_conv, relu) -def fuse_deconv_freezebn(deconv, bn): +def fuse_deconv_freezebn(is_qat, deconv, bn): assert bn.training is False, "Freezebn must be eval." if deconv.training: @@ -120,7 +179,7 @@ def fuse_deconv_freezebn(deconv, bn): return fuse_deconv_bn_eval(deconv, bn) -def fuse_deconv_freezebn_relu(deconv, bn, relu): +def fuse_deconv_freezebn_relu(is_qat, deconv, bn, relu): assert deconv.training == relu.training and bn.training is False, \ "Conv and relu both must be in the same mode (train or eval) and bn must be eval." @@ -133,77 +192,4 @@ def fuse_deconv_freezebn_relu(deconv, bn, relu): return qnni.ConvTransposeReLU2d(fuse_deconv_bn_eval(deconv, bn), relu) -fuse_custom_config_dict = { - "additional_fuser_method_mapping": { - (torch.nn.Linear, torch.nn.BatchNorm1d): fuse_linear_bn, - (torch.nn.ConvTranspose2d, torch.nn.BatchNorm2d): fuse_deconv_bn, - (torch.nn.ConvTranspose2d, torch.nn.BatchNorm2d, torch.nn.ReLU): fuse_deconv_bn_relu, - (torch.nn.ConvTranspose2d, torch.nn.ReLU): qnni.ConvTransposeReLU2d, - (nn.Conv2d, FrozenBatchNorm2d, nn.ReLU): fuse_conv_freezebn_relu, - (nn.Conv2d, FrozenBatchNorm2d): fuse_conv_freezebn, - (nn.ConvTranspose2d, FrozenBatchNorm2d, nn.ReLU): fuse_deconv_freezebn_relu, - (nn.ConvTranspose2d, FrozenBatchNorm2d): fuse_deconv_freezebn, - }, - "additional_fusion_pattern": { - (torch.nn.BatchNorm1d, torch.nn.Linear): - ConvBNReLUFusion, - (torch.nn.BatchNorm2d, torch.nn.ConvTranspose2d): - ConvBNReLUFusion, - (torch.nn.ReLU, torch.nn.ConvTranspose2d): - ConvBNReLUFusion, - (torch.nn.ReLU, (torch.nn.BatchNorm2d, torch.nn.ConvTranspose2d)): - ConvBNReLUFusion, - (torch.nn.functional.relu, torch.nn.ConvTranspose2d): - ConvBNReLUFusion, - (torch.nn.functional.relu, (torch.nn.BatchNorm2d, torch.nn.ConvTranspose2d)): - ConvBNReLUFusion, - (torch.nn.ReLU, (FrozenBatchNorm2d, torch.nn.Conv2d)): - ConvExtendBnReLUFusion, - (FrozenBatchNorm2d, torch.nn.Conv2d): - ConvExtendBnReLUFusion, - (torch.nn.ReLU, (FrozenBatchNorm2d, torch.nn.ConvTranspose2d)): - ConvExtendBnReLUFusion, - (FrozenBatchNorm2d, torch.nn.ConvTranspose2d): - ConvExtendBnReLUFusion, - }, - "additional_qat_module_mappings": { - nn.ConvTranspose2d: qnn.qat.ConvTranspose2d, - qnni.LinearBn1d: qnniqat.LinearBn1d, - qnni.ConvTransposeBn2d: qnniqat.ConvTransposeBn2d, - qnni.ConvTransposeReLU2d: qnniqat.ConvTransposeReLU2d, - qnni.ConvTransposeBnReLU2d: qnniqat.ConvTransposeBnReLU2d, - qnni.ConvFreezebn2d: qnniqat.ConvFreezebn2d, - qnni.ConvFreezebnReLU2d: qnniqat.ConvFreezebnReLU2d, - qnni.ConvTransposeFreezebn2d: qnniqat.ConvTransposeFreezebn2d, - qnni.ConvTransposeFreezebnReLU2d: qnniqat.ConvTransposeFreezebnReLU2d, - nn.Embedding: qnn.qat.Embedding, - }, -} - - -def _sort_fusion_patterns(pats): - keys = [] - for key in pats.keys(): - if pats[key] is ModuleReLUFusion: - keys.append(key) - for key in keys: - pats.move_to_end(key) - - -# Sinse additional_fuser_method_mapping will not be set because fuser.py:54 -# do not pass this dict. -from torch.quantization.fuser_method_mappings import DEFAULT_OP_LIST_TO_FUSER_METHOD -from torch.quantization.fx.pattern_utils import DEFAULT_FUSION_PATTERNS -from torch.quantization.quantization_mappings import DEFAULT_QAT_MODULE_MAPPINGS - -DEFAULT_OP_LIST_TO_FUSER_METHOD.update( - fuse_custom_config_dict['additional_fuser_method_mapping']) -DEFAULT_FUSION_PATTERNS.update( - fuse_custom_config_dict['additional_fusion_pattern']) -# Make longer matched pattern prior. -# i.e. Conv + BN + Relu should match ConvBnRelu before BNRelu. -# Any thing registered in class ConvBNReLUFusion should be -# proir than class ModuleReLUFusion. -_sort_fusion_patterns(DEFAULT_FUSION_PATTERNS) -DEFAULT_QAT_MODULE_MAPPINGS.update( - fuse_custom_config_dict['additional_qat_module_mappings']) + diff --git a/mqbench/fusion_method.py b/mqbench/fusion_method.py index 2bbf693c..09e99310 100644 --- a/mqbench/fusion_method.py +++ b/mqbench/fusion_method.py @@ -1,13 +1,13 @@ import torch import torch.nn.intrinsic.qat as nniqat from torch.nn.utils.fusion import fuse_conv_bn_eval, fuse_linear_bn_eval -from torch.quantization.fx.utils import _parent_name +from torch.ao.quantization.utils import _parent_name import mqbench.nn.intrinsic as qnni import mqbench.nn.intrinsic.qat as qnniqat import mqbench.nn.qat as qnnqat from mqbench.utils.registry import register_convert_function -from mqbench.fuser_method_mappings import fuse_deconv_bn_eval +from mqbench.utils.fusion import fuse_deconv_bn_eval from mqbench.quantization.default_bias_fake_quant import bias_fake_quantizer @@ -30,7 +30,7 @@ def convert_qnniqat_linearbn(model, fused_node): if fused_module.bias is not None: linear.bias = fused_module.bias # Merge Linear + BN - fused_linear = fuse_linear_bn_eval(linear.eval(), fused_module.bn) + fused_linear = fuse_linear_bn_eval(linear.eval(), fused_module.bn.eval()) # We need nn.qat.linear here to export weight quantize node. linear.qconfig = fused_module.qconfig linear = torch.nn.qat.Linear.from_float(linear) @@ -69,7 +69,7 @@ def convert_nniqat_convbn(model, fused_node): conv.weight = fused_module.weight if fused_module.bias is not None: conv.bias = fused_module.bias - fused_conv = fuse_conv_bn_eval(conv.eval(), fused_module.bn) + fused_conv = fuse_conv_bn_eval(conv.eval(), fused_module.bn.eval()) # We need nn.qat.conv here to export weight quantize node. fused_conv.qconfig = fused_module.qconfig fused_conv = fused_qat_module_class_map[type(conv)].from_float(fused_conv) @@ -146,7 +146,7 @@ def convert_qnniqat_deconvbn(model, fused_node): deconv.weight = fused_module.weight if fused_module.bias is not None: deconv.bias = fused_module.bias - fused_deconv = fuse_deconv_bn_eval(deconv.eval(), fused_module.bn) + fused_deconv = fuse_deconv_bn_eval(deconv.eval(), fused_module.bn.eval()) # We need nn.qat.conv here to export weight quantize node. fused_deconv.qconfig = fused_module.qconfig fused_deconv = qnnqat.ConvTranspose2d.from_float(fused_deconv) @@ -226,7 +226,7 @@ def convert_qnniqat_convbn(model, fused_node): conv.weight = fused_module.weight if fused_module.bias is not None: conv.bias = fused_module.bias - fused_conv = fuse_conv_bn_eval(conv.eval(), fused_module.bn) + fused_conv = fuse_conv_bn_eval(conv.eval(), fused_module.bn.eval()) # We need nn.qat.conv here to export weight quantize node. fused_conv.qconfig = fused_module.qconfig fused_conv = qnnqat.Conv2d.from_float(fused_conv) diff --git a/mqbench/observer.py b/mqbench/observer.py index 74c316f3..31888efb 100644 --- a/mqbench/observer.py +++ b/mqbench/observer.py @@ -123,7 +123,7 @@ def forward(self, x_orig): return x_orig x = x_orig.to(self.min_val.dtype) if self.ch_axis == -1: - min_val_cur, max_val_cur = torch._aminmax(x) + min_val_cur, max_val_cur = torch.aminmax(x) else: x_dim = x.size() new_axis_list = [i for i in range(len(x_dim))] @@ -131,7 +131,7 @@ def forward(self, x_orig): new_axis_list[0] = self.ch_axis y = x.permute(new_axis_list) y = torch.flatten(y, start_dim=1) - min_val_cur, max_val_cur = torch._aminmax(y, 1) + min_val_cur, max_val_cur = torch.aminmax(y, dim = 1) self.min_val = torch.min(self.min_val, min_val_cur) self.max_val = torch.max(self.max_val, max_val_cur) @@ -162,10 +162,10 @@ def forward(self, x_orig): return x_orig x = x_orig.to(self.min_val.dtype) if self.ch_axis == -1: - min_val_cur, max_val_cur = torch._aminmax(x) + min_val_cur, max_val_cur = torch.aminmax(x) else: logger.warn('The per-tensor observer does not support per-channel min-max!') - min_val_cur, max_val_cur = torch._aminmax(x) + min_val_cur, max_val_cur = torch.aminmax(x) self.min_val = min_val_cur self.max_val = max_val_cur @@ -232,7 +232,7 @@ def forward(self, x_orig): return x_orig x = x_orig.to(self.min_val.dtype) if self.ch_axis == -1: - min_val_cur, max_val_cur = torch._aminmax(x) + min_val_cur, max_val_cur = torch.aminmax(x) else: x_dim = x.size() new_axis_list = [i for i in range(len(x_dim))] # noqa: C416 @@ -240,7 +240,7 @@ def forward(self, x_orig): new_axis_list[0] = self.ch_axis y = x.permute(new_axis_list) y = torch.flatten(y, start_dim=1) - min_val_cur, max_val_cur = torch._aminmax(y, 1) + min_val_cur, max_val_cur = torch.aminmax(y, dim=1) if self.max_val.numel() <= 1 and self.max_val.isinf(): self.min_val = min_val_cur @@ -270,10 +270,10 @@ def forward(self, x_orig): return x_orig x = x_orig.to(self.min_val.dtype) if self.ch_axis == -1: - min_val_cur, max_val_cur = torch._aminmax(x) + min_val_cur, max_val_cur = torch.aminmax(x) else: logger.warn('The per-tensor observer does not support per-channel min-max!') - min_val_cur, max_val_cur = torch._aminmax(x) + min_val_cur, max_val_cur = torch.aminmax(x) self.min_val = min_val_cur self.max_val = max_val_cur @@ -344,7 +344,7 @@ def forward(self, x_orig): if x_orig.numel() == 0: return x_orig x = x_orig.to(self.min_val.dtype) - min_val_cur, max_val_cur = torch._aminmax(x) + min_val_cur, max_val_cur = torch.aminmax(x) max_hist_range = torch.max(-min_val_cur, max_val_cur) hist = torch.histc(torch.abs(x), bins=self.bins, min=0., max=max_hist_range) cur_total = 0 @@ -381,7 +381,7 @@ def forward(self, x_orig): return x_orig x = x_orig.to(self.min_val.dtype) if self.ch_axis == -1: - min_val_cur, max_val_cur = torch._aminmax(x) + min_val_cur, max_val_cur = torch.aminmax(x) mean = x.mean() std = x.std() else: @@ -391,7 +391,7 @@ def forward(self, x_orig): new_axis_list[0] = self.ch_axis y = x.permute(new_axis_list) y = torch.flatten(y, start_dim=1) - min_val_cur, max_val_cur = torch._aminmax(y, 1) + min_val_cur, max_val_cur = torch.aminmax(y, dim=1) mean = y.mean(1) std = y.std(1) @@ -422,7 +422,7 @@ def forward(self, x_orig): x = x_orig.to(self.min_val.dtype) if self.ch_axis == -1: self.tensor_norm = x.abs().mean() - self.min_val, self.max_val = torch._aminmax(x) + self.min_val, self.max_val = torch.aminmax(x) else: # compute channel-wise mean x_dim = x.size() @@ -432,7 +432,7 @@ def forward(self, x_orig): y = x.permute(new_axis_list) y = torch.flatten(y, start_dim=1) self.tensor_norm = y.abs().mean(1) - self.min_val, self.max_val = torch._aminmax(y, 1) + self.min_val, self.max_val = torch.aminmax(y, dim=1) return x @@ -468,7 +468,7 @@ def forward(self, x_orig): if self.ch_axis == -1: self.mean = x.mean() self.std = x.std() - self.min_val, self.max_val = torch._aminmax(x) + self.min_val, self.max_val = torch.aminmax(x) else: # compute channel-wise mean x_dim = x.size() @@ -479,7 +479,7 @@ def forward(self, x_orig): y = torch.flatten(y, start_dim=1) self.mean = y.mean(1) self.std = y.std(1) - self.min_val, self.max_val = torch._aminmax(y) + self.min_val, self.max_val = torch.aminmax(y) return x @@ -487,10 +487,11 @@ def calculate_qparams(self): scale = torch.maximum((self.mean - 3 * self.std).abs(), (self.mean + 3 * self.std).abs()) / (self.quant_max - self.quant_min + 1) sync_tensor(scale) - sync_tensor(zero_point) + # sync_tensor(zero_point) if self.pot_scale: scale = pot_quantization(scale) zero_point = torch.zeros_like(self.mean) + sync_tensor(zero_point) if not is_symmetric_quant(self.qscheme): if self.min_val >= 0.: zero_point = self.quant_min - torch.round(self.min_val / scale) @@ -560,7 +561,7 @@ def forward(self, x_orig): return x_orig x = x_orig.clone().detach().to(self.min_val.dtype) if self.ch_axis == -1: - min_val_cur, max_val_cur = torch._aminmax(x) + min_val_cur, max_val_cur = torch.aminmax(x) min_val_cur, max_val_cur = self.mse(x, min_val_cur, max_val_cur, iter=95) else: x_dim = x.size() @@ -569,7 +570,7 @@ def forward(self, x_orig): new_axis_list[0] = self.ch_axis x_channel = x.permute(new_axis_list) y = torch.flatten(x_channel, start_dim=1) - min_val_cur, max_val_cur = torch._aminmax(y, 1) + min_val_cur, max_val_cur = torch.aminmax(y, dim=1) min_val_cur, max_val_cur = self.mse_perchannel(x, min_val_cur, max_val_cur, iter=80, ch_axis=self.ch_axis) self.min_val = torch.min(self.min_val, min_val_cur) @@ -639,7 +640,7 @@ def forward(self, x_orig): return x_orig x = x_orig.clone().detach().to(self.min_val.dtype) if self.ch_axis == -1: - min_val_cur, max_val_cur = torch._aminmax(x) + min_val_cur, max_val_cur = torch.aminmax(x) min_val_cur, max_val_cur = self.mse(x, min_val_cur, max_val_cur, iter=95) else: x_dim = x.size() @@ -648,7 +649,7 @@ def forward(self, x_orig): new_axis_list[0] = self.ch_axis x_channel = x.permute(new_axis_list) y = torch.flatten(x_channel, start_dim=1) - min_val_cur, max_val_cur = torch._aminmax(y, 1) + min_val_cur, max_val_cur = torch.aminmax(y, dim=1) min_val_cur, max_val_cur = self.mse_perchannel(x, min_val_cur, max_val_cur, iter=80, ch_axis=self.ch_axis) if self.max_val.numel() <= 1 and self.max_val.isinf(): diff --git a/mqbench/prepare_by_platform.py b/mqbench/prepare_by_platform.py index ccd0d5be..c7b1f0e1 100644 --- a/mqbench/prepare_by_platform.py +++ b/mqbench/prepare_by_platform.py @@ -9,7 +9,7 @@ from torch.fx.graph_module import GraphModule from torch.quantization.quantize_fx import _swap_ff_with_fxff from torch.quantization import QConfig - +from mqbench.fuser_method_mappings import _get_custom_conv_configs from mqbench.fake_quantize import ( LearnableFakeQuantize, @@ -33,13 +33,25 @@ MSEObserver, EMAMSEObserver, ) -from mqbench.fuser_method_mappings import fuse_custom_config_dict +# from mqbench.fuser_method_mappings import fuse_custom_config_dict from mqbench.utils.logger import logger from mqbench.utils.registry import DEFAULT_MODEL_QUANTIZER from mqbench.scheme import QuantizeScheme - +from torch.ao.quantization.backend_config import ( + BackendConfig, + get_native_backend_config, +get_tensorrt_backend_config, +DTypeConfig +) +from torch.ao.quantization.backend_config.native import weighted_op_quint8_dtype_config __all__ = ['prepare_by_platform'] +weighted_op_qint8_dtype_config = DTypeConfig( + input_dtype=torch.qint8, + output_dtype=torch.qint8, + weight_dtype=torch.qint8, + bias_dtype=torch.float, +) class BackendType(Enum): Academic = 'Academic' Tensorrt = 'Tensorrt' @@ -54,6 +66,7 @@ class BackendType(Enum): Tensorrt_NLP = "Tensorrt_NLP" Academic_NLP = "Academic_NLP" STPU = "STPU" + QDQ = "QDQ" ParamsTable = { @@ -129,6 +142,15 @@ class BackendType(Enum): default_act_quantize=FixedFakeQuantize, default_weight_observer=MinMaxObserver, default_act_observer=EMAMinMaxObserver), + BackendType.QDQ: dict(qtype='affine', # noqa: E241 + w_qscheme=QuantizeScheme(symmetry=True, per_channel=True, pot_scale=False, bit=8, + symmetric_range=True), + a_qscheme=QuantizeScheme(symmetry=True, per_channel=False, pot_scale=False, bit=8, + symmetric_range=True), + default_weight_quantize=LearnableFakeQuantize, + default_act_quantize=LearnableFakeQuantize, + default_weight_observer=MinMaxObserver, + default_act_observer=EMAMinMaxObserver), } ParamsTable[BackendType.Tensorrt_NLP] = ParamsTable[BackendType.Tensorrt] ParamsTable[BackendType.Academic_NLP] = ParamsTable[BackendType.Academic] @@ -341,8 +363,10 @@ def _get_attrs(target, attrs): def prepare_by_platform( model: torch.nn.Module, deploy_backend: BackendType, + is_qat: bool = False, prepare_custom_config_dict: Dict[str, Any] = {}, - custom_tracer: Tracer = None): + custom_tracer: Tracer = None, + freeze_bn: bool = True): """ Args: model (torch.nn.Module): @@ -367,7 +391,8 @@ def prepare_by_platform( # Get Qconfig extra_qconfig_dict = prepare_custom_config_dict.get('extra_qconfig_dict', {}) qconfig = get_qconfig_by_platform(deploy_backend, extra_qconfig_dict) - + backend_config = get_native_backend_config() + backend_config.set_backend_pattern_configs(_get_custom_conv_configs(weighted_op_qint8_dtype_config)) _swap_ff_with_fxff(model) # Preserve attr. preserve_attr_dict = dict() @@ -396,12 +421,12 @@ def prepare_by_platform( graph_module = GraphModule(modules, graph, name) # Model fusion. extra_fuse_dict = prepare_custom_config_dict.get('extra_fuse_dict', {}) - extra_fuse_dict.update(fuse_custom_config_dict) + # extra_fuse_dict.update(fuse_custom_config_dict) # Prepare import mqbench.custom_quantizer # noqa: F401 extra_quantizer_dict = prepare_custom_config_dict.get('extra_quantizer_dict', {}) quantizer = DEFAULT_MODEL_QUANTIZER[deploy_backend](extra_quantizer_dict, extra_fuse_dict) - prepared = quantizer.prepare(graph_module, qconfig) + prepared = quantizer.prepare(graph_module, qconfig, is_qat, backend_config, freeze_bn) # Restore attr. if 'preserve_attr' in prepare_custom_config_dict: for submodule_name in prepare_custom_config_dict['preserve_attr']: diff --git a/mqbench/quantization/qconfig_mapping_utils.py b/mqbench/quantization/qconfig_mapping_utils.py new file mode 100644 index 00000000..e0397e82 --- /dev/null +++ b/mqbench/quantization/qconfig_mapping_utils.py @@ -0,0 +1,36 @@ +def get_flattened_qconfig_dict(qconfig_dict): + """ flatten the global, object_type and module_name qconfig + to the same qconfig_dict so that it can be used by + propagate_qconfig_ function. + "module_name_regex" is ignored for now since it's not supported + in propagate_qconfig_, but it can be fixed later. + + For example: + Input: { + "": qconfig, + "object_type": [ + (torch.add, qconfig) + ], + "module_name": [ + ("conv", qconfig) + ] + } + + Output: { + "": qconfig, + torch.add: qconfig, + "conv": qconfig + } + """ + flattened = dict() + if '' in qconfig_dict: + flattened[''] = qconfig_dict[''] + + def flatten_key(key): + if key in qconfig_dict: + for (obj, qconfig) in qconfig_dict[key].items(): + flattened[obj] = qconfig + + flatten_key('object_type') + flatten_key('module_name') + return flattened \ No newline at end of file