Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
simplify infer type
Browse files Browse the repository at this point in the history
  • Loading branch information
szha committed Feb 13, 2019
1 parent 704b0e7 commit 811a37a
Showing 1 changed file with 9 additions and 27 deletions.
36 changes: 9 additions & 27 deletions src/operator/nn/softmax-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -309,45 +309,27 @@ struct SoftmaxParam : public dmlc::Parameter<SoftmaxParam> {
}
};

static inline bool softmax_has_dtype_override(const nnvm::NodeAttrs& attrs) {
const SoftmaxParam& param = nnvm::get<SoftmaxParam>(attrs.parsed);
return param.dtype.has_value() && param.dtype.value() != -1;
}

static inline bool SoftmaxOpType(const nnvm::NodeAttrs& attrs,
std::vector<int>* in_attrs,
std::vector<int>* out_attrs) {
CHECK_EQ(in_attrs->size(), 1);
CHECK_EQ(out_attrs->size(), 1);
const SoftmaxParam& param = nnvm::get<SoftmaxParam>(attrs.parsed);

int arg_dtype = param.dtype.has_value() ? param.dtype.value() : -1;
int in_dtype = (*in_attrs)[0];
int out_dtype = (*out_attrs)[0];

if (out_dtype != -1 && in_dtype != -1) {
TYPE_ASSIGN_CHECK(*out_attrs, 0, arg_dtype);
TYPE_ASSIGN_CHECK(*in_attrs, 0, out_dtype);
return true;
} else if (in_dtype != -1) {
if (arg_dtype != -1) {
TYPE_ASSIGN_CHECK(*out_attrs, 0, arg_dtype);
} else {
TYPE_ASSIGN_CHECK(*out_attrs, 0, in_dtype);
}
return true;
} else if (out_dtype != -1) {
TYPE_ASSIGN_CHECK(*out_attrs, 0, arg_dtype);
TYPE_ASSIGN_CHECK(*in_attrs, 0, out_dtype);
if (softmax_has_dtype_override(attrs)) {
TYPE_ASSIGN_CHECK(*out_attrs, 0, param.dtype.value());
type_assign(&(*in_attrs)[0], (*out_attrs)[0]);
return true;
} else {
if (arg_dtype != -1) {
TYPE_ASSIGN_CHECK(*out_attrs, 0, arg_dtype);
}
return false;
return ElemwiseType<1, 1>(attrs, in_attrs, out_attrs);
}
}

static inline bool softmax_has_dtype_override(const nnvm::NodeAttrs& attrs) {
const SoftmaxParam& param = nnvm::get<SoftmaxParam>(attrs.parsed);
return param.dtype.has_value() && param.dtype.value() != -1;
}

static inline bool SoftmaxGradOpShape(const nnvm::NodeAttrs& attrs,
std::vector<TShape> *in_attrs,
std::vector<TShape> *out_attrs) {
Expand Down

0 comments on commit 811a37a

Please sign in to comment.