mirror of https://github.com/hpcaitech/ColossalAI
Recover kernal files
parent
e83b2ce853
commit
7696cead8d
|
@ -1,10 +1,11 @@
|
|||
#include <cooperative_groups.h>
|
||||
|
||||
#include <chrono>
|
||||
#include <ctime>
|
||||
|
||||
#include "kernels.h"
|
||||
|
||||
#include <cooperative_groups.h>
|
||||
|
||||
|
||||
namespace cg = cooperative_groups;
|
||||
|
||||
curandStatePhilox4_32_10_t *curandstate;
|
||||
|
|
|
@ -3,11 +3,10 @@
|
|||
#include <cuda.h>
|
||||
#include <cuda_fp16.h>
|
||||
#include <curand_kernel.h>
|
||||
#include <stdexcept>
|
||||
#include <stdio.h>
|
||||
#include <stdlib.h>
|
||||
|
||||
#include <stdexcept>
|
||||
|
||||
#define MAX_THREADS 1024
|
||||
#define WARP_SIZE 32
|
||||
|
||||
|
@ -133,9 +132,8 @@ __forceinline__ __host__ __device__ int flat_3dim(int id1, int id2, int id3,
|
|||
}
|
||||
|
||||
/* Convert 4-dim tensor index into vector index */
|
||||
__forceinline__ __host__ __device__ int flat_4dim(int id1, int id2, int id3,
|
||||
int id4, int dim2, int dim3,
|
||||
int dim4) {
|
||||
__forceinline__ __host__ __device__ int
|
||||
flat_4dim(int id1, int id2, int id3, int id4, int dim2, int dim3, int dim4) {
|
||||
// return id1*(dim2*dim3*dim4) + id2*(dim3*dim4) + id3*dim4 + id4;
|
||||
int res = id4;
|
||||
|
||||
|
@ -203,9 +201,9 @@ __forceinline__ __host__ __device__ int flat_6dim(int id1, int id2, int id3,
|
|||
}
|
||||
|
||||
/* Convert vector index to 6-dim tensor index */
|
||||
__forceinline__ __host__ __device__ void decompose_6dim(
|
||||
int src, int dim1, int dim2, int dim3, int dim4, int dim5, int *id0,
|
||||
int *id1, int *id2, int *id3, int *id4, int *id5) {
|
||||
__forceinline__ __host__ __device__ void
|
||||
decompose_6dim(int src, int dim1, int dim2, int dim3, int dim4, int dim5,
|
||||
int *id0, int *id1, int *id2, int *id3, int *id4, int *id5) {
|
||||
*id5 = src % dim5;
|
||||
src /= dim5;
|
||||
|
||||
|
@ -223,11 +221,9 @@ __forceinline__ __host__ __device__ void decompose_6dim(
|
|||
}
|
||||
|
||||
/* Convert vector index to 5-dim tensor index */
|
||||
__forceinline__ __host__ __device__ void decompose_5dim(int src, int dim1,
|
||||
int dim2, int dim3,
|
||||
int dim4, int *id0,
|
||||
int *id1, int *id2,
|
||||
int *id3, int *id4) {
|
||||
__forceinline__ __host__ __device__ void
|
||||
decompose_5dim(int src, int dim1, int dim2, int dim3, int dim4, int *id0,
|
||||
int *id1, int *id2, int *id3, int *id4) {
|
||||
*id4 = src % dim4;
|
||||
src /= dim4;
|
||||
|
||||
|
@ -257,9 +253,8 @@ __forceinline__ __host__ __device__ void decompose_4dim(int src, int dim1,
|
|||
}
|
||||
|
||||
/* Convert vector index to 3-dim tensor index */
|
||||
__forceinline__ __host__ __device__ void decompose_3dim(int src, int dim1,
|
||||
int dim2, int *id0,
|
||||
int *id1, int *id2) {
|
||||
__forceinline__ __host__ __device__ void
|
||||
decompose_3dim(int src, int dim1, int dim2, int *id0, int *id1, int *id2) {
|
||||
*id2 = src % dim2;
|
||||
src /= dim2;
|
||||
|
||||
|
|
|
@ -135,10 +135,9 @@ __global__ void bias_add_transform_20314(T *output, const T *input,
|
|||
const T *bias, int dim_3, int dim_4);
|
||||
|
||||
template <>
|
||||
__global__ void bias_add_transform_20314<float>(float *output,
|
||||
const float *input,
|
||||
const float *bias, int dim_3,
|
||||
int dim_4) {
|
||||
__global__ void
|
||||
bias_add_transform_20314<float>(float *output, const float *input,
|
||||
const float *bias, int dim_3, int dim_4) {
|
||||
int id0 = blockIdx.x;
|
||||
int id1 = blockIdx.y;
|
||||
int id2 = blockIdx.z;
|
||||
|
@ -174,10 +173,9 @@ __global__ void bias_add_transform_20314<float>(float *output,
|
|||
}
|
||||
|
||||
template <>
|
||||
__global__ void bias_add_transform_20314<__half>(__half *output,
|
||||
const __half *input,
|
||||
const __half *bias, int dim_3,
|
||||
int dim_4) {
|
||||
__global__ void
|
||||
bias_add_transform_20314<__half>(__half *output, const __half *input,
|
||||
const __half *bias, int dim_3, int dim_4) {
|
||||
int id0 = blockIdx.x;
|
||||
int id1 = blockIdx.y;
|
||||
int id2 = blockIdx.z;
|
||||
|
|
|
@ -1,14 +1,13 @@
|
|||
// modified from
|
||||
// https://github.com/NVIDIA/apex/blob/master/csrc/multi_tensor_apply.cuh
|
||||
// modified from https://github.com/NVIDIA/apex/blob/master/csrc/multi_tensor_apply.cuh
|
||||
#include <ATen/ATen.h>
|
||||
#include <ATen/AccumulateType.h>
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <ATen/cuda/Exceptions.h>
|
||||
#include <assert.h>
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
|
||||
#include "compat.h"
|
||||
|
||||
#include <assert.h>
|
||||
|
||||
// #include <iostream>
|
||||
|
||||
// This header is the one-stop shop for all your multi-tensor apply needs.
|
||||
|
@ -18,108 +17,117 @@ constexpr int depth_to_max_tensors[5] = {110, 64, 48, 36, 30};
|
|||
constexpr int depth_to_max_blocks[5] = {320, 320, 320, 320, 320};
|
||||
|
||||
template <int n>
|
||||
struct TensorListMetadata {
|
||||
void *addresses[n][depth_to_max_tensors[n - 1]];
|
||||
int sizes[depth_to_max_tensors[n - 1]];
|
||||
unsigned char block_to_tensor[depth_to_max_blocks[n - 1]];
|
||||
int block_to_chunk[depth_to_max_blocks[n - 1]]; // I fear this needs to be a
|
||||
// full int.
|
||||
int start_tensor_this_launch;
|
||||
struct TensorListMetadata
|
||||
{
|
||||
void *addresses[n][depth_to_max_tensors[n - 1]];
|
||||
int sizes[depth_to_max_tensors[n - 1]];
|
||||
unsigned char block_to_tensor[depth_to_max_blocks[n - 1]];
|
||||
int block_to_chunk[depth_to_max_blocks[n - 1]]; // I fear this needs to be a full int.
|
||||
int start_tensor_this_launch;
|
||||
};
|
||||
|
||||
template <typename T, typename U, typename... ArgTypes>
|
||||
__global__ void multi_tensor_apply_kernel(int chunk_size,
|
||||
volatile int *noop_flag, T tl,
|
||||
U callable, ArgTypes... args) {
|
||||
// Hand the chunk information to the user-supplied functor to process however
|
||||
// it likes.
|
||||
callable(chunk_size, noop_flag, tl, args...);
|
||||
__global__ void multi_tensor_apply_kernel(
|
||||
int chunk_size,
|
||||
volatile int *noop_flag,
|
||||
T tl,
|
||||
U callable,
|
||||
ArgTypes... args)
|
||||
{
|
||||
// Hand the chunk information to the user-supplied functor to process however it likes.
|
||||
callable(chunk_size, noop_flag, tl, args...);
|
||||
}
|
||||
|
||||
template <int depth, typename T, typename... ArgTypes>
|
||||
void multi_tensor_apply(
|
||||
int block_size, int chunk_size, const at::Tensor &noop_flag,
|
||||
const std::vector<std::vector<at::Tensor>> &tensor_lists, T callable,
|
||||
ArgTypes... args) {
|
||||
TORCH_CHECK(tensor_lists.size() == depth, "tensor_lists.size() != depth");
|
||||
int len0 = tensor_lists[0].size();
|
||||
TORCH_CHECK(len0 > 0, "tensor_lists[0].size() is not > 0");
|
||||
auto ref_device = tensor_lists[0][0].device();
|
||||
TORCH_CHECK(ref_device.type() == at::kCUDA, "expected input to be on cuda");
|
||||
for (int l = 0; l < tensor_lists.size();
|
||||
l++) // No range-based for because I need indices
|
||||
{
|
||||
TORCH_CHECK(tensor_lists[l].size() == len0,
|
||||
"Size mismatch among tensor lists");
|
||||
for (int t = 0; t < tensor_lists[l].size(); t++) {
|
||||
// TODO: Print which tensor fails.
|
||||
bool contiguous_memory = tensor_lists[l][t].is_contiguous();
|
||||
int block_size,
|
||||
int chunk_size,
|
||||
const at::Tensor &noop_flag,
|
||||
const std::vector<std::vector<at::Tensor>> &tensor_lists,
|
||||
T callable,
|
||||
ArgTypes... args)
|
||||
{
|
||||
TORCH_CHECK(tensor_lists.size() == depth, "tensor_lists.size() != depth");
|
||||
int len0 = tensor_lists[0].size();
|
||||
TORCH_CHECK(len0 > 0, "tensor_lists[0].size() is not > 0");
|
||||
auto ref_device = tensor_lists[0][0].device();
|
||||
TORCH_CHECK(ref_device.type() == at::kCUDA, "expected input to be on cuda");
|
||||
for (int l = 0; l < tensor_lists.size(); l++) // No range-based for because I need indices
|
||||
{
|
||||
TORCH_CHECK(tensor_lists[l].size() == len0, "Size mismatch among tensor lists");
|
||||
for (int t = 0; t < tensor_lists[l].size(); t++)
|
||||
{
|
||||
// TODO: Print which tensor fails.
|
||||
bool contiguous_memory = tensor_lists[l][t].is_contiguous();
|
||||
#ifdef VERSION_GE_1_5
|
||||
contiguous_memory =
|
||||
(contiguous_memory ||
|
||||
tensor_lists[l][t].is_contiguous(at::MemoryFormat::ChannelsLast));
|
||||
contiguous_memory = (contiguous_memory || tensor_lists[l][t].is_contiguous(at::MemoryFormat::ChannelsLast));
|
||||
#endif
|
||||
TORCH_CHECK(contiguous_memory, "A tensor was not contiguous.");
|
||||
TORCH_CHECK(tensor_lists[l][t].device() == ref_device,
|
||||
"A tensor was not on the same device as the first tensor");
|
||||
TORCH_CHECK(tensor_lists[l][t].numel() == tensor_lists[0][t].numel(),
|
||||
"Size mismatch");
|
||||
}
|
||||
}
|
||||
|
||||
int ntensors = tensor_lists[0].size();
|
||||
|
||||
TensorListMetadata<depth> tl;
|
||||
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(tensor_lists[0][0]));
|
||||
auto stream = at::cuda::getCurrentCUDAStream();
|
||||
|
||||
tl.start_tensor_this_launch = 0;
|
||||
int loc_block_info = 0;
|
||||
int loc_tensor_info = 0;
|
||||
for (int t = 0; t < ntensors; t++) {
|
||||
tl.sizes[loc_tensor_info] = tensor_lists[0][t].numel();
|
||||
for (int d = 0; d < depth; d++)
|
||||
tl.addresses[d][loc_tensor_info] = tensor_lists[d][t].data_ptr();
|
||||
loc_tensor_info++;
|
||||
|
||||
int chunks_this_tensor =
|
||||
(tensor_lists[0][t].numel() + chunk_size - 1) / chunk_size;
|
||||
|
||||
for (int chunk = 0; chunk < chunks_this_tensor; chunk++) {
|
||||
// std::cout << chunks_this_tensor << std::endl;
|
||||
tl.block_to_tensor[loc_block_info] = loc_tensor_info - 1;
|
||||
tl.block_to_chunk[loc_block_info] = chunk;
|
||||
loc_block_info++;
|
||||
|
||||
bool tensors_full = (loc_tensor_info == depth_to_max_tensors[depth - 1] &&
|
||||
chunk == chunks_this_tensor - 1);
|
||||
bool blocks_full = (loc_block_info == depth_to_max_blocks[depth - 1]);
|
||||
bool last_chunk = (t == ntensors - 1 && chunk == chunks_this_tensor - 1);
|
||||
if (tensors_full || blocks_full || last_chunk) {
|
||||
// using accscalar_t = acc_type<scalar_t, true>;
|
||||
multi_tensor_apply_kernel<<<loc_block_info, block_size, 0, stream>>>(
|
||||
chunk_size, noop_flag.DATA_PTR<int>(), tl, callable, args...);
|
||||
|
||||
AT_CUDA_CHECK(cudaGetLastError());
|
||||
|
||||
// Reset. The control flow possibilities here make my brain hurt.
|
||||
loc_block_info = 0;
|
||||
if (chunk == chunks_this_tensor - 1) {
|
||||
// std::cout << "Hit case 1 " << cond1 << " " << cond2 << " " << cond3
|
||||
// << std::endl;
|
||||
loc_tensor_info = 0;
|
||||
tl.start_tensor_this_launch = t + 1;
|
||||
} else {
|
||||
// std::cout << "Hit case 2 " << cond1 << " " << cond2 << " " << cond3
|
||||
// << std::endl;
|
||||
tl.sizes[0] = tl.sizes[loc_tensor_info - 1];
|
||||
for (int d = 0; d < depth; d++)
|
||||
tl.addresses[d][0] = tl.addresses[d][loc_tensor_info - 1];
|
||||
loc_tensor_info = 1;
|
||||
tl.start_tensor_this_launch = t;
|
||||
TORCH_CHECK(contiguous_memory, "A tensor was not contiguous.");
|
||||
TORCH_CHECK(tensor_lists[l][t].device() == ref_device, "A tensor was not on the same device as the first tensor");
|
||||
TORCH_CHECK(tensor_lists[l][t].numel() == tensor_lists[0][t].numel(), "Size mismatch");
|
||||
}
|
||||
}
|
||||
|
||||
int ntensors = tensor_lists[0].size();
|
||||
|
||||
TensorListMetadata<depth> tl;
|
||||
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(tensor_lists[0][0]));
|
||||
auto stream = at::cuda::getCurrentCUDAStream();
|
||||
|
||||
tl.start_tensor_this_launch = 0;
|
||||
int loc_block_info = 0;
|
||||
int loc_tensor_info = 0;
|
||||
for (int t = 0; t < ntensors; t++)
|
||||
{
|
||||
tl.sizes[loc_tensor_info] = tensor_lists[0][t].numel();
|
||||
for (int d = 0; d < depth; d++)
|
||||
tl.addresses[d][loc_tensor_info] = tensor_lists[d][t].data_ptr();
|
||||
loc_tensor_info++;
|
||||
|
||||
int chunks_this_tensor = (tensor_lists[0][t].numel() + chunk_size - 1) / chunk_size;
|
||||
|
||||
for (int chunk = 0; chunk < chunks_this_tensor; chunk++)
|
||||
{
|
||||
// std::cout << chunks_this_tensor << std::endl;
|
||||
tl.block_to_tensor[loc_block_info] = loc_tensor_info - 1;
|
||||
tl.block_to_chunk[loc_block_info] = chunk;
|
||||
loc_block_info++;
|
||||
|
||||
bool tensors_full = (loc_tensor_info == depth_to_max_tensors[depth - 1] &&
|
||||
chunk == chunks_this_tensor - 1);
|
||||
bool blocks_full = (loc_block_info == depth_to_max_blocks[depth - 1]);
|
||||
bool last_chunk = (t == ntensors - 1 && chunk == chunks_this_tensor - 1);
|
||||
if (tensors_full || blocks_full || last_chunk)
|
||||
{
|
||||
// using accscalar_t = acc_type<scalar_t, true>;
|
||||
multi_tensor_apply_kernel<<<loc_block_info, block_size, 0, stream>>>(
|
||||
chunk_size,
|
||||
noop_flag.DATA_PTR<int>(),
|
||||
tl,
|
||||
callable,
|
||||
args...);
|
||||
|
||||
AT_CUDA_CHECK(cudaGetLastError());
|
||||
|
||||
// Reset. The control flow possibilities here make my brain hurt.
|
||||
loc_block_info = 0;
|
||||
if (chunk == chunks_this_tensor - 1)
|
||||
{
|
||||
// std::cout << "Hit case 1 " << cond1 << " " << cond2 << " " << cond3 << std::endl;
|
||||
loc_tensor_info = 0;
|
||||
tl.start_tensor_this_launch = t + 1;
|
||||
}
|
||||
else
|
||||
{
|
||||
// std::cout << "Hit case 2 " << cond1 << " " << cond2 << " " << cond3 << std::endl;
|
||||
tl.sizes[0] = tl.sizes[loc_tensor_info - 1];
|
||||
for (int d = 0; d < depth; d++)
|
||||
tl.addresses[d][0] = tl.addresses[d][loc_tensor_info - 1];
|
||||
loc_tensor_info = 1;
|
||||
tl.start_tensor_this_launch = t;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
|
@ -3,68 +3,82 @@
|
|||
|
||||
#include <cuda_fp16.h>
|
||||
#include <torch/extension.h>
|
||||
|
||||
#include <vector>
|
||||
|
||||
namespace multihead_attn {
|
||||
namespace fused_softmax {
|
||||
namespace scaled_masked_softmax {
|
||||
|
||||
torch::Tensor fwd_cuda(torch::Tensor const& input, torch::Tensor const& mask,
|
||||
float scale_factor);
|
||||
torch::Tensor fwd_cuda(
|
||||
torch::Tensor const& input,
|
||||
torch::Tensor const& mask,
|
||||
float scale_factor);
|
||||
|
||||
torch::Tensor bwd_cuda(torch::Tensor const& output_grads,
|
||||
torch::Tensor const& softmax_results,
|
||||
float scale_factor);
|
||||
torch::Tensor bwd_cuda(
|
||||
torch::Tensor const& output_grads,
|
||||
torch::Tensor const& softmax_results,
|
||||
float scale_factor);
|
||||
|
||||
int get_batch_per_block_cuda(int query_seq_len, int key_seq_len, int batches,
|
||||
int attn_heads);
|
||||
int get_batch_per_block_cuda(
|
||||
int query_seq_len,
|
||||
int key_seq_len,
|
||||
int batches,
|
||||
int attn_heads);
|
||||
|
||||
torch::Tensor fwd(torch::Tensor const& input, torch::Tensor const& mask,
|
||||
float scale_factor) {
|
||||
torch::Tensor fwd(
|
||||
torch::Tensor const& input,
|
||||
torch::Tensor const& mask,
|
||||
float scale_factor) {
|
||||
AT_ASSERTM(input.dim() == 4, "expected 4D tensor");
|
||||
AT_ASSERTM((input.scalar_type() == at::ScalarType::Half) ||
|
||||
(input.scalar_type() == at::ScalarType::BFloat16),
|
||||
"Only fp16 and bf16 are supported");
|
||||
(input.scalar_type() == at::ScalarType::BFloat16),
|
||||
"Only fp16 and bf16 are supported");
|
||||
AT_ASSERTM(mask.dim() == 4, "expected 4D tensor");
|
||||
|
||||
return fwd_cuda(input, mask, scale_factor);
|
||||
}
|
||||
|
||||
torch::Tensor bwd(torch::Tensor const& output_grads,
|
||||
torch::Tensor const& softmax_results, float scale_factor) {
|
||||
torch::Tensor bwd(
|
||||
torch::Tensor const& output_grads,
|
||||
torch::Tensor const& softmax_results,
|
||||
float scale_factor) {
|
||||
|
||||
AT_ASSERTM(output_grads.dim() == 4, "expected 3D tensor");
|
||||
AT_ASSERTM(softmax_results.dim() == 4, "expected 3D tensor");
|
||||
|
||||
AT_ASSERTM((output_grads.scalar_type() == at::ScalarType::Half) ||
|
||||
(output_grads.scalar_type() == at::ScalarType::BFloat16),
|
||||
"Only fp16 and bf16 are supported");
|
||||
(output_grads.scalar_type() == at::ScalarType::BFloat16),
|
||||
"Only fp16 and bf16 are supported");
|
||||
AT_ASSERTM((softmax_results.scalar_type() == at::ScalarType::Half) ||
|
||||
(softmax_results.scalar_type() == at::ScalarType::BFloat16),
|
||||
"Only fp16 and bf16 are supported");
|
||||
(softmax_results.scalar_type() == at::ScalarType::BFloat16),
|
||||
"Only fp16 and bf16 are supported");
|
||||
|
||||
return bwd_cuda(output_grads, softmax_results, scale_factor);
|
||||
}
|
||||
|
||||
int get_batch_per_block(int query_seq_len, int key_seq_len, int batches,
|
||||
int attn_heads) {
|
||||
return get_batch_per_block_cuda(query_seq_len, key_seq_len, batches,
|
||||
attn_heads);
|
||||
int get_batch_per_block(
|
||||
int query_seq_len,
|
||||
int key_seq_len,
|
||||
int batches,
|
||||
int attn_heads) {
|
||||
return get_batch_per_block_cuda(query_seq_len, key_seq_len, batches, attn_heads);
|
||||
}
|
||||
|
||||
} // end namespace scaled_masked_softmax
|
||||
} // end namespace fused_softmax
|
||||
} // end namespace multihead_attn
|
||||
} // end namespace scaled_masked_softmax
|
||||
} // end namespace fused_softmax
|
||||
} // end namespace multihead_attn
|
||||
|
||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||
m.def("forward", &multihead_attn::fused_softmax::scaled_masked_softmax::fwd,
|
||||
"Self Multihead Attention scaled, time masked softmax -- Forward.");
|
||||
m.def("forward",
|
||||
&multihead_attn::fused_softmax::scaled_masked_softmax::fwd,
|
||||
"Self Multihead Attention scaled, time masked softmax -- Forward.");
|
||||
|
||||
m.def("backward", &multihead_attn::fused_softmax::scaled_masked_softmax::bwd,
|
||||
"Self Multihead Attention scaled, time masked softmax -- Backward.");
|
||||
m.def("backward",
|
||||
&multihead_attn::fused_softmax::scaled_masked_softmax::bwd,
|
||||
"Self Multihead Attention scaled, time masked softmax -- Backward.");
|
||||
|
||||
m.def("get_batch_per_block",
|
||||
&multihead_attn::fused_softmax::scaled_masked_softmax::
|
||||
get_batch_per_block,
|
||||
"Return Batch per block size.");
|
||||
&multihead_attn::fused_softmax::scaled_masked_softmax::get_batch_per_block,
|
||||
"Return Batch per block size."
|
||||
);
|
||||
}
|
||||
|
|
|
@ -4,12 +4,12 @@
|
|||
#pragma once
|
||||
|
||||
#include <assert.h>
|
||||
#include <c10/macros/Macros.h>
|
||||
#include <cuda_fp16.h>
|
||||
#include <stdint.h>
|
||||
|
||||
#include <cfloat>
|
||||
#include <limits>
|
||||
#include <stdint.h>
|
||||
#include <cuda_fp16.h>
|
||||
#include <c10/macros/Macros.h>
|
||||
|
||||
namespace {
|
||||
|
||||
|
@ -17,53 +17,37 @@ template <typename Datatype, int ELEMENTS_PER_LDG>
|
|||
__device__ __inline__ void copy_vector(Datatype *dst, const Datatype *src);
|
||||
|
||||
template <>
|
||||
__device__ __inline__ void copy_vector<c10::BFloat16, 1>(
|
||||
c10::BFloat16 *dst, const c10::BFloat16 *src) {
|
||||
*dst = *src;
|
||||
}
|
||||
__device__ __inline__ void copy_vector<c10::BFloat16, 1>(c10::BFloat16 *dst, const c10::BFloat16 *src) { *dst = *src; }
|
||||
|
||||
template <>
|
||||
__device__ __inline__ void copy_vector<c10::BFloat16, 4>(
|
||||
c10::BFloat16 *dst, const c10::BFloat16 *src) {
|
||||
*((float2 *)dst) = *((float2 *)src);
|
||||
}
|
||||
__device__ __inline__ void copy_vector<c10::BFloat16, 4>(c10::BFloat16 *dst, const c10::BFloat16 *src) { *((float2*) dst) = *((float2*) src); }
|
||||
|
||||
template <>
|
||||
__device__ __inline__ void copy_vector<c10::Half, 1>(c10::Half *dst,
|
||||
const c10::Half *src) {
|
||||
*dst = *src;
|
||||
}
|
||||
__device__ __inline__ void copy_vector<c10::Half, 1>(c10::Half *dst, const c10::Half *src) { *dst = *src; }
|
||||
|
||||
template <>
|
||||
__device__ __inline__ void copy_vector<c10::Half, 4>(c10::Half *dst,
|
||||
const c10::Half *src) {
|
||||
*((float2 *)dst) = *((float2 *)src);
|
||||
}
|
||||
__device__ __inline__ void copy_vector<c10::Half, 4>(c10::Half *dst, const c10::Half *src) { *((float2*) dst) = *((float2*) src); }
|
||||
|
||||
template <>
|
||||
__device__ __inline__ void copy_vector<uint8_t, 1>(uint8_t *dst,
|
||||
const uint8_t *src) {
|
||||
*dst = *src;
|
||||
}
|
||||
__device__ __inline__ void copy_vector<uint8_t, 1>(uint8_t *dst, const uint8_t *src) { *dst = *src; }
|
||||
|
||||
template <>
|
||||
__device__ __inline__ void copy_vector<uint8_t, 4>(uint8_t *dst,
|
||||
const uint8_t *src) {
|
||||
*((half2 *)dst) = *((half2 *)src);
|
||||
}
|
||||
__device__ __inline__ void copy_vector<uint8_t, 4>(uint8_t *dst, const uint8_t *src) {*((half2*) dst) = *((half2*) src); }
|
||||
|
||||
int log2_ceil(int value) {
|
||||
int log2_value = 0;
|
||||
while ((1 << log2_value) < value) ++log2_value;
|
||||
return log2_value;
|
||||
int log2_value = 0;
|
||||
while ((1 << log2_value) < value) ++log2_value;
|
||||
return log2_value;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
template<typename T>
|
||||
struct Add {
|
||||
__device__ __forceinline__ T operator()(T a, T b) const { return a + b; }
|
||||
__device__ __forceinline__ T operator()(T a, T b) const {
|
||||
return a + b;
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
template<typename T>
|
||||
struct Max {
|
||||
__device__ __forceinline__ T operator()(T a, T b) const {
|
||||
return a < b ? b : a;
|
||||
|
@ -71,468 +55,438 @@ struct Max {
|
|||
};
|
||||
|
||||
template <typename T>
|
||||
__device__ __forceinline__ T
|
||||
WARP_SHFL_XOR_NATIVE(T value, int laneMask, int width = warpSize,
|
||||
unsigned int mask = 0xffffffff) {
|
||||
__device__ __forceinline__ T WARP_SHFL_XOR_NATIVE(T value, int laneMask, int width = warpSize, unsigned int mask = 0xffffffff)
|
||||
{
|
||||
#if CUDA_VERSION >= 9000
|
||||
return __shfl_xor_sync(mask, value, laneMask, width);
|
||||
return __shfl_xor_sync(mask, value, laneMask, width);
|
||||
#else
|
||||
return __shfl_xor(value, laneMask, width);
|
||||
return __shfl_xor(value, laneMask, width);
|
||||
#endif
|
||||
}
|
||||
|
||||
template <typename acc_t, int WARP_BATCH, int WARP_SIZE,
|
||||
template <typename> class ReduceOp>
|
||||
__device__ __forceinline__ void warp_reduce(acc_t *sum) {
|
||||
ReduceOp<acc_t> r;
|
||||
#pragma unroll
|
||||
for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < WARP_BATCH; ++i) {
|
||||
acc_t b = WARP_SHFL_XOR_NATIVE(sum[i], offset, WARP_SIZE);
|
||||
sum[i] = r(sum[i], b);
|
||||
template <typename acc_t, int WARP_BATCH, int WARP_SIZE, template<typename> class ReduceOp>
|
||||
__device__ __forceinline__ void warp_reduce(acc_t* sum) {
|
||||
ReduceOp<acc_t> r;
|
||||
#pragma unroll
|
||||
for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < WARP_BATCH; ++i) {
|
||||
acc_t b = WARP_SHFL_XOR_NATIVE(sum[i], offset, WARP_SIZE);
|
||||
sum[i] = r(sum[i], b);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/*
|
||||
* Extended softmax (from native aten pytorch) with following additional
|
||||
* features 1) input scaling 2) Explicit masking
|
||||
*/
|
||||
template <typename input_t, typename output_t, typename acc_t,
|
||||
int log2_elements>
|
||||
* Extended softmax (from native aten pytorch) with following additional features
|
||||
* 1) input scaling
|
||||
* 2) Explicit masking
|
||||
*/
|
||||
template <typename input_t, typename output_t, typename acc_t, int log2_elements>
|
||||
__global__ void scaled_masked_softmax_warp_forward(
|
||||
output_t *dst, const input_t *src, const uint8_t *mask, const acc_t scale,
|
||||
int micro_batch_size, int element_count, int pad_batches) {
|
||||
// WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and
|
||||
// warp_size of method warp_softmax_forward_kernel.
|
||||
constexpr int next_power_of_two = 1 << log2_elements;
|
||||
constexpr int WARP_SIZE =
|
||||
(next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE;
|
||||
constexpr int WARP_ITERATIONS = next_power_of_two / WARP_SIZE;
|
||||
constexpr int WARP_BATCH = (next_power_of_two <= 128) ? 2 : 1;
|
||||
constexpr int ELEMENTS_PER_LDG_STG = (WARP_ITERATIONS < 4) ? 1 : 4;
|
||||
output_t *dst,
|
||||
const input_t *src,
|
||||
const uint8_t *mask,
|
||||
const acc_t scale,
|
||||
int micro_batch_size,
|
||||
int element_count,
|
||||
int pad_batches)
|
||||
{
|
||||
// WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and
|
||||
// warp_size of method warp_softmax_forward_kernel.
|
||||
constexpr int next_power_of_two = 1 << log2_elements;
|
||||
constexpr int WARP_SIZE = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE;
|
||||
constexpr int WARP_ITERATIONS = next_power_of_two / WARP_SIZE;
|
||||
constexpr int WARP_BATCH = (next_power_of_two <= 128) ? 2 : 1;
|
||||
constexpr int ELEMENTS_PER_LDG_STG = (WARP_ITERATIONS < 4) ? 1 : 4;
|
||||
|
||||
// blockDim/threadIdx = (WARP_SIZE, WARPS_PER_BLOCK, )
|
||||
// gridDim/blockIdx = (seq_len, attn_heads, batches)
|
||||
int first_batch =
|
||||
(blockDim.y *
|
||||
(blockIdx.x + gridDim.x * (blockIdx.y + gridDim.y * blockIdx.z)) +
|
||||
threadIdx.y) *
|
||||
WARP_BATCH;
|
||||
int pad_first_batch = 0;
|
||||
if (pad_batches != 1) { // bert style
|
||||
pad_first_batch =
|
||||
(blockDim.y * (blockIdx.x + gridDim.x * blockIdx.z) + threadIdx.y) *
|
||||
WARP_BATCH;
|
||||
} else { // gpt2 style
|
||||
pad_first_batch = (blockDim.y * blockIdx.x + threadIdx.y) * WARP_BATCH;
|
||||
}
|
||||
// blockDim/threadIdx = (WARP_SIZE, WARPS_PER_BLOCK, )
|
||||
// gridDim/blockIdx = (seq_len, attn_heads, batches)
|
||||
int first_batch = (blockDim.y * (blockIdx.x + gridDim.x * (blockIdx.y + gridDim.y * blockIdx.z))+ threadIdx.y) * WARP_BATCH;
|
||||
int pad_first_batch = 0;
|
||||
if (pad_batches != 1) { // bert style
|
||||
pad_first_batch = (blockDim.y * (blockIdx.x + gridDim.x * blockIdx.z) + threadIdx.y) * WARP_BATCH;
|
||||
} else { // gpt2 style
|
||||
pad_first_batch = (blockDim.y * blockIdx.x + threadIdx.y) * WARP_BATCH;
|
||||
}
|
||||
|
||||
// micro_batch_size might not be a multiple of WARP_BATCH. Check how
|
||||
// many batches have to computed within this WARP.
|
||||
int local_batches = micro_batch_size - first_batch;
|
||||
if (local_batches > WARP_BATCH) local_batches = WARP_BATCH;
|
||||
// micro_batch_size might not be a multiple of WARP_BATCH. Check how
|
||||
// many batches have to computed within this WARP.
|
||||
int local_batches = micro_batch_size - first_batch;
|
||||
if (local_batches > WARP_BATCH)
|
||||
local_batches = WARP_BATCH;
|
||||
|
||||
// there might be multiple batches per warp. compute the index within the
|
||||
// batch
|
||||
int local_idx = threadIdx.x;
|
||||
// there might be multiple batches per warp. compute the index within the batch
|
||||
int local_idx = threadIdx.x;
|
||||
|
||||
src += first_batch * element_count + ELEMENTS_PER_LDG_STG * local_idx;
|
||||
dst += first_batch * element_count + ELEMENTS_PER_LDG_STG * local_idx;
|
||||
mask += pad_first_batch * element_count + ELEMENTS_PER_LDG_STG * local_idx;
|
||||
src += first_batch * element_count + ELEMENTS_PER_LDG_STG * local_idx;
|
||||
dst += first_batch * element_count + ELEMENTS_PER_LDG_STG * local_idx;
|
||||
mask += pad_first_batch * element_count + ELEMENTS_PER_LDG_STG * local_idx;
|
||||
|
||||
// load data from global memory
|
||||
acc_t elements[WARP_BATCH][WARP_ITERATIONS];
|
||||
input_t temp_data[ELEMENTS_PER_LDG_STG];
|
||||
uint8_t temp_mask[ELEMENTS_PER_LDG_STG];
|
||||
#pragma unroll
|
||||
for (int i = 0; i < WARP_BATCH; ++i) {
|
||||
int batch_element_count = (i >= local_batches) ? 0 : element_count;
|
||||
// load data from global memory
|
||||
acc_t elements[WARP_BATCH][WARP_ITERATIONS];
|
||||
input_t temp_data[ELEMENTS_PER_LDG_STG];
|
||||
uint8_t temp_mask[ELEMENTS_PER_LDG_STG];
|
||||
#pragma unroll
|
||||
for (int i = 0; i < WARP_BATCH; ++i) {
|
||||
int batch_element_count = (i >= local_batches) ? 0 : element_count;
|
||||
|
||||
#pragma unroll
|
||||
for (int it = 0; it < WARP_ITERATIONS; it += ELEMENTS_PER_LDG_STG) {
|
||||
int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE;
|
||||
#pragma unroll
|
||||
for (int it = 0; it < WARP_ITERATIONS; it+=ELEMENTS_PER_LDG_STG) {
|
||||
int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE;
|
||||
|
||||
if (element_index < batch_element_count) {
|
||||
int itr_idx = i * element_count + it * WARP_SIZE;
|
||||
copy_vector<input_t, ELEMENTS_PER_LDG_STG>(temp_data, src + itr_idx);
|
||||
copy_vector<uint8_t, ELEMENTS_PER_LDG_STG>(temp_mask, mask + itr_idx);
|
||||
if (element_index < batch_element_count) {
|
||||
int itr_idx = i*element_count+it*WARP_SIZE;
|
||||
copy_vector<input_t, ELEMENTS_PER_LDG_STG>(temp_data, src + itr_idx);
|
||||
copy_vector<uint8_t, ELEMENTS_PER_LDG_STG>(temp_mask, mask + itr_idx);
|
||||
|
||||
#pragma unroll
|
||||
for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) {
|
||||
if (temp_mask[element] != 1) {
|
||||
elements[i][it + element] = (acc_t)temp_data[element] * scale;
|
||||
} else {
|
||||
elements[i][it + element] = -10000.0;
|
||||
}
|
||||
#pragma unroll
|
||||
for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) {
|
||||
if (temp_mask[element] != 1) {
|
||||
elements[i][it + element] = (acc_t)temp_data[element] * scale;
|
||||
} else {
|
||||
elements[i][it + element] = -10000.0;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
#pragma unroll
|
||||
for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) {
|
||||
elements[i][it + element] = -std::numeric_limits<acc_t>::infinity();
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
#pragma unroll
|
||||
for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) {
|
||||
elements[i][it + element] = -std::numeric_limits<acc_t>::infinity();
|
||||
}
|
||||
|
||||
// compute max_value
|
||||
acc_t max_value[WARP_BATCH];
|
||||
#pragma unroll
|
||||
for (int i = 0; i < WARP_BATCH; ++i) {
|
||||
max_value[i] = elements[i][0];
|
||||
#pragma unroll
|
||||
for (int it = 1; it < WARP_ITERATIONS; ++it) {
|
||||
max_value[i] = (max_value[i] > elements[i][it]) ? max_value[i] : elements[i][it];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
warp_reduce<acc_t, WARP_BATCH, WARP_SIZE, Max>(max_value);
|
||||
|
||||
// compute max_value
|
||||
acc_t max_value[WARP_BATCH];
|
||||
#pragma unroll
|
||||
for (int i = 0; i < WARP_BATCH; ++i) {
|
||||
max_value[i] = elements[i][0];
|
||||
#pragma unroll
|
||||
for (int it = 1; it < WARP_ITERATIONS; ++it) {
|
||||
max_value[i] =
|
||||
(max_value[i] > elements[i][it]) ? max_value[i] : elements[i][it];
|
||||
}
|
||||
}
|
||||
warp_reduce<acc_t, WARP_BATCH, WARP_SIZE, Max>(max_value);
|
||||
|
||||
acc_t sum[WARP_BATCH]{0.0f};
|
||||
#pragma unroll
|
||||
for (int i = 0; i < WARP_BATCH; ++i) {
|
||||
#pragma unroll
|
||||
for (int it = 0; it < WARP_ITERATIONS; ++it) {
|
||||
elements[i][it] = std::exp((elements[i][it] - max_value[i]));
|
||||
sum[i] += elements[i][it];
|
||||
}
|
||||
}
|
||||
warp_reduce<acc_t, WARP_BATCH, WARP_SIZE, Add>(sum);
|
||||
|
||||
// store result
|
||||
output_t out[ELEMENTS_PER_LDG_STG];
|
||||
#pragma unroll
|
||||
for (int i = 0; i < WARP_BATCH; ++i) {
|
||||
if (i >= local_batches) break;
|
||||
#pragma unroll
|
||||
for (int it = 0; it < WARP_ITERATIONS; it += ELEMENTS_PER_LDG_STG) {
|
||||
int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE;
|
||||
if (element_index < element_count) {
|
||||
#pragma unroll
|
||||
for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) {
|
||||
out[element] = elements[i][it + element] / sum[i];
|
||||
acc_t sum[WARP_BATCH] { 0.0f };
|
||||
#pragma unroll
|
||||
for (int i = 0; i < WARP_BATCH; ++i) {
|
||||
#pragma unroll
|
||||
for (int it = 0; it < WARP_ITERATIONS; ++it) {
|
||||
elements[i][it] = std::exp((elements[i][it] - max_value[i]));
|
||||
sum[i] += elements[i][it];
|
||||
}
|
||||
}
|
||||
warp_reduce<acc_t, WARP_BATCH, WARP_SIZE, Add>(sum);
|
||||
|
||||
// store result
|
||||
output_t out[ELEMENTS_PER_LDG_STG];
|
||||
#pragma unroll
|
||||
for (int i = 0; i < WARP_BATCH; ++i) {
|
||||
if (i >= local_batches)
|
||||
break;
|
||||
#pragma unroll
|
||||
for (int it = 0; it < WARP_ITERATIONS; it+=ELEMENTS_PER_LDG_STG) {
|
||||
int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE;
|
||||
if (element_index < element_count) {
|
||||
#pragma unroll
|
||||
for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) {
|
||||
out[element] = elements[i][it + element] / sum[i];
|
||||
}
|
||||
copy_vector<output_t, ELEMENTS_PER_LDG_STG>(dst + i * element_count + it * WARP_SIZE, out);
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
copy_vector<output_t, ELEMENTS_PER_LDG_STG>(
|
||||
dst + i * element_count + it * WARP_SIZE, out);
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename input_t, typename output_t, typename acc_t,
|
||||
int log2_elements>
|
||||
template <typename input_t, typename output_t, typename acc_t, int log2_elements>
|
||||
__global__ void scaled_masked_softmax_warp_backward(
|
||||
output_t *gradInput, input_t *grad, const input_t *output, acc_t scale,
|
||||
int micro_batch_size, int element_count) {
|
||||
// WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and
|
||||
// warp_size of method warp_softmax_backward_kernel.
|
||||
constexpr int next_power_of_two = 1 << log2_elements;
|
||||
constexpr int WARP_SIZE =
|
||||
(next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE;
|
||||
constexpr int WARP_ITERATIONS = next_power_of_two / WARP_SIZE;
|
||||
constexpr int WARP_BATCH = (next_power_of_two <= 128) ? 2 : 1;
|
||||
constexpr int ELEMENTS_PER_LDG_STG = (WARP_ITERATIONS < 4) ? 1 : 4;
|
||||
output_t *gradInput,
|
||||
input_t *grad,
|
||||
const input_t *output,
|
||||
acc_t scale,
|
||||
int micro_batch_size,
|
||||
int element_count)
|
||||
{
|
||||
// WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and
|
||||
// warp_size of method warp_softmax_backward_kernel.
|
||||
constexpr int next_power_of_two = 1 << log2_elements;
|
||||
constexpr int WARP_SIZE = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE;
|
||||
constexpr int WARP_ITERATIONS = next_power_of_two / WARP_SIZE;
|
||||
constexpr int WARP_BATCH = (next_power_of_two <= 128) ? 2 : 1;
|
||||
constexpr int ELEMENTS_PER_LDG_STG = (WARP_ITERATIONS < 4) ? 1 : 4;
|
||||
|
||||
// blockDim/threadIdx = (WARP_SIZE, WARPS_PER_BLOCK, )
|
||||
// gridDim/blockIdx = (seq_len, attn_heads, batches)
|
||||
int first_batch = (blockDim.y * blockIdx.x + threadIdx.y) * WARP_BATCH;
|
||||
// blockDim/threadIdx = (WARP_SIZE, WARPS_PER_BLOCK, )
|
||||
// gridDim/blockIdx = (seq_len, attn_heads, batches)
|
||||
int first_batch = (blockDim.y * blockIdx.x + threadIdx.y) * WARP_BATCH;
|
||||
|
||||
// micro_batch_size might not be a multiple of WARP_BATCH. Check how
|
||||
// many batches have to computed within this WARP.
|
||||
int local_batches = micro_batch_size - first_batch;
|
||||
if (local_batches > WARP_BATCH)
|
||||
local_batches = WARP_BATCH;
|
||||
|
||||
// micro_batch_size might not be a multiple of WARP_BATCH. Check how
|
||||
// many batches have to computed within this WARP.
|
||||
int local_batches = micro_batch_size - first_batch;
|
||||
if (local_batches > WARP_BATCH) local_batches = WARP_BATCH;
|
||||
// there might be multiple batches per warp. compute the index within the batch
|
||||
int local_idx = threadIdx.x;
|
||||
|
||||
// there might be multiple batches per warp. compute the index within the
|
||||
// batch
|
||||
int local_idx = threadIdx.x;
|
||||
// the first element to process by the current thread
|
||||
int thread_offset = first_batch * element_count + ELEMENTS_PER_LDG_STG * local_idx;
|
||||
grad += thread_offset;
|
||||
output += thread_offset;
|
||||
gradInput += thread_offset;
|
||||
|
||||
// the first element to process by the current thread
|
||||
int thread_offset =
|
||||
first_batch * element_count + ELEMENTS_PER_LDG_STG * local_idx;
|
||||
grad += thread_offset;
|
||||
output += thread_offset;
|
||||
gradInput += thread_offset;
|
||||
// load data from global memory
|
||||
acc_t grad_reg[WARP_BATCH][WARP_ITERATIONS] { 0.0f };
|
||||
acc_t output_reg[WARP_BATCH][WARP_ITERATIONS] { 0.0f };
|
||||
input_t temp_grad[ELEMENTS_PER_LDG_STG];
|
||||
input_t temp_output[ELEMENTS_PER_LDG_STG];
|
||||
#pragma unroll
|
||||
for (int i = 0; i < WARP_BATCH; ++i) {
|
||||
int batch_element_count = (i >= local_batches) ? 0 : element_count;
|
||||
|
||||
// load data from global memory
|
||||
acc_t grad_reg[WARP_BATCH][WARP_ITERATIONS]{0.0f};
|
||||
acc_t output_reg[WARP_BATCH][WARP_ITERATIONS]{0.0f};
|
||||
input_t temp_grad[ELEMENTS_PER_LDG_STG];
|
||||
input_t temp_output[ELEMENTS_PER_LDG_STG];
|
||||
#pragma unroll
|
||||
for (int i = 0; i < WARP_BATCH; ++i) {
|
||||
int batch_element_count = (i >= local_batches) ? 0 : element_count;
|
||||
#pragma unroll
|
||||
for (int it = 0; it < WARP_ITERATIONS; it+=ELEMENTS_PER_LDG_STG) {
|
||||
int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE;
|
||||
if (element_index < batch_element_count) {
|
||||
copy_vector<input_t, ELEMENTS_PER_LDG_STG>(temp_grad, grad + i * element_count + it * WARP_SIZE);
|
||||
copy_vector<input_t, ELEMENTS_PER_LDG_STG>(temp_output, output + i * element_count + it * WARP_SIZE);
|
||||
|
||||
#pragma unroll
|
||||
for (int it = 0; it < WARP_ITERATIONS; it += ELEMENTS_PER_LDG_STG) {
|
||||
int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE;
|
||||
if (element_index < batch_element_count) {
|
||||
copy_vector<input_t, ELEMENTS_PER_LDG_STG>(
|
||||
temp_grad, grad + i * element_count + it * WARP_SIZE);
|
||||
copy_vector<input_t, ELEMENTS_PER_LDG_STG>(
|
||||
temp_output, output + i * element_count + it * WARP_SIZE);
|
||||
|
||||
#pragma unroll
|
||||
for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) {
|
||||
output_reg[i][it + element] = (acc_t)temp_output[element];
|
||||
#pragma unroll
|
||||
for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) {
|
||||
output_reg[i][it + element] = (acc_t)temp_output[element];
|
||||
}
|
||||
#pragma unroll
|
||||
for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) {
|
||||
grad_reg[i][it + element] = (acc_t)temp_grad[element] * output_reg[i][it + element];
|
||||
}
|
||||
}
|
||||
}
|
||||
#pragma unroll
|
||||
for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) {
|
||||
grad_reg[i][it + element] =
|
||||
(acc_t)temp_grad[element] * output_reg[i][it + element];
|
||||
}
|
||||
|
||||
acc_t sum[WARP_BATCH];
|
||||
#pragma unroll
|
||||
for (int i = 0; i < WARP_BATCH; ++i) {
|
||||
sum[i] = grad_reg[i][0];
|
||||
#pragma unroll
|
||||
for (int it = 1; it < WARP_ITERATIONS; ++it) {
|
||||
sum[i] += grad_reg[i][it];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
warp_reduce<acc_t, WARP_BATCH, WARP_SIZE, Add>(sum);
|
||||
|
||||
acc_t sum[WARP_BATCH];
|
||||
#pragma unroll
|
||||
for (int i = 0; i < WARP_BATCH; ++i) {
|
||||
sum[i] = grad_reg[i][0];
|
||||
#pragma unroll
|
||||
for (int it = 1; it < WARP_ITERATIONS; ++it) {
|
||||
sum[i] += grad_reg[i][it];
|
||||
}
|
||||
}
|
||||
warp_reduce<acc_t, WARP_BATCH, WARP_SIZE, Add>(sum);
|
||||
|
||||
// store result
|
||||
#pragma unroll
|
||||
for (int i = 0; i < WARP_BATCH; ++i) {
|
||||
if (i >= local_batches) break;
|
||||
#pragma unroll
|
||||
for (int it = 0; it < WARP_ITERATIONS; it += ELEMENTS_PER_LDG_STG) {
|
||||
int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE;
|
||||
if (element_index < element_count) {
|
||||
// compute gradients
|
||||
output_t out[ELEMENTS_PER_LDG_STG];
|
||||
#pragma unroll
|
||||
for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) {
|
||||
out[element] =
|
||||
(output_t)(scale * (grad_reg[i][it + element] -
|
||||
output_reg[i][it + element] * sum[i]));
|
||||
// store result
|
||||
#pragma unroll
|
||||
for (int i = 0; i < WARP_BATCH; ++i) {
|
||||
if (i >= local_batches)
|
||||
break;
|
||||
#pragma unroll
|
||||
for (int it = 0; it < WARP_ITERATIONS; it+=ELEMENTS_PER_LDG_STG) {
|
||||
int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE;
|
||||
if (element_index < element_count) {
|
||||
// compute gradients
|
||||
output_t out[ELEMENTS_PER_LDG_STG];
|
||||
#pragma unroll
|
||||
for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) {
|
||||
out[element] = (output_t)(scale * (grad_reg[i][it + element] - output_reg[i][it + element] * sum[i]));
|
||||
}
|
||||
copy_vector<output_t, ELEMENTS_PER_LDG_STG>(gradInput + i * element_count + it * WARP_SIZE, out);
|
||||
}
|
||||
}
|
||||
copy_vector<output_t, ELEMENTS_PER_LDG_STG>(
|
||||
gradInput + i * element_count + it * WARP_SIZE, out);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
} // end of anonymous namespace
|
||||
} // end of anonymous namespace
|
||||
|
||||
int get_batch_per_block(int query_seq_len, int key_seq_len, int batches,
|
||||
int attn_heads) {
|
||||
int log2_elements = log2_ceil(key_seq_len);
|
||||
const int next_power_of_two = 1 << log2_elements;
|
||||
|
||||
int warp_size =
|
||||
(next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE;
|
||||
int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1;
|
||||
|
||||
constexpr int threads_per_block = 128;
|
||||
int warps_per_block = (threads_per_block / warp_size);
|
||||
int batches_per_block = warps_per_block * batches_per_warp;
|
||||
|
||||
return batches_per_block;
|
||||
}
|
||||
|
||||
template <typename input_t, typename output_t, typename acc_t>
|
||||
void dispatch_scaled_masked_softmax_forward(output_t *dst, const input_t *src,
|
||||
const uint8_t *mask,
|
||||
const input_t scale,
|
||||
int query_seq_len, int key_seq_len,
|
||||
int batches, int attn_heads,
|
||||
int pad_batches) {
|
||||
TORCH_INTERNAL_ASSERT(key_seq_len >= 0 && key_seq_len <= 2048);
|
||||
if (key_seq_len == 0) {
|
||||
return;
|
||||
} else {
|
||||
int get_batch_per_block(int query_seq_len, int key_seq_len, int batches, int attn_heads){
|
||||
int log2_elements = log2_ceil(key_seq_len);
|
||||
const int next_power_of_two = 1 << log2_elements;
|
||||
int batch_count = batches * attn_heads * query_seq_len;
|
||||
|
||||
// This value must match the WARP_SIZE constexpr value computed inside
|
||||
// softmax_warp_forward.
|
||||
int warp_size =
|
||||
(next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE;
|
||||
|
||||
// This value must match the WARP_BATCH constexpr value computed inside
|
||||
// softmax_warp_forward.
|
||||
int warp_size = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE;
|
||||
int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1;
|
||||
|
||||
// use 128 threads per block to maximimize gpu utilization
|
||||
constexpr int threads_per_block = 128;
|
||||
|
||||
int warps_per_block = (threads_per_block / warp_size);
|
||||
int batches_per_block = warps_per_block * batches_per_warp;
|
||||
TORCH_INTERNAL_ASSERT(query_seq_len % batches_per_block == 0);
|
||||
dim3 blocks(query_seq_len / batches_per_block, attn_heads, batches);
|
||||
dim3 threads(warp_size, warps_per_block, 1);
|
||||
// Launch code would be more elegant if C++ supported FOR CONSTEXPR
|
||||
switch (log2_elements) {
|
||||
case 0: // 1
|
||||
scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 0>
|
||||
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
|
||||
dst, src, mask, scale, batch_count, key_seq_len, pad_batches);
|
||||
break;
|
||||
case 1: // 2
|
||||
scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 1>
|
||||
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
|
||||
dst, src, mask, scale, batch_count, key_seq_len, pad_batches);
|
||||
break;
|
||||
case 2: // 4
|
||||
scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 2>
|
||||
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
|
||||
dst, src, mask, scale, batch_count, key_seq_len, pad_batches);
|
||||
break;
|
||||
case 3: // 8
|
||||
scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 3>
|
||||
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
|
||||
dst, src, mask, scale, batch_count, key_seq_len, pad_batches);
|
||||
break;
|
||||
case 4: // 16
|
||||
scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 4>
|
||||
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
|
||||
dst, src, mask, scale, batch_count, key_seq_len, pad_batches);
|
||||
break;
|
||||
case 5: // 32
|
||||
scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 5>
|
||||
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
|
||||
dst, src, mask, scale, batch_count, key_seq_len, pad_batches);
|
||||
break;
|
||||
case 6: // 64
|
||||
scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 6>
|
||||
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
|
||||
dst, src, mask, scale, batch_count, key_seq_len, pad_batches);
|
||||
break;
|
||||
case 7: // 128
|
||||
scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 7>
|
||||
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
|
||||
dst, src, mask, scale, batch_count, key_seq_len, pad_batches);
|
||||
break;
|
||||
case 8: // 256
|
||||
scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 8>
|
||||
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
|
||||
dst, src, mask, scale, batch_count, key_seq_len, pad_batches);
|
||||
break;
|
||||
case 9: // 512
|
||||
scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 9>
|
||||
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
|
||||
dst, src, mask, scale, batch_count, key_seq_len, pad_batches);
|
||||
break;
|
||||
case 10: // 1024
|
||||
scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 10>
|
||||
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
|
||||
dst, src, mask, scale, batch_count, key_seq_len, pad_batches);
|
||||
break;
|
||||
case 11: // 2048
|
||||
scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 11>
|
||||
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
|
||||
dst, src, mask, scale, batch_count, key_seq_len, pad_batches);
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
return batches_per_block;
|
||||
}
|
||||
|
||||
template <typename input_t, typename output_t, typename acc_t>
|
||||
void dispatch_scaled_masked_softmax_backward(output_t *grad_input,
|
||||
input_t *grad,
|
||||
const input_t *output,
|
||||
const acc_t scale,
|
||||
int query_seq_len, int key_seq_len,
|
||||
int batches, int attn_heads) {
|
||||
TORCH_INTERNAL_ASSERT(key_seq_len >= 0 && key_seq_len <= 2048);
|
||||
if (key_seq_len == 0) {
|
||||
return;
|
||||
} else {
|
||||
int log2_elements = log2_ceil(key_seq_len);
|
||||
const int next_power_of_two = 1 << log2_elements;
|
||||
int batch_count = batches * attn_heads * query_seq_len;
|
||||
template<typename input_t, typename output_t, typename acc_t>
|
||||
void dispatch_scaled_masked_softmax_forward(
|
||||
output_t *dst,
|
||||
const input_t *src,
|
||||
const uint8_t *mask,
|
||||
const input_t scale,
|
||||
int query_seq_len,
|
||||
int key_seq_len,
|
||||
int batches,
|
||||
int attn_heads,
|
||||
int pad_batches)
|
||||
{
|
||||
TORCH_INTERNAL_ASSERT(key_seq_len >= 0 && key_seq_len <= 2048 );
|
||||
if (key_seq_len == 0) {
|
||||
return;
|
||||
} else {
|
||||
int log2_elements = log2_ceil(key_seq_len);
|
||||
const int next_power_of_two = 1 << log2_elements;
|
||||
int batch_count = batches * attn_heads * query_seq_len;
|
||||
|
||||
// This value must match the WARP_SIZE constexpr value computed inside
|
||||
// softmax_warp_backward.
|
||||
int warp_size =
|
||||
(next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE;
|
||||
// This value must match the WARP_SIZE constexpr value computed inside softmax_warp_forward.
|
||||
int warp_size = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE;
|
||||
|
||||
// This value must match the WARP_BATCH constexpr value computed inside
|
||||
// softmax_warp_backward.
|
||||
int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1;
|
||||
// This value must match the WARP_BATCH constexpr value computed inside softmax_warp_forward.
|
||||
int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1;
|
||||
|
||||
// use 128 threads per block to maximimize gpu utilization
|
||||
constexpr int threads_per_block = 128;
|
||||
// use 128 threads per block to maximimize gpu utilization
|
||||
constexpr int threads_per_block = 128;
|
||||
|
||||
int warps_per_block = (threads_per_block / warp_size);
|
||||
int batches_per_block = warps_per_block * batches_per_warp;
|
||||
int blocks = batch_count / batches_per_block;
|
||||
dim3 threads(warp_size, warps_per_block, 1);
|
||||
// Launch code would be more elegant if C++ supported FOR CONSTEXPR
|
||||
switch (log2_elements) {
|
||||
case 0: // 1
|
||||
scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 0>
|
||||
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
|
||||
grad_input, grad, output, scale, batch_count, key_seq_len);
|
||||
break;
|
||||
case 1: // 2
|
||||
scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 1>
|
||||
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
|
||||
grad_input, grad, output, scale, batch_count, key_seq_len);
|
||||
break;
|
||||
case 2: // 4
|
||||
scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 2>
|
||||
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
|
||||
grad_input, grad, output, scale, batch_count, key_seq_len);
|
||||
break;
|
||||
case 3: // 8
|
||||
scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 3>
|
||||
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
|
||||
grad_input, grad, output, scale, batch_count, key_seq_len);
|
||||
break;
|
||||
case 4: // 16
|
||||
scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 4>
|
||||
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
|
||||
grad_input, grad, output, scale, batch_count, key_seq_len);
|
||||
break;
|
||||
case 5: // 32
|
||||
scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 5>
|
||||
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
|
||||
grad_input, grad, output, scale, batch_count, key_seq_len);
|
||||
break;
|
||||
case 6: // 64
|
||||
scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 6>
|
||||
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
|
||||
grad_input, grad, output, scale, batch_count, key_seq_len);
|
||||
break;
|
||||
case 7: // 128
|
||||
scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 7>
|
||||
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
|
||||
grad_input, grad, output, scale, batch_count, key_seq_len);
|
||||
break;
|
||||
case 8: // 256
|
||||
scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 8>
|
||||
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
|
||||
grad_input, grad, output, scale, batch_count, key_seq_len);
|
||||
break;
|
||||
case 9: // 512
|
||||
scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 9>
|
||||
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
|
||||
grad_input, grad, output, scale, batch_count, key_seq_len);
|
||||
break;
|
||||
case 10: // 1024
|
||||
scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 10>
|
||||
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
|
||||
grad_input, grad, output, scale, batch_count, key_seq_len);
|
||||
break;
|
||||
case 11: // 2048
|
||||
scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 11>
|
||||
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
|
||||
grad_input, grad, output, scale, batch_count, key_seq_len);
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
int warps_per_block = (threads_per_block / warp_size);
|
||||
int batches_per_block = warps_per_block * batches_per_warp;
|
||||
TORCH_INTERNAL_ASSERT(query_seq_len%batches_per_block == 0);
|
||||
dim3 blocks(query_seq_len/batches_per_block, attn_heads, batches);
|
||||
dim3 threads(warp_size, warps_per_block, 1);
|
||||
// Launch code would be more elegant if C++ supported FOR CONSTEXPR
|
||||
switch (log2_elements) {
|
||||
case 0: // 1
|
||||
scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 0>
|
||||
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches);
|
||||
break;
|
||||
case 1: // 2
|
||||
scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 1>
|
||||
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches);
|
||||
break;
|
||||
case 2: // 4
|
||||
scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 2>
|
||||
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches);
|
||||
break;
|
||||
case 3: // 8
|
||||
scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 3>
|
||||
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches);
|
||||
break;
|
||||
case 4: // 16
|
||||
scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 4>
|
||||
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches);
|
||||
break;
|
||||
case 5: // 32
|
||||
scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 5>
|
||||
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches);
|
||||
break;
|
||||
case 6: // 64
|
||||
scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 6>
|
||||
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches);
|
||||
break;
|
||||
case 7: // 128
|
||||
scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 7>
|
||||
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches);
|
||||
break;
|
||||
case 8: // 256
|
||||
scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 8>
|
||||
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches);
|
||||
break;
|
||||
case 9: // 512
|
||||
scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 9>
|
||||
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches);
|
||||
break;
|
||||
case 10: // 1024
|
||||
scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 10>
|
||||
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches);
|
||||
break;
|
||||
case 11: // 2048
|
||||
scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 11>
|
||||
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches);
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template<typename input_t, typename output_t, typename acc_t>
|
||||
void dispatch_scaled_masked_softmax_backward(
|
||||
output_t *grad_input,
|
||||
input_t *grad,
|
||||
const input_t *output,
|
||||
const acc_t scale,
|
||||
int query_seq_len,
|
||||
int key_seq_len,
|
||||
int batches,
|
||||
int attn_heads)
|
||||
{
|
||||
TORCH_INTERNAL_ASSERT( key_seq_len >= 0 && key_seq_len <= 2048 );
|
||||
if (key_seq_len == 0) {
|
||||
return;
|
||||
} else {
|
||||
int log2_elements = log2_ceil(key_seq_len);
|
||||
const int next_power_of_two = 1 << log2_elements;
|
||||
int batch_count = batches * attn_heads * query_seq_len;
|
||||
|
||||
// This value must match the WARP_SIZE constexpr value computed inside softmax_warp_backward.
|
||||
int warp_size = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE;
|
||||
|
||||
// This value must match the WARP_BATCH constexpr value computed inside softmax_warp_backward.
|
||||
int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1;
|
||||
|
||||
// use 128 threads per block to maximimize gpu utilization
|
||||
constexpr int threads_per_block = 128;
|
||||
|
||||
int warps_per_block = (threads_per_block / warp_size);
|
||||
int batches_per_block = warps_per_block * batches_per_warp;
|
||||
int blocks = batch_count/batches_per_block;
|
||||
dim3 threads(warp_size, warps_per_block, 1);
|
||||
// Launch code would be more elegant if C++ supported FOR CONSTEXPR
|
||||
switch (log2_elements) {
|
||||
case 0: // 1
|
||||
scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 0>
|
||||
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, key_seq_len);
|
||||
break;
|
||||
case 1: // 2
|
||||
scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 1>
|
||||
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, key_seq_len);
|
||||
break;
|
||||
case 2: // 4
|
||||
scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 2>
|
||||
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, key_seq_len);
|
||||
break;
|
||||
case 3: // 8
|
||||
scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 3>
|
||||
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, key_seq_len);
|
||||
break;
|
||||
case 4: // 16
|
||||
scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 4>
|
||||
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, key_seq_len);
|
||||
break;
|
||||
case 5: // 32
|
||||
scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 5>
|
||||
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, key_seq_len);
|
||||
break;
|
||||
case 6: // 64
|
||||
scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 6>
|
||||
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, key_seq_len);
|
||||
break;
|
||||
case 7: // 128
|
||||
scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 7>
|
||||
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, key_seq_len);
|
||||
break;
|
||||
case 8: // 256
|
||||
scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 8>
|
||||
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, key_seq_len);
|
||||
break;
|
||||
case 9: // 512
|
||||
scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 9>
|
||||
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, key_seq_len);
|
||||
break;
|
||||
case 10: // 1024
|
||||
scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 10>
|
||||
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, key_seq_len);
|
||||
break;
|
||||
case 11: // 2048
|
||||
scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 11>
|
||||
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, key_seq_len);
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -4,12 +4,11 @@
|
|||
#pragma once
|
||||
|
||||
#include <assert.h>
|
||||
#include <c10/macros/Macros.h>
|
||||
#include <cuda_fp16.h>
|
||||
#include <stdint.h>
|
||||
|
||||
#include <cfloat>
|
||||
#include <limits>
|
||||
#include <stdint.h>
|
||||
#include <c10/macros/Macros.h>
|
||||
|
||||
namespace {
|
||||
|
||||
|
@ -17,78 +16,53 @@ template <typename Datatype, int ELEMENTS_PER_LDG>
|
|||
__device__ __inline__ void copy_vector(Datatype *dst, const Datatype *src);
|
||||
|
||||
template <>
|
||||
__device__ __inline__ void copy_vector<c10::BFloat16, 1>(
|
||||
c10::BFloat16 *dst, const c10::BFloat16 *src) {
|
||||
*dst = *src;
|
||||
}
|
||||
__device__ __inline__ void copy_vector<c10::BFloat16, 1>(c10::BFloat16 *dst, const c10::BFloat16 *src) { *dst = *src; }
|
||||
|
||||
template <>
|
||||
__device__ __inline__ void copy_vector<c10::BFloat16, 4>(
|
||||
c10::BFloat16 *dst, const c10::BFloat16 *src) {
|
||||
*((float2 *)dst) = *((float2 *)src);
|
||||
}
|
||||
__device__ __inline__ void copy_vector<c10::BFloat16, 4>(c10::BFloat16 *dst, const c10::BFloat16 *src) { *((float2*) dst) = *((float2*) src); }
|
||||
|
||||
template <>
|
||||
__device__ __inline__ void copy_vector<c10::Half, 1>(c10::Half *dst, const c10::Half *src) { *dst = *src; }
|
||||
|
||||
template <>
|
||||
__device__ __inline__ void copy_vector<c10::Half, 1>(c10::Half *dst,
|
||||
const c10::Half *src) {
|
||||
*dst = *src;
|
||||
}
|
||||
__device__ __inline__ void copy_vector<c10::Half, 4>(c10::Half *dst, const c10::Half *src) { *((float2*) dst) = *((float2*) src); }
|
||||
|
||||
template <>
|
||||
__device__ __inline__ void copy_vector<c10::Half, 4>(c10::Half *dst,
|
||||
const c10::Half *src) {
|
||||
*((float2 *)dst) = *((float2 *)src);
|
||||
}
|
||||
__device__ __inline__ void copy_vector<uint8_t, 1>(uint8_t *dst, const uint8_t *src) { *dst = *src; }
|
||||
|
||||
template <>
|
||||
__device__ __inline__ void copy_vector<uint8_t, 1>(uint8_t *dst,
|
||||
const uint8_t *src) {
|
||||
*dst = *src;
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ __inline__ void copy_vector<uint8_t, 4>(uint8_t *dst,
|
||||
const uint8_t *src) {
|
||||
*((half2 *)dst) = *((half2 *)src);
|
||||
}
|
||||
__device__ __inline__ void copy_vector<uint8_t, 4>(uint8_t *dst, const uint8_t *src) {*((half2*) dst) = *((half2*) src); }
|
||||
|
||||
template <typename Datatype, int ELEMENTS_PER_LDG>
|
||||
__device__ __inline__ void copy_zero_vector(Datatype *dst);
|
||||
|
||||
template <>
|
||||
__device__ __inline__ void copy_zero_vector<c10::BFloat16, 1>(
|
||||
c10::BFloat16 *dst) {
|
||||
*dst = 0.0;
|
||||
}
|
||||
__device__ __inline__ void copy_zero_vector<c10::BFloat16, 1>(c10::BFloat16 *dst) { *dst = 0.0; }
|
||||
|
||||
template <>
|
||||
__device__ __inline__ void copy_zero_vector<c10::BFloat16, 4>(
|
||||
c10::BFloat16 *dst) {
|
||||
*((float2 *)dst) = make_float2(0.0f, 0.0f);
|
||||
}
|
||||
__device__ __inline__ void copy_zero_vector<c10::BFloat16, 4>(c10::BFloat16 *dst) { *((float2*) dst) = make_float2(0.0f, 0.0f); }
|
||||
|
||||
template <>
|
||||
__device__ __inline__ void copy_zero_vector<c10::Half, 1>(c10::Half *dst) {
|
||||
*dst = 0.0;
|
||||
}
|
||||
__device__ __inline__ void copy_zero_vector<c10::Half, 1>(c10::Half *dst) { *dst = 0.0; }
|
||||
|
||||
template <>
|
||||
__device__ __inline__ void copy_zero_vector<c10::Half, 4>(c10::Half *dst) {
|
||||
*((float2 *)dst) = make_float2(0.0f, 0.0f);
|
||||
}
|
||||
__device__ __inline__ void copy_zero_vector<c10::Half, 4>(c10::Half *dst) { *((float2*) dst) = make_float2(0.0f, 0.0f); }
|
||||
|
||||
|
||||
int log2_ceil(int value) {
|
||||
int log2_value = 0;
|
||||
while ((1 << log2_value) < value) ++log2_value;
|
||||
return log2_value;
|
||||
int log2_value = 0;
|
||||
while ((1 << log2_value) < value) ++log2_value;
|
||||
return log2_value;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
template<typename T>
|
||||
struct Add {
|
||||
__device__ __forceinline__ T operator()(T a, T b) const { return a + b; }
|
||||
__device__ __forceinline__ T operator()(T a, T b) const {
|
||||
return a + b;
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
template<typename T>
|
||||
struct Max {
|
||||
__device__ __forceinline__ T operator()(T a, T b) const {
|
||||
return a < b ? b : a;
|
||||
|
@ -96,505 +70,431 @@ struct Max {
|
|||
};
|
||||
|
||||
template <typename T>
|
||||
__device__ __forceinline__ T
|
||||
WARP_SHFL_XOR_NATIVE(T value, int laneMask, int width = warpSize,
|
||||
unsigned int mask = 0xffffffff) {
|
||||
__device__ __forceinline__ T WARP_SHFL_XOR_NATIVE(T value, int laneMask, int width = warpSize, unsigned int mask = 0xffffffff)
|
||||
{
|
||||
#if CUDA_VERSION >= 9000
|
||||
return __shfl_xor_sync(mask, value, laneMask, width);
|
||||
return __shfl_xor_sync(mask, value, laneMask, width);
|
||||
#else
|
||||
return __shfl_xor(value, laneMask, width);
|
||||
return __shfl_xor(value, laneMask, width);
|
||||
#endif
|
||||
}
|
||||
|
||||
template <typename acc_t, int WARP_BATCH, int WARP_SIZE,
|
||||
template <typename> class ReduceOp>
|
||||
__device__ __forceinline__ void warp_reduce(acc_t *sum) {
|
||||
ReduceOp<acc_t> r;
|
||||
#pragma unroll
|
||||
for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < WARP_BATCH; ++i) {
|
||||
acc_t b = WARP_SHFL_XOR_NATIVE(sum[i], offset, WARP_SIZE);
|
||||
sum[i] = r(sum[i], b);
|
||||
template <typename acc_t, int WARP_BATCH, int WARP_SIZE, template<typename> class ReduceOp>
|
||||
__device__ __forceinline__ void warp_reduce(acc_t* sum) {
|
||||
ReduceOp<acc_t> r;
|
||||
#pragma unroll
|
||||
for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < WARP_BATCH; ++i) {
|
||||
acc_t b = WARP_SHFL_XOR_NATIVE(sum[i], offset, WARP_SIZE);
|
||||
sum[i] = r(sum[i], b);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/*
|
||||
* Extended softmax (from native aten pytorch) with following additional
|
||||
* features 1) input scaling 2) Implicit time (diagonal masking)
|
||||
* Extended softmax (from native aten pytorch) with following additional features
|
||||
* 1) input scaling
|
||||
* 2) Implicit time (diagonal masking)
|
||||
*/
|
||||
template <typename input_t, typename output_t, typename acc_t,
|
||||
int log2_elements>
|
||||
template <typename input_t, typename output_t, typename acc_t, int log2_elements>
|
||||
__global__ void scaled_upper_triang_masked_softmax_warp_forward(
|
||||
output_t *dst, const input_t *src, const acc_t scale, int micro_batch_size,
|
||||
int stride, int element_count) {
|
||||
// WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and
|
||||
// warp_size of method warp_softmax_forward_kernel.
|
||||
constexpr int next_power_of_two = 1 << log2_elements;
|
||||
constexpr int WARP_SIZE =
|
||||
(next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE;
|
||||
constexpr int WARP_ITERATIONS = next_power_of_two / WARP_SIZE;
|
||||
constexpr int WARP_BATCH = (next_power_of_two <= 128) ? 2 : 1;
|
||||
constexpr int ELEMENTS_PER_LDG_STG = (WARP_ITERATIONS < 4) ? 1 : 4;
|
||||
output_t *dst,
|
||||
const input_t *src,
|
||||
const acc_t scale,
|
||||
int micro_batch_size,
|
||||
int stride,
|
||||
int element_count)
|
||||
{
|
||||
// WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and
|
||||
// warp_size of method warp_softmax_forward_kernel.
|
||||
constexpr int next_power_of_two = 1 << log2_elements;
|
||||
constexpr int WARP_SIZE = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE;
|
||||
constexpr int WARP_ITERATIONS = next_power_of_two / WARP_SIZE;
|
||||
constexpr int WARP_BATCH = (next_power_of_two <= 128) ? 2 : 1;
|
||||
constexpr int ELEMENTS_PER_LDG_STG = (WARP_ITERATIONS < 4) ? 1 : 4;
|
||||
|
||||
int first_batch =
|
||||
(blockDim.y * blockIdx.y + threadIdx.y) * gridDim.x * WARP_BATCH +
|
||||
blockIdx.x;
|
||||
int local_seq = blockIdx.x + 1;
|
||||
int warp_iteration_limit =
|
||||
(local_seq + ELEMENTS_PER_LDG_STG * WARP_SIZE - 1) / WARP_SIZE;
|
||||
int first_batch = (blockDim.y * blockIdx.y + threadIdx.y) * gridDim.x * WARP_BATCH + blockIdx.x;
|
||||
int local_seq = blockIdx.x + 1;
|
||||
int warp_iteration_limit = (local_seq + ELEMENTS_PER_LDG_STG * WARP_SIZE - 1)/ WARP_SIZE;
|
||||
|
||||
// micro_batch_size might not be a multiple of WARP_BATCH. Check how
|
||||
// many batches have to computed within this WARP.
|
||||
int local_batches = micro_batch_size - first_batch;
|
||||
if (local_batches > WARP_BATCH) local_batches = WARP_BATCH;
|
||||
// micro_batch_size might not be a multiple of WARP_BATCH. Check how
|
||||
// many batches have to computed within this WARP.
|
||||
int local_batches = micro_batch_size - first_batch;
|
||||
if (local_batches > WARP_BATCH)
|
||||
local_batches = WARP_BATCH;
|
||||
|
||||
// there might be multiple batches per warp. compute the index within the
|
||||
// batch
|
||||
int local_idx = threadIdx.x;
|
||||
// there might be multiple batches per warp. compute the index within the batch
|
||||
int local_idx = threadIdx.x;
|
||||
|
||||
src += first_batch * stride + ELEMENTS_PER_LDG_STG * local_idx;
|
||||
dst += first_batch * stride + ELEMENTS_PER_LDG_STG * local_idx;
|
||||
src += first_batch * stride + ELEMENTS_PER_LDG_STG * local_idx;
|
||||
dst += first_batch * stride + ELEMENTS_PER_LDG_STG * local_idx;
|
||||
|
||||
// load data from global memory
|
||||
acc_t elements[WARP_BATCH][WARP_ITERATIONS];
|
||||
input_t temp_data[ELEMENTS_PER_LDG_STG];
|
||||
#pragma unroll
|
||||
for (int i = 0; i < WARP_BATCH; ++i) {
|
||||
int batch_element_count = (i >= local_batches) ? 0 : local_seq;
|
||||
// load data from global memory
|
||||
acc_t elements[WARP_BATCH][WARP_ITERATIONS];
|
||||
input_t temp_data[ELEMENTS_PER_LDG_STG];
|
||||
#pragma unroll
|
||||
for (int i = 0; i < WARP_BATCH; ++i) {
|
||||
int batch_element_count = (i >= local_batches) ? 0 : local_seq;
|
||||
|
||||
#pragma unroll
|
||||
for (int it = 0; it < WARP_ITERATIONS; it += ELEMENTS_PER_LDG_STG) {
|
||||
int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE;
|
||||
#pragma unroll
|
||||
for (int it = 0; it < WARP_ITERATIONS; it+=ELEMENTS_PER_LDG_STG) {
|
||||
int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE;
|
||||
|
||||
if (element_index < batch_element_count) {
|
||||
copy_vector<input_t, ELEMENTS_PER_LDG_STG>(
|
||||
temp_data, src + i * element_count * stride + it * WARP_SIZE);
|
||||
if (element_index < batch_element_count) {
|
||||
copy_vector<input_t, ELEMENTS_PER_LDG_STG>(temp_data, src + i*element_count*stride + it*WARP_SIZE);
|
||||
|
||||
#pragma unroll
|
||||
for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) {
|
||||
if ((element_index + element) < batch_element_count) {
|
||||
elements[i][it + element] = (acc_t)temp_data[element] * scale;
|
||||
} else {
|
||||
elements[i][it + element] = -std::numeric_limits<acc_t>::infinity();
|
||||
}
|
||||
#pragma unroll
|
||||
for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) {
|
||||
if ((element_index + element) < batch_element_count) {
|
||||
elements[i][it+element] = (acc_t)temp_data[element] * scale;
|
||||
} else {
|
||||
elements[i][it + element] = -std::numeric_limits<acc_t>::infinity();
|
||||
}
|
||||
}
|
||||
} else {
|
||||
#pragma unroll
|
||||
for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) {
|
||||
elements[i][it + element] = -std::numeric_limits<acc_t>::infinity();
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
#pragma unroll
|
||||
for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) {
|
||||
elements[i][it + element] = -std::numeric_limits<acc_t>::infinity();
|
||||
}
|
||||
|
||||
// compute max_value
|
||||
acc_t max_value[WARP_BATCH];
|
||||
#pragma unroll
|
||||
for (int i = 0; i < WARP_BATCH; ++i) {
|
||||
max_value[i] = elements[i][0];
|
||||
#pragma unroll
|
||||
for (int it = 1; it < WARP_ITERATIONS; ++it) {
|
||||
max_value[i] = (max_value[i] > elements[i][it]) ? max_value[i] : elements[i][it];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
warp_reduce<acc_t, WARP_BATCH, WARP_SIZE, Max>(max_value);
|
||||
|
||||
// compute max_value
|
||||
acc_t max_value[WARP_BATCH];
|
||||
#pragma unroll
|
||||
for (int i = 0; i < WARP_BATCH; ++i) {
|
||||
max_value[i] = elements[i][0];
|
||||
#pragma unroll
|
||||
for (int it = 1; it < WARP_ITERATIONS; ++it) {
|
||||
max_value[i] =
|
||||
(max_value[i] > elements[i][it]) ? max_value[i] : elements[i][it];
|
||||
}
|
||||
}
|
||||
warp_reduce<acc_t, WARP_BATCH, WARP_SIZE, Max>(max_value);
|
||||
|
||||
acc_t sum[WARP_BATCH]{0.0f};
|
||||
#pragma unroll
|
||||
for (int i = 0; i < WARP_BATCH; ++i) {
|
||||
#pragma unroll
|
||||
for (int it = 0; it < WARP_ITERATIONS; ++it) {
|
||||
if (it < warp_iteration_limit) {
|
||||
elements[i][it] = std::exp((elements[i][it] - max_value[i]));
|
||||
sum[i] += elements[i][it];
|
||||
}
|
||||
}
|
||||
}
|
||||
warp_reduce<acc_t, WARP_BATCH, WARP_SIZE, Add>(sum);
|
||||
|
||||
// store result
|
||||
output_t out[ELEMENTS_PER_LDG_STG];
|
||||
#pragma unroll
|
||||
for (int i = 0; i < WARP_BATCH; ++i) {
|
||||
if (i >= local_batches) break;
|
||||
#pragma unroll
|
||||
for (int it = 0; it < WARP_ITERATIONS; it += ELEMENTS_PER_LDG_STG) {
|
||||
int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE;
|
||||
|
||||
if (element_index < local_seq) {
|
||||
#pragma unroll
|
||||
for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) {
|
||||
if (element_index + element < local_seq) {
|
||||
out[element] = elements[i][it + element] / sum[i];
|
||||
} else {
|
||||
out[element] = 0;
|
||||
}
|
||||
acc_t sum[WARP_BATCH] { 0.0f };
|
||||
#pragma unroll
|
||||
for (int i = 0; i < WARP_BATCH; ++i) {
|
||||
#pragma unroll
|
||||
for (int it = 0; it < WARP_ITERATIONS; ++it) {
|
||||
if (it < warp_iteration_limit) {
|
||||
elements[i][it] = std::exp((elements[i][it] - max_value[i]));
|
||||
sum[i] += elements[i][it];
|
||||
}
|
||||
}
|
||||
}
|
||||
warp_reduce<acc_t, WARP_BATCH, WARP_SIZE, Add>(sum);
|
||||
|
||||
// store result
|
||||
output_t out[ELEMENTS_PER_LDG_STG];
|
||||
#pragma unroll
|
||||
for (int i = 0; i < WARP_BATCH; ++i) {
|
||||
if (i >= local_batches)
|
||||
break;
|
||||
#pragma unroll
|
||||
for (int it = 0; it < WARP_ITERATIONS; it+=ELEMENTS_PER_LDG_STG) {
|
||||
int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE;
|
||||
|
||||
if (element_index < local_seq) {
|
||||
|
||||
#pragma unroll
|
||||
for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) {
|
||||
if (element_index + element < local_seq) {
|
||||
out[element] = elements[i][it + element] / sum[i];
|
||||
} else {
|
||||
out[element] = 0;
|
||||
}
|
||||
}
|
||||
copy_vector<output_t, ELEMENTS_PER_LDG_STG>(dst + i * element_count * stride + it * WARP_SIZE, out);
|
||||
} else if (element_index < element_count) {
|
||||
copy_zero_vector<output_t, ELEMENTS_PER_LDG_STG>(dst + i * element_count * stride + it * WARP_SIZE);
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
copy_vector<output_t, ELEMENTS_PER_LDG_STG>(
|
||||
dst + i * element_count * stride + it * WARP_SIZE, out);
|
||||
} else if (element_index < element_count) {
|
||||
copy_zero_vector<output_t, ELEMENTS_PER_LDG_STG>(
|
||||
dst + i * element_count * stride + it * WARP_SIZE);
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename input_t, typename output_t, typename acc_t,
|
||||
int log2_elements>
|
||||
template <typename input_t, typename output_t, typename acc_t, int log2_elements>
|
||||
__global__ void scaled_upper_triang_masked_softmax_warp_backward(
|
||||
output_t *gradInput, input_t *grad, const input_t *output, acc_t scale,
|
||||
int micro_batch_size, int stride, int element_count) {
|
||||
// WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and
|
||||
// warp_size of method warp_softmax_backward_kernel.
|
||||
constexpr int next_power_of_two = 1 << log2_elements;
|
||||
constexpr int WARP_SIZE =
|
||||
(next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE;
|
||||
constexpr int WARP_ITERATIONS = next_power_of_two / WARP_SIZE;
|
||||
constexpr int WARP_BATCH = (next_power_of_two <= 128) ? 2 : 1;
|
||||
constexpr int ELEMENTS_PER_LDG_STG = (WARP_ITERATIONS < 4) ? 1 : 4;
|
||||
output_t *gradInput,
|
||||
input_t *grad,
|
||||
const input_t *output,
|
||||
acc_t scale,
|
||||
int micro_batch_size,
|
||||
int stride,
|
||||
int element_count)
|
||||
{
|
||||
// WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and
|
||||
// warp_size of method warp_softmax_backward_kernel.
|
||||
constexpr int next_power_of_two = 1 << log2_elements;
|
||||
constexpr int WARP_SIZE = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE;
|
||||
constexpr int WARP_ITERATIONS = next_power_of_two / WARP_SIZE;
|
||||
constexpr int WARP_BATCH = (next_power_of_two <= 128) ? 2 : 1;
|
||||
constexpr int ELEMENTS_PER_LDG_STG = (WARP_ITERATIONS < 4) ? 1 : 4;
|
||||
|
||||
int first_batch =
|
||||
(blockDim.y * blockIdx.y + threadIdx.y) * gridDim.x * WARP_BATCH +
|
||||
blockIdx.x;
|
||||
int local_seq = blockIdx.x + 1;
|
||||
int first_batch = (blockDim.y * blockIdx.y + threadIdx.y) * gridDim.x * WARP_BATCH + blockIdx.x;
|
||||
int local_seq = blockIdx.x + 1;
|
||||
|
||||
// micro_batch_size might not be a multiple of WARP_BATCH. Check how
|
||||
// many batches have to computed within this WARP.
|
||||
int local_batches = micro_batch_size - first_batch;
|
||||
if (local_batches > WARP_BATCH)
|
||||
local_batches = WARP_BATCH;
|
||||
|
||||
// micro_batch_size might not be a multiple of WARP_BATCH. Check how
|
||||
// many batches have to computed within this WARP.
|
||||
int local_batches = micro_batch_size - first_batch;
|
||||
if (local_batches > WARP_BATCH) local_batches = WARP_BATCH;
|
||||
// there might be multiple batches per warp. compute the index within the batch
|
||||
int local_idx = threadIdx.x;
|
||||
|
||||
// there might be multiple batches per warp. compute the index within the
|
||||
// batch
|
||||
int local_idx = threadIdx.x;
|
||||
// the first element to process by the current thread
|
||||
int thread_offset = first_batch * stride + ELEMENTS_PER_LDG_STG * local_idx;
|
||||
grad += thread_offset;
|
||||
output += thread_offset;
|
||||
gradInput += thread_offset;
|
||||
|
||||
// the first element to process by the current thread
|
||||
int thread_offset = first_batch * stride + ELEMENTS_PER_LDG_STG * local_idx;
|
||||
grad += thread_offset;
|
||||
output += thread_offset;
|
||||
gradInput += thread_offset;
|
||||
// load data from global memory
|
||||
acc_t grad_reg[WARP_BATCH][WARP_ITERATIONS] { 0.0f };
|
||||
acc_t output_reg[WARP_BATCH][WARP_ITERATIONS] { 0.0f };
|
||||
input_t temp_grad[ELEMENTS_PER_LDG_STG];
|
||||
input_t temp_output[ELEMENTS_PER_LDG_STG];
|
||||
#pragma unroll
|
||||
for (int i = 0; i < WARP_BATCH; ++i) {
|
||||
int batch_element_count = (i >= local_batches) ? 0 : local_seq;
|
||||
|
||||
// load data from global memory
|
||||
acc_t grad_reg[WARP_BATCH][WARP_ITERATIONS]{0.0f};
|
||||
acc_t output_reg[WARP_BATCH][WARP_ITERATIONS]{0.0f};
|
||||
input_t temp_grad[ELEMENTS_PER_LDG_STG];
|
||||
input_t temp_output[ELEMENTS_PER_LDG_STG];
|
||||
#pragma unroll
|
||||
for (int i = 0; i < WARP_BATCH; ++i) {
|
||||
int batch_element_count = (i >= local_batches) ? 0 : local_seq;
|
||||
#pragma unroll
|
||||
for (int it = 0; it < WARP_ITERATIONS; it+=ELEMENTS_PER_LDG_STG) {
|
||||
int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE;
|
||||
if (element_index < batch_element_count) {
|
||||
copy_vector<input_t, ELEMENTS_PER_LDG_STG>(temp_grad, grad + i * element_count * stride + it * WARP_SIZE);
|
||||
copy_vector<input_t, ELEMENTS_PER_LDG_STG>(temp_output, output + i * element_count * stride + it * WARP_SIZE);
|
||||
|
||||
#pragma unroll
|
||||
for (int it = 0; it < WARP_ITERATIONS; it += ELEMENTS_PER_LDG_STG) {
|
||||
int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE;
|
||||
if (element_index < batch_element_count) {
|
||||
copy_vector<input_t, ELEMENTS_PER_LDG_STG>(
|
||||
temp_grad, grad + i * element_count * stride + it * WARP_SIZE);
|
||||
copy_vector<input_t, ELEMENTS_PER_LDG_STG>(
|
||||
temp_output, output + i * element_count * stride + it * WARP_SIZE);
|
||||
|
||||
#pragma unroll
|
||||
for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) {
|
||||
if (element_index + element < batch_element_count) {
|
||||
output_reg[i][it + element] = (acc_t)temp_output[element];
|
||||
}
|
||||
#pragma unroll
|
||||
for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) {
|
||||
if (element_index + element < batch_element_count) {
|
||||
output_reg[i][it + element] = (acc_t)temp_output[element];
|
||||
}
|
||||
}
|
||||
#pragma unroll
|
||||
for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) {
|
||||
if (element_index + element < batch_element_count) {
|
||||
grad_reg[i][it + element] = (acc_t)temp_grad[element] * output_reg[i][it + element];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
#pragma unroll
|
||||
for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) {
|
||||
if (element_index + element < batch_element_count) {
|
||||
grad_reg[i][it + element] =
|
||||
(acc_t)temp_grad[element] * output_reg[i][it + element];
|
||||
}
|
||||
}
|
||||
|
||||
acc_t sum[WARP_BATCH];
|
||||
#pragma unroll
|
||||
for (int i = 0; i < WARP_BATCH; ++i) {
|
||||
sum[i] = grad_reg[i][0];
|
||||
#pragma unroll
|
||||
for (int it = 1; it < WARP_ITERATIONS; ++it) {
|
||||
sum[i] += grad_reg[i][it];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
warp_reduce<acc_t, WARP_BATCH, WARP_SIZE, Add>(sum);
|
||||
|
||||
acc_t sum[WARP_BATCH];
|
||||
#pragma unroll
|
||||
for (int i = 0; i < WARP_BATCH; ++i) {
|
||||
sum[i] = grad_reg[i][0];
|
||||
#pragma unroll
|
||||
for (int it = 1; it < WARP_ITERATIONS; ++it) {
|
||||
sum[i] += grad_reg[i][it];
|
||||
}
|
||||
}
|
||||
warp_reduce<acc_t, WARP_BATCH, WARP_SIZE, Add>(sum);
|
||||
|
||||
// store result
|
||||
#pragma unroll
|
||||
for (int i = 0; i < WARP_BATCH; ++i) {
|
||||
if (i >= local_batches) break;
|
||||
#pragma unroll
|
||||
for (int it = 0; it < WARP_ITERATIONS; it += ELEMENTS_PER_LDG_STG) {
|
||||
int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE;
|
||||
if (element_index < element_count) {
|
||||
// compute gradients
|
||||
output_t out[ELEMENTS_PER_LDG_STG];
|
||||
#pragma unroll
|
||||
for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) {
|
||||
out[element] =
|
||||
(output_t)(scale * (grad_reg[i][it + element] -
|
||||
output_reg[i][it + element] * sum[i]));
|
||||
// store result
|
||||
#pragma unroll
|
||||
for (int i = 0; i < WARP_BATCH; ++i) {
|
||||
if (i >= local_batches)
|
||||
break;
|
||||
#pragma unroll
|
||||
for (int it = 0; it < WARP_ITERATIONS; it+=ELEMENTS_PER_LDG_STG) {
|
||||
int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE;
|
||||
if (element_index < element_count) {
|
||||
// compute gradients
|
||||
output_t out[ELEMENTS_PER_LDG_STG];
|
||||
#pragma unroll
|
||||
for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) {
|
||||
out[element] = (output_t)(scale * (grad_reg[i][it + element] - output_reg[i][it + element] * sum[i]));
|
||||
}
|
||||
copy_vector<output_t, ELEMENTS_PER_LDG_STG>(gradInput + i * element_count * stride + it * WARP_SIZE, out);
|
||||
}
|
||||
}
|
||||
copy_vector<output_t, ELEMENTS_PER_LDG_STG>(
|
||||
gradInput + i * element_count * stride + it * WARP_SIZE, out);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // end of anonymous namespace
|
||||
} // end of anonymous namespace
|
||||
|
||||
template <typename input_t, typename output_t, typename acc_t>
|
||||
template<typename input_t, typename output_t, typename acc_t>
|
||||
void dispatch_scaled_upper_triang_masked_softmax_forward(
|
||||
output_t *dst, const input_t *src, const input_t scale,
|
||||
int softmax_elements, int softmax_elements_stride, int attn_batches) {
|
||||
TORCH_INTERNAL_ASSERT(softmax_elements >= 0 && softmax_elements <= 2048);
|
||||
if (softmax_elements == 0) {
|
||||
return;
|
||||
} else {
|
||||
int log2_elements = log2_ceil(softmax_elements);
|
||||
const int next_power_of_two = 1 << log2_elements;
|
||||
int seq_len = softmax_elements;
|
||||
int batch_count = attn_batches * seq_len;
|
||||
output_t *dst,
|
||||
const input_t *src,
|
||||
const input_t scale,
|
||||
int softmax_elements,
|
||||
int softmax_elements_stride,
|
||||
int attn_batches)
|
||||
{
|
||||
TORCH_INTERNAL_ASSERT(softmax_elements >= 0 && softmax_elements <= 2048 );
|
||||
if (softmax_elements == 0) {
|
||||
return;
|
||||
} else {
|
||||
int log2_elements = log2_ceil(softmax_elements);
|
||||
const int next_power_of_two = 1 << log2_elements;
|
||||
int seq_len = softmax_elements;
|
||||
int batch_count = attn_batches * seq_len;
|
||||
|
||||
// This value must match the WARP_SIZE constexpr value computed inside
|
||||
// softmax_warp_forward.
|
||||
int warp_size =
|
||||
(next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE;
|
||||
// This value must match the WARP_SIZE constexpr value computed inside softmax_warp_forward.
|
||||
int warp_size = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE;
|
||||
|
||||
// This value must match the WARP_BATCH constexpr value computed inside
|
||||
// softmax_warp_forward.
|
||||
int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1;
|
||||
// This value must match the WARP_BATCH constexpr value computed inside softmax_warp_forward.
|
||||
int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1;
|
||||
|
||||
// use 128 threads per block to maximimize gpu utilization
|
||||
constexpr int threads_per_block = 128;
|
||||
// use 128 threads per block to maximimize gpu utilization
|
||||
constexpr int threads_per_block = 128;
|
||||
|
||||
int warps_per_block = (threads_per_block / warp_size);
|
||||
int batches_per_block = warps_per_block * batches_per_warp;
|
||||
TORCH_INTERNAL_ASSERT(attn_batches % batches_per_block == 0);
|
||||
int warps_per_block = (threads_per_block / warp_size);
|
||||
int batches_per_block = warps_per_block * batches_per_warp;
|
||||
TORCH_INTERNAL_ASSERT(attn_batches % batches_per_block == 0);
|
||||
|
||||
int blocks_per_seq = attn_batches / batches_per_block;
|
||||
dim3 blocks(seq_len, blocks_per_seq, 1);
|
||||
dim3 threads(warp_size, warps_per_block, 1);
|
||||
// Launch code would be more elegant if C++ supported FOR CONSTEXPR
|
||||
switch (log2_elements) {
|
||||
case 0: // 1
|
||||
scaled_upper_triang_masked_softmax_warp_forward<input_t, output_t,
|
||||
acc_t, 0>
|
||||
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
|
||||
dst, src, scale, batch_count, softmax_elements_stride,
|
||||
softmax_elements);
|
||||
break;
|
||||
case 1: // 2
|
||||
scaled_upper_triang_masked_softmax_warp_forward<input_t, output_t,
|
||||
acc_t, 1>
|
||||
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
|
||||
dst, src, scale, batch_count, softmax_elements_stride,
|
||||
softmax_elements);
|
||||
break;
|
||||
case 2: // 4
|
||||
scaled_upper_triang_masked_softmax_warp_forward<input_t, output_t,
|
||||
acc_t, 2>
|
||||
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
|
||||
dst, src, scale, batch_count, softmax_elements_stride,
|
||||
softmax_elements);
|
||||
break;
|
||||
case 3: // 8
|
||||
scaled_upper_triang_masked_softmax_warp_forward<input_t, output_t,
|
||||
acc_t, 3>
|
||||
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
|
||||
dst, src, scale, batch_count, softmax_elements_stride,
|
||||
softmax_elements);
|
||||
break;
|
||||
case 4: // 16
|
||||
scaled_upper_triang_masked_softmax_warp_forward<input_t, output_t,
|
||||
acc_t, 4>
|
||||
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
|
||||
dst, src, scale, batch_count, softmax_elements_stride,
|
||||
softmax_elements);
|
||||
break;
|
||||
case 5: // 32
|
||||
scaled_upper_triang_masked_softmax_warp_forward<input_t, output_t,
|
||||
acc_t, 5>
|
||||
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
|
||||
dst, src, scale, batch_count, softmax_elements_stride,
|
||||
softmax_elements);
|
||||
break;
|
||||
case 6: // 64
|
||||
scaled_upper_triang_masked_softmax_warp_forward<input_t, output_t,
|
||||
acc_t, 6>
|
||||
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
|
||||
dst, src, scale, batch_count, softmax_elements_stride,
|
||||
softmax_elements);
|
||||
break;
|
||||
case 7: // 128
|
||||
scaled_upper_triang_masked_softmax_warp_forward<input_t, output_t,
|
||||
acc_t, 7>
|
||||
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
|
||||
dst, src, scale, batch_count, softmax_elements_stride,
|
||||
softmax_elements);
|
||||
break;
|
||||
case 8: // 256
|
||||
scaled_upper_triang_masked_softmax_warp_forward<input_t, output_t,
|
||||
acc_t, 8>
|
||||
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
|
||||
dst, src, scale, batch_count, softmax_elements_stride,
|
||||
softmax_elements);
|
||||
break;
|
||||
case 9: // 512
|
||||
scaled_upper_triang_masked_softmax_warp_forward<input_t, output_t,
|
||||
acc_t, 9>
|
||||
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
|
||||
dst, src, scale, batch_count, softmax_elements_stride,
|
||||
softmax_elements);
|
||||
break;
|
||||
case 10: // 1024
|
||||
scaled_upper_triang_masked_softmax_warp_forward<input_t, output_t,
|
||||
acc_t, 10>
|
||||
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
|
||||
dst, src, scale, batch_count, softmax_elements_stride,
|
||||
softmax_elements);
|
||||
break;
|
||||
case 11: // 2048
|
||||
scaled_upper_triang_masked_softmax_warp_forward<input_t, output_t,
|
||||
acc_t, 11>
|
||||
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
|
||||
dst, src, scale, batch_count, softmax_elements_stride,
|
||||
softmax_elements);
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
int blocks_per_seq = attn_batches / batches_per_block;
|
||||
dim3 blocks(seq_len, blocks_per_seq, 1);
|
||||
dim3 threads(warp_size, warps_per_block, 1);
|
||||
// Launch code would be more elegant if C++ supported FOR CONSTEXPR
|
||||
switch (log2_elements) {
|
||||
case 0: // 1
|
||||
scaled_upper_triang_masked_softmax_warp_forward<input_t, output_t, acc_t, 0>
|
||||
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements);
|
||||
break;
|
||||
case 1: // 2
|
||||
scaled_upper_triang_masked_softmax_warp_forward<input_t, output_t, acc_t, 1>
|
||||
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements);
|
||||
break;
|
||||
case 2: // 4
|
||||
scaled_upper_triang_masked_softmax_warp_forward<input_t, output_t, acc_t, 2>
|
||||
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements);
|
||||
break;
|
||||
case 3: // 8
|
||||
scaled_upper_triang_masked_softmax_warp_forward<input_t, output_t, acc_t, 3>
|
||||
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements);
|
||||
break;
|
||||
case 4: // 16
|
||||
scaled_upper_triang_masked_softmax_warp_forward<input_t, output_t, acc_t, 4>
|
||||
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements);
|
||||
break;
|
||||
case 5: // 32
|
||||
scaled_upper_triang_masked_softmax_warp_forward<input_t, output_t, acc_t, 5>
|
||||
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements);
|
||||
break;
|
||||
case 6: // 64
|
||||
scaled_upper_triang_masked_softmax_warp_forward<input_t, output_t, acc_t, 6>
|
||||
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements);
|
||||
break;
|
||||
case 7: // 128
|
||||
scaled_upper_triang_masked_softmax_warp_forward<input_t, output_t, acc_t, 7>
|
||||
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements);
|
||||
break;
|
||||
case 8: // 256
|
||||
scaled_upper_triang_masked_softmax_warp_forward<input_t, output_t, acc_t, 8>
|
||||
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements);
|
||||
break;
|
||||
case 9: // 512
|
||||
scaled_upper_triang_masked_softmax_warp_forward<input_t, output_t, acc_t, 9>
|
||||
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements);
|
||||
break;
|
||||
case 10: // 1024
|
||||
scaled_upper_triang_masked_softmax_warp_forward<input_t, output_t, acc_t, 10>
|
||||
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements);
|
||||
break;
|
||||
case 11: // 2048
|
||||
scaled_upper_triang_masked_softmax_warp_forward<input_t, output_t, acc_t, 11>
|
||||
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements);
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename input_t, typename output_t, typename acc_t>
|
||||
template<typename input_t, typename output_t, typename acc_t>
|
||||
void dispatch_scaled_upper_triang_masked_softmax_backward(
|
||||
output_t *grad_input, input_t *grad, const input_t *output,
|
||||
const acc_t scale, int softmax_elements, int softmax_elements_stride,
|
||||
int attn_batches) {
|
||||
TORCH_INTERNAL_ASSERT(softmax_elements >= 0 && softmax_elements <= 2048);
|
||||
if (softmax_elements == 0) {
|
||||
return;
|
||||
} else {
|
||||
int log2_elements = log2_ceil(softmax_elements);
|
||||
const int next_power_of_two = 1 << log2_elements;
|
||||
int seq_len = softmax_elements;
|
||||
int batch_count = attn_batches * seq_len;
|
||||
output_t *grad_input,
|
||||
input_t *grad,
|
||||
const input_t *output,
|
||||
const acc_t scale,
|
||||
int softmax_elements,
|
||||
int softmax_elements_stride,
|
||||
int attn_batches)
|
||||
{
|
||||
TORCH_INTERNAL_ASSERT( softmax_elements >= 0 && softmax_elements <= 2048 );
|
||||
if (softmax_elements == 0) {
|
||||
return;
|
||||
} else {
|
||||
int log2_elements = log2_ceil(softmax_elements);
|
||||
const int next_power_of_two = 1 << log2_elements;
|
||||
int seq_len = softmax_elements;
|
||||
int batch_count = attn_batches * seq_len;
|
||||
|
||||
// This value must match the WARP_SIZE constexpr value computed inside
|
||||
// softmax_warp_backward.
|
||||
int warp_size =
|
||||
(next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE;
|
||||
// This value must match the WARP_SIZE constexpr value computed inside softmax_warp_backward.
|
||||
int warp_size = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE;
|
||||
|
||||
// This value must match the WARP_BATCH constexpr value computed inside
|
||||
// softmax_warp_backward.
|
||||
int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1;
|
||||
// This value must match the WARP_BATCH constexpr value computed inside softmax_warp_backward.
|
||||
int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1;
|
||||
|
||||
// use 128 threads per block to maximimize gpu utilization
|
||||
constexpr int threads_per_block = 128;
|
||||
// use 128 threads per block to maximimize gpu utilization
|
||||
constexpr int threads_per_block = 128;
|
||||
|
||||
int warps_per_block = (threads_per_block / warp_size);
|
||||
int batches_per_block = warps_per_block * batches_per_warp;
|
||||
TORCH_INTERNAL_ASSERT(attn_batches % batches_per_block == 0);
|
||||
int warps_per_block = (threads_per_block / warp_size);
|
||||
int batches_per_block = warps_per_block * batches_per_warp;
|
||||
TORCH_INTERNAL_ASSERT(attn_batches % batches_per_block == 0);
|
||||
|
||||
int blocks_per_seq = attn_batches / batches_per_block;
|
||||
dim3 blocks(seq_len, blocks_per_seq, 1);
|
||||
dim3 threads(warp_size, warps_per_block, 1);
|
||||
// Launch code would be more elegant if C++ supported FOR CONSTEXPR
|
||||
switch (log2_elements) {
|
||||
case 0: // 1
|
||||
scaled_upper_triang_masked_softmax_warp_backward<input_t, output_t,
|
||||
acc_t, 0>
|
||||
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
|
||||
grad_input, grad, output, scale, batch_count,
|
||||
softmax_elements_stride, softmax_elements);
|
||||
break;
|
||||
case 1: // 2
|
||||
scaled_upper_triang_masked_softmax_warp_backward<input_t, output_t,
|
||||
acc_t, 1>
|
||||
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
|
||||
grad_input, grad, output, scale, batch_count,
|
||||
softmax_elements_stride, softmax_elements);
|
||||
break;
|
||||
case 2: // 4
|
||||
scaled_upper_triang_masked_softmax_warp_backward<input_t, output_t,
|
||||
acc_t, 2>
|
||||
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
|
||||
grad_input, grad, output, scale, batch_count,
|
||||
softmax_elements_stride, softmax_elements);
|
||||
break;
|
||||
case 3: // 8
|
||||
scaled_upper_triang_masked_softmax_warp_backward<input_t, output_t,
|
||||
acc_t, 3>
|
||||
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
|
||||
grad_input, grad, output, scale, batch_count,
|
||||
softmax_elements_stride, softmax_elements);
|
||||
break;
|
||||
case 4: // 16
|
||||
scaled_upper_triang_masked_softmax_warp_backward<input_t, output_t,
|
||||
acc_t, 4>
|
||||
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
|
||||
grad_input, grad, output, scale, batch_count,
|
||||
softmax_elements_stride, softmax_elements);
|
||||
break;
|
||||
case 5: // 32
|
||||
scaled_upper_triang_masked_softmax_warp_backward<input_t, output_t,
|
||||
acc_t, 5>
|
||||
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
|
||||
grad_input, grad, output, scale, batch_count,
|
||||
softmax_elements_stride, softmax_elements);
|
||||
break;
|
||||
case 6: // 64
|
||||
scaled_upper_triang_masked_softmax_warp_backward<input_t, output_t,
|
||||
acc_t, 6>
|
||||
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
|
||||
grad_input, grad, output, scale, batch_count,
|
||||
softmax_elements_stride, softmax_elements);
|
||||
break;
|
||||
case 7: // 128
|
||||
scaled_upper_triang_masked_softmax_warp_backward<input_t, output_t,
|
||||
acc_t, 7>
|
||||
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
|
||||
grad_input, grad, output, scale, batch_count,
|
||||
softmax_elements_stride, softmax_elements);
|
||||
break;
|
||||
case 8: // 256
|
||||
scaled_upper_triang_masked_softmax_warp_backward<input_t, output_t,
|
||||
acc_t, 8>
|
||||
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
|
||||
grad_input, grad, output, scale, batch_count,
|
||||
softmax_elements_stride, softmax_elements);
|
||||
break;
|
||||
case 9: // 512
|
||||
scaled_upper_triang_masked_softmax_warp_backward<input_t, output_t,
|
||||
acc_t, 9>
|
||||
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
|
||||
grad_input, grad, output, scale, batch_count,
|
||||
softmax_elements_stride, softmax_elements);
|
||||
break;
|
||||
case 10: // 1024
|
||||
scaled_upper_triang_masked_softmax_warp_backward<input_t, output_t,
|
||||
acc_t, 10>
|
||||
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
|
||||
grad_input, grad, output, scale, batch_count,
|
||||
softmax_elements_stride, softmax_elements);
|
||||
break;
|
||||
case 11: // 2048
|
||||
scaled_upper_triang_masked_softmax_warp_backward<input_t, output_t,
|
||||
acc_t, 11>
|
||||
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
|
||||
grad_input, grad, output, scale, batch_count,
|
||||
softmax_elements_stride, softmax_elements);
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
int blocks_per_seq = attn_batches / batches_per_block;
|
||||
dim3 blocks(seq_len, blocks_per_seq, 1);
|
||||
dim3 threads(warp_size, warps_per_block, 1);
|
||||
// Launch code would be more elegant if C++ supported FOR CONSTEXPR
|
||||
switch (log2_elements) {
|
||||
case 0: // 1
|
||||
scaled_upper_triang_masked_softmax_warp_backward<input_t, output_t, acc_t, 0>
|
||||
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements);
|
||||
break;
|
||||
case 1: // 2
|
||||
scaled_upper_triang_masked_softmax_warp_backward<input_t, output_t, acc_t, 1>
|
||||
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements);
|
||||
break;
|
||||
case 2: // 4
|
||||
scaled_upper_triang_masked_softmax_warp_backward<input_t, output_t, acc_t, 2>
|
||||
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements);
|
||||
break;
|
||||
case 3: // 8
|
||||
scaled_upper_triang_masked_softmax_warp_backward<input_t, output_t, acc_t, 3>
|
||||
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements);
|
||||
break;
|
||||
case 4: // 16
|
||||
scaled_upper_triang_masked_softmax_warp_backward<input_t, output_t, acc_t, 4>
|
||||
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements);
|
||||
break;
|
||||
case 5: // 32
|
||||
scaled_upper_triang_masked_softmax_warp_backward<input_t, output_t, acc_t, 5>
|
||||
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements);
|
||||
break;
|
||||
case 6: // 64
|
||||
scaled_upper_triang_masked_softmax_warp_backward<input_t, output_t, acc_t, 6>
|
||||
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements);
|
||||
break;
|
||||
case 7: // 128
|
||||
scaled_upper_triang_masked_softmax_warp_backward<input_t, output_t, acc_t, 7>
|
||||
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements);
|
||||
break;
|
||||
case 8: // 256
|
||||
scaled_upper_triang_masked_softmax_warp_backward<input_t, output_t, acc_t, 8>
|
||||
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements);
|
||||
break;
|
||||
case 9: // 512
|
||||
scaled_upper_triang_masked_softmax_warp_backward<input_t, output_t, acc_t, 9>
|
||||
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements);
|
||||
break;
|
||||
case 10: // 1024
|
||||
scaled_upper_triang_masked_softmax_warp_backward<input_t, output_t, acc_t, 10>
|
||||
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements);
|
||||
break;
|
||||
case 11: // 2048
|
||||
scaled_upper_triang_masked_softmax_warp_backward<input_t, output_t, acc_t, 11>
|
||||
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements);
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,63 +1,76 @@
|
|||
#include <ATen/ATen.h>
|
||||
|
||||
#include "compat.h"
|
||||
|
||||
#define DISPATCH_HALF_AND_BFLOAT(TYPE, NAME, ...) \
|
||||
switch (TYPE) { \
|
||||
case at::ScalarType::Half: { \
|
||||
using scalar_t = at::Half; \
|
||||
__VA_ARGS__; \
|
||||
break; \
|
||||
} \
|
||||
case at::ScalarType::BFloat16: { \
|
||||
using scalar_t = at::BFloat16; \
|
||||
__VA_ARGS__; \
|
||||
break; \
|
||||
} \
|
||||
default: \
|
||||
AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
|
||||
}
|
||||
|
||||
#define DISPATCH_HALF_AND_BFLOAT(TYPE, NAME, ...) \
|
||||
switch(TYPE) \
|
||||
{ \
|
||||
case at::ScalarType::Half: \
|
||||
{ \
|
||||
using scalar_t = at::Half; \
|
||||
__VA_ARGS__; \
|
||||
break; \
|
||||
} \
|
||||
case at::ScalarType::BFloat16: \
|
||||
{ \
|
||||
using scalar_t = at::BFloat16; \
|
||||
__VA_ARGS__; \
|
||||
break; \
|
||||
} \
|
||||
default: \
|
||||
AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
|
||||
}
|
||||
|
||||
|
||||
|
||||
#define DISPATCH_FLOAT_HALF_AND_BFLOAT_INOUT_TYPES(TYPEIN, TYPEOUT, NAME, ...) \
|
||||
switch (TYPEIN) { \
|
||||
case at::ScalarType::Float: { \
|
||||
using scalar_t_in = float; \
|
||||
switch (TYPEOUT) { \
|
||||
case at::ScalarType::Float: { \
|
||||
using scalar_t_out = float; \
|
||||
__VA_ARGS__; \
|
||||
break; \
|
||||
} \
|
||||
case at::ScalarType::Half: { \
|
||||
using scalar_t_out = at::Half; \
|
||||
__VA_ARGS__; \
|
||||
break; \
|
||||
} \
|
||||
case at::ScalarType::BFloat16: { \
|
||||
using scalar_t_out = at::BFloat16; \
|
||||
__VA_ARGS__; \
|
||||
break; \
|
||||
} \
|
||||
default: \
|
||||
AT_ERROR(#NAME, " not implemented for '", toString(TYPEOUT), "'"); \
|
||||
} \
|
||||
break; \
|
||||
} \
|
||||
case at::ScalarType::Half: { \
|
||||
using scalar_t_in = at::Half; \
|
||||
using scalar_t_out = at::Half; \
|
||||
__VA_ARGS__; \
|
||||
break; \
|
||||
} \
|
||||
case at::ScalarType::BFloat16: { \
|
||||
using scalar_t_in = at::BFloat16; \
|
||||
using scalar_t_out = at::BFloat16; \
|
||||
__VA_ARGS__; \
|
||||
break; \
|
||||
} \
|
||||
default: \
|
||||
AT_ERROR(#NAME, " not implemented for '", toString(TYPEIN), "'"); \
|
||||
}
|
||||
switch(TYPEIN) \
|
||||
{ \
|
||||
case at::ScalarType::Float: \
|
||||
{ \
|
||||
using scalar_t_in = float; \
|
||||
switch(TYPEOUT) \
|
||||
{ \
|
||||
case at::ScalarType::Float: \
|
||||
{ \
|
||||
using scalar_t_out = float; \
|
||||
__VA_ARGS__; \
|
||||
break; \
|
||||
} \
|
||||
case at::ScalarType::Half: \
|
||||
{ \
|
||||
using scalar_t_out = at::Half; \
|
||||
__VA_ARGS__; \
|
||||
break; \
|
||||
} \
|
||||
case at::ScalarType::BFloat16: \
|
||||
{ \
|
||||
using scalar_t_out = at::BFloat16; \
|
||||
__VA_ARGS__; \
|
||||
break; \
|
||||
} \
|
||||
default: \
|
||||
AT_ERROR(#NAME, " not implemented for '", toString(TYPEOUT), "'"); \
|
||||
} \
|
||||
break; \
|
||||
} \
|
||||
case at::ScalarType::Half: \
|
||||
{ \
|
||||
using scalar_t_in = at::Half; \
|
||||
using scalar_t_out = at::Half; \
|
||||
__VA_ARGS__; \
|
||||
break; \
|
||||
} \
|
||||
case at::ScalarType::BFloat16: \
|
||||
{ \
|
||||
using scalar_t_in = at::BFloat16; \
|
||||
using scalar_t_out = at::BFloat16; \
|
||||
__VA_ARGS__; \
|
||||
break; \
|
||||
} \
|
||||
default: \
|
||||
AT_ERROR(#NAME, " not implemented for '", toString(TYPEIN), "'"); \
|
||||
}
|
||||
|
||||
// Forward/backward compatiblity hack around
|
||||
// https://github.com/pytorch/pytorch/commit/3aeb78079bcd68282fe9117088e138b77318e288
|
||||
|
@ -68,191 +81,222 @@
|
|||
// TypeShim(const at::Type& type) : payload(type) {}
|
||||
// // Enable trivial conversion to a const at::Type& for pre-3aeb78
|
||||
// operator const at::Type&(){ return payload; };
|
||||
// // Enable dispatch switch statements to take *this directly for post-3aeb78
|
||||
// // Enable dispatch switch statements to take *this directly for post-3aeb78
|
||||
// //operator at::ScalarType(){ return payload.; };
|
||||
// };
|
||||
|
||||
#define DISPATCH_FLOAT_AND_HALF(TYPE, LEVEL, NAME, ...) \
|
||||
switch (TYPE) { \
|
||||
case at::ScalarType::Float: { \
|
||||
using scalar_t_##LEVEL = float; \
|
||||
__VA_ARGS__; \
|
||||
break; \
|
||||
} \
|
||||
case at::ScalarType::Half: { \
|
||||
using scalar_t_##LEVEL = at::Half; \
|
||||
__VA_ARGS__; \
|
||||
break; \
|
||||
} \
|
||||
default: \
|
||||
AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
|
||||
}
|
||||
#define DISPATCH_FLOAT_AND_HALF(TYPE, LEVEL, NAME, ...) \
|
||||
switch (TYPE) \
|
||||
{ \
|
||||
case at::ScalarType::Float: \
|
||||
{ \
|
||||
using scalar_t_##LEVEL = float; \
|
||||
__VA_ARGS__; \
|
||||
break; \
|
||||
} \
|
||||
case at::ScalarType::Half: \
|
||||
{ \
|
||||
using scalar_t_##LEVEL = at::Half; \
|
||||
__VA_ARGS__; \
|
||||
break; \
|
||||
} \
|
||||
default: \
|
||||
AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
|
||||
}
|
||||
|
||||
#define DISPATCH_FLOAT_HALF_AND_BYTE(TYPE, LEVEL, NAME, ...) \
|
||||
switch (TYPE) { \
|
||||
case at::ScalarType::Float: { \
|
||||
using scalar_t_##LEVEL = float; \
|
||||
__VA_ARGS__; \
|
||||
break; \
|
||||
} \
|
||||
case at::ScalarType::Half: { \
|
||||
using scalar_t_##LEVEL = at::Half; \
|
||||
__VA_ARGS__; \
|
||||
break; \
|
||||
} \
|
||||
case at::ScalarType::Byte: { \
|
||||
using scalar_t_##LEVEL = uint8_t; \
|
||||
__VA_ARGS__; \
|
||||
break; \
|
||||
} \
|
||||
default: \
|
||||
AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
|
||||
}
|
||||
#define DISPATCH_FLOAT_HALF_AND_BYTE(TYPE, LEVEL, NAME, ...) \
|
||||
switch (TYPE) \
|
||||
{ \
|
||||
case at::ScalarType::Float: \
|
||||
{ \
|
||||
using scalar_t_##LEVEL = float; \
|
||||
__VA_ARGS__; \
|
||||
break; \
|
||||
} \
|
||||
case at::ScalarType::Half: \
|
||||
{ \
|
||||
using scalar_t_##LEVEL = at::Half; \
|
||||
__VA_ARGS__; \
|
||||
break; \
|
||||
} \
|
||||
case at::ScalarType::Byte: \
|
||||
{ \
|
||||
using scalar_t_##LEVEL = uint8_t; \
|
||||
__VA_ARGS__; \
|
||||
break; \
|
||||
} \
|
||||
default: \
|
||||
AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
|
||||
}
|
||||
|
||||
#define DISPATCH_DOUBLE_FLOAT_AND_HALF(TYPE, LEVEL, NAME, ...) \
|
||||
switch (TYPE) { \
|
||||
case at::ScalarType::Double: { \
|
||||
using scalar_t_##LEVEL = double; \
|
||||
__VA_ARGS__; \
|
||||
break; \
|
||||
} \
|
||||
case at::ScalarType::Float: { \
|
||||
using scalar_t_##LEVEL = float; \
|
||||
__VA_ARGS__; \
|
||||
break; \
|
||||
} \
|
||||
case at::ScalarType::Half: { \
|
||||
using scalar_t_##LEVEL = at::Half; \
|
||||
__VA_ARGS__; \
|
||||
break; \
|
||||
} \
|
||||
default: \
|
||||
AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
|
||||
}
|
||||
#define DISPATCH_DOUBLE_FLOAT_AND_HALF(TYPE, LEVEL, NAME, ...) \
|
||||
switch (TYPE) \
|
||||
{ \
|
||||
case at::ScalarType::Double: \
|
||||
{ \
|
||||
using scalar_t_##LEVEL = double; \
|
||||
__VA_ARGS__; \
|
||||
break; \
|
||||
} \
|
||||
case at::ScalarType::Float: \
|
||||
{ \
|
||||
using scalar_t_##LEVEL = float; \
|
||||
__VA_ARGS__; \
|
||||
break; \
|
||||
} \
|
||||
case at::ScalarType::Half: \
|
||||
{ \
|
||||
using scalar_t_##LEVEL = at::Half; \
|
||||
__VA_ARGS__; \
|
||||
break; \
|
||||
} \
|
||||
default: \
|
||||
AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
|
||||
}
|
||||
|
||||
#define DISPATCH_DOUBLE_AND_FLOAT(TYPE, LEVEL, NAME, ...) \
|
||||
switch (TYPE) { \
|
||||
case at::ScalarType::Double: { \
|
||||
using scalar_t_##LEVEL = double; \
|
||||
__VA_ARGS__; \
|
||||
break; \
|
||||
} \
|
||||
case at::ScalarType::Float: { \
|
||||
using scalar_t_##LEVEL = float; \
|
||||
__VA_ARGS__; \
|
||||
break; \
|
||||
} \
|
||||
default: \
|
||||
AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
|
||||
}
|
||||
#define DISPATCH_DOUBLE_AND_FLOAT(TYPE, LEVEL, NAME, ...) \
|
||||
switch (TYPE) \
|
||||
{ \
|
||||
case at::ScalarType::Double: \
|
||||
{ \
|
||||
using scalar_t_##LEVEL = double; \
|
||||
__VA_ARGS__; \
|
||||
break; \
|
||||
} \
|
||||
case at::ScalarType::Float: \
|
||||
{ \
|
||||
using scalar_t_##LEVEL = float; \
|
||||
__VA_ARGS__; \
|
||||
break; \
|
||||
} \
|
||||
default: \
|
||||
AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
|
||||
}
|
||||
|
||||
#define DISPATCH_FLOAT_AND_HALF_FOR_G_P(GTYPE, PTYPE, LEVEL, NAME, ...) \
|
||||
if (GTYPE == at::ScalarType::Float && PTYPE == at::ScalarType::Float) { \
|
||||
using g_scalar_t_##LEVEL = float; \
|
||||
using p_scalar_t_##LEVEL = float; \
|
||||
__VA_ARGS__; \
|
||||
} else if (GTYPE == at::ScalarType::Float && \
|
||||
PTYPE == at::ScalarType::Half) { \
|
||||
using g_scalar_t_##LEVEL = float; \
|
||||
using p_scalar_t_##LEVEL = at::Half; \
|
||||
__VA_ARGS__; \
|
||||
} else if (GTYPE == at::ScalarType::Half && \
|
||||
PTYPE == at::ScalarType::Float) { \
|
||||
using g_scalar_t_##LEVEL = at::Half; \
|
||||
using p_scalar_t_##LEVEL = float; \
|
||||
__VA_ARGS__; \
|
||||
} else if (GTYPE == at::ScalarType::Half && PTYPE == at::ScalarType::Half) { \
|
||||
using g_scalar_t_##LEVEL = at::Half; \
|
||||
using p_scalar_t_##LEVEL = at::Half; \
|
||||
__VA_ARGS__; \
|
||||
} else { \
|
||||
AT_ERROR(#NAME, "not implemented for '", toString(GTYPE), toString(PTYPE), \
|
||||
"'"); \
|
||||
}
|
||||
#define DISPATCH_FLOAT_AND_HALF_FOR_G_P(GTYPE, PTYPE, LEVEL, NAME, ...) \
|
||||
if (GTYPE == at::ScalarType::Float && PTYPE == at::ScalarType::Float) \
|
||||
{ \
|
||||
using g_scalar_t_##LEVEL = float; \
|
||||
using p_scalar_t_##LEVEL = float; \
|
||||
__VA_ARGS__; \
|
||||
} \
|
||||
else if (GTYPE == at::ScalarType::Float && PTYPE == at::ScalarType::Half) \
|
||||
{ \
|
||||
using g_scalar_t_##LEVEL = float; \
|
||||
using p_scalar_t_##LEVEL = at::Half; \
|
||||
__VA_ARGS__; \
|
||||
} \
|
||||
else if (GTYPE == at::ScalarType::Half && PTYPE == at::ScalarType::Float) \
|
||||
{ \
|
||||
using g_scalar_t_##LEVEL = at::Half; \
|
||||
using p_scalar_t_##LEVEL = float; \
|
||||
__VA_ARGS__; \
|
||||
} \
|
||||
else if (GTYPE == at::ScalarType::Half && PTYPE == at::ScalarType::Half) \
|
||||
{ \
|
||||
using g_scalar_t_##LEVEL = at::Half; \
|
||||
using p_scalar_t_##LEVEL = at::Half; \
|
||||
__VA_ARGS__; \
|
||||
} \
|
||||
else \
|
||||
{ \
|
||||
AT_ERROR(#NAME, "not implemented for '", toString(GTYPE), toString(PTYPE), "'"); \
|
||||
} \
|
||||
|
||||
template <typename T>
|
||||
__device__ __forceinline__ T reduce_block_into_lanes(
|
||||
T *x, T val, int lanes = 1,
|
||||
bool share_result = false) // lanes is intended to be <= 32.
|
||||
__device__ __forceinline__ T reduce_block_into_lanes(T *x,
|
||||
T val,
|
||||
int lanes = 1,
|
||||
bool share_result = false) // lanes is intended to be <= 32.
|
||||
{
|
||||
int tid = threadIdx.x + threadIdx.y * blockDim.x;
|
||||
int blockSize =
|
||||
blockDim.x * blockDim.y; // blockSize is intended to be a multiple of 32.
|
||||
int tid = threadIdx.x + threadIdx.y * blockDim.x;
|
||||
int blockSize = blockDim.x * blockDim.y; // blockSize is intended to be a multiple of 32.
|
||||
|
||||
if (blockSize >= 64) {
|
||||
x[tid] = val;
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int i = (blockSize >> 1); i >= 64; i >>= 1) {
|
||||
if (tid < i) x[tid] = x[tid] + x[tid + i];
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
T final;
|
||||
|
||||
if (tid < 32) {
|
||||
if (blockSize >= 64)
|
||||
final = x[tid] + x[tid + 32];
|
||||
else
|
||||
final = val;
|
||||
// __SYNCWARP();
|
||||
{
|
||||
x[tid] = val;
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 16; i >= lanes; i >>= 1)
|
||||
final = final + __shfl_down_sync(0xffffffff, final, i);
|
||||
}
|
||||
for (int i = (blockSize >> 1); i >= 64; i >>= 1)
|
||||
{
|
||||
if (tid < i)
|
||||
x[tid] = x[tid] + x[tid + i];
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
if (share_result) {
|
||||
if (tid < lanes) x[tid] = final; // EpilogueOp
|
||||
// Make sure the smem result is visible to all warps.
|
||||
__syncthreads();
|
||||
}
|
||||
T final;
|
||||
|
||||
return final;
|
||||
if (tid < 32)
|
||||
{
|
||||
if (blockSize >= 64)
|
||||
final = x[tid] + x[tid + 32];
|
||||
else
|
||||
final = val;
|
||||
// __SYNCWARP();
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 16; i >= lanes; i >>= 1)
|
||||
final = final + __shfl_down_sync(0xffffffff, final, i);
|
||||
}
|
||||
|
||||
if (share_result)
|
||||
{
|
||||
if (tid < lanes)
|
||||
x[tid] = final; // EpilogueOp
|
||||
// Make sure the smem result is visible to all warps.
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
return final;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__device__ __forceinline__ T reduce_block_into_lanes_max_op(
|
||||
T *x, T val, int lanes = 1,
|
||||
bool share_result = false) // lanes is intended to be <= 32.
|
||||
__device__ __forceinline__ T reduce_block_into_lanes_max_op(T *x,
|
||||
T val,
|
||||
int lanes = 1,
|
||||
bool share_result = false) // lanes is intended to be <= 32.
|
||||
{
|
||||
int tid = threadIdx.x + threadIdx.y * blockDim.x;
|
||||
int blockSize =
|
||||
blockDim.x * blockDim.y; // blockSize is intended to be a multiple of 32.
|
||||
int tid = threadIdx.x + threadIdx.y * blockDim.x;
|
||||
int blockSize = blockDim.x * blockDim.y; // blockSize is intended to be a multiple of 32.
|
||||
|
||||
if (blockSize >= 64) {
|
||||
x[tid] = val;
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int i = (blockSize >> 1); i >= 64; i >>= 1) {
|
||||
if (tid < i) x[tid] = fmaxf(fabsf(x[tid]), fabsf(x[tid + i]));
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
T final;
|
||||
|
||||
if (tid < 32) {
|
||||
if (blockSize >= 64)
|
||||
final = fmaxf(fabsf(x[tid]), fabsf(x[tid + 32]));
|
||||
else
|
||||
final = val;
|
||||
// __SYNCWARP();
|
||||
{
|
||||
x[tid] = val;
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 16; i >= lanes; i >>= 1)
|
||||
final =
|
||||
fmaxf(fabsf(final), fabsf(__shfl_down_sync(0xffffffff, final, i)));
|
||||
}
|
||||
for (int i = (blockSize >> 1); i >= 64; i >>= 1)
|
||||
{
|
||||
if (tid < i)
|
||||
x[tid] = fmaxf(fabsf(x[tid]), fabsf(x[tid + i]));
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
if (share_result) {
|
||||
if (tid < lanes) x[tid] = final; // EpilogueOp
|
||||
// Make sure the smem result is visible to all warps.
|
||||
__syncthreads();
|
||||
}
|
||||
T final;
|
||||
|
||||
return final;
|
||||
if (tid < 32)
|
||||
{
|
||||
if (blockSize >= 64)
|
||||
final = fmaxf(fabsf(x[tid]), fabsf(x[tid + 32]));
|
||||
else
|
||||
final = val;
|
||||
// __SYNCWARP();
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 16; i >= lanes; i >>= 1)
|
||||
final = fmaxf(fabsf(final), fabsf(__shfl_down_sync(0xffffffff, final, i)));
|
||||
}
|
||||
|
||||
if (share_result)
|
||||
{
|
||||
if (tid < lanes)
|
||||
x[tid] = final; // EpilogueOp
|
||||
// Make sure the smem result is visible to all warps.
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
return final;
|
||||
}
|
Loading…
Reference in New Issue