From b3b952f9d5490ee2707209ab866e6c3f094e2046 Mon Sep 17 00:00:00 2001 From: Hao Jin Date: Sun, 14 Apr 2019 22:04:52 -0700 Subject: [PATCH] fp16 safe norm operator (#14616) * use safe aggregation for norm * safe norm with DataType, AccuType and OutType * new test for backward * change back to MSHADOW_TYPE_SWITCH * remove dead debug outputs * Allow integer types --- src/operator/mshadow_op.h | 68 ++++- src/operator/mxnet_op.h | 83 +++++- src/operator/tensor/broadcast_reduce-inl.cuh | 61 +++-- src/operator/tensor/broadcast_reduce-inl.h | 38 ++- src/operator/tensor/broadcast_reduce_op.h | 257 +++++++++++++----- .../tensor/broadcast_reduce_op_value.cc | 2 +- src/operator/tensor/matrix_op-inl.h | 4 +- tests/python/unittest/test_operator.py | 65 +++-- 8 files changed, 430 insertions(+), 148 deletions(-) diff --git a/src/operator/mshadow_op.h b/src/operator/mshadow_op.h index c27a98ac1940..d9d6151c06bf 100644 --- a/src/operator/mshadow_op.h +++ b/src/operator/mshadow_op.h @@ -945,13 +945,13 @@ struct nanprod { /*! \brief compute l2 norm */ struct nrm2 { /*! \brief do reduction into dst */ - template - MSHADOW_XINLINE static void Reduce(volatile DType& sum_of_squares, volatile DType src) { // NOLINT(*) + template + MSHADOW_XINLINE static void Reduce(volatile AType& sum_of_squares, volatile DType src) { // NOLINT(*) sum_of_squares += src * src; } /*! \brief do stable reduction into dst */ - template - MSHADOW_XINLINE static void Reduce(volatile DType& sum_of_squares, volatile DType src, volatile DType& scale) { // NOLINT(*) + template + MSHADOW_XINLINE static void Reduce(volatile AType& sum_of_squares, volatile DType src, volatile DType& scale) { // NOLINT(*) if (src != 0) { DType abs = mshadow_op::abs::Map(src); if (scale < abs) { @@ -1012,6 +1012,66 @@ struct nrm2 { } }; +/*! \brief sum reducer */ +struct sum { + /*! \brief do reduction into dst */ + template + MSHADOW_XINLINE static void Reduce(volatile AType& dst, volatile DType src) { // NOLINT(*) + dst += src; + } + /*! \brief do stable reduction into dst */ + template + MSHADOW_XINLINE static void Reduce(volatile AType& dst, volatile DType src, volatile DType& residual) { // NOLINT(*) + DType y = src - residual; + DType t = dst + y; + residual = (t - dst) - y; + dst = t; + } + /*! \brief combine the results of two reducers */ + template + MSHADOW_XINLINE static void Merge(volatile DType& dst_val, volatile DType& src_val) { // NOLINT(*) + Reduce(dst_val, src_val); + } + /*! \brief combine the results of two reducers */ + template + MSHADOW_XINLINE static void Merge(volatile DType& dst_val, volatile DType& dst_residual, volatile DType& src_val, volatile DType& src_residual) { // NOLINT(*) + DType t1 = dst_val + src_val; + DType e = t1 - dst_val; + DType t2 = ((src_val - e) + (dst_val - (t1 - e))) + dst_residual + src_residual; + dst_val = t1 + t2; + dst_residual = t2 - (dst_val - t1); + } + /*! \brief finalize reduction */ + template + MSHADOW_XINLINE static void Finalize(volatile DType& dst) {} // NOLINT(*) + /*! \brief finalize reduction */ + template + MSHADOW_XINLINE static void Finalize(volatile DType& dst, volatile DType& residual) {} // NOLINT(*) + /*! + *\brief calculate gradient of redres with respect to redsrc, + * redres: reduced result, redsrc: one of reduction element + */ + template + MSHADOW_XINLINE static DType PartialGrad(DType redres, DType redsrc) { + return 1; + } + /*! + *\brief set the initial value during reduction + */ + template + MSHADOW_XINLINE static void SetInitValue(DType &initv) { // NOLINT(*) + initv = 0; + } + /*! + *\brief set the initial value during reduction + */ + template + MSHADOW_XINLINE static void SetInitValue(DType &initv, DType &residual) { // NOLINT(*) + SetInitValue(initv); + residual = 0; + } +}; + struct nanprod_grad : public mxnet_op::tunable { template MSHADOW_XINLINE static DType Map(DType a, DType b) { diff --git a/src/operator/mxnet_op.h b/src/operator/mxnet_op.h index d8fc5031e4ff..a937f839c9bb 100644 --- a/src/operator/mxnet_op.h +++ b/src/operator/mxnet_op.h @@ -273,20 +273,87 @@ inline int get_num_threads(const int N) { } \ break; \ case mshadow::kUint8: \ - LOG(FATAL) << "This operation only support " \ - "floating point types not uint8"; \ + { \ + typedef uint8_t DType; \ + typedef uint8_t AType; \ + LOG(FATAL) << "This operation only support " \ + "floating point types not uint8"; \ + } \ + break; \ + case mshadow::kInt8: \ + { \ + typedef int8_t DType; \ + typedef int8_t AType; \ + LOG(FATAL) << "This operation only support " \ + "floating point types not int8"; \ + } \ + break; \ + case mshadow::kInt32: \ + { \ + typedef int32_t DType; \ + typedef int32_t AType; \ + LOG(FATAL) << "This operation only support " \ + "floating point types, not int32"; \ + } \ + break; \ + case mshadow::kInt64: \ + { \ + typedef int64_t DType; \ + typedef int64_t AType; \ + LOG(FATAL) << "This operation only support " \ + "floating point types, not int64"; \ + } \ + break; \ + default: \ + LOG(FATAL) << "Unknown type enum " << type; \ + } + +#define MXNET_ACC_TYPE_SWITCH(type, DType, AType, ...)\ + switch (type) { \ + case mshadow::kFloat32: \ + { \ + typedef float DType; \ + typedef double AType; \ + {__VA_ARGS__} \ + } \ + break; \ + case mshadow::kFloat64: \ + { \ + typedef double DType; \ + typedef double AType; \ + {__VA_ARGS__} \ + } \ + break; \ + case mshadow::kFloat16: \ + { \ + typedef mshadow::half::half_t DType; \ + typedef float AType; \ + {__VA_ARGS__} \ + } \ + break; \ + case mshadow::kUint8: \ + { \ + typedef uint8_t DType; \ + typedef uint32_t AType; \ + } \ break; \ case mshadow::kInt8: \ - LOG(FATAL) << "This operation only support " \ - "floating point types not int8"; \ + { \ + typedef int8_t DType; \ + typedef int32_t AType; \ + } \ break; \ case mshadow::kInt32: \ - LOG(FATAL) << "This operation only support " \ - "floating point types, not int32"; \ + { \ + typedef int32_t DType; \ + typedef int64_t AType; \ + } \ break; \ case mshadow::kInt64: \ - LOG(FATAL) << "This operation only support " \ - "floating point types, not int64"; \ + { \ + typedef int64_t DType; \ + typedef int64_t AType; \ + } \ break; \ default: \ LOG(FATAL) << "Unknown type enum " << type; \ diff --git a/src/operator/tensor/broadcast_reduce-inl.cuh b/src/operator/tensor/broadcast_reduce-inl.cuh index 5d6c49ff8882..54db35061c6a 100644 --- a/src/operator/tensor/broadcast_reduce-inl.cuh +++ b/src/operator/tensor/broadcast_reduce-inl.cuh @@ -72,15 +72,15 @@ void BinaryBroadcastComputeImpl(Stream *s, const OpReqType req, } const int nthread_reduce = kMaxThreadsPerBlock; -template +template __launch_bounds__(nthread_reduce) __global__ void reduce_kernel(const int N, const int M, const bool addto, - const DType* __restrict big, DType *small, + const DType* __restrict big, OType *small, const Shape big_shape0, const Shape small_shape, const Shape big_shape, const Shape big_stride, const int Mnext, const bool do_transpose) { extern __shared__ char shTileChar[]; - DType* shTile = (DType*)(shTileChar); + AType* shTile = (AType*)(shTileChar); const int tid = threadIdx.x + threadIdx.y*blockDim.x; const int bx = (do_transpose) ? blockDim.y : blockDim.x; const int by = (do_transpose) ? blockDim.x : blockDim.y; @@ -95,7 +95,7 @@ __global__ void reduce_kernel(const int N, const int M, const bool addto, Shape coord = unravel(idx, small_shape); int idx_big0 = ravel(coord, big_shape0); - DType val, residual; + AType val, residual; Reducer::SetInitValue(val, residual); if (idx < N) { for (int k = tidy + Mstart; k < Mend; k += by*unroll) { @@ -113,7 +113,7 @@ __global__ void reduce_kernel(const int N, const int M, const bool addto, } #pragma unroll for (int u=0;u < unroll;u++) { - if (k + u*by < Mend) Reducer::Reduce(val, tmp[u], residual); + if (k + u*by < Mend) Reducer::Reduce(val, AType(tmp[u]), residual); } } } @@ -127,7 +127,7 @@ __global__ void reduce_kernel(const int N, const int M, const bool addto, shTile[it0 * 2 + 1] = residual; __syncthreads(); for (int t=1;t < by;t <<= 1) { - DType tmp, tmp_residual; + AType tmp, tmp_residual; Reducer::SetInitValue(tmp, tmp_residual); if (tidy + t < by) { tmp = shTile[(it0 + t*fbx) * 2]; @@ -139,12 +139,12 @@ __global__ void reduce_kernel(const int N, const int M, const bool addto, } if (idx < N && tidy == 0) { Reducer::Finalize(shTile[tidx * 2], shTile[tidx * 2 + 1]); - assign(&small[idx + m0*N], addto, shTile[tidx * 2]); + assign(&small[idx + m0*N], addto, OType(shTile[tidx * 2])); } } else { if (idx < N) { Reducer::Finalize(val, residual); - assign(&small[idx + m0*N], addto, val); + assign(&small[idx + m0*N], addto, OType(val)); } } } @@ -261,18 +261,18 @@ __global__ void reduce_lines_kernel(const int N, const int M, const bool addto, } } -template +template __global__ void reduce_kernel_M1(const int N, const bool addto, - const DType* __restrict big, DType *small, const Shape bshape, + const DType* __restrict big, OType *small, const Shape bshape, const Shape sshape) { for (int idx = threadIdx.x + blockIdx.x*blockDim.x; idx < N; idx += blockDim.x*gridDim.x) { Shape coord = unravel(idx, sshape); int j = ravel(coord, bshape); - DType val, residual; + AType val, residual; Reducer::SetInitValue(val, residual); - Reducer::Reduce(val, OP::Map(big[j]), residual); + Reducer::Reduce(val, AType(OP::Map(big[j])), residual); Reducer::Finalize(val, residual); - assign(&small[idx], addto, val); + assign(&small[idx], addto, OType(val)); } } @@ -491,7 +491,7 @@ ReduceImplConfig ConfigureReduceImpl(const mxnet::TShape& small, const mxn if (config.Mnext > 1) { // small_dptr[] is N*Mnext*sizeof(DType) bytes - config.workspace_size += config.N*config.Mnext*sizeof(DType); + config.workspace_size += config.N*config.Mnext*sizeof(double); // Set gridDim.y to Mnext config.kernel_1.gridDim.y = std::min(kBaseGridNum, config.Mnext); } @@ -516,23 +516,22 @@ ReduceImplConfig ConfigureReduceImpl(const mxnet::TShape& small, const mxn {__VA_ARGS__} \ } -template +template void ReduceImpl(cudaStream_t stream, const TBlob& small, const OpReqType req, const TBlob& big, const Tensor& workspace, const ReduceImplConfig& config) { if (config.M == 1) { - reduce_kernel_M1 + reduce_kernel_M1 <<< config.kernel_1.gridDim, config.kernel_1.blockDim, 0, stream >>>( - config.N, req == kAddTo, big.dptr(), small.dptr(), big.shape_.get(), + config.N, req == kAddTo, big.dptr(), small.dptr(), big.shape_.get(), small.shape_.get()); MSHADOW_CUDA_POST_KERNEL_CHECK(reduce_kernel_M1); } else { - - DType* small_dptr = small.dptr(); + OType* small_dptr = small.dptr(); bool addto = (req == kAddTo); if (config.Mnext > 1) { // small_dptr[] is N*Mnext*sizeof(DType) bytes - small_dptr = reinterpret_cast(workspace.dptr_); + small_dptr = reinterpret_cast(workspace.dptr_); addto = false; // Check that the workspace is contigiuous CHECK_EQ(workspace.CheckContiguous(), true); @@ -544,7 +543,7 @@ void ReduceImpl(cudaStream_t stream, const TBlob& small, const OpReqType req, config.kernel_1.blockDim.x : config.kernel_1.blockDim.y; const bool do_unroll = ( config.M / (by*config.Mnext) >= config.unroll_reduce ); KERNEL_UNROLL_SWITCH(do_unroll, ReduceImplConfig::unroll_reduce, UNROLL, { - reduce_kernel + reduce_kernel <<< config.kernel_1.gridDim, config.kernel_1.blockDim, config.kernel_1.shMemSize, stream>>>( config.N, config.M, addto, big.dptr(), small_dptr, big.shape_.get(), small.shape_.get(), config.rshape, config.rstride, config.Mnext, @@ -553,9 +552,9 @@ void ReduceImpl(cudaStream_t stream, const TBlob& small, const OpReqType req, MSHADOW_CUDA_POST_KERNEL_CHECK(reduce_kernel); if (config.Mnext > 1) { - reduce_lines_kernel + reduce_lines_kernel <<< config.kernel_2.gridSize, config.kernel_2.blockSize, 0, stream >>> - (config.N, config.Mnext, req == kAddTo, config.N, small_dptr, small.dptr()); + (config.N, config.Mnext, req == kAddTo, config.N, small_dptr, small.dptr()); MSHADOW_CUDA_POST_KERNEL_CHECK(reduce_lines_kernel); } } @@ -610,14 +609,26 @@ void ReduceImpl(cudaStream_t stream, const TBlob& small, const TBlob& lhs, const #undef KERNEL_UNROLL_SWITCH -template +template void Reduce(Stream *s, const TBlob& small, const OpReqType req, const Tensor& workspace, const TBlob& big) { if (req == kNullOp) return; cudaStream_t stream = Stream::GetStream(s); ReduceImplConfig config = ConfigureReduceImpl(small.shape_, big.shape_, NULL, NULL); - ReduceImpl(stream, small, req, big, workspace, config); + if (safe_acc) { + MXNET_ACC_TYPE_SWITCH(mshadow::DataType::kFlag, DataType, AType, { + typedef typename std::conditional::type AccType; + MSHADOW_TYPE_SWITCH(small.type_flag_, OType, { + typedef typename std::conditional::type OutType; + config = ConfigureReduceImpl(small.shape_, big.shape_, NULL, NULL); + ReduceImpl( + stream, small, req, big, workspace, config); + }); + }); + } else { + ReduceImpl(stream, small, req, big, workspace, config); + } } template diff --git a/src/operator/tensor/broadcast_reduce-inl.h b/src/operator/tensor/broadcast_reduce-inl.h index 0f6913e6e9df..be589c41168b 100644 --- a/src/operator/tensor/broadcast_reduce-inl.h +++ b/src/operator/tensor/broadcast_reduce-inl.h @@ -153,21 +153,21 @@ MSHADOW_XINLINE void binary_broadcast_assign(const index_t idx, const bool addto assign(&out[idx], addto, OP::Map(lhs[j], rhs[k])); } -template +template MSHADOW_XINLINE void seq_reduce_assign(const index_t idx, const size_t M, const bool addto, - const DType* __restrict big, DType *small, + const DType* __restrict big, OType *small, const Shape& bshape, const Shape& sshape, const Shape& rshape, const Shape& rstride) { Shape coord = unravel(idx, sshape); index_t j = ravel(coord, bshape); - DType val, residual; + AType val, residual; Reducer::SetInitValue(val, residual); for (size_t k = 0; k < M; ++k) { coord = unravel(k, rshape); - Reducer::Reduce(val, OP::Map(big[j + dot(coord, rstride)]), residual); + Reducer::Reduce(val, AType(OP::Map(big[j + dot(coord, rstride)])), residual); } Reducer::Finalize(val, residual); - assign(&small[idx], addto, val); + assign(&small[idx], addto, OType(val)); } #ifdef __CUDACC__ @@ -194,15 +194,15 @@ void BinaryBroadcastComputeImpl(Stream *s, const OpReqType req, out.shape_.get()); } -template +template void seq_reduce_compute(const size_t N, const size_t M, const bool addto, - const DType *big, DType *small, const Shape bshape, + const DType *big, OType *small, const Shape bshape, const Shape sshape, const Shape rshape, const Shape rstride) { #pragma omp parallel for num_threads(engine::OpenMP::Get()->GetRecommendedOMPThreadCount()) for (index_t idx = 0; idx < static_cast(N); ++idx) { - seq_reduce_assign(idx, M, addto, big, small, bshape, sshape, rshape, - rstride); + seq_reduce_assign(idx, M, addto, big, small, + bshape, sshape, rshape, rstride); } } @@ -227,16 +227,28 @@ void seq_reduce_compute_extra_mem(const size_t N, const size_t M, const bool add } } -template +template void Reduce(Stream* s, const TBlob& small, const OpReqType req, const Tensor& workspace, const TBlob& big) { if (req == kNullOp) return; Shape rshape, rstride; diff(small.shape_.get(), big.shape_.get(), &rshape, &rstride); size_t N = small.shape_.Size(), M = rshape.Size(); - seq_reduce_compute( - N, M, req == kAddTo, big.dptr(), small.dptr(), - big.shape_.get(), small.shape_.get(), rshape, rstride); + if (!safe_acc) { + seq_reduce_compute( + N, M, req == kAddTo, big.dptr(), small.dptr(), + big.shape_.get(), small.shape_.get(), rshape, rstride); + } else { + MXNET_ACC_TYPE_SWITCH(mshadow::DataType::kFlag, DataType, AType, { + typedef typename std::conditional::type AccType; + MSHADOW_TYPE_SWITCH(small.type_flag_, OType, { + typedef typename std::conditional::type OutType; + seq_reduce_compute( + N, M, req == kAddTo, big.dptr(), small.dptr(), + big.shape_.get(), small.shape_.get(), rshape, rstride); + }); + }); + } } template diff --git a/src/operator/tensor/broadcast_reduce_op.h b/src/operator/tensor/broadcast_reduce_op.h index b13906af6624..069c8ddb04fb 100644 --- a/src/operator/tensor/broadcast_reduce_op.h +++ b/src/operator/tensor/broadcast_reduce_op.h @@ -67,6 +67,7 @@ struct ReduceAxesParam : public dmlc::Parameter { struct NormParam : public dmlc::Parameter { int ord; dmlc::optional axis; + dmlc::optional out_dtype; bool keepdims; DMLC_DECLARE_PARAMETER(NormParam) { DMLC_DECLARE_FIELD(ord).set_default(2) @@ -78,6 +79,15 @@ struct NormParam : public dmlc::Parameter { If `axis` is int, a reduction is performed on a particular axis. If `axis` is a 2-tuple, it specifies the axes that hold 2-D matrices, and the matrix norms of these matrices are computed.)code"); + DMLC_DECLARE_FIELD(out_dtype) + .add_enum("float16", mshadow::kFloat16) + .add_enum("float32", mshadow::kFloat32) + .add_enum("float64", mshadow::kFloat64) + .add_enum("int64", mshadow::kInt64) + .add_enum("int32", mshadow::kInt32) + .add_enum("int8", mshadow::kInt8) + .set_default(dmlc::optional()) + .describe(R"code(The data type of the output.)code"); DMLC_DECLARE_FIELD(keepdims).set_default(false) .describe("If this is set to `True`, the reduced axis is left " "in the result as dimension with size one."); @@ -302,6 +312,23 @@ inline bool ReduceAxesShape(const nnvm::NodeAttrs& attrs, return true; } +inline bool NormType(const nnvm::NodeAttrs& attrs, + std::vector *in_attrs, + std::vector *out_attrs) { + CHECK_EQ(in_attrs->size(), 1U); + CHECK_EQ(out_attrs->size(), 1U); + const NormParam& param = nnvm::get(attrs.parsed); + if (param.out_dtype.has_value()) { + CHECK_NE(in_attrs->at(0), -1) + << "input data type should be specified when out_dtype is not null"; + TYPE_ASSIGN_CHECK(*out_attrs, 0, param.out_dtype.value()); + } else { + TYPE_ASSIGN_CHECK(*out_attrs, 0, (*in_attrs)[0]); + TYPE_ASSIGN_CHECK(*in_attrs, 0, (*out_attrs)[0]); + } + return (*out_attrs)[0] != -1; +} + inline bool NormShape(const nnvm::NodeAttrs& attrs, mxnet::ShapeVector *in_attrs, mxnet::ShapeVector *out_attrs) { @@ -525,7 +552,7 @@ void SearchAxisCompute(const nnvm::NodeAttrs& attrs, }); } -template void ReduceAxesComputeImpl(const OpContext& ctx, const std::vector& inputs, @@ -538,20 +565,22 @@ void ReduceAxesComputeImpl(const OpContext& ctx, mxnet::TShape src_shape, dst_shape; BroadcastReduceShapeCompact(inputs[0].shape_, small, &src_shape, &dst_shape); Stream *s = ctx.get_stream(); - MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, DType, { - const TBlob in_data = inputs[0].reshape(src_shape); - const TBlob out_data = outputs[0].reshape(dst_shape); - BROADCAST_NDIM_SWITCH(dst_shape.ndim(), NDim, { - size_t workspace_size = broadcast::ReduceWorkspaceSize( - s, out_data.shape_, req[0], in_data.shape_); - Tensor workspace = - ctx.requested[0].get_space_typed(Shape1(workspace_size), s); - broadcast::Reduce( - s, out_data, req[0], workspace, in_data); - if (normalize) { - auto out = out_data.FlatTo2D(s); - out /= scalar(src_shape.Size()/dst_shape.Size()); - } + MSHADOW_TYPE_SWITCH(inputs[0].type_flag_, DType, { + MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, OType, { + const TBlob in_data = inputs[0].reshape(src_shape); + const TBlob out_data = outputs[0].reshape(dst_shape); + BROADCAST_NDIM_SWITCH(dst_shape.ndim(), NDim, { + size_t workspace_size = broadcast::ReduceWorkspaceSize( + s, out_data.shape_, req[0], in_data.shape_); + Tensor workspace = + ctx.requested[0].get_space_typed(Shape1(workspace_size), s); + broadcast::Reduce( + s, out_data, req[0], workspace, in_data); + if (normalize) { + auto out = out_data.FlatTo2D(s); + out /= scalar(src_shape.Size()/dst_shape.Size()); + } + }); }); }); } @@ -571,7 +600,7 @@ void ReduceAxesCompute(const nnvm::NodeAttrs& attrs, small = ReduceAxesShapeImpl(inputs[0].shape_, param.axis, true, param.exclude); } - ReduceAxesComputeImpl(ctx, inputs, req, outputs, small); + ReduceAxesComputeImpl(ctx, inputs, req, outputs, small); } template @@ -813,6 +842,35 @@ void ReduceAxesOpForwardEx(const nnvm::NodeAttrs& attrs, const OpContext& ctx, } } +template +struct reduce_axes_backward_broadcast { + template + MSHADOW_XINLINE static void Map(index_t i, + DType *data, + OType *out, + DType *igrad, + OType *ograd, + mshadow::Shape<5> in_shape, + mshadow::Shape<5> out_shape, + const uint32_t ndim) { + size_t in_stride = 1; + size_t out_stride = 1; + index_t idx = i; + index_t out_idx = i; + for (int iter = ndim - 1; iter >= 0; --iter) { + size_t dim_idx = idx % in_shape[iter]; + out_idx -= dim_idx * in_stride; + if (out_shape[iter] != 1) { + out_idx += dim_idx * out_stride; + } + idx /= in_shape[iter]; + in_stride *= in_shape[iter]; + out_stride *= out_shape[iter]; + } + KERNEL_ASSIGN(igrad[i], req, DType(ograd[out_idx]) * OP::Map(data[i], DType(out[out_idx]))); + } +}; + template void ReduceAxesBackwardUseInOutImpl(const OpContext& ctx, const mxnet::TShape &small, @@ -821,37 +879,58 @@ void ReduceAxesBackwardUseInOutImpl(const OpContext& ctx, const std::vector& outputs) { using namespace mshadow; using namespace mshadow::expr; + using namespace mxnet_op; mxnet::TShape src_shape, dst_shape; BroadcastReduceShapeCompact(outputs[0].shape_, small, &src_shape, &dst_shape); Stream *s = ctx.get_stream(); - MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, DType, { - if (dst_shape.ndim() == 2) { - Tensor igrad = - outputs[0].get_with_shape(src_shape.get<2>(), s); - Tensor ograd = - inputs[0].get_with_shape(dst_shape.get<2>(), s); - Tensor data = - inputs[1].get_with_shape(src_shape.get<2>(), s); - Tensor out = - inputs[2].get_with_shape(dst_shape.get<2>(), s); - ASSIGN_DISPATCH(igrad, req[0], - broadcast_to(ograd, src_shape)*F(data, broadcast_to(out, src_shape))); - if (normalize) igrad /= scalar(src_shape.Size()/dst_shape.Size()); - } else { - const int ndim = MXNET_SPECIAL_MAX_NDIM; - Tensor igrad = - outputs[0].get_with_shape(src_shape.get(), s); - Tensor ograd = - inputs[0].get_with_shape(dst_shape.get(), s); - Tensor data = - inputs[1].get_with_shape(src_shape.get(), s); - Tensor out = - inputs[2].get_with_shape(dst_shape.get(), s); - ASSIGN_DISPATCH(igrad, req[0], - broadcast_to(ograd, src_shape)*F(data, broadcast_to(out, src_shape))); - if (normalize) igrad /= scalar(src_shape.Size()/dst_shape.Size()); - } + + MSHADOW_TYPE_SWITCH(inputs[0].type_flag_, DType, { + MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, OType, { + mshadow::Shape<5> in_shape; + mshadow::Shape<5> out_shape; + for (uint32_t i = 0; i < 5; ++i) { + if (i < dst_shape.ndim()) { + in_shape[i] = src_shape[i]; + out_shape[i] = dst_shape[i]; + } else { + in_shape[i] = 1; + out_shape[i] = 1; + } + } + if (dst_shape.ndim() == 2) { + Tensor igrad = + outputs[0].get_with_shape(src_shape.get<2>(), s); + Tensor ograd = + inputs[0].get_with_shape(dst_shape.get<2>(), s); + Tensor data = + inputs[1].get_with_shape(src_shape.get<2>(), s); + Tensor out = + inputs[2].get_with_shape(dst_shape.get<2>(), s); + MXNET_REQ_TYPE_SWITCH(req[0], Req, { + Kernel, xpu>::Launch( + s, outputs[0].shape_.Size(), data.dptr_, out.dptr_, igrad.dptr_, ograd.dptr_, + in_shape, out_shape, src_shape.ndim()); + }); + if (normalize) igrad /= scalar(src_shape.Size()/dst_shape.Size()); + } else { + const int ndim = MXNET_SPECIAL_MAX_NDIM; + Tensor igrad = + outputs[0].get_with_shape(src_shape.get(), s); + Tensor ograd = + inputs[0].get_with_shape(dst_shape.get(), s); + Tensor data = + inputs[1].get_with_shape(src_shape.get(), s); + Tensor out = + inputs[2].get_with_shape(dst_shape.get(), s); + MXNET_REQ_TYPE_SWITCH(req[0], Req, { + Kernel, xpu>::Launch( + s, outputs[0].shape_.Size(), data.dptr_, out.dptr_, igrad.dptr_, ograd.dptr_, + in_shape, out_shape, src_shape.ndim()); + }); + if (normalize) igrad /= scalar(src_shape.Size()/dst_shape.Size()); + } + }); }); } @@ -1090,14 +1169,42 @@ void LpNormCompute(const nnvm::NodeAttrs& attrs, small = ReduceAxesShapeImpl(inputs[0].shape_, param.axis, true, false); } if (param.ord == 1) { - ReduceAxesComputeImpl( - ctx, inputs, req, outputs, small); + ReduceAxesComputeImpl( + ctx, inputs, req, outputs, small); } else if (param.ord == 2) { - ReduceAxesComputeImpl( + ReduceAxesComputeImpl( ctx, inputs, req, outputs, small); } } +template +struct norm_backward_broadcast { + template + MSHADOW_XINLINE static void Map(index_t i, + DType *igrad, + OType *ograd, + DType *data, + mshadow::Shape<5> in_shape, + mshadow::Shape<5> out_shape, + const uint32_t ndim) { + size_t in_stride = 1; + size_t out_stride = 1; + index_t idx = i; + index_t out_idx = i; + for (int iter = ndim - 1; iter >= 0; --iter) { + size_t dim_idx = idx % in_shape[iter]; + out_idx -= dim_idx * in_stride; + if (out_shape[iter] != 1) { + out_idx += dim_idx * out_stride; + } + idx /= in_shape[iter]; + in_stride *= in_shape[iter]; + out_stride *= out_shape[iter]; + } + KERNEL_ASSIGN(igrad[i], req, ograd[out_idx] * mshadow_op::sign::Map(data[i])); + } +}; + template void LpNormGradCompute(const nnvm::NodeAttrs& attrs, const OpContext& ctx, @@ -1106,6 +1213,7 @@ void LpNormGradCompute(const nnvm::NodeAttrs& attrs, const std::vector& outputs) { using namespace mshadow; using namespace mshadow::expr; + using namespace mxnet_op; if (req[0] == kNullOp) return; const NormParam& param = nnvm::get(attrs.parsed); @@ -1119,27 +1227,46 @@ void LpNormGradCompute(const nnvm::NodeAttrs& attrs, mxnet::TShape src_shape, dst_shape; BroadcastReduceShapeCompact(outputs[0].shape_, small, &src_shape, &dst_shape); Stream *s = ctx.get_stream(); - MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, DType, { - if (dst_shape.ndim() == 2) { - Tensor ograd = - inputs[0].get_with_shape(dst_shape.get<2>(), s); - Tensor igrad = - outputs[0].get_with_shape(src_shape.get<2>(), s); - Tensor data = - inputs[1].get_with_shape(src_shape.get<2>(), s); - ASSIGN_DISPATCH(igrad, req[0], - broadcast_to(ograd, src_shape)*F(data)); + mshadow::Shape<5> in_shape; + mshadow::Shape<5> out_shape; + for (uint32_t i = 0; i < 5; ++i) { + if (i < dst_shape.ndim()) { + in_shape[i] = src_shape[i]; + out_shape[i] = dst_shape[i]; } else { - const int ndim = MXNET_SPECIAL_MAX_NDIM; - Tensor igrad = - outputs[0].get_with_shape(src_shape.get(), s); - Tensor ograd = - inputs[0].get_with_shape(dst_shape.get(), s); - Tensor data = - inputs[1].get_with_shape(src_shape.get(), s); - ASSIGN_DISPATCH(igrad, req[0], - broadcast_to(ograd, src_shape)*F(data)); + in_shape[i] = 1; + out_shape[i] = 1; } + } + MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, DType, { + MSHADOW_TYPE_SWITCH(inputs[0].type_flag_, OType, { + if (dst_shape.ndim() == 2) { + Tensor ograd = + inputs[0].get_with_shape(dst_shape.get<2>(), s); + Tensor igrad = + outputs[0].get_with_shape(src_shape.get<2>(), s); + Tensor data = + inputs[1].get_with_shape(src_shape.get<2>(), s); + MXNET_REQ_TYPE_SWITCH(req[0], Req, { + Kernel, xpu>::Launch( + s, igrad.shape_.Size(), igrad.dptr_, ograd.dptr_, data.dptr_, + in_shape, out_shape, src_shape.ndim()); + }); + } else { + const int ndim = MXNET_SPECIAL_MAX_NDIM; + Tensor igrad = + outputs[0].get_with_shape(src_shape.get(), s); + Tensor ograd = + inputs[0].get_with_shape(dst_shape.get(), s); + Tensor data = + inputs[1].get_with_shape(src_shape.get(), s); + MXNET_REQ_TYPE_SWITCH(req[0], Req, { + Kernel, xpu>::Launch( + s, igrad.shape_.Size(), igrad.dptr_, ograd.dptr_, data.dptr_, + in_shape, out_shape, src_shape.ndim()); + }); + } + }); }); } else if (param.ord == 2) { ReduceAxesBackwardUseInOutImpl(ctx, small, inputs, diff --git a/src/operator/tensor/broadcast_reduce_op_value.cc b/src/operator/tensor/broadcast_reduce_op_value.cc index 52fd61aa110e..f4231917e90d 100644 --- a/src/operator/tensor/broadcast_reduce_op_value.cc +++ b/src/operator/tensor/broadcast_reduce_op_value.cc @@ -352,7 +352,7 @@ Examples:: .set_num_outputs(1) .set_attr_parser(ParamParser) .set_attr("FInferShape", NormShape) -.set_attr("FInferType", ElemwiseType<1, 1>) +.set_attr("FInferType", NormType) .set_attr("FInferStorageType", LpNormStorageType) .set_attr("FGradient", ReduceGrad{ "_backward_norm" }) .set_attr("FResourceRequest", diff --git a/src/operator/tensor/matrix_op-inl.h b/src/operator/tensor/matrix_op-inl.h index fa108158b5c9..ba62d0e9def7 100644 --- a/src/operator/tensor/matrix_op-inl.h +++ b/src/operator/tensor/matrix_op-inl.h @@ -1732,7 +1732,7 @@ void RepeatOpBackward(const nnvm::NodeAttrs& attrs, inputs[0].type_flag_, inputs[0].dev_id()); std::vector newInputs = {iblob}; - ReduceAxesComputeImpl( + ReduceAxesComputeImpl( ctx, newInputs, req, newOutputs, rshapes.first); } @@ -1914,7 +1914,7 @@ void TileOpBackward(const nnvm::NodeAttrs& attrs, inputs[0].type_flag_, inputs[0].dev_id()); std::vector newInputs = {iblob}; - ReduceAxesComputeImpl( + ReduceAxesComputeImpl( ctx, newInputs, req, newOutputs, rshapes.first); } diff --git a/tests/python/unittest/test_operator.py b/tests/python/unittest/test_operator.py index ccb351f434da..59d72d4b18b6 100644 --- a/tests/python/unittest/test_operator.py +++ b/tests/python/unittest/test_operator.py @@ -3384,41 +3384,46 @@ def l2norm(input_data, axis=0, keepdims=True): ctx = default_context() data = mx.symbol.Variable('data') in_data_dim = random_sample([4,5,6], 1)[0] - in_shape = rand_shape_nd(in_data_dim) + in_shape = rand_shape_nd(in_data_dim, dim=5) epsilon = 1e-3 + acc_type = {np.float16: np.float32, np.float32: np.float32, np.float64: np.float64} for order in [1, 2]: for dtype in [np.float16, np.float32, np.float64]: - in_data = np.random.uniform(-1, 1, in_shape).astype(dtype) - in_data[abs(in_data) < epsilon] = 2 * epsilon for i in range(in_data_dim): - norm_sym = mx.symbol.norm(data=data, ord=order, axis=i, keepdims=True) - npy_out = l1norm(in_data, i) if order is 1 else l2norm(in_data, i) - npy_out_backward = np.sign(in_data) if order is 1 else in_data/npy_out - check_symbolic_forward(norm_sym, [in_data], [npy_out], - rtol=1e-2 if dtype is np.float16 else 1e-5, - atol=1e-2 if dtype is np.float16 else 1e-5, ctx=ctx) - check_symbolic_backward(norm_sym, [in_data], [np.ones(npy_out.shape)], - [npy_out_backward], - rtol=1e-2 if dtype is np.float16 else 1e-5, - atol=1e-2 if dtype is np.float16 else 1e-5, ctx=ctx) - # Disable numeric gradient https://github.com/apache/incubator-mxnet/issues/11509 - # # check gradient - # if dtype is not np.float16: - # check_numeric_gradient(norm_sym, [in_data], numeric_eps=epsilon, rtol=1e-1, atol=1e-3) - if i < in_data_dim-1: - norm_sym = mx.symbol.norm(data=data, ord=order, axis=(i, i+1), keepdims=True) - npy_out = l1norm(in_data, (i, i+1)) if order is 1 else l2norm(in_data, (i, i+1)) + for out_dtype in ['float32', 'float64']: + backward_dtype = np.float32 if out_dtype == 'float32' else np.float64 + print(order, dtype, i, out_dtype, in_shape) + in_data = np.random.uniform(-1, 1, in_shape).astype(acc_type[dtype]) + in_data[abs(in_data) < epsilon] = 2 * epsilon + norm_sym = mx.symbol.norm(data=data, ord=order, axis=i, out_dtype=out_dtype, keepdims=True) + npy_out = l1norm(in_data, i) if order is 1 else l2norm(in_data, i) npy_out_backward = np.sign(in_data) if order is 1 else in_data/npy_out - check_symbolic_forward(norm_sym, [in_data], [npy_out], - rtol=1e-2 if dtype is np.float16 else 1e-5, - atol=1e-2 if dtype is np.float16 else 1e-5, ctx=ctx) - check_symbolic_backward(norm_sym, [in_data], [np.ones(npy_out.shape)], - [npy_out_backward], - rtol=1e-2 if dtype is np.float16 else 1e-5, - atol=1e-2 if dtype is np.float16 else 1e-5, ctx=ctx) - # # check gradient - # if dtype is not np.float16: - # check_numeric_gradient(norm_sym, [in_data], numeric_eps=epsilon, rtol=1e-1, atol=1e-3) + check_symbolic_forward(norm_sym, [in_data.astype(dtype)], [npy_out.astype(out_dtype)], + rtol=1e-3, atol=1e-5, ctx=ctx) + check_symbolic_backward(norm_sym, [in_data.astype(dtype)], + [np.ones(npy_out.shape).astype(out_dtype)], + [npy_out_backward], rtol=1e-3, atol=1e-5, ctx=ctx, + dtype=backward_dtype) + # Disable numeric gradient https://github.com/apache/incubator-mxnet/issues/11509 + # check gradient + if dtype is not np.float16: + check_numeric_gradient(norm_sym, [in_data], numeric_eps=epsilon, + rtol=1e-1, atol=1e-3, dtype=backward_dtype) + if i < in_data_dim-1: + norm_sym = mx.symbol.norm(data=data, ord=order, axis=(i, i+1), keepdims=True) + npy_out = l1norm(in_data, (i, i+1)) if order is 1 else l2norm(in_data, (i, i+1)) + npy_out_backward = np.sign(in_data) if order is 1 else in_data/npy_out + check_symbolic_forward(norm_sym, [in_data], [npy_out.astype(dtype)], + rtol=1e-3 if dtype is np.float16 else 1e-3, + atol=1e-5 if dtype is np.float16 else 1e-5, ctx=ctx) + check_symbolic_backward(norm_sym, [in_data], + [np.ones(npy_out.shape).astype(out_dtype)], + [npy_out_backward.astype(out_dtype)], + rtol=1e-3, atol=1e-5, ctx=ctx, dtype=backward_dtype) + # check gradient + if dtype is not np.float16: + check_numeric_gradient(norm_sym, [in_data], numeric_eps=epsilon, + rtol=1e-1, atol=1e-3, dtype=backward_dtype) def test_layer_norm():