From c3e423c8be70c2aec620e6789ac1030cf7dee358 Mon Sep 17 00:00:00 2001 From: "JT.Han" <59948448+JThh@users.noreply.github.com> Date: Fri, 13 May 2022 21:52:06 +0800 Subject: [PATCH] [NFC] polish colossalai/kernel/cuda_native/csrc/scaled_masked_softmax_cuda.cu code style (#949) Co-authored-by: Jiatong --- .../csrc/scaled_masked_softmax_cuda.cu | 81 ++++++++----------- 1 file changed, 33 insertions(+), 48 deletions(-) diff --git a/colossalai/kernel/cuda_native/csrc/scaled_masked_softmax_cuda.cu b/colossalai/kernel/cuda_native/csrc/scaled_masked_softmax_cuda.cu index d2370e9f3..41781ebc7 100644 --- a/colossalai/kernel/cuda_native/csrc/scaled_masked_softmax_cuda.cu +++ b/colossalai/kernel/cuda_native/csrc/scaled_masked_softmax_cuda.cu @@ -2,12 +2,13 @@ * with minor changes. */ #include +#include #include -#include #include #include -#include +#include #include + #include "scaled_masked_softmax.h" #include "type_shim.h" @@ -15,17 +16,15 @@ namespace multihead_attn { namespace fused_softmax { namespace scaled_masked_softmax { -int get_batch_per_block_cuda(int query_seq_len, int key_seq_len, int batches, int attn_heads){ - return get_batch_per_block(query_seq_len, key_seq_len, batches, attn_heads); +int get_batch_per_block_cuda(int query_seq_len, int key_seq_len, int batches, + int attn_heads) { + return get_batch_per_block(query_seq_len, key_seq_len, batches, attn_heads); } - -torch::Tensor fwd_cuda( - torch::Tensor const& input, - torch::Tensor const& mask, - float scale_factor) -{ - // input is a 4d tensor with dimensions [batches, attn_heads, seq_len, seq_len] +torch::Tensor fwd_cuda(torch::Tensor const& input, torch::Tensor const& mask, + float scale_factor) { + // input is a 4d tensor with dimensions [batches, attn_heads, seq_len, + // seq_len] const int batches = input.size(0); const int pad_batches = mask.size(0); const int attn_heads = input.size(1); @@ -38,10 +37,10 @@ torch::Tensor fwd_cuda( TORCH_INTERNAL_ASSERT(mask.size(2) == query_seq_len); TORCH_INTERNAL_ASSERT(mask.size(3) == key_seq_len); - // Output + // Output auto act_options = input.options().requires_grad(false); - torch::Tensor softmax_results = - torch::empty({batches, attn_heads, query_seq_len, key_seq_len}, act_options); + torch::Tensor softmax_results = torch::empty( + {batches, attn_heads, query_seq_len, key_seq_len}, act_options); // Softmax Intermediate Result Ptr void* input_ptr = static_cast(input.data_ptr()); @@ -49,31 +48,23 @@ torch::Tensor fwd_cuda( void* softmax_results_ptr = static_cast(softmax_results.data_ptr()); DISPATCH_HALF_AND_BFLOAT( - input.scalar_type(), - "dispatch_scaled_masked_softmax_forward", + input.scalar_type(), "dispatch_scaled_masked_softmax_forward", dispatch_scaled_masked_softmax_forward( reinterpret_cast(softmax_results_ptr), - reinterpret_cast(input_ptr), - reinterpret_cast(mask_ptr), - scale_factor, - query_seq_len, - key_seq_len, - batches, - attn_heads, - pad_batches); - ); + reinterpret_cast(input_ptr), + reinterpret_cast(mask_ptr), scale_factor, + query_seq_len, key_seq_len, batches, attn_heads, pad_batches);); return softmax_results; } -torch::Tensor bwd_cuda( - torch::Tensor const& output_grads_, - torch::Tensor const& softmax_results_, - float scale_factor) { - +torch::Tensor bwd_cuda(torch::Tensor const& output_grads_, + torch::Tensor const& softmax_results_, + float scale_factor) { auto output_grads = output_grads_.contiguous(); auto softmax_results = softmax_results_.contiguous(); - //output grads is a 4d tensor with dimensions [batches, attn_heads, seq_len, seq_len] + // output grads is a 4d tensor with dimensions [batches, attn_heads, seq_len, + // seq_len] const int batches = output_grads.size(0); const int attn_heads = output_grads.size(1); const int query_seq_len = output_grads.size(2); @@ -81,24 +72,18 @@ torch::Tensor bwd_cuda( void* output_grads_ptr = static_cast(output_grads.data_ptr()); - //Softmax Grad + // Softmax Grad DISPATCH_HALF_AND_BFLOAT( - output_grads_.scalar_type(), - "dispatch_scaled_masked_softmax_backward", + output_grads_.scalar_type(), "dispatch_scaled_masked_softmax_backward", dispatch_scaled_masked_softmax_backward( - reinterpret_cast(output_grads_ptr), - reinterpret_cast(output_grads_ptr), - reinterpret_cast(softmax_results.data_ptr()), - scale_factor, - query_seq_len, - key_seq_len, - batches, - attn_heads); - ); - - //backward pass is completely in-place + reinterpret_cast(output_grads_ptr), + reinterpret_cast(output_grads_ptr), + reinterpret_cast(softmax_results.data_ptr()), + scale_factor, query_seq_len, key_seq_len, batches, attn_heads);); + + // backward pass is completely in-place return output_grads; } -} -} -} +} // namespace scaled_masked_softmax +} // namespace fused_softmax +} // namespace multihead_attn