Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
grad use in-out only when dtype override
Browse files Browse the repository at this point in the history
  • Loading branch information
szha committed Feb 12, 2019
1 parent fd37040 commit bdc0d6d
Show file tree
Hide file tree
Showing 2 changed files with 72 additions and 37 deletions.
82 changes: 63 additions & 19 deletions src/operator/nn/softmax-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -308,16 +308,16 @@ struct SoftmaxParam : public dmlc::Parameter<SoftmaxParam> {
}
};

inline bool SoftmaxOpType(const nnvm::NodeAttrs& attrs,
std::vector<int>* in_attrs,
std::vector<int>* out_attrs) {
static inline bool SoftmaxOpType(const nnvm::NodeAttrs& attrs,
std::vector<int>* in_attrs,
std::vector<int>* out_attrs) {
CHECK_EQ(in_attrs->size(), 1);
CHECK_EQ(out_attrs->size(), 1);
const SoftmaxParam& param = nnvm::get<SoftmaxParam>(attrs.parsed);

int arg_dtype = param.dtype.has_value()?param.dtype.value():-1,
in_dtype = (*in_attrs)[0],
out_dtype = (*out_attrs)[0];
int arg_dtype = param.dtype.has_value() ? param.dtype.value() : -1;
int in_dtype = (*in_attrs)[0];
int out_dtype = (*out_attrs)[0];

if (out_dtype != -1 && in_dtype != -1) {
TYPE_ASSIGN_CHECK(*out_attrs, 0, arg_dtype);
Expand All @@ -342,20 +342,61 @@ inline bool SoftmaxOpType(const nnvm::NodeAttrs& attrs,
}
}

inline bool SoftmaxGradOpType(const nnvm::NodeAttrs& attrs,
std::vector<int>* in_attrs,
std::vector<int>* out_attrs) {
CHECK_EQ(in_attrs->size(), 3);
CHECK_EQ(out_attrs->size(), 1);
static inline bool softmax_has_dtype_override(const nnvm::NodeAttrs& attrs) {
const SoftmaxParam& param = nnvm::get<SoftmaxParam>(attrs.parsed);
return param.dtype.has_value() && param.dtype.value() != -1;
}

static inline bool SoftmaxGradOpShape(const nnvm::NodeAttrs& attrs,
std::vector<TShape> *in_attrs,
std::vector<TShape> *out_attrs) {
if (softmax_has_dtype_override(attrs)) {
return ElemwiseShape<3, 1>(attrs, in_attrs, out_attrs);
} else {
return ElemwiseShape<2, 1>(attrs, in_attrs, out_attrs);
}
}

static inline bool SoftmaxGradOpType(const nnvm::NodeAttrs& attrs,
std::vector<int>* in_attrs,
std::vector<int>* out_attrs) {
if (softmax_has_dtype_override(attrs)) {
int in_dtype = (*in_attrs)[1];
int out_dtype = (*in_attrs)[2];
TYPE_ASSIGN_CHECK(*in_attrs, 0, out_dtype);
TYPE_ASSIGN_CHECK(*out_attrs, 0, in_dtype);

int in_dtype = (*in_attrs)[1],
out_dtype = (*in_attrs)[2];
TYPE_ASSIGN_CHECK(*in_attrs, 0, out_dtype);
TYPE_ASSIGN_CHECK(*out_attrs, 0, in_dtype);
return (*out_attrs)[0] != -1 && (*in_attrs)[0] != -1;
} else {
int out_dtype = (*in_attrs)[1];
TYPE_ASSIGN_CHECK(*out_attrs, 0, out_dtype);
TYPE_ASSIGN_CHECK(*in_attrs, 0, out_dtype);

return (*out_attrs)[0] != -1 && (*in_attrs)[0] != -1;
return (*out_attrs)[0] != -1 && (*in_attrs)[0] != -1;
}
}

static inline std::vector<std::pair<int, int> >
SoftmaxGradOpInplaceOption(const nnvm::NodeAttrs& attrs) {
if (softmax_has_dtype_override(attrs)) {
return std::vector<std::pair<int, int> >{{0, 0}, {1, 0}, {2, 0}};
} else {
return std::vector<std::pair<int, int> >{{0, 0}, {1, 0}};
}
}

struct SoftmaxFGradient {
const char *op_name;
std::vector<nnvm::NodeEntry> operator()(const nnvm::NodePtr& n,
const std::vector<nnvm::NodeEntry>& ograds) const {
if (softmax_has_dtype_override(n->attrs)) {
return ElemwiseGradUseInOut {op_name}(n, ograds);
} else {
return ElemwiseGradUseOut {op_name}(n, ograds);
}
}
};

template<typename xpu, typename OP, bool negate = false>
void SoftmaxCompute(const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
Expand Down Expand Up @@ -401,17 +442,20 @@ void SoftmaxGradCompute(const nnvm::NodeAttrs& attrs,
const double temperature = param.temperature.has_value() ?
param.temperature.value() : 1.0;
TShape shape = AxisShapeCompact(inputs[0].shape_, &axis, true);
MXNET_REAL_ACC_TYPE_SWITCH(inputs[2].type_flag_, OType, AType, {

int out_idx = softmax_has_dtype_override(attrs) ? 2 : 1;

MXNET_REAL_ACC_TYPE_SWITCH(inputs[0].type_flag_, OType, AType, {
MSHADOW_REAL_TYPE_SWITCH(outputs[0].type_flag_, DType, {
MXNET_ASSIGN_REQ_SWITCH(req[0], Req, {
if (shape.ndim() == 2) {
SoftmaxGrad<OP1, OP2, Req, negate, AType>(
ctx.get_stream<xpu>(), inputs[2].dptr<OType>(),
ctx.get_stream<xpu>(), inputs[out_idx].dptr<OType>(),
inputs[0].dptr<OType>(), outputs[0].dptr<DType>(),
shape.get<2>(), axis, static_cast<DType>(temperature));
} else {
SoftmaxGrad<OP1, OP2, Req, negate, AType>(
ctx.get_stream<xpu>(), inputs[2].dptr<OType>(),
ctx.get_stream<xpu>(), inputs[out_idx].dptr<OType>(),
inputs[0].dptr<OType>(), outputs[0].dptr<DType>(),
shape.get<3>(), axis, static_cast<DType>(temperature));
}
Expand Down
27 changes: 9 additions & 18 deletions src/operator/nn/softmax.cc
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ Example::
.set_attr<FComputeEx>("FComputeEx<cpu>", SoftmaxComputeExCPU)
.set_attr<FInferStorageType>("FInferStorageType", SoftmaxStorageType)
#endif
.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseInOut{"_backward_softmax"})
.set_attr<nnvm::FGradient>("FGradient", SoftmaxFGradient{"_backward_softmax"})
.set_attr<nnvm::FInferType>("FInferType", SoftmaxOpType)
.set_num_inputs(1)
.set_num_outputs(1)
Expand All @@ -121,12 +121,9 @@ NNVM_REGISTER_OP(_backward_softmax)
[](const NodeAttrs& attrs) {
return std::vector<std::string>{"ograd", "data", "output"};
})
.set_attr<nnvm::FInferShape>("FInferShape", ElemwiseShape<3, 1>)
.set_attr<nnvm::FInferShape>("FInferShape", SoftmaxGradOpShape)
.set_attr<nnvm::FInferType>("FInferType", SoftmaxGradOpType)
.set_attr<nnvm::FInplaceOption>("FInplaceOption",
[](const NodeAttrs& attrs){
return std::vector<std::pair<int, int> >{{0, 0}, {1, 0}, {2, 0}};
})
.set_attr<nnvm::FInplaceOption>("FInplaceOption", SoftmaxGradOpInplaceOption)
.add_argument("ograd", "NDArray-or-Symbol", "gradient of output")
.add_argument("data", "NDArray-or-Symbol", "input")
.add_argument("output", "NDArray-or-Symbol", "output")
Expand Down Expand Up @@ -165,7 +162,7 @@ Example::
return std::vector<std::string>{"output"};
})
.set_attr<FCompute>("FCompute<cpu>", SoftmaxCompute<cpu, mxnet_op::softmax_fwd, true>)
.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseInOut{"_backward_softmin"})
.set_attr<nnvm::FGradient>("FGradient", SoftmaxFGradient{"_backward_softmin"})
.set_attr<nnvm::FInferType>("FInferType", SoftmaxOpType)
.set_num_inputs(1)
.set_num_outputs(1)
Expand All @@ -184,12 +181,9 @@ NNVM_REGISTER_OP(_backward_softmin)
[](const NodeAttrs& attrs) {
return std::vector<std::string>{"ograd", "data", "output"};
})
.set_attr<nnvm::FInferShape>("FInferShape", ElemwiseShape<3, 1>)
.set_attr<nnvm::FInferShape>("FInferShape", SoftmaxGradOpShape)
.set_attr<nnvm::FInferType>("FInferType", SoftmaxGradOpType)
.set_attr<nnvm::FInplaceOption>("FInplaceOption",
[](const NodeAttrs& attrs){
return std::vector<std::pair<int, int> >{{0, 0}, {1, 0}, {2, 0}};
})
.set_attr<nnvm::FInplaceOption>("FInplaceOption", SoftmaxGradOpInplaceOption)
.add_argument("ograd", "NDArray-or-Symbol", "gradient of output")
.add_argument("data", "NDArray-or-Symbol", "input")
.add_argument("output", "NDArray-or-Symbol", "output")
Expand All @@ -216,7 +210,7 @@ Examples::
)code")
.set_attr_parser(ParamParser<SoftmaxParam>)
.set_attr<FCompute>("FCompute<cpu>", SoftmaxCompute<cpu, mxnet_op::log_softmax_fwd>)
.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseInOut{"_backward_log_softmax"})
.set_attr<nnvm::FGradient>("FGradient", SoftmaxFGradient{"_backward_log_softmax"})
.set_attr<nnvm::FInferType>("FInferType", SoftmaxOpType)
.set_num_inputs(1)
.set_num_outputs(1)
Expand All @@ -235,12 +229,9 @@ NNVM_REGISTER_OP(_backward_log_softmax)
[](const NodeAttrs& attrs) {
return std::vector<std::string>{"ograd", "data", "output"};
})
.set_attr<nnvm::FInferShape>("FInferShape", ElemwiseShape<3, 1>)
.set_attr<nnvm::FInferShape>("FInferShape", SoftmaxGradOpShape)
.set_attr<nnvm::FInferType>("FInferType", SoftmaxGradOpType)
.set_attr<nnvm::FInplaceOption>("FInplaceOption",
[](const NodeAttrs& attrs){
return std::vector<std::pair<int, int> >{{0, 0}, {1, 0}, {2, 0}};
})
.set_attr<nnvm::FInplaceOption>("FInplaceOption", SoftmaxGradOpInplaceOption)
.add_argument("ograd", "NDArray-or-Symbol", "gradient of output")
.add_argument("data", "NDArray-or-Symbol", "input")
.add_argument("output", "NDArray-or-Symbol", "output")
Expand Down

0 comments on commit bdc0d6d

Please sign in to comment.