diff --git a/extensions/csrc/common/cuda_type_utils.h b/extensions/csrc/common/cuda_type_utils.h new file mode 100644 index 000000000..35d4c1492 --- /dev/null +++ b/extensions/csrc/common/cuda_type_utils.h @@ -0,0 +1,122 @@ +/* + * This code from NVIDIA FasterTransformer: + * https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/utils/cuda_type_utils.cuh + */ + +#pragma once + +#include +#include + +template +inline __device__ T add(T a, T b) { + return a + b; +} + +template <> +inline __device__ half2 add(half2 a, half2 b) { + return __hadd2(a, b); +} + +template <> +inline __device__ half add(half a, half b) { + return __hadd(a, b); +} + +#if ENABLE_BF16 +template <> +inline __device__ __nv_bfloat162 add(__nv_bfloat162 a, __nv_bfloat162 b) { + return bf16hadd2(a, b); +} + +template <> +inline __device__ __nv_bfloat16 add(__nv_bfloat16 a, __nv_bfloat16 b) { + return bf16hadd(a, b); +} + +#endif // ENABLE_BF16 + +template +inline __device__ T mul(T a, T b, T c) { + return a * b * c; +} + +template <> +inline __device__ half2 mul(half2 a, half2 b, half2 c) { + return __hmul2(__hmul2(a, b), c); +} + +#if ENABLE_BF16 +template <> +inline __device__ __nv_bfloat16 mul(__nv_bfloat16 a, __nv_bfloat16 b, + __nv_bfloat16 c) { + return bf16hmul(a, b, c); +} + +inline __device__ __nv_bfloat162 mul(__nv_bfloat162 a, __nv_bfloat162 b, + __nv_bfloat162 c) { + return bf16hmul2(a, b, c); +} +#endif // ENABLE_BF16 + +template +__device__ inline T_OUT cuda_cast(T_IN val) { + return val; +} + +template <> +__device__ inline float2 cuda_cast(int2 val) { + return make_float2(val.x, val.y); +} +template <> +__device__ inline float2 cuda_cast(float val) { + return make_float2(val, val); +} +template <> +__device__ inline float2 cuda_cast(half2 val) { + return __half22float2(val); +} +template <> +__device__ inline half2 cuda_cast(float2 val) { + return __float22half2_rn(val); +} +template <> +__device__ inline half2 cuda_cast(float val) { + return __float2half2_rn(val); +} +template <> +__device__ inline half2 cuda_cast(half val) { + return __half2half2(val); +} +template <> +__device__ inline float cuda_cast(half val) { + return __half2float(val); +} + +// Get type2 from type or vice versa (applied to half and bfloat16) +template +struct TypeConverter { + using Type = half2; +}; // keep for generality + +template <> +struct TypeConverter { + using Type = at::Half; +}; + +template <> +struct TypeConverter { + using Type = half2; +}; + +#if ENABLE_BF16 +template <> +struct TypeConverter<__nv_bfloat162> { + using Type = at::BFloat16; +}; + +template <> +struct TypeConverter { + using Type = __nv_bfloat162; +}; +#endif // ENABLE_BF16 diff --git a/extensions/csrc/cuda/rms_layernorm_kernel.cu b/extensions/csrc/cuda/rms_layernorm_kernel.cu index 0ab40f9f7..0e3e4e900 100644 --- a/extensions/csrc/cuda/rms_layernorm_kernel.cu +++ b/extensions/csrc/cuda/rms_layernorm_kernel.cu @@ -1,5 +1,5 @@ /*This code from VLLM: - * https://github.com/vllm-project/vllm/ + * https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/kernels/layernorm_kernels.cu * with minor changes. */ #include @@ -10,8 +10,10 @@ #include "block_reduce.h" #include "../common/micros.h" +#include "../common/cuda_type_utils.h" -template +// optimized for half and bf16 +template __global__ void rms_layernorm_kernel( scalar_t* __restrict__ out, // [..., hidden_size] const scalar_t* __restrict__ input, // [..., hidden_size] @@ -19,8 +21,9 @@ __global__ void rms_layernorm_kernel( const float epsilon, const int num_tokens, const int hidden_size) { + using scalar2_t = typename TypeConverter::Type; __shared__ float s_variance; - float variance = 0.0f; + /* * since the open-sourced LLM's hidden dimensions mainly range from * 4096 (LLAMA-7B) to 8192 (LLAMA-65B), we thus set the supported @@ -29,11 +32,22 @@ __global__ void rms_layernorm_kernel( * will cause problems for extremely large models, such as * Megatron-Turing NLG 530B with hidden dimensions up to 20480 */ - float x_local[8]; + scalar2_t x_local[4]; - for (int idx = threadIdx.x, cnt = 0; idx < hidden_size; idx += blockDim.x, cnt++) { - x_local[cnt] = (float) input[blockIdx.x * hidden_size + idx]; - variance += x_local[cnt] * x_local[cnt]; + scalar2_t* out_ptr = (scalar2_t*)out; + const scalar2_t* input_ptr = (scalar2_t*)input; + const scalar2_t* weight_ptr = (const scalar2_t*)weight; + + float variance = 0.0f; + int row_offset = blockIdx.x * hidden_size / 2; + +#pragma unroll unroll_factor + for (int idx = threadIdx.x, cnt = 0; idx < hidden_size / 2; idx += blockDim.x, cnt++) { + int id = row_offset + idx; + x_local[cnt] = input_ptr[id]; + float v1 = cuda_cast(x_local[cnt].x); + float v2 = cuda_cast(x_local[cnt].y); + variance += v1 * v1 + v2 * v2; } variance = blockReduceSum(variance); if (threadIdx.x == 0) { @@ -41,16 +55,19 @@ __global__ void rms_layernorm_kernel( } __syncthreads(); - for (int idx = threadIdx.x, cnt = 0; idx < hidden_size; idx += blockDim.x, cnt++) { - out[blockIdx.x * hidden_size + idx] = ((scalar_t) (x_local[cnt] * s_variance)) * weight[idx]; + scalar2_t s_variance_2 = cuda_cast(s_variance); +#pragma unroll unroll_factor + for (int idx = threadIdx.x, cnt = 0; idx < hidden_size / 2; idx += blockDim.x, cnt++) { + int id = row_offset + idx; + out_ptr[id] = mul(x_local[cnt], s_variance_2, weight_ptr[idx]); } } -template -__global__ void fused_add_rms_layernorm_kernel( - scalar_t* __restrict__ input, // [..., hidden_size] - scalar_t* __restrict__ residual, // [..., hidden_size] - const scalar_t* __restrict__ weight, // [hidden_size] +template +__global__ void rms_layernorm_kernel( + float* __restrict__ out, // [..., hidden_size] + const float* __restrict__ input, // [..., hidden_size] + const float* __restrict__ weight, // [hidden_size] const float epsilon, const int num_tokens, const int hidden_size) { @@ -58,11 +75,13 @@ __global__ void fused_add_rms_layernorm_kernel( float variance = 0.0f; float x_local[8]; + int row_offset = blockIdx.x * hidden_size; + +#pragma unroll unroll_factor for (int idx = threadIdx.x, cnt = 0; idx < hidden_size; idx += blockDim.x, cnt++) { - x_local[cnt] = (float) input[blockIdx.x * hidden_size + idx]; - x_local[cnt] += (float) residual[blockIdx.x * hidden_size + idx]; + int id = row_offset + idx; + x_local[cnt] = input[id]; variance += x_local[cnt] * x_local[cnt]; - residual[blockIdx.x * hidden_size + idx] = (scalar_t) x_local[cnt]; } variance = blockReduceSum(variance); if (threadIdx.x == 0) { @@ -70,8 +89,89 @@ __global__ void fused_add_rms_layernorm_kernel( } __syncthreads(); +#pragma unroll unroll_factor for (int idx = threadIdx.x, cnt = 0; idx < hidden_size; idx += blockDim.x, cnt++) { - input[blockIdx.x * hidden_size + idx] = ((scalar_t) (x_local[cnt] * s_variance)) * weight[idx]; + int id = row_offset + idx; + out[id] = ((x_local[cnt] * s_variance)) * weight[idx]; + } +} + +// optimized for half and bf16 +template +__global__ void fused_add_rms_layernorm_kernel( + scalar_t* __restrict__ input, // [..., hidden_size] + scalar_t* __restrict__ residual, // [..., hidden_size] + const scalar_t* __restrict__ weight, // [hidden_size] + const float epsilon, + const int num_tokens, + const int hidden_size) { + using scalar2_t = typename TypeConverter::Type; + __shared__ float s_variance; + scalar2_t x_local[4]; + + scalar2_t* input_ptr = (scalar2_t*)input; + scalar2_t* residual_ptr = (scalar2_t*)residual; + const scalar2_t* weight_ptr = (const scalar2_t*)weight; + + float variance = 0.0f; + int row_offset = blockIdx.x * hidden_size / 2; + +#pragma unroll unroll_factor + for (int idx = threadIdx.x, cnt = 0; idx < hidden_size / 2; idx += blockDim.x, cnt++) { + int id = row_offset + idx; + x_local[cnt] = input_ptr[id]; + x_local[cnt] = add(x_local[cnt], residual_ptr[id]); + float v1 = cuda_cast(x_local[cnt].x); + float v2 = cuda_cast(x_local[cnt].y); + variance += v1 * v1 + v2 * v2; + residual_ptr[id] = x_local[cnt]; + } + variance = blockReduceSum(variance); + if (threadIdx.x == 0) { + s_variance = rsqrtf(variance / hidden_size + epsilon); + } + __syncthreads(); + + scalar2_t s_variance_2 = cuda_cast(s_variance); +#pragma unroll unroll_factor + for (int idx = threadIdx.x, cnt = 0; idx < hidden_size / 2; idx += blockDim.x, cnt++) { + int id = row_offset + idx; + input_ptr[id] = mul(x_local[cnt], s_variance_2, weight_ptr[idx]); + } +} + +template +__global__ void fused_add_rms_layernorm_kernel( + float* __restrict__ input, // [..., hidden_size] + float* __restrict__ residual, // [..., hidden_size] + const float* __restrict__ weight, // [hidden_size] + const float epsilon, + const int num_tokens, + const int hidden_size) { + __shared__ float s_variance; + float variance = 0.0f; + float x_local[8]; + + int row_offset = blockIdx.x * hidden_size; + +#pragma unroll unroll_factor + for (int idx = threadIdx.x, cnt = 0; idx < hidden_size; idx += blockDim.x, cnt++) { + int id = row_offset + idx; + x_local[cnt] = input[id]; + x_local[cnt] += residual[id]; + variance += x_local[cnt] * x_local[cnt]; + residual[id] = x_local[cnt]; + } + variance = blockReduceSum(variance); + if (threadIdx.x == 0) { + s_variance = rsqrtf(variance / hidden_size + epsilon); + } + __syncthreads(); + +#pragma unroll unroll_factor + for (int idx = threadIdx.x, cnt = 0; idx < hidden_size; idx += blockDim.x, cnt++) { + int id = row_offset + idx; + input[id] = ((x_local[cnt] * s_variance)) * weight[idx]; } } @@ -88,16 +188,89 @@ void rms_layernorm( const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - DISPATCH_FLOAT_HALF_AND_BFLOAT( - input.scalar_type(), - "rms_layernorm_kernel", - rms_layernorm_kernel<<>>( - out.data_ptr(), - input.data_ptr(), - weight.data_ptr(), - epsilon, - num_tokens, - hidden_size);) + if (num_tokens >= 512) { + if (input.scalar_type() == at::ScalarType::Float) { + DISPATCH_FLOAT_HALF_AND_BFLOAT( + input.scalar_type(), + "rms_layernorm_kernel", + rms_layernorm_kernel<<>>( + out.data_ptr(), + input.data_ptr(), + weight.data_ptr(), + epsilon, + num_tokens, + hidden_size);) + } else { + DISPATCH_FLOAT_HALF_AND_BFLOAT( + input.scalar_type(), + "rms_layernorm_kernel", + rms_layernorm_kernel<<>>( + out.data_ptr(), + input.data_ptr(), + weight.data_ptr(), + epsilon, + num_tokens, + hidden_size);) + } + } else { + int unroll_factor = (hidden_size + block.x - 1) / block.x; + if (input.scalar_type() != at::ScalarType::Float) { + block.x = std::min(hidden_size / 2, 1024); + int unroll_factor = (hidden_size / 2 + block.x - 1) / block.x; + } + switch (unroll_factor) { + case 1: + DISPATCH_FLOAT_HALF_AND_BFLOAT( + input.scalar_type(), + "rms_layernorm_kernel", + rms_layernorm_kernel<<>>( + out.data_ptr(), + input.data_ptr(), + weight.data_ptr(), + epsilon, + num_tokens, + hidden_size);) + break; + case 2: + DISPATCH_FLOAT_HALF_AND_BFLOAT( + input.scalar_type(), + "rms_layernorm_kernel", + rms_layernorm_kernel<<>>( + out.data_ptr(), + input.data_ptr(), + weight.data_ptr(), + epsilon, + num_tokens, + hidden_size);) + break; + case 4: + DISPATCH_FLOAT_HALF_AND_BFLOAT( + input.scalar_type(), + "rms_layernorm_kernel", + rms_layernorm_kernel<<>>( + out.data_ptr(), + input.data_ptr(), + weight.data_ptr(), + epsilon, + num_tokens, + hidden_size);) + break; + case 8: + DISPATCH_FLOAT_HALF_AND_BFLOAT( + input.scalar_type(), + "rms_layernorm_kernel", + rms_layernorm_kernel<<>>( + out.data_ptr(), + input.data_ptr(), + weight.data_ptr(), + epsilon, + num_tokens, + hidden_size);) + break; + default: + AT_ERROR("unroll_factor must be 1, 2, 4 or 8"); + } + } } void fused_add_rms_layernorm( @@ -113,14 +286,87 @@ void fused_add_rms_layernorm( const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - DISPATCH_FLOAT_HALF_AND_BFLOAT( - input.scalar_type(), - "fused_add_rms_layernorm_kernel", - fused_add_rms_layernorm_kernel<<>>( - input.data_ptr(), - residual.data_ptr(), - weight.data_ptr(), - epsilon, - num_tokens, - hidden_size);) + if (num_tokens >= 512) { + if (input.scalar_type() == at::ScalarType::Float) { + DISPATCH_FLOAT_HALF_AND_BFLOAT( + input.scalar_type(), + "fused_add_rms_layernorm_kernel", + fused_add_rms_layernorm_kernel<<>>( + input.data_ptr(), + residual.data_ptr(), + weight.data_ptr(), + epsilon, + num_tokens, + hidden_size);) + } else { + DISPATCH_FLOAT_HALF_AND_BFLOAT( + input.scalar_type(), + "fused_add_rms_layernorm_kernel", + fused_add_rms_layernorm_kernel<<>>( + input.data_ptr(), + residual.data_ptr(), + weight.data_ptr(), + epsilon, + num_tokens, + hidden_size);) + } + } else { + int unroll_factor = (hidden_size + block.x - 1) / block.x; + if (input.scalar_type() != at::ScalarType::Float) { + block.x = std::min(hidden_size / 2, 1024); + int unroll_factor = (hidden_size / 2 + block.x - 1) / block.x; + } + switch (unroll_factor) { + case 1: + DISPATCH_FLOAT_HALF_AND_BFLOAT( + input.scalar_type(), + "fused_add_rms_layernorm_kernel", + fused_add_rms_layernorm_kernel<<>>( + input.data_ptr(), + residual.data_ptr(), + weight.data_ptr(), + epsilon, + num_tokens, + hidden_size);) + break; + case 2: + DISPATCH_FLOAT_HALF_AND_BFLOAT( + input.scalar_type(), + "fused_add_rms_layernorm_kernel", + fused_add_rms_layernorm_kernel<<>>( + input.data_ptr(), + residual.data_ptr(), + weight.data_ptr(), + epsilon, + num_tokens, + hidden_size);) + break; + case 4: + DISPATCH_FLOAT_HALF_AND_BFLOAT( + input.scalar_type(), + "fused_add_rms_layernorm_kernel", + fused_add_rms_layernorm_kernel<<>>( + input.data_ptr(), + residual.data_ptr(), + weight.data_ptr(), + epsilon, + num_tokens, + hidden_size);) + break; + case 8: + DISPATCH_FLOAT_HALF_AND_BFLOAT( + input.scalar_type(), + "fused_add_rms_layernorm_kernel", + fused_add_rms_layernorm_kernel<<>>( + input.data_ptr(), + residual.data_ptr(), + weight.data_ptr(), + epsilon, + num_tokens, + hidden_size);) + break; + default: + AT_ERROR("unroll_factor must be 1, 2, 4 or 8"); + } + } }