Skip to content

Commit

Permalink
Updated tests, fixed bugs
Browse files Browse the repository at this point in the history
  • Loading branch information
Jose Luis Contreras committed Feb 14, 2019
1 parent 7e69503 commit 5288e9d
Show file tree
Hide file tree
Showing 4 changed files with 39 additions and 37 deletions.
2 changes: 1 addition & 1 deletion src/operator/tensor/digitize_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ bin edges. Within this last dimension, bins must be strictly monotonically incre
[](const NodeAttrs &attrs) {
return std::vector<std::string>{ "data", "bins" };
})
.set_attr<nnvm::FInferShape>("FInferShape", InferShape)
.set_attr<nnvm::FInferShape>("FInferShape", DigitizeOpShape)
.set_attr<nnvm::FInferType>("FInferType", DigitizeOpType)
.set_attr<FCompute>("FCompute", DigitizeOpForward<cpu>)
.add_argument("data", "NDArray-or-Symbol", "Input data ndarray")
Expand Down
2 changes: 1 addition & 1 deletion src/operator/tensor/digitize_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ namespace op {

template<>
struct ForwardKernel<gpu> {
template<typename DType, typename BType>
template<typename DType, typename OType>
static MSHADOW_XINLINE void Map(int i,
DType *in_data,
OType *out_data,
Expand Down
19 changes: 12 additions & 7 deletions src/operator/tensor/digitize_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ struct DigitizeParam : public dmlc::Parameter<DigitizeParam> {
}
};

bool InferShape(const nnvm::NodeAttrs &attrs,
inline bool DigitizeOpShape(const nnvm::NodeAttrs &attrs,
std::vector<TShape> *in_attrs,
std::vector<TShape> *out_attrs) {
using namespace mshadow;
Expand Down Expand Up @@ -178,10 +178,12 @@ struct ForwardKernel<cpu> {

template<typename DType>
struct CheckMonotonic {
static MSHADOW_XINLINE void Map(int i, int bins_length, DType *bins) {
static MSHADOW_XINLINE void Map(int i, int bins_length, DType *bins, bool* mono) {
if ((i + 1) % bins_length != 0) {
CHECK_LT(bins[i], bins[i + 1]) << "Bins vector is not strictly monotonic and increasing";
} // TODO: Make sure the next element in bins is actually bins[i+1]
if(bins[i] >= bins[i + 1]){
*mono = false;
}
}
}

return true;
Expand Down Expand Up @@ -224,17 +226,20 @@ void DigitizeOpForward(const nnvm::NodeAttrs &attrs,
MSHADOW_TYPE_SWITCH(data.type_flag_, DType, {

// Verify bins is strictly monotonic
bool mono = true;
auto bins_length = bins.shape_[bins.ndim() - 1];
mxnet_op::Kernel<CheckMonotonic<DType>, xpu>::Launch(s, bins.Size(), bins_length,
bins.dptr<DType>());
bins.dptr<DType>(), &mono);
CHECK(mono) << "Bins vector is not strictly monotonic and increasing";


MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, OType, {
auto batch_size = data.shape_.ProdShape(bins.ndim() - 1, data.ndim());

mxnet_op::Kernel<ForwardKernel<xpu>, xpu>::Launch(s,
outputs[ 0 ].Size(),
outputs[0].Size(),
data.dptr<DType>(),
outputs[ 0 ].dptr<OType>(),
outputs[0].dptr<OType>(),
bins.dptr<DType>(),
batch_size,
bins_length,
Expand Down
53 changes: 25 additions & 28 deletions tests/python/unittest/test_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -7329,59 +7329,56 @@ def test_invalid_max_pooling_pad_type_same():
@with_seed()
def test_digitize():
def f(x, bins, right):
a = np.zeros_like(x)
if bins.ndim == 1:
a = np.digitize(x, bins, right=right)
elif bins.ndim == 2:
# verify x.shape[0] == bins.shape[0]
for idx, batch in enumerate(zip(x, bins)):
a[idx] = np.digitize(batch[0], batch[1], right=right)
x = x.asnumpy()
bins = bins.asnumpy()

return a
N = bins.ndim
x1 = x.reshape((-1,) + x.shape[N-1:])
x2 = np.atleast_2d((x1.reshape(x1.shape[0], -1)))
b1 = np.atleast_2d(bins.reshape(-1, bins.shape[-1]))
out = np.zeros_like(x2)

for idx,batch in enumerate(zip(x2, b1)):
out[idx] = np.digitize(batch[0], batch[1])

return out.reshape(x.shape)

data = mx.symbol.Variable('data')
dig_sym = mx.sym.tensor.digitize(data=data)
x = np.random.randn(100, 200) * 100
dig_sym = mx.sym.digitize(data=data)
x = mx.nd.array(np.random.randn(100, 200) * 100)

# Test with 1D bin vector
bins = np.linspace(0, np.random.randint(90)+10, np.random.randint(20))
bins = mx.nd.array(np.linspace(0, np.random.randint(90)+10, np.random.randint(20)))

# right = False (default)
output = mx.nd.tensor.digitize([mx.nd.array(x), bins])
output = mx.nd.digitize(x, bins)
expected = f(x, bins, False)
assert_equal(output.asnumpy(), expected)
check_symbolic_forward(dig_sym, [x, bins], [expected]) # TODO: Should the symbolic op be checked in all cases?
check_symbolic_forward(dig_sym, [x, bins], [expected])

# right = True
output = mx.nd.tensor.digitize([mx.nd.array(x), bins], right=True)
output = mx.nd.digitize(x, bins, right=True)
expected = f(x, bins, True)
assert_equal(output.asnumpy(), expected)

# Test with 2D bins vector
x = np.random.randn(2, 200) * 100
bins = np.vstack(np.linspace(np.linspace(0, np.random.randint(90)+10, np.random.randint(20)),
np.linspace(0, np.random.randint(90) + 10, np.random.randint(20))))
output = mx.nd.tensor.digitize([mx.nd.array(x), bins])
output = mx.nd.digitize(x, bins)
assert_equal(output.asnumpy(), f(x, bins, False))

# Test exception raising
def test_invalid_inputs():
# For all the assert_exception below:
# TODO: Check exception type
# TODO 2: Verify the arguments are right

# bins not monotonic and ascending
bins = np.ones(10)
assert_exception(mx.nd.tensor.digitize, MXNetError, [x, bins])
bins = mx.nd.ones(10)
assert_exception(mx.nd.digitize, MXNetError, [x, bins])

# bins.ndim > 2
bins = np.tile(np.linspace(0, 10, 10), (4, 1))
assert_exception(mx.nd.tensor.digitize, MXNetError, [x, bins])

# bins.ndim > data.ndim (only applies when bins.ndim=2, data.ndim=1)
bins = np.tile(np.linspace(0, 10, 10), (2, 1))
x = np.random.randn(100) * 100
assert_exception(mx.nd.tensor.digitize, MXNetError, [x, bins])
# bins.ndim > data.ndim
bins = mx.nd.array(np.tile(np.linspace(0, 10, 10), (2, 1)))
x = mx.nd.array(np.random.randn(100) * 100)
assert_exception(mx.nd.digitize, MXNetError, [x, bins])

test_invalid_inputs()

Expand Down

0 comments on commit 5288e9d

Please sign in to comment.