2021-12-21 04:19:52 +00:00
|
|
|
/*This code from NVIDIA Megatron:
|
|
|
|
* with minor changes. */
|
|
|
|
|
|
|
|
#include <ATen/ATen.h>
|
2022-05-13 13:52:06 +00:00
|
|
|
#include <ATen/cuda/CUDAContext.h>
|
2021-12-21 04:19:52 +00:00
|
|
|
#include <cuda.h>
|
|
|
|
#include <cuda_fp16.h>
|
|
|
|
#include <cuda_profiler_api.h>
|
2022-05-13 13:52:06 +00:00
|
|
|
#include <cuda_runtime.h>
|
2021-12-21 04:19:52 +00:00
|
|
|
#include <torch/extension.h>
|
2022-05-13 13:52:06 +00:00
|
|
|
|
2021-12-21 04:19:52 +00:00
|
|
|
#include "scaled_masked_softmax.h"
|
|
|
|
#include "type_shim.h"
|
|
|
|
|
|
|
|
namespace multihead_attn {
|
|
|
|
namespace fused_softmax {
|
|
|
|
namespace scaled_masked_softmax {
|
|
|
|
|
2022-05-13 13:52:06 +00:00
|
|
|
int get_batch_per_block_cuda(int query_seq_len, int key_seq_len, int batches,
|
|
|
|
int attn_heads) {
|
|
|
|
return get_batch_per_block(query_seq_len, key_seq_len, batches, attn_heads);
|
2021-12-21 04:19:52 +00:00
|
|
|
}
|
|
|
|
|
2022-05-13 13:52:06 +00:00
|
|
|
torch::Tensor fwd_cuda(torch::Tensor const& input, torch::Tensor const& mask,
|
|
|
|
float scale_factor) {
|
|
|
|
// input is a 4d tensor with dimensions [batches, attn_heads, seq_len,
|
|
|
|
// seq_len]
|
2021-12-21 04:19:52 +00:00
|
|
|
const int batches = input.size(0);
|
|
|
|
const int pad_batches = mask.size(0);
|
|
|
|
const int attn_heads = input.size(1);
|
|
|
|
const int query_seq_len = input.size(2);
|
|
|
|
const int key_seq_len = input.size(3);
|
|
|
|
TORCH_INTERNAL_ASSERT(key_seq_len <= 2048);
|
|
|
|
TORCH_INTERNAL_ASSERT(query_seq_len > 1);
|
|
|
|
TORCH_INTERNAL_ASSERT(pad_batches == 1 || pad_batches == batches);
|
|
|
|
TORCH_INTERNAL_ASSERT(mask.size(1) == 1);
|
|
|
|
TORCH_INTERNAL_ASSERT(mask.size(2) == query_seq_len);
|
|
|
|
TORCH_INTERNAL_ASSERT(mask.size(3) == key_seq_len);
|
|
|
|
|
2022-05-13 13:52:06 +00:00
|
|
|
// Output
|
2021-12-21 04:19:52 +00:00
|
|
|
auto act_options = input.options().requires_grad(false);
|
2022-05-13 13:52:06 +00:00
|
|
|
torch::Tensor softmax_results = torch::empty(
|
|
|
|
{batches, attn_heads, query_seq_len, key_seq_len}, act_options);
|
2021-12-21 04:19:52 +00:00
|
|
|
|
|
|
|
// Softmax Intermediate Result Ptr
|
|
|
|
void* input_ptr = static_cast<void*>(input.data_ptr());
|
|
|
|
void* mask_ptr = static_cast<void*>(mask.data_ptr());
|
|
|
|
void* softmax_results_ptr = static_cast<void*>(softmax_results.data_ptr());
|
|
|
|
|
|
|
|
DISPATCH_HALF_AND_BFLOAT(
|
2022-05-13 13:52:06 +00:00
|
|
|
input.scalar_type(), "dispatch_scaled_masked_softmax_forward",
|
2021-12-21 04:19:52 +00:00
|
|
|
dispatch_scaled_masked_softmax_forward<scalar_t, scalar_t, float>(
|
|
|
|
reinterpret_cast<scalar_t*>(softmax_results_ptr),
|
2022-05-13 13:52:06 +00:00
|
|
|
reinterpret_cast<const scalar_t*>(input_ptr),
|
|
|
|
reinterpret_cast<const uint8_t*>(mask_ptr), scale_factor,
|
|
|
|
query_seq_len, key_seq_len, batches, attn_heads, pad_batches););
|
2021-12-21 04:19:52 +00:00
|
|
|
return softmax_results;
|
|
|
|
}
|
|
|
|
|
2022-05-13 13:52:06 +00:00
|
|
|
torch::Tensor bwd_cuda(torch::Tensor const& output_grads_,
|
|
|
|
torch::Tensor const& softmax_results_,
|
|
|
|
float scale_factor) {
|
2021-12-21 04:19:52 +00:00
|
|
|
auto output_grads = output_grads_.contiguous();
|
|
|
|
auto softmax_results = softmax_results_.contiguous();
|
|
|
|
|
2022-05-13 13:52:06 +00:00
|
|
|
// output grads is a 4d tensor with dimensions [batches, attn_heads, seq_len,
|
|
|
|
// seq_len]
|
2021-12-21 04:19:52 +00:00
|
|
|
const int batches = output_grads.size(0);
|
|
|
|
const int attn_heads = output_grads.size(1);
|
|
|
|
const int query_seq_len = output_grads.size(2);
|
|
|
|
const int key_seq_len = output_grads.size(3);
|
|
|
|
|
|
|
|
void* output_grads_ptr = static_cast<void*>(output_grads.data_ptr());
|
|
|
|
|
2022-05-13 13:52:06 +00:00
|
|
|
// Softmax Grad
|
2021-12-21 04:19:52 +00:00
|
|
|
DISPATCH_HALF_AND_BFLOAT(
|
2022-05-13 13:52:06 +00:00
|
|
|
output_grads_.scalar_type(), "dispatch_scaled_masked_softmax_backward",
|
2021-12-21 04:19:52 +00:00
|
|
|
dispatch_scaled_masked_softmax_backward<scalar_t, scalar_t, float>(
|
2022-05-13 13:52:06 +00:00
|
|
|
reinterpret_cast<scalar_t*>(output_grads_ptr),
|
|
|
|
reinterpret_cast<scalar_t*>(output_grads_ptr),
|
|
|
|
reinterpret_cast<scalar_t const*>(softmax_results.data_ptr()),
|
|
|
|
scale_factor, query_seq_len, key_seq_len, batches, attn_heads););
|
|
|
|
|
|
|
|
// backward pass is completely in-place
|
2021-12-21 04:19:52 +00:00
|
|
|
return output_grads;
|
|
|
|
}
|
2022-05-13 13:52:06 +00:00
|
|
|
} // namespace scaled_masked_softmax
|
|
|
|
} // namespace fused_softmax
|
|
|
|
} // namespace multihead_attn
|