mirror of https://github.com/hpcaitech/ColossalAI
add colossalai kernel module (#55)
parent
648f806315
commit
5c3843dc98
|
@ -0,0 +1,8 @@
|
|||
from .jit.bias_dropout_add import bias_dropout_add_fused_train, bias_dropout_add_fused_inference
|
||||
from .jit.bias_gelu import bias_gelu_impl
|
||||
from .cuda_native import LayerNorm, FusedScaleMaskSoftmax, MultiHeadAttention
|
||||
|
||||
__all__ = [
|
||||
"bias_dropout_add_fused_train", "bias_dropout_add_fused_inference", "bias_gelu_impl",
|
||||
"LayerNorm", "FusedScaleMaskSoftmax", "MultiHeadAttention"
|
||||
]
|
|
@ -0,0 +1,17 @@
|
|||
from .builder import _build_cuda_native_kernel
|
||||
|
||||
CUDA_NATIVE_KERNEL_BUILD = False
|
||||
|
||||
|
||||
def build_cuda_native_kernel():
|
||||
global CUDA_NATIVE_KERNEL_BUILD
|
||||
if CUDA_NATIVE_KERNEL_BUILD == False:
|
||||
_build_cuda_native_kernel()
|
||||
CUDA_NATIVE_KERNEL_BUILD = True
|
||||
|
||||
|
||||
build_cuda_native_kernel()
|
||||
|
||||
from .layer_norm import MixedFusedLayerNorm as LayerNorm
|
||||
from .scaled_softmax import FusedScaleMaskSoftmax
|
||||
from .multihead_attention import MultiHeadAttention
|
|
@ -0,0 +1,114 @@
|
|||
import os
|
||||
import pathlib
|
||||
import subprocess
|
||||
|
||||
from torch.utils import cpp_extension
|
||||
|
||||
# Setting this param to a list has a problem of generating different
|
||||
# compilation commands (with diferent order of architectures) and
|
||||
# leading to recompilation of fused kernels. Set it to empty string
|
||||
# to avoid recompilation and assign arch flags explicity in
|
||||
# extra_cuda_cflags below
|
||||
os.environ["TORCH_CUDA_ARCH_LIST"] = ""
|
||||
|
||||
|
||||
def _build_cuda_native_kernel():
|
||||
|
||||
# Check if cuda 11 is installed for compute capability 8.0
|
||||
cc_flag = []
|
||||
_, bare_metal_major, _ = _get_cuda_bare_metal_version(cpp_extension.CUDA_HOME)
|
||||
if int(bare_metal_major) >= 11:
|
||||
cc_flag.append('-gencode')
|
||||
cc_flag.append('arch=compute_80,code=sm_80')
|
||||
|
||||
# Build path
|
||||
basepath = pathlib.Path(__file__).parent.absolute()
|
||||
srcpath = basepath / 'csrc'
|
||||
buildpath = basepath / 'build'
|
||||
_create_build_dir(buildpath)
|
||||
|
||||
# Helper function to build the kernels.
|
||||
def _cpp_extention_load_helper(name, sources, extra_cuda_flags):
|
||||
return cpp_extension.load(
|
||||
name=name,
|
||||
sources=sources,
|
||||
build_directory=buildpath,
|
||||
extra_cflags=[
|
||||
'-O3',
|
||||
],
|
||||
extra_include_paths=[str(srcpath / 'kernels' / 'include')],
|
||||
extra_cuda_cflags=['-O3', '-gencode', 'arch=compute_70,code=sm_70', '--use_fast_math'] +
|
||||
extra_cuda_flags + cc_flag,
|
||||
verbose=False)
|
||||
|
||||
# ==============
|
||||
# Fused softmax.
|
||||
# ==============
|
||||
|
||||
extra_cuda_flags = ['-U__CUDA_NO_HALF_OPERATORS__',
|
||||
'-U__CUDA_NO_HALF_CONVERSIONS__',
|
||||
'--expt-relaxed-constexpr',
|
||||
'--expt-extended-lambda']
|
||||
|
||||
# Upper triangular softmax.
|
||||
sources=[srcpath / 'scaled_upper_triang_masked_softmax.cpp',
|
||||
srcpath / 'scaled_upper_triang_masked_softmax_cuda.cu']
|
||||
colossal_scaled_upper_triang_masked_softmax = _cpp_extention_load_helper(
|
||||
"colossal_scaled_upper_triang_masked_softmax",
|
||||
sources, extra_cuda_flags)
|
||||
|
||||
# Masked softmax.
|
||||
sources=[srcpath / 'scaled_masked_softmax.cpp',
|
||||
srcpath / 'scaled_masked_softmax_cuda.cu']
|
||||
colossal_scaled_masked_softmax = _cpp_extention_load_helper(
|
||||
"colossal_scaled_masked_softmax", sources, extra_cuda_flags)
|
||||
|
||||
# =================================
|
||||
# Mixed precision fused layer norm.
|
||||
# =================================
|
||||
|
||||
extra_cuda_flags = ['-maxrregcount=50']
|
||||
sources = [srcpath / 'layer_norm_cuda.cpp', srcpath / 'layer_norm_cuda_kernel.cu']
|
||||
colossal_layer_norm_cuda = _cpp_extention_load_helper("colossal_layer_norm_cuda", sources,
|
||||
extra_cuda_flags)
|
||||
|
||||
# ==========================================
|
||||
# Mixed precision Transformer Encoder Layer.
|
||||
# ==========================================
|
||||
|
||||
extra_cuda_flags = ['-std=c++14',
|
||||
'-U__CUDA_NO_HALF_OPERATORS__',
|
||||
'-U__CUDA_NO_HALF_CONVERSIONS__',
|
||||
'-U__CUDA_NO_HALF2_OPERATORS__',
|
||||
'-DTHRUST_IGNORE_CUB_VERSION_CHECK']
|
||||
|
||||
sources = [srcpath / 'multihead_attention_1d.cpp']
|
||||
kernel_sources = ["cublas_wrappers.cu",
|
||||
"transform_kernels.cu",
|
||||
"dropout_kernels.cu",
|
||||
"normalize_kernels.cu",
|
||||
"softmax_kernels.cu",
|
||||
"general_kernels.cu",
|
||||
"cuda_util.cu"]
|
||||
sources += [(srcpath / 'kernels' / cu_file) for cu_file in kernel_sources]
|
||||
colossal_multihead_attention = _cpp_extention_load_helper("colossal_multihead_attention", sources,
|
||||
extra_cuda_flags)
|
||||
|
||||
|
||||
def _get_cuda_bare_metal_version(cuda_dir):
|
||||
raw_output = subprocess.check_output([cuda_dir + "/bin/nvcc", "-V"], universal_newlines=True)
|
||||
output = raw_output.split()
|
||||
release_idx = output.index("release") + 1
|
||||
release = output[release_idx].split(".")
|
||||
bare_metal_major = release[0]
|
||||
bare_metal_minor = release[1][0]
|
||||
|
||||
return raw_output, bare_metal_major, bare_metal_minor
|
||||
|
||||
|
||||
def _create_build_dir(buildpath):
|
||||
try:
|
||||
os.mkdir(buildpath)
|
||||
except OSError:
|
||||
if not os.path.isdir(buildpath):
|
||||
print(f"Creation of the build directory {buildpath} failed")
|
|
@ -0,0 +1,13 @@
|
|||
/*This code from NVIDIA apex:
|
||||
* https://github.com/NVIDIA/apex
|
||||
* with minor changes. */
|
||||
|
||||
#ifndef TORCH_CHECK
|
||||
#define TORCH_CHECK AT_CHECK
|
||||
#endif
|
||||
|
||||
#ifdef VERSION_GE_1_3
|
||||
#define DATA_PTR data_ptr
|
||||
#else
|
||||
#define DATA_PTR data
|
||||
#endif
|
|
@ -0,0 +1,191 @@
|
|||
#include "block_reduce.h"
|
||||
#include "cuda_util.h"
|
||||
#include "kernels.h"
|
||||
#include "ls_cub.cuh"
|
||||
|
||||
ls::cub::CachingDeviceAllocator g_allocator(true);
|
||||
|
||||
template <typename T>
|
||||
__global__ void ls_cross_entropy_fw_kernel(
|
||||
const T *__restrict__ inputs, const int *__restrict__ targets,
|
||||
float *__restrict__ outputs, float *__restrict__ nll_loss_outputs,
|
||||
const int padding_idx, const float epsilon, const int vocab_size) {
|
||||
/* step1: compute each thread's max_logit and sum_exp_logit, store in
|
||||
* max_input, sum_exp_logit */
|
||||
const int block_start = blockIdx.x * vocab_size;
|
||||
const int left_idx = block_start + threadIdx.x;
|
||||
const int right_idx = (blockIdx.x + 1) * vocab_size;
|
||||
float max_input[1] = {REDUCE_FLOAT_INF_NEG};
|
||||
float sum_logits[2] = {0.f, 0.f}; // logit and logit exp
|
||||
int target_tid = targets[blockIdx.x];
|
||||
|
||||
if (target_tid == padding_idx) {
|
||||
if (threadIdx.x == 0) {
|
||||
nll_loss_outputs[blockIdx.x] = 0.f;
|
||||
outputs[blockIdx.x] = 0.f;
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
for (int i = left_idx; i < right_idx; i += blockDim.x) {
|
||||
max_input[0] = fmaxf(max_input[0], static_cast<float>(inputs[i]));
|
||||
}
|
||||
blockReduce<ReduceType::kMax, 1>(max_input);
|
||||
__shared__ float s_max_input;
|
||||
if (threadIdx.x == 0) {
|
||||
s_max_input = max_input[0];
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
for (int i = left_idx; i < right_idx; i += blockDim.x) {
|
||||
float logit = static_cast<float>(inputs[i]) - s_max_input;
|
||||
sum_logits[0] += logit;
|
||||
sum_logits[1] += expf(logit);
|
||||
}
|
||||
|
||||
blockReduce<ReduceType::kSum, 2>(sum_logits);
|
||||
__shared__ float s_sum_logit;
|
||||
__shared__ float s_sum_exp;
|
||||
if (threadIdx.x == 0) {
|
||||
s_sum_logit = sum_logits[0];
|
||||
s_sum_exp = sum_logits[1];
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
float eps_i = epsilon / (vocab_size - 1);
|
||||
if (threadIdx.x == 0) {
|
||||
// neg_log_prob = log(sum(exp(x - x_max))) - (x - x_max)
|
||||
float nll_loss = logf(s_sum_exp) -
|
||||
static_cast<float>(inputs[block_start + target_tid]) +
|
||||
s_max_input;
|
||||
nll_loss_outputs[blockIdx.x] = nll_loss;
|
||||
float sum_nll_loss = vocab_size * logf(s_sum_exp) - s_sum_logit;
|
||||
outputs[blockIdx.x] =
|
||||
(1.f - epsilon - eps_i) * nll_loss + eps_i * sum_nll_loss;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__global__ void ls_cross_entropy_bw_kernel(
|
||||
const float *__restrict__ grad_outputs, const T *__restrict__ inputs,
|
||||
const int *__restrict__ targets, T *__restrict__ grad_inputs,
|
||||
const int padding_idx, const float epsilon, const int vocab_size) {
|
||||
/* step1: compute each thread's max_logit and sum_exp_logit, store in
|
||||
* max_input, sum_exp_logit */
|
||||
const int block_start = blockIdx.x * vocab_size;
|
||||
const int left_idx = block_start + threadIdx.x;
|
||||
const int right_idx = (blockIdx.x + 1) * vocab_size;
|
||||
float max_input[1] = {REDUCE_FLOAT_INF_NEG};
|
||||
float sum_logits[1] = {0.f};
|
||||
const float grad_out = static_cast<float>(grad_outputs[0]);
|
||||
int target_tid = targets[blockIdx.x];
|
||||
|
||||
if (target_tid == padding_idx) {
|
||||
for (int i = left_idx; i < right_idx; i += blockDim.x) {
|
||||
grad_inputs[i] = 0.f;
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
for (int i = left_idx; i < right_idx; i += blockDim.x) {
|
||||
max_input[0] = fmaxf(max_input[0], static_cast<float>(inputs[i]));
|
||||
}
|
||||
blockReduce<ReduceType::kMax, 1>(max_input);
|
||||
__shared__ float s_max_input;
|
||||
if (threadIdx.x == 0) {
|
||||
s_max_input = max_input[0];
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
for (int i = left_idx; i < right_idx; i += blockDim.x) {
|
||||
float logit = static_cast<float>(inputs[i]) - s_max_input;
|
||||
sum_logits[0] += expf(logit);
|
||||
}
|
||||
|
||||
blockReduce<ReduceType::kSum, 1>(sum_logits);
|
||||
__shared__ float s_sum_exp;
|
||||
if (threadIdx.x == 0) {
|
||||
s_sum_exp = sum_logits[0];
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
float eps_i = epsilon / (vocab_size - 1);
|
||||
float nll_weight = 1.0 - epsilon - eps_i;
|
||||
|
||||
for (int i = left_idx; i < right_idx; i += blockDim.x) {
|
||||
float prob = expf(static_cast<float>(inputs[i]) - s_max_input) / s_sum_exp;
|
||||
float grad = 0;
|
||||
grad += (vocab_size * prob - 1) * eps_i;
|
||||
grad += prob * nll_weight;
|
||||
if ((i - block_start) == target_tid) {
|
||||
grad -= nll_weight;
|
||||
}
|
||||
grad_inputs[i] = grad_out * grad;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void launch_cross_entropy_fw(const T *inputs_ptr, const int *targets_ptr,
|
||||
float *outputs_ptr, float *nll_loss_ptr,
|
||||
float *loss_buffer, const int padding_idx,
|
||||
const float epsilon, const int batch_size,
|
||||
const int seq_len, const int vocab_size,
|
||||
cudaStream_t stream) {
|
||||
int grid_dim = batch_size * seq_len;
|
||||
float *nll_loss_buffer = loss_buffer + grid_dim;
|
||||
ls_cross_entropy_fw_kernel<<<grid_dim, MAX_THREADS, 0, stream>>>(
|
||||
inputs_ptr, targets_ptr, loss_buffer, nll_loss_buffer, padding_idx,
|
||||
epsilon, vocab_size);
|
||||
|
||||
int num_items = grid_dim;
|
||||
void *d_temp_storage = NULL;
|
||||
size_t temp_storage_bytes = 0;
|
||||
CHECK_GPU_ERROR(ls::cub::DeviceReduce::Sum(d_temp_storage, temp_storage_bytes,
|
||||
loss_buffer, outputs_ptr,
|
||||
num_items, stream));
|
||||
CHECK_GPU_ERROR(
|
||||
g_allocator.DeviceAllocate(&d_temp_storage, temp_storage_bytes));
|
||||
CHECK_GPU_ERROR(ls::cub::DeviceReduce::Sum(d_temp_storage, temp_storage_bytes,
|
||||
loss_buffer, outputs_ptr,
|
||||
num_items, stream));
|
||||
CHECK_GPU_ERROR(ls::cub::DeviceReduce::Sum(d_temp_storage, temp_storage_bytes,
|
||||
nll_loss_buffer, nll_loss_ptr,
|
||||
num_items, stream));
|
||||
CHECK_GPU_ERROR(g_allocator.DeviceFree(d_temp_storage));
|
||||
}
|
||||
|
||||
template void launch_cross_entropy_fw<float>(
|
||||
const float *inputs_ptr, const int *targets_ptr, float *outputs_ptr,
|
||||
float *nll_loss_ptr, float *loss_buffer, const int padding_idx,
|
||||
const float epsilon, const int batch_size, const int seq_len,
|
||||
const int vocab_size, cudaStream_t stream);
|
||||
|
||||
template void launch_cross_entropy_fw<__half>(
|
||||
const __half *inputs_ptr, const int *targets_ptr, float *outputs_ptr,
|
||||
float *nll_loss_ptr, float *loss_buffer, const int padding_idx,
|
||||
const float epsilon, const int batch_size, const int seq_len,
|
||||
const int vocab_size, cudaStream_t stream);
|
||||
|
||||
template <typename T>
|
||||
void launch_cross_entropy_bw(const float *grad_outputs_ptr, const T *inputs_ptr,
|
||||
const int *targets_ptr, T *grad_inputs_ptr,
|
||||
const int padding_idx, const float epsilon,
|
||||
const int batch_size, const int seq_len,
|
||||
const int vocab_size, cudaStream_t stream) {
|
||||
int grid_dim = batch_size * seq_len;
|
||||
ls_cross_entropy_bw_kernel<<<grid_dim, MAX_THREADS, 0, stream>>>(
|
||||
grad_outputs_ptr, inputs_ptr, targets_ptr, grad_inputs_ptr, padding_idx,
|
||||
epsilon, vocab_size);
|
||||
}
|
||||
|
||||
template void launch_cross_entropy_bw<float>(
|
||||
const float *grad_outputs_ptr, const float *inputs_ptr,
|
||||
const int *targets_ptr, float *grad_inputs_ptr, const int padding_idx,
|
||||
const float epsilon, const int batch_size, const int seq_len,
|
||||
const int vocab_size, cudaStream_t stream);
|
||||
|
||||
template void launch_cross_entropy_bw<__half>(
|
||||
const float *grad_outputs_ptr, const __half *inputs_ptr,
|
||||
const int *targets_ptr, __half *grad_inputs_ptr, const int padding_idx,
|
||||
const float epsilon, const int batch_size, const int seq_len,
|
||||
const int vocab_size, cudaStream_t stream);
|
|
@ -0,0 +1,87 @@
|
|||
/* Copyright 2021 The LightSeq Team
|
||||
Copyright Microsoft DeepSpeed
|
||||
This file is adapted from Microsoft DeepSpeed
|
||||
*/
|
||||
#include "cublas_wrappers.h"
|
||||
|
||||
int cublas_gemm_ex(cublasHandle_t handle, cublasOperation_t transa,
|
||||
cublasOperation_t transb, int m, int n, int k,
|
||||
const float *alpha, const float *beta, const float *A,
|
||||
const float *B, float *C, cublasGemmAlgo_t algo) {
|
||||
cublasStatus_t status =
|
||||
cublasGemmEx(handle, transa, transb, m, n, k, (const void *)alpha,
|
||||
(const void *)A, CUDA_R_32F, (transa == CUBLAS_OP_N) ? m : k,
|
||||
(const void *)B, CUDA_R_32F, (transb == CUBLAS_OP_N) ? k : n,
|
||||
(const void *)beta, C, CUDA_R_32F, m, CUDA_R_32F, algo);
|
||||
|
||||
if (status != CUBLAS_STATUS_SUCCESS) {
|
||||
fprintf(stderr,
|
||||
"!!!! kernel execution error. (m: %d, n: %d, k: %d, error: %d) \n",
|
||||
m, n, k, (int)status);
|
||||
return EXIT_FAILURE;
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
|
||||
int cublas_gemm_ex(cublasHandle_t handle, cublasOperation_t transa,
|
||||
cublasOperation_t transb, int m, int n, int k,
|
||||
const float *alpha, const float *beta, const __half *A,
|
||||
const __half *B, __half *C, cublasGemmAlgo_t algo) {
|
||||
cublasStatus_t status = cublasGemmEx(
|
||||
handle, transa, transb, m, n, k, (const void *)alpha, (const void *)A,
|
||||
CUDA_R_16F, (transa == CUBLAS_OP_N) ? m : k, (const void *)B, CUDA_R_16F,
|
||||
(transb == CUBLAS_OP_N) ? k : n, (const void *)beta, (void *)C,
|
||||
CUDA_R_16F, m, CUDA_R_32F, algo);
|
||||
|
||||
if (status != CUBLAS_STATUS_SUCCESS) {
|
||||
fprintf(stderr,
|
||||
"!!!! kernel execution error. (m: %d, n: %d, k: %d, error: %d) \n",
|
||||
m, n, k, (int)status);
|
||||
return EXIT_FAILURE;
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
|
||||
int cublas_strided_batched_gemm(cublasHandle_t handle, int m, int n, int k,
|
||||
const float *alpha, const float *beta,
|
||||
const float *A, const float *B, float *C,
|
||||
cublasOperation_t op_A, cublasOperation_t op_B,
|
||||
int stride_A, int stride_B, int stride_C,
|
||||
int batch, cublasGemmAlgo_t algo) {
|
||||
cublasStatus_t status = cublasGemmStridedBatchedEx(
|
||||
handle, op_A, op_B, m, n, k, alpha, A, CUDA_R_32F,
|
||||
(op_A == CUBLAS_OP_N) ? m : k, stride_A, B, CUDA_R_32F,
|
||||
(op_B == CUBLAS_OP_N) ? k : n, stride_B, beta, C, CUDA_R_32F, m, stride_C,
|
||||
batch, CUDA_R_32F, algo);
|
||||
|
||||
if (status != CUBLAS_STATUS_SUCCESS) {
|
||||
fprintf(stderr,
|
||||
"!!!! kernel execution error. (batch: %d, m: %d, n: %d, k: %d, "
|
||||
"error: %d) \n",
|
||||
batch, m, n, k, (int)status);
|
||||
return EXIT_FAILURE;
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
|
||||
int cublas_strided_batched_gemm(cublasHandle_t handle, int m, int n, int k,
|
||||
const float *alpha, const float *beta,
|
||||
const __half *A, const __half *B, __half *C,
|
||||
cublasOperation_t op_A, cublasOperation_t op_B,
|
||||
int stride_A, int stride_B, int stride_C,
|
||||
int batch, cublasGemmAlgo_t algo) {
|
||||
cublasStatus_t status = cublasGemmStridedBatchedEx(
|
||||
handle, op_A, op_B, m, n, k, alpha, A, CUDA_R_16F,
|
||||
(op_A == CUBLAS_OP_N) ? m : k, stride_A, B, CUDA_R_16F,
|
||||
(op_B == CUBLAS_OP_N) ? k : n, stride_B, beta, C, CUDA_R_16F, m, stride_C,
|
||||
batch, CUDA_R_32F, algo);
|
||||
|
||||
if (status != CUBLAS_STATUS_SUCCESS) {
|
||||
fprintf(stderr,
|
||||
"!!!! kernel execution error. (m: %d, n: %d, k: %d, error: %d) \n",
|
||||
m, n, k, (int)status);
|
||||
return EXIT_FAILURE;
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
|
@ -0,0 +1,169 @@
|
|||
#include <thrust/device_vector.h>
|
||||
#include <thrust/reduce.h>
|
||||
|
||||
#include "cuda_util.h"
|
||||
|
||||
/* GPU function guard */
|
||||
std::string _cudaGetErrorString(cudaError_t error) {
|
||||
return cudaGetErrorString(error);
|
||||
}
|
||||
|
||||
std::string _cudaGetErrorString(cublasStatus_t error) {
|
||||
switch (error) {
|
||||
case CUBLAS_STATUS_SUCCESS:
|
||||
return "CUBLAS_STATUS_SUCCESS";
|
||||
|
||||
case CUBLAS_STATUS_NOT_INITIALIZED:
|
||||
return "CUBLAS_STATUS_NOT_INITIALIZED";
|
||||
|
||||
case CUBLAS_STATUS_ALLOC_FAILED:
|
||||
return "CUBLAS_STATUS_ALLOC_FAILED";
|
||||
|
||||
case CUBLAS_STATUS_INVALID_VALUE:
|
||||
return "CUBLAS_STATUS_INVALID_VALUE";
|
||||
|
||||
case CUBLAS_STATUS_ARCH_MISMATCH:
|
||||
return "CUBLAS_STATUS_ARCH_MISMATCH";
|
||||
|
||||
case CUBLAS_STATUS_MAPPING_ERROR:
|
||||
return "CUBLAS_STATUS_MAPPING_ERROR";
|
||||
|
||||
case CUBLAS_STATUS_EXECUTION_FAILED:
|
||||
return "CUBLAS_STATUS_EXECUTION_FAILED";
|
||||
|
||||
case CUBLAS_STATUS_INTERNAL_ERROR:
|
||||
return "CUBLAS_STATUS_INTERNAL_ERROR";
|
||||
|
||||
case CUBLAS_STATUS_NOT_SUPPORTED:
|
||||
return "CUBLAS_STATUS_NOT_SUPPORTED";
|
||||
|
||||
case CUBLAS_STATUS_LICENSE_ERROR:
|
||||
return "CUBLAS_STATUS_LICENSE_ERROR";
|
||||
}
|
||||
return "CUBLAS_UNKNOW";
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void check_gpu_error(T result, char const *const func, const char *const file,
|
||||
int const line) {
|
||||
if (result) {
|
||||
throw std::runtime_error(std::string("[CUDA][ERROR] ") + +file + "(" +
|
||||
std::to_string(line) +
|
||||
"): " + (_cudaGetErrorString(result)) + "\n");
|
||||
}
|
||||
}
|
||||
|
||||
template void check_gpu_error<cudaError_t>(cudaError_t result,
|
||||
char const *const func,
|
||||
const char *const file,
|
||||
int const line);
|
||||
template void check_gpu_error<cublasStatus_t>(cublasStatus_t result,
|
||||
char const *const func,
|
||||
const char *const file,
|
||||
int const line);
|
||||
|
||||
template <typename T>
|
||||
void print_vec(const T *outv, std::string outn, int num_output_ele) {
|
||||
std::cout << outn << ": ";
|
||||
std::vector<T> hout(num_output_ele, (T)0);
|
||||
cudaMemcpy(hout.data(), outv, num_output_ele * sizeof(T),
|
||||
cudaMemcpyDeviceToHost);
|
||||
for (int i = 0; i < num_output_ele; i++) {
|
||||
std::cout << hout[i] << ", ";
|
||||
}
|
||||
std::cout << std::endl;
|
||||
}
|
||||
|
||||
template <>
|
||||
void print_vec<__half>(const __half *outv, std::string outn,
|
||||
int num_output_ele) {
|
||||
std::cout << outn << ": ";
|
||||
std::vector<__half> hout(num_output_ele, (__half)0.f);
|
||||
cudaMemcpy(hout.data(), outv, num_output_ele * sizeof(__half),
|
||||
cudaMemcpyDeviceToHost);
|
||||
for (int i = 0; i < num_output_ele; i++) {
|
||||
std::cout << __half2float(hout[i]) << ", ";
|
||||
}
|
||||
std::cout << std::endl;
|
||||
}
|
||||
|
||||
template void print_vec<float>(const float *outv, std::string outn,
|
||||
int num_output_ele);
|
||||
|
||||
template void print_vec<int>(const int *outv, std::string outn,
|
||||
int num_output_ele);
|
||||
|
||||
template void print_vec<__half>(const __half *outv, std::string outn,
|
||||
int num_output_ele);
|
||||
|
||||
template <typename T>
|
||||
T *cuda_malloc(size_t ele_num) {
|
||||
size_t byte_size = ele_num * sizeof(T);
|
||||
T *pdata = nullptr;
|
||||
CHECK_GPU_ERROR(cudaMalloc((void **)&pdata, byte_size));
|
||||
return pdata;
|
||||
}
|
||||
|
||||
template float *cuda_malloc<float>(size_t ele_num);
|
||||
|
||||
template __half *cuda_malloc<__half>(size_t ele_num);
|
||||
|
||||
template uint8_t *cuda_malloc<uint8_t>(size_t ele_num);
|
||||
|
||||
void cuda_free(void *pdata) {
|
||||
if (pdata != nullptr) {
|
||||
cudaFree(pdata);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
struct _isnan {
|
||||
__device__ bool operator()(T a) const { return isnan(a); }
|
||||
};
|
||||
|
||||
template <>
|
||||
struct _isnan<__half> {
|
||||
__device__ bool operator()(const __half a) const { return __hisnan(a); }
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct _isinf {
|
||||
__device__ bool operator()(T a) const { return isinf(a); }
|
||||
};
|
||||
|
||||
template <>
|
||||
struct _isinf<__half> {
|
||||
__device__ bool operator()(const __half a) const { return __hisinf(a); }
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
void check_nan_inf(const T *data_ptr, int dsize, bool check_nan_inf,
|
||||
std::string file, int line, cudaStream_t stream) {
|
||||
// check_nan_inf = 0 for checking nan
|
||||
// check_nan_inf = 1 for checking inf
|
||||
bool res = false;
|
||||
std::string msg = file + "(" + std::to_string(line) + "): ";
|
||||
if (check_nan_inf) {
|
||||
msg += "nan.";
|
||||
res = thrust::transform_reduce(thrust::cuda::par.on(stream), data_ptr,
|
||||
data_ptr + dsize, _isnan<T>(), false,
|
||||
thrust::logical_or<bool>());
|
||||
} else {
|
||||
msg += "inf.";
|
||||
res = thrust::transform_reduce(thrust::cuda::par.on(stream), data_ptr,
|
||||
data_ptr + dsize, _isinf<T>(), false,
|
||||
thrust::logical_or<bool>());
|
||||
}
|
||||
if (res) {
|
||||
throw std::runtime_error(msg);
|
||||
}
|
||||
std::cout << msg << " [check pass]." << std::endl;
|
||||
}
|
||||
|
||||
template void check_nan_inf<float>(const float *data_ptr, int dsize,
|
||||
bool check_nan_inf, std::string file,
|
||||
int line, cudaStream_t stream);
|
||||
|
||||
template void check_nan_inf<__half>(const __half *data_ptr, int dsize,
|
||||
bool check_nan_inf, std::string file,
|
||||
int line, cudaStream_t stream);
|
File diff suppressed because it is too large
Load Diff
|
@ -0,0 +1,232 @@
|
|||
#include "kernels.h"
|
||||
|
||||
#include <cooperative_groups.h>
|
||||
|
||||
namespace cg = cooperative_groups;
|
||||
|
||||
/**
|
||||
@brief: fuse_transpose_bias
|
||||
Calculate the sum of elements in each column of the matrix.
|
||||
|
||||
@thread
|
||||
gridDim.x = ceil(cols / WARP_SIZE)
|
||||
blockDim.x = WARP_SIZE
|
||||
blockDim.y = WARP_SIZE
|
||||
|
||||
@param
|
||||
inp: [rows, cols]
|
||||
out: [cols]
|
||||
rows: the number of rows in the matrix
|
||||
cols: the number of cols in the matrix
|
||||
*/
|
||||
template <typename T>
|
||||
__global__ void column_sum_reduce(const T *__restrict__ inp,
|
||||
T *__restrict__ out, int rows, int cols) {
|
||||
__shared__ float tile[WARP_SIZE][WARP_SIZE];
|
||||
|
||||
cg::thread_block b = cg::this_thread_block();
|
||||
cg::thread_block_tile<WARP_SIZE> g = cg::tiled_partition<WARP_SIZE>(b);
|
||||
|
||||
int idx = flat_2dim(blockIdx.x, threadIdx.x, WARP_SIZE);
|
||||
int y_stride = cols * WARP_SIZE;
|
||||
float localSum = 0;
|
||||
|
||||
// Loop across matrix row
|
||||
// TODO: optimize to log complexity
|
||||
if (idx < cols) {
|
||||
int offset = flat_2dim(threadIdx.y, idx, cols);
|
||||
for (int r = threadIdx.y; r < rows; r += WARP_SIZE) {
|
||||
localSum += (float)inp[offset];
|
||||
offset += y_stride;
|
||||
}
|
||||
}
|
||||
|
||||
// The sum of a row in tile is equal to the sum of a col in original matrix
|
||||
tile[threadIdx.x][threadIdx.y] = localSum;
|
||||
|
||||
__syncthreads();
|
||||
|
||||
// Sum the shared buffer.
|
||||
// The change of threadIdx.x is continuous
|
||||
float sum = tile[threadIdx.y][threadIdx.x];
|
||||
|
||||
__syncthreads();
|
||||
|
||||
// Calculate the sum of a row in tile
|
||||
for (int i = 1; i < WARP_SIZE; i <<= 1) sum += g.shfl_down(sum, i);
|
||||
|
||||
if (threadIdx.x == 0) {
|
||||
int pos = flat_2dim(blockIdx.x, threadIdx.y, WARP_SIZE);
|
||||
if (pos < cols) out[pos] = sum;
|
||||
}
|
||||
}
|
||||
|
||||
// [r, c] -> [c]
|
||||
template <>
|
||||
void launch_fuse_transpose_bias_kernel<float>(const float *inp, float *out,
|
||||
int rows, int cols,
|
||||
cudaStream_t stream) {
|
||||
dim3 grid_dim((cols - 1) / WARP_SIZE + 1);
|
||||
dim3 block_dim(WARP_SIZE, WARP_SIZE);
|
||||
|
||||
column_sum_reduce<float>
|
||||
<<<grid_dim, block_dim, 0, stream>>>(inp, out, rows, cols);
|
||||
}
|
||||
|
||||
template <>
|
||||
void launch_fuse_transpose_bias_kernel<__half>(const __half *inp, __half *out,
|
||||
int rows, int cols,
|
||||
cudaStream_t stream) {
|
||||
dim3 grid_dim((cols - 1) / WARP_SIZE + 1);
|
||||
dim3 block_dim(WARP_SIZE, WARP_SIZE);
|
||||
|
||||
column_sum_reduce<__half>
|
||||
<<<grid_dim, block_dim, 0, stream>>>(inp, out, rows, cols);
|
||||
}
|
||||
|
||||
/**
|
||||
@brief: fused_add2
|
||||
Add two matrix inp1 and inp2 to out.
|
||||
|
||||
@thread
|
||||
gridDim.x = batch_size * seq_len
|
||||
blockDim.x = min(hidden_dim, MAX_THREADS)
|
||||
|
||||
@param
|
||||
inp1: [batch_size, seq_len, hidden_dim]
|
||||
inp2: [batch_size, seq_len, hidden_dim]
|
||||
out: [batch_size, seq_len, hidden_dim]
|
||||
batch_size: the size of the current batch
|
||||
seq_len: the sequence length of the current batch
|
||||
hidden_dim: dim of the hidden tensor
|
||||
*/
|
||||
template <typename T>
|
||||
__global__ void fused_add2_kernel(T *out, const T *inp1, const T *inp2,
|
||||
int hidden_dim);
|
||||
|
||||
template <>
|
||||
__global__ void fused_add2_kernel<float>(float *out, const float *inp1,
|
||||
const float *inp2, int hidden_dim) {
|
||||
int row_id = blockIdx.x;
|
||||
int offset = flat_2dim(row_id, 0, hidden_dim);
|
||||
|
||||
const float4 *inp1_4 = reinterpret_cast<const float4 *>(inp1);
|
||||
const float4 *inp2_4 = reinterpret_cast<const float4 *>(inp2);
|
||||
float4 *out_4 = reinterpret_cast<float4 *>(out);
|
||||
float4 vinp1;
|
||||
float4 vinp2;
|
||||
float4 val;
|
||||
|
||||
for (std::size_t i = threadIdx.x; i < hidden_dim; i += blockDim.x) {
|
||||
vinp1 = inp1_4[offset + i];
|
||||
vinp2 = inp2_4[offset + i];
|
||||
val.x = vinp1.x + vinp2.x;
|
||||
val.y = vinp1.y + vinp2.y;
|
||||
val.z = vinp1.z + vinp2.z;
|
||||
val.w = vinp1.w + vinp2.w;
|
||||
out_4[offset + i] = val;
|
||||
}
|
||||
}
|
||||
|
||||
template <>
|
||||
__global__ void fused_add2_kernel<__half>(__half *out, const __half *inp1,
|
||||
const __half *inp2, int hidden_dim) {
|
||||
int row_id = blockIdx.x;
|
||||
int offset = flat_2dim(row_id, 0, hidden_dim);
|
||||
|
||||
const float4 *inp1_4 = reinterpret_cast<const float4 *>(inp1);
|
||||
const float4 *inp2_4 = reinterpret_cast<const float4 *>(inp2);
|
||||
float4 *out_4 = reinterpret_cast<float4 *>(out);
|
||||
float4 vinp1;
|
||||
float4 vinp2;
|
||||
float4 val;
|
||||
__half2 *h2_inp1 = reinterpret_cast<__half2 *>(&vinp1);
|
||||
__half2 *h2_inp2 = reinterpret_cast<__half2 *>(&vinp2);
|
||||
__half2 *h2_val = reinterpret_cast<__half2 *>(&val);
|
||||
|
||||
for (std::size_t i = threadIdx.x; i < hidden_dim; i += blockDim.x) {
|
||||
vinp1 = inp1_4[offset + i];
|
||||
vinp2 = inp2_4[offset + i];
|
||||
h2_val[0] = __hadd2(h2_inp1[0], h2_inp2[0]);
|
||||
h2_val[1] = __hadd2(h2_inp1[1], h2_inp2[1]);
|
||||
h2_val[2] = __hadd2(h2_inp1[2], h2_inp2[2]);
|
||||
h2_val[3] = __hadd2(h2_inp1[3], h2_inp2[3]);
|
||||
out_4[offset + i] = val;
|
||||
}
|
||||
}
|
||||
|
||||
//[b, s, h] -> [b, s, h]
|
||||
template <>
|
||||
void launch_fused_add2<float>(float *out, const float *inp1, const float *inp2,
|
||||
int batch_size, int seq_len, int hidden_dim,
|
||||
cudaStream_t &stream) {
|
||||
hidden_dim >>= 2;
|
||||
|
||||
dim3 grid_dim(batch_size * seq_len);
|
||||
dim3 block_dim(min(hidden_dim, MAX_THREADS));
|
||||
|
||||
fused_add2_kernel<<<grid_dim, block_dim, 0, stream>>>(out, inp1, inp2,
|
||||
hidden_dim);
|
||||
}
|
||||
|
||||
template <>
|
||||
void launch_fused_add2<__half>(__half *out, const __half *inp1,
|
||||
const __half *inp2, int batch_size, int seq_len,
|
||||
int hidden_dim, cudaStream_t &stream) {
|
||||
hidden_dim >>= 3;
|
||||
|
||||
dim3 grid_dim(batch_size * seq_len);
|
||||
dim3 block_dim(min(hidden_dim, MAX_THREADS));
|
||||
|
||||
fused_add2_kernel<<<grid_dim, block_dim, 0, stream>>>(out, inp1, inp2,
|
||||
hidden_dim);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__global__ void kernel_concat3_dim1(const T *inp1, const T *inp2, T *output,
|
||||
int sz0, int sz2, int sz1_1, int sz1_2) {
|
||||
int nele = sz0 * sz2 * (sz1_1 + sz1_2);
|
||||
int idx = flat_2dim(blockIdx.x, threadIdx.x, blockDim.x);
|
||||
if (idx >= nele) {
|
||||
return;
|
||||
}
|
||||
float4 *dst_ptr = (float4 *)output + idx;
|
||||
int idx2 = idx % sz2;
|
||||
idx = idx / sz2;
|
||||
int idx1 = idx % (sz1_1 + sz1_2);
|
||||
int idx0 = idx / (sz1_1 + sz1_2);
|
||||
float4 *src_ptr = nullptr;
|
||||
int sz1 = 0;
|
||||
if (idx1 < sz1_1) {
|
||||
sz1 = sz1_1;
|
||||
src_ptr = (float4 *)inp1;
|
||||
} else {
|
||||
idx1 -= sz1_1;
|
||||
sz1 = sz1_2;
|
||||
src_ptr = (float4 *)inp2;
|
||||
}
|
||||
src_ptr += flat_3dim(idx0, idx1, idx2, sz1, sz2);
|
||||
dst_ptr[0] = src_ptr[0];
|
||||
}
|
||||
|
||||
template <>
|
||||
void launch_concat3_dim1<float>(const float *inp1, const float *inp2,
|
||||
float *output, int sz0, int sz2, int sz1_1,
|
||||
int sz1_2, cudaStream_t stream) {
|
||||
sz2 >>= 2;
|
||||
int nele = sz0 * sz2 * (sz1_1 + sz1_2);
|
||||
int nblock = (nele + MAX_THREADS - 1) / MAX_THREADS;
|
||||
kernel_concat3_dim1<<<nblock, MAX_THREADS, 0, stream>>>(
|
||||
inp1, inp2, output, sz0, sz2, sz1_1, sz1_2);
|
||||
}
|
||||
|
||||
template <>
|
||||
void launch_concat3_dim1<__half>(const __half *inp1, const __half *inp2,
|
||||
__half *output, int sz0, int sz2, int sz1_1,
|
||||
int sz1_2, cudaStream_t stream) {
|
||||
sz2 >>= 3;
|
||||
int nele = sz0 * sz2 * (sz1_1 + sz1_2);
|
||||
int nblock = (nele + MAX_THREADS - 1) / MAX_THREADS;
|
||||
kernel_concat3_dim1<<<nblock, MAX_THREADS, 0, stream>>>(
|
||||
inp1, inp2, output, sz0, sz2, sz1_1, sz1_2);
|
||||
}
|
|
@ -0,0 +1,312 @@
|
|||
/* Copyright 2021 The LightSeq Team
|
||||
Copyright Tencent/TurboTransformers
|
||||
This block_reduce_n is adapted from Tencent/TurboTransformers
|
||||
*/
|
||||
#pragma once
|
||||
#include <cuda.h>
|
||||
#include <cuda_fp16.h>
|
||||
#include <cuda_runtime.h>
|
||||
|
||||
enum class ReduceType { kMax = 0, kSum };
|
||||
const unsigned int WARP_REDUCE_MASK = 0xffffffff;
|
||||
const float REDUCE_FLOAT_INF_NEG = -100000000.f;
|
||||
const float REDUCE_FLOAT_INF_POS = 100000000.f;
|
||||
const unsigned int WARP_REDUCE_SIZE = 32;
|
||||
|
||||
template <typename T>
|
||||
__forceinline__ __device__ T warpReduceSum(T val) {
|
||||
for (int mask = (WARP_REDUCE_SIZE >> 1); mask > 0; mask >>= 1)
|
||||
val += __shfl_xor_sync(WARP_REDUCE_MASK, val, mask, WARP_REDUCE_SIZE);
|
||||
return val;
|
||||
}
|
||||
|
||||
/* Calculate the sum of all elements in a block */
|
||||
template <typename T>
|
||||
__forceinline__ __device__ T blockReduceSum(T val) {
|
||||
static __shared__ T shared[32];
|
||||
int lane = threadIdx.x & 0x1f;
|
||||
int wid = threadIdx.x >> 5;
|
||||
|
||||
val = warpReduceSum<T>(val);
|
||||
|
||||
if (lane == 0) shared[wid] = val;
|
||||
__syncthreads();
|
||||
|
||||
val = (threadIdx.x < (blockDim.x >> 5)) ? shared[lane] : (T)0.0f;
|
||||
val = warpReduceSum<T>(val);
|
||||
return val;
|
||||
}
|
||||
|
||||
template <ReduceType Rtype, int Num>
|
||||
__inline__ __device__ void blockReduce(float *pval);
|
||||
|
||||
// use template to make code more concise
|
||||
template <ReduceType Rtype, int Num>
|
||||
__inline__ __device__ void warpReduce(float *pval);
|
||||
|
||||
// static
|
||||
template <>
|
||||
__inline__ __device__ void warpReduce<ReduceType::kMax, 1>(float *pval) {
|
||||
*pval = max(*pval, __shfl_xor_sync(WARP_REDUCE_MASK, *pval, 16, 32));
|
||||
*pval = max(*pval, __shfl_xor_sync(WARP_REDUCE_MASK, *pval, 8, 32));
|
||||
*pval = max(*pval, __shfl_xor_sync(WARP_REDUCE_MASK, *pval, 4, 32));
|
||||
*pval = max(*pval, __shfl_xor_sync(WARP_REDUCE_MASK, *pval, 2, 32));
|
||||
*pval = max(*pval, __shfl_xor_sync(WARP_REDUCE_MASK, *pval, 1, 32));
|
||||
}
|
||||
|
||||
template <>
|
||||
__inline__ __device__ void warpReduce<ReduceType::kMax, 2>(float *pval) {
|
||||
float val0_tmp, val1_tmp;
|
||||
#define WarpReduceMaxOneStep(a, b) \
|
||||
val0_tmp = __shfl_xor_sync(WARP_REDUCE_MASK, *(pval), a, b); \
|
||||
val1_tmp = __shfl_xor_sync(WARP_REDUCE_MASK, *(pval + 1), a, b); \
|
||||
*(pval) = max(val0_tmp, *(pval)); \
|
||||
*(pval + 1) = max(val1_tmp, *(pval + 1));
|
||||
|
||||
WarpReduceMaxOneStep(16, 32);
|
||||
WarpReduceMaxOneStep(8, 32);
|
||||
WarpReduceMaxOneStep(4, 32);
|
||||
WarpReduceMaxOneStep(2, 32);
|
||||
WarpReduceMaxOneStep(1, 32);
|
||||
#undef WarpReduceMaxOneStep
|
||||
}
|
||||
|
||||
template <>
|
||||
__inline__ __device__ void warpReduce<ReduceType::kSum, 1>(float *pval) {
|
||||
*pval += __shfl_xor_sync(WARP_REDUCE_MASK, *pval, 16, 32);
|
||||
*pval += __shfl_xor_sync(WARP_REDUCE_MASK, *pval, 8, 32);
|
||||
*pval += __shfl_xor_sync(WARP_REDUCE_MASK, *pval, 4, 32);
|
||||
*pval += __shfl_xor_sync(WARP_REDUCE_MASK, *pval, 2, 32);
|
||||
*pval += __shfl_xor_sync(WARP_REDUCE_MASK, *pval, 1, 32);
|
||||
}
|
||||
|
||||
/*
|
||||
* Unorll for loop for warpreduce to
|
||||
* imporve instruction issue efficiency
|
||||
* ElemX means there are X numbers to be summed
|
||||
*/
|
||||
|
||||
template <>
|
||||
__inline__ __device__ void warpReduce<ReduceType::kSum, 2>(float *pval) {
|
||||
float val0_tmp, val1_tmp;
|
||||
#define WarpReduceSumOneStep(a, b) \
|
||||
val0_tmp = __shfl_xor_sync(WARP_REDUCE_MASK, *(pval + 0), a, b); \
|
||||
val1_tmp = __shfl_xor_sync(WARP_REDUCE_MASK, *(pval + 1), a, b); \
|
||||
*(pval + 0) += val0_tmp; \
|
||||
*(pval + 1) += val1_tmp
|
||||
|
||||
WarpReduceSumOneStep(16, 32);
|
||||
WarpReduceSumOneStep(8, 32);
|
||||
WarpReduceSumOneStep(4, 32);
|
||||
WarpReduceSumOneStep(2, 32);
|
||||
WarpReduceSumOneStep(1, 32);
|
||||
|
||||
#undef WarpReduceSumOneStep
|
||||
}
|
||||
|
||||
template <>
|
||||
__inline__ __device__ void warpReduce<ReduceType::kSum, 4>(float *pval) {
|
||||
float val0_tmp, val1_tmp, val2_tmp, val3_tmp;
|
||||
#define WarpReduceSumOneStep(a, b) \
|
||||
val0_tmp = __shfl_xor_sync(WARP_REDUCE_MASK, *(pval + 0), a, b); \
|
||||
val1_tmp = __shfl_xor_sync(WARP_REDUCE_MASK, *(pval + 1), a, b); \
|
||||
val2_tmp = __shfl_xor_sync(WARP_REDUCE_MASK, *(pval + 2), a, b); \
|
||||
val3_tmp = __shfl_xor_sync(WARP_REDUCE_MASK, *(pval + 3), a, b); \
|
||||
*(pval + 0) += val0_tmp; \
|
||||
*(pval + 1) += val1_tmp; \
|
||||
*(pval + 2) += val2_tmp; \
|
||||
*(pval + 3) += val3_tmp
|
||||
|
||||
WarpReduceSumOneStep(16, 32);
|
||||
WarpReduceSumOneStep(8, 32);
|
||||
WarpReduceSumOneStep(4, 32);
|
||||
WarpReduceSumOneStep(2, 32);
|
||||
WarpReduceSumOneStep(1, 32);
|
||||
#undef WarpReduceSumOneStep
|
||||
}
|
||||
|
||||
template <>
|
||||
__inline__ __device__ void blockReduce<ReduceType::kSum, 1>(float *pval) {
|
||||
const int num = 1;
|
||||
static __shared__ float shared[num][32];
|
||||
int lane_id = threadIdx.x & 0x1f;
|
||||
int wid = threadIdx.x >> 5;
|
||||
|
||||
warpReduce<ReduceType::kSum, num>(pval);
|
||||
|
||||
if (lane_id == 0) {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < num; ++i) {
|
||||
shared[i][wid] = *(pval + i);
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
if (threadIdx.x < (blockDim.x >> 5)) {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < num; ++i) {
|
||||
*(pval + i) = shared[i][lane_id];
|
||||
}
|
||||
} else {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < num; ++i) {
|
||||
*(pval + i) = 0.f;
|
||||
}
|
||||
}
|
||||
warpReduce<ReduceType::kSum, num>(pval);
|
||||
}
|
||||
|
||||
template <>
|
||||
__inline__ __device__ void blockReduce<ReduceType::kSum, 2>(float *pval) {
|
||||
const int num = 2;
|
||||
static __shared__ float shared[num][32];
|
||||
int lane_id = threadIdx.x & 0x1f;
|
||||
int wid = threadIdx.x >> 5;
|
||||
|
||||
warpReduce<ReduceType::kSum, num>(pval);
|
||||
|
||||
if (lane_id == 0) {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < num; ++i) {
|
||||
shared[i][wid] = *(pval + i);
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
if (threadIdx.x < (blockDim.x >> 5)) {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < num; ++i) {
|
||||
*(pval + i) = shared[i][lane_id];
|
||||
}
|
||||
} else {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < num; ++i) {
|
||||
*(pval + i) = 0.f;
|
||||
}
|
||||
}
|
||||
warpReduce<ReduceType::kSum, num>(pval);
|
||||
}
|
||||
|
||||
template <>
|
||||
__inline__ __device__ void blockReduce<ReduceType::kSum, 4>(float *pval) {
|
||||
const int num = 4;
|
||||
static __shared__ float shared[num][32];
|
||||
int lane_id = threadIdx.x & 0x1f;
|
||||
int wid = threadIdx.x >> 5;
|
||||
|
||||
warpReduce<ReduceType::kSum, num>(pval);
|
||||
|
||||
if (lane_id == 0) {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < num; ++i) {
|
||||
shared[i][wid] = *(pval + i);
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
if (threadIdx.x < (blockDim.x >> 5)) {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < num; ++i) {
|
||||
*(pval + i) = shared[i][lane_id];
|
||||
}
|
||||
} else {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < num; ++i) {
|
||||
*(pval + i) = 0.f;
|
||||
}
|
||||
}
|
||||
warpReduce<ReduceType::kSum, num>(pval);
|
||||
}
|
||||
|
||||
template <>
|
||||
__inline__ __device__ void blockReduce<ReduceType::kMax, 1>(float *pval) {
|
||||
const int num = 1;
|
||||
static __shared__ float shared[num][32];
|
||||
int lane_id = threadIdx.x & 0x1f;
|
||||
int wid = threadIdx.x >> 5;
|
||||
|
||||
warpReduce<ReduceType::kMax, num>(pval);
|
||||
|
||||
if (lane_id == 0) {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < num; ++i) {
|
||||
shared[i][wid] = *(pval + i);
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
if (threadIdx.x < (blockDim.x >> 5)) {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < num; ++i) {
|
||||
*(pval + i) = shared[i][lane_id];
|
||||
}
|
||||
} else {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < num; ++i) {
|
||||
*(pval + i) = REDUCE_FLOAT_INF_NEG;
|
||||
}
|
||||
}
|
||||
warpReduce<ReduceType::kMax, num>(pval);
|
||||
}
|
||||
|
||||
template <>
|
||||
__inline__ __device__ void blockReduce<ReduceType::kMax, 2>(float *pval) {
|
||||
const int num = 1;
|
||||
static __shared__ float shared[num][32];
|
||||
int lane_id = threadIdx.x & 0x1f;
|
||||
int wid = threadIdx.x >> 5;
|
||||
|
||||
warpReduce<ReduceType::kMax, num>(pval);
|
||||
|
||||
if (lane_id == 0) {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < num; ++i) {
|
||||
shared[i][wid] = *(pval + i);
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
if (threadIdx.x < (blockDim.x >> 5)) {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < num; ++i) {
|
||||
*(pval + i) = shared[i][lane_id];
|
||||
}
|
||||
} else {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < num; ++i) {
|
||||
*(pval + i) = REDUCE_FLOAT_INF_NEG;
|
||||
}
|
||||
}
|
||||
warpReduce<ReduceType::kMax, num>(pval);
|
||||
}
|
||||
|
||||
template <>
|
||||
__inline__ __device__ void blockReduce<ReduceType::kMax, 4>(float *pval) {
|
||||
const int num = 1;
|
||||
static __shared__ float shared[num][32];
|
||||
int lane_id = threadIdx.x & 0x1f;
|
||||
int wid = threadIdx.x >> 5;
|
||||
|
||||
warpReduce<ReduceType::kMax, num>(pval);
|
||||
|
||||
if (lane_id == 0) {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < num; ++i) {
|
||||
shared[i][wid] = *(pval + i);
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
if (threadIdx.x < (blockDim.x >> 5)) {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < num; ++i) {
|
||||
*(pval + i) = shared[i][lane_id];
|
||||
}
|
||||
} else {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < num; ++i) {
|
||||
*(pval + i) = REDUCE_FLOAT_INF_NEG;
|
||||
}
|
||||
}
|
||||
warpReduce<ReduceType::kMax, num>(pval);
|
||||
}
|
|
@ -0,0 +1,36 @@
|
|||
#pragma once
|
||||
|
||||
#include <cublas_v2.h>
|
||||
#include <cuda.h>
|
||||
|
||||
#include <iostream>
|
||||
#include <string>
|
||||
|
||||
#include "cuda_util.h"
|
||||
|
||||
class Context {
|
||||
public:
|
||||
Context() : _stream(nullptr) {
|
||||
CHECK_GPU_ERROR(cublasCreate(&_cublasHandle));
|
||||
}
|
||||
|
||||
virtual ~Context() {}
|
||||
|
||||
static Context &Instance() {
|
||||
static Context _ctx;
|
||||
return _ctx;
|
||||
}
|
||||
|
||||
void set_stream(cudaStream_t stream) {
|
||||
_stream = stream;
|
||||
CHECK_GPU_ERROR(cublasSetStream(_cublasHandle, _stream));
|
||||
}
|
||||
|
||||
cudaStream_t get_stream() { return _stream; }
|
||||
|
||||
cublasHandle_t get_cublashandle() { return _cublasHandle; }
|
||||
|
||||
private:
|
||||
cudaStream_t _stream;
|
||||
cublasHandle_t _cublasHandle;
|
||||
};
|
|
@ -0,0 +1,46 @@
|
|||
#pragma once
|
||||
|
||||
#include <cuda.h>
|
||||
#include <cuda_fp16.h>
|
||||
#include <cuda_runtime_api.h>
|
||||
|
||||
#include <type_traits>
|
||||
|
||||
#include "cuda_util.h"
|
||||
|
||||
template <typename T>
|
||||
class CrossEntropyLayer {
|
||||
public:
|
||||
CrossEntropyLayer(float epsilon, int padding_idx, int max_batch_tokens);
|
||||
|
||||
virtual ~CrossEntropyLayer();
|
||||
|
||||
void Forward(const T *inputs_ptr, const int *targets_ptr, float *outputs_ptr,
|
||||
float *nll_loss_ptr);
|
||||
|
||||
void Backward(const float *grad_outputs_ptr, const T *inputs_ptr,
|
||||
const int *targets_ptr, T *grad_inputs_ptr);
|
||||
|
||||
void set_cur_batch_shape(int batch_size, int seq_len, int vocab_size);
|
||||
|
||||
private:
|
||||
void allocate_mem_buffer() {
|
||||
// allocate local gpu memory
|
||||
_loss_buffer = cuda_malloc<float>(_max_batch_tokens * 2);
|
||||
}
|
||||
|
||||
void free_mem_buffer() {
|
||||
// free local gpu memory
|
||||
cuda_free(_loss_buffer);
|
||||
}
|
||||
|
||||
const int _padding_idx;
|
||||
const float _epsilon;
|
||||
const int _max_batch_tokens;
|
||||
|
||||
size_t _batch_size;
|
||||
size_t _seq_len;
|
||||
size_t _vocab_size;
|
||||
|
||||
float *_loss_buffer;
|
||||
};
|
|
@ -0,0 +1,40 @@
|
|||
/* Copyright 2021 The LightSeq Team
|
||||
Copyright Microsoft DeepSpeed
|
||||
This file is adapted from Microsoft DeepSpeed
|
||||
*/
|
||||
#pragma once
|
||||
|
||||
#include <assert.h>
|
||||
#include <cublas_v2.h>
|
||||
#include <cuda.h>
|
||||
#include <cuda_fp16.h>
|
||||
#include <cuda_runtime.h>
|
||||
#include <mma.h>
|
||||
#include <stdio.h>
|
||||
|
||||
int cublas_gemm_ex(cublasHandle_t handle, cublasOperation_t transa,
|
||||
cublasOperation_t transb, int m, int n, int k,
|
||||
const float *alpha, const float *beta, const float *A,
|
||||
const float *B, float *C,
|
||||
cublasGemmAlgo_t algo = CUBLAS_GEMM_DEFAULT);
|
||||
|
||||
int cublas_gemm_ex(cublasHandle_t handle, cublasOperation_t transa,
|
||||
cublasOperation_t transb, int m, int n, int k,
|
||||
const float *alpha, const float *beta, const __half *A,
|
||||
const __half *B, __half *C,
|
||||
cublasGemmAlgo_t algo = CUBLAS_GEMM_DEFAULT_TENSOR_OP);
|
||||
|
||||
int cublas_strided_batched_gemm(cublasHandle_t handle, int m, int n, int k,
|
||||
const float *alpha, const float *beta,
|
||||
const float *A, const float *B, float *C,
|
||||
cublasOperation_t op_A, cublasOperation_t op_B,
|
||||
int stride_A, int stride_B, int stride_C,
|
||||
int batch,
|
||||
cublasGemmAlgo_t algo = CUBLAS_GEMM_DEFAULT);
|
||||
|
||||
int cublas_strided_batched_gemm(
|
||||
cublasHandle_t handle, int m, int n, int k, const float *alpha,
|
||||
const float *beta, const __half *A, const __half *B, __half *C,
|
||||
cublasOperation_t op_A, cublasOperation_t op_B, int stride_A, int stride_B,
|
||||
int stride_C, int batch,
|
||||
cublasGemmAlgo_t algo = CUBLAS_GEMM_DEFAULT_TENSOR_OP);
|
|
@ -0,0 +1,34 @@
|
|||
#pragma once
|
||||
|
||||
#include <cublas_v2.h>
|
||||
#include <cuda.h>
|
||||
#include <math_constants.h>
|
||||
|
||||
#include <chrono>
|
||||
#include <fstream>
|
||||
#include <iostream>
|
||||
#include <string>
|
||||
#include <type_traits>
|
||||
#include <vector>
|
||||
|
||||
template <typename T>
|
||||
void check_gpu_error(T result, char const *const func, const char *const file,
|
||||
int const line);
|
||||
|
||||
#define CHECK_GPU_ERROR(val) check_gpu_error((val), #val, __FILE__, __LINE__)
|
||||
|
||||
template <typename T>
|
||||
void print_vec(const T *outv, std::string outn, int num_output_ele);
|
||||
|
||||
template <typename T>
|
||||
T *cuda_malloc(size_t ele_num);
|
||||
|
||||
void cuda_free(void *pdata);
|
||||
|
||||
template <typename T>
|
||||
void check_nan_inf(const T *data_ptr, int dsize, bool check_nan_inf,
|
||||
std::string file, int line, cudaStream_t stream);
|
||||
|
||||
#define CHECK_NAN_INF(ptr, size, stream) \
|
||||
check_nan_inf((ptr), (size), true, __FILE__, __LINE__, (stream)); \
|
||||
check_nan_inf((ptr), (size), false, __FILE__, __LINE__, (stream))
|
|
@ -0,0 +1,95 @@
|
|||
#pragma once
|
||||
|
||||
#include <string>
|
||||
#include <cuda.h>
|
||||
#include <cuda_fp16.h>
|
||||
#include <stdio.h>
|
||||
|
||||
#include "kernels.h"
|
||||
|
||||
template <typename T>
|
||||
class Dropout {
|
||||
public:
|
||||
struct Config {
|
||||
float ratio;
|
||||
bool training;
|
||||
|
||||
Config(float r) : ratio(r), training(true) {}
|
||||
float RATIO() const { return training ? ratio : 0.0; }
|
||||
};
|
||||
|
||||
Dropout(const Config &config, size_t max_ele_num)
|
||||
: _config(config), _mask(nullptr) {
|
||||
_mask = cuda_malloc<uint8_t>(max_ele_num);
|
||||
}
|
||||
|
||||
virtual ~Dropout() { cuda_free(_mask); }
|
||||
|
||||
// after attention softmax
|
||||
void dropout(T *output, const T *input, int count, cudaStream_t stream,
|
||||
bool bwd = false) {
|
||||
launch_ls_dropout<T>(output, input, _mask, count, _config.RATIO(), stream,
|
||||
bwd);
|
||||
}
|
||||
|
||||
void d_dropout(T *d_inp_out, int count, cudaStream_t stream) {
|
||||
launch_ls_dropout<T>(d_inp_out, d_inp_out, _mask, count, _config.RATIO(),
|
||||
stream, true);
|
||||
}
|
||||
|
||||
// transformer layer's postprocessing dropout, after attn or ffn module,
|
||||
// before residual add.
|
||||
void bias_dropout_residual(T *output, const T *input, const T *residual,
|
||||
const T *bias, int rows, int cols,
|
||||
cudaStream_t stream) {
|
||||
launch_ls_dropout_res_bias<T>(output, input, _mask, bias, residual,
|
||||
rows * cols, cols, _config.RATIO(), stream);
|
||||
}
|
||||
|
||||
void d_bias_dropout_residual(T *d_input, T *d_bias, const T *d_output,
|
||||
int rows, int cols, cudaStream_t stream) {
|
||||
launch_ls_dropout_bias_bwd<T>(d_input, d_bias, d_output, _mask, rows, cols,
|
||||
_config.RATIO(), stream);
|
||||
}
|
||||
|
||||
// dropout inside ffn.
|
||||
void bias_act_dropout(T *output, const T *input, const T *bias, int rows,
|
||||
int cols, std::string activation_fn,
|
||||
cudaStream_t stream) {
|
||||
if (activation_fn == "relu") {
|
||||
launch_ls_dropout_act_bias<ActivationType::kRelu, T>(
|
||||
output, input, _mask, bias, rows * cols, cols, _config.RATIO(),
|
||||
stream);
|
||||
} else if (activation_fn == "gelu") {
|
||||
launch_ls_dropout_act_bias<ActivationType::kGelu, T>(
|
||||
output, input, _mask, bias, rows * cols, cols, _config.RATIO(),
|
||||
stream);
|
||||
} else {
|
||||
throw std::runtime_error("not supported activation: " + activation_fn);
|
||||
}
|
||||
}
|
||||
|
||||
void d_bias_act_dropout(T *d_inp_out, T *d_bias_out, const T *input,
|
||||
const T *bias, int rows, int cols,
|
||||
std::string activation_fn, cudaStream_t stream) {
|
||||
if (activation_fn == "relu") {
|
||||
launch_ls_dropout_act_bias_bwd<ActivationType::kRelu, T>(
|
||||
d_inp_out, d_bias_out, input, bias, d_inp_out, _mask, rows, cols,
|
||||
_config.RATIO(), stream);
|
||||
} else if (activation_fn == "gelu") {
|
||||
launch_ls_dropout_act_bias_bwd<ActivationType::kGelu, T>(
|
||||
d_inp_out, d_bias_out, input, bias, d_inp_out, _mask, rows, cols,
|
||||
_config.RATIO(), stream);
|
||||
} else {
|
||||
throw std::runtime_error("not supported activation: " + activation_fn);
|
||||
}
|
||||
}
|
||||
|
||||
bool HasDropout() const { return _config.RATIO() > 0.0; }
|
||||
|
||||
void SetTrainingMode(bool training) { _config.training = training; }
|
||||
|
||||
private:
|
||||
uint8_t *_mask;
|
||||
Config _config;
|
||||
};
|
|
@ -0,0 +1,68 @@
|
|||
#pragma once
|
||||
|
||||
/* Copyright 2021 The LightSeq Team
|
||||
Copyright Microsoft DeepSpeed
|
||||
This file is adapted from Microsoft DeepSpeed
|
||||
*/
|
||||
#include <cuda.h>
|
||||
#include <cuda_fp16.h>
|
||||
#include <stdio.h>
|
||||
|
||||
#include <array>
|
||||
|
||||
#include "cublas_wrappers.h"
|
||||
#include "kernels.h"
|
||||
|
||||
template <typename T>
|
||||
class FeedForward {
|
||||
public:
|
||||
struct Config {
|
||||
int outputSize;
|
||||
int inputSize;
|
||||
std::array<int, 3> gemm_algos;
|
||||
Config(int outputs, int inputs)
|
||||
: outputSize(outputs),
|
||||
inputSize(inputs),
|
||||
gemm_algos(std::array<int, 3>({99, 99, 99})) {}
|
||||
};
|
||||
|
||||
FeedForward(Config config) : config_(config) {}
|
||||
|
||||
~FeedForward() {}
|
||||
|
||||
void Forward(int bsz, const T *input_ptr, const T *weights, T *out,
|
||||
cublasHandle_t &_cublasHandle) {
|
||||
float alpha = T(1.);
|
||||
float beta = T(0.);
|
||||
|
||||
cublas_gemm_ex(_cublasHandle, CUBLAS_OP_T, CUBLAS_OP_N, config_.outputSize,
|
||||
bsz, config_.inputSize, &alpha, &beta, weights, input_ptr,
|
||||
out, cublasGemmAlgo_t(config_.gemm_algos[0]));
|
||||
}
|
||||
void Backward(int bsz, const T *out_grad, const T *input_ptr,
|
||||
const T *weights, T *weights_grad, T *bias_grad,
|
||||
cublasHandle_t &_cublasHandle, cudaStream_t &stream,
|
||||
T *inp_grad_out = nullptr, T *out_grad_trans_out = nullptr,
|
||||
bool compute_bias = true) {
|
||||
float alpha = (T)1.0, beta = (T)0.0;
|
||||
cublas_gemm_ex(_cublasHandle, CUBLAS_OP_N, CUBLAS_OP_T, config_.inputSize,
|
||||
config_.outputSize, bsz, &alpha, &beta, input_ptr, out_grad,
|
||||
weights_grad, cublasGemmAlgo_t(config_.gemm_algos[1]));
|
||||
|
||||
cublas_gemm_ex(_cublasHandle, CUBLAS_OP_N, CUBLAS_OP_N, config_.inputSize,
|
||||
bsz, config_.outputSize, &alpha, &beta, weights, out_grad,
|
||||
inp_grad_out, cublasGemmAlgo_t(config_.gemm_algos[2]));
|
||||
if (compute_bias) {
|
||||
launch_fuse_transpose_bias_kernel<T>(out_grad, bias_grad, bsz,
|
||||
config_.outputSize, stream);
|
||||
}
|
||||
}
|
||||
|
||||
void reset_size(int outputSize, int inputSize) {
|
||||
config_.outputSize = outputSize;
|
||||
config_.inputSize = inputSize;
|
||||
}
|
||||
|
||||
private:
|
||||
Config config_;
|
||||
};
|
|
@ -0,0 +1,274 @@
|
|||
#pragma once
|
||||
|
||||
#include <cuda.h>
|
||||
#include <cuda_fp16.h>
|
||||
#include <curand_kernel.h>
|
||||
#include <stdio.h>
|
||||
#include <stdlib.h>
|
||||
#include <stdexcept>
|
||||
|
||||
#define MAX_THREADS 1024
|
||||
#define WARP_SIZE 32
|
||||
|
||||
enum class ActivationType { kRelu, kGelu };
|
||||
|
||||
void launch_curand_init(int total_count, int dim, cudaStream_t stream);
|
||||
|
||||
template <typename T>
|
||||
void launch_layer_norm(T *ln_res, T *vars, T *means, const T *inp,
|
||||
const T *scale, const T *bias, int batch_size,
|
||||
int hidden_dim, cudaStream_t stream);
|
||||
|
||||
template <typename T>
|
||||
void launch_ln_bw(T *gamma_grad, T *betta_grad, T *inp_grad, const T *out_grad,
|
||||
const T *residual_grad, const T *inp_or_out, const T *gamma,
|
||||
const T *betta, const T *vars, const T *means, int batch,
|
||||
int hidden_dim, cudaStream_t stream[2]);
|
||||
|
||||
template <typename T>
|
||||
void launch_attn_softmax(T *vals, const T *attn_mask, int batch_size, int heads,
|
||||
int from_len, int to_len, bool mask_future,
|
||||
cudaStream_t stream);
|
||||
|
||||
template <typename T>
|
||||
void launch_attn_softmax_bw(T *out_grad, const T *soft_inp, int rows,
|
||||
int softmax_len, cudaStream_t stream);
|
||||
|
||||
// [b, s, h] -> [b, nh, s, ad]
|
||||
template <typename T>
|
||||
void launch_transform_0213(T *output, const T *vals, int batch_size,
|
||||
int seq_length, int hidden_dim, int nhead,
|
||||
cudaStream_t stream);
|
||||
|
||||
// [b, s, 3, h] -> [3, b, nh, s, ad]
|
||||
template <typename T>
|
||||
void launch_bias_add_transform_20314(T *output, const T *input, const T *bias,
|
||||
int dim_0, int dim_1, int dim_2, int dim_3,
|
||||
int dim_4, cudaStream_t stream);
|
||||
|
||||
// [tc, b, nh, s, ad] -> [b, s, tc, nh, ad]
|
||||
template <typename T>
|
||||
void launch_transform4d_0213(T *output, const T *vals, int batch_size,
|
||||
int seq_len, int hidden_dim, int nhead,
|
||||
int trans_count, cudaStream_t stream);
|
||||
|
||||
template <typename T>
|
||||
void launch_ls_dropout(T *out, const T *vals, uint8_t *mask, int total_count,
|
||||
float ratio, cudaStream_t stream, bool backward = false);
|
||||
|
||||
template <typename T>
|
||||
void launch_ls_dropout_res_bias(T *out, const T *vals, uint8_t *mask,
|
||||
const T *bias, const T *residual,
|
||||
int total_count, int dim, float ratio,
|
||||
cudaStream_t stream);
|
||||
|
||||
template <ActivationType, typename T>
|
||||
void launch_ls_dropout_act_bias(T *out, const T *vals, uint8_t *mask,
|
||||
const T *bias, int total_count, int dim,
|
||||
float ratio, cudaStream_t stream);
|
||||
|
||||
template <typename T>
|
||||
void launch_ls_dropout_bias_bwd(T *in_grad, T *bias_grad, const T *out_grad,
|
||||
const uint8_t *mask, int row_size, int dim,
|
||||
float ratio, cudaStream_t stream);
|
||||
|
||||
template <ActivationType act_type, typename T>
|
||||
void launch_ls_dropout_act_bias_bwd(T *in_grad, T *bias_grad, const T *input,
|
||||
const T *bias, const T *out_grad,
|
||||
const uint8_t *mask, int row_size, int dim,
|
||||
float ratio, cudaStream_t stream);
|
||||
|
||||
template <typename T>
|
||||
void launch_fuse_transpose_bias_kernel(const T *inp, T *out, int rows, int cols,
|
||||
cudaStream_t stream);
|
||||
|
||||
void launch_param_update(const float *input, __half *output, int size,
|
||||
cudaStream_t stream);
|
||||
|
||||
template <typename T>
|
||||
void launch_concat3_dim1(const T *inp1, const T *inp2, T *output, int sz0,
|
||||
int sz2, int sz1_1, int sz1_2, cudaStream_t stream);
|
||||
|
||||
template <typename T>
|
||||
void launch_fused_add2(T *out, const T *inp1, const T *inp2, int batch_size,
|
||||
int seq_len, int hidden_size, cudaStream_t &stream);
|
||||
|
||||
template <typename T>
|
||||
void launch_cross_entropy_fw(const T *inputs_ptr, const int *targets_ptr,
|
||||
float *outputs_ptr, float *nll_loss_ptr,
|
||||
float *loss_buffer, const int padding_idx,
|
||||
const float epsilon, const int batch_size,
|
||||
const int seq_len, const int vocab_size,
|
||||
cudaStream_t stream);
|
||||
|
||||
template <typename T>
|
||||
void launch_cross_entropy_bw(const float *grad_outputs_ptr, const T *inputs_ptr,
|
||||
const int *targets_ptr, T *grad_inputs_ptr,
|
||||
const int padding_idx, const float epsilon,
|
||||
const int batch_size, const int seq_len,
|
||||
const int vocab_size, cudaStream_t stream);
|
||||
|
||||
template <typename T>
|
||||
void launch_lookup_scale_pos_dropout(
|
||||
T *output, const int *input, const T *embeddings, const T *pos_embeddings,
|
||||
uint8_t *dropout_mask, int batch_size, int seq_len, int embedding_dim,
|
||||
int padding_idx, float dropout_ratio, int step, cudaStream_t &stream);
|
||||
|
||||
template <typename T>
|
||||
void launch_d_lookup_scale_pos_dropout(
|
||||
T *grad_embeddings, const T *grad_output, const int *input,
|
||||
const uint8_t *dropout_mask, int batch_size, int seq_len, int embedding_dim,
|
||||
int vocab_size, int padding_idx, float dropout_ratio, cudaStream_t &stream);
|
||||
|
||||
/* Convert 2-dim tensor index into vector index */
|
||||
__forceinline__ __host__ __device__ int flat_2dim(int id1, int id2, int dim2) {
|
||||
return id1 * dim2 + id2;
|
||||
}
|
||||
|
||||
/* Convert 3-dim tensor index into vector index */
|
||||
__forceinline__ __host__ __device__ int flat_3dim(int id1, int id2, int id3,
|
||||
int dim2, int dim3) {
|
||||
return id1 * dim2 * dim3 + id2 * dim3 + id3;
|
||||
}
|
||||
|
||||
/* Convert 4-dim tensor index into vector index */
|
||||
__forceinline__ __host__ __device__ int flat_4dim(int id1, int id2, int id3,
|
||||
int id4, int dim2, int dim3,
|
||||
int dim4) {
|
||||
// return id1*(dim2*dim3*dim4) + id2*(dim3*dim4) + id3*dim4 + id4;
|
||||
int res = id4;
|
||||
|
||||
int ld = dim4;
|
||||
res += id3 * ld;
|
||||
|
||||
ld *= dim3;
|
||||
res += id2 * ld;
|
||||
|
||||
ld *= dim2;
|
||||
res += id1 * ld;
|
||||
|
||||
return res;
|
||||
}
|
||||
|
||||
/* Convert 5-dim tensor index into vector index */
|
||||
__forceinline__ __host__ __device__ int flat_5dim(int id1, int id2, int id3,
|
||||
int id4, int id5, int dim2,
|
||||
int dim3, int dim4,
|
||||
int dim5) {
|
||||
// return id1*(dim2*dim3*dim4*dim5) + id2*(dim3*dim4*dim5) + id3*(dim4*dim5) +
|
||||
// id4*dim5 + dim5;
|
||||
int res = id5;
|
||||
|
||||
int ld = dim5;
|
||||
res += id4 * ld;
|
||||
|
||||
ld *= dim4;
|
||||
res += id3 * ld;
|
||||
|
||||
ld *= dim3;
|
||||
res += id2 * ld;
|
||||
|
||||
ld *= dim2;
|
||||
res += id1 * ld;
|
||||
|
||||
return res;
|
||||
}
|
||||
|
||||
/* Convert 6-dim tensor index into vector index */
|
||||
__forceinline__ __host__ __device__ int flat_6dim(int id1, int id2, int id3,
|
||||
int id4, int id5, int id6,
|
||||
int dim2, int dim3, int dim4,
|
||||
int dim5, int dim6) {
|
||||
// return id1*(dim2*dim3*dim4*dim5*dim6) + id2*(dim3*dim4*dim5*dim6) +
|
||||
// id3*(dim4*dim5*dim6) + id4*(dim5*dim6) + id5*dim6 + id6;
|
||||
int res = id6;
|
||||
|
||||
int ld = dim6;
|
||||
res += id5 * ld;
|
||||
|
||||
ld *= dim5;
|
||||
res += id4 * ld;
|
||||
|
||||
ld *= dim4;
|
||||
res += id3 * ld;
|
||||
|
||||
ld *= dim3;
|
||||
res += id2 * ld;
|
||||
|
||||
ld *= dim2;
|
||||
res += id1 * ld;
|
||||
|
||||
return res;
|
||||
}
|
||||
|
||||
/* Convert vector index to 6-dim tensor index */
|
||||
__forceinline__ __host__ __device__ void decompose_6dim(
|
||||
int src, int dim1, int dim2, int dim3, int dim4, int dim5, int *id0,
|
||||
int *id1, int *id2, int *id3, int *id4, int *id5) {
|
||||
*id5 = src % dim5;
|
||||
src /= dim5;
|
||||
|
||||
*id4 = src % dim4;
|
||||
src /= dim4;
|
||||
|
||||
*id3 = src % dim3;
|
||||
src /= dim3;
|
||||
|
||||
*id2 = src % dim2;
|
||||
src /= dim2;
|
||||
|
||||
*id1 = src % dim1;
|
||||
*id0 = src / dim1;
|
||||
}
|
||||
|
||||
/* Convert vector index to 5-dim tensor index */
|
||||
__forceinline__ __host__ __device__ void decompose_5dim(int src, int dim1,
|
||||
int dim2, int dim3,
|
||||
int dim4, int *id0,
|
||||
int *id1, int *id2,
|
||||
int *id3, int *id4) {
|
||||
*id4 = src % dim4;
|
||||
src /= dim4;
|
||||
|
||||
*id3 = src % dim3;
|
||||
src /= dim3;
|
||||
|
||||
*id2 = src % dim2;
|
||||
src /= dim2;
|
||||
|
||||
*id1 = src % dim1;
|
||||
*id0 = src / dim1;
|
||||
}
|
||||
|
||||
/* Convert vector index to 4-dim tensor index */
|
||||
__forceinline__ __host__ __device__ void decompose_4dim(int src, int dim1,
|
||||
int dim2, int dim3,
|
||||
int *id0, int *id1,
|
||||
int *id2, int *id3) {
|
||||
*id3 = src % dim3;
|
||||
src /= dim3;
|
||||
|
||||
*id2 = src % dim2;
|
||||
src /= dim2;
|
||||
|
||||
*id1 = src % dim1;
|
||||
*id0 = src / dim1;
|
||||
}
|
||||
|
||||
/* Convert vector index to 3-dim tensor index */
|
||||
__forceinline__ __host__ __device__ void decompose_3dim(int src, int dim1,
|
||||
int dim2, int *id0,
|
||||
int *id1, int *id2) {
|
||||
*id2 = src % dim2;
|
||||
src /= dim2;
|
||||
|
||||
*id1 = src % dim1;
|
||||
*id0 = src / dim1;
|
||||
}
|
||||
|
||||
/* Convert vector index to 2-dim tensor index */
|
||||
__forceinline__ __host__ __device__ void decompose_2dim(int src, int dim1,
|
||||
int *id0, int *id1) {
|
||||
*id1 = src % dim1;
|
||||
*id0 = src / dim1;
|
||||
}
|
|
@ -0,0 +1,12 @@
|
|||
// copied from https://github.com/dmlc/dgl/pull/2758
|
||||
#ifndef DGL_ARRAY_CUDA_DGL_CUB_CUH_
|
||||
#define DGL_ARRAY_CUDA_DGL_CUB_CUH_
|
||||
|
||||
#define CUB_NS_PREFIX namespace ls {
|
||||
#define CUB_NS_POSTFIX }
|
||||
#include "cub/cub.cuh"
|
||||
#include "cub/util_allocator.cuh"
|
||||
#undef CUB_NS_POSTFIX
|
||||
#undef CUB_NS_PREFIX
|
||||
|
||||
#endif
|
|
@ -0,0 +1,65 @@
|
|||
#pragma once
|
||||
|
||||
#include <cuda.h>
|
||||
#include <cuda_fp16.h>
|
||||
#include <stdio.h>
|
||||
|
||||
#include <fstream>
|
||||
|
||||
#include "kernels.h"
|
||||
|
||||
using namespace std;
|
||||
|
||||
template <typename T>
|
||||
class Normalize_Layer {
|
||||
public:
|
||||
struct Config {
|
||||
uint32_t hidden_dim;
|
||||
bool use_mean;
|
||||
Config(uint32_t hidden_dim, bool use_mean = false)
|
||||
: hidden_dim(hidden_dim), use_mean(use_mean) {}
|
||||
};
|
||||
|
||||
Normalize_Layer(Config config, size_t max_rows)
|
||||
: config_(config), vars_(nullptr), means_(nullptr) {
|
||||
vars_ = cuda_malloc<T>(max_rows);
|
||||
if (config_.use_mean) {
|
||||
means_ = cuda_malloc<T>(max_rows);
|
||||
}
|
||||
}
|
||||
|
||||
~Normalize_Layer() {
|
||||
cuda_free(vars_);
|
||||
cuda_free(means_);
|
||||
}
|
||||
|
||||
void Forward(T *ln_res, const T *inp, const T *gamma, const T *betta,
|
||||
int batch_size, cudaStream_t stream) {
|
||||
launch_layer_norm(ln_res, vars_, means_, inp, gamma, betta, batch_size,
|
||||
config_.hidden_dim, stream);
|
||||
}
|
||||
|
||||
/*
|
||||
residual_grad, inp_or_out, betta should be treated carefully.
|
||||
inp_or_out = input if use_mean else output
|
||||
residual_grad, betta can be nullptr.
|
||||
residual_grad will be added to dinp if it is not nullptr
|
||||
which is useful in transformer layer when pre-ln
|
||||
betta are only used to compute xhat,
|
||||
(use_mean == false) ^ (betta == nullptr) should be true
|
||||
*/
|
||||
void Backward(T *gamma_grad, T *betta_grad, T *inp_grad, const T *out_grad,
|
||||
const T *residual_grad, const T *inp_or_out, const T *gamma,
|
||||
const T *betta, int batch_size, cudaStream_t stream[2]) {
|
||||
launch_ln_bw(gamma_grad, betta_grad, inp_grad, out_grad, residual_grad,
|
||||
inp_or_out, gamma, betta, vars_, means_, batch_size,
|
||||
config_.hidden_dim, stream);
|
||||
}
|
||||
|
||||
inline bool use_mean() const { return config_.use_mean; }
|
||||
|
||||
private:
|
||||
Config config_;
|
||||
T *vars_;
|
||||
T *means_;
|
||||
};
|
|
@ -0,0 +1,44 @@
|
|||
#pragma once
|
||||
|
||||
#include <cuda.h>
|
||||
#include <cuda_fp16.h>
|
||||
#include <stdio.h>
|
||||
|
||||
#include <fstream>
|
||||
|
||||
#include "kernels.h"
|
||||
|
||||
using namespace std;
|
||||
|
||||
template <typename T>
|
||||
class Softmax {
|
||||
public:
|
||||
struct Config {
|
||||
size_t nhead;
|
||||
Config(size_t nhead) : nhead(nhead) {}
|
||||
};
|
||||
|
||||
Softmax(Config config) : config_(config) {}
|
||||
|
||||
~Softmax() {}
|
||||
|
||||
void Forward(T *vals, const T *attn_mask, int batch_size, int from_len,
|
||||
int to_len, cudaStream_t &stream, bool mask_future = true) {
|
||||
launch_attn_softmax<T>(vals, attn_mask, batch_size, config_.nhead, from_len,
|
||||
to_len, mask_future, stream);
|
||||
}
|
||||
|
||||
void Backward(T *out_grad, const T *soft_out, int batch_size, int from_len,
|
||||
int to_len, cudaStream_t stream) {
|
||||
launch_attn_softmax_bw<T>(out_grad, soft_out,
|
||||
batch_size * config_.nhead * from_len, to_len,
|
||||
stream);
|
||||
}
|
||||
|
||||
void reset_size(size_t nhead) {
|
||||
config_.nhead = nhead;
|
||||
}
|
||||
|
||||
private:
|
||||
Config config_;
|
||||
};
|
|
@ -0,0 +1,99 @@
|
|||
/* Copyright 2021 The LightSeq Team
|
||||
Copyright Microsoft DeepSpeed
|
||||
This file is adapted from Microsoft DeepSpeed
|
||||
*/
|
||||
#pragma once
|
||||
|
||||
#include <cuda.h>
|
||||
#include <cuda_fp16.h>
|
||||
#include <stdio.h>
|
||||
|
||||
#include <array>
|
||||
|
||||
#include "cublas_wrappers.h"
|
||||
|
||||
template <typename T>
|
||||
class StridedBatchGemm {
|
||||
public:
|
||||
struct Config {
|
||||
int m;
|
||||
int n;
|
||||
int k;
|
||||
float alpha;
|
||||
float beta;
|
||||
cublasOperation_t op_A;
|
||||
cublasOperation_t op_B;
|
||||
std::array<int, 3> gemm_algos;
|
||||
|
||||
Config(float param_alpha, float param_beta, cublasOperation_t opA,
|
||||
cublasOperation_t opB)
|
||||
: alpha(param_alpha),
|
||||
beta(param_beta),
|
||||
op_A(opA),
|
||||
op_B(opB),
|
||||
gemm_algos(std::array<int, 3>({99, 99, 99})) {}
|
||||
void SetConfig(int mm, int nn, int kk) {
|
||||
m = mm;
|
||||
n = nn;
|
||||
k = kk;
|
||||
}
|
||||
};
|
||||
|
||||
StridedBatchGemm(const Config &config) : _config(config) {}
|
||||
|
||||
virtual ~StridedBatchGemm() {}
|
||||
|
||||
void Forward(int bsz, T *output, const T *_buffer_a, const T *_buffer_b,
|
||||
cublasHandle_t handle) {
|
||||
int stride_a = _config.m * _config.k;
|
||||
int stride_b = _config.n * _config.k;
|
||||
int stride_c = _config.m * _config.n;
|
||||
|
||||
cublas_strided_batched_gemm(
|
||||
handle, _config.m, _config.n, _config.k, &_config.alpha, &_config.beta,
|
||||
_buffer_a, _buffer_b, output, _config.op_A, _config.op_B, stride_a,
|
||||
stride_b, stride_c, bsz, cublasGemmAlgo_t(_config.gemm_algos[0]));
|
||||
}
|
||||
|
||||
void Backward(int bsz, const T *d_output, const T *_buffer_a,
|
||||
const T *_buffer_b, cublasHandle_t handle,
|
||||
T *inpGradA = nullptr, T *inpGradB = nullptr) {
|
||||
int mb = (_config.op_A == CUBLAS_OP_T ? _config.k : _config.m);
|
||||
int kb = (_config.op_A == CUBLAS_OP_T ? _config.m : _config.k);
|
||||
|
||||
int stride_a = mb * _config.n;
|
||||
int stride_b = _config.n * kb;
|
||||
int stride_c = _config.m * _config.k;
|
||||
|
||||
// B need to transpose.
|
||||
cublasOperation_t op_b =
|
||||
(_config.op_B == CUBLAS_OP_T ? CUBLAS_OP_N : CUBLAS_OP_T);
|
||||
|
||||
// Calculate d_A.
|
||||
cublas_strided_batched_gemm(
|
||||
handle, mb, kb, _config.n, &_config.alpha, &_config.beta,
|
||||
(_config.op_A == CUBLAS_OP_T ? _buffer_b : d_output),
|
||||
(_config.op_A == CUBLAS_OP_T ? d_output : _buffer_b), inpGradA,
|
||||
CUBLAS_OP_N, op_b, stride_a, stride_b, stride_c, bsz,
|
||||
cublasGemmAlgo_t(_config.gemm_algos[1]));
|
||||
|
||||
// A need to transpose.
|
||||
cublasOperation_t op_a =
|
||||
(_config.op_A == CUBLAS_OP_T ? CUBLAS_OP_N : CUBLAS_OP_T);
|
||||
|
||||
stride_a = _config.m * _config.k;
|
||||
stride_b = _config.m * _config.n;
|
||||
stride_c = _config.n * _config.k;
|
||||
|
||||
// Calculate d_B.
|
||||
cublas_strided_batched_gemm(
|
||||
handle, _config.k, _config.n, _config.m, &_config.alpha, &_config.beta,
|
||||
_buffer_a, d_output, inpGradB, op_a, CUBLAS_OP_N, stride_a, stride_b,
|
||||
stride_c, bsz, cublasGemmAlgo_t(_config.gemm_algos[2]));
|
||||
}
|
||||
|
||||
inline void SetConfig(int m, int n, int k) { _config.SetConfig(m, n, k); }
|
||||
|
||||
private:
|
||||
Config _config;
|
||||
};
|
File diff suppressed because it is too large
Load Diff
|
@ -0,0 +1,366 @@
|
|||
#include <math.h>
|
||||
|
||||
#include <cub/block/block_load.cuh>
|
||||
#include <cub/cub.cuh>
|
||||
|
||||
#include "block_reduce.h"
|
||||
#include "kernels.h"
|
||||
|
||||
#include <cooperative_groups.h>
|
||||
|
||||
namespace cg = cooperative_groups;
|
||||
const float EPSILON = 1e-8f;
|
||||
|
||||
/**
|
||||
@brief: softmax_kernel
|
||||
Softmax forward kernel for
|
||||
enc-self-attn, dec-self-attn, encdec-attn
|
||||
|
||||
@thread
|
||||
gridDim.x = dynamic
|
||||
gridDim.y = batch_size
|
||||
gridDim.z = nhead
|
||||
blockDim.x = from_len
|
||||
|
||||
@param
|
||||
inp: [batch_size, nhead, from_len, to_len], softmax input.
|
||||
attn_mask: [batch_size, to_len], padding tokens are -inf,
|
||||
non padding tokens are 0.
|
||||
attn_mask!=nullptr for enc-self-attn and enc-dec-attn
|
||||
attn_mask=nullptr and mask_future=ture for dec-self-attn training
|
||||
attn_mask=nullptr and mask_future=false for dec-self-attn infer
|
||||
*/
|
||||
template <typename T, int block_dim, int ele_per_thread>
|
||||
__global__ void ker_attn_softmax(T *inp, const T *attn_mask, int from_len,
|
||||
int to_len, bool mask_future) {
|
||||
int batch_id = blockIdx.y;
|
||||
int head_id = blockIdx.z;
|
||||
const int nhead = gridDim.z;
|
||||
const int token_per_reduce = 1;
|
||||
typedef cub::BlockLoad<T, block_dim, ele_per_thread,
|
||||
cub::BLOCK_LOAD_VECTORIZE>
|
||||
BlockLoad;
|
||||
__shared__ typename BlockLoad::TempStorage ts_load;
|
||||
typedef cub::BlockStore<T, block_dim, ele_per_thread,
|
||||
cub::BLOCK_STORE_VECTORIZE>
|
||||
BlockStore;
|
||||
__shared__ typename BlockStore::TempStorage ts_store;
|
||||
|
||||
T mval[ele_per_thread];
|
||||
if (attn_mask) {
|
||||
attn_mask += batch_id * to_len;
|
||||
BlockLoad(ts_load).Load(attn_mask, mval, to_len, REDUCE_FLOAT_INF_NEG);
|
||||
}
|
||||
|
||||
inp += flat_3dim(batch_id, head_id, 0, nhead, from_len * to_len);
|
||||
for (int token_id = blockIdx.x * token_per_reduce; token_id < from_len;
|
||||
token_id += gridDim.x * token_per_reduce) {
|
||||
T inp_val[token_per_reduce][ele_per_thread];
|
||||
for (int i = 0; i < token_per_reduce && (token_id + i) < from_len; i++) {
|
||||
BlockLoad(ts_load).Load(inp + (token_id + i) * to_len, inp_val[i], to_len,
|
||||
REDUCE_FLOAT_INF_NEG);
|
||||
}
|
||||
|
||||
/* step 1. compute max */
|
||||
// thread local max
|
||||
float val[token_per_reduce][ele_per_thread];
|
||||
float l_max[token_per_reduce];
|
||||
for (int i = 0; i < token_per_reduce; i++) {
|
||||
l_max[i] = REDUCE_FLOAT_INF_NEG;
|
||||
for (int j = 0; j < ele_per_thread; j++) {
|
||||
if (attn_mask) {
|
||||
val[i][j] = (float)inp_val[i][j] + (float)mval[j];
|
||||
} else {
|
||||
if (mask_future && ele_per_thread * threadIdx.x + j > token_id + i) {
|
||||
val[i][j] = REDUCE_FLOAT_INF_NEG;
|
||||
} else {
|
||||
val[i][j] = (float)inp_val[i][j];
|
||||
}
|
||||
}
|
||||
l_max[i] = fmaxf(l_max[i], val[i][j]);
|
||||
}
|
||||
}
|
||||
// block reduce max
|
||||
blockReduce<ReduceType::kMax, token_per_reduce>(l_max);
|
||||
// write shared
|
||||
__shared__ float s_max[token_per_reduce];
|
||||
if (threadIdx.x == 0) {
|
||||
for (int i = 0; i < token_per_reduce; i++) {
|
||||
s_max[i] = l_max[i];
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
/* step 2. compute sum */
|
||||
// thread local sum
|
||||
float l_sum[token_per_reduce];
|
||||
for (int i = 0; i < token_per_reduce; i++) {
|
||||
l_sum[i] = 0.f;
|
||||
for (int j = 0; j < ele_per_thread; j++) {
|
||||
val[i][j] = __expf(val[i][j] - s_max[i]);
|
||||
l_sum[i] += val[i][j];
|
||||
}
|
||||
}
|
||||
// block reduce sum
|
||||
blockReduce<ReduceType::kSum, token_per_reduce>(l_sum);
|
||||
// write shared
|
||||
__shared__ float s_sum[token_per_reduce];
|
||||
if (threadIdx.x == 0) {
|
||||
for (int i = 0; i < token_per_reduce; i++) {
|
||||
s_sum[i] = __fdividef(1.0f, l_sum[i] + EPSILON);
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
/* step 3. compute final result */
|
||||
for (int i = 0; i < token_per_reduce && (token_id + i) < from_len; i++) {
|
||||
for (int j = 0; j < ele_per_thread; j++) {
|
||||
inp_val[i][j] = (T)(val[i][j] * s_sum[i]);
|
||||
}
|
||||
BlockStore(ts_store).Store(inp + (token_id + i) * to_len, inp_val[i],
|
||||
to_len);
|
||||
}
|
||||
} // blockIdx.x
|
||||
}
|
||||
|
||||
template <typename T, int block_dim, int ele_per_thread>
|
||||
__global__ void ker_attn_softmax_lt32(T *inp, const T *attn_mask, int from_len,
|
||||
int to_len, bool mask_future) {
|
||||
int batch_id = blockIdx.y;
|
||||
int head_id = blockIdx.z;
|
||||
const int nhead = gridDim.z;
|
||||
const int token_per_reduce = 1;
|
||||
typedef cub::BlockLoad<T, block_dim, ele_per_thread,
|
||||
cub::BLOCK_LOAD_VECTORIZE>
|
||||
BlockLoad;
|
||||
__shared__ typename BlockLoad::TempStorage ts_load;
|
||||
typedef cub::BlockStore<T, block_dim, ele_per_thread,
|
||||
cub::BLOCK_STORE_VECTORIZE>
|
||||
BlockStore;
|
||||
__shared__ typename BlockStore::TempStorage ts_store;
|
||||
|
||||
T mval[ele_per_thread];
|
||||
if (attn_mask) {
|
||||
attn_mask += batch_id * to_len;
|
||||
BlockLoad(ts_load).Load(attn_mask, mval, to_len, REDUCE_FLOAT_INF_NEG);
|
||||
}
|
||||
|
||||
inp += flat_3dim(batch_id, head_id, 0, nhead, from_len * to_len);
|
||||
for (int token_id = blockIdx.x * token_per_reduce; token_id < from_len;
|
||||
token_id += gridDim.x * token_per_reduce) {
|
||||
T inp_val[token_per_reduce][ele_per_thread];
|
||||
for (int i = 0; i < token_per_reduce && (token_id + i) < from_len; i++) {
|
||||
BlockLoad(ts_load).Load(inp + (token_id + i) * to_len, inp_val[i], to_len,
|
||||
REDUCE_FLOAT_INF_NEG);
|
||||
}
|
||||
|
||||
/* step 1. compute max */
|
||||
// thread local max
|
||||
float val[token_per_reduce][ele_per_thread];
|
||||
float l_max[token_per_reduce];
|
||||
for (int i = 0; i < token_per_reduce; i++) {
|
||||
l_max[i] = REDUCE_FLOAT_INF_NEG;
|
||||
for (int j = 0; j < ele_per_thread; j++) {
|
||||
if (attn_mask) {
|
||||
val[i][j] = (float)inp_val[i][j] + (float)mval[j];
|
||||
} else {
|
||||
if (mask_future && ele_per_thread * threadIdx.x + j > token_id + i) {
|
||||
val[i][j] = REDUCE_FLOAT_INF_NEG;
|
||||
} else {
|
||||
val[i][j] = (float)inp_val[i][j];
|
||||
}
|
||||
}
|
||||
l_max[i] = fmaxf(l_max[i], val[i][j]);
|
||||
}
|
||||
}
|
||||
// warp reduce max
|
||||
warpReduce<ReduceType::kMax, token_per_reduce>(l_max);
|
||||
|
||||
/* step 2. compute sum */
|
||||
// thread local sum
|
||||
float l_sum[token_per_reduce];
|
||||
for (int i = 0; i < token_per_reduce; i++) {
|
||||
l_sum[i] = 0.f;
|
||||
for (int j = 0; j < ele_per_thread; j++) {
|
||||
val[i][j] = __expf(val[i][j] - l_max[i]);
|
||||
l_sum[i] += val[i][j];
|
||||
}
|
||||
}
|
||||
// warp reduce sum
|
||||
warpReduce<ReduceType::kSum, token_per_reduce>(l_sum);
|
||||
|
||||
/* step 3. compute final result */
|
||||
for (int i = 0; i < token_per_reduce && (token_id + i) < from_len; i++) {
|
||||
l_sum[i] = __fdividef(1.0f, l_sum[i] + EPSILON);
|
||||
for (int j = 0; j < ele_per_thread; j++) {
|
||||
inp_val[i][j] = (T)(val[i][j] * l_sum[i]);
|
||||
}
|
||||
BlockStore(ts_store).Store(inp + (token_id + i) * to_len, inp_val[i],
|
||||
to_len);
|
||||
}
|
||||
} // blockIdx.x
|
||||
}
|
||||
|
||||
/*
|
||||
attn_mask!=nullptr for enc-self-attn and enc-dec-attn
|
||||
attn_mask=nullptr and mask_future=ture for dec-self-attn training
|
||||
attn_mask=nullptr and mask_future=false for dec-self-attn infer
|
||||
*/
|
||||
template <>
|
||||
void launch_attn_softmax<float>(float *inp, const float *attn_mask,
|
||||
int batch_size, int nhead, int from_len,
|
||||
int to_len, bool mask_future,
|
||||
cudaStream_t stream) {
|
||||
dim3 grid_dim(1, batch_size, nhead);
|
||||
if (to_len <= 32) {
|
||||
ker_attn_softmax_lt32<float, 32, 1><<<grid_dim, 32, 0, stream>>>(
|
||||
inp, attn_mask, from_len, to_len, mask_future);
|
||||
} else if (to_len <= 64) {
|
||||
ker_attn_softmax_lt32<float, 32, 2><<<grid_dim, 32, 0, stream>>>(
|
||||
inp, attn_mask, from_len, to_len, mask_future);
|
||||
} else if (to_len <= 128) {
|
||||
grid_dim.x = 16;
|
||||
ker_attn_softmax<float, 64, 2><<<grid_dim, 64, 0, stream>>>(
|
||||
inp, attn_mask, from_len, to_len, mask_future);
|
||||
} else if (to_len <= 256) {
|
||||
grid_dim.x = 32;
|
||||
ker_attn_softmax<float, 128, 2><<<grid_dim, 128, 0, stream>>>(
|
||||
inp, attn_mask, from_len, to_len, mask_future);
|
||||
} else if (to_len <= 512) {
|
||||
grid_dim.x = 64;
|
||||
ker_attn_softmax<float, 256, 2><<<grid_dim, 256, 0, stream>>>(
|
||||
inp, attn_mask, from_len, to_len, mask_future);
|
||||
} else {
|
||||
throw std::runtime_error(
|
||||
"Sequence length greater than 512 is currently not supported");
|
||||
}
|
||||
}
|
||||
|
||||
template <>
|
||||
void launch_attn_softmax<__half>(__half *inp, const __half *attn_mask,
|
||||
int batch_size, int nhead, int from_len,
|
||||
int to_len, bool mask_future,
|
||||
cudaStream_t stream) {
|
||||
dim3 grid_dim(1, batch_size, nhead);
|
||||
if (to_len <= 32) {
|
||||
ker_attn_softmax_lt32<__half, 32, 1><<<grid_dim, 32, 0, stream>>>(
|
||||
inp, attn_mask, from_len, to_len, mask_future);
|
||||
} else if (to_len <= 64) {
|
||||
ker_attn_softmax_lt32<__half, 32, 2><<<grid_dim, 32, 0, stream>>>(
|
||||
inp, attn_mask, from_len, to_len, mask_future);
|
||||
} else if (to_len <= 128) {
|
||||
grid_dim.x = 8;
|
||||
ker_attn_softmax<__half, 64, 2><<<grid_dim, 64, 0, stream>>>(
|
||||
inp, attn_mask, from_len, to_len, mask_future);
|
||||
} else if (to_len <= 256) {
|
||||
grid_dim.x = 16;
|
||||
ker_attn_softmax<__half, 128, 2><<<grid_dim, 128, 0, stream>>>(
|
||||
inp, attn_mask, from_len, to_len, mask_future);
|
||||
} else if (to_len <= 512) {
|
||||
grid_dim.x = 32;
|
||||
ker_attn_softmax<__half, 256, 2><<<grid_dim, 256, 0, stream>>>(
|
||||
inp, attn_mask, from_len, to_len, mask_future);
|
||||
} else {
|
||||
throw std::runtime_error(
|
||||
"Sequence length greater than 512 is currently not supported");
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
@brief: ker_attn_softmax_bw
|
||||
Softmax backward in self attention.
|
||||
|
||||
@thread
|
||||
gridDim.x = batch_size * nhead * seq_len / warps_per_block
|
||||
blockDim.x = WARP_SIZE
|
||||
blockDim.y = warps_per_block
|
||||
|
||||
@param
|
||||
grad: [batch_size, nhead, seq_len, seq_len], output grad.
|
||||
output: [batch_size, nhead, seq_len, seq_len], output of softmax forward.
|
||||
*/
|
||||
template <typename T, int ITERATIONS>
|
||||
__global__ void ker_attn_softmax_bw(T *grad, const T *inp, int softmax_length) {
|
||||
int batch_idx = blockIdx.x * blockDim.y + threadIdx.y;
|
||||
int offset = batch_idx * softmax_length + threadIdx.x;
|
||||
|
||||
grad += offset;
|
||||
inp += offset;
|
||||
|
||||
T grad_reg[ITERATIONS];
|
||||
T inp_reg[ITERATIONS];
|
||||
float sum = 0.0;
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 0; i < ITERATIONS; ++i) {
|
||||
int curr_idx = threadIdx.x + i * WARP_SIZE;
|
||||
if (curr_idx < softmax_length) {
|
||||
grad_reg[i] = grad[i * WARP_SIZE];
|
||||
inp_reg[i] = inp[i * WARP_SIZE];
|
||||
sum += (float)grad_reg[i] * (float)inp_reg[i];
|
||||
}
|
||||
}
|
||||
|
||||
cg::thread_block b = cg::this_thread_block();
|
||||
cg::thread_block_tile<WARP_SIZE> g = cg::tiled_partition<WARP_SIZE>(b);
|
||||
|
||||
for (int i = 1; i < WARP_SIZE; i <<= 1) sum += g.shfl_xor(sum, i);
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 0; i < ITERATIONS; ++i) {
|
||||
int curr_idx = threadIdx.x + i * WARP_SIZE;
|
||||
if (curr_idx < softmax_length)
|
||||
grad[i * WARP_SIZE] = (T)((float)inp_reg[i] * ((float)grad_reg[i] - sum));
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void launch_attn_softmax_bw(T *out_grad, const T *soft_inp, int rows,
|
||||
int softmax_len, cudaStream_t stream) {
|
||||
const int warps_per_block = 4;
|
||||
// rows = batch_size * nhead * from_len
|
||||
dim3 grid_dim(rows / warps_per_block);
|
||||
dim3 block_dim(WARP_SIZE, warps_per_block);
|
||||
|
||||
if (softmax_len <= 32)
|
||||
ker_attn_softmax_bw<T, 1>
|
||||
<<<grid_dim, block_dim, 0, stream>>>(out_grad, soft_inp, softmax_len);
|
||||
else if (softmax_len <= 64)
|
||||
ker_attn_softmax_bw<T, 2>
|
||||
<<<grid_dim, block_dim, 0, stream>>>(out_grad, soft_inp, softmax_len);
|
||||
else if (softmax_len <= 128)
|
||||
ker_attn_softmax_bw<T, 4>
|
||||
<<<grid_dim, block_dim, 0, stream>>>(out_grad, soft_inp, softmax_len);
|
||||
else if (softmax_len <= 256)
|
||||
ker_attn_softmax_bw<T, 8>
|
||||
<<<grid_dim, block_dim, 0, stream>>>(out_grad, soft_inp, softmax_len);
|
||||
else if (softmax_len <= 384)
|
||||
ker_attn_softmax_bw<T, 12>
|
||||
<<<grid_dim, block_dim, 0, stream>>>(out_grad, soft_inp, softmax_len);
|
||||
else if (softmax_len <= 512)
|
||||
ker_attn_softmax_bw<T, 16>
|
||||
<<<grid_dim, block_dim, 0, stream>>>(out_grad, soft_inp, softmax_len);
|
||||
else if (softmax_len <= 768)
|
||||
ker_attn_softmax_bw<T, 24>
|
||||
<<<grid_dim, block_dim, 0, stream>>>(out_grad, soft_inp, softmax_len);
|
||||
else if (softmax_len <= 1024)
|
||||
ker_attn_softmax_bw<T, 32>
|
||||
<<<grid_dim, block_dim, 0, stream>>>(out_grad, soft_inp, softmax_len);
|
||||
else if (softmax_len <= 2048)
|
||||
ker_attn_softmax_bw<T, 64>
|
||||
<<<grid_dim, block_dim, 0, stream>>>(out_grad, soft_inp, softmax_len);
|
||||
else
|
||||
throw std::runtime_error(
|
||||
std::string(
|
||||
"Special sequence length found in softmax backward, seq_len: ") +
|
||||
std::to_string(softmax_len));
|
||||
}
|
||||
|
||||
template void launch_attn_softmax_bw<__half>(__half *out_grad,
|
||||
const __half *soft_inp, int rows,
|
||||
int softmax_len,
|
||||
cudaStream_t stream);
|
||||
template void launch_attn_softmax_bw<float>(float *out_grad,
|
||||
const float *soft_inp, int rows,
|
||||
int softmax_len,
|
||||
cudaStream_t stream);
|
|
@ -0,0 +1,314 @@
|
|||
#include <cub/block/block_load.cuh>
|
||||
#include <cub/block/block_scan.cuh>
|
||||
#include <cub/block/block_store.cuh>
|
||||
|
||||
#include "kernels.h"
|
||||
|
||||
using namespace cub;
|
||||
|
||||
/**
|
||||
@brief: transform_0213
|
||||
Split the attention heads and reshape input
|
||||
during backward progress of encoder self-attention
|
||||
|
||||
@thread
|
||||
gridDim.x = batch_size
|
||||
gridDim.y = seq_len
|
||||
blockDim.x = min(hidden_dim, MAX_THREADS)
|
||||
|
||||
@param
|
||||
input: [batch_size, seq_len, hidden_dim]
|
||||
output: [batch_size, nhead, seq_len, head_dim]
|
||||
batch_size: the size of the current batch
|
||||
seq_len: the sequence length of the current batch
|
||||
hidden_dim: dim of the hidden tensor
|
||||
nhead: number of attention heads
|
||||
*/
|
||||
|
||||
template <typename T>
|
||||
__global__ void transform_0213(T *output, const T *input, int hidden_dim,
|
||||
int head_dim);
|
||||
|
||||
template <>
|
||||
__global__ void transform_0213<float>(float *output, const float *input,
|
||||
int hidden_dim, int head_dim) {
|
||||
int batch_id = blockIdx.x;
|
||||
int token_id = blockIdx.y;
|
||||
int seq_len = gridDim.y;
|
||||
int nhead = hidden_dim / head_dim;
|
||||
|
||||
// [b, s, h]
|
||||
int src_offset = flat_3dim(batch_id, token_id, 0, seq_len, hidden_dim);
|
||||
// [b, nh, s, ad]
|
||||
int trg_offset =
|
||||
flat_4dim(batch_id, 0, token_id, 0, nhead, seq_len, head_dim);
|
||||
|
||||
const float4 *input4 = reinterpret_cast<const float4 *>(input);
|
||||
float4 *res4 = reinterpret_cast<float4 *>(output);
|
||||
float4 vinput4;
|
||||
|
||||
for (std::size_t i = threadIdx.x; i < hidden_dim; i += blockDim.x) {
|
||||
vinput4 = input4[src_offset + i];
|
||||
|
||||
int head_id = i / head_dim;
|
||||
int dim_id = i % head_dim;
|
||||
int cur_trg_offset = flat_3dim(head_id, 0, dim_id, seq_len, head_dim);
|
||||
res4[trg_offset + cur_trg_offset] = vinput4;
|
||||
}
|
||||
}
|
||||
|
||||
template <>
|
||||
__global__ void transform_0213<__half>(__half *output, const __half *input,
|
||||
int hidden_dim, int head_dim) {
|
||||
int batch_id = blockIdx.x;
|
||||
int token_id = blockIdx.y;
|
||||
int seq_len = gridDim.y;
|
||||
int nhead = hidden_dim / head_dim;
|
||||
|
||||
// [b, s, h]
|
||||
int src_offset = flat_3dim(batch_id, token_id, 0, seq_len, hidden_dim);
|
||||
// [b, nh, s, ad]
|
||||
int trg_offset =
|
||||
flat_4dim(batch_id, 0, token_id, 0, nhead, seq_len, head_dim);
|
||||
|
||||
const float4 *input4 = reinterpret_cast<const float4 *>(input);
|
||||
float4 *res4 = reinterpret_cast<float4 *>(output);
|
||||
float4 vinput4;
|
||||
|
||||
for (std::size_t i = threadIdx.x; i < hidden_dim; i += blockDim.x) {
|
||||
vinput4 = input4[src_offset + i];
|
||||
|
||||
int head_id = i / head_dim;
|
||||
int dim_id = i % head_dim;
|
||||
int cur_trg_offset = flat_3dim(head_id, 0, dim_id, seq_len, head_dim);
|
||||
res4[trg_offset + cur_trg_offset] = vinput4;
|
||||
}
|
||||
}
|
||||
|
||||
// [b, s, h] -> [b, nh, s, ad]
|
||||
template <>
|
||||
void launch_transform_0213<float>(float *output, const float *input,
|
||||
int batch_size, int seq_len, int hidden_dim,
|
||||
int nhead, cudaStream_t stream) {
|
||||
hidden_dim >>= 2;
|
||||
int head_dim = hidden_dim / nhead;
|
||||
|
||||
dim3 grid_dim(batch_size, seq_len);
|
||||
dim3 block_dim(min(hidden_dim, MAX_THREADS));
|
||||
|
||||
transform_0213<float>
|
||||
<<<grid_dim, block_dim, 0, stream>>>(output, input, hidden_dim, head_dim);
|
||||
}
|
||||
|
||||
template <>
|
||||
void launch_transform_0213<__half>(__half *output, const __half *input,
|
||||
int batch_size, int seq_len, int hidden_dim,
|
||||
int nhead, cudaStream_t stream) {
|
||||
hidden_dim >>= 3;
|
||||
int head_dim = hidden_dim / nhead;
|
||||
|
||||
dim3 grid_dim(batch_size, seq_len);
|
||||
dim3 block_dim(min(hidden_dim, MAX_THREADS));
|
||||
|
||||
transform_0213<__half>
|
||||
<<<grid_dim, block_dim, 0, stream>>>(output, input, hidden_dim, head_dim);
|
||||
}
|
||||
|
||||
/**
|
||||
@brief: bias_add_transform_20314
|
||||
Add bias to input, transform from
|
||||
[0, 1, 2, 3, 4] to [2, 0, 3, 1, 4]
|
||||
|
||||
@thread
|
||||
gridDim.x = dim_0
|
||||
gridDim.y = dim_1
|
||||
gridDim.z = dim_2
|
||||
blockDim.x = min(dim_3 * dim_4, MAX_THREADS)
|
||||
|
||||
@param
|
||||
input: [dim_0, dim_1, dim_2, dim_3, dim_4]
|
||||
bias: [dim_2, dim_3, dim_4]
|
||||
output: [dim_2, dim_0, dim_3, dim_1, dim_4]
|
||||
*/
|
||||
template <typename T>
|
||||
__global__ void bias_add_transform_20314(T *output, const T *input,
|
||||
const T *bias, int dim_3, int dim_4);
|
||||
|
||||
template <>
|
||||
__global__ void bias_add_transform_20314<float>(float *output,
|
||||
const float *input,
|
||||
const float *bias, int dim_3,
|
||||
int dim_4) {
|
||||
int id0 = blockIdx.x;
|
||||
int id1 = blockIdx.y;
|
||||
int id2 = blockIdx.z;
|
||||
int dim_0 = gridDim.x;
|
||||
int dim_1 = gridDim.y;
|
||||
int dim_2 = gridDim.z;
|
||||
int dim_34 = dim_3 * dim_4;
|
||||
|
||||
int src_offset = flat_4dim(id0, id1, id2, 0, dim_1, dim_2, dim_34);
|
||||
int trg_offset = flat_5dim(id2, id0, 0, id1, 0, dim_0, dim_3, dim_1, dim_4);
|
||||
int bias_offset = flat_2dim(id2, 0, dim_34);
|
||||
|
||||
const float4 *qkv4 = reinterpret_cast<const float4 *>(input);
|
||||
const float4 *bias4 = reinterpret_cast<const float4 *>(bias);
|
||||
float4 *res4 = reinterpret_cast<float4 *>(output);
|
||||
float4 vqkv4;
|
||||
float4 vbias4;
|
||||
float4 vres4;
|
||||
|
||||
for (std::size_t i = threadIdx.x; i < dim_34; i += blockDim.x) {
|
||||
vqkv4 = qkv4[src_offset + i];
|
||||
vbias4 = bias4[bias_offset + i];
|
||||
vres4.x = vqkv4.x + vbias4.x;
|
||||
vres4.y = vqkv4.y + vbias4.y;
|
||||
vres4.z = vqkv4.z + vbias4.z;
|
||||
vres4.w = vqkv4.w + vbias4.w;
|
||||
|
||||
int id3 = i / dim_4;
|
||||
int id4 = i % dim_4;
|
||||
int cur_trg_offset = flat_3dim(id3, 0, id4, dim_1, dim_4);
|
||||
res4[trg_offset + cur_trg_offset] = vres4;
|
||||
}
|
||||
}
|
||||
|
||||
template <>
|
||||
__global__ void bias_add_transform_20314<__half>(__half *output,
|
||||
const __half *input,
|
||||
const __half *bias, int dim_3,
|
||||
int dim_4) {
|
||||
int id0 = blockIdx.x;
|
||||
int id1 = blockIdx.y;
|
||||
int id2 = blockIdx.z;
|
||||
int dim_0 = gridDim.x;
|
||||
int dim_1 = gridDim.y;
|
||||
int dim_2 = gridDim.z;
|
||||
int dim_34 = dim_3 * dim_4;
|
||||
|
||||
int src_offset = flat_4dim(id0, id1, id2, 0, dim_1, dim_2, dim_34);
|
||||
int trg_offset = flat_5dim(id2, id0, 0, id1, 0, dim_0, dim_3, dim_1, dim_4);
|
||||
int bias_offset = flat_2dim(id2, 0, dim_34);
|
||||
|
||||
const float4 *qkv4 = reinterpret_cast<const float4 *>(input);
|
||||
const float4 *bias4 = reinterpret_cast<const float4 *>(bias);
|
||||
float4 *res4 = reinterpret_cast<float4 *>(output);
|
||||
float4 vqkv4;
|
||||
float4 vbias4;
|
||||
float4 vres4;
|
||||
__half2 *h2_qkv = reinterpret_cast<__half2 *>(&vqkv4);
|
||||
__half2 *h2_bias = reinterpret_cast<__half2 *>(&vbias4);
|
||||
__half2 *h2_res = reinterpret_cast<__half2 *>(&vres4);
|
||||
|
||||
for (std::size_t i = threadIdx.x; i < dim_34; i += blockDim.x) {
|
||||
vqkv4 = qkv4[src_offset + i];
|
||||
vbias4 = bias4[bias_offset + i];
|
||||
h2_res[0] = __hadd2(h2_qkv[0], h2_bias[0]);
|
||||
h2_res[1] = __hadd2(h2_qkv[1], h2_bias[1]);
|
||||
h2_res[2] = __hadd2(h2_qkv[2], h2_bias[2]);
|
||||
h2_res[3] = __hadd2(h2_qkv[3], h2_bias[3]);
|
||||
|
||||
int id3 = i / dim_4;
|
||||
int id4 = i % dim_4;
|
||||
int cur_trg_offset = flat_3dim(id3, 0, id4, dim_1, dim_4);
|
||||
res4[trg_offset + cur_trg_offset] = vres4;
|
||||
}
|
||||
}
|
||||
|
||||
// [b, s, 3, h] -> [3, b, nh, s, ad]
|
||||
template <>
|
||||
void launch_bias_add_transform_20314<float>(float *output, const float *input,
|
||||
const float *bias, int dim_0,
|
||||
int dim_1, int dim_2, int dim_3,
|
||||
int dim_4, cudaStream_t stream) {
|
||||
dim_4 >>= 2;
|
||||
|
||||
dim3 grid_dim(dim_0, dim_1, dim_2);
|
||||
dim3 block_dim(min(dim_3 * dim_4, MAX_THREADS));
|
||||
|
||||
bias_add_transform_20314<float>
|
||||
<<<grid_dim, block_dim, 0, stream>>>(output, input, bias, dim_3, dim_4);
|
||||
}
|
||||
|
||||
template <>
|
||||
void launch_bias_add_transform_20314<__half>(__half *output,
|
||||
const __half *input,
|
||||
const __half *bias, int dim_0,
|
||||
int dim_1, int dim_2, int dim_3,
|
||||
int dim_4, cudaStream_t stream) {
|
||||
dim_4 >>= 3;
|
||||
|
||||
dim3 grid_dim(dim_0, dim_1, dim_2);
|
||||
dim3 block_dim(min(dim_3 * dim_4, MAX_THREADS));
|
||||
|
||||
bias_add_transform_20314<__half>
|
||||
<<<grid_dim, block_dim, 0, stream>>>(output, input, bias, dim_3, dim_4);
|
||||
}
|
||||
|
||||
/**
|
||||
@brief: transform4d_0213
|
||||
Reshape the input matrix to merge the heads
|
||||
|
||||
@thread
|
||||
gridDim.x = (num_all + max_block_thread - 1) / max_block_thread
|
||||
blockDim.x = max_block_thread
|
||||
|
||||
@param
|
||||
input: [trans_count, batch_size, nhead, seq_len, head_dim]
|
||||
output: [batch_size, seq_len, trans_count, nhead, head_dim]
|
||||
batch_size: the size of the current batch
|
||||
seq_len: the sequence length of the current batch
|
||||
hidden_dim: dim of the hidden tensor
|
||||
nhead: number of attention heads
|
||||
trans_count: 1 or 3, the count of matrice need to be transformed
|
||||
*/
|
||||
template <typename T>
|
||||
__global__ void transform4d_0213(T *output, const T *input, int batch_size,
|
||||
int seq_len, int trans_count, int nhead,
|
||||
int head_dim, int num_all) {
|
||||
int offset = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
if (offset >= num_all) {
|
||||
return;
|
||||
}
|
||||
int trans_id, batch_id, head_id, token_id, dim_id;
|
||||
decompose_5dim(offset, batch_size, nhead, seq_len, head_dim, &trans_id,
|
||||
&batch_id, &head_id, &token_id, &dim_id);
|
||||
// [b, s, tc, nh, ad]
|
||||
int trg_offset = flat_5dim(batch_id, token_id, trans_id, head_id, dim_id,
|
||||
seq_len, trans_count, nhead, head_dim);
|
||||
|
||||
const float4 *input4 = reinterpret_cast<const float4 *>(input);
|
||||
float4 *res4 = reinterpret_cast<float4 *>(output);
|
||||
res4[trg_offset] = input4[offset];
|
||||
}
|
||||
|
||||
// [tc, b, nh, s, ad] -> [b, s, tc, nh, ad]
|
||||
template <>
|
||||
void launch_transform4d_0213<float>(float *output, const float *input,
|
||||
int batch_size, int seq_len, int hidden_dim,
|
||||
int nhead, int trans_count,
|
||||
cudaStream_t stream) {
|
||||
hidden_dim >>= 2;
|
||||
int head_dim = hidden_dim / nhead;
|
||||
int num_all = batch_size * seq_len * trans_count * hidden_dim;
|
||||
int nblock = (num_all + MAX_THREADS - 1) / MAX_THREADS;
|
||||
|
||||
transform4d_0213<float><<<nblock, MAX_THREADS, 0, stream>>>(
|
||||
output, input, batch_size, seq_len, trans_count, nhead, head_dim,
|
||||
num_all);
|
||||
}
|
||||
|
||||
template <>
|
||||
void launch_transform4d_0213<__half>(__half *output, const __half *input,
|
||||
int batch_size, int seq_len,
|
||||
int hidden_dim, int nhead, int trans_count,
|
||||
cudaStream_t stream) {
|
||||
hidden_dim >>= 3;
|
||||
int head_dim = hidden_dim / nhead;
|
||||
int num_all = batch_size * seq_len * trans_count * hidden_dim;
|
||||
int nblock = (num_all + MAX_THREADS - 1) / MAX_THREADS;
|
||||
|
||||
transform4d_0213<__half><<<nblock, MAX_THREADS, 0, stream>>>(
|
||||
output, input, batch_size, seq_len, trans_count, nhead, head_dim,
|
||||
num_all);
|
||||
}
|
|
@ -0,0 +1,185 @@
|
|||
/*This code from NVIDIA apex:
|
||||
* https://github.com/NVIDIA/apex
|
||||
* with minor changes. */
|
||||
|
||||
#include <torch/extension.h>
|
||||
#include <vector>
|
||||
#include <cassert>
|
||||
#include "compat.h"
|
||||
|
||||
namespace {
|
||||
|
||||
void compute_n1_n2(
|
||||
at::Tensor input,
|
||||
at::IntArrayRef normalized_shape,
|
||||
int& n1,
|
||||
int& n2) {
|
||||
int idiff = input.ndimension() - normalized_shape.size();
|
||||
n2 = 1;
|
||||
for (int i = 0; i < (int)normalized_shape.size(); ++i) {
|
||||
assert( input.sizes()[i+idiff] == normalized_shape[i] );
|
||||
n2 *= normalized_shape[i];
|
||||
}
|
||||
n1 = 1;
|
||||
for (int i = 0; i < idiff; ++i) {
|
||||
n1 *= input.sizes()[i];
|
||||
}
|
||||
}
|
||||
|
||||
void check_args(
|
||||
at::IntArrayRef normalized_shape,
|
||||
at::Tensor gamma,
|
||||
at::Tensor beta
|
||||
)
|
||||
{
|
||||
TORCH_CHECK(!gamma.defined() || gamma.sizes().equals(normalized_shape));
|
||||
TORCH_CHECK(!beta.defined() || beta.sizes().equals(normalized_shape));
|
||||
}
|
||||
|
||||
void check_args(
|
||||
at::Tensor input,
|
||||
at::IntArrayRef normalized_shape,
|
||||
int& n1,
|
||||
int& n2
|
||||
)
|
||||
{
|
||||
int64_t normalized_ndim = normalized_shape.size();
|
||||
|
||||
if (normalized_ndim < 1) {
|
||||
std::stringstream ss;
|
||||
ss << "Expected normalized_shape to be at least 1-dimensional, i.e., "
|
||||
<< "containing at least one element, but got normalized_shape="
|
||||
<< normalized_shape;
|
||||
throw std::runtime_error(ss.str());
|
||||
}
|
||||
|
||||
auto input_shape = input.sizes();
|
||||
auto input_ndim = input.dim();
|
||||
|
||||
if (input_ndim < normalized_ndim ||
|
||||
!input_shape.slice(input_ndim - normalized_ndim).equals(normalized_shape)) {
|
||||
std::stringstream ss;
|
||||
ss << "Given normalized_shape=" << normalized_shape
|
||||
<< ", expected input with shape [*";
|
||||
for (auto size : normalized_shape) {
|
||||
ss << ", " << size;
|
||||
}
|
||||
ss << "], but got input of size" << input_shape;
|
||||
throw std::runtime_error(ss.str());
|
||||
}
|
||||
|
||||
compute_n1_n2(input,normalized_shape,n1,n2);
|
||||
}
|
||||
|
||||
|
||||
void check_args(
|
||||
at::Tensor input,
|
||||
at::IntArrayRef normalized_shape,
|
||||
at::Tensor gamma,
|
||||
at::Tensor beta,
|
||||
int& n1,
|
||||
int& n2
|
||||
)
|
||||
{
|
||||
check_args(input,normalized_shape,n1,n2);
|
||||
check_args(normalized_shape,gamma,beta);
|
||||
}
|
||||
}
|
||||
|
||||
void cuda_layer_norm(
|
||||
at::Tensor* output,
|
||||
at::Tensor* mean,
|
||||
at::Tensor* invvar,
|
||||
at::Tensor* input,
|
||||
int n1,
|
||||
int n2,
|
||||
at::IntArrayRef normalized_shape,
|
||||
at::Tensor* gamma,
|
||||
at::Tensor* beta,
|
||||
double epsilon);
|
||||
|
||||
#define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor")
|
||||
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
|
||||
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
|
||||
|
||||
std::vector<at::Tensor> layer_norm_affine(
|
||||
at::Tensor input,
|
||||
at::IntArrayRef normalized_shape,
|
||||
at::Tensor gamma,
|
||||
at::Tensor beta,
|
||||
double epsilon) {
|
||||
|
||||
CHECK_INPUT(input);
|
||||
CHECK_INPUT(gamma);
|
||||
CHECK_INPUT(beta);
|
||||
int n1, n2;
|
||||
check_args(input, normalized_shape, gamma, beta, n1, n2);
|
||||
|
||||
at::Tensor output = at::empty_like(
|
||||
input, gamma.options().dtype(gamma.scalar_type()));
|
||||
at::Tensor mean = at::empty(
|
||||
{n1}, input.options().dtype(at::ScalarType::Float));
|
||||
at::Tensor invvar = at::empty_like(mean);
|
||||
|
||||
cuda_layer_norm(&output, &mean, &invvar, &input, n1, n2,
|
||||
normalized_shape, &gamma, &beta, epsilon);
|
||||
|
||||
return {output, mean, invvar};
|
||||
|
||||
}
|
||||
|
||||
|
||||
void cuda_layer_norm_gradient(
|
||||
at::Tensor* dout,
|
||||
at::Tensor* mean,
|
||||
at::Tensor* invvar,
|
||||
at::Tensor* input,
|
||||
int n1,
|
||||
int n2,
|
||||
at::IntArrayRef normalized_shape,
|
||||
at::Tensor* gamma,
|
||||
at::Tensor* beta,
|
||||
double epsilon,
|
||||
at::Tensor* grad_input,
|
||||
at::Tensor* grad_gamma,
|
||||
at::Tensor* grad_beta
|
||||
);
|
||||
|
||||
std::vector<at::Tensor> layer_norm_gradient_affine(
|
||||
at::Tensor dout,
|
||||
at::Tensor mean,
|
||||
at::Tensor invvar,
|
||||
at::Tensor input,
|
||||
at::IntArrayRef normalized_shape,
|
||||
at::Tensor gamma,
|
||||
at::Tensor beta,
|
||||
double epsilon) {
|
||||
|
||||
CHECK_INPUT(dout);
|
||||
CHECK_INPUT(mean);
|
||||
CHECK_INPUT(invvar);
|
||||
CHECK_INPUT(input);
|
||||
CHECK_INPUT(gamma);
|
||||
CHECK_INPUT(beta);
|
||||
int n1, n2;
|
||||
check_args(input, normalized_shape, gamma, beta, n1, n2);
|
||||
|
||||
at::Tensor grad_input = at::empty_like(input);
|
||||
at::Tensor grad_gamma = at::empty_like(gamma);
|
||||
at::Tensor grad_beta = at::empty_like(beta);
|
||||
|
||||
cuda_layer_norm_gradient(&dout, &mean, &invvar, &input, n1, n2,
|
||||
normalized_shape, &gamma, &beta, epsilon,
|
||||
&grad_input, &grad_gamma, &grad_beta);
|
||||
|
||||
return {grad_input, grad_gamma, grad_beta};
|
||||
|
||||
}
|
||||
|
||||
|
||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||
m.def("forward_affine", &layer_norm_affine,
|
||||
"LayerNorm forward (CUDA)");
|
||||
m.def("backward_affine", &layer_norm_gradient_affine,
|
||||
"LayerNorm backward (CUDA)");
|
||||
}
|
|
@ -0,0 +1,813 @@
|
|||
/*This code from NVIDIA apex:
|
||||
* https://github.com/NVIDIA/apex
|
||||
* with minor changes. */
|
||||
|
||||
#include "ATen/ATen.h"
|
||||
#include "ATen/AccumulateType.h"
|
||||
#include "ATen/cuda/CUDAContext.h"
|
||||
#include <THC/THCDeviceUtils.cuh>
|
||||
|
||||
#include <cuda.h>
|
||||
#include <cuda_runtime.h>
|
||||
|
||||
#include "type_shim.h"
|
||||
|
||||
template<typename U> __device__
|
||||
void cuWelfordOnlineSum(
|
||||
const U curr,
|
||||
U& mu,
|
||||
U& sigma2,
|
||||
U& count)
|
||||
{
|
||||
count = count + U(1);
|
||||
U delta = curr - mu;
|
||||
U lmean = mu + delta / count;
|
||||
mu = lmean;
|
||||
U delta2 = curr - lmean;
|
||||
sigma2 = sigma2 + delta * delta2;
|
||||
}
|
||||
|
||||
template<typename U> __device__
|
||||
void cuChanOnlineSum(
|
||||
const U muB,
|
||||
const U sigma2B,
|
||||
const U countB,
|
||||
U& mu,
|
||||
U& sigma2,
|
||||
U& count)
|
||||
{
|
||||
U delta = muB - mu;
|
||||
U nA = count;
|
||||
U nB = countB;
|
||||
count = count + countB;
|
||||
U nX = count;
|
||||
if (nX > U(0)) {
|
||||
nA = nA / nX;
|
||||
nB = nB / nX;
|
||||
mu = nA*mu + nB*muB;
|
||||
sigma2 = sigma2 + sigma2B + delta * delta * nA * nB * nX;
|
||||
} else {
|
||||
mu = U(0);
|
||||
sigma2 = U(0);
|
||||
}
|
||||
}
|
||||
|
||||
template<typename T, typename U> __device__
|
||||
void cuWelfordMuSigma2(
|
||||
const T* __restrict__ vals,
|
||||
const int n1,
|
||||
const int n2,
|
||||
const int i1,
|
||||
U& mu,
|
||||
U& sigma2,
|
||||
U* buf)
|
||||
{
|
||||
// Assumptions:
|
||||
// 1) blockDim.x == warpSize
|
||||
// 2) Tensor is contiguous
|
||||
// 3) 2*blockDim.y*sizeof(U)+blockDim.y*sizeof(int) shared memory available.
|
||||
//
|
||||
// compute variance and mean over n2
|
||||
U count = U(0);
|
||||
mu= U(0);
|
||||
sigma2 = U(0);
|
||||
if (i1 < n1) {
|
||||
// one warp normalizes one n1 index,
|
||||
// synchronization is implicit
|
||||
// initialize with standard Welford algorithm
|
||||
const int numx = blockDim.x * blockDim.y;
|
||||
const int thrx = threadIdx.x + threadIdx.y * blockDim.x;
|
||||
const T* lvals = vals + i1*n2;
|
||||
int l = 4*thrx;
|
||||
for (; l+3 < n2; l+=4*numx) {
|
||||
for (int k = 0; k < 4; ++k) {
|
||||
U curr = static_cast<U>(lvals[l+k]);
|
||||
cuWelfordOnlineSum<U>(curr,mu,sigma2,count);
|
||||
}
|
||||
}
|
||||
for (; l < n2; ++l) {
|
||||
U curr = static_cast<U>(lvals[l]);
|
||||
cuWelfordOnlineSum<U>(curr,mu,sigma2,count);
|
||||
}
|
||||
// intra-warp reductions
|
||||
for (int l = 0; l <= 4; ++l) {
|
||||
int srcLaneB = (threadIdx.x+(1<<l))&31;
|
||||
U muB = WARP_SHFL(mu, srcLaneB);
|
||||
U countB = WARP_SHFL(count, srcLaneB);
|
||||
U sigma2B = WARP_SHFL(sigma2, srcLaneB);
|
||||
cuChanOnlineSum<U>(muB,sigma2B,countB,mu,sigma2,count);
|
||||
}
|
||||
// threadIdx.x == 0 has correct values for each warp
|
||||
// inter-warp reductions
|
||||
if (blockDim.y > 1) {
|
||||
U* ubuf = (U*)buf;
|
||||
U* ibuf = (U*)(ubuf + blockDim.y);
|
||||
for (int offset = blockDim.y/2; offset > 0; offset /= 2) {
|
||||
// upper half of warps write to shared
|
||||
if (threadIdx.x == 0 && threadIdx.y >= offset && threadIdx.y < 2*offset) {
|
||||
const int wrt_y = threadIdx.y - offset;
|
||||
ubuf[2*wrt_y] = mu;
|
||||
ubuf[2*wrt_y+1] = sigma2;
|
||||
ibuf[wrt_y] = count;
|
||||
}
|
||||
__syncthreads();
|
||||
// lower half merges
|
||||
if (threadIdx.x == 0 && threadIdx.y < offset) {
|
||||
U muB = ubuf[2*threadIdx.y];
|
||||
U sigma2B = ubuf[2*threadIdx.y+1];
|
||||
U countB = ibuf[threadIdx.y];
|
||||
cuChanOnlineSum<U>(muB,sigma2B,countB,mu,sigma2,count);
|
||||
}
|
||||
__syncthreads();
|
||||
}
|
||||
// threadIdx.x = 0 && threadIdx.y == 0 only thread that has correct values
|
||||
if (threadIdx.x == 0 && threadIdx.y == 0) {
|
||||
ubuf[0] = mu;
|
||||
ubuf[1] = sigma2;
|
||||
}
|
||||
__syncthreads();
|
||||
mu = ubuf[0];
|
||||
sigma2 = ubuf[1]/U(n2);
|
||||
// don't care about final value of count, we know count == n2
|
||||
} else {
|
||||
mu = WARP_SHFL(mu, 0);
|
||||
sigma2 = WARP_SHFL(sigma2/U(n2), 0);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template<> __device__
|
||||
void cuWelfordMuSigma2(
|
||||
const at::Half* __restrict__ vals,
|
||||
const int n1,
|
||||
const int n2,
|
||||
const int i1,
|
||||
float& mu,
|
||||
float& sigma2,
|
||||
float* buf)
|
||||
{
|
||||
// Assumptions:
|
||||
// 1) blockDim.x == warpSize
|
||||
// 2) Tensor is contiguous
|
||||
// 3) 2*blockDim.y*sizeof(U)+blockDim.y*sizeof(int) shared memory available.
|
||||
//
|
||||
// compute variance and mean over n2
|
||||
float count = 0.0f;
|
||||
mu= float(0);
|
||||
sigma2 = float(0);
|
||||
if (i1 < n1) {
|
||||
// one warp normalizes one n1 index,
|
||||
// synchronization is implicit
|
||||
// initialize with standard Welford algorithm
|
||||
const int numx = blockDim.x * blockDim.y;
|
||||
const int thrx = threadIdx.x + threadIdx.y * blockDim.x;
|
||||
const at::Half* lvals = vals + i1*n2;
|
||||
int l = 8*thrx;
|
||||
if ((((size_t)lvals)&3) != 0) {
|
||||
// 16 bit alignment
|
||||
// first thread consumes first point
|
||||
if (thrx == 0) {
|
||||
float curr = static_cast<float>(lvals[0]);
|
||||
cuWelfordOnlineSum(curr,mu,sigma2,count);
|
||||
}
|
||||
++l;
|
||||
}
|
||||
// at this point, lvals[l] are 32 bit aligned for all threads.
|
||||
for (; l+7 < n2; l+=8*numx) {
|
||||
for (int k = 0; k < 8; k+=2) {
|
||||
float2 curr = __half22float2(*((__half2*)(lvals+l+k)));
|
||||
cuWelfordOnlineSum(curr.x,mu,sigma2,count);
|
||||
cuWelfordOnlineSum(curr.y,mu,sigma2,count);
|
||||
}
|
||||
}
|
||||
for (; l < n2; ++l) {
|
||||
float curr = static_cast<float>(lvals[l]);
|
||||
cuWelfordOnlineSum(curr,mu,sigma2,count);
|
||||
}
|
||||
// intra-warp reductions
|
||||
for (int l = 0; l <= 4; ++l) {
|
||||
int srcLaneB = (threadIdx.x+(1<<l))&31;
|
||||
float muB = WARP_SHFL(mu, srcLaneB);
|
||||
float countB = WARP_SHFL(count, srcLaneB);
|
||||
float sigma2B = WARP_SHFL(sigma2, srcLaneB);
|
||||
cuChanOnlineSum(muB,sigma2B,countB,mu,sigma2,count);
|
||||
}
|
||||
// threadIdx.x == 0 has correct values for each warp
|
||||
// inter-warp reductions
|
||||
if (blockDim.y > 1) {
|
||||
float* ubuf = (float*)buf;
|
||||
float* ibuf = (float*)(ubuf + blockDim.y);
|
||||
for (int offset = blockDim.y/2; offset > 0; offset /= 2) {
|
||||
// upper half of warps write to shared
|
||||
if (threadIdx.x == 0 && threadIdx.y >= offset && threadIdx.y < 2*offset) {
|
||||
const int wrt_y = threadIdx.y - offset;
|
||||
ubuf[2*wrt_y] = mu;
|
||||
ubuf[2*wrt_y+1] = sigma2;
|
||||
ibuf[wrt_y] = count;
|
||||
}
|
||||
__syncthreads();
|
||||
// lower half merges
|
||||
if (threadIdx.x == 0 && threadIdx.y < offset) {
|
||||
float muB = ubuf[2*threadIdx.y];
|
||||
float sigma2B = ubuf[2*threadIdx.y+1];
|
||||
float countB = ibuf[threadIdx.y];
|
||||
cuChanOnlineSum(muB,sigma2B,countB,mu,sigma2,count);
|
||||
}
|
||||
__syncthreads();
|
||||
}
|
||||
// threadIdx.x = 0 && threadIdx.y == 0 only thread that has correct values
|
||||
if (threadIdx.x == 0 && threadIdx.y == 0) {
|
||||
ubuf[0] = mu;
|
||||
ubuf[1] = sigma2;
|
||||
}
|
||||
__syncthreads();
|
||||
mu = ubuf[0];
|
||||
sigma2 = ubuf[1]/float(n2);
|
||||
// don't care about final value of count, we know count == n2
|
||||
} else {
|
||||
mu = WARP_SHFL(mu, 0);
|
||||
sigma2 = WARP_SHFL(sigma2/float(n2), 0);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template<typename U> U rsqrt(U v) {
|
||||
return U(1) / sqrt(v);
|
||||
}
|
||||
template<> float rsqrt(float v) {
|
||||
return rsqrtf(v);
|
||||
}
|
||||
template<> double rsqrt(double v) {
|
||||
return rsqrt(v);
|
||||
}
|
||||
|
||||
namespace {
|
||||
// This is the un-specialized struct. Note that we prevent instantiation of this
|
||||
// struct by putting an undefined symbol in the function body so it won't compile.
|
||||
// template <typename T>
|
||||
// struct SharedMemory
|
||||
// {
|
||||
// // Ensure that we won't compile any un-specialized types
|
||||
// __device__ T *getPointer()
|
||||
// {
|
||||
// extern __device__ void error(void);
|
||||
// error();
|
||||
// return NULL;
|
||||
// }
|
||||
// };
|
||||
// https://github.com/NVIDIA/apex/issues/246
|
||||
template <typename T>
|
||||
struct SharedMemory;
|
||||
|
||||
template <>
|
||||
struct SharedMemory <float>
|
||||
{
|
||||
__device__ float *getPointer()
|
||||
{
|
||||
extern __shared__ float s_float[];
|
||||
return s_float;
|
||||
}
|
||||
};
|
||||
|
||||
}
|
||||
|
||||
template<typename T, typename U, typename V> __global__
|
||||
void cuApplyLayerNorm(
|
||||
V* __restrict__ output_vals,
|
||||
U* __restrict__ mean,
|
||||
U* __restrict__ invvar,
|
||||
const T* __restrict__ vals,
|
||||
const int n1,
|
||||
const int n2,
|
||||
const U epsilon,
|
||||
const V* __restrict__ gamma,
|
||||
const V* __restrict__ beta
|
||||
)
|
||||
{
|
||||
// Assumptions:
|
||||
// 1) blockDim.x == warpSize
|
||||
// 2) Tensors are contiguous
|
||||
//
|
||||
for (auto i1=blockIdx.y; i1 < n1; i1 += gridDim.y) {
|
||||
SharedMemory<U> shared;
|
||||
U* buf = shared.getPointer();
|
||||
U mu,sigma2;
|
||||
cuWelfordMuSigma2(vals,n1,n2,i1,mu,sigma2,buf);
|
||||
const T* lvals = vals + i1*n2;
|
||||
V* ovals = output_vals + i1*n2;
|
||||
U c_invvar = rsqrt(sigma2 + epsilon);
|
||||
const int numx = blockDim.x * blockDim.y;
|
||||
const int thrx = threadIdx.x + threadIdx.y * blockDim.x;
|
||||
if (gamma != NULL && beta != NULL) {
|
||||
for (int i = thrx; i < n2; i+=numx) {
|
||||
U curr = static_cast<U>(lvals[i]);
|
||||
ovals[i] = gamma[i] * static_cast<V>(c_invvar * (curr - mu)) + beta[i];
|
||||
}
|
||||
} else {
|
||||
for (int i = thrx; i < n2; i+=numx) {
|
||||
U curr = static_cast<U>(lvals[i]);
|
||||
ovals[i] = static_cast<V>(c_invvar * (curr - mu));
|
||||
}
|
||||
}
|
||||
if (threadIdx.x == 0 && threadIdx.y == 0) {
|
||||
mean[i1] = mu;
|
||||
invvar[i1] = c_invvar;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template<typename T, typename U, typename V> __device__
|
||||
void cuLoadWriteStridedInputs(
|
||||
const int i1_block,
|
||||
const int thr_load_row_off,
|
||||
const int thr_load_col_off,
|
||||
const int i2_off,
|
||||
const int row_stride,
|
||||
U* warp_buf1,
|
||||
U* warp_buf2,
|
||||
const T* input,
|
||||
const V* dout,
|
||||
const int i1_end,
|
||||
const int n2,
|
||||
const U* __restrict__ mean,
|
||||
const U* __restrict__ invvar
|
||||
)
|
||||
{
|
||||
int i1 = i1_block+thr_load_row_off;
|
||||
if (i1 < i1_end) {
|
||||
U curr_mean = mean[i1];
|
||||
U curr_invvar = invvar[i1];
|
||||
for (int k = 0; k < blockDim.y; ++k) {
|
||||
int i2 = i2_off + k;
|
||||
int load_idx = i1*n2+i2;
|
||||
int write_idx = thr_load_row_off*row_stride+thr_load_col_off+k;
|
||||
if (i2<n2) {
|
||||
U curr_input = static_cast<U>(input[load_idx]);
|
||||
U curr_dout = static_cast<U>(dout[load_idx]);
|
||||
warp_buf1[write_idx] = curr_dout;
|
||||
warp_buf2[write_idx] = curr_dout * (curr_input - curr_mean) * curr_invvar;
|
||||
} else {
|
||||
warp_buf1[write_idx] = U(0);
|
||||
warp_buf2[write_idx] = U(0);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
for (int k = 0; k < blockDim.y; ++k) {
|
||||
int write_idx = thr_load_row_off*row_stride+thr_load_col_off+k;
|
||||
warp_buf1[write_idx] = U(0);
|
||||
warp_buf2[write_idx] = U(0);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template<typename T, typename U, typename V> __device__
|
||||
void cuLoadAddStridedInputs(
|
||||
const int i1_block,
|
||||
const int thr_load_row_off,
|
||||
const int thr_load_col_off,
|
||||
const int i2_off,
|
||||
const int row_stride,
|
||||
U* warp_buf1,
|
||||
U* warp_buf2,
|
||||
const T* input,
|
||||
const V* dout,
|
||||
const int i1_end,
|
||||
const int n2,
|
||||
const U* __restrict__ mean,
|
||||
const U* __restrict__ invvar
|
||||
)
|
||||
{
|
||||
int i1 = i1_block+thr_load_row_off;
|
||||
if (i1 < i1_end) {
|
||||
U curr_mean = mean[i1];
|
||||
U curr_invvar = invvar[i1];
|
||||
for (int k = 0; k < blockDim.y; ++k) {
|
||||
int i2 = i2_off + k;
|
||||
int load_idx = i1*n2+i2;
|
||||
int write_idx = thr_load_row_off*row_stride+thr_load_col_off+k;
|
||||
if (i2<n2) {
|
||||
U curr_input = static_cast<U>(input[load_idx]);
|
||||
U curr_dout = static_cast<U>(dout[load_idx]);
|
||||
warp_buf1[write_idx] += curr_dout;
|
||||
warp_buf2[write_idx] += curr_dout * (curr_input - curr_mean) * curr_invvar;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template<typename T, typename U, typename V> __global__
|
||||
void cuComputePartGradGammaBeta(
|
||||
const V* __restrict__ dout,
|
||||
const T* __restrict__ input,
|
||||
const int n1,
|
||||
const int n2,
|
||||
const U* __restrict__ mean,
|
||||
const U* __restrict__ invvar,
|
||||
U epsilon,
|
||||
U* part_grad_gamma,
|
||||
U* part_grad_beta)
|
||||
{
|
||||
const int numsegs_n1 = (n1+blockDim.y*blockDim.y-1) / (blockDim.y*blockDim.y);
|
||||
const int segs_per_block = (numsegs_n1 + gridDim.y - 1) / gridDim.y;
|
||||
const int i1_beg = blockIdx.y * segs_per_block * blockDim.y*blockDim.y;
|
||||
const int i1_beg_plus_one = (blockIdx.y+1) * segs_per_block * blockDim.y*blockDim.y;
|
||||
const int i1_end = i1_beg_plus_one < n1 ? i1_beg_plus_one : n1;
|
||||
const int row_stride = blockDim.x+1;
|
||||
const int thr_load_col_off = (threadIdx.x*blockDim.y)&(blockDim.x-1);
|
||||
const int thr_load_row_off = (threadIdx.x*blockDim.y)/blockDim.x + threadIdx.y*blockDim.y;
|
||||
const int i2_off = blockIdx.x * blockDim.x + thr_load_col_off;
|
||||
SharedMemory<U> shared;
|
||||
U* buf = shared.getPointer(); // buf has at least blockDim.x * blockDim.y * blockDim.y + (blockDim.y - 1)*(blockDim.x/blockDim.y) elements
|
||||
U* warp_buf1 = (U*)buf;
|
||||
U* warp_buf2 = warp_buf1 + blockDim.y * blockDim.y * row_stride;
|
||||
// compute partial sums from strided inputs
|
||||
// do this to increase number of loads in flight
|
||||
cuLoadWriteStridedInputs(i1_beg,thr_load_row_off,thr_load_col_off,i2_off,row_stride,warp_buf1,warp_buf2,input,dout,i1_end,n2,mean,invvar);
|
||||
for (int i1_block = i1_beg+blockDim.y*blockDim.y; i1_block < i1_end; i1_block+=blockDim.y*blockDim.y) {
|
||||
cuLoadAddStridedInputs(i1_block,thr_load_row_off,thr_load_col_off,i2_off,row_stride,warp_buf1,warp_buf2,input,dout,i1_end,n2,mean,invvar);
|
||||
}
|
||||
__syncthreads();
|
||||
// inter-warp reductions
|
||||
// sum within each warp
|
||||
U acc1 = U(0);
|
||||
U acc2 = U(0);
|
||||
for (int k = 0; k < blockDim.y; ++k) {
|
||||
int row1 = threadIdx.y + k*blockDim.y;
|
||||
int idx1 = row1*row_stride + threadIdx.x;
|
||||
acc1 += warp_buf1[idx1];
|
||||
acc2 += warp_buf2[idx1];
|
||||
}
|
||||
warp_buf1[threadIdx.y*row_stride+threadIdx.x] = acc1;
|
||||
warp_buf2[threadIdx.y*row_stride+threadIdx.x] = acc2;
|
||||
__syncthreads();
|
||||
// sum all warps
|
||||
for (int offset = blockDim.y/2; offset > 1; offset /= 2) {
|
||||
if (threadIdx.y < offset) {
|
||||
int row1 = threadIdx.y;
|
||||
int row2 = threadIdx.y + offset;
|
||||
int idx1 = row1*row_stride + threadIdx.x;
|
||||
int idx2 = row2*row_stride + threadIdx.x;
|
||||
warp_buf1[idx1] += warp_buf1[idx2];
|
||||
warp_buf2[idx1] += warp_buf2[idx2];
|
||||
}
|
||||
__syncthreads();
|
||||
}
|
||||
int i2 = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
if (threadIdx.y == 0 && i2 < n2) {
|
||||
int row1 = threadIdx.y;
|
||||
int row2 = threadIdx.y + 1;
|
||||
int idx1 = row1*row_stride + threadIdx.x;
|
||||
int idx2 = row2*row_stride + threadIdx.x;
|
||||
part_grad_beta[blockIdx.y*n2+i2] = warp_buf1[idx1] + warp_buf1[idx2];
|
||||
part_grad_gamma[blockIdx.y*n2+i2] = warp_buf2[idx1] + warp_buf2[idx2];
|
||||
}
|
||||
}
|
||||
|
||||
template<typename U, typename V> __global__
|
||||
void cuComputeGradGammaBeta(
|
||||
const U* part_grad_gamma,
|
||||
const U* part_grad_beta,
|
||||
const int part_size,
|
||||
const int n1,
|
||||
const int n2,
|
||||
V* grad_gamma,
|
||||
V* grad_beta)
|
||||
{
|
||||
// sum partial gradients for gamma and beta
|
||||
SharedMemory<U> shared;
|
||||
U* buf = shared.getPointer();
|
||||
int i2 = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
if (i2 < n2) {
|
||||
// each warp does sequential reductions until reduced part_size is num_warps
|
||||
int num_warp_reductions = part_size / blockDim.y;
|
||||
U sum_gamma = U(0);
|
||||
U sum_beta = U(0);
|
||||
const U* part_grad_gamma_ptr = part_grad_gamma + threadIdx.y * num_warp_reductions * n2 + i2;
|
||||
const U* part_grad_beta_ptr = part_grad_beta + threadIdx.y * num_warp_reductions * n2 + i2;
|
||||
for (int warp_offset = 0; warp_offset < num_warp_reductions; ++warp_offset) {
|
||||
sum_gamma += part_grad_gamma_ptr[warp_offset*n2];
|
||||
sum_beta += part_grad_beta_ptr[warp_offset*n2];
|
||||
}
|
||||
// inter-warp reductions
|
||||
const int nbsize3 = blockDim.x * blockDim.y / 2;
|
||||
for (int offset = blockDim.y/2; offset >= 1; offset /= 2) {
|
||||
// top half write to shared memory
|
||||
if (threadIdx.y >= offset && threadIdx.y < 2*offset) {
|
||||
const int write_idx = (threadIdx.y - offset) * blockDim.x + threadIdx.x;
|
||||
buf[write_idx] = sum_gamma;
|
||||
buf[write_idx+nbsize3] = sum_beta;
|
||||
}
|
||||
__syncthreads();
|
||||
// bottom half sums
|
||||
if (threadIdx.y < offset) {
|
||||
const int read_idx = threadIdx.y * blockDim.x + threadIdx.x;
|
||||
sum_gamma += buf[read_idx];
|
||||
sum_beta += buf[read_idx+nbsize3];
|
||||
}
|
||||
__syncthreads();
|
||||
}
|
||||
// write out fully summed gradients
|
||||
if (threadIdx.y == 0) {
|
||||
grad_gamma[i2] = sum_gamma;
|
||||
grad_beta[i2] = sum_beta;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template<typename T, typename U, typename V> __global__
|
||||
void cuComputeGradInput(
|
||||
const V* __restrict__ dout,
|
||||
const T* __restrict__ input,
|
||||
const int n1,
|
||||
const int n2,
|
||||
const U* __restrict__ mean,
|
||||
const U* __restrict__ invvar,
|
||||
U epsilon,
|
||||
const V* gamma,
|
||||
T* grad_input)
|
||||
{
|
||||
for (auto i1=blockIdx.y; i1 < n1; i1 += gridDim.y) {
|
||||
U sum_loss1 = U(0);
|
||||
U sum_loss2 = U(0);
|
||||
const U c_mean = mean[i1];
|
||||
const U c_invvar = invvar[i1];
|
||||
const T* k_input = input + i1*n2;
|
||||
const V* k_dout = dout + i1*n2;
|
||||
const int numx = blockDim.x * blockDim.y;
|
||||
const int thrx = threadIdx.x + threadIdx.y * blockDim.x;
|
||||
if (gamma != NULL) {
|
||||
int l = 4*thrx;
|
||||
for (; l+3 < n2; l+=4*numx) {
|
||||
for (int k = 0; k < 4; ++k) {
|
||||
const U c_h = static_cast<U>(k_input[l+k]);
|
||||
const U c_loss = static_cast<U>(k_dout[l+k]);
|
||||
sum_loss1 += c_loss * gamma[l+k];
|
||||
sum_loss2 += c_loss * gamma[l+k] * (c_h - c_mean) * c_invvar;
|
||||
}
|
||||
}
|
||||
for (; l < n2; ++l) {
|
||||
const U c_h = static_cast<U>(k_input[l]);
|
||||
const U c_loss = static_cast<U>(k_dout[l]);
|
||||
sum_loss1 += c_loss * gamma[l];
|
||||
sum_loss2 += c_loss * gamma[l] * (c_h - c_mean) * c_invvar;
|
||||
}
|
||||
} else {
|
||||
int l = 4*thrx;
|
||||
for (; l+3 < n2; l+=4*numx) {
|
||||
for (int k = 0; k < 4; ++k) {
|
||||
const U c_h = static_cast<U>(k_input[l+k]);
|
||||
const U c_loss = static_cast<U>(k_dout[l+k]);
|
||||
sum_loss1 += c_loss;
|
||||
sum_loss2 += c_loss * (c_h - c_mean) * c_invvar;
|
||||
}
|
||||
}
|
||||
for (; l < n2; ++l) {
|
||||
const U c_h = static_cast<U>(k_input[l]);
|
||||
const U c_loss = static_cast<U>(k_dout[l]);
|
||||
sum_loss1 += c_loss;
|
||||
sum_loss2 += c_loss * (c_h - c_mean) * c_invvar;
|
||||
}
|
||||
}
|
||||
// intra-warp reductions
|
||||
for (int mask = blockDim.x/2; mask > 0; mask /= 2) {
|
||||
sum_loss1 += WARP_SHFL_XOR(sum_loss1, mask);
|
||||
sum_loss2 += WARP_SHFL_XOR(sum_loss2, mask);
|
||||
}
|
||||
// inter-warp reductions
|
||||
if (blockDim.y > 1) {
|
||||
SharedMemory<U> shared;
|
||||
U* buf = shared.getPointer();
|
||||
for (int offset = blockDim.y/2; offset > 0; offset /= 2) {
|
||||
// upper half of warps write to shared
|
||||
if (threadIdx.y >= offset && threadIdx.y < 2*offset) {
|
||||
const int wrt_i = (threadIdx.y - offset) * blockDim.x + threadIdx.x;
|
||||
buf[2*wrt_i] = sum_loss1;
|
||||
buf[2*wrt_i+1] = sum_loss2;
|
||||
}
|
||||
__syncthreads();
|
||||
// lower half merges
|
||||
if (threadIdx.y < offset) {
|
||||
const int read_i = threadIdx.y * blockDim.x + threadIdx.x;
|
||||
sum_loss1 += buf[2*read_i];
|
||||
sum_loss2 += buf[2*read_i+1];
|
||||
}
|
||||
__syncthreads();
|
||||
}
|
||||
if (threadIdx.y == 0) {
|
||||
buf[2*threadIdx.x] = sum_loss1;
|
||||
buf[2*threadIdx.x+1] = sum_loss2;
|
||||
}
|
||||
__syncthreads();
|
||||
if (threadIdx.y !=0) {
|
||||
sum_loss1 = buf[2*threadIdx.x];
|
||||
sum_loss2 = buf[2*threadIdx.x+1];
|
||||
}
|
||||
}
|
||||
// all threads now have the two sums over l
|
||||
U fH = (U)n2;
|
||||
U term1 = (U(1) / fH) * c_invvar;
|
||||
T* k_grad_input = grad_input + i1*n2;
|
||||
if (gamma != NULL) {
|
||||
for (int l = thrx; l < n2; l+=numx) {
|
||||
const U c_h = static_cast<U>(k_input[l]);
|
||||
const U c_loss = static_cast<U>(k_dout[l]);
|
||||
U f_grad_input = fH * c_loss * gamma[l];
|
||||
f_grad_input -= sum_loss1;
|
||||
f_grad_input -= (c_h - c_mean) * c_invvar * sum_loss2;
|
||||
f_grad_input *= term1;
|
||||
k_grad_input[l] = static_cast<T>(f_grad_input);
|
||||
}
|
||||
} else {
|
||||
for (int l = thrx; l < n2; l+=numx) {
|
||||
const U c_h = static_cast<U>(k_input[l]);
|
||||
const U c_loss = static_cast<U>(k_dout[l]);
|
||||
U f_grad_input = fH * c_loss;
|
||||
f_grad_input -= sum_loss1;
|
||||
f_grad_input -= (c_h - c_mean) * c_invvar * sum_loss2;
|
||||
f_grad_input *= term1;
|
||||
k_grad_input[l] = static_cast<T>(f_grad_input);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
|
||||
template<typename T, typename U, typename V>
|
||||
void HostApplyLayerNorm(
|
||||
V* output,
|
||||
U* mean,
|
||||
U* invvar,
|
||||
const T* input,
|
||||
int n1,
|
||||
int n2,
|
||||
double epsilon,
|
||||
const V* gamma,
|
||||
const V* beta
|
||||
)
|
||||
{
|
||||
auto stream = at::cuda::getCurrentCUDAStream().stream();
|
||||
const dim3 threads(32,4,1);
|
||||
const uint64_t maxGridY =
|
||||
at::cuda::getCurrentDeviceProperties()->maxGridSize[1];
|
||||
const dim3 blocks(1, std::min((uint64_t)n1, maxGridY), 1);
|
||||
int nshared =
|
||||
threads.y > 1 ?
|
||||
threads.y*sizeof(U)+(threads.y/2)*sizeof(U) :
|
||||
0;
|
||||
cuApplyLayerNorm<<<blocks, threads, nshared, stream>>>(
|
||||
output,
|
||||
mean,
|
||||
invvar,
|
||||
input,
|
||||
n1,n2,
|
||||
U(epsilon),
|
||||
gamma,beta);
|
||||
}
|
||||
|
||||
|
||||
void cuda_layer_norm(
|
||||
at::Tensor* output,
|
||||
at::Tensor* mean,
|
||||
at::Tensor* invvar,
|
||||
at::Tensor* input,
|
||||
int n1,
|
||||
int n2,
|
||||
#ifdef VERSION_GE_1_1
|
||||
at::IntArrayRef normalized_shape,
|
||||
#else
|
||||
at::IntList normalized_shape,
|
||||
#endif
|
||||
at::Tensor* gamma,
|
||||
at::Tensor* beta,
|
||||
double epsilon)
|
||||
{
|
||||
using namespace at;
|
||||
DISPATCH_FLOAT_HALF_AND_BFLOAT_INOUT_TYPES(
|
||||
input->scalar_type(), output->scalar_type(), "cuda_layer_norm_kernel",
|
||||
HostApplyLayerNorm(
|
||||
output->DATA_PTR<scalar_t_out>(),
|
||||
mean->DATA_PTR<float>(),
|
||||
invvar->DATA_PTR<float>(),
|
||||
input->DATA_PTR<scalar_t_in>(),
|
||||
n1,n2,
|
||||
epsilon,
|
||||
gamma != NULL ? gamma->DATA_PTR<scalar_t_out>() : NULL,
|
||||
beta != NULL ? beta->DATA_PTR<scalar_t_out>() : NULL);
|
||||
)
|
||||
}
|
||||
|
||||
|
||||
template<typename T, typename U, typename V>
|
||||
void HostLayerNormGradient(
|
||||
const V* dout,
|
||||
const U* mean,
|
||||
const U* invvar,
|
||||
at::Tensor* input,
|
||||
int n1,
|
||||
int n2,
|
||||
const V* gamma,
|
||||
const V* beta,
|
||||
double epsilon,
|
||||
T* grad_input,
|
||||
V* grad_gamma,
|
||||
V* grad_beta
|
||||
)
|
||||
{
|
||||
auto stream = at::cuda::getCurrentCUDAStream().stream();
|
||||
|
||||
if (gamma != NULL && beta != NULL) {
|
||||
// compute grad_gamma(j) and grad_beta(j)
|
||||
const int part_size = 16;
|
||||
const dim3 threads2(32,4,1);
|
||||
const dim3 blocks2((n2+threads2.x-1)/threads2.x,part_size,1);
|
||||
const int nshared2_a = 2 * sizeof(U) * threads2.y * threads2.y *
|
||||
(threads2.x + 1);
|
||||
const int nshared2_b = threads2.x * threads2.y * sizeof(U);
|
||||
const int nshared2 = nshared2_a > nshared2_b ? nshared2_a : nshared2_b;
|
||||
at::Tensor part_grad_gamma = at::empty(
|
||||
{part_size,n2}, input->options().dtype(at::ScalarType::Float));
|
||||
at::Tensor part_grad_beta = at::empty_like(part_grad_gamma);
|
||||
cuComputePartGradGammaBeta<<<blocks2, threads2, nshared2, stream>>>(
|
||||
dout,
|
||||
input->DATA_PTR<T>(),
|
||||
n1,n2,
|
||||
mean,
|
||||
invvar,
|
||||
U(epsilon),
|
||||
part_grad_gamma.DATA_PTR<U>(),
|
||||
part_grad_beta.DATA_PTR<U>());
|
||||
|
||||
const dim3 threads3(32,8,1);
|
||||
const dim3 blocks3((n2+threads2.x-1)/threads2.x,1,1);
|
||||
const int nshared3 = threads3.x * threads3.y * sizeof(U);
|
||||
cuComputeGradGammaBeta<<<blocks3, threads3, nshared3, stream>>>(
|
||||
part_grad_gamma.DATA_PTR<U>(),
|
||||
part_grad_beta.DATA_PTR<U>(),
|
||||
part_size,
|
||||
n1,n2,
|
||||
grad_gamma,
|
||||
grad_beta);
|
||||
}
|
||||
|
||||
// compute grad_input
|
||||
const uint64_t maxGridY =
|
||||
at::cuda::getCurrentDeviceProperties()->maxGridSize[1];
|
||||
const dim3 blocks1(1, std::min((uint64_t)n1, maxGridY), 1);
|
||||
const dim3 threads1(32,4,1);
|
||||
int nshared =
|
||||
threads1.y > 1 ?
|
||||
threads1.y*threads1.x*sizeof(U) :
|
||||
0;
|
||||
cuComputeGradInput<<<blocks1, threads1, nshared, stream>>>(
|
||||
dout,
|
||||
input->DATA_PTR<T>(),
|
||||
n1,n2,
|
||||
mean,
|
||||
invvar,
|
||||
U(epsilon),
|
||||
gamma,
|
||||
grad_input);
|
||||
}
|
||||
|
||||
|
||||
void cuda_layer_norm_gradient(
|
||||
at::Tensor* dout,
|
||||
at::Tensor* mean,
|
||||
at::Tensor* invvar,
|
||||
at::Tensor* input,
|
||||
int n1,
|
||||
int n2,
|
||||
#ifdef VERSION_GE_1_1
|
||||
at::IntArrayRef normalized_shape,
|
||||
#else
|
||||
at::IntList normalized_shape,
|
||||
#endif
|
||||
at::Tensor* gamma,
|
||||
at::Tensor* beta,
|
||||
double epsilon,
|
||||
at::Tensor* grad_input,
|
||||
at::Tensor* grad_gamma,
|
||||
at::Tensor* grad_beta)
|
||||
{
|
||||
using namespace at;
|
||||
DISPATCH_FLOAT_HALF_AND_BFLOAT_INOUT_TYPES(
|
||||
input->scalar_type(), gamma->scalar_type(),
|
||||
"cuda_layer_norm_gradient_kernel",
|
||||
HostLayerNormGradient(
|
||||
dout->DATA_PTR<scalar_t_out>(),
|
||||
mean->DATA_PTR<float>(),
|
||||
invvar->DATA_PTR<float>(),
|
||||
input,
|
||||
n1,n2,
|
||||
// TMJ pass NULL argument for gamma, beta, grad_gamma and grad_beta
|
||||
// if gamma Tensor is NULL on input.
|
||||
gamma != NULL ? gamma->DATA_PTR<scalar_t_out>() : NULL,
|
||||
gamma != NULL ? beta->DATA_PTR<scalar_t_out>() : NULL,
|
||||
epsilon,
|
||||
grad_input->DATA_PTR<scalar_t_in>(),
|
||||
gamma != NULL ? grad_gamma->DATA_PTR<scalar_t_out>() : NULL,
|
||||
gamma != NULL ? grad_beta->DATA_PTR<scalar_t_out>() : NULL);
|
||||
)
|
||||
}
|
|
@ -0,0 +1,364 @@
|
|||
#include "multihead_attention_1d.h"
|
||||
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <torch/extension.h>
|
||||
|
||||
#include <c10d/Types.hpp>
|
||||
#include <iostream>
|
||||
|
||||
#include "context.h"
|
||||
#include "kernels.h"
|
||||
|
||||
template <typename T>
|
||||
MultiHeadAttention<T>::MultiHeadAttention(int layer_id, int max_batch_tokens, int max_seq_len,
|
||||
int hidden_size, int num_heads,
|
||||
float attn_prob_dropout_ratio,
|
||||
float hidden_output_dropout_ratio,
|
||||
bool pre_or_postLayerNorm)
|
||||
: _layer_id(layer_id),
|
||||
_max_batch_tokens(max_batch_tokens),
|
||||
_max_seq_len(max_seq_len),
|
||||
_hidden_size(hidden_size),
|
||||
_heads(num_heads),
|
||||
_training(true),
|
||||
_pre_or_postLayerNorm(pre_or_postLayerNorm),
|
||||
_qkv_linear(typename FeedForward<T>::Config(3 * hidden_size, hidden_size)),
|
||||
_attn_out_linear(typename FeedForward<T>::Config(hidden_size, hidden_size)),
|
||||
_attn_ln(typename Normalize_Layer<T>::Config(hidden_size, false), _max_batch_tokens),
|
||||
_softmax(typename Softmax<T>::Config(num_heads)),
|
||||
_attn_prob_dropout(typename Dropout<T>::Config(attn_prob_dropout_ratio),
|
||||
_max_batch_tokens * _heads * _max_seq_len),
|
||||
_attn_dropout(typename Dropout<T>::Config(hidden_output_dropout_ratio),
|
||||
_max_batch_tokens * _hidden_size),
|
||||
_attn_scores(typename StridedBatchGemm<T>::Config((T(1.0) / T(sqrt(_hidden_size / _heads))),
|
||||
T(0.0), CUBLAS_OP_T, CUBLAS_OP_N)),
|
||||
_attn_context(
|
||||
typename StridedBatchGemm<T>::Config(T(1.0), T(0.0), CUBLAS_OP_N, CUBLAS_OP_N)) {
|
||||
assert(_hidden_size % _heads == 0);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
MultiHeadAttention<T>::~MultiHeadAttention() {
|
||||
free_mem_buffer();
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void MultiHeadAttention<T>::attn_layer_fw(const T *input_ptr, const T *input_mask_ptr,
|
||||
T *output_ptr, T *buffer) {
|
||||
T *q_tf_ptr = _qkv_ptr;
|
||||
T *k_tf_ptr = q_tf_ptr + _batch_dim / pg_size;
|
||||
T *v_tf_ptr = k_tf_ptr + _batch_dim / pg_size;
|
||||
|
||||
if (_pre_or_postLayerNorm) {
|
||||
_attn_ln.Forward(_gemmQKV_inp_ptr, input_ptr, _attn_nw_ptr, _attn_nb_ptr, _batch_tokens,
|
||||
_stream);
|
||||
}
|
||||
const T *gemmQKV_inp_ptr = _pre_or_postLayerNorm ? _gemmQKV_inp_ptr : input_ptr;
|
||||
_qkv_linear.reset_size(3 * _hidden_size / pg_size, _hidden_size);
|
||||
_qkv_linear.Forward(_batch_tokens, gemmQKV_inp_ptr, _attn_qkvw_ptr, buffer, _cublasHandle);
|
||||
|
||||
launch_bias_add_transform_20314<T>(q_tf_ptr, buffer, _attn_qkvb_ptr, _batch_size, _seq_len, 3,
|
||||
_heads / pg_size, _hidden_size / _heads, _stream);
|
||||
|
||||
// attention scores, q*k
|
||||
_attn_scores.Forward(_batch_heads, _soft_out_ptr, k_tf_ptr, q_tf_ptr, _cublasHandle);
|
||||
|
||||
// Softmax + Mask
|
||||
_softmax.reset_size(_heads / pg_size);
|
||||
_softmax.Forward(_soft_out_ptr, input_mask_ptr, _batch_size, _seq_len, _seq_len, _stream, true);
|
||||
|
||||
// attn prob dropout.
|
||||
_attn_prob_dropout.dropout(_ctx_bufB_ptr, _soft_out_ptr, _batch_heads * _seq_len * _seq_len,
|
||||
_stream);
|
||||
|
||||
// attention context, score * v
|
||||
_attn_context.Forward(_batch_heads, buffer, v_tf_ptr, _ctx_bufB_ptr, _cublasHandle);
|
||||
|
||||
// [b, nh, s, ad] -> [b, s, nh, ad]
|
||||
launch_transform4d_0213<T>(_attn_o_inp_ptr, buffer, _batch_size, _seq_len, _hidden_size / pg_size,
|
||||
_heads / pg_size, 1, _stream);
|
||||
|
||||
_attn_out_linear.reset_size(_hidden_size, _hidden_size / pg_size);
|
||||
_attn_out_linear.Forward(_batch_tokens, _attn_o_inp_ptr, _attn_ow_ptr, output_ptr, _cublasHandle);
|
||||
|
||||
// allreduce
|
||||
if (pg == c10::detail::UniqueVoidPtr() || pg->getSize() == 1) {
|
||||
} else {
|
||||
auto data_type = torch::kFloat;
|
||||
if (typeid(T) != typeid(float)) {
|
||||
data_type = torch::kHalf;
|
||||
}
|
||||
auto output_tensor =
|
||||
torch::from_blob(output_ptr, {int(_batch_size), int(_seq_len), int(_hidden_size)},
|
||||
torch::TensorOptions(torch::kCUDA).dtype(data_type));
|
||||
std::vector<torch::Tensor> allreduce_tensors = {output_tensor};
|
||||
auto work = pg->allreduce(allreduce_tensors, c10d::AllreduceOptions());
|
||||
work->wait();
|
||||
}
|
||||
|
||||
_attn_dropout.bias_dropout_residual(output_ptr, output_ptr, input_ptr, _attn_ob_ptr,
|
||||
_batch_tokens, _hidden_size, _stream);
|
||||
if (!_pre_or_postLayerNorm) {
|
||||
// in-place ln since ln-input will not be used in post-ln mode
|
||||
_attn_ln.Forward(output_ptr, output_ptr, _attn_nw_ptr, _attn_nb_ptr, _batch_tokens, _stream);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void MultiHeadAttention<T>::Forward(const T *input_ptr, const T *input_mask_ptr, T *out_ptr) {
|
||||
_stream = Context::Instance().get_stream();
|
||||
_cublasHandle = Context::Instance().get_cublashandle();
|
||||
T *attn_buffer = _shared_mem_ptr; // 3 * _batch_dim
|
||||
|
||||
attn_layer_fw(input_ptr, input_mask_ptr, out_ptr, attn_buffer);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void MultiHeadAttention<T>::attn_layer_bw(const T *input_ptr, const T *input_mask_ptr, const T *output_ptr,
|
||||
const T *grad_output_ptr, T *grad_input_ptr, T *buffer) {
|
||||
cudaStream_t streams[2] = {_stream, _stream};
|
||||
|
||||
const T *q_tf_ptr = _qkv_ptr;
|
||||
const T *k_tf_ptr = q_tf_ptr + _batch_dim / pg_size;
|
||||
const T *v_tf_ptr = k_tf_ptr + _batch_dim / pg_size;
|
||||
// batch_dim = batch_size * seq_len * hidden_size
|
||||
// buffer size: batch_dim * 3 + max(batch_dim * 3,
|
||||
// batch_size * head_num * seq_len * seq_len)
|
||||
T *grad_residual_ptr = buffer;
|
||||
buffer += _batch_dim;
|
||||
|
||||
T *grad_input_buf_ptr = buffer; // batch_dim
|
||||
T *grad_qkv_5d_ptr = buffer; // batch_dim * 3
|
||||
buffer += 3 * _batch_dim / pg_size;
|
||||
|
||||
T *grad_qkv_4d_ptr = buffer; // batch_dim * 3
|
||||
T *grad_softmax_ptr = buffer; // batch_size * head_num * seq_len * seq_len
|
||||
// buffer += max(3 * _batch_dim,
|
||||
// batch_size * head_num * seq_len * seq_len);
|
||||
|
||||
if (_pre_or_postLayerNorm) {
|
||||
_attn_dropout.d_bias_dropout_residual(grad_input_ptr, _grad_attn_ob_ptr, grad_output_ptr,
|
||||
_batch_tokens, _hidden_size, _stream);
|
||||
} else {
|
||||
_attn_ln.Backward(_grad_attn_nw_ptr, _grad_attn_nb_ptr, grad_residual_ptr, grad_output_ptr,
|
||||
nullptr, output_ptr, _attn_nw_ptr, _attn_nb_ptr, _batch_tokens, streams);
|
||||
_attn_dropout.d_bias_dropout_residual(grad_input_ptr, _grad_attn_ob_ptr, grad_residual_ptr,
|
||||
_batch_tokens, _hidden_size, _stream);
|
||||
}
|
||||
|
||||
// bw of output project
|
||||
_attn_out_linear.reset_size(_hidden_size, _hidden_size / pg_size);
|
||||
_attn_out_linear.Backward(_batch_tokens, grad_input_ptr, _attn_o_inp_ptr, _attn_ow_ptr,
|
||||
_grad_attn_ow_ptr, _grad_attn_ob_ptr, _cublasHandle, _stream,
|
||||
grad_input_buf_ptr, nullptr, false);
|
||||
launch_transform_0213<T>(grad_input_ptr, grad_input_buf_ptr, _batch_size, _seq_len,
|
||||
_hidden_size / pg_size, _heads / pg_size, _stream);
|
||||
|
||||
// bw of score * v
|
||||
_attn_context.Backward(_batch_heads, grad_input_ptr, v_tf_ptr, _ctx_bufB_ptr, _cublasHandle,
|
||||
grad_qkv_5d_ptr + 2 * _batch_dim / pg_size, grad_softmax_ptr);
|
||||
|
||||
_attn_prob_dropout.d_dropout(grad_softmax_ptr, _batch_heads * _seq_len * _seq_len, _stream);
|
||||
|
||||
_softmax.reset_size(_heads / pg_size);
|
||||
_softmax.Backward(grad_softmax_ptr, _soft_out_ptr, _batch_size, _seq_len, _seq_len, _stream);
|
||||
|
||||
// bw of q * k
|
||||
_attn_scores.Backward(_batch_heads, grad_softmax_ptr, k_tf_ptr, q_tf_ptr, _cublasHandle,
|
||||
grad_qkv_5d_ptr + _batch_dim / pg_size, grad_qkv_5d_ptr);
|
||||
|
||||
// [3, b, nh, s, ad] -> [b, s, 3, h]
|
||||
launch_transform4d_0213<T>(grad_qkv_4d_ptr, grad_qkv_5d_ptr, _batch_size, _seq_len,
|
||||
_hidden_size / pg_size, _heads / pg_size, 3, _stream);
|
||||
|
||||
const T *gemmQKV_inp_ptr = _pre_or_postLayerNorm ? _gemmQKV_inp_ptr : input_ptr;
|
||||
_qkv_linear.reset_size(3 * _hidden_size / pg_size, _hidden_size);
|
||||
_qkv_linear.Backward(_batch_tokens, grad_qkv_4d_ptr, gemmQKV_inp_ptr, _attn_qkvw_ptr,
|
||||
_grad_attn_qkvw_ptr, _grad_attn_qkvb_ptr, _cublasHandle, _stream,
|
||||
grad_input_buf_ptr, nullptr, true);
|
||||
|
||||
// allreduce
|
||||
if (pg == c10::detail::UniqueVoidPtr() || pg->getSize() == 1) {
|
||||
} else {
|
||||
auto data_type = torch::kFloat;
|
||||
if (typeid(T) != typeid(float)) {
|
||||
data_type = torch::kHalf;
|
||||
}
|
||||
auto grad_input_tensor =
|
||||
torch::from_blob(grad_input_buf_ptr, {int(_batch_size), int(_seq_len), int(_hidden_size)},
|
||||
torch::TensorOptions(torch::kCUDA).dtype(data_type));
|
||||
std::vector<torch::Tensor> allreduce_tensors = {grad_input_tensor};
|
||||
auto work = pg->allreduce(allreduce_tensors, c10d::AllreduceOptions());
|
||||
work->wait();
|
||||
}
|
||||
|
||||
if (_pre_or_postLayerNorm) {
|
||||
_attn_ln.Backward(_grad_attn_nw_ptr, _grad_attn_nb_ptr, grad_input_ptr, grad_input_buf_ptr,
|
||||
grad_output_ptr, gemmQKV_inp_ptr, _attn_nw_ptr, _attn_nb_ptr, _batch_tokens,
|
||||
streams);
|
||||
} else {
|
||||
// FIXME later
|
||||
launch_fused_add2<T>(grad_input_ptr, grad_input_buf_ptr, grad_residual_ptr, _batch_size,
|
||||
_seq_len, _hidden_size, _stream);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void MultiHeadAttention<T>::Backward(const T *grad_output_ptr, const T *input_ptr, const T *output_ptr,
|
||||
const T *input_mask_ptr, T *grad_input_ptr) {
|
||||
_stream = Context::Instance().get_stream();
|
||||
_cublasHandle = Context::Instance().get_cublashandle();
|
||||
T *buffer = _shared_mem_ptr;
|
||||
|
||||
/*
|
||||
buffer size needed by attn bw:
|
||||
4 * _batch_dim + max(3 * _batch_dim,
|
||||
_batch_size * _head_num * _seq_len * _seq_len);
|
||||
*/
|
||||
attn_layer_bw(input_ptr, input_mask_ptr, output_ptr, grad_output_ptr, grad_input_ptr, buffer);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void MultiHeadAttention<T>::SetTrainingMode(bool training) {
|
||||
// Dropout will be skipped when not in training model.
|
||||
_attn_prob_dropout.SetTrainingMode(training);
|
||||
_attn_dropout.SetTrainingMode(training);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
T *MultiHeadAttention<T>::_shared_mem_ptr = nullptr;
|
||||
|
||||
template class MultiHeadAttention<float>;
|
||||
template class MultiHeadAttention<__half>;
|
||||
|
||||
// x is torch::Tensor
|
||||
#define CHECK_CUDA(x) AT_ASSERTM(x.is_cuda(), #x " must be a CUDA tensor")
|
||||
#define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous")
|
||||
#define CHECK_INPUT(x) \
|
||||
CHECK_CUDA(x); \
|
||||
CHECK_CONTIGUOUS(x)
|
||||
|
||||
static std::unordered_map<int, std::shared_ptr<void>> s_multihead_attention;
|
||||
|
||||
template <typename T>
|
||||
int create_multihead_attention(int layer_id, int max_batch_tokens, int max_seq_len, int hidden_dim,
|
||||
int num_heads, float attn_prob_dropout_ratio,
|
||||
float hidden_dropout_ratio, bool pre_or_postLayerNorm,
|
||||
c10::intrusive_ptr<c10d::ProcessGroup> pg_) {
|
||||
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
Context::Instance().set_stream(stream);
|
||||
auto layer = std::make_shared<MultiHeadAttention<T>>(
|
||||
layer_id, max_batch_tokens, max_seq_len, hidden_dim, num_heads, attn_prob_dropout_ratio,
|
||||
hidden_dropout_ratio, pre_or_postLayerNorm);
|
||||
|
||||
layer->SetPG(pg_);
|
||||
|
||||
s_multihead_attention[layer_id] = layer;
|
||||
|
||||
std::string dtype = (std::is_same<T, __half>::value) ? "half" : "float";
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
std::vector<torch::Tensor> multihead_attention_fw(int layer_id, const torch::Tensor &input,
|
||||
const torch::Tensor &input_mask,
|
||||
const torch::Tensor &in_proj_weight,
|
||||
const torch::Tensor &in_proj_bias,
|
||||
const torch::Tensor &out_proj_weight,
|
||||
const torch::Tensor &out_proj_bias,
|
||||
const torch::Tensor &norm_weight,
|
||||
const torch::Tensor &norm_bias,
|
||||
bool training_mode, bool prelayernorm) {
|
||||
CHECK_INPUT(input);
|
||||
CHECK_INPUT(input_mask);
|
||||
|
||||
const T *input_ptr = (const T *)input.data_ptr();
|
||||
const T *input_mask_ptr = (const T *)input_mask.data_ptr();
|
||||
|
||||
auto output = torch::empty_like(input);
|
||||
T *out_ptr = (T *)output.data_ptr();
|
||||
|
||||
std::shared_ptr<MultiHeadAttention<T>> layer =
|
||||
std::static_pointer_cast<MultiHeadAttention<T>>(s_multihead_attention[layer_id]);
|
||||
layer->set_cur_batch_shape(input.size(0), input.size(1));
|
||||
layer->SetTrainingMode(training_mode);
|
||||
|
||||
layer->_attn_qkvw_ptr = (const T *)in_proj_weight.data_ptr();
|
||||
layer->_attn_qkvb_ptr = (const T *)in_proj_bias.data_ptr();
|
||||
layer->_attn_ow_ptr = (const T *)out_proj_weight.data_ptr();
|
||||
layer->_attn_ob_ptr = (const T *)out_proj_bias.data_ptr();
|
||||
layer->_attn_nw_ptr = (const T *)norm_weight.data_ptr();
|
||||
layer->_attn_nb_ptr = (const T *)norm_bias.data_ptr();
|
||||
|
||||
layer->Forward(input_ptr, input_mask_ptr, out_ptr);
|
||||
|
||||
return {output};
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
std::vector<torch::Tensor> multihead_attention_bw(int layer_id,
|
||||
const torch::Tensor &grad_dec_output,
|
||||
const torch::Tensor &output,
|
||||
const torch::Tensor &input,
|
||||
const torch::Tensor &input_mask,
|
||||
const torch::Tensor &in_proj_weight,
|
||||
const torch::Tensor &in_proj_bias,
|
||||
const torch::Tensor &out_proj_weight,
|
||||
const torch::Tensor &out_proj_bias,
|
||||
const torch::Tensor &norm_weight,
|
||||
const torch::Tensor &norm_bias) {
|
||||
auto g_output = grad_dec_output.contiguous();
|
||||
CHECK_INPUT(g_output);
|
||||
CHECK_INPUT(output);
|
||||
CHECK_INPUT(input);
|
||||
CHECK_INPUT(input_mask);
|
||||
|
||||
auto grad_input = torch::empty_like(input);
|
||||
auto grad_in_proj_weight = torch::empty_like(in_proj_weight);
|
||||
auto grad_in_proj_bias = torch::empty_like(in_proj_bias);
|
||||
auto grad_out_proj_weight = torch::empty_like(out_proj_weight);
|
||||
auto grad_out_proj_bias = torch::empty_like(out_proj_bias);
|
||||
auto grad_norm_weight = torch::empty_like(norm_weight);
|
||||
auto grad_norm_bias = torch::empty_like(norm_bias);
|
||||
|
||||
// inputs.
|
||||
const T *grad_dec_output_ptr = (const T *)g_output.data_ptr();
|
||||
const T *input_ptr = (const T *)input.data_ptr();
|
||||
const T *output_ptr = (const T *)output.data_ptr();
|
||||
const T *input_mask_ptr = (const T *)input_mask.data_ptr();
|
||||
|
||||
// outputs.
|
||||
T *grad_input_ptr = (T *)grad_input.data_ptr();
|
||||
|
||||
std::shared_ptr<MultiHeadAttention<T>> layer =
|
||||
std::static_pointer_cast<MultiHeadAttention<T>>(s_multihead_attention[layer_id]);
|
||||
layer->set_cur_batch_shape(g_output.size(0), g_output.size(1));
|
||||
|
||||
layer->_grad_attn_qkvw_ptr = (T *)grad_in_proj_weight.data_ptr();
|
||||
layer->_grad_attn_qkvb_ptr = (T *)grad_in_proj_bias.data_ptr();
|
||||
layer->_grad_attn_ow_ptr = (T *)grad_out_proj_weight.data_ptr();
|
||||
layer->_grad_attn_ob_ptr = (T *)grad_out_proj_bias.data_ptr();
|
||||
layer->_grad_attn_nw_ptr = (T *)grad_norm_weight.data_ptr();
|
||||
layer->_grad_attn_nb_ptr = (T *)grad_norm_bias.data_ptr();
|
||||
|
||||
layer->Backward(grad_dec_output_ptr, input_ptr, output_ptr, input_mask_ptr, grad_input_ptr);
|
||||
|
||||
return {grad_input, grad_in_proj_weight, grad_in_proj_bias, grad_out_proj_weight,
|
||||
grad_out_proj_bias, grad_norm_weight, grad_norm_bias};
|
||||
}
|
||||
|
||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||
m.def("multihead_attention_fw_fp32", &multihead_attention_fw<float>,
|
||||
"Multi-head Attention forward with fp32 (CUDA)");
|
||||
m.def("multihead_attention_fw_fp16", &multihead_attention_fw<__half>,
|
||||
"Multi-head Attention forward with fp16 (CUDA)");
|
||||
m.def("multihead_attention_bw_fp32", &multihead_attention_bw<float>,
|
||||
"Multi-head Attention backward with fp32 (CUDA)");
|
||||
m.def("multihead_attention_bw_fp16", &multihead_attention_bw<__half>,
|
||||
"Multi-head Attention backward with fp16 (CUDA)");
|
||||
m.def("create_multihead_attention_fp32", &create_multihead_attention<float>,
|
||||
"Create Multi-head Attention with fp32 (CUDA)");
|
||||
m.def("create_multihead_attention_fp16", &create_multihead_attention<__half>,
|
||||
"Create Multi-head Attention with fp16 (CUDA)");
|
||||
}
|
|
@ -0,0 +1,153 @@
|
|||
#pragma once
|
||||
|
||||
#include <c10/util/intrusive_ptr.h>
|
||||
#include <cuda.h>
|
||||
#include <cuda_fp16.h>
|
||||
#include <cuda_runtime_api.h>
|
||||
|
||||
#include <c10d/ProcessGroup.hpp>
|
||||
#include <string>
|
||||
#include <type_traits>
|
||||
|
||||
#include "cuda_util.h"
|
||||
#include "dropout.h"
|
||||
#include "feed_forward.h"
|
||||
#include "normalize_layer.h"
|
||||
#include "softmax.h"
|
||||
#include "strided_batch_gemm.h"
|
||||
|
||||
template <typename T>
|
||||
class MultiHeadAttention {
|
||||
public:
|
||||
MultiHeadAttention(int layer_id, int max_batch_tokens, int _max_seq_len, int hidden_size,
|
||||
int num_heads, float attn_dropout_ratio, float hidden_output_dropout_ratio,
|
||||
bool pre_or_postLayerNorm);
|
||||
|
||||
virtual ~MultiHeadAttention();
|
||||
|
||||
void Forward(const T *input_ptr, const T *input_mask_ptr, T *out_ptr);
|
||||
|
||||
void Backward(const T *grad_output_ptr, const T *input_ptr, const T *output_ptr,
|
||||
const T *input_mask_ptr, T *grad_input_ptr);
|
||||
|
||||
void attn_layer_fw(const T *input_ptr, const T *input_mask_ptr, T *output_ptr, T *buffer);
|
||||
|
||||
void attn_layer_bw(const T *input_ptr, const T *input_mask_ptr, const T *output_ptr,
|
||||
const T *grad_output_ptr, T *grad_input_attn_layer_bwptr, T *buffer);
|
||||
|
||||
void set_cur_batch_shape(int batch_size, int seq_len) {
|
||||
_batch_size = batch_size;
|
||||
_seq_len = seq_len;
|
||||
_batch_tokens = batch_size * seq_len;
|
||||
_batch_heads = batch_size * _heads / pg_size;
|
||||
_batch_dim = _batch_tokens * _hidden_size;
|
||||
_attn_scores.SetConfig(_seq_len, _seq_len, _hidden_size / _heads);
|
||||
_attn_context.SetConfig(_hidden_size / _heads, _seq_len, _seq_len);
|
||||
}
|
||||
|
||||
void SetTrainingMode(bool training);
|
||||
inline bool IsTrainingMode() const { return _training; }
|
||||
|
||||
void SetPG(c10::intrusive_ptr<c10d::ProcessGroup> pg_) {
|
||||
pg = pg_;
|
||||
pg_size = 1;
|
||||
if (pg != c10::detail::UniqueVoidPtr()) {
|
||||
pg_size = pg->getSize();
|
||||
}
|
||||
allocate_mem_buffer();
|
||||
}
|
||||
|
||||
// weights ptr
|
||||
const T *_attn_qkvw_ptr;
|
||||
const T *_attn_qkvb_ptr;
|
||||
const T *_attn_ow_ptr;
|
||||
const T *_attn_ob_ptr;
|
||||
const T *_attn_nw_ptr;
|
||||
const T *_attn_nb_ptr;
|
||||
|
||||
// grads ptr
|
||||
T *_grad_attn_qkvw_ptr;
|
||||
T *_grad_attn_qkvb_ptr;
|
||||
T *_grad_attn_ow_ptr;
|
||||
T *_grad_attn_ob_ptr;
|
||||
T *_grad_attn_nw_ptr;
|
||||
T *_grad_attn_nb_ptr;
|
||||
|
||||
private:
|
||||
void allocate_mem_buffer() {
|
||||
// allocate local gpu memory
|
||||
if (_pre_or_postLayerNorm) {
|
||||
_gemmQKV_inp_ptr = cuda_malloc<T>(_max_batch_tokens * _hidden_size);
|
||||
} else {
|
||||
_gemmQKV_inp_ptr = nullptr;
|
||||
}
|
||||
|
||||
_qkv_ptr = cuda_malloc<T>(_max_batch_tokens * _hidden_size * 3);
|
||||
_soft_out_ptr = cuda_malloc<T>(_max_batch_tokens * _heads / pg_size * _max_seq_len);
|
||||
_ctx_bufB_ptr = cuda_malloc<T>(_max_batch_tokens * _heads / pg_size * _max_seq_len);
|
||||
_attn_o_inp_ptr = cuda_malloc<T>(_max_batch_tokens * _hidden_size);
|
||||
|
||||
// buffer size needed by attn bw
|
||||
size_t smem_size = 4 * _max_batch_tokens * _hidden_size / pg_size +
|
||||
std::max(3 * _max_batch_tokens * _hidden_size / pg_size,
|
||||
_max_batch_tokens * _heads / pg_size * _max_seq_len);
|
||||
|
||||
if (!_shared_mem_ptr) {
|
||||
cuda_free(_shared_mem_ptr);
|
||||
_shared_mem_ptr = cuda_malloc<T>(smem_size);
|
||||
}
|
||||
}
|
||||
|
||||
void free_mem_buffer() {
|
||||
// free local gpu memory
|
||||
cuda_free(_gemmQKV_inp_ptr);
|
||||
cuda_free(_qkv_ptr);
|
||||
cuda_free(_soft_out_ptr);
|
||||
cuda_free(_ctx_bufB_ptr);
|
||||
cuda_free(_attn_o_inp_ptr);
|
||||
|
||||
// free shared gpu memory between layers
|
||||
cuda_free(_shared_mem_ptr);
|
||||
_shared_mem_ptr = nullptr;
|
||||
}
|
||||
|
||||
// const parameter between batch
|
||||
const size_t _layer_id;
|
||||
const size_t _hidden_size;
|
||||
const size_t _heads;
|
||||
const size_t _max_batch_tokens;
|
||||
const size_t _max_seq_len;
|
||||
const bool _pre_or_postLayerNorm;
|
||||
// dynamic parameter between batch
|
||||
size_t _batch_size;
|
||||
size_t _seq_len;
|
||||
size_t _batch_tokens;
|
||||
size_t _batch_heads;
|
||||
size_t _batch_dim;
|
||||
bool _training;
|
||||
|
||||
cublasHandle_t _cublasHandle;
|
||||
cudaStream_t _stream;
|
||||
|
||||
// layers
|
||||
FeedForward<T> _qkv_linear;
|
||||
FeedForward<T> _attn_out_linear;
|
||||
Normalize_Layer<T> _attn_ln;
|
||||
Softmax<T> _softmax;
|
||||
Dropout<T> _attn_prob_dropout;
|
||||
Dropout<T> _attn_dropout;
|
||||
StridedBatchGemm<T> _attn_scores;
|
||||
StridedBatchGemm<T> _attn_context;
|
||||
|
||||
// local GPU memory
|
||||
T *_gemmQKV_inp_ptr;
|
||||
T *_qkv_ptr;
|
||||
T *_soft_out_ptr;
|
||||
T *_ctx_bufB_ptr;
|
||||
T *_attn_o_inp_ptr;
|
||||
// shared GPU memory between layer
|
||||
static T *_shared_mem_ptr;
|
||||
|
||||
c10::intrusive_ptr<c10d::ProcessGroup> pg;
|
||||
int pg_size;
|
||||
};
|
|
@ -0,0 +1,84 @@
|
|||
/*This code from NVIDIA Megatron:
|
||||
* with minor changes. */
|
||||
|
||||
#include <cuda_fp16.h>
|
||||
#include <torch/extension.h>
|
||||
#include <vector>
|
||||
|
||||
namespace multihead_attn {
|
||||
namespace fused_softmax {
|
||||
namespace scaled_masked_softmax {
|
||||
|
||||
torch::Tensor fwd_cuda(
|
||||
torch::Tensor const& input,
|
||||
torch::Tensor const& mask,
|
||||
float scale_factor);
|
||||
|
||||
torch::Tensor bwd_cuda(
|
||||
torch::Tensor const& output_grads,
|
||||
torch::Tensor const& softmax_results,
|
||||
float scale_factor);
|
||||
|
||||
int get_batch_per_block_cuda(
|
||||
int query_seq_len,
|
||||
int key_seq_len,
|
||||
int batches,
|
||||
int attn_heads);
|
||||
|
||||
torch::Tensor fwd(
|
||||
torch::Tensor const& input,
|
||||
torch::Tensor const& mask,
|
||||
float scale_factor) {
|
||||
AT_ASSERTM(input.dim() == 4, "expected 4D tensor");
|
||||
AT_ASSERTM((input.scalar_type() == at::ScalarType::Half) ||
|
||||
(input.scalar_type() == at::ScalarType::BFloat16),
|
||||
"Only fp16 and bf16 are supported");
|
||||
AT_ASSERTM(mask.dim() == 4, "expected 4D tensor");
|
||||
|
||||
return fwd_cuda(input, mask, scale_factor);
|
||||
}
|
||||
|
||||
torch::Tensor bwd(
|
||||
torch::Tensor const& output_grads,
|
||||
torch::Tensor const& softmax_results,
|
||||
float scale_factor) {
|
||||
|
||||
AT_ASSERTM(output_grads.dim() == 4, "expected 3D tensor");
|
||||
AT_ASSERTM(softmax_results.dim() == 4, "expected 3D tensor");
|
||||
|
||||
AT_ASSERTM((output_grads.scalar_type() == at::ScalarType::Half) ||
|
||||
(output_grads.scalar_type() == at::ScalarType::BFloat16),
|
||||
"Only fp16 and bf16 are supported");
|
||||
AT_ASSERTM((softmax_results.scalar_type() == at::ScalarType::Half) ||
|
||||
(softmax_results.scalar_type() == at::ScalarType::BFloat16),
|
||||
"Only fp16 and bf16 are supported");
|
||||
|
||||
return bwd_cuda(output_grads, softmax_results, scale_factor);
|
||||
}
|
||||
|
||||
int get_batch_per_block(
|
||||
int query_seq_len,
|
||||
int key_seq_len,
|
||||
int batches,
|
||||
int attn_heads) {
|
||||
return get_batch_per_block_cuda(query_seq_len, key_seq_len, batches, attn_heads);
|
||||
}
|
||||
|
||||
} // end namespace scaled_masked_softmax
|
||||
} // end namespace fused_softmax
|
||||
} // end namespace multihead_attn
|
||||
|
||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||
m.def("forward",
|
||||
&multihead_attn::fused_softmax::scaled_masked_softmax::fwd,
|
||||
"Self Multihead Attention scaled, time masked softmax -- Forward.");
|
||||
|
||||
m.def("backward",
|
||||
&multihead_attn::fused_softmax::scaled_masked_softmax::bwd,
|
||||
"Self Multihead Attention scaled, time masked softmax -- Backward.");
|
||||
|
||||
m.def("get_batch_per_block",
|
||||
&multihead_attn::fused_softmax::scaled_masked_softmax::get_batch_per_block,
|
||||
"Return Batch per block size."
|
||||
);
|
||||
}
|
|
@ -0,0 +1,492 @@
|
|||
/*This code from NVIDIA Megatron:
|
||||
* with minor changes. */
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <assert.h>
|
||||
#include <cuda_fp16.h>
|
||||
#include <cfloat>
|
||||
#include <limits>
|
||||
#include <stdint.h>
|
||||
#include <cuda_fp16.h>
|
||||
#include <c10/macros/Macros.h>
|
||||
|
||||
namespace {
|
||||
|
||||
template <typename Datatype, int ELEMENTS_PER_LDG>
|
||||
__device__ __inline__ void copy_vector(Datatype *dst, const Datatype *src);
|
||||
|
||||
template <>
|
||||
__device__ __inline__ void copy_vector<c10::BFloat16, 1>(c10::BFloat16 *dst, const c10::BFloat16 *src) { *dst = *src; }
|
||||
|
||||
template <>
|
||||
__device__ __inline__ void copy_vector<c10::BFloat16, 4>(c10::BFloat16 *dst, const c10::BFloat16 *src) { *((float2*) dst) = *((float2*) src); }
|
||||
|
||||
template <>
|
||||
__device__ __inline__ void copy_vector<c10::Half, 1>(c10::Half *dst, const c10::Half *src) { *dst = *src; }
|
||||
|
||||
template <>
|
||||
__device__ __inline__ void copy_vector<c10::Half, 4>(c10::Half *dst, const c10::Half *src) { *((float2*) dst) = *((float2*) src); }
|
||||
|
||||
template <>
|
||||
__device__ __inline__ void copy_vector<uint8_t, 1>(uint8_t *dst, const uint8_t *src) { *dst = *src; }
|
||||
|
||||
template <>
|
||||
__device__ __inline__ void copy_vector<uint8_t, 4>(uint8_t *dst, const uint8_t *src) {*((half2*) dst) = *((half2*) src); }
|
||||
|
||||
int log2_ceil(int value) {
|
||||
int log2_value = 0;
|
||||
while ((1 << log2_value) < value) ++log2_value;
|
||||
return log2_value;
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
struct Add {
|
||||
__device__ __forceinline__ T operator()(T a, T b) const {
|
||||
return a + b;
|
||||
}
|
||||
};
|
||||
|
||||
template<typename T>
|
||||
struct Max {
|
||||
__device__ __forceinline__ T operator()(T a, T b) const {
|
||||
return a < b ? b : a;
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
__device__ __forceinline__ T WARP_SHFL_XOR_NATIVE(T value, int laneMask, int width = warpSize, unsigned int mask = 0xffffffff)
|
||||
{
|
||||
#if CUDA_VERSION >= 9000
|
||||
return __shfl_xor_sync(mask, value, laneMask, width);
|
||||
#else
|
||||
return __shfl_xor(value, laneMask, width);
|
||||
#endif
|
||||
}
|
||||
|
||||
template <typename acc_t, int WARP_BATCH, int WARP_SIZE, template<typename> class ReduceOp>
|
||||
__device__ __forceinline__ void warp_reduce(acc_t* sum) {
|
||||
ReduceOp<acc_t> r;
|
||||
#pragma unroll
|
||||
for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < WARP_BATCH; ++i) {
|
||||
acc_t b = WARP_SHFL_XOR_NATIVE(sum[i], offset, WARP_SIZE);
|
||||
sum[i] = r(sum[i], b);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/*
|
||||
* Extended softmax (from native aten pytorch) with following additional features
|
||||
* 1) input scaling
|
||||
* 2) Explicit masking
|
||||
*/
|
||||
template <typename input_t, typename output_t, typename acc_t, int log2_elements>
|
||||
__global__ void scaled_masked_softmax_warp_forward(
|
||||
output_t *dst,
|
||||
const input_t *src,
|
||||
const uint8_t *mask,
|
||||
const acc_t scale,
|
||||
int micro_batch_size,
|
||||
int element_count,
|
||||
int pad_batches)
|
||||
{
|
||||
// WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and
|
||||
// warp_size of method warp_softmax_forward_kernel.
|
||||
constexpr int next_power_of_two = 1 << log2_elements;
|
||||
constexpr int WARP_SIZE = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE;
|
||||
constexpr int WARP_ITERATIONS = next_power_of_two / WARP_SIZE;
|
||||
constexpr int WARP_BATCH = (next_power_of_two <= 128) ? 2 : 1;
|
||||
constexpr int ELEMENTS_PER_LDG_STG = (WARP_ITERATIONS < 4) ? 1 : 4;
|
||||
|
||||
// blockDim/threadIdx = (WARP_SIZE, WARPS_PER_BLOCK, )
|
||||
// gridDim/blockIdx = (seq_len, attn_heads, batches)
|
||||
int first_batch = (blockDim.y * (blockIdx.x + gridDim.x * (blockIdx.y + gridDim.y * blockIdx.z))+ threadIdx.y) * WARP_BATCH;
|
||||
int pad_first_batch = 0;
|
||||
if (pad_batches != 1) { // bert style
|
||||
pad_first_batch = (blockDim.y * (blockIdx.x + gridDim.x * blockIdx.z) + threadIdx.y) * WARP_BATCH;
|
||||
} else { // gpt2 style
|
||||
pad_first_batch = (blockDim.y * blockIdx.x + threadIdx.y) * WARP_BATCH;
|
||||
}
|
||||
|
||||
// micro_batch_size might not be a multiple of WARP_BATCH. Check how
|
||||
// many batches have to computed within this WARP.
|
||||
int local_batches = micro_batch_size - first_batch;
|
||||
if (local_batches > WARP_BATCH)
|
||||
local_batches = WARP_BATCH;
|
||||
|
||||
// there might be multiple batches per warp. compute the index within the batch
|
||||
int local_idx = threadIdx.x;
|
||||
|
||||
src += first_batch * element_count + ELEMENTS_PER_LDG_STG * local_idx;
|
||||
dst += first_batch * element_count + ELEMENTS_PER_LDG_STG * local_idx;
|
||||
mask += pad_first_batch * element_count + ELEMENTS_PER_LDG_STG * local_idx;
|
||||
|
||||
// load data from global memory
|
||||
acc_t elements[WARP_BATCH][WARP_ITERATIONS];
|
||||
input_t temp_data[ELEMENTS_PER_LDG_STG];
|
||||
uint8_t temp_mask[ELEMENTS_PER_LDG_STG];
|
||||
#pragma unroll
|
||||
for (int i = 0; i < WARP_BATCH; ++i) {
|
||||
int batch_element_count = (i >= local_batches) ? 0 : element_count;
|
||||
|
||||
#pragma unroll
|
||||
for (int it = 0; it < WARP_ITERATIONS; it+=ELEMENTS_PER_LDG_STG) {
|
||||
int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE;
|
||||
|
||||
if (element_index < batch_element_count) {
|
||||
int itr_idx = i*element_count+it*WARP_SIZE;
|
||||
copy_vector<input_t, ELEMENTS_PER_LDG_STG>(temp_data, src + itr_idx);
|
||||
copy_vector<uint8_t, ELEMENTS_PER_LDG_STG>(temp_mask, mask + itr_idx);
|
||||
|
||||
#pragma unroll
|
||||
for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) {
|
||||
if (temp_mask[element] != 1) {
|
||||
elements[i][it + element] = (acc_t)temp_data[element] * scale;
|
||||
} else {
|
||||
elements[i][it + element] = -10000.0;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
#pragma unroll
|
||||
for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) {
|
||||
elements[i][it + element] = -std::numeric_limits<acc_t>::infinity();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// compute max_value
|
||||
acc_t max_value[WARP_BATCH];
|
||||
#pragma unroll
|
||||
for (int i = 0; i < WARP_BATCH; ++i) {
|
||||
max_value[i] = elements[i][0];
|
||||
#pragma unroll
|
||||
for (int it = 1; it < WARP_ITERATIONS; ++it) {
|
||||
max_value[i] = (max_value[i] > elements[i][it]) ? max_value[i] : elements[i][it];
|
||||
}
|
||||
}
|
||||
warp_reduce<acc_t, WARP_BATCH, WARP_SIZE, Max>(max_value);
|
||||
|
||||
acc_t sum[WARP_BATCH] { 0.0f };
|
||||
#pragma unroll
|
||||
for (int i = 0; i < WARP_BATCH; ++i) {
|
||||
#pragma unroll
|
||||
for (int it = 0; it < WARP_ITERATIONS; ++it) {
|
||||
elements[i][it] = std::exp((elements[i][it] - max_value[i]));
|
||||
sum[i] += elements[i][it];
|
||||
}
|
||||
}
|
||||
warp_reduce<acc_t, WARP_BATCH, WARP_SIZE, Add>(sum);
|
||||
|
||||
// store result
|
||||
output_t out[ELEMENTS_PER_LDG_STG];
|
||||
#pragma unroll
|
||||
for (int i = 0; i < WARP_BATCH; ++i) {
|
||||
if (i >= local_batches)
|
||||
break;
|
||||
#pragma unroll
|
||||
for (int it = 0; it < WARP_ITERATIONS; it+=ELEMENTS_PER_LDG_STG) {
|
||||
int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE;
|
||||
if (element_index < element_count) {
|
||||
#pragma unroll
|
||||
for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) {
|
||||
out[element] = elements[i][it + element] / sum[i];
|
||||
}
|
||||
copy_vector<output_t, ELEMENTS_PER_LDG_STG>(dst + i * element_count + it * WARP_SIZE, out);
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename input_t, typename output_t, typename acc_t, int log2_elements>
|
||||
__global__ void scaled_masked_softmax_warp_backward(
|
||||
output_t *gradInput,
|
||||
input_t *grad,
|
||||
const input_t *output,
|
||||
acc_t scale,
|
||||
int micro_batch_size,
|
||||
int element_count)
|
||||
{
|
||||
// WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and
|
||||
// warp_size of method warp_softmax_backward_kernel.
|
||||
constexpr int next_power_of_two = 1 << log2_elements;
|
||||
constexpr int WARP_SIZE = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE;
|
||||
constexpr int WARP_ITERATIONS = next_power_of_two / WARP_SIZE;
|
||||
constexpr int WARP_BATCH = (next_power_of_two <= 128) ? 2 : 1;
|
||||
constexpr int ELEMENTS_PER_LDG_STG = (WARP_ITERATIONS < 4) ? 1 : 4;
|
||||
|
||||
// blockDim/threadIdx = (WARP_SIZE, WARPS_PER_BLOCK, )
|
||||
// gridDim/blockIdx = (seq_len, attn_heads, batches)
|
||||
int first_batch = (blockDim.y * blockIdx.x + threadIdx.y) * WARP_BATCH;
|
||||
|
||||
// micro_batch_size might not be a multiple of WARP_BATCH. Check how
|
||||
// many batches have to computed within this WARP.
|
||||
int local_batches = micro_batch_size - first_batch;
|
||||
if (local_batches > WARP_BATCH)
|
||||
local_batches = WARP_BATCH;
|
||||
|
||||
// there might be multiple batches per warp. compute the index within the batch
|
||||
int local_idx = threadIdx.x;
|
||||
|
||||
// the first element to process by the current thread
|
||||
int thread_offset = first_batch * element_count + ELEMENTS_PER_LDG_STG * local_idx;
|
||||
grad += thread_offset;
|
||||
output += thread_offset;
|
||||
gradInput += thread_offset;
|
||||
|
||||
// load data from global memory
|
||||
acc_t grad_reg[WARP_BATCH][WARP_ITERATIONS] { 0.0f };
|
||||
acc_t output_reg[WARP_BATCH][WARP_ITERATIONS] { 0.0f };
|
||||
input_t temp_grad[ELEMENTS_PER_LDG_STG];
|
||||
input_t temp_output[ELEMENTS_PER_LDG_STG];
|
||||
#pragma unroll
|
||||
for (int i = 0; i < WARP_BATCH; ++i) {
|
||||
int batch_element_count = (i >= local_batches) ? 0 : element_count;
|
||||
|
||||
#pragma unroll
|
||||
for (int it = 0; it < WARP_ITERATIONS; it+=ELEMENTS_PER_LDG_STG) {
|
||||
int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE;
|
||||
if (element_index < batch_element_count) {
|
||||
copy_vector<input_t, ELEMENTS_PER_LDG_STG>(temp_grad, grad + i * element_count + it * WARP_SIZE);
|
||||
copy_vector<input_t, ELEMENTS_PER_LDG_STG>(temp_output, output + i * element_count + it * WARP_SIZE);
|
||||
|
||||
#pragma unroll
|
||||
for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) {
|
||||
output_reg[i][it + element] = (acc_t)temp_output[element];
|
||||
}
|
||||
#pragma unroll
|
||||
for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) {
|
||||
grad_reg[i][it + element] = (acc_t)temp_grad[element] * output_reg[i][it + element];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
acc_t sum[WARP_BATCH];
|
||||
#pragma unroll
|
||||
for (int i = 0; i < WARP_BATCH; ++i) {
|
||||
sum[i] = grad_reg[i][0];
|
||||
#pragma unroll
|
||||
for (int it = 1; it < WARP_ITERATIONS; ++it) {
|
||||
sum[i] += grad_reg[i][it];
|
||||
}
|
||||
}
|
||||
warp_reduce<acc_t, WARP_BATCH, WARP_SIZE, Add>(sum);
|
||||
|
||||
// store result
|
||||
#pragma unroll
|
||||
for (int i = 0; i < WARP_BATCH; ++i) {
|
||||
if (i >= local_batches)
|
||||
break;
|
||||
#pragma unroll
|
||||
for (int it = 0; it < WARP_ITERATIONS; it+=ELEMENTS_PER_LDG_STG) {
|
||||
int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE;
|
||||
if (element_index < element_count) {
|
||||
// compute gradients
|
||||
output_t out[ELEMENTS_PER_LDG_STG];
|
||||
#pragma unroll
|
||||
for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) {
|
||||
out[element] = (output_t)(scale * (grad_reg[i][it + element] - output_reg[i][it + element] * sum[i]));
|
||||
}
|
||||
copy_vector<output_t, ELEMENTS_PER_LDG_STG>(gradInput + i * element_count + it * WARP_SIZE, out);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
} // end of anonymous namespace
|
||||
|
||||
int get_batch_per_block(int query_seq_len, int key_seq_len, int batches, int attn_heads){
|
||||
int log2_elements = log2_ceil(key_seq_len);
|
||||
const int next_power_of_two = 1 << log2_elements;
|
||||
|
||||
int warp_size = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE;
|
||||
int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1;
|
||||
|
||||
constexpr int threads_per_block = 128;
|
||||
int warps_per_block = (threads_per_block / warp_size);
|
||||
int batches_per_block = warps_per_block * batches_per_warp;
|
||||
|
||||
return batches_per_block;
|
||||
}
|
||||
|
||||
template<typename input_t, typename output_t, typename acc_t>
|
||||
void dispatch_scaled_masked_softmax_forward(
|
||||
output_t *dst,
|
||||
const input_t *src,
|
||||
const uint8_t *mask,
|
||||
const input_t scale,
|
||||
int query_seq_len,
|
||||
int key_seq_len,
|
||||
int batches,
|
||||
int attn_heads,
|
||||
int pad_batches)
|
||||
{
|
||||
TORCH_INTERNAL_ASSERT(key_seq_len >= 0 && key_seq_len <= 2048 );
|
||||
if (key_seq_len == 0) {
|
||||
return;
|
||||
} else {
|
||||
int log2_elements = log2_ceil(key_seq_len);
|
||||
const int next_power_of_two = 1 << log2_elements;
|
||||
int batch_count = batches * attn_heads * query_seq_len;
|
||||
|
||||
// This value must match the WARP_SIZE constexpr value computed inside softmax_warp_forward.
|
||||
int warp_size = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE;
|
||||
|
||||
// This value must match the WARP_BATCH constexpr value computed inside softmax_warp_forward.
|
||||
int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1;
|
||||
|
||||
// use 128 threads per block to maximimize gpu utilization
|
||||
constexpr int threads_per_block = 128;
|
||||
|
||||
int warps_per_block = (threads_per_block / warp_size);
|
||||
int batches_per_block = warps_per_block * batches_per_warp;
|
||||
TORCH_INTERNAL_ASSERT(query_seq_len%batches_per_block == 0);
|
||||
dim3 blocks(query_seq_len/batches_per_block, attn_heads, batches);
|
||||
dim3 threads(warp_size, warps_per_block, 1);
|
||||
// Launch code would be more elegant if C++ supported FOR CONSTEXPR
|
||||
switch (log2_elements) {
|
||||
case 0: // 1
|
||||
scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 0>
|
||||
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches);
|
||||
break;
|
||||
case 1: // 2
|
||||
scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 1>
|
||||
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches);
|
||||
break;
|
||||
case 2: // 4
|
||||
scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 2>
|
||||
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches);
|
||||
break;
|
||||
case 3: // 8
|
||||
scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 3>
|
||||
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches);
|
||||
break;
|
||||
case 4: // 16
|
||||
scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 4>
|
||||
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches);
|
||||
break;
|
||||
case 5: // 32
|
||||
scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 5>
|
||||
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches);
|
||||
break;
|
||||
case 6: // 64
|
||||
scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 6>
|
||||
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches);
|
||||
break;
|
||||
case 7: // 128
|
||||
scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 7>
|
||||
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches);
|
||||
break;
|
||||
case 8: // 256
|
||||
scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 8>
|
||||
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches);
|
||||
break;
|
||||
case 9: // 512
|
||||
scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 9>
|
||||
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches);
|
||||
break;
|
||||
case 10: // 1024
|
||||
scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 10>
|
||||
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches);
|
||||
break;
|
||||
case 11: // 2048
|
||||
scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 11>
|
||||
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches);
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template<typename input_t, typename output_t, typename acc_t>
|
||||
void dispatch_scaled_masked_softmax_backward(
|
||||
output_t *grad_input,
|
||||
input_t *grad,
|
||||
const input_t *output,
|
||||
const acc_t scale,
|
||||
int query_seq_len,
|
||||
int key_seq_len,
|
||||
int batches,
|
||||
int attn_heads)
|
||||
{
|
||||
TORCH_INTERNAL_ASSERT( key_seq_len >= 0 && key_seq_len <= 2048 );
|
||||
if (key_seq_len == 0) {
|
||||
return;
|
||||
} else {
|
||||
int log2_elements = log2_ceil(key_seq_len);
|
||||
const int next_power_of_two = 1 << log2_elements;
|
||||
int batch_count = batches * attn_heads * query_seq_len;
|
||||
|
||||
// This value must match the WARP_SIZE constexpr value computed inside softmax_warp_backward.
|
||||
int warp_size = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE;
|
||||
|
||||
// This value must match the WARP_BATCH constexpr value computed inside softmax_warp_backward.
|
||||
int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1;
|
||||
|
||||
// use 128 threads per block to maximimize gpu utilization
|
||||
constexpr int threads_per_block = 128;
|
||||
|
||||
int warps_per_block = (threads_per_block / warp_size);
|
||||
int batches_per_block = warps_per_block * batches_per_warp;
|
||||
int blocks = batch_count/batches_per_block;
|
||||
dim3 threads(warp_size, warps_per_block, 1);
|
||||
// Launch code would be more elegant if C++ supported FOR CONSTEXPR
|
||||
switch (log2_elements) {
|
||||
case 0: // 1
|
||||
scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 0>
|
||||
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, key_seq_len);
|
||||
break;
|
||||
case 1: // 2
|
||||
scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 1>
|
||||
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, key_seq_len);
|
||||
break;
|
||||
case 2: // 4
|
||||
scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 2>
|
||||
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, key_seq_len);
|
||||
break;
|
||||
case 3: // 8
|
||||
scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 3>
|
||||
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, key_seq_len);
|
||||
break;
|
||||
case 4: // 16
|
||||
scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 4>
|
||||
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, key_seq_len);
|
||||
break;
|
||||
case 5: // 32
|
||||
scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 5>
|
||||
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, key_seq_len);
|
||||
break;
|
||||
case 6: // 64
|
||||
scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 6>
|
||||
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, key_seq_len);
|
||||
break;
|
||||
case 7: // 128
|
||||
scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 7>
|
||||
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, key_seq_len);
|
||||
break;
|
||||
case 8: // 256
|
||||
scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 8>
|
||||
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, key_seq_len);
|
||||
break;
|
||||
case 9: // 512
|
||||
scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 9>
|
||||
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, key_seq_len);
|
||||
break;
|
||||
case 10: // 1024
|
||||
scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 10>
|
||||
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, key_seq_len);
|
||||
break;
|
||||
case 11: // 2048
|
||||
scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 11>
|
||||
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, key_seq_len);
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,104 @@
|
|||
/*This code from NVIDIA Megatron:
|
||||
* with minor changes. */
|
||||
|
||||
#include <ATen/ATen.h>
|
||||
#include <cuda.h>
|
||||
#include <cuda_runtime.h>
|
||||
#include <cuda_fp16.h>
|
||||
#include <cuda_profiler_api.h>
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <torch/extension.h>
|
||||
#include "scaled_masked_softmax.h"
|
||||
#include "type_shim.h"
|
||||
|
||||
namespace multihead_attn {
|
||||
namespace fused_softmax {
|
||||
namespace scaled_masked_softmax {
|
||||
|
||||
int get_batch_per_block_cuda(int query_seq_len, int key_seq_len, int batches, int attn_heads){
|
||||
return get_batch_per_block(query_seq_len, key_seq_len, batches, attn_heads);
|
||||
}
|
||||
|
||||
|
||||
torch::Tensor fwd_cuda(
|
||||
torch::Tensor const& input,
|
||||
torch::Tensor const& mask,
|
||||
float scale_factor)
|
||||
{
|
||||
// input is a 4d tensor with dimensions [batches, attn_heads, seq_len, seq_len]
|
||||
const int batches = input.size(0);
|
||||
const int pad_batches = mask.size(0);
|
||||
const int attn_heads = input.size(1);
|
||||
const int query_seq_len = input.size(2);
|
||||
const int key_seq_len = input.size(3);
|
||||
TORCH_INTERNAL_ASSERT(key_seq_len <= 2048);
|
||||
TORCH_INTERNAL_ASSERT(query_seq_len > 1);
|
||||
TORCH_INTERNAL_ASSERT(pad_batches == 1 || pad_batches == batches);
|
||||
TORCH_INTERNAL_ASSERT(mask.size(1) == 1);
|
||||
TORCH_INTERNAL_ASSERT(mask.size(2) == query_seq_len);
|
||||
TORCH_INTERNAL_ASSERT(mask.size(3) == key_seq_len);
|
||||
|
||||
// Output
|
||||
auto act_options = input.options().requires_grad(false);
|
||||
torch::Tensor softmax_results =
|
||||
torch::empty({batches, attn_heads, query_seq_len, key_seq_len}, act_options);
|
||||
|
||||
// Softmax Intermediate Result Ptr
|
||||
void* input_ptr = static_cast<void*>(input.data_ptr());
|
||||
void* mask_ptr = static_cast<void*>(mask.data_ptr());
|
||||
void* softmax_results_ptr = static_cast<void*>(softmax_results.data_ptr());
|
||||
|
||||
DISPATCH_HALF_AND_BFLOAT(
|
||||
input.scalar_type(),
|
||||
"dispatch_scaled_masked_softmax_forward",
|
||||
dispatch_scaled_masked_softmax_forward<scalar_t, scalar_t, float>(
|
||||
reinterpret_cast<scalar_t*>(softmax_results_ptr),
|
||||
reinterpret_cast<const scalar_t*>(input_ptr),
|
||||
reinterpret_cast<const uint8_t*>(mask_ptr),
|
||||
scale_factor,
|
||||
query_seq_len,
|
||||
key_seq_len,
|
||||
batches,
|
||||
attn_heads,
|
||||
pad_batches);
|
||||
);
|
||||
return softmax_results;
|
||||
}
|
||||
|
||||
torch::Tensor bwd_cuda(
|
||||
torch::Tensor const& output_grads_,
|
||||
torch::Tensor const& softmax_results_,
|
||||
float scale_factor) {
|
||||
|
||||
auto output_grads = output_grads_.contiguous();
|
||||
auto softmax_results = softmax_results_.contiguous();
|
||||
|
||||
//output grads is a 4d tensor with dimensions [batches, attn_heads, seq_len, seq_len]
|
||||
const int batches = output_grads.size(0);
|
||||
const int attn_heads = output_grads.size(1);
|
||||
const int query_seq_len = output_grads.size(2);
|
||||
const int key_seq_len = output_grads.size(3);
|
||||
|
||||
void* output_grads_ptr = static_cast<void*>(output_grads.data_ptr());
|
||||
|
||||
//Softmax Grad
|
||||
DISPATCH_HALF_AND_BFLOAT(
|
||||
output_grads_.scalar_type(),
|
||||
"dispatch_scaled_masked_softmax_backward",
|
||||
dispatch_scaled_masked_softmax_backward<scalar_t, scalar_t, float>(
|
||||
reinterpret_cast<scalar_t*>(output_grads_ptr),
|
||||
reinterpret_cast<scalar_t*>(output_grads_ptr),
|
||||
reinterpret_cast<scalar_t const*>(softmax_results.data_ptr()),
|
||||
scale_factor,
|
||||
query_seq_len,
|
||||
key_seq_len,
|
||||
batches,
|
||||
attn_heads);
|
||||
);
|
||||
|
||||
//backward pass is completely in-place
|
||||
return output_grads;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,59 @@
|
|||
/*This code from NVIDIA Megatron:
|
||||
* with minor changes. */
|
||||
|
||||
#include <cuda_fp16.h>
|
||||
#include <torch/extension.h>
|
||||
#include <vector>
|
||||
|
||||
namespace multihead_attn {
|
||||
namespace fused_softmax {
|
||||
namespace scaled_upper_triang_masked_softmax {
|
||||
|
||||
torch::Tensor fwd_cuda(
|
||||
torch::Tensor const& input,
|
||||
float scale_factor);
|
||||
|
||||
torch::Tensor bwd_cuda(
|
||||
torch::Tensor const& output_grads,
|
||||
torch::Tensor const& softmax_results,
|
||||
float scale_factor);
|
||||
|
||||
torch::Tensor fwd(torch::Tensor const& input, float scale_factor) {
|
||||
AT_ASSERTM(input.dim() == 3, "expected 3D tensor");
|
||||
AT_ASSERTM((input.scalar_type() == at::ScalarType::Half) ||
|
||||
(input.scalar_type() == at::ScalarType::BFloat16),
|
||||
"Only fp16 and bf16 are supported");
|
||||
|
||||
return fwd_cuda(input, scale_factor);
|
||||
}
|
||||
|
||||
torch::Tensor bwd(
|
||||
torch::Tensor const& output_grads,
|
||||
torch::Tensor const& softmax_results,
|
||||
float scale_factor) {
|
||||
|
||||
AT_ASSERTM(output_grads.dim() == 3, "expected 3D tensor");
|
||||
AT_ASSERTM(softmax_results.dim() == 3, "expected 3D tensor");
|
||||
|
||||
AT_ASSERTM((output_grads.scalar_type() == at::ScalarType::Half) ||
|
||||
(output_grads.scalar_type() == at::ScalarType::BFloat16),
|
||||
"Only fp16 and bf16 are supported");
|
||||
AT_ASSERTM((softmax_results.scalar_type() == at::ScalarType::Half) ||
|
||||
(softmax_results.scalar_type() == at::ScalarType::BFloat16),
|
||||
"Only fp16 and bf16 are supported");
|
||||
|
||||
return bwd_cuda(output_grads, softmax_results, scale_factor);
|
||||
}
|
||||
|
||||
} // end namespace scaled_upper_triang_masked_softmax
|
||||
} // end namespace fused_softmax
|
||||
} // end namespace multihead_attn
|
||||
|
||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||
m.def("forward",
|
||||
&multihead_attn::fused_softmax::scaled_upper_triang_masked_softmax::fwd,
|
||||
"Self Multihead Attention scaled, time masked softmax -- Forward.");
|
||||
m.def("backward",
|
||||
&multihead_attn::fused_softmax::scaled_upper_triang_masked_softmax::bwd,
|
||||
"Self Multihead Attention scaled, time masked softmax -- Backward.");
|
||||
}
|
|
@ -0,0 +1,500 @@
|
|||
/*This code from NVIDIA Megatron:
|
||||
* with minor changes. */
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <assert.h>
|
||||
#include <cuda_fp16.h>
|
||||
#include <cfloat>
|
||||
#include <limits>
|
||||
#include <stdint.h>
|
||||
#include <c10/macros/Macros.h>
|
||||
|
||||
namespace {
|
||||
|
||||
template <typename Datatype, int ELEMENTS_PER_LDG>
|
||||
__device__ __inline__ void copy_vector(Datatype *dst, const Datatype *src);
|
||||
|
||||
template <>
|
||||
__device__ __inline__ void copy_vector<c10::BFloat16, 1>(c10::BFloat16 *dst, const c10::BFloat16 *src) { *dst = *src; }
|
||||
|
||||
template <>
|
||||
__device__ __inline__ void copy_vector<c10::BFloat16, 4>(c10::BFloat16 *dst, const c10::BFloat16 *src) { *((float2*) dst) = *((float2*) src); }
|
||||
|
||||
template <>
|
||||
__device__ __inline__ void copy_vector<c10::Half, 1>(c10::Half *dst, const c10::Half *src) { *dst = *src; }
|
||||
|
||||
template <>
|
||||
__device__ __inline__ void copy_vector<c10::Half, 4>(c10::Half *dst, const c10::Half *src) { *((float2*) dst) = *((float2*) src); }
|
||||
|
||||
template <>
|
||||
__device__ __inline__ void copy_vector<uint8_t, 1>(uint8_t *dst, const uint8_t *src) { *dst = *src; }
|
||||
|
||||
template <>
|
||||
__device__ __inline__ void copy_vector<uint8_t, 4>(uint8_t *dst, const uint8_t *src) {*((half2*) dst) = *((half2*) src); }
|
||||
|
||||
template <typename Datatype, int ELEMENTS_PER_LDG>
|
||||
__device__ __inline__ void copy_zero_vector(Datatype *dst);
|
||||
|
||||
template <>
|
||||
__device__ __inline__ void copy_zero_vector<c10::BFloat16, 1>(c10::BFloat16 *dst) { *dst = 0.0; }
|
||||
|
||||
template <>
|
||||
__device__ __inline__ void copy_zero_vector<c10::BFloat16, 4>(c10::BFloat16 *dst) { *((float2*) dst) = make_float2(0.0f, 0.0f); }
|
||||
|
||||
template <>
|
||||
__device__ __inline__ void copy_zero_vector<c10::Half, 1>(c10::Half *dst) { *dst = 0.0; }
|
||||
|
||||
template <>
|
||||
__device__ __inline__ void copy_zero_vector<c10::Half, 4>(c10::Half *dst) { *((float2*) dst) = make_float2(0.0f, 0.0f); }
|
||||
|
||||
|
||||
int log2_ceil(int value) {
|
||||
int log2_value = 0;
|
||||
while ((1 << log2_value) < value) ++log2_value;
|
||||
return log2_value;
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
struct Add {
|
||||
__device__ __forceinline__ T operator()(T a, T b) const {
|
||||
return a + b;
|
||||
}
|
||||
};
|
||||
|
||||
template<typename T>
|
||||
struct Max {
|
||||
__device__ __forceinline__ T operator()(T a, T b) const {
|
||||
return a < b ? b : a;
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
__device__ __forceinline__ T WARP_SHFL_XOR_NATIVE(T value, int laneMask, int width = warpSize, unsigned int mask = 0xffffffff)
|
||||
{
|
||||
#if CUDA_VERSION >= 9000
|
||||
return __shfl_xor_sync(mask, value, laneMask, width);
|
||||
#else
|
||||
return __shfl_xor(value, laneMask, width);
|
||||
#endif
|
||||
}
|
||||
|
||||
template <typename acc_t, int WARP_BATCH, int WARP_SIZE, template<typename> class ReduceOp>
|
||||
__device__ __forceinline__ void warp_reduce(acc_t* sum) {
|
||||
ReduceOp<acc_t> r;
|
||||
#pragma unroll
|
||||
for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < WARP_BATCH; ++i) {
|
||||
acc_t b = WARP_SHFL_XOR_NATIVE(sum[i], offset, WARP_SIZE);
|
||||
sum[i] = r(sum[i], b);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/*
|
||||
* Extended softmax (from native aten pytorch) with following additional features
|
||||
* 1) input scaling
|
||||
* 2) Implicit time (diagonal masking)
|
||||
*/
|
||||
template <typename input_t, typename output_t, typename acc_t, int log2_elements>
|
||||
__global__ void scaled_upper_triang_masked_softmax_warp_forward(
|
||||
output_t *dst,
|
||||
const input_t *src,
|
||||
const acc_t scale,
|
||||
int micro_batch_size,
|
||||
int stride,
|
||||
int element_count)
|
||||
{
|
||||
// WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and
|
||||
// warp_size of method warp_softmax_forward_kernel.
|
||||
constexpr int next_power_of_two = 1 << log2_elements;
|
||||
constexpr int WARP_SIZE = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE;
|
||||
constexpr int WARP_ITERATIONS = next_power_of_two / WARP_SIZE;
|
||||
constexpr int WARP_BATCH = (next_power_of_two <= 128) ? 2 : 1;
|
||||
constexpr int ELEMENTS_PER_LDG_STG = (WARP_ITERATIONS < 4) ? 1 : 4;
|
||||
|
||||
int first_batch = (blockDim.y * blockIdx.y + threadIdx.y) * gridDim.x * WARP_BATCH + blockIdx.x;
|
||||
int local_seq = blockIdx.x + 1;
|
||||
int warp_iteration_limit = (local_seq + ELEMENTS_PER_LDG_STG * WARP_SIZE - 1)/ WARP_SIZE;
|
||||
|
||||
// micro_batch_size might not be a multiple of WARP_BATCH. Check how
|
||||
// many batches have to computed within this WARP.
|
||||
int local_batches = micro_batch_size - first_batch;
|
||||
if (local_batches > WARP_BATCH)
|
||||
local_batches = WARP_BATCH;
|
||||
|
||||
// there might be multiple batches per warp. compute the index within the batch
|
||||
int local_idx = threadIdx.x;
|
||||
|
||||
src += first_batch * stride + ELEMENTS_PER_LDG_STG * local_idx;
|
||||
dst += first_batch * stride + ELEMENTS_PER_LDG_STG * local_idx;
|
||||
|
||||
// load data from global memory
|
||||
acc_t elements[WARP_BATCH][WARP_ITERATIONS];
|
||||
input_t temp_data[ELEMENTS_PER_LDG_STG];
|
||||
#pragma unroll
|
||||
for (int i = 0; i < WARP_BATCH; ++i) {
|
||||
int batch_element_count = (i >= local_batches) ? 0 : local_seq;
|
||||
|
||||
#pragma unroll
|
||||
for (int it = 0; it < WARP_ITERATIONS; it+=ELEMENTS_PER_LDG_STG) {
|
||||
int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE;
|
||||
|
||||
if (element_index < batch_element_count) {
|
||||
copy_vector<input_t, ELEMENTS_PER_LDG_STG>(temp_data, src + i*element_count*stride + it*WARP_SIZE);
|
||||
|
||||
#pragma unroll
|
||||
for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) {
|
||||
if ((element_index + element) < batch_element_count) {
|
||||
elements[i][it+element] = (acc_t)temp_data[element] * scale;
|
||||
} else {
|
||||
elements[i][it + element] = -std::numeric_limits<acc_t>::infinity();
|
||||
}
|
||||
}
|
||||
} else {
|
||||
#pragma unroll
|
||||
for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) {
|
||||
elements[i][it + element] = -std::numeric_limits<acc_t>::infinity();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// compute max_value
|
||||
acc_t max_value[WARP_BATCH];
|
||||
#pragma unroll
|
||||
for (int i = 0; i < WARP_BATCH; ++i) {
|
||||
max_value[i] = elements[i][0];
|
||||
#pragma unroll
|
||||
for (int it = 1; it < WARP_ITERATIONS; ++it) {
|
||||
max_value[i] = (max_value[i] > elements[i][it]) ? max_value[i] : elements[i][it];
|
||||
}
|
||||
}
|
||||
warp_reduce<acc_t, WARP_BATCH, WARP_SIZE, Max>(max_value);
|
||||
|
||||
acc_t sum[WARP_BATCH] { 0.0f };
|
||||
#pragma unroll
|
||||
for (int i = 0; i < WARP_BATCH; ++i) {
|
||||
#pragma unroll
|
||||
for (int it = 0; it < WARP_ITERATIONS; ++it) {
|
||||
if (it < warp_iteration_limit) {
|
||||
elements[i][it] = std::exp((elements[i][it] - max_value[i]));
|
||||
sum[i] += elements[i][it];
|
||||
}
|
||||
}
|
||||
}
|
||||
warp_reduce<acc_t, WARP_BATCH, WARP_SIZE, Add>(sum);
|
||||
|
||||
// store result
|
||||
output_t out[ELEMENTS_PER_LDG_STG];
|
||||
#pragma unroll
|
||||
for (int i = 0; i < WARP_BATCH; ++i) {
|
||||
if (i >= local_batches)
|
||||
break;
|
||||
#pragma unroll
|
||||
for (int it = 0; it < WARP_ITERATIONS; it+=ELEMENTS_PER_LDG_STG) {
|
||||
int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE;
|
||||
|
||||
if (element_index < local_seq) {
|
||||
|
||||
#pragma unroll
|
||||
for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) {
|
||||
if (element_index + element < local_seq) {
|
||||
out[element] = elements[i][it + element] / sum[i];
|
||||
} else {
|
||||
out[element] = 0;
|
||||
}
|
||||
}
|
||||
copy_vector<output_t, ELEMENTS_PER_LDG_STG>(dst + i * element_count * stride + it * WARP_SIZE, out);
|
||||
} else if (element_index < element_count) {
|
||||
copy_zero_vector<output_t, ELEMENTS_PER_LDG_STG>(dst + i * element_count * stride + it * WARP_SIZE);
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename input_t, typename output_t, typename acc_t, int log2_elements>
|
||||
__global__ void scaled_upper_triang_masked_softmax_warp_backward(
|
||||
output_t *gradInput,
|
||||
input_t *grad,
|
||||
const input_t *output,
|
||||
acc_t scale,
|
||||
int micro_batch_size,
|
||||
int stride,
|
||||
int element_count)
|
||||
{
|
||||
// WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and
|
||||
// warp_size of method warp_softmax_backward_kernel.
|
||||
constexpr int next_power_of_two = 1 << log2_elements;
|
||||
constexpr int WARP_SIZE = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE;
|
||||
constexpr int WARP_ITERATIONS = next_power_of_two / WARP_SIZE;
|
||||
constexpr int WARP_BATCH = (next_power_of_two <= 128) ? 2 : 1;
|
||||
constexpr int ELEMENTS_PER_LDG_STG = (WARP_ITERATIONS < 4) ? 1 : 4;
|
||||
|
||||
int first_batch = (blockDim.y * blockIdx.y + threadIdx.y) * gridDim.x * WARP_BATCH + blockIdx.x;
|
||||
int local_seq = blockIdx.x + 1;
|
||||
|
||||
// micro_batch_size might not be a multiple of WARP_BATCH. Check how
|
||||
// many batches have to computed within this WARP.
|
||||
int local_batches = micro_batch_size - first_batch;
|
||||
if (local_batches > WARP_BATCH)
|
||||
local_batches = WARP_BATCH;
|
||||
|
||||
// there might be multiple batches per warp. compute the index within the batch
|
||||
int local_idx = threadIdx.x;
|
||||
|
||||
// the first element to process by the current thread
|
||||
int thread_offset = first_batch * stride + ELEMENTS_PER_LDG_STG * local_idx;
|
||||
grad += thread_offset;
|
||||
output += thread_offset;
|
||||
gradInput += thread_offset;
|
||||
|
||||
// load data from global memory
|
||||
acc_t grad_reg[WARP_BATCH][WARP_ITERATIONS] { 0.0f };
|
||||
acc_t output_reg[WARP_BATCH][WARP_ITERATIONS] { 0.0f };
|
||||
input_t temp_grad[ELEMENTS_PER_LDG_STG];
|
||||
input_t temp_output[ELEMENTS_PER_LDG_STG];
|
||||
#pragma unroll
|
||||
for (int i = 0; i < WARP_BATCH; ++i) {
|
||||
int batch_element_count = (i >= local_batches) ? 0 : local_seq;
|
||||
|
||||
#pragma unroll
|
||||
for (int it = 0; it < WARP_ITERATIONS; it+=ELEMENTS_PER_LDG_STG) {
|
||||
int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE;
|
||||
if (element_index < batch_element_count) {
|
||||
copy_vector<input_t, ELEMENTS_PER_LDG_STG>(temp_grad, grad + i * element_count * stride + it * WARP_SIZE);
|
||||
copy_vector<input_t, ELEMENTS_PER_LDG_STG>(temp_output, output + i * element_count * stride + it * WARP_SIZE);
|
||||
|
||||
#pragma unroll
|
||||
for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) {
|
||||
if (element_index + element < batch_element_count) {
|
||||
output_reg[i][it + element] = (acc_t)temp_output[element];
|
||||
}
|
||||
}
|
||||
#pragma unroll
|
||||
for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) {
|
||||
if (element_index + element < batch_element_count) {
|
||||
grad_reg[i][it + element] = (acc_t)temp_grad[element] * output_reg[i][it + element];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
acc_t sum[WARP_BATCH];
|
||||
#pragma unroll
|
||||
for (int i = 0; i < WARP_BATCH; ++i) {
|
||||
sum[i] = grad_reg[i][0];
|
||||
#pragma unroll
|
||||
for (int it = 1; it < WARP_ITERATIONS; ++it) {
|
||||
sum[i] += grad_reg[i][it];
|
||||
}
|
||||
}
|
||||
warp_reduce<acc_t, WARP_BATCH, WARP_SIZE, Add>(sum);
|
||||
|
||||
// store result
|
||||
#pragma unroll
|
||||
for (int i = 0; i < WARP_BATCH; ++i) {
|
||||
if (i >= local_batches)
|
||||
break;
|
||||
#pragma unroll
|
||||
for (int it = 0; it < WARP_ITERATIONS; it+=ELEMENTS_PER_LDG_STG) {
|
||||
int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE;
|
||||
if (element_index < element_count) {
|
||||
// compute gradients
|
||||
output_t out[ELEMENTS_PER_LDG_STG];
|
||||
#pragma unroll
|
||||
for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) {
|
||||
out[element] = (output_t)(scale * (grad_reg[i][it + element] - output_reg[i][it + element] * sum[i]));
|
||||
}
|
||||
copy_vector<output_t, ELEMENTS_PER_LDG_STG>(gradInput + i * element_count * stride + it * WARP_SIZE, out);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // end of anonymous namespace
|
||||
|
||||
template<typename input_t, typename output_t, typename acc_t>
|
||||
void dispatch_scaled_upper_triang_masked_softmax_forward(
|
||||
output_t *dst,
|
||||
const input_t *src,
|
||||
const input_t scale,
|
||||
int softmax_elements,
|
||||
int softmax_elements_stride,
|
||||
int attn_batches)
|
||||
{
|
||||
TORCH_INTERNAL_ASSERT(softmax_elements >= 0 && softmax_elements <= 2048 );
|
||||
if (softmax_elements == 0) {
|
||||
return;
|
||||
} else {
|
||||
int log2_elements = log2_ceil(softmax_elements);
|
||||
const int next_power_of_two = 1 << log2_elements;
|
||||
int seq_len = softmax_elements;
|
||||
int batch_count = attn_batches * seq_len;
|
||||
|
||||
// This value must match the WARP_SIZE constexpr value computed inside softmax_warp_forward.
|
||||
int warp_size = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE;
|
||||
|
||||
// This value must match the WARP_BATCH constexpr value computed inside softmax_warp_forward.
|
||||
int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1;
|
||||
|
||||
// use 128 threads per block to maximimize gpu utilization
|
||||
constexpr int threads_per_block = 128;
|
||||
|
||||
int warps_per_block = (threads_per_block / warp_size);
|
||||
int batches_per_block = warps_per_block * batches_per_warp;
|
||||
TORCH_INTERNAL_ASSERT(attn_batches % batches_per_block == 0);
|
||||
|
||||
int blocks_per_seq = attn_batches / batches_per_block;
|
||||
dim3 blocks(seq_len, blocks_per_seq, 1);
|
||||
dim3 threads(warp_size, warps_per_block, 1);
|
||||
// Launch code would be more elegant if C++ supported FOR CONSTEXPR
|
||||
switch (log2_elements) {
|
||||
case 0: // 1
|
||||
scaled_upper_triang_masked_softmax_warp_forward<input_t, output_t, acc_t, 0>
|
||||
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements);
|
||||
break;
|
||||
case 1: // 2
|
||||
scaled_upper_triang_masked_softmax_warp_forward<input_t, output_t, acc_t, 1>
|
||||
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements);
|
||||
break;
|
||||
case 2: // 4
|
||||
scaled_upper_triang_masked_softmax_warp_forward<input_t, output_t, acc_t, 2>
|
||||
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements);
|
||||
break;
|
||||
case 3: // 8
|
||||
scaled_upper_triang_masked_softmax_warp_forward<input_t, output_t, acc_t, 3>
|
||||
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements);
|
||||
break;
|
||||
case 4: // 16
|
||||
scaled_upper_triang_masked_softmax_warp_forward<input_t, output_t, acc_t, 4>
|
||||
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements);
|
||||
break;
|
||||
case 5: // 32
|
||||
scaled_upper_triang_masked_softmax_warp_forward<input_t, output_t, acc_t, 5>
|
||||
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements);
|
||||
break;
|
||||
case 6: // 64
|
||||
scaled_upper_triang_masked_softmax_warp_forward<input_t, output_t, acc_t, 6>
|
||||
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements);
|
||||
break;
|
||||
case 7: // 128
|
||||
scaled_upper_triang_masked_softmax_warp_forward<input_t, output_t, acc_t, 7>
|
||||
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements);
|
||||
break;
|
||||
case 8: // 256
|
||||
scaled_upper_triang_masked_softmax_warp_forward<input_t, output_t, acc_t, 8>
|
||||
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements);
|
||||
break;
|
||||
case 9: // 512
|
||||
scaled_upper_triang_masked_softmax_warp_forward<input_t, output_t, acc_t, 9>
|
||||
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements);
|
||||
break;
|
||||
case 10: // 1024
|
||||
scaled_upper_triang_masked_softmax_warp_forward<input_t, output_t, acc_t, 10>
|
||||
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements);
|
||||
break;
|
||||
case 11: // 2048
|
||||
scaled_upper_triang_masked_softmax_warp_forward<input_t, output_t, acc_t, 11>
|
||||
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements);
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template<typename input_t, typename output_t, typename acc_t>
|
||||
void dispatch_scaled_upper_triang_masked_softmax_backward(
|
||||
output_t *grad_input,
|
||||
input_t *grad,
|
||||
const input_t *output,
|
||||
const acc_t scale,
|
||||
int softmax_elements,
|
||||
int softmax_elements_stride,
|
||||
int attn_batches)
|
||||
{
|
||||
TORCH_INTERNAL_ASSERT( softmax_elements >= 0 && softmax_elements <= 2048 );
|
||||
if (softmax_elements == 0) {
|
||||
return;
|
||||
} else {
|
||||
int log2_elements = log2_ceil(softmax_elements);
|
||||
const int next_power_of_two = 1 << log2_elements;
|
||||
int seq_len = softmax_elements;
|
||||
int batch_count = attn_batches * seq_len;
|
||||
|
||||
// This value must match the WARP_SIZE constexpr value computed inside softmax_warp_backward.
|
||||
int warp_size = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE;
|
||||
|
||||
// This value must match the WARP_BATCH constexpr value computed inside softmax_warp_backward.
|
||||
int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1;
|
||||
|
||||
// use 128 threads per block to maximimize gpu utilization
|
||||
constexpr int threads_per_block = 128;
|
||||
|
||||
int warps_per_block = (threads_per_block / warp_size);
|
||||
int batches_per_block = warps_per_block * batches_per_warp;
|
||||
TORCH_INTERNAL_ASSERT(attn_batches % batches_per_block == 0);
|
||||
|
||||
int blocks_per_seq = attn_batches / batches_per_block;
|
||||
dim3 blocks(seq_len, blocks_per_seq, 1);
|
||||
dim3 threads(warp_size, warps_per_block, 1);
|
||||
// Launch code would be more elegant if C++ supported FOR CONSTEXPR
|
||||
switch (log2_elements) {
|
||||
case 0: // 1
|
||||
scaled_upper_triang_masked_softmax_warp_backward<input_t, output_t, acc_t, 0>
|
||||
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements);
|
||||
break;
|
||||
case 1: // 2
|
||||
scaled_upper_triang_masked_softmax_warp_backward<input_t, output_t, acc_t, 1>
|
||||
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements);
|
||||
break;
|
||||
case 2: // 4
|
||||
scaled_upper_triang_masked_softmax_warp_backward<input_t, output_t, acc_t, 2>
|
||||
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements);
|
||||
break;
|
||||
case 3: // 8
|
||||
scaled_upper_triang_masked_softmax_warp_backward<input_t, output_t, acc_t, 3>
|
||||
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements);
|
||||
break;
|
||||
case 4: // 16
|
||||
scaled_upper_triang_masked_softmax_warp_backward<input_t, output_t, acc_t, 4>
|
||||
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements);
|
||||
break;
|
||||
case 5: // 32
|
||||
scaled_upper_triang_masked_softmax_warp_backward<input_t, output_t, acc_t, 5>
|
||||
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements);
|
||||
break;
|
||||
case 6: // 64
|
||||
scaled_upper_triang_masked_softmax_warp_backward<input_t, output_t, acc_t, 6>
|
||||
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements);
|
||||
break;
|
||||
case 7: // 128
|
||||
scaled_upper_triang_masked_softmax_warp_backward<input_t, output_t, acc_t, 7>
|
||||
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements);
|
||||
break;
|
||||
case 8: // 256
|
||||
scaled_upper_triang_masked_softmax_warp_backward<input_t, output_t, acc_t, 8>
|
||||
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements);
|
||||
break;
|
||||
case 9: // 512
|
||||
scaled_upper_triang_masked_softmax_warp_backward<input_t, output_t, acc_t, 9>
|
||||
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements);
|
||||
break;
|
||||
case 10: // 1024
|
||||
scaled_upper_triang_masked_softmax_warp_backward<input_t, output_t, acc_t, 10>
|
||||
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements);
|
||||
break;
|
||||
case 11: // 2048
|
||||
scaled_upper_triang_masked_softmax_warp_backward<input_t, output_t, acc_t, 11>
|
||||
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements);
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,85 @@
|
|||
/*This code from NVIDIA Megatron:
|
||||
* with minor changes. */
|
||||
|
||||
#include <ATen/ATen.h>
|
||||
#include <cuda.h>
|
||||
#include <cuda_runtime.h>
|
||||
#include <cuda_fp16.h>
|
||||
#include <cuda_profiler_api.h>
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <torch/extension.h>
|
||||
#include "scaled_upper_triang_masked_softmax.h"
|
||||
#include "type_shim.h"
|
||||
|
||||
namespace multihead_attn {
|
||||
namespace fused_softmax {
|
||||
namespace scaled_upper_triang_masked_softmax {
|
||||
|
||||
torch::Tensor fwd_cuda(
|
||||
torch::Tensor const& input,
|
||||
float scale_factor)
|
||||
{
|
||||
// input is a 3d tensor with dimensions [attn_batches, seq_len, seq_len]
|
||||
const int attn_batches = input.size(0);
|
||||
const int seq_len = input.size(1);
|
||||
TORCH_INTERNAL_ASSERT(seq_len <= 2048);
|
||||
|
||||
// Output
|
||||
auto act_options = input.options().requires_grad(false);
|
||||
torch::Tensor softmax_results =
|
||||
torch::empty({attn_batches, seq_len, seq_len}, act_options);
|
||||
|
||||
// Softmax Intermediate Result Ptr
|
||||
void* input_ptr = static_cast<void*>(input.data_ptr());
|
||||
void* softmax_results_ptr = static_cast<void*>(softmax_results.data_ptr());
|
||||
|
||||
DISPATCH_HALF_AND_BFLOAT(
|
||||
input.scalar_type(),
|
||||
"dispatch_scaled_upper_triang_masked_softmax_forward",
|
||||
dispatch_scaled_upper_triang_masked_softmax_forward<scalar_t, scalar_t, float>(
|
||||
reinterpret_cast<scalar_t*>(softmax_results_ptr),
|
||||
reinterpret_cast<const scalar_t*>(input_ptr),
|
||||
scale_factor,
|
||||
seq_len,
|
||||
seq_len,
|
||||
attn_batches);
|
||||
);
|
||||
return softmax_results;
|
||||
}
|
||||
|
||||
|
||||
torch::Tensor bwd_cuda(
|
||||
torch::Tensor const& output_grads_,
|
||||
torch::Tensor const& softmax_results_,
|
||||
float scale_factor) {
|
||||
|
||||
auto output_grads = output_grads_.contiguous();
|
||||
auto softmax_results = softmax_results_.contiguous();
|
||||
|
||||
//output grads is a 3d tensor with dimensions [attn_batches, seq_len, seq_len]
|
||||
const int attn_batches = output_grads.size(0);
|
||||
const int seq_len = output_grads.size(1);
|
||||
TORCH_INTERNAL_ASSERT(output_grads.size(1) == output_grads.size(2));
|
||||
|
||||
void* output_grads_ptr = static_cast<void*>(output_grads.data_ptr());
|
||||
|
||||
//Softmax Grad
|
||||
DISPATCH_HALF_AND_BFLOAT(
|
||||
output_grads_.scalar_type(),
|
||||
"dispatch_scaled_upper_triang_masked_softmax_backward",
|
||||
dispatch_scaled_upper_triang_masked_softmax_backward<scalar_t, scalar_t, float>(
|
||||
reinterpret_cast<scalar_t*>(output_grads_ptr),
|
||||
reinterpret_cast<scalar_t*>(output_grads_ptr),
|
||||
reinterpret_cast<scalar_t const*>(softmax_results.data_ptr()),
|
||||
scale_factor,
|
||||
seq_len,
|
||||
seq_len,
|
||||
attn_batches);
|
||||
);
|
||||
|
||||
//backward pass is completely in-place
|
||||
return output_grads;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,73 @@
|
|||
#include <ATen/ATen.h>
|
||||
#include "compat.h"
|
||||
|
||||
|
||||
#define DISPATCH_HALF_AND_BFLOAT(TYPE, NAME, ...) \
|
||||
switch(TYPE) \
|
||||
{ \
|
||||
case at::ScalarType::Half: \
|
||||
{ \
|
||||
using scalar_t = at::Half; \
|
||||
__VA_ARGS__; \
|
||||
break; \
|
||||
} \
|
||||
case at::ScalarType::BFloat16: \
|
||||
{ \
|
||||
using scalar_t = at::BFloat16; \
|
||||
__VA_ARGS__; \
|
||||
break; \
|
||||
} \
|
||||
default: \
|
||||
AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
|
||||
}
|
||||
|
||||
|
||||
|
||||
#define DISPATCH_FLOAT_HALF_AND_BFLOAT_INOUT_TYPES(TYPEIN, TYPEOUT, NAME, ...) \
|
||||
switch(TYPEIN) \
|
||||
{ \
|
||||
case at::ScalarType::Float: \
|
||||
{ \
|
||||
using scalar_t_in = float; \
|
||||
switch(TYPEOUT) \
|
||||
{ \
|
||||
case at::ScalarType::Float: \
|
||||
{ \
|
||||
using scalar_t_out = float; \
|
||||
__VA_ARGS__; \
|
||||
break; \
|
||||
} \
|
||||
case at::ScalarType::Half: \
|
||||
{ \
|
||||
using scalar_t_out = at::Half; \
|
||||
__VA_ARGS__; \
|
||||
break; \
|
||||
} \
|
||||
case at::ScalarType::BFloat16: \
|
||||
{ \
|
||||
using scalar_t_out = at::BFloat16; \
|
||||
__VA_ARGS__; \
|
||||
break; \
|
||||
} \
|
||||
default: \
|
||||
AT_ERROR(#NAME, " not implemented for '", toString(TYPEOUT), "'"); \
|
||||
} \
|
||||
break; \
|
||||
} \
|
||||
case at::ScalarType::Half: \
|
||||
{ \
|
||||
using scalar_t_in = at::Half; \
|
||||
using scalar_t_out = at::Half; \
|
||||
__VA_ARGS__; \
|
||||
break; \
|
||||
} \
|
||||
case at::ScalarType::BFloat16: \
|
||||
{ \
|
||||
using scalar_t_in = at::BFloat16; \
|
||||
using scalar_t_out = at::BFloat16; \
|
||||
__VA_ARGS__; \
|
||||
break; \
|
||||
} \
|
||||
default: \
|
||||
AT_ERROR(#NAME, " not implemented for '", toString(TYPEIN), "'"); \
|
||||
}
|
|
@ -0,0 +1,69 @@
|
|||
"""This code is from NVIDIA apex:
|
||||
https://github.com/NVIDIA/apex
|
||||
with some changes. """
|
||||
|
||||
import numbers
|
||||
import torch
|
||||
from torch.nn.parameter import Parameter
|
||||
from torch.nn import init
|
||||
import importlib
|
||||
|
||||
global colossal_layer_norm_cuda
|
||||
colossal_layer_norm_cuda = None
|
||||
|
||||
|
||||
class FusedLayerNormAffineFunction(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, input, weight, bias, normalized_shape, eps):
|
||||
|
||||
ctx.normalized_shape = normalized_shape
|
||||
ctx.eps = eps
|
||||
input_ = input.contiguous()
|
||||
weight_ = weight.contiguous()
|
||||
bias_ = bias.contiguous()
|
||||
output, mean, invvar = colossal_layer_norm_cuda.forward_affine(
|
||||
input_, ctx.normalized_shape, weight_, bias_, ctx.eps)
|
||||
ctx.save_for_backward(input_, weight_, bias_, mean, invvar)
|
||||
|
||||
return output
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
|
||||
input_, weight_, bias_, mean, invvar = ctx.saved_tensors
|
||||
grad_input = grad_weight = grad_bias = None
|
||||
grad_input, grad_weight, grad_bias \
|
||||
= colossal_layer_norm_cuda.backward_affine(
|
||||
grad_output.contiguous(), mean, invvar,
|
||||
input_, ctx.normalized_shape,
|
||||
weight_, bias_, ctx.eps)
|
||||
|
||||
return grad_input, grad_weight, grad_bias, None, None
|
||||
|
||||
|
||||
class MixedFusedLayerNorm(torch.nn.Module):
|
||||
|
||||
def __init__(self, normalized_shape, eps=1e-5):
|
||||
super(MixedFusedLayerNorm, self).__init__()
|
||||
|
||||
global colossal_layer_norm_cuda
|
||||
colossal_layer_norm_cuda = importlib.import_module("colossal_layer_norm_cuda")
|
||||
|
||||
if isinstance(normalized_shape, numbers.Integral):
|
||||
normalized_shape = (normalized_shape,)
|
||||
self.normalized_shape = torch.Size(normalized_shape)
|
||||
self.eps = eps
|
||||
self.weight = Parameter(torch.Tensor(*normalized_shape))
|
||||
self.bias = Parameter(torch.Tensor(*normalized_shape))
|
||||
self.reset_parameters()
|
||||
|
||||
def reset_parameters(self):
|
||||
|
||||
init.ones_(self.weight)
|
||||
init.zeros_(self.bias)
|
||||
|
||||
def forward(self, input):
|
||||
|
||||
return FusedLayerNormAffineFunction.apply(input, self.weight, self.bias,
|
||||
self.normalized_shape, self.eps)
|
|
@ -0,0 +1,270 @@
|
|||
import math
|
||||
import importlib
|
||||
from dataclasses import dataclass
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.autograd import Function
|
||||
|
||||
|
||||
def check_config(config):
|
||||
if config.hidden_size % config.nhead != 0:
|
||||
raise Exception(f"hidden_size % nhead != 0")
|
||||
|
||||
factor = 8 if config.fp16 else 4
|
||||
upbound = factor * 1024 * 4
|
||||
if config.hidden_size > upbound:
|
||||
# as required by ln backward kernel currently
|
||||
raise Exception(f"hidden_size > {upbound}")
|
||||
|
||||
head_dim = config.hidden_size // config.nhead
|
||||
if head_dim % factor != 0:
|
||||
# as required by reshape kernel
|
||||
raise Exception(f"head_dim({head_dim}) % {factor} != 0")
|
||||
|
||||
|
||||
def calc_offset(sizes):
|
||||
offsets = [0]
|
||||
tmp = 0
|
||||
for x in sizes:
|
||||
tmp += x
|
||||
offsets.append(tmp)
|
||||
return offsets
|
||||
|
||||
|
||||
colossal_multihead_attention = None
|
||||
|
||||
@dataclass
|
||||
class Config:
|
||||
max_batch_tokens: int # max batch token numbers
|
||||
max_seq_len: int # max sequence length
|
||||
hidden_size: int # size of transformer hidden layers
|
||||
nhead: int # number of heads in attention
|
||||
attn_prob_dropout_ratio: float # attention score dropout ratio
|
||||
hidden_dropout_ratio: float # dropout ration before residual
|
||||
norm_first: bool # norm_first
|
||||
fp16: bool # fp16 presion
|
||||
|
||||
|
||||
class MultiHeadAttention1DFunc(Function):
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, input, input_mask, in_proj_weight, in_proj_bias, out_proj_weight,
|
||||
out_proj_bias, norm_weight, norm_bias, config):
|
||||
cuda_module = colossal_multihead_attention
|
||||
forward_func = (cuda_module.multihead_attention_fw_fp16
|
||||
if config.fp16 else cuda_module.multihead_attention_fw_fp32)
|
||||
if config.fp16:
|
||||
input = input.to(torch.half)
|
||||
input_mask = input_mask.to(torch.half)
|
||||
|
||||
(output,) = forward_func(config.layer_id, input, input_mask, in_proj_weight, in_proj_bias,
|
||||
out_proj_weight, out_proj_bias, norm_weight, norm_bias,
|
||||
config.training, config.norm_first)
|
||||
|
||||
if config.is_grad_enabled and config.training:
|
||||
ctx.save_for_backward(output, input, input_mask, in_proj_weight, in_proj_bias,
|
||||
out_proj_weight, out_proj_bias, norm_weight, norm_bias)
|
||||
ctx.config = config
|
||||
return output
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
assert ctx.config.training
|
||||
|
||||
cuda_module = colossal_multihead_attention
|
||||
backward_func = (cuda_module.multihead_attention_bw_fp16
|
||||
if ctx.config.fp16 else cuda_module.multihead_attention_bw_fp32)
|
||||
|
||||
output, input, input_mask, in_proj_weight, in_proj_bias, out_proj_weight, \
|
||||
out_proj_bias, norm_weight, norm_bias = ctx.saved_tensors
|
||||
|
||||
grad_input = None
|
||||
grad_in_proj_weight = None
|
||||
grad_in_proj_bias = None
|
||||
grad_out_proj_weight = None
|
||||
grad_out_proj_bias = None
|
||||
grad_norm_weight = None
|
||||
grad_norm_bias = None
|
||||
|
||||
if ctx.config.fp16:
|
||||
grad_output = grad_output.to(torch.half)
|
||||
output = output.to(torch.half)
|
||||
input = input.to(torch.half)
|
||||
input_mask = input_mask.to(torch.half)
|
||||
grad_input, grad_in_proj_weight, grad_in_proj_bias, grad_out_proj_weight, \
|
||||
grad_out_proj_bias, grad_norm_weight, grad_norm_bias = backward_func(
|
||||
ctx.config.layer_id, grad_output, output, input, input_mask, in_proj_weight, \
|
||||
in_proj_bias, out_proj_weight, out_proj_bias, norm_weight, norm_bias)
|
||||
|
||||
return (grad_input, None, grad_in_proj_weight, grad_in_proj_bias, grad_out_proj_weight,
|
||||
grad_out_proj_bias, grad_norm_weight, grad_norm_bias, None)
|
||||
|
||||
|
||||
class MultiHeadAttention(nn.Module):
|
||||
"""Initialize the MultiHeadAttention.
|
||||
|
||||
Static variable:
|
||||
layer_id: The layer-index counter starting from 0 and incrementing by 1 every time a layer object is instantiated,
|
||||
e.g. if a model has 24 transformer layers, layer_id goes from 0 to 23.
|
||||
Arguments:
|
||||
hidden_size: Total dimension of hidden_size.
|
||||
nhead: Number of parallel attention heads.
|
||||
batch_size: Batch Size for one foward
|
||||
max_seq_len: Max length of input sequence
|
||||
dropout: Dropout probability
|
||||
norm_first: perform LayerNorms before attention
|
||||
"""
|
||||
|
||||
layer_id = 0
|
||||
|
||||
def __init__(self,
|
||||
hidden_size,
|
||||
nhead,
|
||||
batch_size,
|
||||
max_seq_len,
|
||||
dropout=0.0,
|
||||
norm_first=False,
|
||||
fp16=True,
|
||||
pg=None):
|
||||
super(MultiHeadAttention, self).__init__()
|
||||
|
||||
self.config = Config(batch_size * max_seq_len, max_seq_len, hidden_size, nhead, dropout,
|
||||
dropout, norm_first, fp16)
|
||||
check_config(self.config)
|
||||
self.pg = pg
|
||||
self.pg_size = 1
|
||||
if self.pg:
|
||||
self.pg_size = pg.size()
|
||||
self.config.layer_id = MultiHeadAttention.layer_id
|
||||
MultiHeadAttention.layer_id = MultiHeadAttention.layer_id + 1
|
||||
|
||||
# Load cuda modules if needed
|
||||
global colossal_multihead_attention
|
||||
if colossal_multihead_attention is None:
|
||||
colossal_multihead_attention = importlib.import_module("colossal_multihead_attention")
|
||||
|
||||
# create the layer in cuda kernels.
|
||||
cuda_module = colossal_multihead_attention
|
||||
create_layer_func = (cuda_module.create_multihead_attention_fp16
|
||||
if self.config.fp16 else cuda_module.create_multihead_attention_fp32)
|
||||
|
||||
create_layer_func(
|
||||
self.config.layer_id,
|
||||
self.config.max_batch_tokens,
|
||||
self.config.max_seq_len,
|
||||
self.config.hidden_size,
|
||||
self.config.nhead,
|
||||
self.config.attn_prob_dropout_ratio,
|
||||
self.config.hidden_dropout_ratio,
|
||||
self.config.norm_first,
|
||||
self.pg,
|
||||
)
|
||||
|
||||
hs = self.config.hidden_size
|
||||
|
||||
self.precision = torch.float32
|
||||
if self.config.fp16:
|
||||
self.precision = torch.half
|
||||
|
||||
self.hs_per_rank = int(hs / self.pg_size)
|
||||
|
||||
self.in_proj_weight = nn.Parameter(torch.Tensor(3, self.hs_per_rank, hs))
|
||||
self.in_proj_bias = nn.Parameter(torch.Tensor(3, self.hs_per_rank))
|
||||
self.out_proj_weight = nn.Parameter(torch.Tensor(hs, self.hs_per_rank))
|
||||
self.out_proj_bias = nn.Parameter(torch.Tensor(hs))
|
||||
self.norm_weight = nn.Parameter(torch.Tensor(hs))
|
||||
self.norm_bias = nn.Parameter(torch.Tensor(hs))
|
||||
|
||||
self.reset_parameters()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def calc_bound(self, w):
|
||||
fan_in, _ = nn.init._calculate_fan_in_and_fan_out(w)
|
||||
bound = 1.0 / math.sqrt(fan_in)
|
||||
return bound
|
||||
|
||||
def reset_parameters(self):
|
||||
hs = self.config.hidden_size
|
||||
|
||||
nn.init.zeros_(self.out_proj_bias)
|
||||
|
||||
nn.init.ones_(self.norm_weight)
|
||||
nn.init.zeros_(self.norm_bias)
|
||||
|
||||
if self.pg_size > 1:
|
||||
rank_in_pg = torch.distributed.get_rank(self.pg)
|
||||
attn_qkvw_global = torch.empty(hs * 3, hs)
|
||||
attn_qkvb_global = torch.empty(hs * 3)
|
||||
nn.init.xavier_uniform_(attn_qkvw_global, 1.0 / math.sqrt(2.0))
|
||||
bound = self.calc_bound(attn_qkvw_global)
|
||||
nn.init.uniform_(attn_qkvb_global, -bound, bound)
|
||||
|
||||
attn_qkvw_global = attn_qkvw_global.cuda()
|
||||
attn_qkvb_global = attn_qkvb_global.cuda()
|
||||
torch.distributed.broadcast(attn_qkvw_global, src=0, group=self.pg)
|
||||
torch.distributed.broadcast(attn_qkvb_global, src=0, group=self.pg)
|
||||
attn_qkvw_global = attn_qkvw_global.cpu()
|
||||
attn_qkvb_global = attn_qkvb_global.cpu()
|
||||
|
||||
with torch.no_grad():
|
||||
self.in_proj_weight.copy_(
|
||||
attn_qkvw_global.view(3, hs, hs)[:,
|
||||
int(hs * rank_in_pg /
|
||||
self.pg_size):int(hs * (rank_in_pg + 1) /
|
||||
self.pg_size), :])
|
||||
self.in_proj_bias.copy_(
|
||||
attn_qkvb_global.view(3, hs)[:,
|
||||
int(hs * rank_in_pg /
|
||||
self.pg_size):int(hs * (rank_in_pg + 1) /
|
||||
self.pg_size)])
|
||||
|
||||
attn_ow_global = torch.empty(hs, hs)
|
||||
nn.init.xavier_uniform_(attn_ow_global, 1.0)
|
||||
attn_ow_global = attn_ow_global.cuda()
|
||||
torch.distributed.broadcast(attn_ow_global, src=0, group=self.pg)
|
||||
attn_ow_global = attn_ow_global.cpu()
|
||||
with torch.no_grad():
|
||||
self.out_proj_weight.copy_(attn_ow_global[:,
|
||||
int(hs * rank_in_pg /
|
||||
self.pg_size):int(hs * (rank_in_pg + 1) /
|
||||
self.pg_size)])
|
||||
|
||||
else:
|
||||
attn_qkvw = self.in_proj_weight.view(-1, hs)
|
||||
nn.init.xavier_uniform_(attn_qkvw, 1.0 / math.sqrt(2.0))
|
||||
bound = self.calc_bound(attn_qkvw)
|
||||
nn.init.uniform_(self.in_proj_bias, -bound, bound)
|
||||
|
||||
nn.init.xavier_uniform_(self.out_proj_weight, 1.0)
|
||||
|
||||
def state_dict(self, destination=None, prefix="", keep_vars=False):
|
||||
destination = torch.nn.Module.state_dict(self,
|
||||
destination=destination,
|
||||
prefix=prefix,
|
||||
keep_vars=keep_vars)
|
||||
return destination
|
||||
|
||||
def forward(self, hidden_states, encoder_padding_mask):
|
||||
self.config.training = self.training
|
||||
self.config.is_grad_enabled = torch.is_grad_enabled()
|
||||
hidden_states = hidden_states.contiguous()
|
||||
encoder_padding_mask = ((encoder_padding_mask * -1e8).type_as(hidden_states).contiguous())
|
||||
|
||||
bs, sl, dim = hidden_states.size()
|
||||
if bs * sl > self.config.max_batch_tokens:
|
||||
raise ValueError(
|
||||
f"Batch token numbers {bs * sl} exceeds the limit {self.config.max_batch_tokens}.")
|
||||
if sl > self.config.max_seq_len:
|
||||
raise ValueError(f"Sequence length {sl} exceeds the limit {self.config.max_seq_len}.")
|
||||
if len(encoder_padding_mask.size()) == 1:
|
||||
assert bs == 1 and sl == encoder_padding_mask.size(0)
|
||||
else:
|
||||
assert bs == encoder_padding_mask.size(0) and sl == encoder_padding_mask.size(1)
|
||||
|
||||
output = MultiHeadAttention1DFunc.apply(hidden_states, encoder_padding_mask,
|
||||
self.in_proj_weight, self.in_proj_bias,
|
||||
self.out_proj_weight, self.out_proj_bias,
|
||||
self.norm_weight, self.norm_bias, self.config)
|
||||
|
||||
return output.to(self.precision)
|
|
@ -0,0 +1,184 @@
|
|||
"""This code from NVIDIA Megatron
|
||||
with some changes. """
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import enum
|
||||
|
||||
|
||||
class AttnMaskType(enum.Enum):
|
||||
padding = 1
|
||||
causal = 2
|
||||
|
||||
|
||||
class ScaledUpperTriangMaskedSoftmax(torch.autograd.Function):
|
||||
"""
|
||||
Fused operation which performs following three operations in sequence
|
||||
1. Scale the tensor.
|
||||
2. Apply upper triangular mask (typically used in gpt models).
|
||||
3. Perform softmax.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, inputs, scale):
|
||||
import colossal_scaled_upper_triang_masked_softmax
|
||||
|
||||
scale_t = torch.tensor([scale])
|
||||
softmax_results = colossal_scaled_upper_triang_masked_softmax.forward(
|
||||
inputs, scale_t[0]
|
||||
)
|
||||
|
||||
ctx.save_for_backward(softmax_results, scale_t)
|
||||
return softmax_results
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, output_grads):
|
||||
import colossal_scaled_upper_triang_masked_softmax
|
||||
|
||||
softmax_results, scale_t = ctx.saved_tensors
|
||||
input_grads = colossal_scaled_upper_triang_masked_softmax.backward(
|
||||
output_grads, softmax_results, scale_t[0]
|
||||
)
|
||||
|
||||
return input_grads, None
|
||||
|
||||
|
||||
class ScaledMaskedSoftmax(torch.autograd.Function):
|
||||
"""
|
||||
Fused operation which performs following three operations in sequence
|
||||
1. Scale the tensor.
|
||||
2. Apply the mask.
|
||||
3. Perform softmax.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, inputs, mask, scale):
|
||||
import colossal_scaled_masked_softmax
|
||||
|
||||
scale_t = torch.tensor([scale])
|
||||
|
||||
softmax_results = colossal_scaled_masked_softmax.forward(inputs, mask, scale_t[0])
|
||||
ctx.save_for_backward(softmax_results, scale_t)
|
||||
return softmax_results
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, output_grads):
|
||||
import colossal_scaled_masked_softmax
|
||||
|
||||
softmax_results, scale_t = ctx.saved_tensors
|
||||
|
||||
input_grads = colossal_scaled_masked_softmax.backward(
|
||||
output_grads, softmax_results, scale_t[0]
|
||||
)
|
||||
return input_grads, None, None
|
||||
|
||||
|
||||
class FusedScaleMaskSoftmax(nn.Module):
|
||||
"""
|
||||
fused operation: scaling + mask + softmax
|
||||
|
||||
Arguments:
|
||||
input_in_fp16: flag to indicate if input in fp16 data format.
|
||||
input_in_bf16: flag to indicate if input in bf16 data format.
|
||||
attn_mask_type: attention mask type (pad or causal)
|
||||
scaled_masked_softmax_fusion: flag to indicate user want to use softmax fusion
|
||||
mask_func: mask function to be applied.
|
||||
softmax_in_fp32: if true, softmax in performed at fp32 precision.
|
||||
scale: scaling factor used in input tensor scaling.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
input_in_fp16,
|
||||
input_in_bf16,
|
||||
attn_mask_type,
|
||||
scaled_masked_softmax_fusion,
|
||||
mask_func,
|
||||
softmax_in_fp32,
|
||||
scale,
|
||||
):
|
||||
super(FusedScaleMaskSoftmax, self).__init__()
|
||||
self.input_in_fp16 = input_in_fp16
|
||||
self.input_in_bf16 = input_in_bf16
|
||||
assert not (
|
||||
self.input_in_fp16 and self.input_in_bf16
|
||||
), "both fp16 and bf16 flags cannot be active at the same time."
|
||||
self.input_in_float16 = self.input_in_fp16 or self.input_in_bf16
|
||||
self.attn_mask_type = attn_mask_type
|
||||
self.scaled_masked_softmax_fusion = scaled_masked_softmax_fusion
|
||||
self.mask_func = mask_func
|
||||
self.softmax_in_fp32 = softmax_in_fp32
|
||||
self.scale = scale
|
||||
|
||||
assert (
|
||||
self.scale is None or softmax_in_fp32
|
||||
), "softmax should be in fp32 when scaled"
|
||||
|
||||
def forward(self, input, mask):
|
||||
# [b, np, sq, sk]
|
||||
assert input.dim() == 4
|
||||
|
||||
if self.is_kernel_available(mask, *input.size()):
|
||||
return self.forward_fused_softmax(input, mask)
|
||||
else:
|
||||
return self.forward_torch_softmax(input, mask)
|
||||
|
||||
def is_kernel_available(self, mask, b, np, sq, sk):
|
||||
attn_batches = b * np
|
||||
|
||||
if (
|
||||
self.scaled_masked_softmax_fusion # user want to fuse
|
||||
and self.input_in_float16 # input must be fp16
|
||||
and mask is not None # mask tensor must not be None
|
||||
and 16 < sk <= 2048 # sk must be 16 ~ 2048
|
||||
and sq % 4 == 0 # sq must be divisor of 4
|
||||
and attn_batches % 4 == 0 # np * b must be divisor of 4
|
||||
):
|
||||
if 0 <= sk <= 2048:
|
||||
batch_per_block = self.get_batch_per_block(sq, sk, b, np)
|
||||
|
||||
if self.attn_mask_type == AttnMaskType.causal:
|
||||
if attn_batches % batch_per_block == 0:
|
||||
return True
|
||||
else:
|
||||
if sq % batch_per_block == 0:
|
||||
return True
|
||||
return False
|
||||
|
||||
def forward_fused_softmax(self, input, mask):
|
||||
b, np, sq, sk = input.size()
|
||||
scale = self.scale if self.scale is not None else 1.0
|
||||
|
||||
if self.attn_mask_type == AttnMaskType.causal:
|
||||
assert sq == sk, "causal mask is only for self attention"
|
||||
|
||||
# input is 3D tensor (attn_batches, sq, sk)
|
||||
input = input.view(-1, sq, sk)
|
||||
probs = ScaledUpperTriangMaskedSoftmax.apply(input, scale)
|
||||
return probs.view(b, np, sq, sk)
|
||||
else:
|
||||
# input is 4D tensor (b, np, sq, sk)
|
||||
return ScaledMaskedSoftmax.apply(input, mask, scale)
|
||||
|
||||
def forward_torch_softmax(self, input, mask):
|
||||
if self.input_in_float16 and self.softmax_in_fp32:
|
||||
input = input.float()
|
||||
|
||||
if self.scale is not None:
|
||||
input = input * self.scale
|
||||
mask_output = self.mask_func(input, mask) if mask is not None else input
|
||||
probs = torch.nn.Softmax(dim=-1)(mask_output)
|
||||
|
||||
if self.input_in_float16 and self.softmax_in_fp32:
|
||||
if self.input_in_fp16:
|
||||
probs = probs.half()
|
||||
else:
|
||||
probs = probs.bfloat16()
|
||||
|
||||
return probs
|
||||
|
||||
@staticmethod
|
||||
def get_batch_per_block(sq, sk, b, np):
|
||||
import colossal_scaled_masked_softmax
|
||||
|
||||
return colossal_scaled_masked_softmax.get_batch_per_block(sq, sk, b, np)
|
|
@ -0,0 +1,3 @@
|
|||
from .option import _set_jit_fusion_options
|
||||
|
||||
_set_jit_fusion_options()
|
|
@ -0,0 +1,24 @@
|
|||
import torch
|
||||
|
||||
|
||||
def bias_dropout_add(x, bias, residual, prob, training):
|
||||
# type: (Tensor, Tensor, Tensor, float, bool) -> Tensor
|
||||
out = torch.nn.functional.dropout(x + bias, p=prob, training=training)
|
||||
out = residual + out
|
||||
return out
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
def bias_dropout_add_fused_train(x: torch.Tensor,
|
||||
bias: torch.Tensor,
|
||||
residual: torch.Tensor,
|
||||
prob: float) -> torch.Tensor:
|
||||
return bias_dropout_add(x, bias, residual, prob, True)
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
def bias_dropout_add_fused_inference(x: torch.Tensor,
|
||||
bias: torch.Tensor,
|
||||
residual: torch.Tensor,
|
||||
prob: float) -> torch.Tensor:
|
||||
return bias_dropout_add(x, bias, residual, prob, False)
|
|
@ -0,0 +1,41 @@
|
|||
import torch
|
||||
|
||||
|
||||
###### BIAS GELU FUSION/ NO AUTOGRAD ################
|
||||
# 1/sqrt(2*pi)-> 0.3989423
|
||||
# 1/sqrt(2) -> 0.70710678
|
||||
# sqrt(2/pi) -> 0.79788456
|
||||
# this function is tanh approximation of gelu
|
||||
# actual gelu is:
|
||||
# x * 0.5 * (1.0 + torch.erf(x * 0.70710678))
|
||||
|
||||
@torch.jit.script
|
||||
def bias_gelu(bias, y):
|
||||
x = bias + y
|
||||
return x * 0.5 * (1.0 + torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x)))
|
||||
|
||||
# gradient of tanh approximation of gelu
|
||||
# gradient of actual gelu is:
|
||||
# 0.5 * (1. + torch.erf(x * 0.70710678)) + 0.3989423 * x * torch.exp(-0.5 * x * x)
|
||||
@torch.jit.script
|
||||
def bias_gelu_back(g, bias, y):
|
||||
x = bias + y
|
||||
tanh_out = torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x))
|
||||
# sqrt(2/pi) * 3 * 0.044715 -> 0.1070322243
|
||||
ff = 0.5 * x * ((1 - tanh_out * tanh_out) * (0.79788456 + 0.1070322243 * x * x)) + 0.5 * (1 + tanh_out)
|
||||
return ff*g
|
||||
|
||||
class GeLUFunction(torch.autograd.Function):
|
||||
@staticmethod
|
||||
# bias is an optional argument
|
||||
def forward(ctx, input, bias):
|
||||
ctx.save_for_backward(input, bias)
|
||||
return bias_gelu(bias, input)
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
input, bias = ctx.saved_tensors
|
||||
tmp = bias_gelu_back(grad_output, bias, input)
|
||||
return tmp, tmp
|
||||
|
||||
bias_gelu_impl = GeLUFunction.apply
|
|
@ -0,0 +1,28 @@
|
|||
import torch
|
||||
|
||||
JIT_OPTIONS_SET = False
|
||||
|
||||
def _set_jit_fusion_options():
|
||||
"""Set PyTorch JIT layer fusion options."""
|
||||
global JIT_OPTIONS_SET
|
||||
if JIT_OPTIONS_SET == False:
|
||||
# flags required to enable jit fusion kernels
|
||||
TORCH_MAJOR = int(torch.__version__.split('.')[0])
|
||||
TORCH_MINOR = int(torch.__version__.split('.')[1])
|
||||
if (TORCH_MAJOR > 1) or (TORCH_MAJOR == 1 and TORCH_MINOR >= 10):
|
||||
# nvfuser
|
||||
torch._C._jit_set_profiling_executor(True)
|
||||
torch._C._jit_set_profiling_mode(True)
|
||||
torch._C._jit_override_can_fuse_on_cpu(False)
|
||||
torch._C._jit_override_can_fuse_on_gpu(False)
|
||||
torch._C._jit_set_texpr_fuser_enabled(False)
|
||||
torch._C._jit_set_nvfuser_enabled(True)
|
||||
torch._C._debug_set_autodiff_subgraph_inlining(False)
|
||||
else:
|
||||
# legacy pytorch fuser
|
||||
torch._C._jit_set_profiling_mode(False)
|
||||
torch._C._jit_set_profiling_executor(False)
|
||||
torch._C._jit_override_can_fuse_on_cpu(True)
|
||||
torch._C._jit_override_can_fuse_on_gpu(True)
|
||||
|
||||
JIT_OPTIONS_SET = True
|
1
setup.py
1
setup.py
|
@ -131,5 +131,6 @@ setup(
|
|||
description='An integrated large-scale model training system with efficient parallelization techniques',
|
||||
ext_modules=ext_modules,
|
||||
cmdclass={'build_ext': BuildExtension} if ext_modules else {},
|
||||
package_data={'colossalai': ['kernel/cuda_native/csrc/*']},
|
||||
install_requires=install_requires,
|
||||
)
|
Loading…
Reference in New Issue