From 3b616c2fd20ab6d4e7b76c698480aae592df6b1e Mon Sep 17 00:00:00 2001 From: Da Zheng Date: Thu, 5 Jul 2018 00:54:35 +0000 Subject: [PATCH] extend _CachedOp a regular operator. --- src/imperative/cached_op.cc | 342 ++++++++++++++++++++++++- src/imperative/cached_op.h | 31 +++ src/operator/operator_common.h | 12 +- tests/python/unittest/test_subgraph.py | 149 +++++++++++ 4 files changed, 518 insertions(+), 16 deletions(-) create mode 100644 tests/python/unittest/test_subgraph.py diff --git a/src/imperative/cached_op.cc b/src/imperative/cached_op.cc index 0c4c1e60208f..defe9df22c7d 100644 --- a/src/imperative/cached_op.cc +++ b/src/imperative/cached_op.cc @@ -874,7 +874,6 @@ OpStatePtr CachedOp::Forward( return op_state; } - void CachedOp::DynamicBackward( const bool retain_graph, const OpStatePtr& op_state, @@ -1067,6 +1066,130 @@ void CachedOp::Backward( Engine::Get()->set_bulk_size(prev_bulk_size); } +struct CachedOpActualState { + std::shared_ptr op; + OpStatePtr forward_state; + + explicit CachedOpActualState(std::shared_ptr op) { + this->op = op; + } +}; + +void CachedOpForward(const OpStatePtr& state_ptr, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + CachedOpActualState &s = state_ptr.get_state(); + std::vector in_bufs = inputs; + std::vector out_bufs = outputs; + std::vector in_ptrs(in_bufs.size()); + std::vector out_ptrs(out_bufs.size()); + for (size_t i = 0; i < in_ptrs.size(); i++) + in_ptrs[i] = &in_bufs[i]; + for (size_t i = 0; i < out_ptrs.size(); i++) + out_ptrs[i] = &out_bufs[i]; + + // Set is_recording correct for the imperative executor. + bool orig_is_record; + if (ctx.need_grad) + orig_is_record = Imperative::Get()->set_is_recording(true); + else + orig_is_record = Imperative::Get()->is_recording(); + // Set is_training correct for the imperative executor. + bool orig_is_train; + if (ctx.is_train) + orig_is_train = Imperative::Get()->set_is_training(true); + else + orig_is_train = Imperative::Get()->is_training(); + s.forward_state = s.op->Forward(nullptr, in_ptrs, out_ptrs); + Imperative::Get()->set_is_training(orig_is_train); + Imperative::Get()->set_is_recording(orig_is_record); + // The arrays in out_ptrs may be changed by CachedOp. + // If it is, we need to copy data back. + for (size_t i = 0; i < out_bufs.size(); i++) + if (!out_bufs[i].IsSame(outputs[i])) + CopyFromTo(out_bufs[i], outputs[i]); +} + +void CachedOpBackward(const OpStatePtr& state_ptr, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + using namespace nnvm; + using namespace imperative; + CachedOpActualState &s = state_ptr.get_state(); + std::vector in_bufs = inputs; + std::vector out_bufs = outputs; + std::vector in_ptrs; + std::vector out_ptrs; + CHECK_EQ(s.op->num_backward_inputs(), inputs.size()); + in_ptrs.reserve(s.op->num_backward_inputs()); + out_ptrs.reserve(s.op->num_inputs()); + + const std::vector &save_inputs = s.op->save_inputs(); + const std::vector &save_outputs = s.op->save_outputs(); + size_t bwd_in_dep = s.op->num_inputs(); + size_t bwd_out_dep = s.op->num_outputs(); + CHECK(s.op->num_backward_inputs() > bwd_in_dep + bwd_out_dep); + size_t bwd_ograd_dep = s.op->num_backward_inputs() - bwd_in_dep - bwd_out_dep; + + // Find inputs, outputs and ograds + auto ograds_begin = in_bufs.begin(); + auto ograds_end = in_bufs.begin() + bwd_ograd_dep; + auto in_begin = ograds_end; + auto in_end = in_begin + bwd_in_dep; + auto out_begin = in_end; + auto out_end = in_bufs.end(); + + for (auto it = ograds_begin; it != ograds_end; it++) + in_ptrs.push_back(&(*it)); + + CHECK_EQ(save_inputs.size(), in_end - in_begin); + CHECK_EQ(s.op->num_outputs(), out_end - out_begin); + for (auto it = in_begin; it != in_end; it++) { + auto i = it - in_begin; + if (save_inputs[i]) + in_ptrs.push_back(&(*it)); + } + for (auto it = out_begin; it != out_end; it++) { + auto i = it - out_begin; + if (save_outputs[i]) + in_ptrs.push_back(&(*it)); + } + CHECK_EQ(in_ptrs.size(), s.op->num_backward_inputs()); + for (size_t i = 0; i < out_bufs.size(); i++) + out_ptrs.push_back(&out_bufs[i]); + CHECK_EQ(out_ptrs.size(), s.op->num_backward_outputs()); + // Set is_training correct for the imperative executor. + bool orig_is_train; + if (ctx.is_train) + orig_is_train = Imperative::Get()->set_is_training(true); + else + orig_is_train = Imperative::Get()->is_training(); + // TODO(zhengda) is it right to use false here? + s.op->Backward(false, s.forward_state, in_ptrs, req, out_ptrs); + Imperative::Get()->set_is_training(orig_is_train); + + // Clean up what we recorded. + s.forward_state.reset(); + + // The arrays in out_ptrs may be changed by CachedOp. + // If it is, we need to copy data back. + for (size_t i = 0; i < out_bufs.size(); i++) + if (!out_bufs[i].IsSame(outputs[i])) + CopyFromTo(out_bufs[i], outputs[i]); +} + +OpStatePtr CreateCachedOpState(const NodeAttrs& attrs, + Context ctx, + const std::vector& in_shapes, + const std::vector& in_types) { + const CachedOpPtr& op = nnvm::get(attrs.parsed); + return OpStatePtr::Create(op); +} + bool CachedOp::ForwardStorageType(const nnvm::NodeAttrs& attrs, const int dev_mask, DispatchMode* dispatch_mode, @@ -1143,6 +1266,155 @@ bool CachedOp::BackwardStorageType(const nnvm::NodeAttrs& attrs, return true; } +bool CachedOp::ForwardInferShape(const nnvm::NodeAttrs& attrs, + std::vector *in_shapes, + std::vector *out_shapes) { + using namespace exec; + nnvm::Graph g(fwd_graph_); + const auto& idx_g = g.indexed_graph(); + CHECK_EQ(idx_g.input_nodes().size(), in_shapes->size()); + CHECK_EQ(idx_g.outputs().size(), out_shapes->size()); + + // TODO(zhengda) we can cache the shape vector. + // Put the input and output shapes to the shape vector. + nnvm::ShapeVector shapes(idx_g.num_node_entries()); + const auto &input_nids = idx_g.input_nodes(); + CHECK_EQ(input_nids.size(), in_shapes->size()); + for (size_t i = 0; i < in_shapes->size(); i++) { + auto eid = idx_g.entry_id(input_nids[i], 0); + shapes[eid] = in_shapes->at(i); + } + CHECK_EQ(g.outputs.size(), out_shapes->size()); + for (size_t i = 0; i < out_shapes->size(); i++) { + auto eid = idx_g.entry_id(g.outputs[i]); + shapes[eid] = out_shapes->at(i); + } + + // Infer shape of the graph. + g.attrs["shape"] = std::make_shared(std::move(shapes)); + g = exec::InferShape(std::move(g)); + + // Copy the inferred shape back to the input shapes and the output shapes. + shapes = g.GetAttr("shape"); + // assign to in_shapes + for (size_t i = 0; i < in_shapes->size(); ++i) { + const auto eid = idx_g.entry_id(input_nids[i], 0); + SHAPE_ASSIGN_CHECK(*in_shapes, i, shapes[eid]); + } + // assign to out_shapes + for (size_t i = 0; i < g.outputs.size(); ++i) { + const auto eid = idx_g.entry_id(g.outputs[i]); + SHAPE_ASSIGN_CHECK(*out_shapes, i, shapes[eid]); + } + // Check if we have inferred the shapes correctly. + return g.GetAttr("shape_num_unknown_nodes") == 0; +} + +bool CachedOp::ForwardInferType(const nnvm::NodeAttrs& attrs, + std::vector *in_types, + std::vector *out_types) { + nnvm::Graph g(fwd_graph_); + const auto& idx_g = g.indexed_graph(); + CHECK_EQ(idx_g.input_nodes().size(), in_types->size()); + CHECK_EQ(idx_g.outputs().size(), out_types->size()); + + // TODO(zhengda) we can cache the shape vector. + // Put the input and output data types to the dtype vector. + nnvm::DTypeVector types(idx_g.num_node_entries(), -1); + const auto &input_nids = idx_g.input_nodes(); + CHECK_EQ(input_nids.size(), in_types->size()); + for (size_t i = 0; i < in_types->size(); i++) { + auto eid = idx_g.entry_id(input_nids[i], 0); + types[eid] = in_types->at(i); + } + CHECK_EQ(g.outputs.size(), out_types->size()); + for (size_t i = 0; i < out_types->size(); i++) { + auto eid = idx_g.entry_id(g.outputs[i]); + types[eid] = out_types->at(i); + } + + // Infer data type of the graph. + g.attrs["dtype"] = std::make_shared(std::move(types)); + g = exec::InferType(std::move(g)); + + types = g.GetAttr("dtype"); + // assign to in_types + for (size_t i = 0; i < in_types->size(); ++i) { + const auto eid = idx_g.entry_id(input_nids[i], 0); + TYPE_ASSIGN_CHECK(*in_types, i, types[eid]); + } + // assign to out_types + for (size_t i = 0; i < g.outputs.size(); ++i) { + const auto eid = idx_g.entry_id(g.outputs[i]); + TYPE_ASSIGN_CHECK(*out_types, i, types[eid]); + } + // Check if we have inferred the dtypes correctly. + return g.GetAttr("dtype_num_unknown_nodes") == 0; +} + +std::vector CachedOp::MutableInputs() const { + nnvm::Symbol sym = GetForwardSym(); + const std::vector input_names = sym.ListInputNames(nnvm::Symbol::kAll); + const std::vector immutable_input_names = + sym.ListInputNames(nnvm::Symbol::kReadOnlyArgs); + const std::vector mutable_input_names = + sym.ListInputNames(nnvm::Symbol::kAuxiliaryStates); + CHECK_EQ(immutable_input_names.size() + mutable_input_names.size(), input_names.size()); + std::vector ret; + size_t i1 = 0, i2 = 0; + for (size_t i = 0; i < input_names.size(); ++i) { + if (i1 < immutable_input_names.size() && input_names[i] == immutable_input_names[i1]) { + ++i1; + } else { + CHECK(i2 < mutable_input_names.size()); + CHECK_EQ(input_names[i], mutable_input_names[i2]); + ++i2; + ret.push_back(i); + } + } + return ret; +} + +std::vector CachedOp::GetResourceRequest() const { + nnvm::Symbol sym = GetForwardSym(); + static auto& fresource = Op::GetAttr("FResourceRequest"); + std::set resource_types; + DFSVisit(sym.outputs, [&](const nnvm::NodePtr& node) { + if (!node->is_variable() && fresource.count(node->op())) { + for (ResourceRequest& r : fresource[node->op()](node->attrs)){ + resource_types.insert(r.type); + } + } + }); + return std::vector(resource_types.begin(), resource_types.end()); +} + +void CachedOpParamParser(nnvm::NodeAttrs* attrs) { + CachedOpConfig param; + try { + param.Init(attrs->dict); + } catch (const dmlc::ParamError& e) { + std::ostringstream os; + os << e.what(); + os << ", in operator " << attrs->op->name << "(" + << "name=\"" << attrs->name << "\""; + for (const auto& k : attrs->dict) { + os << ", " << k.first << "=\"" << k.second << "\""; + } + os << ")"; + throw dmlc::ParamError(os.str()); + } + if (!param.subgraph.empty()) { + nnvm::Graph g = nnvm::pass::LoadJSON(param.subgraph); + CHECK(!g.outputs.empty()); + nnvm::Symbol sym; + sym.outputs = g.outputs; + std::vector > flags; + for (auto it = attrs->dict.begin(); it != attrs->dict.end(); it++) + flags.emplace_back(it->first, it->second); + attrs->parsed = CachedOpPtr(new CachedOp(sym, flags)); + } +} NNVM_REGISTER_OP(_CachedOp) .set_num_inputs([](const NodeAttrs& attrs) { @@ -1153,19 +1425,63 @@ NNVM_REGISTER_OP(_CachedOp) const CachedOpPtr& op = nnvm::get(attrs.parsed); return op->num_outputs(); }) -.set_attr("FInferStorageType", [](const nnvm::NodeAttrs& attrs, - const int dev_mask, - DispatchMode* dispatch_mode, - std::vector *in_attrs, - std::vector *out_attrs) { - const CachedOpPtr& op = nnvm::get(attrs.parsed); - return op->ForwardStorageType(attrs, dev_mask, dispatch_mode, in_attrs, out_attrs); - }) +.set_attr_parser(CachedOpParamParser) .set_attr("FGradient", [](const nnvm::NodePtr& n, const std::vector& ograds) { const CachedOpPtr& op = nnvm::get(n->attrs.parsed); return op->Gradient(n, ograds); - }); + }) +.set_attr("FListInputNames", + [](const nnvm::NodeAttrs& attrs) { + const CachedOpPtr& op = nnvm::get(attrs.parsed); + return op->ListForwardInputNames(); + }) +.set_attr("FListOutputNames", + [](const nnvm::NodeAttrs& attrs) { + const CachedOpPtr& op = nnvm::get(attrs.parsed); + return op->ListForwardOutputNames(); + }) +.set_attr("FCreateOpState", CreateCachedOpState) +.set_attr("FInferShape", + [](const nnvm::NodeAttrs& attrs, + std::vector *in_shapes, + std::vector *out_shapes) { + const CachedOpPtr& op = nnvm::get(attrs.parsed); + return op->ForwardInferShape(attrs, in_shapes, out_shapes); + }) +.set_attr("FInferType", + [](const nnvm::NodeAttrs& attrs, + std::vector *in_types, + std::vector *out_types) { + const CachedOpPtr& op = nnvm::get(attrs.parsed); + return op->ForwardInferType(attrs, in_types, out_types); + }) +.set_attr("FInferStorageType", + [](const nnvm::NodeAttrs& attrs, + const int dev_mask, + DispatchMode* dispatch_mode, + std::vector* in_stypes, + std::vector* out_stypes) { + const CachedOpPtr& op = nnvm::get(attrs.parsed); + return op->ForwardStorageType(attrs, dev_mask, dispatch_mode, in_stypes, out_stypes); + }) +.set_attr("FStatefulComputeEx", CachedOpForward) +.set_attr("FStatefulComputeEx", CachedOpForward) +.set_attr("FMutateInputs", + [](const nnvm::NodeAttrs& attrs) { + const CachedOpPtr& op = nnvm::get(attrs.parsed); + return op->MutableInputs(); + }) +.set_attr("FResourceRequest", + [](const nnvm::NodeAttrs& attrs) { + const CachedOpPtr& op = nnvm::get(attrs.parsed); + return op->GetResourceRequest(); + }) +.set_attr("FExecType", + [](const nnvm::NodeAttrs& attrs) { + return ExecType::kSubgraphExec; + }) +.add_argument("data", "NDArray-or-Symbol[]", "input data list"); NNVM_REGISTER_OP(_backward_CachedOp) .set_num_inputs([](const NodeAttrs& attrs){ @@ -1184,6 +1500,12 @@ NNVM_REGISTER_OP(_backward_CachedOp) const CachedOpPtr& op = nnvm::get(attrs.parsed); return op->BackwardStorageType(attrs, dev_mask, dispatch_mode, in_attrs, out_attrs); }) +.set_attr("FStatefulComputeEx", CachedOpBackward) +.set_attr("FStatefulComputeEx", CachedOpBackward) +.set_attr("FExecType", + [](const nnvm::NodeAttrs& attrs) { + return ExecType::kSubgraphExec; + }) .set_attr("TIsLayerOpBackward", true) .set_attr("TIsBackward", true); diff --git a/src/imperative/cached_op.h b/src/imperative/cached_op.h index 4f4dfdcc14dd..138e0a38a017 100644 --- a/src/imperative/cached_op.h +++ b/src/imperative/cached_op.h @@ -37,6 +37,7 @@ struct CachedOpConfig : public dmlc::Parameter { bool static_shape; nnvm::Tuple data_indices; nnvm::Tuple param_indices; + std::string subgraph; DMLC_DECLARE_PARAMETER(CachedOpConfig) { DMLC_DECLARE_FIELD(static_alloc) .set_default(false) @@ -62,6 +63,9 @@ struct CachedOpConfig : public dmlc::Parameter { DMLC_DECLARE_FIELD(param_indices) .set_default(nnvm::Tuple()) .describe("Position of parameters."); + DMLC_DECLARE_FIELD(subgraph) + .set_default(std::string("")) + .describe("JSON string of a subgraph."); } }; @@ -80,6 +84,10 @@ class CachedOp { uint32_t num_backward_inputs() const { return bwd_ograd_dep_.size() + bwd_in_dep_.size() + bwd_out_dep_.size(); } + uint32_t num_backward_outputs() const { + auto &idx = fwd_graph_.indexed_graph(); + return idx.input_nodes().size() - idx.mutable_input_nodes().size(); + } std::vector& save_inputs() { return save_inputs_; } @@ -116,6 +124,24 @@ class CachedOp { DispatchMode* dispatch_mode, std::vector *in_attrs, std::vector *out_attrs); + bool ForwardInferShape( + const nnvm::NodeAttrs& attrs, + std::vector *in_shapes, + std::vector *out_shapes); + bool ForwardInferType( + const nnvm::NodeAttrs& attrs, + std::vector *in_types, + std::vector *out_types); + std::vector ListForwardInputNames() const { + nnvm::Symbol sym = GetForwardSym(); + return sym.ListInputNames(nnvm::Symbol::kAll); + } + std::vector ListForwardOutputNames() const { + nnvm::Symbol sym = GetForwardSym(); + return sym.ListOutputNames(); + } + std::vector MutableInputs() const; + std::vector GetResourceRequest() const; private: struct GraphInfo; @@ -167,6 +193,11 @@ class CachedOp { const std::vector& inputs, const std::vector& reqs, const std::vector& outputs); + nnvm::Symbol GetForwardSym() const { + nnvm::Symbol sym; + sym.outputs = fwd_graph_.outputs; + return sym; + } CachedOpConfig config_; nnvm::Graph fwd_graph_; diff --git a/src/operator/operator_common.h b/src/operator/operator_common.h index 29112939a22f..6a4c3d027075 100644 --- a/src/operator/operator_common.h +++ b/src/operator/operator_common.h @@ -221,7 +221,7 @@ inline bool dispatch_mode_assign(DispatchMode *y, const DispatchMode& x) { */ #define SHAPE_ASSIGN_CHECK(shape_array, index, shape) \ { \ - if (!shape_assign(&(shape_array)[index], TShape(shape))) { \ + if (!::mxnet::op::shape_assign(&(shape_array)[index], TShape(shape))) { \ std::ostringstream os; \ os << "Shape inconsistent, Provided = " << (shape_array)[index] << ','\ << " inferred shape=" << shape; \ @@ -238,11 +238,11 @@ inline bool dispatch_mode_assign(DispatchMode *y, const DispatchMode& x) { */ #define TYPE_ASSIGN_CHECK(type_array, index, type) \ { \ - if (!type_assign(&(type_array)[index], type)) { \ + if (!::mxnet::op::type_assign(&(type_array)[index], type)) { \ std::ostringstream os; \ os << "Type inconsistent, Provided = " \ - << type_string((type_array)[index]) << ',' \ - << " inferred type = " << type_string(type); \ + << ::mxnet::op::type_string((type_array)[index]) << ',' \ + << " inferred type = " << ::mxnet::op::type_string(type); \ throw ::mxnet::op::InferTypeError(os.str(), index); \ } \ } @@ -291,8 +291,8 @@ inline bool dispatch_mode_assign(DispatchMode *y, const DispatchMode& x) { #define UNIFORM_TYPE_CHECK(type, expected, arg) \ { \ CHECK_EQ(type, expected) << "This layer requires uniform type. " \ - << "Expected '" << type_string(expected) \ - << "' v.s. given '" << type_string(type) \ + << "Expected '" << ::mxnet::op::type_string(expected) \ + << "' v.s. given '" << ::mxnet::op::type_string(type) \ << "' at '" << arg << "'"; \ } diff --git a/tests/python/unittest/test_subgraph.py b/tests/python/unittest/test_subgraph.py new file mode 100644 index 000000000000..338d3ae781f4 --- /dev/null +++ b/tests/python/unittest/test_subgraph.py @@ -0,0 +1,149 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# pylint: skip-file +from __future__ import print_function +import numpy as np +import mxnet as mx +import copy +import math +import ctypes +import random +import itertools +from numpy.testing import assert_allclose, assert_array_equal +from mxnet.test_utils import * +from mxnet.base import py_str, MXNetError, _as_list, SymbolHandle, check_call, _LIB, c_handle_array, mx_uint +from common import setup_module, with_seed, teardown +import unittest +from mxnet.gluon.model_zoo.vision import get_model + +def make_subgraph(subg, *args): + js = subg.tojson() + return mx.sym._internal._CachedOp(*args, subgraph=js) + +@with_seed() +def test_make_subgraph(): + def make_subgraph1(stype): + a = mx.symbol.Variable(name='a', stype=stype) + b = mx.symbol.Variable(name='b', stype=stype) + c = a * b + d = c * 2 + + a1 = mx.symbol.Variable(name='a', stype=stype) + b1 = mx.symbol.Variable(name='b', stype=stype) + y = make_subgraph(c, a1, b1) + y = y * 2 + + s = (10, 10) + a_arr = mx.nd.array(np.random.normal(-0.1, 0.1, size=s), + ctx=default_context()).tostype(stype) + b_arr = mx.nd.array(np.random.normal(-0.1, 0.1, size=s), + ctx=default_context()).tostype(stype) + return (d, y, {'a': a_arr, 'b': b_arr}, {}) + + def create_weights(shapes, names): + nd_dict = {} + sym_dict = {} + assert len(shapes) == len(names) + for i in range(len(shapes)): + sym_dict[names[i]] = mx.symbol.Variable(names[i]) + nd_dict[names[i]] = mx.nd.array(np.ones(shapes[i]), ctx=default_context()) + return (nd_dict, sym_dict) + + def make_subgraph_weight(orig, shape, stype): + arg_shapes, out_shapes, aux_shapes = orig.infer_shape(data=shape) + weight_shapes = arg_shapes[1:] + weight_names = orig.list_arguments()[1:] + weight_dict, weight_sym_dict = create_weights(weight_shapes, weight_names) + aux_dict, aux_sym_dict = create_weights(aux_shapes, orig.list_auxiliary_states()) + + input_dict = copy.deepcopy(weight_sym_dict) + input_dict.update(aux_sym_dict) + input_dict['data'] = mx.symbol.Variable('data', stype=stype) + input_list = [] + for name in orig.list_inputs(): + assert name in input_dict.keys() + input_list.append(input_dict[name]) + subg = make_subgraph(orig, *input_list) + + arr = mx.nd.random.uniform(-1, 1, shape=shape, ctx=default_context()).tostype(stype) + arg_dict = weight_dict + arg_dict['data'] = arr + return (orig, subg, arg_dict, aux_dict) + + def make_subgraph2(stype, out_mean_var): + data = mx.symbol.Variable('data', stype=stype) + orig = mx.symbol.BatchNorm(data, fix_gamma=False, + output_mean_var=out_mean_var, name="batchnorm") + s = (10, 10) + return make_subgraph_weight(orig, s, stype) + + def make_subgraph3(stype): + data = mx.symbol.Variable('data', stype=stype) + conv1 = mx.symbol.Convolution(data=data, kernel=(3, 3), num_filter=16, no_bias=True) + bn1 = mx.symbol.BatchNorm(conv1, fix_gamma=False, output_mean_var=False) + conv2 = mx.symbol.Convolution(data=data, kernel=(3, 3), num_filter=16, no_bias=True) + bn2 = mx.symbol.BatchNorm(conv2, fix_gamma=False, output_mean_var=False) + orig = bn1 + bn2 + s = (1, 3, 32, 32) + return make_subgraph_weight(orig, s, stype) + + def make_subgraph4(stype): + model = get_model('resnet18_v1') + model.hybridize() + model.initialize() + s = (1, 3, 32, 32) + data = mx.nd.random.normal(shape=s) + out = model(data) + model.export('resnet18') + orig = mx.sym.load('resnet18-symbol.json') + return make_subgraph_weight(orig, s, stype) + + make_subgraphs = [make_subgraph1, + lambda stype: make_subgraph2(stype, False), + lambda stype: make_subgraph2(stype, True), + make_subgraph3] + stypes = ['default', 'row_sparse'] + for make_subg in make_subgraphs: + for stype in stypes: + orig, subg, inputs, aux_states = make_subg(stype) + all_inputs = copy.deepcopy(inputs) + all_inputs.update(aux_states) + args_grad = {key : mx.nd.empty(shape=all_inputs[key].shape) for key in all_inputs.keys()} + e1 = orig.bind(ctx=default_context(), args=all_inputs, args_grad=args_grad, + aux_states=all_inputs) + args_grad = {key : mx.nd.empty(shape=all_inputs[key].shape) for key in all_inputs.keys()} + e2 = subg.bind(ctx=default_context(), args=all_inputs, args_grad=args_grad, + aux_states=all_inputs) + e1.forward() + e2.forward() + for i in range(len(e1.outputs)): + assert_almost_equal(e1.outputs[i].asnumpy(), e2.outputs[i].asnumpy(), + rtol=0.001, atol=0.0001) + + out_grads = [mx.nd.random.uniform(-1, 1, shape=out.shape, ctx=default_context()) + for out in e1.outputs] + e1.backward(out_grads) + e2.backward(out_grads) + for i in range(len(e1.grad_arrays)): + assert_almost_equal(e1.grad_arrays[i].asnumpy(), e2.grad_arrays[i].asnumpy(), + rtol=0.001, atol=0.0001) + + +if __name__ == '__main__': + import nose + nose.runmodule()