mirror of https://github.com/hpcaitech/ColossalAI
feat rmsnorm cuda kernel and add unittest, benchmark script (#5417)
parent
2b28b54ac6
commit
f7aecc0c6b
|
@ -9,6 +9,7 @@ from transformers.models.llama.modeling_llama import (
|
||||||
LlamaForCausalLM,
|
LlamaForCausalLM,
|
||||||
LlamaMLP,
|
LlamaMLP,
|
||||||
LlamaModel,
|
LlamaModel,
|
||||||
|
LlamaRMSNorm,
|
||||||
)
|
)
|
||||||
|
|
||||||
from colossalai.inference.batch_bucket import BatchBucket
|
from colossalai.inference.batch_bucket import BatchBucket
|
||||||
|
@ -19,6 +20,7 @@ from colossalai.kernel.triton import (
|
||||||
decoding_fused_rotary_embedding,
|
decoding_fused_rotary_embedding,
|
||||||
flash_decoding_attention,
|
flash_decoding_attention,
|
||||||
get_xine_cache,
|
get_xine_cache,
|
||||||
|
rms_layernorm,
|
||||||
rotary_embedding,
|
rotary_embedding,
|
||||||
)
|
)
|
||||||
from colossalai.logging import get_dist_logger
|
from colossalai.logging import get_dist_logger
|
||||||
|
@ -124,7 +126,7 @@ def llama_model_forward(
|
||||||
hidden_states = hidden_states[last_token_indexs - 1].contiguous()
|
hidden_states = hidden_states[last_token_indexs - 1].contiguous()
|
||||||
residual = residual[last_token_indexs - 1].contiguous()
|
residual = residual[last_token_indexs - 1].contiguous()
|
||||||
norm_output = torch.empty_like(hidden_states)
|
norm_output = torch.empty_like(hidden_states)
|
||||||
hidden_states, _ = self.norm(hidden_states, norm_output, residual)
|
hidden_states, _ = self.norm(hidden_states, norm_output, residual, use_cuda_kernel)
|
||||||
|
|
||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
|
@ -167,7 +169,7 @@ def llama_decoder_layer_forward(
|
||||||
use_cuda_kernel: (bool, optional): Whether to use cuda kernel. Defaults to True.
|
use_cuda_kernel: (bool, optional): Whether to use cuda kernel. Defaults to True.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
hidden_states, residual = self.input_layernorm(hidden_states, norm_output, residual)
|
hidden_states, residual = self.input_layernorm(hidden_states, norm_output, residual, use_cuda_kernel)
|
||||||
# Self Attention
|
# Self Attention
|
||||||
hidden_states = self.self_attn(
|
hidden_states = self.self_attn(
|
||||||
hidden_states=hidden_states,
|
hidden_states=hidden_states,
|
||||||
|
@ -185,12 +187,32 @@ def llama_decoder_layer_forward(
|
||||||
)
|
)
|
||||||
|
|
||||||
# Fully Connected
|
# Fully Connected
|
||||||
hidden_states, residual = self.post_attention_layernorm(hidden_states, norm_output, residual)
|
hidden_states, residual = self.post_attention_layernorm(hidden_states, norm_output, residual, use_cuda_kernel)
|
||||||
hidden_states = self.mlp(hidden_states)
|
hidden_states = self.mlp(hidden_states)
|
||||||
|
|
||||||
return hidden_states, residual
|
return hidden_states, residual
|
||||||
|
|
||||||
|
|
||||||
|
def llama_rmsnorm_forward(
|
||||||
|
self: LlamaRMSNorm,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
norm_output: torch.Tensor,
|
||||||
|
residual: torch.Tensor = None,
|
||||||
|
use_cuda_kernel: bool = True,
|
||||||
|
):
|
||||||
|
if use_cuda_kernel:
|
||||||
|
if residual is not None:
|
||||||
|
inference_ops.fused_add_rms_layernorm(hidden_states, residual, self.weight.data, self.variance_epsilon)
|
||||||
|
return hidden_states, residual
|
||||||
|
|
||||||
|
if norm_output is None:
|
||||||
|
norm_output = torch.empty_like(hidden_states)
|
||||||
|
inference_ops.rms_layernorm(norm_output, hidden_states, self.weight.data, self.variance_epsilon)
|
||||||
|
return norm_output, hidden_states
|
||||||
|
else:
|
||||||
|
return rms_layernorm(hidden_states, self.weight.data, self.variance_epsilon, norm_output, residual)
|
||||||
|
|
||||||
|
|
||||||
class NopadLlamaAttention(LlamaAttention):
|
class NopadLlamaAttention(LlamaAttention):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
|
|
@ -1,6 +1,5 @@
|
||||||
from functools import partial
|
from functools import partial
|
||||||
|
|
||||||
import torch
|
|
||||||
from torch.nn import Parameter
|
from torch.nn import Parameter
|
||||||
from transformers.models.llama.modeling_llama import LlamaDecoderLayer, LlamaForCausalLM, LlamaModel, LlamaRMSNorm
|
from transformers.models.llama.modeling_llama import LlamaDecoderLayer, LlamaForCausalLM, LlamaModel, LlamaRMSNorm
|
||||||
|
|
||||||
|
@ -10,6 +9,7 @@ from colossalai.inference.modeling.models.nopadding_llama import (
|
||||||
llama_causal_lm_forward,
|
llama_causal_lm_forward,
|
||||||
llama_decoder_layer_forward,
|
llama_decoder_layer_forward,
|
||||||
llama_model_forward,
|
llama_model_forward,
|
||||||
|
llama_rmsnorm_forward,
|
||||||
)
|
)
|
||||||
from colossalai.inference.utils import init_to_get_rotary
|
from colossalai.inference.utils import init_to_get_rotary
|
||||||
from colossalai.shardformer.policies.base_policy import ModulePolicyDescription, SubModuleReplacementDescription
|
from colossalai.shardformer.policies.base_policy import ModulePolicyDescription, SubModuleReplacementDescription
|
||||||
|
@ -17,27 +17,6 @@ from colossalai.shardformer.policies.base_policy import ModulePolicyDescription,
|
||||||
# import colossalai
|
# import colossalai
|
||||||
from colossalai.shardformer.policies.llama import LlamaForCausalLMPolicy
|
from colossalai.shardformer.policies.llama import LlamaForCausalLMPolicy
|
||||||
|
|
||||||
try:
|
|
||||||
from colossalai.kernel.triton import rms_layernorm
|
|
||||||
|
|
||||||
HAS_TRITON_RMSNORM = True
|
|
||||||
except:
|
|
||||||
print("you should install triton from https://github.com/openai/triton")
|
|
||||||
HAS_TRITON_RMSNORM = False
|
|
||||||
|
|
||||||
|
|
||||||
def get_triton_rmsnorm_forward():
|
|
||||||
if HAS_TRITON_RMSNORM:
|
|
||||||
|
|
||||||
def _triton_rmsnorm_forward(
|
|
||||||
self: LlamaRMSNorm, hidden_states: torch.Tensor, norm_output: torch.Tensor, residual: torch.Tensor = None
|
|
||||||
):
|
|
||||||
return rms_layernorm(hidden_states, self.weight.data, self.variance_epsilon, norm_output, residual)
|
|
||||||
|
|
||||||
return _triton_rmsnorm_forward
|
|
||||||
else:
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
class NoPaddingLlamaModelInferPolicy(LlamaForCausalLMPolicy):
|
class NoPaddingLlamaModelInferPolicy(LlamaForCausalLMPolicy):
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
|
@ -84,15 +63,9 @@ class NoPaddingLlamaModelInferPolicy(LlamaForCausalLMPolicy):
|
||||||
description=method_replacement, policy=policy, target_key=LlamaDecoderLayer
|
description=method_replacement, policy=policy, target_key=LlamaDecoderLayer
|
||||||
)
|
)
|
||||||
|
|
||||||
infer_forward = None
|
infer_forward = llama_rmsnorm_forward
|
||||||
if HAS_TRITON_RMSNORM:
|
|
||||||
infer_forward = get_triton_rmsnorm_forward()
|
|
||||||
|
|
||||||
if infer_forward is not None:
|
|
||||||
method_replacement = {"forward": partial(infer_forward)}
|
method_replacement = {"forward": partial(infer_forward)}
|
||||||
self.append_or_create_method_replacement(
|
self.append_or_create_method_replacement(description=method_replacement, policy=policy, target_key=LlamaRMSNorm)
|
||||||
description=method_replacement, policy=policy, target_key=LlamaRMSNorm
|
|
||||||
)
|
|
||||||
|
|
||||||
return policy
|
return policy
|
||||||
|
|
||||||
|
|
|
@ -1,14 +1,14 @@
|
||||||
import torch
|
import torch
|
||||||
import triton
|
|
||||||
|
|
||||||
|
from colossalai.kernel.kernel_loader import InferenceOpsLoader
|
||||||
from colossalai.kernel.triton import rms_layernorm
|
from colossalai.kernel.triton import rms_layernorm
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import triton # noqa
|
import triton # noqa
|
||||||
|
|
||||||
except ImportError:
|
except ImportError:
|
||||||
print("please install triton from https://github.com/openai/triton")
|
print("please install triton from https://github.com/openai/triton")
|
||||||
|
|
||||||
|
inference_ops = InferenceOpsLoader().load()
|
||||||
|
|
||||||
# Triton benchmark plot attributions
|
# Triton benchmark plot attributions
|
||||||
configs = [
|
configs = [
|
||||||
|
@ -19,16 +19,20 @@ configs = [
|
||||||
line_vals=[
|
line_vals=[
|
||||||
"vllm_rms_layernorm",
|
"vllm_rms_layernorm",
|
||||||
"triton_rms_layernorm",
|
"triton_rms_layernorm",
|
||||||
"triton_rms_layernorm_with_residual",
|
"cuda_rms_layernorm",
|
||||||
"vllm_rms_layernorm_with_residual",
|
"vllm_rms_layernorm_with_residual",
|
||||||
|
"triton_rms_layernorm_with_residual",
|
||||||
|
"cuda_rms_layernorm_with_residual",
|
||||||
],
|
],
|
||||||
line_names=[
|
line_names=[
|
||||||
"vllm_rms_layernorm",
|
"vllm_rms_layernorm",
|
||||||
"triton_rms_layernorm",
|
"triton_rms_layernorm",
|
||||||
"triton_rms_layernorm_with_residual",
|
"cuda_rms_layernorm",
|
||||||
"vllm_rms_layernorm_with_residual",
|
"vllm_rms_layernorm_with_residual",
|
||||||
|
"triton_rms_layernorm_with_residual",
|
||||||
|
"cuda_rms_layernorm_with_residual",
|
||||||
],
|
],
|
||||||
styles=[("red", "-"), ("blue", "-"), ("yellow", "-"), ("green", "-")],
|
styles=[("red", "-"), ("blue", "-"), ("yellow", "-"), ("red", "--"), ("blue", "--"), ("yellow", "--")],
|
||||||
ylabel="ms",
|
ylabel="ms",
|
||||||
plot_name=f"RMSNorm benchmarking results",
|
plot_name=f"RMSNorm benchmarking results",
|
||||||
args={"HIDDEN_SIZE": 1024},
|
args={"HIDDEN_SIZE": 1024},
|
||||||
|
@ -62,10 +66,15 @@ def benchmark_rms_layernorm(
|
||||||
fn = lambda: vllm_norm(x)
|
fn = lambda: vllm_norm(x)
|
||||||
elif provider == "triton_rms_layernorm":
|
elif provider == "triton_rms_layernorm":
|
||||||
fn = lambda: rms_layernorm(x, weight, eps=eps)
|
fn = lambda: rms_layernorm(x, weight, eps=eps)
|
||||||
|
elif provider == "cuda_rms_layernorm":
|
||||||
|
out = torch.empty_like(x)
|
||||||
|
fn = lambda: inference_ops.rms_layernorm(out, x, weight, eps)
|
||||||
elif provider == "vllm_rms_layernorm_with_residual":
|
elif provider == "vllm_rms_layernorm_with_residual":
|
||||||
fn = lambda: vllm_norm(x, residual=residual)
|
fn = lambda: vllm_norm(x, residual=residual)
|
||||||
elif provider == "triton_rms_layernorm_with_residual":
|
elif provider == "triton_rms_layernorm_with_residual":
|
||||||
fn = lambda: rms_layernorm(x, weight, eps=eps, residual=residual)
|
fn = lambda: rms_layernorm(x, weight, eps=eps, residual=residual)
|
||||||
|
elif provider == "cuda_rms_layernorm_with_residual":
|
||||||
|
fn = lambda: inference_ops.fused_add_rms_layernorm(x, residual, weight, eps)
|
||||||
else:
|
else:
|
||||||
raise ValueError("Undefined provider.")
|
raise ValueError("Undefined provider.")
|
||||||
|
|
|
@ -11,8 +11,25 @@ void decode_kv_cache_memcpy(
|
||||||
|
|
||||||
torch::Tensor silu_and_mul(const torch::Tensor& ins);
|
torch::Tensor silu_and_mul(const torch::Tensor& ins);
|
||||||
|
|
||||||
|
void rms_layernorm(torch::Tensor& out, // [..., hidden_size]
|
||||||
|
torch::Tensor& input, // [..., hidden_size]
|
||||||
|
torch::Tensor& weight, // [hidden_size]
|
||||||
|
float epsilon);
|
||||||
|
|
||||||
|
void fused_add_rms_layernorm(torch::Tensor& input, // [..., hidden_size]
|
||||||
|
torch::Tensor& residual, // [..., hidden_size]
|
||||||
|
torch::Tensor& weight, // [hidden_size]
|
||||||
|
float epsilon);
|
||||||
|
|
||||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||||
m.def("decode_kv_cache_memcpy", &decode_kv_cache_memcpy,
|
m.def("decode_kv_cache_memcpy", &decode_kv_cache_memcpy,
|
||||||
"Copy the GPU memory of kvcache during the decode stage.");
|
"Copy the GPU memory of kvcache during the decode stage.");
|
||||||
|
|
||||||
m.def("silu_and_mul", &silu_and_mul, "Silu with a following multiply");
|
m.def("silu_and_mul", &silu_and_mul, "Silu with a following multiply");
|
||||||
|
|
||||||
|
m.def("rms_layernorm", &rms_layernorm,
|
||||||
|
"Apply Root Mean Square (RMS) Normalization to the input tensor.");
|
||||||
|
|
||||||
|
m.def("fused_add_rms_layernorm", &fused_add_rms_layernorm,
|
||||||
|
"In-place fused Add and RMS Normalization.");
|
||||||
}
|
}
|
||||||
|
|
|
@ -0,0 +1,126 @@
|
||||||
|
/*This code from VLLM:
|
||||||
|
* https://github.com/vllm-project/vllm/
|
||||||
|
* with minor changes. */
|
||||||
|
|
||||||
|
#include <ATen/cuda/CUDAContext.h>
|
||||||
|
#include <torch/extension.h>
|
||||||
|
#include <c10/cuda/CUDAGuard.h>
|
||||||
|
#include <stdio.h>
|
||||||
|
|
||||||
|
|
||||||
|
#include "block_reduce.h"
|
||||||
|
#include "type_shim.h"
|
||||||
|
|
||||||
|
template<typename scalar_t>
|
||||||
|
__global__ void rms_layernorm_kernel(
|
||||||
|
scalar_t* __restrict__ out, // [..., hidden_size]
|
||||||
|
const scalar_t* __restrict__ input, // [..., hidden_size]
|
||||||
|
const scalar_t* __restrict__ weight, // [hidden_size]
|
||||||
|
const float epsilon,
|
||||||
|
const int num_tokens,
|
||||||
|
const int hidden_size) {
|
||||||
|
__shared__ float s_variance;
|
||||||
|
float variance = 0.0f;
|
||||||
|
/*
|
||||||
|
* since the open-sourced LLM's hidden dimensions mainly range from
|
||||||
|
* 4096 (LLAMA-7B) to 8192 (LLAMA-65B), we thus set the supported
|
||||||
|
* hidden dimension limit to 8192, and each thread's capacity
|
||||||
|
* for caching input tensors to 8 (8192 = 8 * 1024) which
|
||||||
|
* will cause problems for extremely large models, such as
|
||||||
|
* Megatron-Turing NLG 530B with hidden dimensions up to 20480
|
||||||
|
*/
|
||||||
|
float x_local[8];
|
||||||
|
|
||||||
|
for (int idx = threadIdx.x, cnt = 0; idx < hidden_size; idx += blockDim.x, cnt++) {
|
||||||
|
x_local[cnt] = (float) input[blockIdx.x * hidden_size + idx];
|
||||||
|
variance += x_local[cnt] * x_local[cnt];
|
||||||
|
}
|
||||||
|
variance = blockReduceSum<float>(variance);
|
||||||
|
if (threadIdx.x == 0) {
|
||||||
|
s_variance = rsqrtf(variance / hidden_size + epsilon);
|
||||||
|
}
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
|
for (int idx = threadIdx.x, cnt = 0; idx < hidden_size; idx += blockDim.x, cnt++) {
|
||||||
|
out[blockIdx.x * hidden_size + idx] = ((scalar_t) (x_local[cnt] * s_variance)) * weight[idx];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template<typename scalar_t>
|
||||||
|
__global__ void fused_add_rms_layernorm_kernel(
|
||||||
|
scalar_t* __restrict__ input, // [..., hidden_size]
|
||||||
|
scalar_t* __restrict__ residual, // [..., hidden_size]
|
||||||
|
const scalar_t* __restrict__ weight, // [hidden_size]
|
||||||
|
const float epsilon,
|
||||||
|
const int num_tokens,
|
||||||
|
const int hidden_size) {
|
||||||
|
__shared__ float s_variance;
|
||||||
|
float variance = 0.0f;
|
||||||
|
float x_local[8];
|
||||||
|
|
||||||
|
for (int idx = threadIdx.x, cnt = 0; idx < hidden_size; idx += blockDim.x, cnt++) {
|
||||||
|
x_local[cnt] = (float) input[blockIdx.x * hidden_size + idx];
|
||||||
|
x_local[cnt] += (float) residual[blockIdx.x * hidden_size + idx];
|
||||||
|
variance += x_local[cnt] * x_local[cnt];
|
||||||
|
residual[blockIdx.x * hidden_size + idx] = (scalar_t) x_local[cnt];
|
||||||
|
}
|
||||||
|
variance = blockReduceSum<float>(variance);
|
||||||
|
if (threadIdx.x == 0) {
|
||||||
|
s_variance = rsqrtf(variance / hidden_size + epsilon);
|
||||||
|
}
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
|
for (int idx = threadIdx.x, cnt = 0; idx < hidden_size; idx += blockDim.x, cnt++) {
|
||||||
|
input[blockIdx.x * hidden_size + idx] = ((scalar_t) (x_local[cnt] * s_variance)) * weight[idx];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void rms_layernorm(
|
||||||
|
torch::Tensor& out, // [..., hidden_size]
|
||||||
|
torch::Tensor& input, // [..., hidden_size]
|
||||||
|
torch::Tensor& weight, // [hidden_size]
|
||||||
|
float epsilon) {
|
||||||
|
int hidden_size = input.size(-1);
|
||||||
|
int num_tokens = input.numel() / hidden_size;
|
||||||
|
|
||||||
|
dim3 grid(num_tokens);
|
||||||
|
dim3 block(std::min(hidden_size, 1024));
|
||||||
|
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
|
||||||
|
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||||
|
|
||||||
|
DISPATCH_FLOAT_HALF_AND_BFLOAT(
|
||||||
|
input.scalar_type(),
|
||||||
|
"rms_layernorm_kernel",
|
||||||
|
rms_layernorm_kernel<scalar_t><<<grid, block, 0, stream>>>(
|
||||||
|
out.data_ptr<scalar_t>(),
|
||||||
|
input.data_ptr<scalar_t>(),
|
||||||
|
weight.data_ptr<scalar_t>(),
|
||||||
|
epsilon,
|
||||||
|
num_tokens,
|
||||||
|
hidden_size);)
|
||||||
|
}
|
||||||
|
|
||||||
|
void fused_add_rms_layernorm(
|
||||||
|
torch::Tensor& input, // [..., hidden_size]
|
||||||
|
torch::Tensor& residual, // [..., hidden_size]
|
||||||
|
torch::Tensor& weight, // [hidden_size]
|
||||||
|
float epsilon) {
|
||||||
|
int hidden_size = input.size(-1);
|
||||||
|
int num_tokens = input.numel() / hidden_size;
|
||||||
|
|
||||||
|
dim3 grid(num_tokens);
|
||||||
|
dim3 block(std::min(hidden_size, 1024));
|
||||||
|
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
|
||||||
|
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||||
|
|
||||||
|
DISPATCH_FLOAT_HALF_AND_BFLOAT(
|
||||||
|
input.scalar_type(),
|
||||||
|
"fused_add_rms_layernorm_kernel",
|
||||||
|
fused_add_rms_layernorm_kernel<scalar_t><<<grid, block, 0, stream>>>(
|
||||||
|
input.data_ptr<scalar_t>(),
|
||||||
|
residual.data_ptr<scalar_t>(),
|
||||||
|
weight.data_ptr<scalar_t>(),
|
||||||
|
epsilon,
|
||||||
|
num_tokens,
|
||||||
|
hidden_size);)
|
||||||
|
}
|
|
@ -13,12 +13,13 @@ class InferenceOpsCudaExtension(_CudaExtension):
|
||||||
"cuda/colossal_inference_C_frontend.cpp",
|
"cuda/colossal_inference_C_frontend.cpp",
|
||||||
"cuda/decode_kv_cache_memcpy_kernel.cu",
|
"cuda/decode_kv_cache_memcpy_kernel.cu",
|
||||||
"cuda/activation_kernel.cu",
|
"cuda/activation_kernel.cu",
|
||||||
|
"cuda/rms_layernorm_kernel.cu",
|
||||||
]
|
]
|
||||||
]
|
]
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
def include_dirs(self):
|
def include_dirs(self):
|
||||||
ret = [self.get_cuda_home_include()]
|
ret = [self.csrc_abs_path("cuda/include"), self.get_cuda_home_include()]
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
def cxx_flags(self):
|
def cxx_flags(self):
|
||||||
|
|
|
@ -22,15 +22,11 @@ def setup_seed(seed):
|
||||||
def check_inference_engine(use_engine=False, prompt_template=None):
|
def check_inference_engine(use_engine=False, prompt_template=None):
|
||||||
setup_seed(20)
|
setup_seed(20)
|
||||||
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/llama-tokenizer")
|
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/llama-tokenizer")
|
||||||
model = (
|
model = LlamaForCausalLM(
|
||||||
LlamaForCausalLM(
|
|
||||||
LlamaConfig(
|
LlamaConfig(
|
||||||
vocab_size=50000, hidden_size=512, intermediate_size=1536, num_attention_heads=4, num_hidden_layers=16
|
vocab_size=50000, hidden_size=512, intermediate_size=1536, num_attention_heads=4, num_hidden_layers=16
|
||||||
)
|
)
|
||||||
)
|
).cuda()
|
||||||
.cuda()
|
|
||||||
.half()
|
|
||||||
)
|
|
||||||
model = model.eval()
|
model = model.eval()
|
||||||
|
|
||||||
inputs = [
|
inputs = [
|
||||||
|
@ -44,7 +40,7 @@ def check_inference_engine(use_engine=False, prompt_template=None):
|
||||||
top_k = 50
|
top_k = 50
|
||||||
|
|
||||||
if use_engine:
|
if use_engine:
|
||||||
inference_config = InferenceConfig(max_output_len=output_len, prompt_template=prompt_template)
|
inference_config = InferenceConfig(max_output_len=output_len, prompt_template=prompt_template, dtype="fp32")
|
||||||
inference_engine = InferenceEngine(model, tokenizer, inference_config, verbose=True)
|
inference_engine = InferenceEngine(model, tokenizer, inference_config, verbose=True)
|
||||||
assert inference_engine.generation_config.max_new_tokens == output_len
|
assert inference_engine.generation_config.max_new_tokens == output_len
|
||||||
inference_engine.add_request(prompts=inputs)
|
inference_engine.add_request(prompts=inputs)
|
||||||
|
|
|
@ -0,0 +1,51 @@
|
||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
from transformers.models.llama.modeling_llama import LlamaRMSNorm
|
||||||
|
|
||||||
|
from colossalai.kernel.kernel_loader import InferenceOpsLoader
|
||||||
|
from colossalai.utils import get_current_device
|
||||||
|
|
||||||
|
inference_ops = InferenceOpsLoader().load()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("M", [2, 4, 8, 16])
|
||||||
|
@pytest.mark.parametrize("N", [64, 128, 512])
|
||||||
|
def test_rms_layernorm(M: int, N: int):
|
||||||
|
torch.manual_seed(123)
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
torch.cuda.reset_peak_memory_stats()
|
||||||
|
|
||||||
|
device = get_current_device()
|
||||||
|
|
||||||
|
dtype = torch.float16
|
||||||
|
eps = 1e-5
|
||||||
|
x_shape = (M, N)
|
||||||
|
w_shape = (x_shape[-1],)
|
||||||
|
weight = torch.ones(w_shape, dtype=dtype, device=device)
|
||||||
|
residual = torch.rand(x_shape, dtype=dtype, device=device)
|
||||||
|
residual_copy = residual.clone()
|
||||||
|
rms_norm = LlamaRMSNorm(hidden_size=N, eps=eps).cuda()
|
||||||
|
x = -2.3 + 0.5 * torch.randn(x_shape, dtype=dtype, device="cuda")
|
||||||
|
x_copy = x.clone()
|
||||||
|
|
||||||
|
y_cuda = torch.empty_like(x)
|
||||||
|
inference_ops.rms_layernorm(y_cuda, x, weight, eps)
|
||||||
|
y_llama = rms_norm.forward(x).to(dtype)
|
||||||
|
|
||||||
|
assert y_cuda.shape == y_llama.shape
|
||||||
|
assert torch.allclose(y_cuda, y_llama, atol=1e-5, rtol=1e-3)
|
||||||
|
|
||||||
|
inference_ops.fused_add_rms_layernorm(x, residual, weight, eps)
|
||||||
|
y_cuda = x
|
||||||
|
|
||||||
|
x = x_copy + residual_copy
|
||||||
|
y_llama = rms_norm.forward(x).to(dtype)
|
||||||
|
|
||||||
|
assert y_cuda.shape == y_llama.shape
|
||||||
|
assert torch.allclose(y_cuda, y_llama, atol=1e-5, rtol=1e-3)
|
||||||
|
assert torch.allclose(x, residual, atol=1e-5, rtol=1e-3)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
test_rms_layernorm(16, 512)
|
Loading…
Reference in New Issue