mirror of https://github.com/hpcaitech/ColossalAI
optimize rmsnorm: add vectorized elementwise op, feat loop unrolling (#5441)
parent
368a2aa543
commit
b699f54007
|
@ -0,0 +1,122 @@
|
||||||
|
/*
|
||||||
|
* This code from NVIDIA FasterTransformer:
|
||||||
|
* https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/utils/cuda_type_utils.cuh
|
||||||
|
*/
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <cuda.h>
|
||||||
|
#include <cuda_fp16.h>
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
inline __device__ T add(T a, T b) {
|
||||||
|
return a + b;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
inline __device__ half2 add(half2 a, half2 b) {
|
||||||
|
return __hadd2(a, b);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
inline __device__ half add(half a, half b) {
|
||||||
|
return __hadd(a, b);
|
||||||
|
}
|
||||||
|
|
||||||
|
#if ENABLE_BF16
|
||||||
|
template <>
|
||||||
|
inline __device__ __nv_bfloat162 add(__nv_bfloat162 a, __nv_bfloat162 b) {
|
||||||
|
return bf16hadd2(a, b);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
inline __device__ __nv_bfloat16 add(__nv_bfloat16 a, __nv_bfloat16 b) {
|
||||||
|
return bf16hadd(a, b);
|
||||||
|
}
|
||||||
|
|
||||||
|
#endif // ENABLE_BF16
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
inline __device__ T mul(T a, T b, T c) {
|
||||||
|
return a * b * c;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
inline __device__ half2 mul(half2 a, half2 b, half2 c) {
|
||||||
|
return __hmul2(__hmul2(a, b), c);
|
||||||
|
}
|
||||||
|
|
||||||
|
#if ENABLE_BF16
|
||||||
|
template <>
|
||||||
|
inline __device__ __nv_bfloat16 mul(__nv_bfloat16 a, __nv_bfloat16 b,
|
||||||
|
__nv_bfloat16 c) {
|
||||||
|
return bf16hmul(a, b, c);
|
||||||
|
}
|
||||||
|
|
||||||
|
inline __device__ __nv_bfloat162 mul(__nv_bfloat162 a, __nv_bfloat162 b,
|
||||||
|
__nv_bfloat162 c) {
|
||||||
|
return bf16hmul2(a, b, c);
|
||||||
|
}
|
||||||
|
#endif // ENABLE_BF16
|
||||||
|
|
||||||
|
template <typename T_OUT, typename T_IN>
|
||||||
|
__device__ inline T_OUT cuda_cast(T_IN val) {
|
||||||
|
return val;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
__device__ inline float2 cuda_cast<float2, int2>(int2 val) {
|
||||||
|
return make_float2(val.x, val.y);
|
||||||
|
}
|
||||||
|
template <>
|
||||||
|
__device__ inline float2 cuda_cast<float2, float>(float val) {
|
||||||
|
return make_float2(val, val);
|
||||||
|
}
|
||||||
|
template <>
|
||||||
|
__device__ inline float2 cuda_cast<float2, half2>(half2 val) {
|
||||||
|
return __half22float2(val);
|
||||||
|
}
|
||||||
|
template <>
|
||||||
|
__device__ inline half2 cuda_cast<half2, float2>(float2 val) {
|
||||||
|
return __float22half2_rn(val);
|
||||||
|
}
|
||||||
|
template <>
|
||||||
|
__device__ inline half2 cuda_cast<half2, float>(float val) {
|
||||||
|
return __float2half2_rn(val);
|
||||||
|
}
|
||||||
|
template <>
|
||||||
|
__device__ inline half2 cuda_cast<half2, half>(half val) {
|
||||||
|
return __half2half2(val);
|
||||||
|
}
|
||||||
|
template <>
|
||||||
|
__device__ inline float cuda_cast<float, half>(half val) {
|
||||||
|
return __half2float(val);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get type2 from type or vice versa (applied to half and bfloat16)
|
||||||
|
template <typename T>
|
||||||
|
struct TypeConverter {
|
||||||
|
using Type = half2;
|
||||||
|
}; // keep for generality
|
||||||
|
|
||||||
|
template <>
|
||||||
|
struct TypeConverter<half2> {
|
||||||
|
using Type = at::Half;
|
||||||
|
};
|
||||||
|
|
||||||
|
template <>
|
||||||
|
struct TypeConverter<at::Half> {
|
||||||
|
using Type = half2;
|
||||||
|
};
|
||||||
|
|
||||||
|
#if ENABLE_BF16
|
||||||
|
template <>
|
||||||
|
struct TypeConverter<__nv_bfloat162> {
|
||||||
|
using Type = at::BFloat16;
|
||||||
|
};
|
||||||
|
|
||||||
|
template <>
|
||||||
|
struct TypeConverter<at::BFloat16> {
|
||||||
|
using Type = __nv_bfloat162;
|
||||||
|
};
|
||||||
|
#endif // ENABLE_BF16
|
|
@ -1,5 +1,5 @@
|
||||||
/*This code from VLLM:
|
/*This code from VLLM:
|
||||||
* https://github.com/vllm-project/vllm/
|
* https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/kernels/layernorm_kernels.cu
|
||||||
* with minor changes. */
|
* with minor changes. */
|
||||||
|
|
||||||
#include <ATen/cuda/CUDAContext.h>
|
#include <ATen/cuda/CUDAContext.h>
|
||||||
|
@ -10,8 +10,10 @@
|
||||||
|
|
||||||
#include "block_reduce.h"
|
#include "block_reduce.h"
|
||||||
#include "../common/micros.h"
|
#include "../common/micros.h"
|
||||||
|
#include "../common/cuda_type_utils.h"
|
||||||
|
|
||||||
template<typename scalar_t>
|
// optimized for half and bf16
|
||||||
|
template<typename scalar_t, int unroll_factor>
|
||||||
__global__ void rms_layernorm_kernel(
|
__global__ void rms_layernorm_kernel(
|
||||||
scalar_t* __restrict__ out, // [..., hidden_size]
|
scalar_t* __restrict__ out, // [..., hidden_size]
|
||||||
const scalar_t* __restrict__ input, // [..., hidden_size]
|
const scalar_t* __restrict__ input, // [..., hidden_size]
|
||||||
|
@ -19,8 +21,9 @@ __global__ void rms_layernorm_kernel(
|
||||||
const float epsilon,
|
const float epsilon,
|
||||||
const int num_tokens,
|
const int num_tokens,
|
||||||
const int hidden_size) {
|
const int hidden_size) {
|
||||||
|
using scalar2_t = typename TypeConverter<scalar_t>::Type;
|
||||||
__shared__ float s_variance;
|
__shared__ float s_variance;
|
||||||
float variance = 0.0f;
|
|
||||||
/*
|
/*
|
||||||
* since the open-sourced LLM's hidden dimensions mainly range from
|
* since the open-sourced LLM's hidden dimensions mainly range from
|
||||||
* 4096 (LLAMA-7B) to 8192 (LLAMA-65B), we thus set the supported
|
* 4096 (LLAMA-7B) to 8192 (LLAMA-65B), we thus set the supported
|
||||||
|
@ -29,11 +32,22 @@ __global__ void rms_layernorm_kernel(
|
||||||
* will cause problems for extremely large models, such as
|
* will cause problems for extremely large models, such as
|
||||||
* Megatron-Turing NLG 530B with hidden dimensions up to 20480
|
* Megatron-Turing NLG 530B with hidden dimensions up to 20480
|
||||||
*/
|
*/
|
||||||
float x_local[8];
|
scalar2_t x_local[4];
|
||||||
|
|
||||||
for (int idx = threadIdx.x, cnt = 0; idx < hidden_size; idx += blockDim.x, cnt++) {
|
scalar2_t* out_ptr = (scalar2_t*)out;
|
||||||
x_local[cnt] = (float) input[blockIdx.x * hidden_size + idx];
|
const scalar2_t* input_ptr = (scalar2_t*)input;
|
||||||
variance += x_local[cnt] * x_local[cnt];
|
const scalar2_t* weight_ptr = (const scalar2_t*)weight;
|
||||||
|
|
||||||
|
float variance = 0.0f;
|
||||||
|
int row_offset = blockIdx.x * hidden_size / 2;
|
||||||
|
|
||||||
|
#pragma unroll unroll_factor
|
||||||
|
for (int idx = threadIdx.x, cnt = 0; idx < hidden_size / 2; idx += blockDim.x, cnt++) {
|
||||||
|
int id = row_offset + idx;
|
||||||
|
x_local[cnt] = input_ptr[id];
|
||||||
|
float v1 = cuda_cast<float>(x_local[cnt].x);
|
||||||
|
float v2 = cuda_cast<float>(x_local[cnt].y);
|
||||||
|
variance += v1 * v1 + v2 * v2;
|
||||||
}
|
}
|
||||||
variance = blockReduceSum<float>(variance);
|
variance = blockReduceSum<float>(variance);
|
||||||
if (threadIdx.x == 0) {
|
if (threadIdx.x == 0) {
|
||||||
|
@ -41,16 +55,19 @@ __global__ void rms_layernorm_kernel(
|
||||||
}
|
}
|
||||||
__syncthreads();
|
__syncthreads();
|
||||||
|
|
||||||
for (int idx = threadIdx.x, cnt = 0; idx < hidden_size; idx += blockDim.x, cnt++) {
|
scalar2_t s_variance_2 = cuda_cast<scalar2_t>(s_variance);
|
||||||
out[blockIdx.x * hidden_size + idx] = ((scalar_t) (x_local[cnt] * s_variance)) * weight[idx];
|
#pragma unroll unroll_factor
|
||||||
|
for (int idx = threadIdx.x, cnt = 0; idx < hidden_size / 2; idx += blockDim.x, cnt++) {
|
||||||
|
int id = row_offset + idx;
|
||||||
|
out_ptr[id] = mul(x_local[cnt], s_variance_2, weight_ptr[idx]);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
template<typename scalar_t>
|
template<int unroll_factor>
|
||||||
__global__ void fused_add_rms_layernorm_kernel(
|
__global__ void rms_layernorm_kernel(
|
||||||
scalar_t* __restrict__ input, // [..., hidden_size]
|
float* __restrict__ out, // [..., hidden_size]
|
||||||
scalar_t* __restrict__ residual, // [..., hidden_size]
|
const float* __restrict__ input, // [..., hidden_size]
|
||||||
const scalar_t* __restrict__ weight, // [hidden_size]
|
const float* __restrict__ weight, // [hidden_size]
|
||||||
const float epsilon,
|
const float epsilon,
|
||||||
const int num_tokens,
|
const int num_tokens,
|
||||||
const int hidden_size) {
|
const int hidden_size) {
|
||||||
|
@ -58,11 +75,13 @@ __global__ void fused_add_rms_layernorm_kernel(
|
||||||
float variance = 0.0f;
|
float variance = 0.0f;
|
||||||
float x_local[8];
|
float x_local[8];
|
||||||
|
|
||||||
|
int row_offset = blockIdx.x * hidden_size;
|
||||||
|
|
||||||
|
#pragma unroll unroll_factor
|
||||||
for (int idx = threadIdx.x, cnt = 0; idx < hidden_size; idx += blockDim.x, cnt++) {
|
for (int idx = threadIdx.x, cnt = 0; idx < hidden_size; idx += blockDim.x, cnt++) {
|
||||||
x_local[cnt] = (float) input[blockIdx.x * hidden_size + idx];
|
int id = row_offset + idx;
|
||||||
x_local[cnt] += (float) residual[blockIdx.x * hidden_size + idx];
|
x_local[cnt] = input[id];
|
||||||
variance += x_local[cnt] * x_local[cnt];
|
variance += x_local[cnt] * x_local[cnt];
|
||||||
residual[blockIdx.x * hidden_size + idx] = (scalar_t) x_local[cnt];
|
|
||||||
}
|
}
|
||||||
variance = blockReduceSum<float>(variance);
|
variance = blockReduceSum<float>(variance);
|
||||||
if (threadIdx.x == 0) {
|
if (threadIdx.x == 0) {
|
||||||
|
@ -70,8 +89,89 @@ __global__ void fused_add_rms_layernorm_kernel(
|
||||||
}
|
}
|
||||||
__syncthreads();
|
__syncthreads();
|
||||||
|
|
||||||
|
#pragma unroll unroll_factor
|
||||||
for (int idx = threadIdx.x, cnt = 0; idx < hidden_size; idx += blockDim.x, cnt++) {
|
for (int idx = threadIdx.x, cnt = 0; idx < hidden_size; idx += blockDim.x, cnt++) {
|
||||||
input[blockIdx.x * hidden_size + idx] = ((scalar_t) (x_local[cnt] * s_variance)) * weight[idx];
|
int id = row_offset + idx;
|
||||||
|
out[id] = ((x_local[cnt] * s_variance)) * weight[idx];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// optimized for half and bf16
|
||||||
|
template<typename scalar_t, int unroll_factor>
|
||||||
|
__global__ void fused_add_rms_layernorm_kernel(
|
||||||
|
scalar_t* __restrict__ input, // [..., hidden_size]
|
||||||
|
scalar_t* __restrict__ residual, // [..., hidden_size]
|
||||||
|
const scalar_t* __restrict__ weight, // [hidden_size]
|
||||||
|
const float epsilon,
|
||||||
|
const int num_tokens,
|
||||||
|
const int hidden_size) {
|
||||||
|
using scalar2_t = typename TypeConverter<scalar_t>::Type;
|
||||||
|
__shared__ float s_variance;
|
||||||
|
scalar2_t x_local[4];
|
||||||
|
|
||||||
|
scalar2_t* input_ptr = (scalar2_t*)input;
|
||||||
|
scalar2_t* residual_ptr = (scalar2_t*)residual;
|
||||||
|
const scalar2_t* weight_ptr = (const scalar2_t*)weight;
|
||||||
|
|
||||||
|
float variance = 0.0f;
|
||||||
|
int row_offset = blockIdx.x * hidden_size / 2;
|
||||||
|
|
||||||
|
#pragma unroll unroll_factor
|
||||||
|
for (int idx = threadIdx.x, cnt = 0; idx < hidden_size / 2; idx += blockDim.x, cnt++) {
|
||||||
|
int id = row_offset + idx;
|
||||||
|
x_local[cnt] = input_ptr[id];
|
||||||
|
x_local[cnt] = add(x_local[cnt], residual_ptr[id]);
|
||||||
|
float v1 = cuda_cast<float>(x_local[cnt].x);
|
||||||
|
float v2 = cuda_cast<float>(x_local[cnt].y);
|
||||||
|
variance += v1 * v1 + v2 * v2;
|
||||||
|
residual_ptr[id] = x_local[cnt];
|
||||||
|
}
|
||||||
|
variance = blockReduceSum<float>(variance);
|
||||||
|
if (threadIdx.x == 0) {
|
||||||
|
s_variance = rsqrtf(variance / hidden_size + epsilon);
|
||||||
|
}
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
|
scalar2_t s_variance_2 = cuda_cast<scalar2_t>(s_variance);
|
||||||
|
#pragma unroll unroll_factor
|
||||||
|
for (int idx = threadIdx.x, cnt = 0; idx < hidden_size / 2; idx += blockDim.x, cnt++) {
|
||||||
|
int id = row_offset + idx;
|
||||||
|
input_ptr[id] = mul(x_local[cnt], s_variance_2, weight_ptr[idx]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template<int unroll_factor>
|
||||||
|
__global__ void fused_add_rms_layernorm_kernel(
|
||||||
|
float* __restrict__ input, // [..., hidden_size]
|
||||||
|
float* __restrict__ residual, // [..., hidden_size]
|
||||||
|
const float* __restrict__ weight, // [hidden_size]
|
||||||
|
const float epsilon,
|
||||||
|
const int num_tokens,
|
||||||
|
const int hidden_size) {
|
||||||
|
__shared__ float s_variance;
|
||||||
|
float variance = 0.0f;
|
||||||
|
float x_local[8];
|
||||||
|
|
||||||
|
int row_offset = blockIdx.x * hidden_size;
|
||||||
|
|
||||||
|
#pragma unroll unroll_factor
|
||||||
|
for (int idx = threadIdx.x, cnt = 0; idx < hidden_size; idx += blockDim.x, cnt++) {
|
||||||
|
int id = row_offset + idx;
|
||||||
|
x_local[cnt] = input[id];
|
||||||
|
x_local[cnt] += residual[id];
|
||||||
|
variance += x_local[cnt] * x_local[cnt];
|
||||||
|
residual[id] = x_local[cnt];
|
||||||
|
}
|
||||||
|
variance = blockReduceSum<float>(variance);
|
||||||
|
if (threadIdx.x == 0) {
|
||||||
|
s_variance = rsqrtf(variance / hidden_size + epsilon);
|
||||||
|
}
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
|
#pragma unroll unroll_factor
|
||||||
|
for (int idx = threadIdx.x, cnt = 0; idx < hidden_size; idx += blockDim.x, cnt++) {
|
||||||
|
int id = row_offset + idx;
|
||||||
|
input[id] = ((x_local[cnt] * s_variance)) * weight[idx];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -88,16 +188,89 @@ void rms_layernorm(
|
||||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
|
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
|
||||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||||
|
|
||||||
DISPATCH_FLOAT_HALF_AND_BFLOAT(
|
if (num_tokens >= 512) {
|
||||||
input.scalar_type(),
|
if (input.scalar_type() == at::ScalarType::Float) {
|
||||||
"rms_layernorm_kernel",
|
DISPATCH_FLOAT_HALF_AND_BFLOAT(
|
||||||
rms_layernorm_kernel<scalar_t><<<grid, block, 0, stream>>>(
|
input.scalar_type(),
|
||||||
out.data_ptr<scalar_t>(),
|
"rms_layernorm_kernel",
|
||||||
input.data_ptr<scalar_t>(),
|
rms_layernorm_kernel<scalar_t, 8><<<grid, hidden_size / 8, 0, stream>>>(
|
||||||
weight.data_ptr<scalar_t>(),
|
out.data_ptr<scalar_t>(),
|
||||||
epsilon,
|
input.data_ptr<scalar_t>(),
|
||||||
num_tokens,
|
weight.data_ptr<scalar_t>(),
|
||||||
hidden_size);)
|
epsilon,
|
||||||
|
num_tokens,
|
||||||
|
hidden_size);)
|
||||||
|
} else {
|
||||||
|
DISPATCH_FLOAT_HALF_AND_BFLOAT(
|
||||||
|
input.scalar_type(),
|
||||||
|
"rms_layernorm_kernel",
|
||||||
|
rms_layernorm_kernel<scalar_t, 4><<<grid, hidden_size / 8, 0, stream>>>(
|
||||||
|
out.data_ptr<scalar_t>(),
|
||||||
|
input.data_ptr<scalar_t>(),
|
||||||
|
weight.data_ptr<scalar_t>(),
|
||||||
|
epsilon,
|
||||||
|
num_tokens,
|
||||||
|
hidden_size);)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
int unroll_factor = (hidden_size + block.x - 1) / block.x;
|
||||||
|
if (input.scalar_type() != at::ScalarType::Float) {
|
||||||
|
block.x = std::min(hidden_size / 2, 1024);
|
||||||
|
int unroll_factor = (hidden_size / 2 + block.x - 1) / block.x;
|
||||||
|
}
|
||||||
|
switch (unroll_factor) {
|
||||||
|
case 1:
|
||||||
|
DISPATCH_FLOAT_HALF_AND_BFLOAT(
|
||||||
|
input.scalar_type(),
|
||||||
|
"rms_layernorm_kernel",
|
||||||
|
rms_layernorm_kernel<scalar_t, 1><<<grid, block, 0, stream>>>(
|
||||||
|
out.data_ptr<scalar_t>(),
|
||||||
|
input.data_ptr<scalar_t>(),
|
||||||
|
weight.data_ptr<scalar_t>(),
|
||||||
|
epsilon,
|
||||||
|
num_tokens,
|
||||||
|
hidden_size);)
|
||||||
|
break;
|
||||||
|
case 2:
|
||||||
|
DISPATCH_FLOAT_HALF_AND_BFLOAT(
|
||||||
|
input.scalar_type(),
|
||||||
|
"rms_layernorm_kernel",
|
||||||
|
rms_layernorm_kernel<scalar_t, 2><<<grid, block, 0, stream>>>(
|
||||||
|
out.data_ptr<scalar_t>(),
|
||||||
|
input.data_ptr<scalar_t>(),
|
||||||
|
weight.data_ptr<scalar_t>(),
|
||||||
|
epsilon,
|
||||||
|
num_tokens,
|
||||||
|
hidden_size);)
|
||||||
|
break;
|
||||||
|
case 4:
|
||||||
|
DISPATCH_FLOAT_HALF_AND_BFLOAT(
|
||||||
|
input.scalar_type(),
|
||||||
|
"rms_layernorm_kernel",
|
||||||
|
rms_layernorm_kernel<scalar_t, 4><<<grid, block, 0, stream>>>(
|
||||||
|
out.data_ptr<scalar_t>(),
|
||||||
|
input.data_ptr<scalar_t>(),
|
||||||
|
weight.data_ptr<scalar_t>(),
|
||||||
|
epsilon,
|
||||||
|
num_tokens,
|
||||||
|
hidden_size);)
|
||||||
|
break;
|
||||||
|
case 8:
|
||||||
|
DISPATCH_FLOAT_HALF_AND_BFLOAT(
|
||||||
|
input.scalar_type(),
|
||||||
|
"rms_layernorm_kernel",
|
||||||
|
rms_layernorm_kernel<scalar_t, 8><<<grid, block, 0, stream>>>(
|
||||||
|
out.data_ptr<scalar_t>(),
|
||||||
|
input.data_ptr<scalar_t>(),
|
||||||
|
weight.data_ptr<scalar_t>(),
|
||||||
|
epsilon,
|
||||||
|
num_tokens,
|
||||||
|
hidden_size);)
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
AT_ERROR("unroll_factor must be 1, 2, 4 or 8");
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void fused_add_rms_layernorm(
|
void fused_add_rms_layernorm(
|
||||||
|
@ -113,14 +286,87 @@ void fused_add_rms_layernorm(
|
||||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
|
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
|
||||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||||
|
|
||||||
DISPATCH_FLOAT_HALF_AND_BFLOAT(
|
if (num_tokens >= 512) {
|
||||||
input.scalar_type(),
|
if (input.scalar_type() == at::ScalarType::Float) {
|
||||||
"fused_add_rms_layernorm_kernel",
|
DISPATCH_FLOAT_HALF_AND_BFLOAT(
|
||||||
fused_add_rms_layernorm_kernel<scalar_t><<<grid, block, 0, stream>>>(
|
input.scalar_type(),
|
||||||
input.data_ptr<scalar_t>(),
|
"fused_add_rms_layernorm_kernel",
|
||||||
residual.data_ptr<scalar_t>(),
|
fused_add_rms_layernorm_kernel<scalar_t, 8><<<grid, hidden_size / 8, 0, stream>>>(
|
||||||
weight.data_ptr<scalar_t>(),
|
input.data_ptr<scalar_t>(),
|
||||||
epsilon,
|
residual.data_ptr<scalar_t>(),
|
||||||
num_tokens,
|
weight.data_ptr<scalar_t>(),
|
||||||
hidden_size);)
|
epsilon,
|
||||||
|
num_tokens,
|
||||||
|
hidden_size);)
|
||||||
|
} else {
|
||||||
|
DISPATCH_FLOAT_HALF_AND_BFLOAT(
|
||||||
|
input.scalar_type(),
|
||||||
|
"fused_add_rms_layernorm_kernel",
|
||||||
|
fused_add_rms_layernorm_kernel<scalar_t, 4><<<grid, hidden_size / 8, 0, stream>>>(
|
||||||
|
input.data_ptr<scalar_t>(),
|
||||||
|
residual.data_ptr<scalar_t>(),
|
||||||
|
weight.data_ptr<scalar_t>(),
|
||||||
|
epsilon,
|
||||||
|
num_tokens,
|
||||||
|
hidden_size);)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
int unroll_factor = (hidden_size + block.x - 1) / block.x;
|
||||||
|
if (input.scalar_type() != at::ScalarType::Float) {
|
||||||
|
block.x = std::min(hidden_size / 2, 1024);
|
||||||
|
int unroll_factor = (hidden_size / 2 + block.x - 1) / block.x;
|
||||||
|
}
|
||||||
|
switch (unroll_factor) {
|
||||||
|
case 1:
|
||||||
|
DISPATCH_FLOAT_HALF_AND_BFLOAT(
|
||||||
|
input.scalar_type(),
|
||||||
|
"fused_add_rms_layernorm_kernel",
|
||||||
|
fused_add_rms_layernorm_kernel<scalar_t, 1><<<grid, block, 0, stream>>>(
|
||||||
|
input.data_ptr<scalar_t>(),
|
||||||
|
residual.data_ptr<scalar_t>(),
|
||||||
|
weight.data_ptr<scalar_t>(),
|
||||||
|
epsilon,
|
||||||
|
num_tokens,
|
||||||
|
hidden_size);)
|
||||||
|
break;
|
||||||
|
case 2:
|
||||||
|
DISPATCH_FLOAT_HALF_AND_BFLOAT(
|
||||||
|
input.scalar_type(),
|
||||||
|
"fused_add_rms_layernorm_kernel",
|
||||||
|
fused_add_rms_layernorm_kernel<scalar_t, 2><<<grid, block, 0, stream>>>(
|
||||||
|
input.data_ptr<scalar_t>(),
|
||||||
|
residual.data_ptr<scalar_t>(),
|
||||||
|
weight.data_ptr<scalar_t>(),
|
||||||
|
epsilon,
|
||||||
|
num_tokens,
|
||||||
|
hidden_size);)
|
||||||
|
break;
|
||||||
|
case 4:
|
||||||
|
DISPATCH_FLOAT_HALF_AND_BFLOAT(
|
||||||
|
input.scalar_type(),
|
||||||
|
"fused_add_rms_layernorm_kernel",
|
||||||
|
fused_add_rms_layernorm_kernel<scalar_t, 4><<<grid, block, 0, stream>>>(
|
||||||
|
input.data_ptr<scalar_t>(),
|
||||||
|
residual.data_ptr<scalar_t>(),
|
||||||
|
weight.data_ptr<scalar_t>(),
|
||||||
|
epsilon,
|
||||||
|
num_tokens,
|
||||||
|
hidden_size);)
|
||||||
|
break;
|
||||||
|
case 8:
|
||||||
|
DISPATCH_FLOAT_HALF_AND_BFLOAT(
|
||||||
|
input.scalar_type(),
|
||||||
|
"fused_add_rms_layernorm_kernel",
|
||||||
|
fused_add_rms_layernorm_kernel<scalar_t, 8><<<grid, block, 0, stream>>>(
|
||||||
|
input.data_ptr<scalar_t>(),
|
||||||
|
residual.data_ptr<scalar_t>(),
|
||||||
|
weight.data_ptr<scalar_t>(),
|
||||||
|
epsilon,
|
||||||
|
num_tokens,
|
||||||
|
hidden_size);)
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
AT_ERROR("unroll_factor must be 1, 2, 4 or 8");
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue