Skip to content

Commit

Permalink
update torch 2.5.1
Browse files Browse the repository at this point in the history
  • Loading branch information
yinnengzhong committed Jan 10, 2025
1 parent 09ff5cb commit 93084c8
Show file tree
Hide file tree
Showing 26 changed files with 891 additions and 462 deletions.
101 changes: 66 additions & 35 deletions mqbench/convert_deploy.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,13 @@
remove_fakequantize_and_collect_params_stpu,
ONNXQLinearPass, ONNXQNNPass
)

import onnx
from onnxsim import simplify
from mqbench.deploy.common import (
parse_attrs
)
__all__ = ['convert_deploy']

qmin_max_dict = {}
@register_deploy_function(BackendType.STPU)
@register_deploy_function(BackendType.Tengine_u8)
@register_deploy_function(BackendType.PPLCUDA)
Expand All @@ -34,6 +38,7 @@
@register_deploy_function(BackendType.NNIE)
@register_deploy_function(BackendType.Vitis)
@register_deploy_function(BackendType.OPENVINO)
@register_deploy_function(BackendType.QDQ)
def convert_merge_bn(model: GraphModule, **kwargs):
logger.info("Merge BN for deploy.")
nodes = list(model.graph.nodes)
Expand All @@ -57,6 +62,7 @@ def convert_merge_bn(model: GraphModule, **kwargs):
@register_deploy_function(BackendType.NNIE)
@register_deploy_function(BackendType.Vitis)
@register_deploy_function(BackendType.OPENVINO)
@register_deploy_function(BackendType.QDQ)
def convert_onnx(model: GraphModule, input_shape_dict, dummy_input, onnx_model_path, **kwargs):
logger.info("Export to onnx.")
output_names = kwargs.get('output_names', [])
Expand All @@ -68,29 +74,54 @@ def convert_onnx(model: GraphModule, input_shape_dict, dummy_input, onnx_model_p
input_names = list(dummy_input.keys())
dummy_input = tuple(dummy_input.values())
# Per-channel QuantizeLinear and DequantizeLinear is supported since opset 13
opset_version = 13 if kwargs.get('deploy_to_qlinear', False) else 11
opset_version = 13 if kwargs.get('deploy_to_qlinear', False) else 13
with torch.no_grad():
# try:
# from torch.onnx.utils import ONNXCheckerError
# try:
torch.onnx.export(model, dummy_input, onnx_model_path,
input_names=input_names,
output_names=output_names,
opset_version=opset_version,
dynamic_axes=dynamic_axes,
do_constant_folding=True,
custom_opsets={'' : opset_version})


# except ONNXCheckerError:
# pass
# except ImportError:
# torch.onnx.export(model, dummy_input, onnx_model_path,
# input_names=input_names,
# output_names=output_names,
# opset_version=opset_version,
# do_constant_folding=True,
# custom_opsets={'' : opset_version},
# enable_onnx_checker=False)
onnx_model = onnx.load(onnx_model_path)
graph = onnx_model.graph
for node in graph.node:
if len(node.attribute) > 1:
qparams = parse_attrs(node.attribute)
if 'quant_max' in qparams:
qmin_max_dict[node.name] = (qparams['quant_min'], qparams['quant_max'])
new_attributes = []
for attr in node.attribute:
if attr.name not in ["quant_min", "quant_max"]:
new_attributes.append(attr)
node.ClearField("attribute")
node.attribute.extend(new_attributes)
onnx.save(onnx_model, onnx_model_path)
try:
from torch.onnx.utils import ONNXCheckerError
try:
torch.onnx.export(model, dummy_input, onnx_model_path,
input_names=input_names,
output_names=output_names,
opset_version=opset_version,
dynamic_axes=dynamic_axes,
do_constant_folding=True,
custom_opsets={'' : opset_version})
except ONNXCheckerError:
pass
except ImportError:
torch.onnx.export(model, dummy_input, onnx_model_path,
input_names=input_names,
output_names=output_names,
opset_version=opset_version,
do_constant_folding=True,
custom_opsets={'' : opset_version},
enable_onnx_checker=False)

logger.info("simplify model.")
onnx_model = onnx.load(onnx_model_path)
onnx_model_simplified, check = simplify(onnx_model)
onnx.save(onnx_model_simplified, onnx_model_path)
except Exception as e:
logger.info("simplify model fail.")
# onnx.checker.check_model(onnx_model_simplified)
# import onnxruntime as ort
# session = ort.InferenceSession(onnx_model_path)

@register_deploy_function(BackendType.Tensorrt)
def convert_onnx_qlinear(model: GraphModule, onnx_model_path, model_name, **kwargs):
Expand All @@ -108,58 +139,58 @@ def deploy_qparams_nnie(model: GraphModule, onnx_model_path, model_name, **kwarg
@register_deploy_function(BackendType.OPENVINO)
def deploy_qparams_openvino(model: GraphModule, onnx_model_path, model_name, **kwargs):
logger.info("Extract qparams for OPENVINO.")
replace_fakequantize_and_collect_params_openvino(onnx_model_path, model_name)
replace_fakequantize_and_collect_params_openvino(onnx_model_path, model_name, qmin_max_dict = qmin_max_dict)


@register_deploy_function(BackendType.Tensorrt)
def deploy_qparams_tensorrt(model: GraphModule, onnx_model_path, model_name, **kwargs):
logger.info("Extract qparams for TensorRT.")
remove_fakequantize_and_collect_params(onnx_model_path, model_name, backend='tensorrt')
remove_fakequantize_and_collect_params(onnx_model_path, model_name, backend='tensorrt', qmin_max_dict = qmin_max_dict)


@register_deploy_function(BackendType.Vitis)
def deploy_qparams_vitis(model: GraphModule, onnx_model_path, model_name, **kwargs):
logger.info("Extract qparams for Vitis-DPU.")
remove_fakequantize_and_collect_params(onnx_model_path, model_name, backend='vitis')
remove_fakequantize_and_collect_params(onnx_model_path, model_name, backend='vitis', qmin_max_dict = qmin_max_dict)


@register_deploy_function(BackendType.SNPE)
def deploy_qparams_snpe(model: GraphModule, onnx_model_path, model_name, **kwargs):
logger.info("Extract qparams for SNPE.")
remove_fakequantize_and_collect_params(onnx_model_path, model_name, backend='snpe')
remove_fakequantize_and_collect_params(onnx_model_path, model_name, backend='snpe', qmin_max_dict = qmin_max_dict)


@register_deploy_function(BackendType.PPLW8A16)
def deploy_qparams_pplw8a16(model: GraphModule, onnx_model_path, model_name, **kwargs):
logger.info("Extract qparams for PPLW8A16.")
remove_fakequantize_and_collect_params(onnx_model_path, model_name, backend='ppl')
remove_fakequantize_and_collect_params(onnx_model_path, model_name, backend='ppl', qmin_max_dict = qmin_max_dict)


@register_deploy_function(BackendType.ONNX_QNN)
def deploy_qparams_tvm(model: GraphModule, onnx_model_path, model_name, **kwargs):
logger.info("Convert to ONNX QNN.")
ONNXQNNPass(onnx_model_path).run(model_name)
ONNXQNNPass(onnx_model_path).run(model_name, qmin_max_dict = qmin_max_dict)


@register_deploy_function(BackendType.PPLCUDA)
def deploy_qparams_ppl_cuda(model: GraphModule, onnx_model_path, model_name, **kwargs):
logger.info("Extract qparams for PPL-CUDA.")
remove_fakequantize_and_collect_params(onnx_model_path, model_name, backend='ppl-cuda')
remove_fakequantize_and_collect_params(onnx_model_path, model_name, backend='ppl-cuda', qmin_max_dict = qmin_max_dict)


@register_deploy_function(BackendType.Tengine_u8)
def deploy_qparams_tengine(model: GraphModule, onnx_model_path, model_name, **kwargs):
logger.info("Extract qparams for Tengine.")
remove_fakequantize_and_collect_params_tengine(onnx_model_path, model_name)
remove_fakequantize_and_collect_params_tengine(onnx_model_path, model_name, qmin_max_dict = qmin_max_dict)


@register_deploy_function(BackendType.STPU)
def deploy_qparams_stpu(model: GraphModule, onnx_model_path, model_name, **kwargs):
logger.info("Extract qparams for STPU.")
remove_fakequantize_and_collect_params_stpu(onnx_model_path, model_name)
remove_fakequantize_and_collect_params_stpu(onnx_model_path, model_name, qmin_max_dict = qmin_max_dict)


def convert_deploy(model: GraphModule, backend_type: BackendType,
def convert_deploy(model: GraphModule, backend_type: BackendType,
input_shape_dict=None, dummy_input=None, output_path='./',
model_name='mqbench_qmodel', deploy_to_qlinear=False, **extra_kwargs):
r"""Convert model to onnx model and quantization params depends on backend.
Expand All @@ -186,9 +217,9 @@ def forward(self, input_0, input_1):
'output_path': output_path,
'model_name': model_name,
'onnx_model_path': osp.join(output_path, '{}.onnx'.format(model_name)),
'deploy_to_qlinear': deploy_to_qlinear
'deploy_to_qlinear': deploy_to_qlinear,
}
kwargs.update(extra_kwargs)
# kwargs.update(extra_kwargs)
deploy_model = deepcopy_graphmodule(model)
for convert_function in BACKEND_DEPLOY_FUNCTION[backend_type]:
convert_function(deploy_model, **kwargs)
10 changes: 5 additions & 5 deletions mqbench/custom_quantizer/academic_quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import torch
from torch.fx import GraphModule
from torch.quantization import propagate_qconfig_
from torch.quantization.fx.qconfig_utils import get_flattened_qconfig_dict
from mqbench.quantization.qconfig_mapping_utils import get_flattened_qconfig_dict

from mqbench.utils import is_symmetric_quant, getitem2node
from mqbench.utils.logger import logger
Expand All @@ -25,14 +25,14 @@ def __init__(self, extra_quantizer_dict, extra_fuse_dict):
self.io_module = {}
self.post_act_8bit_node_name = []

def prepare(self, model: GraphModule, qconfig):
def prepare(self, model: GraphModule, qconfig, is_qat, backend_config, freeze_bn):
self._get_io_module(model)
self._get_post_act_8bit_node_name(model)
model = self._weight_quant(model, qconfig)
model = self._weight_quant(model, qconfig, backend_config, freeze_bn)
model = self._insert_fake_quantize_for_act_quant(model, qconfig)
return model

def _weight_quant(self, model: GraphModule, qconfig):
def _weight_quant(self, model: GraphModule, qconfig, backend_config, freeze_bn):
logger.info("Replace module to qat module.")
wqconfig_8bit = copy.deepcopy(qconfig)
wq_symmetry = True if is_symmetric_quant(qconfig.weight.p.keywords['qscheme']) else False
Expand All @@ -44,7 +44,7 @@ def _weight_quant(self, model: GraphModule, qconfig):
module.qconfig = wqconfig_8bit
flattened_qconfig_dict = get_flattened_qconfig_dict({'': qconfig})
propagate_qconfig_(model, flattened_qconfig_dict)
self._qat_swap_modules(model, self.additional_qat_module_mapping)
self._qat_swap_modules(model, self.additional_qat_module_mapping, backend_config, freeze_bn)
return model

@property
Expand Down
43 changes: 26 additions & 17 deletions mqbench/custom_quantizer/model_quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,7 @@
from torch.quantization.utils import (
get_combined_dict
)
from torch.quantization.fx.qconfig_utils import (
get_flattened_qconfig_dict
)
from mqbench.quantization.qconfig_mapping_utils import get_flattened_qconfig_dict
from torch.quantization.quantize_fx import (
_fuse_fx
)
Expand All @@ -34,10 +32,15 @@
from mqbench.utils.logger import logger
from mqbench.utils.registry import register_model_quantizer
from mqbench.prepare_by_platform import BackendType


from torch.ao.quantization.backend_config import (
BackendConfig,
)
from torch.ao.quantization.backend_config.utils import (
get_module_to_qat_module,
)
@register_model_quantizer(BackendType.Tensorrt)
@register_model_quantizer(BackendType.NNIE)
@register_model_quantizer(BackendType.QDQ)
class ModelQuantizer(object):
"""General model quantizer class.
First, replace common float module to nn.qat.modules to make weight fake
Expand All @@ -60,9 +63,9 @@ def __init__(self, extra_quantizer_dict, extra_fuse_dict):
self.exclude_node_name = extra_quantizer_dict.get('exclude_node_name', [])
self.extra_fuse_dict = extra_fuse_dict

def prepare(self, model: GraphModule, qconfig):
model = _fuse_fx(model, self.extra_fuse_dict)
model = self._weight_quant(model, qconfig)
def prepare(self, model: GraphModule, qconfig, is_qat, backend_config, freeze_bn):
model = _fuse_fx(model, is_qat, self.extra_fuse_dict, backend_config)
model = self._weight_quant(model, qconfig, backend_config, freeze_bn)
model = self._insert_fake_quantize_for_act_quant(model, qconfig)
return model

Expand Down Expand Up @@ -119,11 +122,11 @@ def _fix_succ_recursivly(self, args, target_node, inserted_node):
else:
raise NotImplementedError('{} can not be handled now.'.format(type(args)))

def _weight_quant(self, model: GraphModule, qconfig):
def _weight_quant(self, model: GraphModule, qconfig, backend_config, freeze_bn):
logger.info("Replace module to qat module.")
flattened_qconfig_dict = get_flattened_qconfig_dict({'': qconfig})
propagate_qconfig_(model, flattened_qconfig_dict)
self._qat_swap_modules(model, self.additional_qat_module_mapping)
self._qat_swap_modules(model, self.additional_qat_module_mapping, backend_config, freeze_bn)
return model

@property
Expand Down Expand Up @@ -245,15 +248,18 @@ def _find_act_quants(self, model: GraphModule) -> List:
node_need_to_quantize_output.append(_node)
return node_need_to_quantize_output

def _qat_swap_modules(self, root: GraphModule, additional_qat_module_mapping: Dict[Callable, Callable]):
def _qat_swap_modules(self, root: GraphModule, additional_qat_module_mapping: Dict[Callable, Callable], backend_config: BackendConfig, freeze_bn: bool):
# all_mappings = get_combined_dict(
# get_default_qat_module_mappings(), additional_qat_module_mapping)
all_mappings = get_combined_dict(
get_default_qat_module_mappings(), additional_qat_module_mapping)
root = self._convert(root, all_mappings, inplace=True)
get_module_to_qat_module(backend_config), additional_qat_module_mapping)
root = self._convert(root, all_mappings, inplace=True, backend_config = backend_config, freeze_bn=freeze_bn)
return root

def _convert(self, module, mapping=None, inplace=False, scope=''):
def _convert(self, module, mapping=None, inplace=False, backend_config=None, freeze_bn=True, scope=''):
if mapping is None:
mapping = get_default_static_quant_module_mappings()
# mapping = get_default_static_quant_module_mappings()
mapping = get_module_to_qat_module(backend_config)

if not inplace:
module = copy.deepcopy(module)
Expand All @@ -265,8 +271,11 @@ def _convert(self, module, mapping=None, inplace=False, scope=''):
logger.info("Skip quant layer: " + new_scope)
continue
if not isinstance(mod, _FusedModule):
self._convert(mod, mapping, True, new_scope)
reassign[name] = swap_module(mod, mapping, {})
self._convert(mod, mapping, True, new_scope, freeze_bn= freeze_bn)
reassign[name] = swap_module(mod, mapping, {}, False)
if freeze_bn:
if (hasattr(reassign[name], 'freeze_bn')):
reassign[name].freeze_bn = True
for key, value in reassign.items():
module._modules[key] = value

Expand Down
10 changes: 6 additions & 4 deletions mqbench/custom_quantizer/onnx_qnn_quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,9 @@
from mqbench.utils.registry import register_model_quantizer
from mqbench.prepare_by_platform import BackendType
from mqbench.custom_quantizer import ModelQuantizer

from torch.ao.quantization.backend_config import (
BackendConfig,
)

@register_model_quantizer(BackendType.ONNX_QNN)
class ONNXQNNQuantizer(ModelQuantizer):
Expand Down Expand Up @@ -52,14 +54,14 @@ def _find_act_quants(self, model: GraphModule) -> List:
node_need_to_quantize_output.append(next_node)
return node_need_to_quantize_output

def _qat_swap_modules(self, root: GraphModule, additional_qat_module_mapping: Dict[Callable, Callable]):
def _qat_swap_modules(self, root: GraphModule, additional_qat_module_mapping: Dict[Callable, Callable], backend_config: BackendConfig, freeze_bn: bool):
all_mappings = get_combined_dict(
get_default_qat_module_mappings(), additional_qat_module_mapping)
# There is no QLinearFC in ONNX for now.
del all_mappings[torch.nn.modules.linear.Linear]
del all_mappings[torch.nn.intrinsic.modules.fused.LinearReLU]
del all_mappings[qnni.modules.fused.LinearBn1d]
root = self._convert(root, all_mappings, inplace=True)
# del all_mappings[qnni.modules.fused.LinearBn1d]
root = self._convert(root, all_mappings, inplace=True, backend_config = backend_config, freeze_bn=freeze_bn)
return root

@property
Expand Down
12 changes: 6 additions & 6 deletions mqbench/custom_quantizer/openvino_quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import torch
from torch.fx import GraphModule
from torch.quantization import propagate_qconfig_
from torch.quantization.fx.qconfig_utils import get_flattened_qconfig_dict
from mqbench.quantization.qconfig_mapping_utils import get_flattened_qconfig_dict
from torch.quantization.quantize_fx import _fuse_fx

from mqbench.utils import is_symmetric_quant
Expand Down Expand Up @@ -137,10 +137,10 @@ def function_type_to_quant_unsigned(self) -> tuple:
def module_type_maybe_unsigned(self) -> tuple:
return (torch.nn.Upsample, torch.nn.modules.pooling.MaxPool2d, torch.nn.modules.pooling.AvgPool2d, torch.nn.modules.pooling.AdaptiveAvgPool2d)

def prepare(self, model: GraphModule, qconfig):
def prepare(self, model: GraphModule, qconfig, is_qat, backend_config, freeze_bn):
if not self.academic_mode:
model = _fuse_fx(model, self.extra_fuse_dict)
model = self._weight_quant(model, qconfig)
model = _fuse_fx(model, is_qat, self.extra_fuse_dict, backend_config)
model = self._weight_quant(model, qconfig, backend_config, freeze_bn)
model = self._insert_fake_quantize_for_act_quant(model, qconfig)
return model

Expand Down Expand Up @@ -199,7 +199,7 @@ def propagated_pattern(prev_node, cur_node):
break
return node_need_to_quantize_output

def _weight_quant(self, model: GraphModule, qconfig):
def _weight_quant(self, model: GraphModule, qconfig, backend_config, freeze_bn):
logger.info("Replace module to qat module.")
wqconfig_8bit = copy.deepcopy(qconfig)
wq_symmetry = True if is_symmetric_quant(qconfig.weight.p.keywords['qscheme']) else False
Expand All @@ -213,7 +213,7 @@ def _weight_quant(self, model: GraphModule, qconfig):
wqconfig_8bit.weight.p.keywords['quant_max'] = 2 ** (numbits - 2) - 1
flattened_qconfig_dict = get_flattened_qconfig_dict({'': wqconfig_8bit})
propagate_qconfig_(model, flattened_qconfig_dict)
self._qat_swap_modules(model, self.additional_qat_module_mapping)
self._qat_swap_modules(model, self.additional_qat_module_mapping, backend_config, freeze_bn)
return model


Expand Down
Loading

0 comments on commit 93084c8

Please sign in to comment.