mirror of https://github.com/hpcaitech/ColossalAI
fix rmsnorm template function invocation problem(template function partial specialization is not allowed in Cpp) and luckily pass e2e precision test (#5454)
parent
6fd355a5a6
commit
ed431de4e4
|
@ -12,6 +12,34 @@
|
|||
#include "../common/micros.h"
|
||||
#include "../common/cuda_type_utils.h"
|
||||
|
||||
#define DISPATCH_RMSNORM_FLOAT_HALF_AND_BFLOAT(DATA_SIZE, TYPE, NAME, ...) \
|
||||
if (DATA_SIZE == 2) { \
|
||||
switch (TYPE) { \
|
||||
case at::ScalarType::Half: { \
|
||||
using scalar_t = at::Half; \
|
||||
__VA_ARGS__; \
|
||||
break; \
|
||||
} \
|
||||
case at::ScalarType::BFloat16: { \
|
||||
using scalar_t = at::BFloat16; \
|
||||
__VA_ARGS__; \
|
||||
break; \
|
||||
} \
|
||||
default: \
|
||||
AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
|
||||
} \
|
||||
} else { \
|
||||
switch (TYPE) { \
|
||||
case at::ScalarType::Float: { \
|
||||
using scalar_t = float; \
|
||||
general_##__VA_ARGS__; \
|
||||
break; \
|
||||
} \
|
||||
default: \
|
||||
AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
|
||||
} \
|
||||
} \
|
||||
|
||||
// optimized for half and bf16
|
||||
template<typename scalar_t, int unroll_factor>
|
||||
__global__ void rms_layernorm_kernel(
|
||||
|
@ -63,11 +91,11 @@ __global__ void rms_layernorm_kernel(
|
|||
}
|
||||
}
|
||||
|
||||
template<int unroll_factor>
|
||||
__global__ void rms_layernorm_kernel(
|
||||
float* __restrict__ out, // [..., hidden_size]
|
||||
const float* __restrict__ input, // [..., hidden_size]
|
||||
const float* __restrict__ weight, // [hidden_size]
|
||||
template<typename scalar_t, int unroll_factor>
|
||||
__global__ void general_rms_layernorm_kernel(
|
||||
scalar_t* __restrict__ out, // [..., hidden_size]
|
||||
const scalar_t* __restrict__ input, // [..., hidden_size]
|
||||
const scalar_t* __restrict__ weight, // [hidden_size]
|
||||
const float epsilon,
|
||||
const int num_tokens,
|
||||
const int hidden_size) {
|
||||
|
@ -80,7 +108,7 @@ __global__ void rms_layernorm_kernel(
|
|||
#pragma unroll unroll_factor
|
||||
for (int idx = threadIdx.x, cnt = 0; idx < hidden_size; idx += blockDim.x, cnt++) {
|
||||
int id = row_offset + idx;
|
||||
x_local[cnt] = input[id];
|
||||
x_local[cnt] = (float) input[id];
|
||||
variance += x_local[cnt] * x_local[cnt];
|
||||
}
|
||||
variance = blockReduceSum<float>(variance);
|
||||
|
@ -92,7 +120,7 @@ __global__ void rms_layernorm_kernel(
|
|||
#pragma unroll unroll_factor
|
||||
for (int idx = threadIdx.x, cnt = 0; idx < hidden_size; idx += blockDim.x, cnt++) {
|
||||
int id = row_offset + idx;
|
||||
out[id] = ((x_local[cnt] * s_variance)) * weight[idx];
|
||||
out[id] = ((scalar_t) (x_local[cnt] * s_variance)) * weight[idx];
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -140,11 +168,11 @@ __global__ void fused_add_rms_layernorm_kernel(
|
|||
}
|
||||
}
|
||||
|
||||
template<int unroll_factor>
|
||||
__global__ void fused_add_rms_layernorm_kernel(
|
||||
float* __restrict__ input, // [..., hidden_size]
|
||||
float* __restrict__ residual, // [..., hidden_size]
|
||||
const float* __restrict__ weight, // [hidden_size]
|
||||
template<typename scalar_t, int unroll_factor>
|
||||
__global__ void general_fused_add_rms_layernorm_kernel(
|
||||
scalar_t* __restrict__ input, // [..., hidden_size]
|
||||
scalar_t* __restrict__ residual, // [..., hidden_size]
|
||||
const scalar_t* __restrict__ weight, // [hidden_size]
|
||||
const float epsilon,
|
||||
const int num_tokens,
|
||||
const int hidden_size) {
|
||||
|
@ -157,10 +185,10 @@ __global__ void fused_add_rms_layernorm_kernel(
|
|||
#pragma unroll unroll_factor
|
||||
for (int idx = threadIdx.x, cnt = 0; idx < hidden_size; idx += blockDim.x, cnt++) {
|
||||
int id = row_offset + idx;
|
||||
x_local[cnt] = input[id];
|
||||
x_local[cnt] += residual[id];
|
||||
x_local[cnt] = (float) input[id];
|
||||
x_local[cnt] += (float) residual[id];
|
||||
variance += x_local[cnt] * x_local[cnt];
|
||||
residual[id] = x_local[cnt];
|
||||
residual[id] = (scalar_t) x_local[cnt];
|
||||
}
|
||||
variance = blockReduceSum<float>(variance);
|
||||
if (threadIdx.x == 0) {
|
||||
|
@ -171,7 +199,7 @@ __global__ void fused_add_rms_layernorm_kernel(
|
|||
#pragma unroll unroll_factor
|
||||
for (int idx = threadIdx.x, cnt = 0; idx < hidden_size; idx += blockDim.x, cnt++) {
|
||||
int id = row_offset + idx;
|
||||
input[id] = ((x_local[cnt] * s_variance)) * weight[idx];
|
||||
input[id] = ((scalar_t) (x_local[cnt] * s_variance)) * weight[idx];
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -190,7 +218,8 @@ void rms_layernorm(
|
|||
|
||||
if (num_tokens >= 512) {
|
||||
if (input.scalar_type() == at::ScalarType::Float) {
|
||||
DISPATCH_FLOAT_HALF_AND_BFLOAT(
|
||||
DISPATCH_RMSNORM_FLOAT_HALF_AND_BFLOAT(
|
||||
input.element_size(),
|
||||
input.scalar_type(),
|
||||
"rms_layernorm_kernel",
|
||||
rms_layernorm_kernel<scalar_t, 8><<<grid, hidden_size / 8, 0, stream>>>(
|
||||
|
@ -201,7 +230,8 @@ void rms_layernorm(
|
|||
num_tokens,
|
||||
hidden_size);)
|
||||
} else {
|
||||
DISPATCH_FLOAT_HALF_AND_BFLOAT(
|
||||
DISPATCH_RMSNORM_FLOAT_HALF_AND_BFLOAT(
|
||||
input.element_size(),
|
||||
input.scalar_type(),
|
||||
"rms_layernorm_kernel",
|
||||
rms_layernorm_kernel<scalar_t, 4><<<grid, hidden_size / 8, 0, stream>>>(
|
||||
|
@ -216,11 +246,12 @@ void rms_layernorm(
|
|||
int unroll_factor = (hidden_size + block.x - 1) / block.x;
|
||||
if (input.scalar_type() != at::ScalarType::Float) {
|
||||
block.x = std::min(hidden_size / 2, 1024);
|
||||
int unroll_factor = (hidden_size / 2 + block.x - 1) / block.x;
|
||||
unroll_factor = (hidden_size / 2 + block.x - 1) / block.x;
|
||||
}
|
||||
switch (unroll_factor) {
|
||||
case 1:
|
||||
DISPATCH_FLOAT_HALF_AND_BFLOAT(
|
||||
DISPATCH_RMSNORM_FLOAT_HALF_AND_BFLOAT(
|
||||
input.element_size(),
|
||||
input.scalar_type(),
|
||||
"rms_layernorm_kernel",
|
||||
rms_layernorm_kernel<scalar_t, 1><<<grid, block, 0, stream>>>(
|
||||
|
@ -232,7 +263,8 @@ void rms_layernorm(
|
|||
hidden_size);)
|
||||
break;
|
||||
case 2:
|
||||
DISPATCH_FLOAT_HALF_AND_BFLOAT(
|
||||
DISPATCH_RMSNORM_FLOAT_HALF_AND_BFLOAT(
|
||||
input.element_size(),
|
||||
input.scalar_type(),
|
||||
"rms_layernorm_kernel",
|
||||
rms_layernorm_kernel<scalar_t, 2><<<grid, block, 0, stream>>>(
|
||||
|
@ -244,7 +276,8 @@ void rms_layernorm(
|
|||
hidden_size);)
|
||||
break;
|
||||
case 4:
|
||||
DISPATCH_FLOAT_HALF_AND_BFLOAT(
|
||||
DISPATCH_RMSNORM_FLOAT_HALF_AND_BFLOAT(
|
||||
input.element_size(),
|
||||
input.scalar_type(),
|
||||
"rms_layernorm_kernel",
|
||||
rms_layernorm_kernel<scalar_t, 4><<<grid, block, 0, stream>>>(
|
||||
|
@ -256,7 +289,8 @@ void rms_layernorm(
|
|||
hidden_size);)
|
||||
break;
|
||||
case 8:
|
||||
DISPATCH_FLOAT_HALF_AND_BFLOAT(
|
||||
DISPATCH_RMSNORM_FLOAT_HALF_AND_BFLOAT(
|
||||
input.element_size(),
|
||||
input.scalar_type(),
|
||||
"rms_layernorm_kernel",
|
||||
rms_layernorm_kernel<scalar_t, 8><<<grid, block, 0, stream>>>(
|
||||
|
@ -288,7 +322,8 @@ void fused_add_rms_layernorm(
|
|||
|
||||
if (num_tokens >= 512) {
|
||||
if (input.scalar_type() == at::ScalarType::Float) {
|
||||
DISPATCH_FLOAT_HALF_AND_BFLOAT(
|
||||
DISPATCH_RMSNORM_FLOAT_HALF_AND_BFLOAT(
|
||||
input.element_size(),
|
||||
input.scalar_type(),
|
||||
"fused_add_rms_layernorm_kernel",
|
||||
fused_add_rms_layernorm_kernel<scalar_t, 8><<<grid, hidden_size / 8, 0, stream>>>(
|
||||
|
@ -299,7 +334,8 @@ void fused_add_rms_layernorm(
|
|||
num_tokens,
|
||||
hidden_size);)
|
||||
} else {
|
||||
DISPATCH_FLOAT_HALF_AND_BFLOAT(
|
||||
DISPATCH_RMSNORM_FLOAT_HALF_AND_BFLOAT(
|
||||
input.element_size(),
|
||||
input.scalar_type(),
|
||||
"fused_add_rms_layernorm_kernel",
|
||||
fused_add_rms_layernorm_kernel<scalar_t, 4><<<grid, hidden_size / 8, 0, stream>>>(
|
||||
|
@ -314,11 +350,12 @@ void fused_add_rms_layernorm(
|
|||
int unroll_factor = (hidden_size + block.x - 1) / block.x;
|
||||
if (input.scalar_type() != at::ScalarType::Float) {
|
||||
block.x = std::min(hidden_size / 2, 1024);
|
||||
int unroll_factor = (hidden_size / 2 + block.x - 1) / block.x;
|
||||
unroll_factor = (hidden_size / 2 + block.x - 1) / block.x;
|
||||
}
|
||||
switch (unroll_factor) {
|
||||
case 1:
|
||||
DISPATCH_FLOAT_HALF_AND_BFLOAT(
|
||||
DISPATCH_RMSNORM_FLOAT_HALF_AND_BFLOAT(
|
||||
input.element_size(),
|
||||
input.scalar_type(),
|
||||
"fused_add_rms_layernorm_kernel",
|
||||
fused_add_rms_layernorm_kernel<scalar_t, 1><<<grid, block, 0, stream>>>(
|
||||
|
@ -330,7 +367,8 @@ void fused_add_rms_layernorm(
|
|||
hidden_size);)
|
||||
break;
|
||||
case 2:
|
||||
DISPATCH_FLOAT_HALF_AND_BFLOAT(
|
||||
DISPATCH_RMSNORM_FLOAT_HALF_AND_BFLOAT(
|
||||
input.element_size(),
|
||||
input.scalar_type(),
|
||||
"fused_add_rms_layernorm_kernel",
|
||||
fused_add_rms_layernorm_kernel<scalar_t, 2><<<grid, block, 0, stream>>>(
|
||||
|
@ -342,7 +380,8 @@ void fused_add_rms_layernorm(
|
|||
hidden_size);)
|
||||
break;
|
||||
case 4:
|
||||
DISPATCH_FLOAT_HALF_AND_BFLOAT(
|
||||
DISPATCH_RMSNORM_FLOAT_HALF_AND_BFLOAT(
|
||||
input.element_size(),
|
||||
input.scalar_type(),
|
||||
"fused_add_rms_layernorm_kernel",
|
||||
fused_add_rms_layernorm_kernel<scalar_t, 4><<<grid, block, 0, stream>>>(
|
||||
|
@ -354,7 +393,8 @@ void fused_add_rms_layernorm(
|
|||
hidden_size);)
|
||||
break;
|
||||
case 8:
|
||||
DISPATCH_FLOAT_HALF_AND_BFLOAT(
|
||||
DISPATCH_RMSNORM_FLOAT_HALF_AND_BFLOAT(
|
||||
input.element_size(),
|
||||
input.scalar_type(),
|
||||
"fused_add_rms_layernorm_kernel",
|
||||
fused_add_rms_layernorm_kernel<scalar_t, 8><<<grid, block, 0, stream>>>(
|
||||
|
|
|
@ -22,11 +22,15 @@ def setup_seed(seed):
|
|||
def check_inference_engine(use_engine=False, prompt_template=None):
|
||||
setup_seed(20)
|
||||
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/llama-tokenizer")
|
||||
model = LlamaForCausalLM(
|
||||
LlamaConfig(
|
||||
vocab_size=50000, hidden_size=512, intermediate_size=1536, num_attention_heads=4, num_hidden_layers=16
|
||||
model = (
|
||||
LlamaForCausalLM(
|
||||
LlamaConfig(
|
||||
vocab_size=50000, hidden_size=512, intermediate_size=1536, num_attention_heads=4, num_hidden_layers=16
|
||||
)
|
||||
)
|
||||
).cuda()
|
||||
.cuda()
|
||||
.half()
|
||||
)
|
||||
model = model.eval()
|
||||
|
||||
inputs = [
|
||||
|
@ -40,7 +44,7 @@ def check_inference_engine(use_engine=False, prompt_template=None):
|
|||
top_k = 50
|
||||
|
||||
if use_engine:
|
||||
inference_config = InferenceConfig(max_output_len=output_len, prompt_template=prompt_template, dtype="fp32")
|
||||
inference_config = InferenceConfig(max_output_len=output_len, prompt_template=prompt_template)
|
||||
inference_engine = InferenceEngine(model, tokenizer, inference_config, verbose=True)
|
||||
assert inference_engine.generation_config.max_new_tokens == output_len
|
||||
inference_engine.add_request(prompts=inputs)
|
||||
|
|
Loading…
Reference in New Issue