Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Fix Cached_op with static_shape=true (#15298)
Browse files Browse the repository at this point in the history
* Fix

* run ci
  • Loading branch information
ZhennanQin authored and pengzhao-intel committed Jun 27, 2019
1 parent ba30644 commit 582489c
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 30 deletions.
7 changes: 5 additions & 2 deletions src/imperative/cached_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ struct CachedOp::CachedOpState {

std::vector<NDArray> buff;
std::vector<NDArray*> arrays;
std::vector<NDArray*> arrays_with_in_out;
std::vector<OpReqType> array_reqs;

std::vector<OpStatePtr> op_states;
Expand Down Expand Up @@ -762,7 +763,8 @@ OpStatePtr CachedOp::StaticForward(
// We are going to add input and output arrays to the array list.
// The input and output arrays should only be valid for this run,
// so we shouldn't modify the state's array list.
auto arrays = state.arrays;
state.arrays_with_in_out = state.arrays;
auto& arrays = state.arrays_with_in_out;
if (config_.static_shape) {
for (auto i : config_.param_indices) {
auto nid = idx.input_nodes()[i];
Expand Down Expand Up @@ -1063,7 +1065,8 @@ void CachedOp::StaticBackward(
// We are going to add input and output arrays to the array list.
// The input and output arrays should only be valid for this run,
// so we shouldn't modify the state's array list.
auto arrays = state.arrays;
state.arrays_with_in_out = state.arrays;
auto& arrays = state.arrays_with_in_out;
for (size_t i = 0; i < state.info.bwd_input_eid.size(); ++i) {
auto eid = state.info.bwd_input_eid[i];
if (eid == kEidNotExist) {
Expand Down
47 changes: 19 additions & 28 deletions src/nnvm/legacy_op_util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,6 @@ class OperatorState {
public:
OperatorState(Operator *opr, const OperatorProperty *prop) {
opr_ = opr;
fwd_init_ = bwd_init_ = false;

in_data_fwd_.resize(prop->ListArguments().size());
in_data_bwd_.resize(prop->ListArguments().size());
Expand Down Expand Up @@ -110,47 +109,39 @@ class OperatorState {
const std::vector<TBlob>& inputs,
const std::vector<OpReqType>& req,
const std::vector<TBlob>& outputs) {
if (!fwd_init_) {
CHECK_EQ(inputs.size(), in_data_fwd_.size() + aux_data_.size());
CHECK_EQ(outputs.size(), out_data_.size());
// in_data_bwd_ has the same tblobs as the ones in in_data_fwd_, except that the ones
// referred by arg_data_ptr_ will be overriden
for (size_t i = 0; i < in_data_fwd_.size(); ++i) in_data_fwd_[i] = inputs[i];
for (size_t i = 0; i < in_data_fwd_.size(); ++i) in_data_bwd_[i] = inputs[i];
for (size_t i = 0; i < aux_data_.size(); ++i) {
aux_data_[i] = inputs[i + in_data_fwd_.size()];
}
for (size_t i = 0; i < out_data_.size(); ++i) out_data_[i] = outputs[i];
fwd_init_ = true;
CHECK_EQ(inputs.size(), in_data_fwd_.size() + aux_data_.size());
CHECK_EQ(outputs.size(), out_data_.size());
// in_data_bwd_ has the same tblobs as the ones in in_data_fwd_, except that the ones
// referred by arg_data_ptr_ will be overriden
for (size_t i = 0; i < in_data_fwd_.size(); ++i) in_data_fwd_[i] = inputs[i];
for (size_t i = 0; i < in_data_fwd_.size(); ++i) in_data_bwd_[i] = inputs[i];
for (size_t i = 0; i < aux_data_.size(); ++i) {
aux_data_[i] = inputs[i + in_data_fwd_.size()];
}
for (size_t i = 0; i < out_data_.size(); ++i) out_data_[i] = outputs[i];
opr_->Forward(ctx, in_data_fwd_, req, out_data_, aux_data_);
}

void Backward(const OpContext &ctx,
const std::vector<TBlob>& inputs,
const std::vector<OpReqType>& req,
const std::vector<TBlob>& outputs) {
if (!bwd_init_) {
CHECK(fwd_init_);
CHECK_EQ(arg_data_ptr_.size() + aux_data_.size(), inputs.size());
// override tblobs pointed by arg_data_ptr_ since they might not contain
// initialized data during forward pass.
for (size_t i = 0; i < arg_data_ptr_.size(); ++i) {
*arg_data_ptr_[i] = inputs[i];
}
for (size_t i = 0; i < aux_data_.size(); ++i) {
aux_data_[i] = inputs[inputs.size() - aux_data_.size() + i];
}
CHECK_EQ(outputs.size(), in_grad_.size());
for (size_t i = 0; i < outputs.size(); ++i) in_grad_[i] = outputs[i];
bwd_init_ = true;
CHECK_EQ(arg_data_ptr_.size() + aux_data_.size(), inputs.size());
// override tblobs pointed by arg_data_ptr_ since they might not contain
// initialized data during forward pass.
for (size_t i = 0; i < arg_data_ptr_.size(); ++i) {
*arg_data_ptr_[i] = inputs[i];
}
for (size_t i = 0; i < aux_data_.size(); ++i) {
aux_data_[i] = inputs[inputs.size() - aux_data_.size() + i];
}
CHECK_EQ(outputs.size(), in_grad_.size());
for (size_t i = 0; i < outputs.size(); ++i) in_grad_[i] = outputs[i];
opr_->Backward(ctx, out_grad_, in_data_bwd_, out_data_, req, in_grad_, aux_data_);
}

private:
Operator *opr_;
bool fwd_init_, bwd_init_;
// input data blobs for forward and backward
// in_data_fwd_ and in_data_bwd_ will hold different tblobs when StorageFallbackOpExecutor
// performs storage fallback on a non-default input NDArray. The one in in_data_fwd_ is
Expand Down

0 comments on commit 582489c

Please sign in to comment.