Skip to content

Commit

Permalink
update 2.5.1
Browse files Browse the repository at this point in the history
  • Loading branch information
yinnengzhong committed Jan 2, 2025
1 parent 09ff5cb commit bd789d1
Show file tree
Hide file tree
Showing 12 changed files with 366 additions and 169 deletions.
4 changes: 2 additions & 2 deletions mqbench/convert_deploy.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def convert_onnx(model: GraphModule, input_shape_dict, dummy_input, onnx_model_p
input_names = list(dummy_input.keys())
dummy_input = tuple(dummy_input.values())
# Per-channel QuantizeLinear and DequantizeLinear is supported since opset 13
opset_version = 13 if kwargs.get('deploy_to_qlinear', False) else 11
opset_version = 13 if kwargs.get('deploy_to_qlinear', False) else 13
with torch.no_grad():
try:
from torch.onnx.utils import ONNXCheckerError
Expand Down Expand Up @@ -159,7 +159,7 @@ def deploy_qparams_stpu(model: GraphModule, onnx_model_path, model_name, **kwarg
remove_fakequantize_and_collect_params_stpu(onnx_model_path, model_name)


def convert_deploy(model: GraphModule, backend_type: BackendType,
def convert_deploy(model: GraphModule, backend_type: BackendType,
input_shape_dict=None, dummy_input=None, output_path='./',
model_name='mqbench_qmodel', deploy_to_qlinear=False, **extra_kwargs):
r"""Convert model to onnx model and quantization params depends on backend.
Expand Down
2 changes: 1 addition & 1 deletion mqbench/custom_quantizer/academic_quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import torch
from torch.fx import GraphModule
from torch.quantization import propagate_qconfig_
from torch.quantization.fx.qconfig_utils import get_flattened_qconfig_dict
from mqbench.quantization.qconfig_mapping_utils import get_flattened_qconfig_dict

from mqbench.utils import is_symmetric_quant, getitem2node
from mqbench.utils.logger import logger
Expand Down
37 changes: 21 additions & 16 deletions mqbench/custom_quantizer/model_quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,7 @@
from torch.quantization.utils import (
get_combined_dict
)
from torch.quantization.fx.qconfig_utils import (
get_flattened_qconfig_dict
)
from mqbench.quantization.qconfig_mapping_utils import get_flattened_qconfig_dict
from torch.quantization.quantize_fx import (
_fuse_fx
)
Expand All @@ -34,8 +32,12 @@
from mqbench.utils.logger import logger
from mqbench.utils.registry import register_model_quantizer
from mqbench.prepare_by_platform import BackendType


from torch.ao.quantization.backend_config import (
BackendConfig,
)
from torch.ao.quantization.backend_config.utils import (
get_module_to_qat_module,
)
@register_model_quantizer(BackendType.Tensorrt)
@register_model_quantizer(BackendType.NNIE)
class ModelQuantizer(object):
Expand All @@ -60,9 +62,9 @@ def __init__(self, extra_quantizer_dict, extra_fuse_dict):
self.exclude_node_name = extra_quantizer_dict.get('exclude_node_name', [])
self.extra_fuse_dict = extra_fuse_dict

def prepare(self, model: GraphModule, qconfig):
model = _fuse_fx(model, self.extra_fuse_dict)
model = self._weight_quant(model, qconfig)
def prepare(self, model: GraphModule, qconfig, is_qat, backend_config):
model = _fuse_fx(model, is_qat, self.extra_fuse_dict, backend_config)
model = self._weight_quant(model, qconfig, backend_config)
model = self._insert_fake_quantize_for_act_quant(model, qconfig)
return model

Expand Down Expand Up @@ -119,11 +121,11 @@ def _fix_succ_recursivly(self, args, target_node, inserted_node):
else:
raise NotImplementedError('{} can not be handled now.'.format(type(args)))

def _weight_quant(self, model: GraphModule, qconfig):
def _weight_quant(self, model: GraphModule, qconfig, backend_config):
logger.info("Replace module to qat module.")
flattened_qconfig_dict = get_flattened_qconfig_dict({'': qconfig})
propagate_qconfig_(model, flattened_qconfig_dict)
self._qat_swap_modules(model, self.additional_qat_module_mapping)
self._qat_swap_modules(model, self.additional_qat_module_mapping, backend_config)
return model

@property
Expand Down Expand Up @@ -245,15 +247,18 @@ def _find_act_quants(self, model: GraphModule) -> List:
node_need_to_quantize_output.append(_node)
return node_need_to_quantize_output

def _qat_swap_modules(self, root: GraphModule, additional_qat_module_mapping: Dict[Callable, Callable]):
def _qat_swap_modules(self, root: GraphModule, additional_qat_module_mapping: Dict[Callable, Callable], backend_config: BackendConfig):
# all_mappings = get_combined_dict(
# get_default_qat_module_mappings(), additional_qat_module_mapping)
all_mappings = get_combined_dict(
get_default_qat_module_mappings(), additional_qat_module_mapping)
root = self._convert(root, all_mappings, inplace=True)
get_module_to_qat_module(backend_config), additional_qat_module_mapping)
root = self._convert(root, all_mappings, inplace=True, backend_config = backend_config)
return root

def _convert(self, module, mapping=None, inplace=False, scope=''):
def _convert(self, module, mapping=None, inplace=False, backend_config=None, scope=''):
if mapping is None:
mapping = get_default_static_quant_module_mappings()
# mapping = get_default_static_quant_module_mappings()
mapping = get_module_to_qat_module(backend_config)

if not inplace:
module = copy.deepcopy(module)
Expand All @@ -266,7 +271,7 @@ def _convert(self, module, mapping=None, inplace=False, scope=''):
continue
if not isinstance(mod, _FusedModule):
self._convert(mod, mapping, True, new_scope)
reassign[name] = swap_module(mod, mapping, {})
reassign[name] = swap_module(mod, mapping, {}, False)
for key, value in reassign.items():
module._modules[key] = value

Expand Down
2 changes: 1 addition & 1 deletion mqbench/custom_quantizer/openvino_quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import torch
from torch.fx import GraphModule
from torch.quantization import propagate_qconfig_
from torch.quantization.fx.qconfig_utils import get_flattened_qconfig_dict
from mqbench.quantization.qconfig_mapping_utils import get_flattened_qconfig_dict
from torch.quantization.quantize_fx import _fuse_fx

from mqbench.utils import is_symmetric_quant
Expand Down
150 changes: 135 additions & 15 deletions mqbench/custom_symbolic_opset.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,143 @@
from torch.onnx import register_custom_op_symbolic

# Register symbolic op for torch.quantize_function op.
import functools
from torch.onnx._internal import jit_utils, registration
_onnx_symbolic = functools.partial(registration.onnx_symbolic, opset=13)
from torch.onnx import (
_type_utils,
symbolic_helper,
symbolic_opset9 as opset9,
)
import torch._C._onnx as _C_onnx
import torch
@_onnx_symbolic("aten::fake_quantize_per_tensor_affine")
@symbolic_helper.parse_args("v", "v", "v", "i", "i")
def fake_quantize_per_tensor_affine(
g: jit_utils.GraphContext,
inputs,
scale,
zero_point,
quant_min=-128,
quant_max=127,
):
# NOTE: (0, 127) is allowed as special case. PyTorch restricts activations to be in the range (0, 127).
# https://github.com/pytorch/pytorch/blob/b34b192d6b97325c9f78e5995c48c8498ede34bd/torch/ao/quantization/observer.py#L1422
# if (quant_min, quant_max) not in [(0, 255), (-128, 127), (0, 127)]:
# raise errors.SymbolicValueError(
# "For (quant_min, quant_max), ONNX allows only (0, 127), (0, 255) and (-128, 127). "
# f"Got ({quant_min}, {quant_max})",
# inputs,
# )
if quant_min == 0:
zero_point = g.op("Cast", zero_point, to_i=_C_onnx.TensorProtoDataType.UINT8)
else:
zero_point = g.op("Cast", zero_point, to_i=_C_onnx.TensorProtoDataType.INT8)
if (
_type_utils.JitScalarType.from_value(scale, _type_utils.JitScalarType.UNDEFINED)
!= _type_utils.JitScalarType.FLOAT
):
scale = g.op("Cast", scale, to_i=_C_onnx.TensorProtoDataType.FLOAT)
quantized = g.op("QuantizeLinear", inputs, scale, zero_point)
if (quant_min, quant_max) == (0, 127):
quantized = g.op(
"Clip",
quantized,
opset9.unused(g),
g.op("Constant", value_t=torch.tensor(127, dtype=torch.uint8)),
)
return g.op("DequantizeLinear", quantized, scale, zero_point)

def _fake_quantize_learnable_per_tensor_affine(g, x, scale, zero_point, quant_min, quant_max, grad_factor):
return g.op("::LearnablePerTensorAffine", x, scale, zero_point, quant_min, quant_max)
@_onnx_symbolic("aten::fake_quantize_per_channel_affine")
@symbolic_helper.parse_args("v", "v", "v", "i", "i", "i")
def fake_quantize_per_channel_affine(
g: jit_utils.GraphContext,
inputs,
scale,
zero_point,
axis,
quant_min=-128,
quant_max=127,
):
# NOTE: (0, 127) is allowed as special case. PyTorch restricts activations to be in the range (0, 127).
# https://github.com/pytorch/pytorch/blob/b34b192d6b97325c9f78e5995c48c8498ede34bd/torch/ao/quantization/observer.py#L1422
# if (quant_min, quant_max) not in [(0, 255), (-128, 127), (0, 127)]:
# raise errors.SymbolicValueError(
# "For (quant_min, quant_max), ONNX allows only (0, 127), (0, 255) and (-128, 127). "
# f"Got ({quant_min}, {quant_max})",
# inputs,
# )
# ONNX defines zero_point to be int8 or uint8
if quant_min == 0:
zero_point = g.op("Cast", zero_point, to_i=_C_onnx.TensorProtoDataType.UINT8)
else:
zero_point = g.op("Cast", zero_point, to_i=_C_onnx.TensorProtoDataType.INT8)
quantized = g.op("QuantizeLinear", inputs, scale, zero_point, axis_i=axis)
if (quant_min, quant_max) == (0, 127):
quantized = g.op(
"Clip",
quantized,
opset9.unused(g),
g.op("Constant", value_t=torch.tensor(127, dtype=torch.uint8)),
)
return g.op("DequantizeLinear", quantized, scale, zero_point, axis_i=axis)

@_onnx_symbolic("aten::_fake_quantize_learnable_per_tensor_affine")
@symbolic_helper.parse_args("v", "v", "v", "i", "i", "i")
def _fake_quantize_learnable_per_tensor_affine(
g: jit_utils.GraphContext,
inputs,
scale,
zero_point,
quant_min=-128,
quant_max=127,
grad_factor=0,
):
# NOTE: (0, 127) is allowed as special case. PyTorch restricts activations to be in the range (0, 127).
# https://github.com/pytorch/pytorch/blob/b34b192d6b97325c9f78e5995c48c8498ede34bd/torch/ao/quantization/observer.py#L1422
# if (quant_min, quant_max) not in [(0, 255), (-128, 127), (0, 127)]:
# raise errors.SymbolicValueError(
# "For (quant_min, quant_max), ONNX allows only (0, 127), (0, 255) and (-128, 127). "
# f"Got ({quant_min}, {quant_max})",
# inputs,
# )
if quant_min == 0:
zero_point = g.op("Cast", zero_point, to_i=_C_onnx.TensorProtoDataType.UINT8)
else:
zero_point = g.op("Cast", zero_point, to_i=_C_onnx.TensorProtoDataType.INT8)
if (
_type_utils.JitScalarType.from_value(scale, _type_utils.JitScalarType.UNDEFINED)
!= _type_utils.JitScalarType.FLOAT
):
scale = g.op("Cast", scale, to_i=_C_onnx.TensorProtoDataType.FLOAT)
quantized = g.op("QuantizeLinear", inputs, scale, zero_point)
if (quant_min, quant_max) == (0, 127):
quantized = g.op(
"Clip",
quantized,
opset9.unused(g),
g.op("Constant", value_t=torch.tensor(127, dtype=torch.uint8)),
)
return g.op("DequantizeLinear", quantized, scale, zero_point)

register_custom_op_symbolic('::_fake_quantize_learnable_per_tensor_affine', _fake_quantize_learnable_per_tensor_affine, 11)

# def _fake_quantize_learnable_per_tensor_affine(g, x, scale, zero_point, quant_min, quant_max, grad_factor):
# return g.op(x, scale, zero_point, quant_min, quant_max)
#
#
# register_custom_op_symbolic('::_fake_quantize_learnable_per_tensor_affine', _fake_quantize_learnable_per_tensor_affine, 11)
#
#
# def fake_quantize_per_channel_affine(g, x, scale, zero_point, ch_axis, quant_min, quant_max):
# return g.op("::FixedPerChannelAffine", x, scale, zero_point, ch_axis, quant_min, quant_max)
#
#
# register_custom_op_symbolic('::fake_quantize_per_channel_affine', fake_quantize_per_channel_affine, 11)
#
#
# def fake_quantize_per_tensor_affine(g, x, scale, zero_point, quant_min, quant_max):
# return g.op("::FixedPerTensorAffine", x, scale, zero_point, quant_min, quant_max)
#
#
# register_custom_op_symbolic('::fake_quantize_per_tensor_affine', fake_quantize_per_tensor_affine, 11)

def fake_quantize_per_channel_affine(g, x, scale, zero_point, ch_axis, quant_min, quant_max):
return g.op("::FixedPerChannelAffine", x, scale, zero_point, ch_axis, quant_min, quant_max)


register_custom_op_symbolic('::fake_quantize_per_channel_affine', fake_quantize_per_channel_affine, 11)


def fake_quantize_per_tensor_affine(g, x, scale, zero_point, quant_min, quant_max):
return g.op("::FixedPerTensorAffine", x, scale, zero_point, quant_min, quant_max)


register_custom_op_symbolic('::fake_quantize_per_tensor_affine', fake_quantize_per_tensor_affine, 11)
2 changes: 1 addition & 1 deletion mqbench/fake_quantize/fixed.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from mqbench.utils.hook import PerChannelLoadHook


_version_under_1100 = int(torch.__version__.split('.')[1]) < 10
_version_under_1100 = int(torch.__version__.split('.')[0]) == 1 and int(torch.__version__.split('.')[1]) < 10

class FixedFakeQuantize(QuantizeBase):
"""This is actually torch.quantization.FakeQuantize.
Expand Down
23 changes: 20 additions & 3 deletions mqbench/fake_quantize/lsq.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,12 @@
from mqbench.fake_quantize.quantize_base import QuantizeBase
from mqbench.utils import is_symmetric_quant, is_tracing_state
from mqbench.utils.hook import PerChannelLoadHook

from torch.onnx import (
_type_utils,
symbolic_helper,
symbolic_opset9 as opset9,
)
import torch._C._onnx as _C_onnx

class LearnableFakeQuantize(QuantizeBase):
r""" This is an extension of the FakeQuantize module in fake_quantize.py, which
Expand Down Expand Up @@ -106,5 +111,17 @@ def forward(ctx, x, scale, zero_point, ch_axis, quant_min, quant_max, grad_facto
quant_min, quant_max, grad_factor)

@staticmethod
def symbolic(g, x, scale, zero_point, ch_axis, quant_min, quant_max, grad_factor):
return g.op("::FakeQuantizeLearnablePerchannelAffine", x, scale, zero_point, quant_min_i=quant_min, quant_max_i=quant_max)
def symbolic(g, inputs, scale, zero_point, ch_axis, quant_min, quant_max, grad_factor):
if quant_min == 0:
zero_point = g.op("Cast", zero_point, to_i=_C_onnx.TensorProtoDataType.UINT8)
else:
zero_point = g.op("Cast", zero_point, to_i=_C_onnx.TensorProtoDataType.INT8)
quantized = g.op("QuantizeLinear", inputs, scale, zero_point, axis_i=ch_axis)
if (quant_min, quant_max) == (0, 127):
quantized = g.op(
"Clip",
quantized,
opset9.unused(g),
g.op("Constant", value_t=torch.tensor(127, dtype=torch.uint8)),
)
return g.op("DequantizeLinear", quantized, scale, zero_point, axis_i=ch_axis)
Loading

0 comments on commit bd789d1

Please sign in to comment.