diff --git a/docs/en/understand_mmcv/ops.md b/docs/en/understand_mmcv/ops.md index 15aadd2980..4adf40d876 100644 --- a/docs/en/understand_mmcv/ops.md +++ b/docs/en/understand_mmcv/ops.md @@ -18,7 +18,7 @@ We implement common ops used in detection, segmentation, etc. | CornerPool | | √ | | | | Correlation | | √ | | | | Deformable Convolution v1/v2 | √ | √ | | | -| Deformable RoIPool | | √ | | | +| Deformable RoIPool | | √ | √ | | | DiffIoURotated | | √ | | | | DynamicScatter | | √ | | | | FurthestPointSample | | √ | | | diff --git a/docs/zh_cn/understand_mmcv/ops.md b/docs/zh_cn/understand_mmcv/ops.md index fdcc8bab5c..44a60d119f 100644 --- a/docs/zh_cn/understand_mmcv/ops.md +++ b/docs/zh_cn/understand_mmcv/ops.md @@ -18,7 +18,7 @@ MMCV 提供了检测、分割等任务中常用的算子 | CornerPool | | √ | | | | Correlation | | √ | | | | Deformable Convolution v1/v2 | √ | √ | | | -| Deformable RoIPool | | √ | | | +| Deformable RoIPool | | √ | √ | | | DiffIoURotated | | √ | | | | DynamicScatter | | √ | | | | FurthestPointSample | | √ | | | diff --git a/mmcv/ops/csrc/common/mlu/deform_roi_pool_mlu_kernel.mlu b/mmcv/ops/csrc/common/mlu/deform_roi_pool_mlu_kernel.mlu new file mode 100644 index 0000000000..6c765e3eaa --- /dev/null +++ b/mmcv/ops/csrc/common/mlu/deform_roi_pool_mlu_kernel.mlu @@ -0,0 +1,712 @@ +/************************************************************************* + * Copyright (C) 2022 Cambricon. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS + * OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF + * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. + * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY + * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, + * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE + * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + *************************************************************************/ +#include + +#include "common_mlu_helper.hpp" + +#define ROI_OFFSET 5 +#define FOURSPLIT 4 +#define FIVESPLIT 5 +#define NINESPLIT 9 +#define THIRTEENSPLIT 13 + +__nram__ char nram_buffer[MAX_NRAM_SIZE]; + +template +static __mlu_func__ void bilinearInterpolate(const int input_width, T y, T x, + T *w1, T *w2, T *w3, T *w4, + int *x_low, int *x_high, + const int y_low, bool *is_empty) { + if (x < -1.0 || x > input_width) { + *is_empty = true; + return; + } + + if (x <= 0) x = 0; + + *x_low = int(x); + + if (*x_low >= input_width - 1) { + *x_high = *x_low = input_width - 1; + x = T(*x_low); + } else { + *x_high = *x_low + 1; + } + + T ly = y - y_low; + T lx = x - *x_low; + T hy = 1.0 - ly; + T hx = 1.0 - lx; + *w1 = hy * hx; + *w2 = hy * lx; + *w3 = ly * hx; + *w4 = ly * lx; +} + +template +__mlu_func__ void MLUUnion1DeformRoIPoolForward( + const T *input, const T *rois, const T *offset, T *output, + const int channels, const int height, const int width, const int num_rois, + const int pooled_height, const int pooled_width, const T spatial_scale, + const int sampling_ratio, const T gamma) { + for (int bin_index = taskId; + bin_index < num_rois * pooled_width * pooled_height; + bin_index += taskDim) { + int out_batch = bin_index / pooled_width / pooled_height; + int out_height = bin_index / pooled_width % pooled_height; + int out_width = bin_index % pooled_width; + const T *cur_roi = rois + out_batch * ROI_OFFSET; + T *nram_rois = (T *)nram_buffer; + __memcpy((void *)nram_rois, (void *)cur_roi, ROI_OFFSET * sizeof(T), + GDRAM2NRAM); + const int roi_batch = nram_rois[0]; + T roi_x_min = nram_rois[1] * spatial_scale - 0.5; + T roi_y_min = nram_rois[2] * spatial_scale - 0.5; + const T roi_x_max = nram_rois[3] * spatial_scale - 0.5; + const T roi_y_max = nram_rois[4] * spatial_scale - 0.5; + const T roi_width = roi_x_max - roi_x_min; + const T roi_height = roi_y_max - roi_y_min; + const T bin_width = roi_width / static_cast(pooled_width); + const T bin_height = roi_height / static_cast(pooled_height); + const T *offset_input = input + roi_batch * height * width * channels; + int roi_bin_grid_height = + (sampling_ratio > 0) + ? sampling_ratio + : static_cast(ceilf(roi_height / pooled_height)); + int roi_bin_grid_width = + (sampling_ratio > 0) + ? sampling_ratio + : static_cast(ceilf(roi_width / pooled_width)); + if (offset != NULL) { + const T *offset_cur = offset + + out_batch * pooled_width * pooled_height * 2 + + out_height * pooled_width + out_width; + roi_x_min += gamma * roi_width * offset_cur[0]; + roi_y_min += + gamma * roi_height * offset_cur[pooled_width * pooled_height]; + } + int type_align = NFU_ALIGN_SIZE / sizeof(T); + int channels_max_num_nram = MAX_NRAM_SIZE / sizeof(T); + int channels_nram_split = + channels_max_num_nram / NINESPLIT / type_align * type_align; + int channel_rem = channels % channels_nram_split; + int channel_loops = + channels / channels_nram_split + (channel_rem != 0 ? 1 : 0); + for (int channel_loop_index = 0; channel_loop_index < channel_loops; + ++channel_loop_index) { + int channels_num = + channels_nram_split >= channels ? channels : channels_nram_split; + const int channel_offset = channel_loop_index * channels_num; + if (channel_loop_index + 1 == channel_loops && channel_rem != 0) { + channels_num = channel_rem; + } + int channels_align = CEIL_ALIGN(channels_num, type_align); + int nram_limit = (MAX_NRAM_SIZE / sizeof(T) - channels_align) >> 1; + int c_slice = nram_limit / FOURSPLIT / type_align * type_align; + int c_slice_align = 0; + + /* NRAM partition + * + * | | ping | pong | + * |----------|-------------------|-------------------| + * | nram_out | p1 | p2 | p3 | p4 | p1 | p2 | p3 | p4 | + * + */ + + T *nram_out = (T *)nram_buffer; + T *nram_ping = nram_out + channels_align; + T *nram_pong = nram_ping + nram_limit; + __bang_write_value((T *)nram_out, channels_align, (T)0); + __bang_write_value((T *)nram_ping, FOURSPLIT * c_slice, (T)0); + __bang_write_value((T *)nram_pong, FOURSPLIT * c_slice, (T)0); + const T num_bins = + static_cast(max(roi_bin_grid_height * roi_bin_grid_width, 1)); + const T value_div = 1.0f / num_bins; + bool is_ping_empty = true; + for (int iy = 0; iy < roi_bin_grid_height; ++iy) { + T y = roi_y_min + out_height * bin_height + + static_cast(iy + .5f) * bin_height / + static_cast(roi_bin_grid_height); + if (y < -1.0 || y > height) { + is_ping_empty = true; + continue; + } + if (y <= 0) { + y = 0; + } + int y_low = 0, y_high = 0; + y_low = int(y); + if (y_low >= height - 1) { + y_high = y_low = height - 1; + y = T(y_low); + } else { + y_high = y_low + 1; + } + for (int ix = 0; ix < roi_bin_grid_width; ++ix) { + T x = roi_x_min + out_width * bin_width + + static_cast(ix + .5f) * bin_width / + static_cast(roi_bin_grid_width); + const int sample_index = iy * roi_bin_grid_width + ix; + int c_rem = channels_num; + c_slice = nram_limit / FOURSPLIT / type_align * type_align; + c_slice_align = 0; + bool is_empty = false; + T w1, w2, w3, w4; + int x_low = 0, x_high = 0; + bilinearInterpolate(width, y, x, &w1, &w2, &w3, &w4, &x_low, &x_high, + y_low, &is_empty); + if (is_empty) { + is_ping_empty = true; + continue; + } + if (is_ping_empty) { + c_slice = c_slice > c_rem ? c_rem : c_slice; + c_slice_align = CEIL_ALIGN(c_slice, type_align); + __bang_write_value(nram_ping, FOURSPLIT * c_slice_align, (T)0); + __asm__ volatile("sync;"); + __memcpy(nram_ping, + offset_input + y_low * width * channels + + x_low * channels + channel_offset, + c_slice * sizeof(T), GDRAM2NRAM); + __memcpy(nram_ping + c_slice_align, + offset_input + y_low * width * channels + + x_high * channels + channel_offset, + c_slice * sizeof(T), GDRAM2NRAM); + __memcpy(nram_ping + 2 * c_slice_align, + offset_input + y_high * width * channels + + x_low * channels + channel_offset, + c_slice * sizeof(T), GDRAM2NRAM); + __memcpy(nram_ping + 3 * c_slice_align, + offset_input + y_high * width * channels + + x_high * channels + channel_offset, + c_slice * sizeof(T), GDRAM2NRAM); + is_ping_empty = false; + } + int c_offset = 0; + int pongc_slice = 0; + int pongc_slice_align = 0; + while (c_rem > 0) { + c_slice = c_slice > c_rem ? c_rem : c_slice; + c_slice_align = CEIL_ALIGN(c_slice, type_align); + if (sample_index + 1 < roi_bin_grid_height * roi_bin_grid_width) { + int iy_tmp = (sample_index + 1) / roi_bin_grid_width; + int ix_tmp = (sample_index + 1) % roi_bin_grid_width; + y = roi_y_min + out_height * bin_height + + static_cast(iy_tmp + .5f) * bin_height / + static_cast(roi_bin_grid_height); + x = roi_x_min + out_width * bin_width + + static_cast(ix_tmp + .5f) * bin_width / + static_cast(roi_bin_grid_width); + if (y < -1.0 || y > height) { + is_empty = true; + } else { + T w1_tmp, w2_tmp, w3_tmp, w4_tmp; + if (y <= 0) { + y = 0; + } + y_low = int(y); + if (y_low >= height - 1) { + y_high = y_low = height - 1; + y = T(y_low); + } else { + y_high = y_low + 1; + } + bilinearInterpolate(width, y, x, &w1_tmp, &w2_tmp, &w3_tmp, + &w4_tmp, &x_low, &x_high, y_low, &is_empty); + } + pongc_slice = nram_limit / FOURSPLIT / type_align * type_align; + pongc_slice = + pongc_slice > channels_num ? channels_num : pongc_slice; + pongc_slice_align = CEIL_ALIGN(pongc_slice, type_align); + __bang_write_value(nram_pong, FOURSPLIT * pongc_slice_align, + (T)0); + __asm__ volatile("sync;"); + if (!is_empty) { + __memcpy_async(nram_pong, + offset_input + y_low * width * channels + + x_low * channels + channel_offset, + pongc_slice * sizeof(T), GDRAM2NRAM); + __memcpy_async(nram_pong + pongc_slice_align, + offset_input + y_low * width * channels + + x_high * channels + channel_offset, + pongc_slice * sizeof(T), GDRAM2NRAM); + __memcpy_async(nram_pong + 2 * pongc_slice_align, + offset_input + y_high * width * channels + + x_low * channels + channel_offset, + pongc_slice * sizeof(T), GDRAM2NRAM); + __memcpy_async(nram_pong + 3 * pongc_slice_align, + offset_input + y_high * width * channels + + x_high * channels + channel_offset, + pongc_slice * sizeof(T), GDRAM2NRAM); + } + } + __bang_mul_scalar(nram_ping, nram_ping, w1, c_slice_align); + __bang_mul_scalar(nram_ping + c_slice_align, + nram_ping + c_slice_align, w2, c_slice_align); + __bang_add(nram_ping, nram_ping, nram_ping + c_slice_align, + c_slice_align); + __bang_mul_scalar(nram_ping + 2 * c_slice_align, + nram_ping + 2 * c_slice_align, w3, c_slice_align); + __bang_add(nram_ping, nram_ping, nram_ping + 2 * c_slice_align, + c_slice_align); + __bang_mul_scalar(nram_ping + 3 * c_slice_align, + nram_ping + 3 * c_slice_align, w4, c_slice_align); + __bang_add(nram_ping, nram_ping, nram_ping + 3 * c_slice_align, + c_slice_align); + __bang_add(nram_out + c_offset, nram_out + c_offset, nram_ping, + c_slice_align); + T *nram_tmp = nram_ping; + nram_ping = nram_pong; + nram_pong = nram_tmp; + c_rem -= c_slice; + c_offset += c_slice; + __asm__ volatile("sync;"); + } + } + } + __bang_mul_scalar(nram_out, nram_out, value_div, channels_align); + __memcpy(output + channels * bin_index + channel_offset, nram_out, + channels_num * sizeof(T), NRAM2GDRAM); + } + } +} + +__mlu_global__ void MLUKernelDeformRoIPoolForward( + cnrtDataType_t data_type, const void *input, const void *rois, + const void *offset, void *output, const int channels, const int height, + const int width, const int num_rois, const int pooled_height, + const int pooled_width, const float spatial_scale, const int sampling_ratio, + const float gamma) { + switch (data_type) { + case CNRT_FLOAT16: { + MLUUnion1DeformRoIPoolForward((half *)input, (half *)rois, (half *)offset, + (half *)output, channels, height, width, + num_rois, pooled_height, pooled_width, + static_cast(spatial_scale), + sampling_ratio, static_cast(gamma)); + }; break; + case CNRT_FLOAT32: { + MLUUnion1DeformRoIPoolForward( + (float *)input, (float *)rois, (float *)offset, (float *)output, + channels, height, width, num_rois, pooled_height, pooled_width, + static_cast(spatial_scale), sampling_ratio, + static_cast(gamma)); + }; break; + default: { + break; + } + } +} + +void KernelDeformRoIPoolForward(cnrtDim3_t k_dim, cnrtFunctionType_t k_type, + cnrtQueue_t queue, cnrtDataType_t data_type, + const void *input, const void *rois, + const void *offset, void *output, + const int channels, const int height, + const int width, const int num_rois, + const int pooled_height, const int pooled_width, + const float spatial_scale, + const int sampling_ratio, const float gamma) { + MLUKernelDeformRoIPoolForward<<>>( + data_type, input, rois, offset, output, channels, height, width, num_rois, + pooled_height, pooled_width, spatial_scale, sampling_ratio, gamma); +} + +template +__mlu_func__ void MLUUnion1DeformRoIPoolBackward( + const T *grad_output, const T *input, const T *rois, const T *offset, + T *grad_input, T *grad_offset, const int channels, const int height, + const int width, const int num_rois, const int pooled_height, + const int pooled_width, const T spatial_scale, const int sampling_ratio, + const T gamma) { + for (int bin_index = taskId; + bin_index < num_rois * pooled_width * pooled_height; + bin_index += taskDim) { + int out_batch = bin_index / pooled_width / pooled_height; + int out_height = bin_index / pooled_width % pooled_height; + int out_width = bin_index % pooled_width; + const T *cur_roi = rois + out_batch * ROI_OFFSET; + T *nram_rois = (T *)nram_buffer; + __memcpy((void *)nram_rois, (void *)cur_roi, ROI_OFFSET * sizeof(T), + GDRAM2NRAM); + const int roi_batch = nram_rois[0]; + T roi_x_min = nram_rois[1] * spatial_scale - 0.5; + T roi_y_min = nram_rois[2] * spatial_scale - 0.5; + const T roi_x_max = nram_rois[3] * spatial_scale - 0.5; + const T roi_y_max = nram_rois[4] * spatial_scale - 0.5; + const T roi_width = roi_x_max - roi_x_min; + const T roi_height = roi_y_max - roi_y_min; + const T bin_width = roi_width / static_cast(pooled_width); + const T bin_height = roi_height / static_cast(pooled_height); + const T *offset_input = input + roi_batch * height * width * channels; + T *offset_grad_input = grad_input + roi_batch * height * width * channels; + int roi_bin_grid_height = + (sampling_ratio > 0) + ? sampling_ratio + : static_cast(ceilf(roi_height / pooled_height)); + int roi_bin_grid_width = + (sampling_ratio > 0) + ? sampling_ratio + : static_cast(ceilf(roi_width / pooled_width)); + if (offset != NULL) { + const T *offset_cur = offset + + out_batch * pooled_width * pooled_height * 2 + + out_height * pooled_width + out_width; + roi_x_min += gamma * roi_width * offset_cur[0]; + roi_y_min += + gamma * roi_height * offset_cur[pooled_width * pooled_height]; + } + + /* NRAM partition + * + * If offset != NULL, NRAM partition belows. + * | | + * ping | pong | + * |---------------------------------------------------------------------|-----------|-----------| + * |nram_tmp1|nram_tmp2|nram_tmp3|nram_tmp4|nram_grad_output|nram_sum_tmp|p1|p2|p3|p4|p1|p2|p3|p4| + * + * If offset == NULL, ping and pang will not be needed. + * | | + * |----------------------------------------------------------------------------------| + * | nram_tmp1 | nram_tmp2 | nram_tmp3 | nram_tmp4 | nram_grad_output | + * + */ + + int type_align = NFU_ALIGN_SIZE / sizeof(T); + int channels_max_num_nram = MAX_NRAM_SIZE / sizeof(T); + int channels_nram_split = + channels_max_num_nram / FIVESPLIT / type_align * type_align; + int channel_rem = channels % channels_nram_split; + int channel_loops = + channels / channels_nram_split + (channel_rem != 0 ? 1 : 0); + if (offset != NULL) { + channels_nram_split = + channels_max_num_nram / THIRTEENSPLIT / type_align * type_align; + channel_rem = channels % channels_nram_split; + channel_loops = + channels / channels_nram_split + (channel_rem != 0 ? 1 : 0); + } + + for (int channel_loop_index = 0; channel_loop_index < channel_loops; + ++channel_loop_index) { + int channels_num = + channels_nram_split >= channels ? channels : channels_nram_split; + const int channel_offset = channel_loop_index * channels_num; + if (channel_loop_index + 1 == channel_loops && channel_rem != 0) { + channels_num = channel_rem; + } + int channels_align = CEIL_ALIGN(channels_num, type_align); + const int32_t nram_sum_tmp_channel = NFU_ALIGN_SIZE / sizeof(T); + int nram_limit = (MAX_NRAM_SIZE / sizeof(T) - 5 * channels_align - + nram_sum_tmp_channel) >> + 1; + int c_slice = 0; + int c_slice_align = 0; + T *nram_tmp1 = (T *)nram_buffer; + T *nram_tmp2 = (T *)nram_buffer + channels_align; + T *nram_tmp3 = (T *)nram_buffer + 2 * channels_align; + T *nram_tmp4 = (T *)nram_buffer + 3 * channels_align; + T *nram_grad_output = nram_tmp4 + channels_align; + T *nram_sum_tmp = NULL; + T *nram_ping_input = NULL; + T *nram_pong_input = NULL; + __bang_write_value((T *)nram_grad_output, channels_align, (T)0); + __asm__ volatile("sync;"); + + if (offset != NULL) { + c_slice = nram_limit / FOURSPLIT / type_align * type_align; + nram_sum_tmp = nram_grad_output + channels_align; + nram_ping_input = nram_sum_tmp + nram_sum_tmp_channel; + nram_pong_input = nram_ping_input + FOURSPLIT * c_slice; + __bang_write_value((T *)nram_sum_tmp, nram_sum_tmp_channel, (T)0); + __bang_write_value((T *)nram_ping_input, FOURSPLIT * c_slice, (T)0); + __bang_write_value((T *)nram_pong_input, FOURSPLIT * c_slice, (T)0); + __asm__ volatile("sync;"); + } + const T num_bins = + static_cast(max(roi_bin_grid_height * roi_bin_grid_width, 1)); + const T value_div = 1.0f / num_bins; + bool is_ping_empty = true; + __memcpy(nram_grad_output, + grad_output + channels * bin_index + channel_offset, + channels_num * sizeof(T), GDRAM2NRAM); + __bang_mul_scalar(nram_grad_output, nram_grad_output, value_div, + channels_align); + for (int iy = 0; iy < roi_bin_grid_height; ++iy) { + T y = roi_y_min + out_height * bin_height + + static_cast(iy + .5f) * bin_height / + static_cast(roi_bin_grid_height); + T y_tmp = y; + if (y_tmp < -1.0 || y_tmp > height) { + is_ping_empty = true; + continue; + } + if (y_tmp <= 0) { + y_tmp = 0; + } + int y_low = 0, y_high = 0; + y_low = int(y_tmp); + if (y_low >= height - 1) { + y_high = y_low = height - 1; + y_tmp = T(y_low); + } else { + y_high = y_low + 1; + } + for (int ix = 0; ix < roi_bin_grid_width; ++ix) { + T x = roi_x_min + out_width * bin_width + + static_cast(ix + .5f) * bin_width / + static_cast(roi_bin_grid_width); + const int sample_index = iy * roi_bin_grid_width + ix; + int c_rem = channels_num; + bool is_empty = false; + T w1, w2, w3, w4; + int x_low = 0, x_high = 0; + bilinearInterpolate(width, y_tmp, x, &w1, &w2, &w3, &w4, &x_low, + &x_high, y_low, &is_empty); + if (is_empty) { + is_ping_empty = true; + continue; + } + __bang_mul_scalar((T *)nram_tmp1, (T *)nram_grad_output, w1, + channels_align); + __bang_mul_scalar((T *)nram_tmp2, (T *)nram_grad_output, w2, + channels_align); + __bang_mul_scalar((T *)nram_tmp3, (T *)nram_grad_output, w3, + channels_align); + __bang_mul_scalar((T *)nram_tmp4, (T *)nram_grad_output, w4, + channels_align); + __asm__ volatile("sync;"); + __bang_atomic_add( + (T *)nram_tmp1, + (T *)(offset_grad_input + (y_low * width + x_low) * channels + + channel_offset), + (T *)nram_tmp1, channels_num); + __bang_atomic_add( + (T *)nram_tmp2, + (T *)(offset_grad_input + (y_low * width + x_high) * channels + + channel_offset), + (T *)nram_tmp2, channels_num); + __bang_atomic_add( + (T *)nram_tmp3, + (T *)(offset_grad_input + (y_high * width + x_low) * channels + + channel_offset), + (T *)nram_tmp3, channels_num); + __bang_atomic_add( + (T *)nram_tmp4, + (T *)(offset_grad_input + (y_high * width + x_high) * channels + + channel_offset), + (T *)nram_tmp4, channels_num); + if (offset != NULL) { + c_slice = nram_limit / FOURSPLIT / type_align * type_align; + c_slice_align = 0; + if (is_ping_empty) { + c_slice = c_slice > c_rem ? c_rem : c_slice; + c_slice_align = CEIL_ALIGN(c_slice, type_align); + __bang_write_value(nram_ping_input, FOURSPLIT * c_slice_align, + (T)0); + __asm__ volatile("sync;"); + const T *src_offset1 = offset_input + y_low * width * channels + + x_low * channels + channel_offset; + const T *src_offset2 = offset_input + y_low * width * channels + + x_high * channels + channel_offset; + const T *src_offset3 = offset_input + y_high * width * channels + + x_low * channels + channel_offset; + const T *src_offset4 = offset_input + y_high * width * channels + + x_high * channels + channel_offset; + __memcpy(nram_ping_input, src_offset1, c_slice * sizeof(T), + GDRAM2NRAM); + __memcpy(nram_ping_input + c_slice_align, src_offset2, + c_slice * sizeof(T), GDRAM2NRAM); + __memcpy(nram_ping_input + 2 * c_slice_align, src_offset3, + c_slice * sizeof(T), GDRAM2NRAM); + __memcpy(nram_ping_input + 3 * c_slice_align, src_offset4, + c_slice * sizeof(T), GDRAM2NRAM); + is_ping_empty = false; + } + int c_offset = 0; + int pongc_slice = 0; + int pongc_slice_align = 0; + while (c_rem > 0) { + c_slice = c_slice > c_rem ? c_rem : c_slice; + c_slice_align = CEIL_ALIGN(c_slice, type_align); + if (sample_index + 1 < roi_bin_grid_height * roi_bin_grid_width) { + int iy_tmp = (sample_index + 1) / roi_bin_grid_width; + int ix_tmp = (sample_index + 1) % roi_bin_grid_width; + T y_tmp = roi_y_min + out_height * bin_height + + static_cast(iy_tmp + .5f) * bin_height / + static_cast(roi_bin_grid_height); + T x_tmp = roi_x_min + out_width * bin_width + + static_cast(ix_tmp + .5f) * bin_width / + static_cast(roi_bin_grid_width); + int x_low_tmp = 0, x_high_tmp = 0, y_low_tmp = 0, + y_high_tmp = 0; + if (y_tmp < -1.0 || y_tmp > height) { + is_empty = true; + } else { + T w1_tmp, w2_tmp, w3_tmp, w4_tmp; + if (y_tmp <= 0) { + y_tmp = 0; + } + y_low_tmp = int(y_tmp); + if (y_low_tmp >= height - 1) { + y_high_tmp = y_low_tmp = height - 1; + y_tmp = T(y_low_tmp); + } else { + y_high_tmp = y_low_tmp + 1; + } + bilinearInterpolate(width, y_tmp, x_tmp, &w1_tmp, &w2_tmp, + &w3_tmp, &w4_tmp, &x_low_tmp, &x_high_tmp, + y_low_tmp, &is_empty); + } + pongc_slice = nram_limit / FOURSPLIT / type_align * type_align; + pongc_slice = + pongc_slice > channels_num ? channels_num : pongc_slice; + pongc_slice_align = CEIL_ALIGN(pongc_slice, type_align); + __bang_write_value(nram_pong_input, + FOURSPLIT * pongc_slice_align, (T)0); + __asm__ volatile("sync;"); + if (!is_empty) { + const T *src_offset1 = offset_input + + y_low_tmp * width * channels + + x_low_tmp * channels + channel_offset; + const T *src_offset2 = offset_input + + y_low_tmp * width * channels + + x_high_tmp * channels + channel_offset; + const T *src_offset3 = offset_input + + y_high_tmp * width * channels + + x_low_tmp * channels + channel_offset; + const T *src_offset4 = offset_input + + y_high_tmp * width * channels + + x_high_tmp * channels + channel_offset; + __memcpy_async(nram_pong_input, src_offset1, + pongc_slice * sizeof(T), GDRAM2NRAM); + __memcpy_async(nram_pong_input + pongc_slice_align, + src_offset2, pongc_slice * sizeof(T), + GDRAM2NRAM); + __memcpy_async(nram_pong_input + 2 * pongc_slice_align, + src_offset3, pongc_slice * sizeof(T), + GDRAM2NRAM); + __memcpy_async(nram_pong_input + 3 * pongc_slice_align, + src_offset4, pongc_slice * sizeof(T), + GDRAM2NRAM); + } + } + + __bang_mul_scalar(nram_tmp1, nram_ping_input + 3 * c_slice_align, + y - y_low, c_slice_align); + __bang_mul_scalar(nram_tmp2, nram_ping_input + c_slice_align, + y_high - y, c_slice_align); + __bang_add(nram_tmp1, nram_tmp1, nram_tmp2, c_slice_align); + __bang_mul_scalar(nram_tmp2, nram_ping_input + 2 * c_slice_align, + y_low - y, c_slice_align); + __bang_add(nram_tmp1, nram_tmp1, nram_tmp2, c_slice_align); + __bang_mul_scalar(nram_tmp2, nram_ping_input, y - y_high, + c_slice_align); + __bang_add(nram_tmp1, nram_tmp1, nram_tmp2, c_slice_align); + __bang_mul_scalar(nram_tmp1, nram_tmp1, gamma * roi_width, + c_slice_align); + __bang_mul(nram_tmp1, nram_grad_output, nram_tmp1, c_slice_align); + const int32_t kernel_width = + c_slice_align / nram_sum_tmp_channel + + (int32_t)(c_slice_align % nram_sum_tmp_channel > 0); + __bang_sumpool(nram_sum_tmp, nram_tmp1, nram_sum_tmp_channel, 1, + kernel_width, 1, kernel_width, kernel_width, 1); + __bang_reduce_sum(nram_sum_tmp, nram_sum_tmp, + nram_sum_tmp_channel); + __bang_atomic_add( + (T *)nram_sum_tmp, + (T *)(grad_offset + + out_batch * pooled_width * pooled_height * 2 + + out_height * pooled_width + out_width), + (T *)nram_sum_tmp, 1); + __bang_write_value((T *)nram_sum_tmp, nram_sum_tmp_channel, (T)0); + __bang_mul_scalar(nram_tmp1, nram_ping_input + 3 * c_slice_align, + x - x_low, c_slice_align); + __bang_mul_scalar(nram_tmp2, nram_ping_input + 2 * c_slice_align, + x_high - x, c_slice_align); + __bang_add(nram_tmp1, nram_tmp1, nram_tmp2, c_slice_align); + __bang_mul_scalar(nram_tmp2, nram_ping_input + c_slice_align, + x_low - x, c_slice_align); + __bang_add(nram_tmp1, nram_tmp1, nram_tmp2, c_slice_align); + __bang_mul_scalar(nram_tmp2, nram_ping_input, x - x_high, + c_slice_align); + __bang_add(nram_tmp1, nram_tmp1, nram_tmp2, c_slice_align); + __bang_mul_scalar(nram_tmp1, nram_tmp1, gamma * roi_height, + c_slice_align); + __bang_mul(nram_tmp1, nram_grad_output, nram_tmp1, c_slice_align); + __bang_sumpool(nram_sum_tmp, nram_tmp1, nram_sum_tmp_channel, 1, + kernel_width, 1, kernel_width, kernel_width, 1); + __bang_reduce_sum(nram_sum_tmp, nram_sum_tmp, + NFU_ALIGN_SIZE / sizeof(T)); + __bang_atomic_add( + (T *)nram_sum_tmp, + (T *)(grad_offset + + out_batch * pooled_width * pooled_height * 2 + + pooled_width * pooled_height + + out_height * pooled_width + out_width), + (T *)nram_sum_tmp, 1); + + T *nram_tmp = nram_ping_input; + nram_ping_input = nram_pong_input; + nram_pong_input = nram_tmp; + c_rem -= c_slice; + c_offset += c_slice; + __asm__ volatile("sync;"); + } + } + } + } + } + } +} + +__mlu_global__ void MLUKernelDeformRoIPoolBackward( + cnrtDataType_t data_type, const void *grad_output, const void *input, + const void *rois, const void *offset, void *grad_input, void *grad_offset, + const int channels, const int height, const int width, const int num_rois, + const int pooled_height, const int pooled_width, const float spatial_scale, + const int sampling_ratio, const float gamma) { + switch (data_type) { + case CNRT_FLOAT16: { + MLUUnion1DeformRoIPoolBackward( + (half *)grad_output, (half *)input, (half *)rois, (half *)offset, + (half *)grad_input, (half *)grad_offset, channels, height, width, + num_rois, pooled_height, pooled_width, + static_cast(spatial_scale), sampling_ratio, + static_cast(gamma)); + }; break; + case CNRT_FLOAT32: { + MLUUnion1DeformRoIPoolBackward( + (float *)grad_output, (float *)input, (float *)rois, (float *)offset, + (float *)grad_input, (float *)grad_offset, channels, height, width, + num_rois, pooled_height, pooled_width, + static_cast(spatial_scale), sampling_ratio, + static_cast(gamma)); + }; break; + default: { + break; + } + } +} + +void KernelDeformRoIPoolBackward( + cnrtDim3_t k_dim, cnrtFunctionType_t k_type, cnrtQueue_t queue, + cnrtDataType_t data_type, const void *grad_output, const void *input, + const void *rois, const void *offset, void *grad_input, void *grad_offset, + const int channels, const int height, const int width, const int num_rois, + const int pooled_height, const int pooled_width, const float spatial_scale, + const int sampling_ratio, const float gamma) { + MLUKernelDeformRoIPoolBackward<<>>( + data_type, grad_output, input, rois, offset, grad_input, grad_offset, + channels, height, width, num_rois, pooled_height, pooled_width, + spatial_scale, sampling_ratio, gamma); +} diff --git a/mmcv/ops/csrc/pytorch/mlu/deform_roi_pool_mlu.cpp b/mmcv/ops/csrc/pytorch/mlu/deform_roi_pool_mlu.cpp new file mode 100644 index 0000000000..4d73cbbe59 --- /dev/null +++ b/mmcv/ops/csrc/pytorch/mlu/deform_roi_pool_mlu.cpp @@ -0,0 +1,343 @@ +/************************************************************************* + * Copyright (C) 2022 Cambricon. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS + * OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF + * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. + * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY + * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, + * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE + * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + *************************************************************************/ +#include "pytorch_device_registry.hpp" +#include "pytorch_mlu_helper.hpp" + +void KernelDeformRoIPoolForward(cnrtDim3_t k_dim, cnrtFunctionType_t k_type, + cnrtQueue_t queue, cnrtDataType_t data_type, + const void *input, const void *rois, + const void *offset, void *output, + const int channels, const int height, + const int width, const int num_rois, + const int pooled_height, const int pooled_width, + const float spatial_scale, + const int sampling_ratio, const float gamma); + +void KernelDeformRoIPoolBackward( + cnrtDim3_t k_dim, cnrtFunctionType_t k_type, cnrtQueue_t queue, + cnrtDataType_t data_type, const void *grad_output, const void *input, + const void *rois, const void *offset, void *grad_input, void *grad_offset, + const int channels, const int height, const int width, const int num_rois, + const int pooled_height, const int pooled_width, const float spatial_scale, + const int sampling_ratio, const float gamma); + +// policy function for forward and backward +static void policyFunc(const int bin_num, cnrtDim3_t *k_dim, + cnrtFunctionType_t *k_type) { + const size_t cluster_limit = torch_mlu::getDeviceAttr(cnrtAttrClusterCount); + ; + const size_t core_limit = torch_mlu::getDeviceAttr(cnrtAttrMcorePerCluster); + const size_t bin_num_align = CEIL_ALIGN(bin_num, core_limit); + k_dim->x = core_limit; + k_dim->y = (bin_num_align / core_limit) > cluster_limit + ? cluster_limit + : (bin_num_align / core_limit); + k_dim->z = 1; + *k_type = CNRT_FUNC_TYPE_UNION1; +} + +void DeformRoIPoolForwardMLUKernelLauncher(Tensor input, Tensor rois, + Tensor offset, Tensor output, + int pooled_height, int pooled_width, + float spatial_scale, + int sampling_ratio, float gamma) { + // Check dtype. + TORCH_CHECK( + input.scalar_type() == at::kFloat || input.scalar_type() == at::kHalf, + "input type should be Float or Half, got ", input.scalar_type()); + TORCH_CHECK(input.scalar_type() == rois.scalar_type(), + "rois should have the same type as input"); + + // Check shape. + TORCH_CHECK(input.dim() == 4, "input should be 4d tensor, got ", input.dim(), + "D."); + TORCH_CHECK(rois.dim() == 2, "rois should be 2d tensor, got ", rois.dim(), + "D."); + if (offset.defined() && offset.numel() > 0) { + TORCH_CHECK(input.scalar_type() == offset.scalar_type(), + "offset should have the same type as input"); + TORCH_CHECK(offset.dim() == 4, "offset should be 4d tensor, got ", + offset.dim(), "D."); + TORCH_CHECK( + (offset.size(0) == rois.size(0)), "offset.size(0) = ", offset.size(0), + "while rois.size(0)) = ", rois.size(0), ". They should be the same."); + TORCH_CHECK((offset.size(1) == 2), "offset.size(1) should be 2, ", + "but now offset.size(1) = ", offset.size(1), "."); + TORCH_CHECK((offset.size(2) == output.size(2)), + "offset.size(2) = ", offset.size(2), + "while output.size(2)) = ", output.size(2), + ". They should be the same."); + TORCH_CHECK((offset.size(3) == output.size(3)), + "offset.size(3) = ", offset.size(3), + "while output.size(3)) = ", output.size(3), + ". They should be the same."); + } + + TORCH_CHECK(spatial_scale > 0 && spatial_scale <= 1, + "spatial_scale should be within (0, 1], got ", spatial_scale, + "."); + + // compute kernel params + auto height = input.size(2); + auto width = input.size(3); + auto channels = input.size(1); + auto num_rois = output.size(0); + + if (output.numel() == 0) { + output = at::zeros({num_rois, channels, pooled_height, pooled_width}, + input.options()); + return; + } + + // zero element check + TORCH_CHECK(input.size(0) != 0, "input.size(0) should not be zero, got ", + input.size(0)); + TORCH_CHECK(rois.numel() != 0, "rois.numel() should not be zero, got ", + rois.numel()); + if (input.numel() == 0 || output.numel() == 0) { + return; + } + + // large tensor check + const size_t max_input_num = 2147483648; // 2^31, 2G num + TORCH_CHECK(input.numel() < max_input_num, + "input.numel() should be less than 2147483648, got ", + input.numel()); + TORCH_CHECK(rois.numel() < max_input_num, + "rois.numel() should be less than 2147483648, got ", + rois.numel()); + TORCH_CHECK(output.numel() < max_input_num, + "output.numel() should be less than 2147483648, got ", + output.numel()); + TORCH_CHECK(!offset.defined() || offset.numel() < max_input_num, + "offset.numel() should be less than 2147483648, got ", + offset.numel()); + + auto memory_format = + torch_mlu::cnnl::ops::get_channels_last_memory_format(input.dim()); + auto input_ = torch_mlu::cnnl::ops::cnnl_contiguous(input, memory_format); + + at::Tensor output_ = + at::empty({num_rois, channels, pooled_height, pooled_width}, + input.options(), memory_format); + + // calculate task dimension + cnrtDim3_t k_dim; + cnrtFunctionType_t k_type; + policyFunc(num_rois * pooled_height * pooled_width, &k_dim, &k_type); + + // get compute queue + auto queue = torch_mlu::getCurQueue(); + + // get ptr of tensors + auto input_impl = torch_mlu::getMluTensorImpl(input_); + auto input_ptr = input_impl->cnnlMalloc(); + auto rois_impl = torch_mlu::getMluTensorImpl(rois); + auto rois_ptr = rois_impl->cnnlMalloc(); + auto offset_impl = torch_mlu::getMluTensorImpl(offset); + auto offset_ptr = offset_impl->cnnlMalloc(); + auto output_impl = torch_mlu::getMluTensorImpl(output_); + auto output_ptr = output_impl->cnnlMalloc(); + + // get comput dtype of input + cnrtDataType_t data_type = torch_mlu::toCnrtDtype(input_.dtype()); + + // launch kernel + CNLOG(INFO) << "Launch Kernel MLUKernelDeformRoIPoolForward<<<" << k_dim.x + << ", " << k_dim.y << ", " << k_dim.z << ">>>"; + + KernelDeformRoIPoolForward(k_dim, k_type, queue, data_type, input_ptr, + rois_ptr, offset_ptr, output_ptr, channels, height, + width, num_rois, pooled_height, pooled_width, + spatial_scale, sampling_ratio, gamma); + + output.copy_(output_); +} + +void DeformRoIPoolBackwardMLUKernelLauncher( + Tensor grad_output, Tensor input, Tensor rois, Tensor offset, + Tensor grad_input, Tensor grad_offset, int pooled_height, int pooled_width, + float spatial_scale, int sampling_ratio, float gamma) { + // Check dtype. + TORCH_CHECK( + input.scalar_type() == at::kFloat || input.scalar_type() == at::kHalf, + "input type should be Float or Half, got ", input.scalar_type()); + TORCH_CHECK(input.scalar_type() == grad_output.scalar_type(), + "grad_output should have the same type as input"); + TORCH_CHECK(input.scalar_type() == rois.scalar_type(), + "rois should have the same type as input"); + TORCH_CHECK(input.scalar_type() == grad_input.scalar_type(), + "grad_input should have the same type as input"); + + // Check shape. + TORCH_CHECK(grad_output.dim() == 4, "grad_output should be 4d tensor, got ", + grad_output.dim(), "D."); + TORCH_CHECK(input.dim() == 4, "input should be 4d tensor, got ", input.dim(), + "D."); + TORCH_CHECK(rois.dim() == 2, "rois should be 2d tensor, got ", rois.dim(), + "D."); + if (offset.defined() && offset.numel() > 0) { + TORCH_CHECK(input.scalar_type() == offset.scalar_type(), + "offset should have the same type as input"); + TORCH_CHECK(offset.dim() == 4, "offset should be 4d tensor, got ", + offset.dim(), "D."); + TORCH_CHECK( + (offset.size(0) == rois.size(0)), "offset.size(0) = ", offset.size(0), + "while rois.size(0)) = ", rois.size(0), ". They should be the same."); + TORCH_CHECK((offset.size(1) == 2), "offset.size(1) should be 2, ", + "but now offset.size(1) = ", offset.size(1), "."); + TORCH_CHECK((offset.size(2) == grad_output.size(2)), + "offset.size(2) = ", offset.size(2), + "while grad_output.size(2)) = ", grad_output.size(2), + ". They should be the same."); + TORCH_CHECK((offset.size(3) == grad_output.size(3)), + "offset.size(3) = ", offset.size(3), + "while grad_output.size(3)) = ", grad_output.size(3), + ". They should be the same."); + } + + TORCH_CHECK(spatial_scale > 0 && spatial_scale <= 1, + "spatial_scale should be within (0, 1], got ", spatial_scale); + + // Check relationship between tensor. + TORCH_CHECK((grad_output.size(0) == rois.size(0)), + "grad_output.size(0) = ", grad_output.size(0), + "while rois.size(0)) = ", rois.size(0), + ". They should be the same."); + TORCH_CHECK((grad_output.size(1) == input.size(1)), + "grad_output.size(1) = ", grad_output.size(1), + "while input.size(1)) = ", input.size(1), + ". They should be the same."); + TORCH_CHECK((grad_output.size(2) == pooled_height), + "grad_output.size(2) = ", grad_output.size(2), + "while pooled_height = ", pooled_height, + ". They should be the same."); + TORCH_CHECK((grad_output.size(3) == pooled_width), + "grad_output.size(3) = ", grad_output.size(3), + "while pooled_width = ", pooled_width, + ". They should be the same."); + + // compute kernel params + auto batch = input.size(0); + auto channels = input.size(1); + auto height = input.size(2); + auto width = input.size(3); + auto num_rois = grad_output.size(0); + + // zero element check + TORCH_CHECK(input.size(0) != 0, "input.size(0) should not be zero, got ", + input.size(0)); + TORCH_CHECK(rois.numel() != 0, "rois.numel() should not be zero, got ", + rois.numel()); + if (input.numel() == 0 || grad_output.numel() == 0) { + return; + } + + // large tensor check + const size_t max_input_num = 2147483648; // 2^31, 2G num + TORCH_CHECK(input.numel() < max_input_num, + "input.numel() should be less than 2147483648, got ", + input.numel()); + TORCH_CHECK(rois.numel() < max_input_num, + "rois.numel() should be less than 2147483648, got ", + rois.numel()); + TORCH_CHECK(grad_output.numel() < max_input_num, + "grad_output.numel() should be less than 2147483648, got ", + grad_output.numel()); + TORCH_CHECK(!offset.defined() || offset.numel() < max_input_num, + "offset.numel() should be less than 2147483648, got ", + offset.numel()); + + auto memory_format = + torch_mlu::cnnl::ops::get_channels_last_memory_format(grad_output.dim()); + auto grad_output_ = + torch_mlu::cnnl::ops::cnnl_contiguous(grad_output, memory_format); + memory_format = + torch_mlu::cnnl::ops::get_channels_last_memory_format(input.dim()); + auto input_ = torch_mlu::cnnl::ops::cnnl_contiguous(input, memory_format); + at::Tensor grad_input_ = at::empty({batch, channels, height, width}, + input.options(), memory_format) + .zero_(); + + // calculate task dimension + cnrtDim3_t k_dim; + cnrtFunctionType_t k_type; + policyFunc(num_rois * pooled_height * pooled_width, &k_dim, &k_type); + + // get compute queue + auto queue = torch_mlu::getCurQueue(); + + // get ptr of tensors + auto grad_output_impl = torch_mlu::getMluTensorImpl(grad_output_); + auto grad_output_ptr = grad_output_impl->cnnlMalloc(); + auto input_impl = torch_mlu::getMluTensorImpl(input_); + auto input_ptr = input_impl->cnnlMalloc(); + auto rois_impl = torch_mlu::getMluTensorImpl(rois); + auto rois_ptr = rois_impl->cnnlMalloc(); + auto offset_impl = torch_mlu::getMluTensorImpl(offset); + auto offset_ptr = offset_impl->cnnlMalloc(); + auto grad_input_impl = torch_mlu::getMluTensorImpl(grad_input_); + auto grad_input_ptr = grad_input_impl->cnnlMalloc(); + auto grad_offset_impl = torch_mlu::getMluTensorImpl(grad_offset); + auto grad_offset_ptr = grad_offset_impl->cnnlMalloc(); + + // get comput dtype of input + cnrtDataType_t data_type = torch_mlu::toCnrtDtype(input.dtype()); + + // launch kernel + CNLOG(INFO) << "Launch Kernel KernelDeformRoIPoolBackward<<<" << k_dim.x + << ", " << k_dim.y << ", " << k_dim.z << ">>>"; + + KernelDeformRoIPoolBackward(k_dim, k_type, queue, data_type, grad_output_ptr, + input_ptr, rois_ptr, offset_ptr, grad_input_ptr, + grad_offset_ptr, channels, height, width, + num_rois, pooled_height, pooled_width, + spatial_scale, sampling_ratio, gamma); + + grad_input.copy_(grad_input_); +} + +void deform_roi_pool_forward_mlu(Tensor input, Tensor rois, Tensor offset, + Tensor output, int pooled_height, + int pooled_width, float spatial_scale, + int sampling_ratio, float gamma) { + DeformRoIPoolForwardMLUKernelLauncher(input, rois, offset, output, + pooled_height, pooled_width, + spatial_scale, sampling_ratio, gamma); +} + +void deform_roi_pool_backward_mlu(Tensor grad_output, Tensor input, Tensor rois, + Tensor offset, Tensor grad_input, + Tensor grad_offset, int pooled_height, + int pooled_width, float spatial_scale, + int sampling_ratio, float gamma) { + DeformRoIPoolBackwardMLUKernelLauncher( + grad_output, input, rois, offset, grad_input, grad_offset, pooled_height, + pooled_width, spatial_scale, sampling_ratio, gamma); +} + +void deform_roi_pool_forward_impl(Tensor input, Tensor rois, Tensor offset, + Tensor output, int pooled_height, + int pooled_width, float spatial_scale, + int sampling_ratio, float gamma); + +void deform_roi_pool_backward_impl(Tensor grad_output, Tensor input, + Tensor rois, Tensor offset, + Tensor grad_input, Tensor grad_offset, + int pooled_height, int pooled_width, + float spatial_scale, int sampling_ratio, + float gamma); + +REGISTER_DEVICE_IMPL(deform_roi_pool_forward_impl, MLU, + deform_roi_pool_forward_mlu); +REGISTER_DEVICE_IMPL(deform_roi_pool_backward_impl, MLU, + deform_roi_pool_backward_mlu); diff --git a/tests/test_ops/test_deform_roi_pool.py b/tests/test_ops/test_deform_roi_pool.py index 37a279ec9b..5c48e6f777 100644 --- a/tests/test_ops/test_deform_roi_pool.py +++ b/tests/test_ops/test_deform_roi_pool.py @@ -2,8 +2,11 @@ import os import numpy as np +import pytest import torch +from mmcv.utils import IS_CUDA_AVAILABLE, IS_MLU_AVAILABLE + _USING_PARROTS = True try: from parrots.autograd import gradcheck @@ -93,3 +96,53 @@ def test_modulated_deform_roi_pool_gradcheck(self): gradcheck(droipool, (x, rois), no_grads=[rois]) else: gradcheck(droipool, (x, rois), eps=1e-2, atol=1e-2) + + def _test_deform_roi_pool_allclose(self, device, dtype=torch.float): + from mmcv.ops import DeformRoIPoolPack + pool_h = 2 + pool_w = 2 + spatial_scale = 1.0 + sampling_ratio = 2 + + for case, output in zip(inputs, outputs): + np_input = np.array(case[0]) + np_rois = np.array(case[1]) + np_output = np.array(output[0]) + np_grad = np.array(output[1]) + + x = torch.tensor( + np_input, device=device, dtype=torch.float, requires_grad=True) + rois = torch.tensor(np_rois, device=device, dtype=torch.float) + output_c = x.size(1) + droipool = DeformRoIPoolPack( + (pool_h, pool_w), + output_c, + spatial_scale=spatial_scale, + sampling_ratio=sampling_ratio).to(device) + + output = droipool(x, rois) + output.backward(torch.ones_like(output)) + assert np.allclose(output.data.cpu().numpy(), np_output, 1e-3) + assert np.allclose(x.grad.data.cpu().numpy(), np_grad, 1e-3) + + @pytest.mark.parametrize('device', [ + pytest.param( + 'cuda', + marks=pytest.mark.skipif( + not IS_CUDA_AVAILABLE, reason='requires CUDA support')), + pytest.param( + 'mlu', + marks=pytest.mark.skipif( + not IS_MLU_AVAILABLE, reason='requires MLU support')) + ]) + @pytest.mark.parametrize('dtype', [ + torch.float, + pytest.param( + torch.double, + marks=pytest.mark.skipif( + IS_MLU_AVAILABLE, + reason='MLU does not support for 64-bit floating point')), + torch.half + ]) + def test_deform_roi_pool_allclose(self, device, dtype): + self._test_deform_roi_pool_allclose(device, dtype)