From 0c93cde0d3af3d211030d18a7b44cb20fa312316 Mon Sep 17 00:00:00 2001 From: Sheng Zha Date: Fri, 8 Feb 2019 17:04:45 -0800 Subject: [PATCH] add dtype --- src/operator/mxnet_op.h | 2 +- src/operator/nn/softmax-inl.h | 151 +++++++++++++++++-------- src/operator/nn/softmax.cc | 90 +++++++++++++-- tests/python/unittest/test_operator.py | 61 ++++++---- 4 files changed, 226 insertions(+), 78 deletions(-) diff --git a/src/operator/mxnet_op.h b/src/operator/mxnet_op.h index a2c1e9ad38f1..83c36defa844 100644 --- a/src/operator/mxnet_op.h +++ b/src/operator/mxnet_op.h @@ -254,7 +254,7 @@ inline int get_num_threads(const int N) { case mshadow::kFloat32: \ { \ typedef float DType; \ - typedef float AType; \ + typedef double AType; \ {__VA_ARGS__} \ } \ break; \ diff --git a/src/operator/nn/softmax-inl.h b/src/operator/nn/softmax-inl.h index 1d034dbf82b9..b1409142cfa2 100644 --- a/src/operator/nn/softmax-inl.h +++ b/src/operator/nn/softmax-inl.h @@ -52,8 +52,8 @@ struct log_softmax_fwd { }; -template -inline void Softmax(Stream *s, DType *in, DType *out, +template +inline void Softmax(Stream *s, DType *in, OType *out, Shape shape, int axis, const DType temperature) { index_t M = shape[axis]; index_t N = shape.Size()/M; @@ -75,8 +75,7 @@ inline void Softmax(Stream *s, DType *in, DType *out, AType sum = AType(0); DType in_val; - // By default temperature is 1.0, and only in reinforcement training - // users would set it to other values. + // By default temperature is 1.0. // Adding a branch here to save the CPU 'divide-by-1' computation at runtime if (temperature == 1.0) { for (index_t j = 0; j < M; ++j) { @@ -119,8 +118,9 @@ struct log_softmax_bwd { }; -template -inline void SoftmaxGrad(Stream *s, DType *out, DType *ograd, +template +inline void SoftmaxGrad(Stream *s, OType *out, OType *ograd, DType *igrad, Shape shape, int axis, const DType temperature) { index_t M = shape[axis]; @@ -139,8 +139,7 @@ inline void SoftmaxGrad(Stream *s, DType *out, DType *ograd, sum += OP1::Map(ograd[base + j*sa], out[base + j*sa]); } - // By default temperature is 1.0, and only in reinforcement training - // users would set it to other values. + // By default temperature is 1.0. // Adding a branch here to save the CPU 'divide-by-1' computation at runtime DType final_result; if (temperature == 1.0) { @@ -163,8 +162,9 @@ inline void SoftmaxGrad(Stream *s, DType *out, DType *ograd, #ifdef __CUDACC__ -template -__global__ void softmax_compute_kernel(DType *in, DType *out, index_t M, int axis, +template +__global__ void softmax_compute_kernel(DType *in, OType *out, index_t M, int axis, Shape sshape, Shape stride, const double temperature) { const unsigned x_size = 1 << x_bits; @@ -201,8 +201,8 @@ __global__ void softmax_compute_kernel(DType *in, DType *out, index_t M, int axi } } -template -inline void Softmax(Stream *s, DType *in, DType *out, +template +inline void Softmax(Stream *s, DType *in, OType *out, Shape shape, int axis, const double temperature) { const int x_bits = 7; const int x_size = 1 << x_bits; @@ -212,16 +212,16 @@ inline void Softmax(Stream *s, DType *in, DType *out, Shape sshape = shape; sshape[axis] = 1; - softmax_compute_kernel + softmax_compute_kernel <<::GetStream(s)>>>( in, out, M, axis, sshape, stride, temperature); MSHADOW_CUDA_POST_KERNEL_CHECK(softmax_compute_kernel); } -template -__global__ void softmax_gradient_kernel(DType *out, DType *ograd, DType *igrad, +template +__global__ void softmax_gradient_kernel(OType *out, OType *ograd, DType *igrad, index_t M, int axis, Shape sshape, Shape stride, const double temperature) { const unsigned x_size = 1 << x_bits; @@ -251,8 +251,9 @@ __global__ void softmax_gradient_kernel(DType *out, DType *ograd, DType *igrad, } -template -inline void SoftmaxGrad(Stream *s, DType *out, DType *ograd, +template +inline void SoftmaxGrad(Stream *s, OType *out, OType *ograd, DType *igrad, Shape shape, int axis, const double temperature) { const int x_bits = 7; @@ -263,7 +264,7 @@ inline void SoftmaxGrad(Stream *s, DType *out, DType *ograd, Shape sshape = shape; sshape[axis] = 1; - softmax_gradient_kernel + softmax_gradient_kernel <<::GetStream(s)>>>( out, ograd, igrad, M, axis, sshape, stride, temperature); MSHADOW_CUDA_POST_KERNEL_CHECK(softmax_gradient_kernel); @@ -276,14 +277,70 @@ inline void SoftmaxGrad(Stream *s, DType *out, DType *ograd, struct SoftmaxParam : public dmlc::Parameter { int axis; dmlc::optional temperature; + dmlc::optional dtype; DMLC_DECLARE_PARAMETER(SoftmaxParam) { DMLC_DECLARE_FIELD(axis).set_default(-1) - .describe("The axis along which to compute softmax."); + .describe("The axis along which to compute softmax."); DMLC_DECLARE_FIELD(temperature).set_default(dmlc::optional()) - .describe("Temperature parameter in softmax"); + .describe("Temperature parameter in softmax"); + DMLC_DECLARE_FIELD(dtype) + .add_enum("float16", mshadow::kFloat16) + .add_enum("float32", mshadow::kFloat32) + .add_enum("float64", mshadow::kFloat64) + .set_default(dmlc::optional()) + .describe("DType of the output in case this can't be inferred. " + "Defaults to the same as input's dtype if not defined (dtype=None)."); } }; +inline bool SoftmaxOpType(const nnvm::NodeAttrs& attrs, + std::vector* in_attrs, + std::vector* out_attrs) { + CHECK_EQ(in_attrs->size(), 1); + CHECK_EQ(out_attrs->size(), 1); + const SoftmaxParam& param = nnvm::get(attrs.parsed); + + int arg_dtype = param.dtype.has_value()?param.dtype.value():-1, + in_dtype = (*in_attrs)[0], + out_dtype = (*out_attrs)[0]; + + if (out_dtype != -1 && in_dtype != -1) { + TYPE_ASSIGN_CHECK(*out_attrs, 0, arg_dtype); + TYPE_ASSIGN_CHECK(*in_attrs, 0, out_dtype); + return true; + } else if (in_dtype != -1) { + if (arg_dtype != -1) { + TYPE_ASSIGN_CHECK(*out_attrs, 0, arg_dtype); + } else { + TYPE_ASSIGN_CHECK(*out_attrs, 0, in_dtype); + } + return true; + } else if (out_dtype != -1) { + TYPE_ASSIGN_CHECK(*out_attrs, 0, arg_dtype); + TYPE_ASSIGN_CHECK(*in_attrs, 0, out_dtype); + return true; + } else { + if (arg_dtype != -1) { + TYPE_ASSIGN_CHECK(*out_attrs, 0, arg_dtype); + } + return false; + } +} + +inline bool SoftmaxGradOpType(const nnvm::NodeAttrs& attrs, + std::vector* in_attrs, + std::vector* out_attrs) { + CHECK_EQ(in_attrs->size(), 3); + CHECK_EQ(out_attrs->size(), 1); + + int in_dtype = (*in_attrs)[1], + out_dtype = (*in_attrs)[2]; + TYPE_ASSIGN_CHECK(*in_attrs, 0, out_dtype); + TYPE_ASSIGN_CHECK(*out_attrs, 0, in_dtype); + + return (*out_attrs)[0] != -1 && (*in_attrs)[0] != -1; +} + template void SoftmaxCompute(const nnvm::NodeAttrs& attrs, const OpContext& ctx, @@ -299,17 +356,19 @@ void SoftmaxCompute(const nnvm::NodeAttrs& attrs, param.temperature.value() : 1.0; TShape shape = AxisShapeCompact(inputs[0].shape_, &axis, true); MXNET_REAL_ACC_TYPE_SWITCH(inputs[0].type_flag_, DType, AType, { - if (shape.ndim() == 2) { - Softmax( - ctx.get_stream(), inputs[0].dptr(), - outputs[0].dptr(), shape.get<2>(), axis, - static_cast(temperature)); - } else { - Softmax( - ctx.get_stream(), inputs[0].dptr(), - outputs[0].dptr(), shape.get<3>(), axis, - static_cast(temperature)); - } + MSHADOW_REAL_TYPE_SWITCH(outputs[0].type_flag_, OType, { + if (shape.ndim() == 2) { + Softmax( + ctx.get_stream(), inputs[0].dptr(), + outputs[0].dptr(), shape.get<2>(), axis, + static_cast(temperature)); + } else { + Softmax( + ctx.get_stream(), inputs[0].dptr(), + outputs[0].dptr(), shape.get<3>(), axis, + static_cast(temperature)); + } + }); }); } @@ -327,19 +386,21 @@ void SoftmaxGradCompute(const nnvm::NodeAttrs& attrs, const double temperature = param.temperature.has_value() ? param.temperature.value() : 1.0; TShape shape = AxisShapeCompact(inputs[0].shape_, &axis, true); - MXNET_REAL_ACC_TYPE_SWITCH(inputs[0].type_flag_, DType, AType, { - MXNET_ASSIGN_REQ_SWITCH(req[0], Req, { - if (shape.ndim() == 2) { - SoftmaxGrad( - ctx.get_stream(), inputs[1].dptr(), - inputs[0].dptr(), outputs[0].dptr(), - shape.get<2>(), axis, static_cast(temperature)); - } else { - SoftmaxGrad( - ctx.get_stream(), inputs[1].dptr(), - inputs[0].dptr(), outputs[0].dptr(), - shape.get<3>(), axis, static_cast(temperature)); - } + MXNET_REAL_ACC_TYPE_SWITCH(inputs[2].type_flag_, OType, AType, { + MSHADOW_REAL_TYPE_SWITCH(outputs[0].type_flag_, DType, { + MXNET_ASSIGN_REQ_SWITCH(req[0], Req, { + if (shape.ndim() == 2) { + SoftmaxGrad( + ctx.get_stream(), inputs[2].dptr(), + inputs[0].dptr(), outputs[0].dptr(), + shape.get<2>(), axis, static_cast(temperature)); + } else { + SoftmaxGrad( + ctx.get_stream(), inputs[2].dptr(), + inputs[0].dptr(), outputs[0].dptr(), + shape.get<3>(), axis, static_cast(temperature)); + } + }); }); }); } diff --git a/src/operator/nn/softmax.cc b/src/operator/nn/softmax.cc index 81e775cac526..1d6cef58263c 100644 --- a/src/operator/nn/softmax.cc +++ b/src/operator/nn/softmax.cc @@ -67,7 +67,7 @@ inline static bool SoftmaxStorageType(const nnvm::NodeAttrs& attrs, } #endif -MXNET_OPERATOR_REGISTER_UNARY(softmax) +NNVM_REGISTER_OP(softmax) .describe(R"code(Applies the softmax function. The resulting array contains elements in the range (0,1) and the elements along the given axis sum up to 1. @@ -102,15 +102,39 @@ Example:: .set_attr("FComputeEx", SoftmaxComputeExCPU) .set_attr("FInferStorageType", SoftmaxStorageType) #endif -.set_attr("FGradient", ElemwiseGradUseOut{"_backward_softmax"}) +.set_attr("FGradient", ElemwiseGradUseInOut{"_backward_softmax"}) +.set_attr("FInferType", SoftmaxOpType) +.set_num_inputs(1) +.set_num_outputs(1) +.set_attr("FInferShape", ElemwiseShape<1, 1>) +.set_attr("FInplaceOption", + [](const NodeAttrs& attrs){ + return std::vector >{{0, 0}}; + }) +.add_argument("data", "NDArray-or-Symbol", "The input array.") .add_arguments(SoftmaxParam::__FIELDS__()); -MXNET_OPERATOR_REGISTER_BINARY(_backward_softmax) +NNVM_REGISTER_OP(_backward_softmax) +.set_num_inputs(3) +.set_num_outputs(1) +.set_attr("FListInputNames", + [](const NodeAttrs& attrs) { + return std::vector{"ograd", "data", "output"}; + }) +.set_attr("FInferShape", ElemwiseShape<3, 1>) +.set_attr("FInferType", SoftmaxGradOpType) +.set_attr("FInplaceOption", + [](const NodeAttrs& attrs){ + return std::vector >{{0, 0}, {1, 0}, {2, 0}}; + }) +.add_argument("ograd", "NDArray-or-Symbol", "gradient of output") +.add_argument("data", "NDArray-or-Symbol", "input") +.add_argument("output", "NDArray-or-Symbol", "output") .set_attr_parser(ParamParser) .set_attr("FCompute", SoftmaxGradCompute); -MXNET_OPERATOR_REGISTER_UNARY(softmin) +NNVM_REGISTER_OP(softmin) .describe(R"code(Applies the softmin function. The resulting array contains elements in the range (0,1) and the elements along the given axis sum @@ -141,15 +165,39 @@ Example:: return std::vector{"output"}; }) .set_attr("FCompute", SoftmaxCompute) -.set_attr("FGradient", ElemwiseGradUseOut{"_backward_softmin"}) +.set_attr("FGradient", ElemwiseGradUseInOut{"_backward_softmin"}) +.set_attr("FInferType", SoftmaxOpType) +.set_num_inputs(1) +.set_num_outputs(1) +.set_attr("FInferShape", ElemwiseShape<1, 1>) +.set_attr("FInplaceOption", + [](const NodeAttrs& attrs){ + return std::vector >{{0, 0}}; + }) +.add_argument("data", "NDArray-or-Symbol", "The input array.") .add_arguments(SoftmaxParam::__FIELDS__()); -MXNET_OPERATOR_REGISTER_BINARY(_backward_softmin) +NNVM_REGISTER_OP(_backward_softmin) +.set_num_inputs(3) +.set_num_outputs(1) +.set_attr("FListInputNames", + [](const NodeAttrs& attrs) { + return std::vector{"ograd", "data", "output"}; + }) +.set_attr("FInferShape", ElemwiseShape<3, 1>) +.set_attr("FInferType", SoftmaxGradOpType) +.set_attr("FInplaceOption", + [](const NodeAttrs& attrs){ + return std::vector >{{0, 0}, {1, 0}, {2, 0}}; + }) +.add_argument("ograd", "NDArray-or-Symbol", "gradient of output") +.add_argument("data", "NDArray-or-Symbol", "input") +.add_argument("output", "NDArray-or-Symbol", "output") .set_attr_parser(ParamParser) .set_attr("FCompute", SoftmaxGradCompute); -MXNET_OPERATOR_REGISTER_UNARY(log_softmax) +NNVM_REGISTER_OP(log_softmax) .describe(R"code(Computes the log softmax of the input. This is equivalent to computing softmax followed by log. @@ -168,10 +216,34 @@ Examples:: )code") .set_attr_parser(ParamParser) .set_attr("FCompute", SoftmaxCompute) -.set_attr("FGradient", ElemwiseGradUseOut{"_backward_log_softmax"}) +.set_attr("FGradient", ElemwiseGradUseInOut{"_backward_log_softmax"}) +.set_attr("FInferType", SoftmaxOpType) +.set_num_inputs(1) +.set_num_outputs(1) +.set_attr("FInferShape", ElemwiseShape<1, 1>) +.set_attr("FInplaceOption", + [](const NodeAttrs& attrs){ + return std::vector >{{0, 0}}; + }) +.add_argument("data", "NDArray-or-Symbol", "The input array.") .add_arguments(SoftmaxParam::__FIELDS__()); -MXNET_OPERATOR_REGISTER_BINARY(_backward_log_softmax) +NNVM_REGISTER_OP(_backward_log_softmax) +.set_num_inputs(3) +.set_num_outputs(1) +.set_attr("FListInputNames", + [](const NodeAttrs& attrs) { + return std::vector{"ograd", "data", "output"}; + }) +.set_attr("FInferShape", ElemwiseShape<3, 1>) +.set_attr("FInferType", SoftmaxGradOpType) +.set_attr("FInplaceOption", + [](const NodeAttrs& attrs){ + return std::vector >{{0, 0}, {1, 0}, {2, 0}}; + }) +.add_argument("ograd", "NDArray-or-Symbol", "gradient of output") +.add_argument("data", "NDArray-or-Symbol", "input") +.add_argument("output", "NDArray-or-Symbol", "output") .set_attr_parser(ParamParser) .set_attr("FCompute", SoftmaxGradCompute); diff --git a/tests/python/unittest/test_operator.py b/tests/python/unittest/test_operator.py index 2360e746e44b..4cf0a970e15e 100644 --- a/tests/python/unittest/test_operator.py +++ b/tests/python/unittest/test_operator.py @@ -4516,30 +4516,45 @@ def softmax_forward(input_data, true_output): softmax_forward(mx.nd.array([[[[3.4e38,3.4e38]]]]), np.array([1.0,1.0])) @with_seed() -def test_softmax_fp16(): - def check_fp16_fp32_almost_equal(input_data): - fp16_input = input_data.astype('float16') - fp32_input = input_data.astype('float32') - fp16_input.attach_grad() - fp32_input.attach_grad() +def test_softmax_dtype(): + def check_dtypes_almost_equal(op_name, + atol, rtol, + grad_atol, grad_rtol, + idtype, ref_dtype, odtype=None): + op = getattr(mx.nd, op_name) + input_data = mx.random.uniform(shape=(100, 500)) + dtype_input = input_data.astype(idtype) + ref_input = input_data.astype(ref_dtype) + dtype_input.attach_grad() + ref_input.attach_grad() with mx.autograd.record(): - fp16_softmax = fp16_input.softmax(axis=-1) - fp32_softmax = fp32_input.softmax(axis=-1) - fp16_softmax.backward() - fp32_softmax.backward() - assert_almost_equal(fp16_softmax.asnumpy(), fp32_softmax.asnumpy(), rtol=1e-5, atol=1e-5) - assert_almost_equal(fp16_input.grad.asnumpy(), fp32_input.grad.asnumpy(), rtol=1e-5, atol=1e-5) - - with mx.autograd.record(): - fp16_log_softmax = fp16_input.log_softmax(axis=-1) - fp32_log_softmax = fp32_input.log_softmax(axis=-1) - fp16_log_softmax.backward() - fp32_log_softmax.backward() - assert_almost_equal(fp16_log_softmax.asnumpy(), fp32_log_softmax.asnumpy(), rtol=1e-2, atol=1e-2) - assert_almost_equal(fp16_input.grad.asnumpy(), fp32_input.grad.asnumpy(), rtol=1e-2, atol=1e-2) - - for _ in range(5): - check_fp16_fp32_almost_equal(mx.random.uniform(shape=(100, 500))) + dtype_softmax = op(dtype_input, axis=-1, dtype=odtype) + ref_softmax = op(ref_input, axis=-1, dtype=odtype) + dtype_softmax_np = dtype_softmax.asnumpy() + ref_softmax_np = ref_softmax.asnumpy() + assert_almost_equal(dtype_softmax_np, ref_softmax_np, rtol=rtol, atol=atol) + dtype_softmax.backward() + ref_softmax.backward() + dtype_grad_np = dtype_input.grad.asnumpy() + ref_grad_np = ref_input.grad.asnumpy() + assert_almost_equal(dtype_grad_np, ref_grad_np, rtol=grad_rtol, atol=grad_atol) + + check_dtypes_almost_equal('softmax', 1e-5, 1e-5, 1e-5, 1e-5, 'float16', 'float32') + check_dtypes_almost_equal('softmax', 1e-5, 1e-5, 1e-5, 1e-5, 'float16', 'float32', 'float32') + check_dtypes_almost_equal('softmax', 1e-5, 1e-5, 1e-5, 1e-5, 'float32', 'float64') + check_dtypes_almost_equal('softmax', 1e-5, 1e-5, 1e-5, 1e-5, 'float32', 'float64', 'float64') + check_dtypes_almost_equal('softmin', 1e-5, 1e-5, 1e-5, 1e-5, 'float16', 'float32') + check_dtypes_almost_equal('softmin', 1e-5, 1e-5, 1e-5, 1e-5, 'float16', 'float32', 'float32') + check_dtypes_almost_equal('softmin', 1e-5, 1e-5, 1e-5, 1e-5, 'float32', 'float64') + check_dtypes_almost_equal('softmin', 1e-5, 1e-5, 1e-5, 1e-5, 'float32', 'float64', 'float64') + check_dtypes_almost_equal('log_softmax', 1e-2, 1e-2, 1e-2, 1e-2, + 'float16', 'float32') + check_dtypes_almost_equal('log_softmax', 1e-2, 1e-2, 1e-2, 1e-2, + 'float16', 'float32', 'float32') + check_dtypes_almost_equal('log_softmax', 1e-3, 1e-3, 1e-3, 1e-3, + 'float32', 'float64') + check_dtypes_almost_equal('log_softmax', 1e-3, 1e-3, 1e-3, 1e-3, + 'float32', 'float64', 'float64') @with_seed() def test_pick():