Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Add GPU version of boolean_mask op
Browse files Browse the repository at this point in the history
  • Loading branch information
HyperZealot committed Feb 8, 2019
1 parent 26ca37c commit ba9ec22
Show file tree
Hide file tree
Showing 4 changed files with 275 additions and 70 deletions.
68 changes: 2 additions & 66 deletions src/operator/contrib/boolean_mask-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -55,78 +55,14 @@ inline void BooleanMaskForward(const nnvm::NodeAttrs& attrs,
const OpContext &ctx,
const std::vector<NDArray> &inputs,
const std::vector<OpReqType> &req,
const std::vector<NDArray> &outputs) {
// TODO(@junrushao1994): This implementation is a proof-of-concept,
// hence very slow actually. Performance should be improved in the future.
CHECK_EQ(inputs.size(), 2U);
CHECK_EQ(outputs.size(), 1U);
const BooleanMaskParam& param = nnvm::get<BooleanMaskParam>(attrs.parsed);
const int axis = param.axis;
const NDArray &data = inputs[0];
const NDArray &idx = inputs[1];
const NDArray &out = outputs[0];
CHECK_EQ(axis, 0) << "Not supported yet";
CHECK_EQ(data.shape()[axis], idx.shape()[0]);
CHECK_EQ(idx.shape().ndim(), 1U);
// count the number of 1s in `idx`, so that we could know the output dimension
size_t valid_num = 0;
MSHADOW_TYPE_SWITCH(idx.dtype(), DType, {
DType* idx_dptr = idx.data().dptr<DType>();
int length = idx.shape()[0];
for (int i = 0; i < length; i++) {
if (idx_dptr[i]) {
++valid_num;
}
}
});
// set the output shape forcefully
TShape s = data.shape();
s[axis] = valid_num;
const_cast<NDArray &>(out).Init(s);
// do the copy
MSHADOW_TYPE_SWITCH(idx.dtype(), DType, {
DType* idx_dptr = idx.data().dptr<DType>();
int length = idx.shape()[0];
mshadow::Stream<xpu> *stream = ctx.get_stream<xpu>();
for (int i = 0, j = 0; i < length; ++i) {
if (idx_dptr[i]) {
NDArray src = data.At(i);
NDArray dst = out.At(j++);
CHECK(src.shape() == dst.shape());
mxnet_op::copy(stream, dst.data(), src.data());
}
}
});
}
const std::vector<NDArray> &outputs);

template<typename xpu>
inline void BooleanMaskBackward(const nnvm::NodeAttrs& attrs,
const OpContext &ctx,
const std::vector<NDArray> &inputs,
const std::vector<OpReqType> &req,
const std::vector<NDArray> &outputs) {
CHECK_EQ(inputs.size(), 3U);
CHECK_EQ(outputs.size(), 2U);
// inputs: {ograd, data, idx}
// outputs: {igrad_data, igrad_idx}
const NDArray& ograd = inputs[0];
const NDArray& idx = inputs[2];
const NDArray& igrad_data = outputs[0];
MSHADOW_TYPE_SWITCH(idx.dtype(), DType, {
DType* idx_dptr = idx.data().dptr<DType>();
int length = idx.shape()[0];
mshadow::Stream<xpu> *stream = ctx.get_stream<xpu>();
Fill<false>(stream, igrad_data.data(), req[0], 0);
for (int i = 0, j = 0; i < length; ++i) {
if (idx_dptr[i]) {
NDArray src = ograd.At(j++);
NDArray dst = igrad_data.At(i);
CHECK(src.shape() == dst.shape());
mxnet_op::copy(stream, dst.data(), src.data());
}
}
});
}
const std::vector<NDArray> &outputs);

} // namespace op
} // namespace mxnet
Expand Down
84 changes: 82 additions & 2 deletions src/operator/contrib/boolean_mask.cc
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@ namespace op {

DMLC_REGISTER_PARAMETER(BooleanMaskParam);


bool BooleanMaskType(const nnvm::NodeAttrs& attrs,
std::vector<int> *in_attrs,
std::vector<int> *out_attrs) {
Expand Down Expand Up @@ -75,9 +74,86 @@ bool BooleanMaskBackStorageType(const nnvm::NodeAttrs& attrs,
return true;
}

template<>
inline void BooleanMaskForward<cpu>(const nnvm::NodeAttrs& attrs,
const OpContext &ctx,
const std::vector<NDArray> &inputs,
const std::vector<OpReqType> &req,
const std::vector<NDArray> &outputs) {
// TODO(@junrushao1994): This implementation is a proof-of-concept,
// hence very slow actually. Performance should be improved in the future.
CHECK_EQ(inputs.size(), 2U);
CHECK_EQ(outputs.size(), 1U);
const BooleanMaskParam& param = nnvm::get<BooleanMaskParam>(attrs.parsed);
const int axis = param.axis;
const NDArray &data = inputs[0];
const NDArray &idx = inputs[1];
const NDArray &out = outputs[0];
CHECK_EQ(axis, 0) << "Not supported yet";
CHECK_EQ(data.shape()[axis], idx.shape()[0]);
CHECK_EQ(idx.shape().ndim(), 1U);
// count the number of 1s in `idx`, so that we could know the output dimension
size_t valid_num = 0;
MSHADOW_TYPE_SWITCH(idx.dtype(), DType, {
DType* idx_dptr = idx.data().dptr<DType>();
int length = idx.shape()[0];
for (int i = 0; i < length; i++) {
if (idx_dptr[i]) {
++valid_num;
}
}
});
// set the output shape forcefully
TShape s = data.shape();
s[axis] = valid_num;
const_cast<NDArray &>(out).Init(s);
// do the copy
MSHADOW_TYPE_SWITCH(idx.dtype(), DType, {
DType* idx_dptr = idx.data().dptr<DType>();
int length = idx.shape()[0];
mshadow::Stream<cpu> *stream = ctx.get_stream<cpu>();
for (int i = 0, j = 0; i < length; ++i) {
if (idx_dptr[i]) {
NDArray src = data.At(i);
NDArray dst = out.At(j++);
CHECK(src.shape() == dst.shape());
mxnet_op::copy(stream, dst.data(), src.data());
}
}
});
}

template<>
inline void BooleanMaskBackward<cpu>(const nnvm::NodeAttrs& attrs,
const OpContext &ctx,
const std::vector<NDArray> &inputs,
const std::vector<OpReqType> &req,
const std::vector<NDArray> &outputs) {
CHECK_EQ(inputs.size(), 3U);
CHECK_EQ(outputs.size(), 2U);
// inputs: {ograd, data, idx}
// outputs: {igrad_data, igrad_idx}
const NDArray& ograd = inputs[0];
const NDArray& idx = inputs[2];
const NDArray& igrad_data = outputs[0];
MSHADOW_TYPE_SWITCH(idx.dtype(), DType, {
DType* idx_dptr = idx.data().dptr<DType>();
int length = idx.shape()[0];
mshadow::Stream<cpu> *stream = ctx.get_stream<cpu>();
Fill<false>(stream, igrad_data.data(), req[0], 0);
for (int i = 0, j = 0; i < length; ++i) {
if (idx_dptr[i]) {
NDArray src = ograd.At(j++);
NDArray dst = igrad_data.At(i);
CHECK(src.shape() == dst.shape());
mxnet_op::copy(stream, dst.data(), src.data());
}
}
});
}

NNVM_REGISTER_OP(_contrib_boolean_mask)
.describe(R"code(
Experimental CPU-only support for boolean masking.
Given an n-d NDArray data, and a 1-d NDArray index,
the operator produces an un-predeterminable shaped n-d NDArray out,
which stands for the rows in x where the corresonding element in index is non-zero.
Expand All @@ -94,6 +170,10 @@ which stands for the rows in x where the corresonding element in index is non-ze
.set_attr_parser(ParamParser<BooleanMaskParam>)
.set_num_inputs(2)
.set_num_outputs(1)
.set_attr<nnvm::FListInputNames>("FListInputNames",
[](const NodeAttrs& attrs) {
return std::vector<std::string>{"data", "index"};
})
.set_attr<nnvm::FInferType>("FInferType", BooleanMaskType)
.set_attr<FComputeEx>("FComputeEx<cpu>", BooleanMaskForward<cpu>)
.set_attr<FInferStorageType>("FInferStorageType", BooleanMaskStorageType)
Expand Down
191 changes: 191 additions & 0 deletions src/operator/contrib/boolean_mask.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,191 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* Copyright (c) 2018 by Contributors
* \file boolean_mask.cu
*/

#include "./boolean_mask-inl.h"
#include <cub/cub.cuh>

namespace mxnet {
namespace op {

struct BooleanMaskForwardKernel {
template<typename DType>
static void MSHADOW_XINLINE Map(int i,
DType* out,
const DType* data,
const int32_t* idx,
const size_t col_size) {
int row_id = i / col_size;
int col_id = i % col_size;
int32_t prev = (row_id == 0) ? 0 : idx[row_id - 1];
int32_t curr = idx[row_id];
if (prev != curr) {
out[prev * col_size + col_id] = data[i];
}
}
};

struct BooleanMaskBackwardKernel {
template<typename DType>
static void MSHADOW_XINLINE Map(int i,
DType* igrad,
const DType* ograd,
const int32_t* idx,
const size_t col_size) {
int row_id = i / col_size;
int col_id = i % col_size;
int32_t prev = (row_id == 0) ? 0 : idx[row_id - 1];
int32_t curr = idx[row_id];
if (prev != curr) {
igrad[i] = ograd[prev * col_size + col_id];
}
}
};

template<>
inline void BooleanMaskForward<gpu>(const nnvm::NodeAttrs& attrs,
const OpContext &ctx,
const std::vector<NDArray> &inputs,
const std::vector<OpReqType> &req,
const std::vector<NDArray> &outputs) {
using namespace mshadow;
CHECK_EQ(inputs.size(), 2U);
CHECK_EQ(outputs.size(), 1U);
const BooleanMaskParam& param = nnvm::get<BooleanMaskParam>(attrs.parsed);
const int axis = param.axis;
const NDArray &data = inputs[0];
const NDArray &idx = inputs[1];
const NDArray &out = outputs[0];
CHECK_EQ(axis, 0) << "Not supported yet";
CHECK_EQ(data.shape()[axis], idx.shape()[0]);
CHECK_EQ(idx.shape().ndim(), 1U);
Stream<gpu>* s = ctx.get_stream<gpu>();
// count the number of 1s in `idx`, so that we could know the output dimension
size_t idx_size = idx.shape()[0];
int32_t valid_num = 0;
int32_t* prefix_sum = nullptr;
void* d_temp_storage = nullptr;
size_t temp_storage_bytes = 0;
cub::DeviceScan::InclusiveSum(d_temp_storage,
temp_storage_bytes,
prefix_sum,
prefix_sum,
idx_size,
Stream<gpu>::GetStream(s));
size_t buffer_size = idx_size * sizeof(int32_t);
temp_storage_bytes += buffer_size;
Tensor<gpu, 1, char> workspace =
ctx.requested[0].get_space_typed<gpu, 1, char>(Shape1(temp_storage_bytes), s);
prefix_sum = reinterpret_cast<int32_t*>(workspace.dptr_);
d_temp_storage = workspace.dptr_ + buffer_size;
MSHADOW_TYPE_SWITCH(idx.dtype(), IType, {
mxnet_op::Kernel<mshadow_op::identity_with_cast, gpu>::Launch(
s, idx.shape()[0], prefix_sum, idx.data().dptr<IType>());
});
cub::DeviceScan::InclusiveSum(d_temp_storage,
temp_storage_bytes,
prefix_sum,
prefix_sum,
idx_size,
Stream<gpu>::GetStream(s));
CUDA_CALL(cudaMemcpy(&valid_num, &prefix_sum[idx_size - 1], sizeof(int32_t),
cudaMemcpyDeviceToHost));
CHECK(valid_num > 0) << "boolean_mask behavior not defined when all masks are 0";
// Set the output shape forcefully
TShape data_shape = data.shape();
data_shape[axis] = valid_num;
const_cast<NDArray &>(out).Init(data_shape);
size_t input_size = data.shape().Size();
size_t col_size = input_size / idx.shape()[0];
// do the copy
MSHADOW_TYPE_SWITCH(out.dtype(), DType, {
mxnet_op::Kernel<BooleanMaskForwardKernel, gpu>::Launch(
s, input_size, out.data().dptr<DType>(), data.data().dptr<DType>(), prefix_sum, col_size);
});
}

template<>
inline void BooleanMaskBackward<gpu>(const nnvm::NodeAttrs& attrs,
const OpContext &ctx,
const std::vector<NDArray> &inputs,
const std::vector<OpReqType> &req,
const std::vector<NDArray> &outputs) {
using namespace mshadow;
CHECK_EQ(inputs.size(), 3U);
CHECK_EQ(outputs.size(), 2U);
// inputs: {ograd, data, idx}
// outputs: {igrad_data, igrad_idx}
const NDArray& ograd = inputs[0];
const NDArray& idx = inputs[2];
const NDArray& igrad_data = outputs[0];
Stream<gpu>* s = ctx.get_stream<gpu>();
// count the number of 1s in `idx`, so that we could know the output dimension
size_t idx_size = idx.shape()[0];
int32_t* prefix_sum = nullptr;
void* d_temp_storage = nullptr;
size_t temp_storage_bytes = 0;
cub::DeviceScan::InclusiveSum(d_temp_storage,
temp_storage_bytes,
prefix_sum,
prefix_sum,
idx_size,
Stream<gpu>::GetStream(s));
size_t buffer_size = idx_size * sizeof(int32_t);
temp_storage_bytes += buffer_size;
Tensor<gpu, 1, char> workspace =
ctx.requested[0].get_space_typed<gpu, 1, char>(Shape1(temp_storage_bytes), s);
prefix_sum = reinterpret_cast<int32_t*>(workspace.dptr_);
d_temp_storage = workspace.dptr_ + buffer_size;
MSHADOW_TYPE_SWITCH(idx.dtype(), IType, {
mxnet_op::Kernel<mshadow_op::identity_with_cast, gpu>::Launch(
s, idx.shape()[0], prefix_sum, idx.data().dptr<IType>());
});
cub::DeviceScan::InclusiveSum(d_temp_storage,
temp_storage_bytes,
prefix_sum,
prefix_sum,
idx_size,
Stream<gpu>::GetStream(s));
size_t input_size = igrad_data.shape().Size();
size_t col_size = input_size / idx_size;
MSHADOW_TYPE_SWITCH(igrad_data.dtype(), DType, {
mxnet_op::Kernel<BooleanMaskBackwardKernel, gpu>::Launch(
s, input_size, igrad_data.data().dptr<DType>(), ograd.data().dptr<DType>(), prefix_sum, col_size);
});
}

NNVM_REGISTER_OP(_contrib_boolean_mask)
.set_attr<FResourceRequest>("FResourceRequest",
[](const NodeAttrs& attrs) {
return std::vector<ResourceRequest>{ResourceRequest::kTempSpace};
})
.set_attr<FComputeEx>("FComputeEx<gpu>", BooleanMaskForward<gpu>);

NNVM_REGISTER_OP(_backward_contrib_boolean_mask)
.set_attr<FResourceRequest>("FResourceRequest",
[](const NodeAttrs& attrs) {
return std::vector<ResourceRequest>{ResourceRequest::kTempSpace};
})
.set_attr<FComputeEx>("FComputeEx<gpu>", BooleanMaskBackward<gpu>);

} // namespace op
} // namespace mxnet
2 changes: 0 additions & 2 deletions tests/python/unittest/test_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -4861,8 +4861,6 @@ def test_index_copy():

@with_seed()
def test_boolean_mask():
if default_context().device_type != 'cpu':
return
data = mx.nd.array([[1, 2, 3],[4, 5, 6],[7, 8, 9]])
index = mx.nd.array([0, 1, 0])
data.attach_grad()
Expand Down

0 comments on commit ba9ec22

Please sign in to comment.