Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Softmax] Add online softmax according to Nvidia Paper (#60) #61

Merged
merged 1 commit into from
Oct 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
107 changes: 107 additions & 0 deletions softmax/softmax.cu
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,35 @@
#define HALF2(value) (reinterpret_cast<half2*>(&(value))[0])
#define BFLOAT2(value) (reinterpret_cast<__nv_bfloat162*>(&(value))[0])
#define LDST128BITS(value) (reinterpret_cast<float4*>(&(value))[0])
// DS required for Online Softmax
struct __align__(8) MD
{
float m;
float d;
};

// -------------------------------------- FP32 --------------------------------------
// Warp Reduce for Online Softmax

template<const int kWarpSize = WARP_SIZE >
__device__ __forceinline__ MD warp_reduce_md_op(MD value) {
unsigned int mask = 0xffffffff;
#pragma unroll
for(int stride = kWarpSize >> 1; stride >= 1; stride >>= 1) {
MD other;
other.m = __shfl_xor_sync(mask, value.m, stride);
other.d = __shfl_xor_sync(mask, value.d, stride);

bool value_bigger = (value.m > other.m);
MD bigger_m = value_bigger ? value : other;
MD smaller_m = value_bigger ? other : value;

value.d = bigger_m.d + smaller_m.d * __expf(smaller_m.m - bigger_m.m);
value.m = bigger_m.m;
}
return value;
}

// Warp Reduce Sum
template<const int kWarpSize = WARP_SIZE>
__device__ __forceinline__ float warp_reduce_sum_f32(float val) {
Expand Down Expand Up @@ -289,6 +316,40 @@ __global__ void safe_softmax_f16x8_pack_f32_per_token_kernel(half* x, half* y, i
// TODO: support non 8-multiple K here
}

template<const int NUM_THREADS = 256 >
__global__ void online_softmax_f32_per_token_kernel(const float* x, float* y, int N) {

int local_tid = threadIdx.x;
int global_tid = blockIdx.x * NUM_THREADS + threadIdx.x;
const int WAPR_NUM = NUM_THREADS / WARP_SIZE;
int warp_id = local_tid / WARP_SIZE;
int lane_id = local_tid % WARP_SIZE;
MD val;
val.m = global_tid < N ? x[global_tid] : -FLT_MAX;
val.d = global_tid < N ? 1.0f : 0.0f;

__shared__ MD shared[ WAPR_NUM ];
MD res = warp_reduce_md_op<WARP_SIZE>(val);

if (lane_id == 0) shared[warp_id] = res;
__syncthreads();

if (local_tid < WARP_SIZE) {
MD block_res = shared[local_tid];
block_res = warp_reduce_md_op<WAPR_NUM>(block_res);
if (local_tid == 0) {
shared[0] = block_res;
}
}
__syncthreads();

MD final_res = shared[0];
float d_total_inverse = __fdividef(1.0f, final_res.d);
if (global_tid < N) {
y[global_tid] = __expf(x[global_tid] - final_res.m) * d_total_inverse;
}
}

// --------------------- PyTorch bindings for custom kernel -----------------------
#define STRINGFY(str) #str
#define TORCH_BINDING_COMMON_EXTENSION(func) \
Expand Down Expand Up @@ -440,6 +501,41 @@ safe_softmax_f32_per_token_kernel<(H)><<<grid, block>>>( \
break; \
}

// online softmax per token
#define LANUCH_ONLINE_SOFTMAX_F32_PER_TOKEN_KERNEL(H) \
online_softmax_f32_per_token_kernel<(H)><<<grid, block>>>( \
reinterpret_cast<float*>(x.data_ptr()), \
reinterpret_cast<float*>(y.data_ptr()), \
N);

#define DISPATCH_ONLINE_SOFTMAX_F32_PER_TOKEN_KERNEL(S, H) \
dim3 block((H)); \
dim3 grid((S)); \
switch ((H)) \
{ \
case 32: \
LANUCH_ONLINE_SOFTMAX_F32_PER_TOKEN_KERNEL(32) \
break; \
case 64: \
LANUCH_ONLINE_SOFTMAX_F32_PER_TOKEN_KERNEL(64) \
break; \
case 128: \
LANUCH_ONLINE_SOFTMAX_F32_PER_TOKEN_KERNEL(128) \
break; \
case 256: \
LANUCH_ONLINE_SOFTMAX_F32_PER_TOKEN_KERNEL(256) \
break; \
case 512: \
LANUCH_ONLINE_SOFTMAX_F32_PER_TOKEN_KERNEL(512) \
break; \
case 1024: \
LANUCH_ONLINE_SOFTMAX_F32_PER_TOKEN_KERNEL(1024) \
break; \
default: \
throw std::runtime_error( \
"only support H: 64/128/256/512/1024"); \
break; \
}
#define LANUCH_SAFE_SOFTMAX_F32x4_PER_TOKEN_KERNEL(H) \
safe_softmax_f32x4_per_token_kernel<(H)/4><<< \
grid, block>>>( \
Expand Down Expand Up @@ -674,6 +770,16 @@ void safe_softmax_f16x8_pack_f32_per_token(torch::Tensor x, torch::Tensor y) {
DISPATCH_SATE_SOFTMAX_F16x8_PACK_F32_PER_TOKEN_KERNEL(S, H)
}

void online_softmax_f32_per_token(torch::Tensor x, torch::Tensor y) {
CHECK_TORCH_TENSOR_DTYPE(x, torch::kFloat32)
CHECK_TORCH_TENSOR_DTYPE(y, torch::kFloat32)
CHECK_TORCH_TENSOR_SHAPE(x, y)
const int S = x.size(0); // seqlens
const int H = x.size(1); // head size/kv_len
const int N = S * H;
DISPATCH_ONLINE_SOFTMAX_F32_PER_TOKEN_KERNEL(S, H)
}

// grid memory fence fp32
TORCH_BINDING_SOFTMAX(f32, torch::kFloat32, float, 1)
TORCH_BINDING_SOFTMAX(f32x4, torch::kFloat32, float, 4)
Expand All @@ -688,4 +794,5 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
TORCH_BINDING_COMMON_EXTENSION(safe_softmax_f16_f32_per_token)
TORCH_BINDING_COMMON_EXTENSION(safe_softmax_f16x2_f32_per_token)
TORCH_BINDING_COMMON_EXTENSION(safe_softmax_f16x8_pack_f32_per_token)
TORCH_BINDING_COMMON_EXTENSION(online_softmax_f32_per_token)
}
3 changes: 3 additions & 0 deletions softmax/softmax.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ def run_benchmark(perf_func: callable, x: torch.Tensor,
run_benchmark(lib.softmax_f32x4_per_token, x, "f32x4(per)", out)
run_benchmark(lib.safe_softmax_f32_per_token, x, "f32(safe)", out)
run_benchmark(lib.safe_softmax_f32x4_per_token, x, "f32x4(safe)", out)
run_benchmark(lib.online_softmax_f32_per_token, x, "f32(online)", out)
run_benchmark(partial(torch.softmax, dim=1, out=out), x, "f32_th(per)")

print("-" * 100)
Expand All @@ -99,6 +100,7 @@ def run_benchmark(perf_func: callable, x: torch.Tensor,
run_benchmark(lib.softmax_f32x4_per_token, x, "f32x4(per)", out)
run_benchmark(lib.safe_softmax_f32_per_token, x, "f32(safe)", out)
run_benchmark(lib.safe_softmax_f32x4_per_token, x, "f32x4(safe)", out)
run_benchmark(lib.online_softmax_f32_per_token, x, "f32(online)", out)
run_benchmark(partial(torch.softmax, dim=1, out=out), x, "f32_th(per)")

print("-" * 100)
Expand All @@ -121,6 +123,7 @@ def run_benchmark(perf_func: callable, x: torch.Tensor,
run_benchmark(lib.softmax_f32x4_per_token, x, "f32x4(per)", out)
run_benchmark(lib.safe_softmax_f32_per_token, x, "f32(safe)", out)
run_benchmark(lib.safe_softmax_f32x4_per_token, x, "f32x4(safe)", out)
run_benchmark(lib.online_softmax_f32_per_token, x, "f32(online)", out)
run_benchmark(partial(torch.softmax, dim=1, out=out), x, "f32_th(per)")

print("-" * 100)
Expand Down