optimize rmsnorm: add vectorized elementwise op, feat loop unrolling (#5441)

pull/5451/head
Steve Luo 2024-03-12 17:48:02 +08:00 committed by GitHub
parent 368a2aa543
commit b699f54007
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 406 additions and 38 deletions

View File

@ -0,0 +1,122 @@
/*
* This code from NVIDIA FasterTransformer:
* https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/utils/cuda_type_utils.cuh
*/
#pragma once
#include <cuda.h>
#include <cuda_fp16.h>
template <typename T>
inline __device__ T add(T a, T b) {
return a + b;
}
template <>
inline __device__ half2 add(half2 a, half2 b) {
return __hadd2(a, b);
}
template <>
inline __device__ half add(half a, half b) {
return __hadd(a, b);
}
#if ENABLE_BF16
template <>
inline __device__ __nv_bfloat162 add(__nv_bfloat162 a, __nv_bfloat162 b) {
return bf16hadd2(a, b);
}
template <>
inline __device__ __nv_bfloat16 add(__nv_bfloat16 a, __nv_bfloat16 b) {
return bf16hadd(a, b);
}
#endif // ENABLE_BF16
template <typename T>
inline __device__ T mul(T a, T b, T c) {
return a * b * c;
}
template <>
inline __device__ half2 mul(half2 a, half2 b, half2 c) {
return __hmul2(__hmul2(a, b), c);
}
#if ENABLE_BF16
template <>
inline __device__ __nv_bfloat16 mul(__nv_bfloat16 a, __nv_bfloat16 b,
__nv_bfloat16 c) {
return bf16hmul(a, b, c);
}
inline __device__ __nv_bfloat162 mul(__nv_bfloat162 a, __nv_bfloat162 b,
__nv_bfloat162 c) {
return bf16hmul2(a, b, c);
}
#endif // ENABLE_BF16
template <typename T_OUT, typename T_IN>
__device__ inline T_OUT cuda_cast(T_IN val) {
return val;
}
template <>
__device__ inline float2 cuda_cast<float2, int2>(int2 val) {
return make_float2(val.x, val.y);
}
template <>
__device__ inline float2 cuda_cast<float2, float>(float val) {
return make_float2(val, val);
}
template <>
__device__ inline float2 cuda_cast<float2, half2>(half2 val) {
return __half22float2(val);
}
template <>
__device__ inline half2 cuda_cast<half2, float2>(float2 val) {
return __float22half2_rn(val);
}
template <>
__device__ inline half2 cuda_cast<half2, float>(float val) {
return __float2half2_rn(val);
}
template <>
__device__ inline half2 cuda_cast<half2, half>(half val) {
return __half2half2(val);
}
template <>
__device__ inline float cuda_cast<float, half>(half val) {
return __half2float(val);
}
// Get type2 from type or vice versa (applied to half and bfloat16)
template <typename T>
struct TypeConverter {
using Type = half2;
}; // keep for generality
template <>
struct TypeConverter<half2> {
using Type = at::Half;
};
template <>
struct TypeConverter<at::Half> {
using Type = half2;
};
#if ENABLE_BF16
template <>
struct TypeConverter<__nv_bfloat162> {
using Type = at::BFloat16;
};
template <>
struct TypeConverter<at::BFloat16> {
using Type = __nv_bfloat162;
};
#endif // ENABLE_BF16

View File

@ -1,5 +1,5 @@
/*This code from VLLM: /*This code from VLLM:
* https://github.com/vllm-project/vllm/ * https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/kernels/layernorm_kernels.cu
* with minor changes. */ * with minor changes. */
#include <ATen/cuda/CUDAContext.h> #include <ATen/cuda/CUDAContext.h>
@ -10,8 +10,10 @@
#include "block_reduce.h" #include "block_reduce.h"
#include "../common/micros.h" #include "../common/micros.h"
#include "../common/cuda_type_utils.h"
template<typename scalar_t> // optimized for half and bf16
template<typename scalar_t, int unroll_factor>
__global__ void rms_layernorm_kernel( __global__ void rms_layernorm_kernel(
scalar_t* __restrict__ out, // [..., hidden_size] scalar_t* __restrict__ out, // [..., hidden_size]
const scalar_t* __restrict__ input, // [..., hidden_size] const scalar_t* __restrict__ input, // [..., hidden_size]
@ -19,8 +21,9 @@ __global__ void rms_layernorm_kernel(
const float epsilon, const float epsilon,
const int num_tokens, const int num_tokens,
const int hidden_size) { const int hidden_size) {
using scalar2_t = typename TypeConverter<scalar_t>::Type;
__shared__ float s_variance; __shared__ float s_variance;
float variance = 0.0f;
/* /*
* since the open-sourced LLM's hidden dimensions mainly range from * since the open-sourced LLM's hidden dimensions mainly range from
* 4096 (LLAMA-7B) to 8192 (LLAMA-65B), we thus set the supported * 4096 (LLAMA-7B) to 8192 (LLAMA-65B), we thus set the supported
@ -29,11 +32,22 @@ __global__ void rms_layernorm_kernel(
* will cause problems for extremely large models, such as * will cause problems for extremely large models, such as
* Megatron-Turing NLG 530B with hidden dimensions up to 20480 * Megatron-Turing NLG 530B with hidden dimensions up to 20480
*/ */
float x_local[8]; scalar2_t x_local[4];
for (int idx = threadIdx.x, cnt = 0; idx < hidden_size; idx += blockDim.x, cnt++) { scalar2_t* out_ptr = (scalar2_t*)out;
x_local[cnt] = (float) input[blockIdx.x * hidden_size + idx]; const scalar2_t* input_ptr = (scalar2_t*)input;
variance += x_local[cnt] * x_local[cnt]; const scalar2_t* weight_ptr = (const scalar2_t*)weight;
float variance = 0.0f;
int row_offset = blockIdx.x * hidden_size / 2;
#pragma unroll unroll_factor
for (int idx = threadIdx.x, cnt = 0; idx < hidden_size / 2; idx += blockDim.x, cnt++) {
int id = row_offset + idx;
x_local[cnt] = input_ptr[id];
float v1 = cuda_cast<float>(x_local[cnt].x);
float v2 = cuda_cast<float>(x_local[cnt].y);
variance += v1 * v1 + v2 * v2;
} }
variance = blockReduceSum<float>(variance); variance = blockReduceSum<float>(variance);
if (threadIdx.x == 0) { if (threadIdx.x == 0) {
@ -41,16 +55,19 @@ __global__ void rms_layernorm_kernel(
} }
__syncthreads(); __syncthreads();
for (int idx = threadIdx.x, cnt = 0; idx < hidden_size; idx += blockDim.x, cnt++) { scalar2_t s_variance_2 = cuda_cast<scalar2_t>(s_variance);
out[blockIdx.x * hidden_size + idx] = ((scalar_t) (x_local[cnt] * s_variance)) * weight[idx]; #pragma unroll unroll_factor
for (int idx = threadIdx.x, cnt = 0; idx < hidden_size / 2; idx += blockDim.x, cnt++) {
int id = row_offset + idx;
out_ptr[id] = mul(x_local[cnt], s_variance_2, weight_ptr[idx]);
} }
} }
template<typename scalar_t> template<int unroll_factor>
__global__ void fused_add_rms_layernorm_kernel( __global__ void rms_layernorm_kernel(
scalar_t* __restrict__ input, // [..., hidden_size] float* __restrict__ out, // [..., hidden_size]
scalar_t* __restrict__ residual, // [..., hidden_size] const float* __restrict__ input, // [..., hidden_size]
const scalar_t* __restrict__ weight, // [hidden_size] const float* __restrict__ weight, // [hidden_size]
const float epsilon, const float epsilon,
const int num_tokens, const int num_tokens,
const int hidden_size) { const int hidden_size) {
@ -58,11 +75,13 @@ __global__ void fused_add_rms_layernorm_kernel(
float variance = 0.0f; float variance = 0.0f;
float x_local[8]; float x_local[8];
int row_offset = blockIdx.x * hidden_size;
#pragma unroll unroll_factor
for (int idx = threadIdx.x, cnt = 0; idx < hidden_size; idx += blockDim.x, cnt++) { for (int idx = threadIdx.x, cnt = 0; idx < hidden_size; idx += blockDim.x, cnt++) {
x_local[cnt] = (float) input[blockIdx.x * hidden_size + idx]; int id = row_offset + idx;
x_local[cnt] += (float) residual[blockIdx.x * hidden_size + idx]; x_local[cnt] = input[id];
variance += x_local[cnt] * x_local[cnt]; variance += x_local[cnt] * x_local[cnt];
residual[blockIdx.x * hidden_size + idx] = (scalar_t) x_local[cnt];
} }
variance = blockReduceSum<float>(variance); variance = blockReduceSum<float>(variance);
if (threadIdx.x == 0) { if (threadIdx.x == 0) {
@ -70,8 +89,89 @@ __global__ void fused_add_rms_layernorm_kernel(
} }
__syncthreads(); __syncthreads();
#pragma unroll unroll_factor
for (int idx = threadIdx.x, cnt = 0; idx < hidden_size; idx += blockDim.x, cnt++) { for (int idx = threadIdx.x, cnt = 0; idx < hidden_size; idx += blockDim.x, cnt++) {
input[blockIdx.x * hidden_size + idx] = ((scalar_t) (x_local[cnt] * s_variance)) * weight[idx]; int id = row_offset + idx;
out[id] = ((x_local[cnt] * s_variance)) * weight[idx];
}
}
// optimized for half and bf16
template<typename scalar_t, int unroll_factor>
__global__ void fused_add_rms_layernorm_kernel(
scalar_t* __restrict__ input, // [..., hidden_size]
scalar_t* __restrict__ residual, // [..., hidden_size]
const scalar_t* __restrict__ weight, // [hidden_size]
const float epsilon,
const int num_tokens,
const int hidden_size) {
using scalar2_t = typename TypeConverter<scalar_t>::Type;
__shared__ float s_variance;
scalar2_t x_local[4];
scalar2_t* input_ptr = (scalar2_t*)input;
scalar2_t* residual_ptr = (scalar2_t*)residual;
const scalar2_t* weight_ptr = (const scalar2_t*)weight;
float variance = 0.0f;
int row_offset = blockIdx.x * hidden_size / 2;
#pragma unroll unroll_factor
for (int idx = threadIdx.x, cnt = 0; idx < hidden_size / 2; idx += blockDim.x, cnt++) {
int id = row_offset + idx;
x_local[cnt] = input_ptr[id];
x_local[cnt] = add(x_local[cnt], residual_ptr[id]);
float v1 = cuda_cast<float>(x_local[cnt].x);
float v2 = cuda_cast<float>(x_local[cnt].y);
variance += v1 * v1 + v2 * v2;
residual_ptr[id] = x_local[cnt];
}
variance = blockReduceSum<float>(variance);
if (threadIdx.x == 0) {
s_variance = rsqrtf(variance / hidden_size + epsilon);
}
__syncthreads();
scalar2_t s_variance_2 = cuda_cast<scalar2_t>(s_variance);
#pragma unroll unroll_factor
for (int idx = threadIdx.x, cnt = 0; idx < hidden_size / 2; idx += blockDim.x, cnt++) {
int id = row_offset + idx;
input_ptr[id] = mul(x_local[cnt], s_variance_2, weight_ptr[idx]);
}
}
template<int unroll_factor>
__global__ void fused_add_rms_layernorm_kernel(
float* __restrict__ input, // [..., hidden_size]
float* __restrict__ residual, // [..., hidden_size]
const float* __restrict__ weight, // [hidden_size]
const float epsilon,
const int num_tokens,
const int hidden_size) {
__shared__ float s_variance;
float variance = 0.0f;
float x_local[8];
int row_offset = blockIdx.x * hidden_size;
#pragma unroll unroll_factor
for (int idx = threadIdx.x, cnt = 0; idx < hidden_size; idx += blockDim.x, cnt++) {
int id = row_offset + idx;
x_local[cnt] = input[id];
x_local[cnt] += residual[id];
variance += x_local[cnt] * x_local[cnt];
residual[id] = x_local[cnt];
}
variance = blockReduceSum<float>(variance);
if (threadIdx.x == 0) {
s_variance = rsqrtf(variance / hidden_size + epsilon);
}
__syncthreads();
#pragma unroll unroll_factor
for (int idx = threadIdx.x, cnt = 0; idx < hidden_size; idx += blockDim.x, cnt++) {
int id = row_offset + idx;
input[id] = ((x_local[cnt] * s_variance)) * weight[idx];
} }
} }
@ -88,16 +188,89 @@ void rms_layernorm(
const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
DISPATCH_FLOAT_HALF_AND_BFLOAT( if (num_tokens >= 512) {
input.scalar_type(), if (input.scalar_type() == at::ScalarType::Float) {
"rms_layernorm_kernel", DISPATCH_FLOAT_HALF_AND_BFLOAT(
rms_layernorm_kernel<scalar_t><<<grid, block, 0, stream>>>( input.scalar_type(),
out.data_ptr<scalar_t>(), "rms_layernorm_kernel",
input.data_ptr<scalar_t>(), rms_layernorm_kernel<scalar_t, 8><<<grid, hidden_size / 8, 0, stream>>>(
weight.data_ptr<scalar_t>(), out.data_ptr<scalar_t>(),
epsilon, input.data_ptr<scalar_t>(),
num_tokens, weight.data_ptr<scalar_t>(),
hidden_size);) epsilon,
num_tokens,
hidden_size);)
} else {
DISPATCH_FLOAT_HALF_AND_BFLOAT(
input.scalar_type(),
"rms_layernorm_kernel",
rms_layernorm_kernel<scalar_t, 4><<<grid, hidden_size / 8, 0, stream>>>(
out.data_ptr<scalar_t>(),
input.data_ptr<scalar_t>(),
weight.data_ptr<scalar_t>(),
epsilon,
num_tokens,
hidden_size);)
}
} else {
int unroll_factor = (hidden_size + block.x - 1) / block.x;
if (input.scalar_type() != at::ScalarType::Float) {
block.x = std::min(hidden_size / 2, 1024);
int unroll_factor = (hidden_size / 2 + block.x - 1) / block.x;
}
switch (unroll_factor) {
case 1:
DISPATCH_FLOAT_HALF_AND_BFLOAT(
input.scalar_type(),
"rms_layernorm_kernel",
rms_layernorm_kernel<scalar_t, 1><<<grid, block, 0, stream>>>(
out.data_ptr<scalar_t>(),
input.data_ptr<scalar_t>(),
weight.data_ptr<scalar_t>(),
epsilon,
num_tokens,
hidden_size);)
break;
case 2:
DISPATCH_FLOAT_HALF_AND_BFLOAT(
input.scalar_type(),
"rms_layernorm_kernel",
rms_layernorm_kernel<scalar_t, 2><<<grid, block, 0, stream>>>(
out.data_ptr<scalar_t>(),
input.data_ptr<scalar_t>(),
weight.data_ptr<scalar_t>(),
epsilon,
num_tokens,
hidden_size);)
break;
case 4:
DISPATCH_FLOAT_HALF_AND_BFLOAT(
input.scalar_type(),
"rms_layernorm_kernel",
rms_layernorm_kernel<scalar_t, 4><<<grid, block, 0, stream>>>(
out.data_ptr<scalar_t>(),
input.data_ptr<scalar_t>(),
weight.data_ptr<scalar_t>(),
epsilon,
num_tokens,
hidden_size);)
break;
case 8:
DISPATCH_FLOAT_HALF_AND_BFLOAT(
input.scalar_type(),
"rms_layernorm_kernel",
rms_layernorm_kernel<scalar_t, 8><<<grid, block, 0, stream>>>(
out.data_ptr<scalar_t>(),
input.data_ptr<scalar_t>(),
weight.data_ptr<scalar_t>(),
epsilon,
num_tokens,
hidden_size);)
break;
default:
AT_ERROR("unroll_factor must be 1, 2, 4 or 8");
}
}
} }
void fused_add_rms_layernorm( void fused_add_rms_layernorm(
@ -113,14 +286,87 @@ void fused_add_rms_layernorm(
const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
DISPATCH_FLOAT_HALF_AND_BFLOAT( if (num_tokens >= 512) {
input.scalar_type(), if (input.scalar_type() == at::ScalarType::Float) {
"fused_add_rms_layernorm_kernel", DISPATCH_FLOAT_HALF_AND_BFLOAT(
fused_add_rms_layernorm_kernel<scalar_t><<<grid, block, 0, stream>>>( input.scalar_type(),
input.data_ptr<scalar_t>(), "fused_add_rms_layernorm_kernel",
residual.data_ptr<scalar_t>(), fused_add_rms_layernorm_kernel<scalar_t, 8><<<grid, hidden_size / 8, 0, stream>>>(
weight.data_ptr<scalar_t>(), input.data_ptr<scalar_t>(),
epsilon, residual.data_ptr<scalar_t>(),
num_tokens, weight.data_ptr<scalar_t>(),
hidden_size);) epsilon,
num_tokens,
hidden_size);)
} else {
DISPATCH_FLOAT_HALF_AND_BFLOAT(
input.scalar_type(),
"fused_add_rms_layernorm_kernel",
fused_add_rms_layernorm_kernel<scalar_t, 4><<<grid, hidden_size / 8, 0, stream>>>(
input.data_ptr<scalar_t>(),
residual.data_ptr<scalar_t>(),
weight.data_ptr<scalar_t>(),
epsilon,
num_tokens,
hidden_size);)
}
} else {
int unroll_factor = (hidden_size + block.x - 1) / block.x;
if (input.scalar_type() != at::ScalarType::Float) {
block.x = std::min(hidden_size / 2, 1024);
int unroll_factor = (hidden_size / 2 + block.x - 1) / block.x;
}
switch (unroll_factor) {
case 1:
DISPATCH_FLOAT_HALF_AND_BFLOAT(
input.scalar_type(),
"fused_add_rms_layernorm_kernel",
fused_add_rms_layernorm_kernel<scalar_t, 1><<<grid, block, 0, stream>>>(
input.data_ptr<scalar_t>(),
residual.data_ptr<scalar_t>(),
weight.data_ptr<scalar_t>(),
epsilon,
num_tokens,
hidden_size);)
break;
case 2:
DISPATCH_FLOAT_HALF_AND_BFLOAT(
input.scalar_type(),
"fused_add_rms_layernorm_kernel",
fused_add_rms_layernorm_kernel<scalar_t, 2><<<grid, block, 0, stream>>>(
input.data_ptr<scalar_t>(),
residual.data_ptr<scalar_t>(),
weight.data_ptr<scalar_t>(),
epsilon,
num_tokens,
hidden_size);)
break;
case 4:
DISPATCH_FLOAT_HALF_AND_BFLOAT(
input.scalar_type(),
"fused_add_rms_layernorm_kernel",
fused_add_rms_layernorm_kernel<scalar_t, 4><<<grid, block, 0, stream>>>(
input.data_ptr<scalar_t>(),
residual.data_ptr<scalar_t>(),
weight.data_ptr<scalar_t>(),
epsilon,
num_tokens,
hidden_size);)
break;
case 8:
DISPATCH_FLOAT_HALF_AND_BFLOAT(
input.scalar_type(),
"fused_add_rms_layernorm_kernel",
fused_add_rms_layernorm_kernel<scalar_t, 8><<<grid, block, 0, stream>>>(
input.data_ptr<scalar_t>(),
residual.data_ptr<scalar_t>(),
weight.data_ptr<scalar_t>(),
epsilon,
num_tokens,
hidden_size);)
break;
default:
AT_ERROR("unroll_factor must be 1, 2, 4 or 8");
}
}
} }