diff --git a/colossalai/kernel/cuda_native/csrc/scaled_upper_triang_masked_softmax_cuda.cu b/colossalai/kernel/cuda_native/csrc/scaled_upper_triang_masked_softmax_cuda.cu index 9dbb63476..62c56e6f7 100644 --- a/colossalai/kernel/cuda_native/csrc/scaled_upper_triang_masked_softmax_cuda.cu +++ b/colossalai/kernel/cuda_native/csrc/scaled_upper_triang_masked_softmax_cuda.cu @@ -2,12 +2,13 @@ * with minor changes. */ #include +#include #include -#include #include #include -#include +#include #include + #include "scaled_upper_triang_masked_softmax.h" #include "type_shim.h" @@ -15,18 +16,15 @@ namespace multihead_attn { namespace fused_softmax { namespace scaled_upper_triang_masked_softmax { -torch::Tensor fwd_cuda( - torch::Tensor const& input, - float scale_factor) -{ +torch::Tensor fwd_cuda(torch::Tensor const& input, float scale_factor) { // input is a 3d tensor with dimensions [attn_batches, seq_len, seq_len] const int attn_batches = input.size(0); const int seq_len = input.size(1); TORCH_INTERNAL_ASSERT(seq_len <= 2048); - // Output + // Output auto act_options = input.options().requires_grad(false); - torch::Tensor softmax_results = + torch::Tensor softmax_results = torch::empty({attn_batches, seq_len, seq_len}, act_options); // Softmax Intermediate Result Ptr @@ -36,50 +34,42 @@ torch::Tensor fwd_cuda( DISPATCH_HALF_AND_BFLOAT( input.scalar_type(), "dispatch_scaled_upper_triang_masked_softmax_forward", - dispatch_scaled_upper_triang_masked_softmax_forward( - reinterpret_cast(softmax_results_ptr), - reinterpret_cast(input_ptr), - scale_factor, - seq_len, - seq_len, - attn_batches); - ); + dispatch_scaled_upper_triang_masked_softmax_forward( + reinterpret_cast(softmax_results_ptr), + reinterpret_cast(input_ptr), scale_factor, seq_len, + seq_len, attn_batches);); return softmax_results; } - -torch::Tensor bwd_cuda( - torch::Tensor const& output_grads_, - torch::Tensor const& softmax_results_, - float scale_factor) { - +torch::Tensor bwd_cuda(torch::Tensor const& output_grads_, + torch::Tensor const& softmax_results_, + float scale_factor) { auto output_grads = output_grads_.contiguous(); auto softmax_results = softmax_results_.contiguous(); - //output grads is a 3d tensor with dimensions [attn_batches, seq_len, seq_len] + // output grads is a 3d tensor with dimensions [attn_batches, seq_len, + // seq_len] const int attn_batches = output_grads.size(0); const int seq_len = output_grads.size(1); TORCH_INTERNAL_ASSERT(output_grads.size(1) == output_grads.size(2)); void* output_grads_ptr = static_cast(output_grads.data_ptr()); - //Softmax Grad + // Softmax Grad DISPATCH_HALF_AND_BFLOAT( output_grads_.scalar_type(), "dispatch_scaled_upper_triang_masked_softmax_backward", - dispatch_scaled_upper_triang_masked_softmax_backward( - reinterpret_cast(output_grads_ptr), - reinterpret_cast(output_grads_ptr), - reinterpret_cast(softmax_results.data_ptr()), - scale_factor, - seq_len, - seq_len, - attn_batches); - ); - - //backward pass is completely in-place + dispatch_scaled_upper_triang_masked_softmax_backward( + reinterpret_cast(output_grads_ptr), + reinterpret_cast(output_grads_ptr), + reinterpret_cast(softmax_results.data_ptr()), + scale_factor, seq_len, seq_len, attn_batches);); + + // backward pass is completely in-place return output_grads; } -} -} -} +} // namespace scaled_upper_triang_masked_softmax +} // namespace fused_softmax +} // namespace multihead_attn