Skip to content

Commit

Permalink
ms_opt
Browse files Browse the repository at this point in the history
  • Loading branch information
Wickyzheng committed Dec 23, 2022
1 parent d205e5d commit b70f205
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 42 deletions.
14 changes: 5 additions & 9 deletions mmcv/ops/csrc/common/mlu/ms_deform_attn_mlu_kernel.mlu
Original file line number Diff line number Diff line change
Expand Up @@ -1371,8 +1371,7 @@ void __mlu_func__ computeData(
const T &w3, const T &w4) {
#if __BANG_ARCH__ > 322
__bang_mul_scalar(top_grad_temp, top_grad, data_attn_weight, deal_num_real);
if (h_low >= 0 && w_low >= 0)
{
if (h_low >= 0 && w_low >= 0) {
__bang_fusion(FUSION_FMA, grad_h_weight, grad_output_nram_tl, (float)(-hw),
grad_h_weight, deal_num_real, deal_num_real);
__bang_fusion(FUSION_FMA, grad_w_weight, grad_output_nram_tl, (float)(-hh),
Expand All @@ -1382,8 +1381,7 @@ void __mlu_func__ computeData(
__bang_mul_scalar(grad_output_nram_tl, grad_output_nram_tl, w1,
deal_num_real);
}
if (h_low >= 0 && w_high <= width - 1)
{
if (h_low >= 0 && w_high <= width - 1) {
__bang_fusion(FUSION_FMA, grad_h_weight, grad_output_nram_tr, (float)(-lw),
grad_h_weight, deal_num_real, deal_num_real);
__bang_fusion(FUSION_FMA, grad_w_weight, grad_output_nram_tr, (float)(hh),
Expand All @@ -1395,8 +1393,7 @@ void __mlu_func__ computeData(
__bang_add(grad_output_nram_tl, grad_output_nram_tl, grad_output_nram_tr,
deal_num_real);
}
if (h_high <= height - 1 && w_low >= 0)
{
if (h_high <= height - 1 && w_low >= 0) {
__bang_fusion(FUSION_FMA, grad_h_weight, grad_output_nram_bl, (float)(hw),
grad_h_weight, deal_num_real, deal_num_real);
__bang_fusion(FUSION_FMA, grad_w_weight, grad_output_nram_bl, (float)(-lh),
Expand All @@ -1408,15 +1405,14 @@ void __mlu_func__ computeData(
__bang_add(grad_output_nram_tl, grad_output_nram_tl, grad_output_nram_bl,
deal_num_real);
}
if (h_high <= height - 1 && w_high <= width - 1)
{
if (h_high <= height - 1 && w_high <= width - 1) {
__bang_fusion(FUSION_FMA, grad_h_weight, grad_output_nram_br, (float)(lw),
grad_h_weight, deal_num_real, deal_num_real);
__bang_fusion(FUSION_FMA, grad_w_weight, grad_output_nram_br, (float)(lh),
grad_w_weight, deal_num_real, deal_num_real);
__bang_mul_scalar(grad_output_nram_br_temp, top_grad_temp, w4,
deal_num_real);
// for calc grad_attn_weight
// for calc grad_attn_weight
__bang_mul_scalar(grad_output_nram_br, grad_output_nram_br, w4,
deal_num_real);
__bang_add(grad_output_nram_tl, grad_output_nram_tl, grad_output_nram_br,
Expand Down
64 changes: 31 additions & 33 deletions mmcv/ops/csrc/pytorch/mlu/ms_deform_attn_mlu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,24 +41,22 @@ void KernelMsDeformAttnForwardSmallChannel(
const int32_t channels, const int32_t num_levels, const int32_t num_queries,
const int32_t num_points, char* data_col_gdram);

typedef enum
{
typedef enum {
MS_DEFORM_ATTN_BACKWARD_DEFAULT = 0,
MS_DEFORM_ATTN_BACKWARD_SMALL_CHANNEL = 1,
} MsDeformAttnBackwardKernelPolicy;

MsDeformAttnBackwardKernelPolicy msDeformAttnBackwardPolicyFunc(const int32_t channels,
const int32_t num_levels,
const int32_t num_points){
MsDeformAttnBackwardKernelPolicy msDeformAttnBackwardPolicyFunc(
const int32_t channels, const int32_t num_levels,
const int32_t num_points) {
const int32_t nram_size = torch_mlu::getDeviceAttr(cnrtAttrNramSizePerMcore);
const uint64_t max_num= nram_size /sizeof(float);
const uint64_t deal_num = 12 * PAD_UP(channels,32)+3 * PAD_UP(
num_levels,32)+3 * num_points;

if (max_num >= deal_num){
return MS_DEFORM_ATTN_BACKWARD_SMALL_CHANNEL;
const uint64_t max_num = nram_size / sizeof(float);
const uint64_t deal_num =
12 * PAD_UP(channels, 32) + 3 * PAD_UP(num_levels, 32) + 3 * num_points;

}
if (max_num >= deal_num) {
return MS_DEFORM_ATTN_BACKWARD_SMALL_CHANNEL;
}

return MS_DEFORM_ATTN_BACKWARD_DEFAULT;
}
Expand Down Expand Up @@ -469,28 +467,28 @@ void ms_deform_attn_mlu_backward(
// launch kernel
CNLOG(INFO) << "Launch Kernel MLUKernelMsDeformAttnBackward<<<" << k_dim.x
<< ", " << k_dim.y << ", " << k_dim.z << ">>>";
MsDeformAttnBackwardKernelPolicy kernelPolicy=
msDeformAttnBackwardPolicyFunc(channels,num_levels,num_points);
MsDeformAttnBackwardKernelPolicy kernelPolicy =
msDeformAttnBackwardPolicyFunc(channels, num_levels, num_points);
switch (kernelPolicy) {
caseMS_DEFORM_ATTN_BACKWARD_DEFAULT : {
KernelMsDeformAttnBackwardDefaultKernel(
k_dim, k_type, queue, data_type, (float*)value_ptr,
(int32_t*)spatial_shapes_ptr, (int32_t*)level_start_index_ptr,
(float*)sampling_loc_ptr, (float*)attn_weight_ptr,
(float*)grad_output_ptr, batch_size, num_keys, num_heads, channels,
num_levels, num_queries, num_points, (float*)grad_value_ptr,
(float*)grad_sampling_loc_ptr, (float*)grad_attn_weight_ptr);
}break;
caseMS_DEFORM_ATTN_BACKWARD_SMALL_CHANNEL :{
KernelMsDeformAttnBackwardSmallChannelsKernel(
k_dim, k_type, queue, data_type, (float*)value_ptr,
(int32_t*)spatial_shapes_ptr, (int32_t*)level_start_index_ptr,
(float*)sampling_loc_ptr, (float*)attn_weight_ptr,
(float*)grad_output_ptr, batch_size, num_keys, num_heads, channels,
num_levels, num_queries, num_points, (float*)grad_value_ptr,
(float*)grad_sampling_loc_ptr, (float*)grad_attn_weight_ptr);
}break;
default:{VLOG(5)<<"NotImplemented.";}
caseMS_DEFORM_ATTN_BACKWARD_DEFAULT : {
KernelMsDeformAttnBackwardDefaultKernel(
k_dim, k_type, queue, data_type, (float*)value_ptr,
(int32_t*)spatial_shapes_ptr, (int32_t*)level_start_index_ptr,
(float*)sampling_loc_ptr, (float*)attn_weight_ptr,
(float*)grad_output_ptr, batch_size, num_keys, num_heads, channels,
num_levels, num_queries, num_points, (float*)grad_value_ptr,
(float*)grad_sampling_loc_ptr, (float*)grad_attn_weight_ptr);
} break;
caseMS_DEFORM_ATTN_BACKWARD_SMALL_CHANNEL : {
KernelMsDeformAttnBackwardSmallChannelsKernel(
k_dim, k_type, queue, data_type, (float*)value_ptr,
(int32_t*)spatial_shapes_ptr, (int32_t*)level_start_index_ptr,
(float*)sampling_loc_ptr, (float*)attn_weight_ptr,
(float*)grad_output_ptr, batch_size, num_keys, num_heads, channels,
num_levels, num_queries, num_points, (float*)grad_value_ptr,
(float*)grad_sampling_loc_ptr, (float*)grad_attn_weight_ptr);
} break;
default: { VLOG(5) << "NotImplemented."; }
}
}

Expand Down

0 comments on commit b70f205

Please sign in to comment.