[NFC] polish colossalai/kernel/cuda_native/csrc/scaled_masked_softmax_cuda.cu code style (#949)

Co-authored-by: Jiatong <jiatong.han@u.nus.edu>
pull/997/head
JT.Han 3 years ago committed by binmakeswell
parent 72c71b67ec
commit c3e423c8be

@ -2,12 +2,13 @@
* with minor changes. */ * with minor changes. */
#include <ATen/ATen.h> #include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include <cuda.h> #include <cuda.h>
#include <cuda_runtime.h>
#include <cuda_fp16.h> #include <cuda_fp16.h>
#include <cuda_profiler_api.h> #include <cuda_profiler_api.h>
#include <ATen/cuda/CUDAContext.h> #include <cuda_runtime.h>
#include <torch/extension.h> #include <torch/extension.h>
#include "scaled_masked_softmax.h" #include "scaled_masked_softmax.h"
#include "type_shim.h" #include "type_shim.h"
@ -15,17 +16,15 @@ namespace multihead_attn {
namespace fused_softmax { namespace fused_softmax {
namespace scaled_masked_softmax { namespace scaled_masked_softmax {
int get_batch_per_block_cuda(int query_seq_len, int key_seq_len, int batches, int attn_heads){ int get_batch_per_block_cuda(int query_seq_len, int key_seq_len, int batches,
return get_batch_per_block(query_seq_len, key_seq_len, batches, attn_heads); int attn_heads) {
return get_batch_per_block(query_seq_len, key_seq_len, batches, attn_heads);
} }
torch::Tensor fwd_cuda(torch::Tensor const& input, torch::Tensor const& mask,
torch::Tensor fwd_cuda( float scale_factor) {
torch::Tensor const& input, // input is a 4d tensor with dimensions [batches, attn_heads, seq_len,
torch::Tensor const& mask, // seq_len]
float scale_factor)
{
// input is a 4d tensor with dimensions [batches, attn_heads, seq_len, seq_len]
const int batches = input.size(0); const int batches = input.size(0);
const int pad_batches = mask.size(0); const int pad_batches = mask.size(0);
const int attn_heads = input.size(1); const int attn_heads = input.size(1);
@ -38,10 +37,10 @@ torch::Tensor fwd_cuda(
TORCH_INTERNAL_ASSERT(mask.size(2) == query_seq_len); TORCH_INTERNAL_ASSERT(mask.size(2) == query_seq_len);
TORCH_INTERNAL_ASSERT(mask.size(3) == key_seq_len); TORCH_INTERNAL_ASSERT(mask.size(3) == key_seq_len);
// Output // Output
auto act_options = input.options().requires_grad(false); auto act_options = input.options().requires_grad(false);
torch::Tensor softmax_results = torch::Tensor softmax_results = torch::empty(
torch::empty({batches, attn_heads, query_seq_len, key_seq_len}, act_options); {batches, attn_heads, query_seq_len, key_seq_len}, act_options);
// Softmax Intermediate Result Ptr // Softmax Intermediate Result Ptr
void* input_ptr = static_cast<void*>(input.data_ptr()); void* input_ptr = static_cast<void*>(input.data_ptr());
@ -49,31 +48,23 @@ torch::Tensor fwd_cuda(
void* softmax_results_ptr = static_cast<void*>(softmax_results.data_ptr()); void* softmax_results_ptr = static_cast<void*>(softmax_results.data_ptr());
DISPATCH_HALF_AND_BFLOAT( DISPATCH_HALF_AND_BFLOAT(
input.scalar_type(), input.scalar_type(), "dispatch_scaled_masked_softmax_forward",
"dispatch_scaled_masked_softmax_forward",
dispatch_scaled_masked_softmax_forward<scalar_t, scalar_t, float>( dispatch_scaled_masked_softmax_forward<scalar_t, scalar_t, float>(
reinterpret_cast<scalar_t*>(softmax_results_ptr), reinterpret_cast<scalar_t*>(softmax_results_ptr),
reinterpret_cast<const scalar_t*>(input_ptr), reinterpret_cast<const scalar_t*>(input_ptr),
reinterpret_cast<const uint8_t*>(mask_ptr), reinterpret_cast<const uint8_t*>(mask_ptr), scale_factor,
scale_factor, query_seq_len, key_seq_len, batches, attn_heads, pad_batches););
query_seq_len,
key_seq_len,
batches,
attn_heads,
pad_batches);
);
return softmax_results; return softmax_results;
} }
torch::Tensor bwd_cuda( torch::Tensor bwd_cuda(torch::Tensor const& output_grads_,
torch::Tensor const& output_grads_, torch::Tensor const& softmax_results_,
torch::Tensor const& softmax_results_, float scale_factor) {
float scale_factor) {
auto output_grads = output_grads_.contiguous(); auto output_grads = output_grads_.contiguous();
auto softmax_results = softmax_results_.contiguous(); auto softmax_results = softmax_results_.contiguous();
//output grads is a 4d tensor with dimensions [batches, attn_heads, seq_len, seq_len] // output grads is a 4d tensor with dimensions [batches, attn_heads, seq_len,
// seq_len]
const int batches = output_grads.size(0); const int batches = output_grads.size(0);
const int attn_heads = output_grads.size(1); const int attn_heads = output_grads.size(1);
const int query_seq_len = output_grads.size(2); const int query_seq_len = output_grads.size(2);
@ -81,24 +72,18 @@ torch::Tensor bwd_cuda(
void* output_grads_ptr = static_cast<void*>(output_grads.data_ptr()); void* output_grads_ptr = static_cast<void*>(output_grads.data_ptr());
//Softmax Grad // Softmax Grad
DISPATCH_HALF_AND_BFLOAT( DISPATCH_HALF_AND_BFLOAT(
output_grads_.scalar_type(), output_grads_.scalar_type(), "dispatch_scaled_masked_softmax_backward",
"dispatch_scaled_masked_softmax_backward",
dispatch_scaled_masked_softmax_backward<scalar_t, scalar_t, float>( dispatch_scaled_masked_softmax_backward<scalar_t, scalar_t, float>(
reinterpret_cast<scalar_t*>(output_grads_ptr), reinterpret_cast<scalar_t*>(output_grads_ptr),
reinterpret_cast<scalar_t*>(output_grads_ptr), reinterpret_cast<scalar_t*>(output_grads_ptr),
reinterpret_cast<scalar_t const*>(softmax_results.data_ptr()), reinterpret_cast<scalar_t const*>(softmax_results.data_ptr()),
scale_factor, scale_factor, query_seq_len, key_seq_len, batches, attn_heads););
query_seq_len,
key_seq_len, // backward pass is completely in-place
batches,
attn_heads);
);
//backward pass is completely in-place
return output_grads; return output_grads;
} }
} } // namespace scaled_masked_softmax
} } // namespace fused_softmax
} } // namespace multihead_attn

Loading…
Cancel
Save