diff --git a/src/operator/quantization/quantized_fully_connected.cc b/src/operator/quantization/quantized_fully_connected.cc index e334fe7ec9b2..64ce73ba1cf7 100644 --- a/src/operator/quantization/quantized_fully_connected.cc +++ b/src/operator/quantization/quantized_fully_connected.cc @@ -23,11 +23,17 @@ * \brief * \author Ziheng Jiang, Jun Wu */ +#include +#include "quantization_utils.h" #include "../nn/fully_connected-inl.h" namespace mxnet { namespace op { +namespace quantized_fc { +enum QuantizedfcOpResource {kTempSpace}; +} + bool QuantizedFullyConnectedShape(const nnvm::NodeAttrs& attrs, std::vector *in_shape, std::vector *out_shape) { @@ -79,6 +85,151 @@ bool QuantizedFullyConnectedType(const nnvm::NodeAttrs& attrs, return true; } +bool QuantizedFullyConnectedStorageType(const nnvm::NodeAttrs& attrs, + const int dev_mask, + DispatchMode* dispatch_mode, + std::vector *in_attrs, + std::vector *out_attrs) { + *dispatch_mode = DispatchMode::kFCompute; + if (dev_mask == mshadow::cpu::kDevMask) { + *dispatch_mode = DispatchMode::kFComputeEx; + } + + for (auto &v : *out_attrs) { + v = kDefaultStorage; + if (common::stype_string(v).compare("unknown") == 0) { + return false; + } + } + + for (auto &v : *in_attrs) { + v = kDefaultStorage; + if (common::stype_string(v).compare("unknown") == 0) { + return false; + } + } + return true; +} + +struct QuantizedSumInitKernelWithBias { + // init sum data with bias for matrix b (n) + MSHADOW_XINLINE static void Map(int i, int32_t *out, + const int8_t *bias, const float *min_out, + const float *max_out, const float *min_bias, + const float *max_bias) { + typedef int32_t T1; + typedef int8_t T2; + using mshadow::red::limits::MinValue; + using mshadow::red::limits::MaxValue; + float float_for_one_out_quant = + MaxAbs(*min_out, *max_out) / static_cast(MaxValue()); + float float_for_one_bias_quant = + MaxAbs(*min_bias, *max_bias) / static_cast(MaxValue()); + if (float_for_one_out_quant != 0) { + out[i] = bias[i] * float_for_one_bias_quant / + float_for_one_out_quant; + } else { + LOG(INFO) << "float_for_one_out_quant is 0," + << " need to check the why MaxAbs(*min_out, *max_out) of out_data is 0!"; + out[i] = 0; + } + } +}; + + +template +void QuantizedFullyConnectedForward(const nnvm::NodeAttrs& attrs, + const OpContext &ctx, + const std::vector &in_data, + const std::vector &req, + const std::vector &out_data) { +#if MSHADOW_USE_MKL == 1 + const FullyConnectedParam& param = nnvm::get(attrs.parsed); + using namespace mshadow; + using namespace mxnet_op; + size_t num_inputs = param.no_bias ? 2 : 3; + CHECK_EQ(in_data.size(), num_inputs * 3); + CHECK_EQ(out_data.size(), 3U); + const NDArray& data = in_data[0]; + const NDArray& weight = in_data[1]; + const NDArray& out = out_data[0]; + TShape dshape = data.shape(); + TShape wshape = weight.shape(); + TShape oshape = out.shape(); + auto output_temp = out.data().dptr(); + auto weight_temp = weight.data().dptr(); + auto data_temp = data.data().dptr(); + const int omp_threads = mxnet::engine::OpenMP::Get()->GetRecommendedOMPThreadCount(); + const float alpha = 1.0f; + const float beta = 1.0f; + const CBLAS_OFFSET offsetc = CblasFixOffset; + const MKL_INT8 oa = 0; + const MKL_INT8 ob = 0; + MKL_INT32 oc = 0; + const int m = dshape[0], n = wshape[0], k = dshape.ProdShape(1, dshape.ndim()); + Stream *s = ctx.get_stream(); + // cblas_gemm_s8u8s32 required first matrix must be uint8 + // shift data from int8(from -128 to 127) to uint8 (from 0 to 255) + int shift = 128; + Tensor shiftdata = + ctx.requested[quantized_fc::kTempSpace].get_space_typed( + Shape1(m * k), s); + #pragma omp parallel for num_threads(omp_threads) + for (int i = 0; i < m * k; ++i) { + shiftdata.dptr_[i] = data_temp[i] + shift; + } + + Kernel::Launch(s, 1, + out_data[1].data().dptr(), out_data[2].data().dptr(), + in_data[num_inputs].data().dptr(), in_data[num_inputs+1].data().dptr(), + in_data[num_inputs+2].data().dptr(), in_data[num_inputs+3].data().dptr()); + if (!param.no_bias) { + const NDArray& bias = in_data[2]; + Kernel::Launch(s, n, out.data().dptr(), + bias.data().dptr(), out_data[1].data().dptr(), + out_data[2].data().dptr(), in_data[7].data().dptr(), + in_data[8].data().dptr()); + } else { + #pragma omp parallel for num_threads(omp_threads) + for (int i = 0; i < m * n; ++i) { + output_temp[i] = 0; + } + } + #pragma omp parallel for num_threads(omp_threads) + for (int i = 0; i < n; ++i) { + for (int j = 0; j < k; ++j) { + output_temp[i] -= shift * weight_temp[i * k + j]; + } + } + #pragma omp parallel for num_threads(omp_threads) + for (int i = n; i < m * n; ++i) { + output_temp[i] = output_temp[i % n]; + } + cblas_gemm_s8u8s32(CblasRowMajor, + CblasNoTrans, + CblasTrans, + offsetc, + m, + n, + k, + alpha, + shiftdata.dptr_, + k, + oa, + weight.data().dptr(), + k, + ob, + beta, + out.data().dptr(), + n, + &oc); +#else + LOG(FATAL) << "Quantized fully connected operator relies on cblas_gemm_s8u8s32" + << " which is only supported by MKL BLAS." + << " Please build MXNet with USE_BLAS=mkl to leverage this operator."; +#endif +} + NNVM_REGISTER_OP(_contrib_quantized_fully_connected) .describe(R"code(Fully Connected operator for input, weight and bias data type of int8, and accumulates in type int32 for the output. For each argument, two more arguments of type @@ -112,7 +263,14 @@ and max thresholds representing the threholds for quantizing the float32 output }) .set_attr("FInferShape", QuantizedFullyConnectedShape) .set_attr("FInferType", QuantizedFullyConnectedType) +.set_attr("FInferStorageType", QuantizedFullyConnectedStorageType) .set_attr("FNeedRequantize", [](const NodeAttrs& attrs) { return true; }) +.set_attr("FComputeEx", + QuantizedFullyConnectedForward) +.set_attr("FResourceRequest", + [](const NodeAttrs& attrs) { + return std::vector{ResourceRequest::kTempSpace}; + }) .add_argument("data", "NDArray-or-Symbol", "Input data.") .add_argument("weight", "NDArray-or-Symbol", "weight.") .add_argument("bias", "NDArray-or-Symbol", "bias.") @@ -135,6 +293,5 @@ NNVM_REGISTER_OP(FullyConnected) } return node; }); - } // namespace op } // namespace mxnet diff --git a/tests/python/quantization/test_quantization.py b/tests/python/quantization/test_quantization.py index ca8070cfc224..a27d5b322e3d 100644 --- a/tests/python/quantization/test_quantization.py +++ b/tests/python/quantization/test_quantization.py @@ -26,6 +26,7 @@ from mxnet.module import Module from mxnet.io import NDArrayIter import unittest +import operator def is_test_for_gpu(): return mx.current_context().device_type == 'gpu' @@ -278,8 +279,15 @@ def check_quantized_pooling(data_shape, kernel, pool_type, pad, stride, global_p def test_quantized_fc(): def check_quantized_fc(data_shape, num_hidden, no_bias, qdtype, flatten=True): if mx.current_context().device_type != 'gpu': - print('skipped testing quantized_fc on cpu since it is not supported yet') - return + hasMKL = False; + for key in os.environ.keys(): + if operator.eq(key, "BUILD_TAG"): + if os.environ['BUILD_TAG'].find("MKL") != -1: + hasMKL = True + break + if hasMKL == False: + print('skipped testing quantized_fc on cpu since s8u8s32 is only supported by MKL BLAS library') + return elif qdtype == 'uint8' and is_test_for_gpu(): print('skipped testing quantized_fc for gpu uint8 since it is not supported yet') return @@ -291,16 +299,16 @@ def check_quantized_fc(data_shape, num_hidden, no_bias, qdtype, flatten=True): fc_fp32_exe = fc_fp32.simple_bind(ctx=mx.current_context(), grad_req='null') if qdtype == 'uint8': data_low = 0.0 - data_high = 127.0 + data_high = 63.0 else: - data_low = -127.0 - data_high = 127.0 + data_low = -63.0 + data_high = 63.0 fc_fp32_exe.arg_dict[arg_names[0]][:] = mx.nd.random.uniform(low=data_low, high=data_high, shape=data_shape).astype('int32') - fc_fp32_exe.arg_dict[arg_names[1]][:] = mx.nd.random.uniform(low=-127.0, high=127.0, + fc_fp32_exe.arg_dict[arg_names[1]][:] = mx.nd.random.uniform(low=data_low, high=data_high, shape=arg_shapes[1]).astype('int32') if not no_bias: - fc_fp32_exe.arg_dict[arg_names[2]][:] = mx.nd.random.uniform(low=-127.0, high=127.0, + fc_fp32_exe.arg_dict[arg_names[2]][:] = mx.nd.random.uniform(low=data_low, high=data_high, shape=arg_shapes[2]).astype('int32') output = fc_fp32_exe.forward()[0] @@ -343,6 +351,10 @@ def check_quantized_fc(data_shape, num_hidden, no_bias, qdtype, flatten=True): check_quantized_fc((32, 111, 2, 2), 100, True, qdtype) check_quantized_fc((32, 512, 2, 2), 100, False, qdtype) check_quantized_fc((32, 111, 2, 2), 100, False, qdtype) + check_quantized_fc((256, 2048, 2, 2), 800, False, qdtype) + check_quantized_fc((256, 111, 2, 2), 800, False, qdtype) + check_quantized_fc((256, 2048, 2, 2), 800, True, qdtype) + check_quantized_fc((256, 111, 2, 2), 800, True, qdtype) @with_seed() def test_quantized_flatten():