fix rmsnorm template function invocation problem(template function partial specialization is not allowed in Cpp) and luckily pass e2e precision test (#5454)

pull/5418/head
Steve Luo 2024-03-13 16:00:55 +08:00 committed by GitHub
parent 6fd355a5a6
commit ed431de4e4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 79 additions and 35 deletions

View File

@ -12,6 +12,34 @@
#include "../common/micros.h"
#include "../common/cuda_type_utils.h"
#define DISPATCH_RMSNORM_FLOAT_HALF_AND_BFLOAT(DATA_SIZE, TYPE, NAME, ...) \
if (DATA_SIZE == 2) { \
switch (TYPE) { \
case at::ScalarType::Half: { \
using scalar_t = at::Half; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::BFloat16: { \
using scalar_t = at::BFloat16; \
__VA_ARGS__; \
break; \
} \
default: \
AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
} \
} else { \
switch (TYPE) { \
case at::ScalarType::Float: { \
using scalar_t = float; \
general_##__VA_ARGS__; \
break; \
} \
default: \
AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
} \
} \
// optimized for half and bf16
template<typename scalar_t, int unroll_factor>
__global__ void rms_layernorm_kernel(
@ -63,11 +91,11 @@ __global__ void rms_layernorm_kernel(
}
}
template<int unroll_factor>
__global__ void rms_layernorm_kernel(
float* __restrict__ out, // [..., hidden_size]
const float* __restrict__ input, // [..., hidden_size]
const float* __restrict__ weight, // [hidden_size]
template<typename scalar_t, int unroll_factor>
__global__ void general_rms_layernorm_kernel(
scalar_t* __restrict__ out, // [..., hidden_size]
const scalar_t* __restrict__ input, // [..., hidden_size]
const scalar_t* __restrict__ weight, // [hidden_size]
const float epsilon,
const int num_tokens,
const int hidden_size) {
@ -80,7 +108,7 @@ __global__ void rms_layernorm_kernel(
#pragma unroll unroll_factor
for (int idx = threadIdx.x, cnt = 0; idx < hidden_size; idx += blockDim.x, cnt++) {
int id = row_offset + idx;
x_local[cnt] = input[id];
x_local[cnt] = (float) input[id];
variance += x_local[cnt] * x_local[cnt];
}
variance = blockReduceSum<float>(variance);
@ -92,7 +120,7 @@ __global__ void rms_layernorm_kernel(
#pragma unroll unroll_factor
for (int idx = threadIdx.x, cnt = 0; idx < hidden_size; idx += blockDim.x, cnt++) {
int id = row_offset + idx;
out[id] = ((x_local[cnt] * s_variance)) * weight[idx];
out[id] = ((scalar_t) (x_local[cnt] * s_variance)) * weight[idx];
}
}
@ -140,11 +168,11 @@ __global__ void fused_add_rms_layernorm_kernel(
}
}
template<int unroll_factor>
__global__ void fused_add_rms_layernorm_kernel(
float* __restrict__ input, // [..., hidden_size]
float* __restrict__ residual, // [..., hidden_size]
const float* __restrict__ weight, // [hidden_size]
template<typename scalar_t, int unroll_factor>
__global__ void general_fused_add_rms_layernorm_kernel(
scalar_t* __restrict__ input, // [..., hidden_size]
scalar_t* __restrict__ residual, // [..., hidden_size]
const scalar_t* __restrict__ weight, // [hidden_size]
const float epsilon,
const int num_tokens,
const int hidden_size) {
@ -157,10 +185,10 @@ __global__ void fused_add_rms_layernorm_kernel(
#pragma unroll unroll_factor
for (int idx = threadIdx.x, cnt = 0; idx < hidden_size; idx += blockDim.x, cnt++) {
int id = row_offset + idx;
x_local[cnt] = input[id];
x_local[cnt] += residual[id];
x_local[cnt] = (float) input[id];
x_local[cnt] += (float) residual[id];
variance += x_local[cnt] * x_local[cnt];
residual[id] = x_local[cnt];
residual[id] = (scalar_t) x_local[cnt];
}
variance = blockReduceSum<float>(variance);
if (threadIdx.x == 0) {
@ -171,7 +199,7 @@ __global__ void fused_add_rms_layernorm_kernel(
#pragma unroll unroll_factor
for (int idx = threadIdx.x, cnt = 0; idx < hidden_size; idx += blockDim.x, cnt++) {
int id = row_offset + idx;
input[id] = ((x_local[cnt] * s_variance)) * weight[idx];
input[id] = ((scalar_t) (x_local[cnt] * s_variance)) * weight[idx];
}
}
@ -190,7 +218,8 @@ void rms_layernorm(
if (num_tokens >= 512) {
if (input.scalar_type() == at::ScalarType::Float) {
DISPATCH_FLOAT_HALF_AND_BFLOAT(
DISPATCH_RMSNORM_FLOAT_HALF_AND_BFLOAT(
input.element_size(),
input.scalar_type(),
"rms_layernorm_kernel",
rms_layernorm_kernel<scalar_t, 8><<<grid, hidden_size / 8, 0, stream>>>(
@ -201,7 +230,8 @@ void rms_layernorm(
num_tokens,
hidden_size);)
} else {
DISPATCH_FLOAT_HALF_AND_BFLOAT(
DISPATCH_RMSNORM_FLOAT_HALF_AND_BFLOAT(
input.element_size(),
input.scalar_type(),
"rms_layernorm_kernel",
rms_layernorm_kernel<scalar_t, 4><<<grid, hidden_size / 8, 0, stream>>>(
@ -216,11 +246,12 @@ void rms_layernorm(
int unroll_factor = (hidden_size + block.x - 1) / block.x;
if (input.scalar_type() != at::ScalarType::Float) {
block.x = std::min(hidden_size / 2, 1024);
int unroll_factor = (hidden_size / 2 + block.x - 1) / block.x;
unroll_factor = (hidden_size / 2 + block.x - 1) / block.x;
}
switch (unroll_factor) {
case 1:
DISPATCH_FLOAT_HALF_AND_BFLOAT(
DISPATCH_RMSNORM_FLOAT_HALF_AND_BFLOAT(
input.element_size(),
input.scalar_type(),
"rms_layernorm_kernel",
rms_layernorm_kernel<scalar_t, 1><<<grid, block, 0, stream>>>(
@ -232,7 +263,8 @@ void rms_layernorm(
hidden_size);)
break;
case 2:
DISPATCH_FLOAT_HALF_AND_BFLOAT(
DISPATCH_RMSNORM_FLOAT_HALF_AND_BFLOAT(
input.element_size(),
input.scalar_type(),
"rms_layernorm_kernel",
rms_layernorm_kernel<scalar_t, 2><<<grid, block, 0, stream>>>(
@ -244,7 +276,8 @@ void rms_layernorm(
hidden_size);)
break;
case 4:
DISPATCH_FLOAT_HALF_AND_BFLOAT(
DISPATCH_RMSNORM_FLOAT_HALF_AND_BFLOAT(
input.element_size(),
input.scalar_type(),
"rms_layernorm_kernel",
rms_layernorm_kernel<scalar_t, 4><<<grid, block, 0, stream>>>(
@ -256,7 +289,8 @@ void rms_layernorm(
hidden_size);)
break;
case 8:
DISPATCH_FLOAT_HALF_AND_BFLOAT(
DISPATCH_RMSNORM_FLOAT_HALF_AND_BFLOAT(
input.element_size(),
input.scalar_type(),
"rms_layernorm_kernel",
rms_layernorm_kernel<scalar_t, 8><<<grid, block, 0, stream>>>(
@ -288,7 +322,8 @@ void fused_add_rms_layernorm(
if (num_tokens >= 512) {
if (input.scalar_type() == at::ScalarType::Float) {
DISPATCH_FLOAT_HALF_AND_BFLOAT(
DISPATCH_RMSNORM_FLOAT_HALF_AND_BFLOAT(
input.element_size(),
input.scalar_type(),
"fused_add_rms_layernorm_kernel",
fused_add_rms_layernorm_kernel<scalar_t, 8><<<grid, hidden_size / 8, 0, stream>>>(
@ -299,7 +334,8 @@ void fused_add_rms_layernorm(
num_tokens,
hidden_size);)
} else {
DISPATCH_FLOAT_HALF_AND_BFLOAT(
DISPATCH_RMSNORM_FLOAT_HALF_AND_BFLOAT(
input.element_size(),
input.scalar_type(),
"fused_add_rms_layernorm_kernel",
fused_add_rms_layernorm_kernel<scalar_t, 4><<<grid, hidden_size / 8, 0, stream>>>(
@ -314,11 +350,12 @@ void fused_add_rms_layernorm(
int unroll_factor = (hidden_size + block.x - 1) / block.x;
if (input.scalar_type() != at::ScalarType::Float) {
block.x = std::min(hidden_size / 2, 1024);
int unroll_factor = (hidden_size / 2 + block.x - 1) / block.x;
unroll_factor = (hidden_size / 2 + block.x - 1) / block.x;
}
switch (unroll_factor) {
case 1:
DISPATCH_FLOAT_HALF_AND_BFLOAT(
DISPATCH_RMSNORM_FLOAT_HALF_AND_BFLOAT(
input.element_size(),
input.scalar_type(),
"fused_add_rms_layernorm_kernel",
fused_add_rms_layernorm_kernel<scalar_t, 1><<<grid, block, 0, stream>>>(
@ -330,7 +367,8 @@ void fused_add_rms_layernorm(
hidden_size);)
break;
case 2:
DISPATCH_FLOAT_HALF_AND_BFLOAT(
DISPATCH_RMSNORM_FLOAT_HALF_AND_BFLOAT(
input.element_size(),
input.scalar_type(),
"fused_add_rms_layernorm_kernel",
fused_add_rms_layernorm_kernel<scalar_t, 2><<<grid, block, 0, stream>>>(
@ -342,7 +380,8 @@ void fused_add_rms_layernorm(
hidden_size);)
break;
case 4:
DISPATCH_FLOAT_HALF_AND_BFLOAT(
DISPATCH_RMSNORM_FLOAT_HALF_AND_BFLOAT(
input.element_size(),
input.scalar_type(),
"fused_add_rms_layernorm_kernel",
fused_add_rms_layernorm_kernel<scalar_t, 4><<<grid, block, 0, stream>>>(
@ -354,7 +393,8 @@ void fused_add_rms_layernorm(
hidden_size);)
break;
case 8:
DISPATCH_FLOAT_HALF_AND_BFLOAT(
DISPATCH_RMSNORM_FLOAT_HALF_AND_BFLOAT(
input.element_size(),
input.scalar_type(),
"fused_add_rms_layernorm_kernel",
fused_add_rms_layernorm_kernel<scalar_t, 8><<<grid, block, 0, stream>>>(

View File

@ -22,11 +22,15 @@ def setup_seed(seed):
def check_inference_engine(use_engine=False, prompt_template=None):
setup_seed(20)
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/llama-tokenizer")
model = LlamaForCausalLM(
LlamaConfig(
vocab_size=50000, hidden_size=512, intermediate_size=1536, num_attention_heads=4, num_hidden_layers=16
model = (
LlamaForCausalLM(
LlamaConfig(
vocab_size=50000, hidden_size=512, intermediate_size=1536, num_attention_heads=4, num_hidden_layers=16
)
)
).cuda()
.cuda()
.half()
)
model = model.eval()
inputs = [
@ -40,7 +44,7 @@ def check_inference_engine(use_engine=False, prompt_template=None):
top_k = 50
if use_engine:
inference_config = InferenceConfig(max_output_len=output_len, prompt_template=prompt_template, dtype="fp32")
inference_config = InferenceConfig(max_output_len=output_len, prompt_template=prompt_template)
inference_engine = InferenceEngine(model, tokenizer, inference_config, verbose=True)
assert inference_engine.generation_config.max_new_tokens == output_len
inference_engine.add_request(prompts=inputs)