mirror of https://github.com/hpcaitech/ColossalAI
feat rmsnorm cuda kernel and add unittest, benchmark script (#5417)
parent
2b28b54ac6
commit
f7aecc0c6b
|
@ -9,6 +9,7 @@ from transformers.models.llama.modeling_llama import (
|
|||
LlamaForCausalLM,
|
||||
LlamaMLP,
|
||||
LlamaModel,
|
||||
LlamaRMSNorm,
|
||||
)
|
||||
|
||||
from colossalai.inference.batch_bucket import BatchBucket
|
||||
|
@ -19,6 +20,7 @@ from colossalai.kernel.triton import (
|
|||
decoding_fused_rotary_embedding,
|
||||
flash_decoding_attention,
|
||||
get_xine_cache,
|
||||
rms_layernorm,
|
||||
rotary_embedding,
|
||||
)
|
||||
from colossalai.logging import get_dist_logger
|
||||
|
@ -124,7 +126,7 @@ def llama_model_forward(
|
|||
hidden_states = hidden_states[last_token_indexs - 1].contiguous()
|
||||
residual = residual[last_token_indexs - 1].contiguous()
|
||||
norm_output = torch.empty_like(hidden_states)
|
||||
hidden_states, _ = self.norm(hidden_states, norm_output, residual)
|
||||
hidden_states, _ = self.norm(hidden_states, norm_output, residual, use_cuda_kernel)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
@ -167,7 +169,7 @@ def llama_decoder_layer_forward(
|
|||
use_cuda_kernel: (bool, optional): Whether to use cuda kernel. Defaults to True.
|
||||
"""
|
||||
|
||||
hidden_states, residual = self.input_layernorm(hidden_states, norm_output, residual)
|
||||
hidden_states, residual = self.input_layernorm(hidden_states, norm_output, residual, use_cuda_kernel)
|
||||
# Self Attention
|
||||
hidden_states = self.self_attn(
|
||||
hidden_states=hidden_states,
|
||||
|
@ -185,12 +187,32 @@ def llama_decoder_layer_forward(
|
|||
)
|
||||
|
||||
# Fully Connected
|
||||
hidden_states, residual = self.post_attention_layernorm(hidden_states, norm_output, residual)
|
||||
hidden_states, residual = self.post_attention_layernorm(hidden_states, norm_output, residual, use_cuda_kernel)
|
||||
hidden_states = self.mlp(hidden_states)
|
||||
|
||||
return hidden_states, residual
|
||||
|
||||
|
||||
def llama_rmsnorm_forward(
|
||||
self: LlamaRMSNorm,
|
||||
hidden_states: torch.Tensor,
|
||||
norm_output: torch.Tensor,
|
||||
residual: torch.Tensor = None,
|
||||
use_cuda_kernel: bool = True,
|
||||
):
|
||||
if use_cuda_kernel:
|
||||
if residual is not None:
|
||||
inference_ops.fused_add_rms_layernorm(hidden_states, residual, self.weight.data, self.variance_epsilon)
|
||||
return hidden_states, residual
|
||||
|
||||
if norm_output is None:
|
||||
norm_output = torch.empty_like(hidden_states)
|
||||
inference_ops.rms_layernorm(norm_output, hidden_states, self.weight.data, self.variance_epsilon)
|
||||
return norm_output, hidden_states
|
||||
else:
|
||||
return rms_layernorm(hidden_states, self.weight.data, self.variance_epsilon, norm_output, residual)
|
||||
|
||||
|
||||
class NopadLlamaAttention(LlamaAttention):
|
||||
def __init__(
|
||||
self,
|
||||
|
|
|
@ -1,6 +1,5 @@
|
|||
from functools import partial
|
||||
|
||||
import torch
|
||||
from torch.nn import Parameter
|
||||
from transformers.models.llama.modeling_llama import LlamaDecoderLayer, LlamaForCausalLM, LlamaModel, LlamaRMSNorm
|
||||
|
||||
|
@ -10,6 +9,7 @@ from colossalai.inference.modeling.models.nopadding_llama import (
|
|||
llama_causal_lm_forward,
|
||||
llama_decoder_layer_forward,
|
||||
llama_model_forward,
|
||||
llama_rmsnorm_forward,
|
||||
)
|
||||
from colossalai.inference.utils import init_to_get_rotary
|
||||
from colossalai.shardformer.policies.base_policy import ModulePolicyDescription, SubModuleReplacementDescription
|
||||
|
@ -17,27 +17,6 @@ from colossalai.shardformer.policies.base_policy import ModulePolicyDescription,
|
|||
# import colossalai
|
||||
from colossalai.shardformer.policies.llama import LlamaForCausalLMPolicy
|
||||
|
||||
try:
|
||||
from colossalai.kernel.triton import rms_layernorm
|
||||
|
||||
HAS_TRITON_RMSNORM = True
|
||||
except:
|
||||
print("you should install triton from https://github.com/openai/triton")
|
||||
HAS_TRITON_RMSNORM = False
|
||||
|
||||
|
||||
def get_triton_rmsnorm_forward():
|
||||
if HAS_TRITON_RMSNORM:
|
||||
|
||||
def _triton_rmsnorm_forward(
|
||||
self: LlamaRMSNorm, hidden_states: torch.Tensor, norm_output: torch.Tensor, residual: torch.Tensor = None
|
||||
):
|
||||
return rms_layernorm(hidden_states, self.weight.data, self.variance_epsilon, norm_output, residual)
|
||||
|
||||
return _triton_rmsnorm_forward
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
class NoPaddingLlamaModelInferPolicy(LlamaForCausalLMPolicy):
|
||||
def __init__(self) -> None:
|
||||
|
@ -84,15 +63,9 @@ class NoPaddingLlamaModelInferPolicy(LlamaForCausalLMPolicy):
|
|||
description=method_replacement, policy=policy, target_key=LlamaDecoderLayer
|
||||
)
|
||||
|
||||
infer_forward = None
|
||||
if HAS_TRITON_RMSNORM:
|
||||
infer_forward = get_triton_rmsnorm_forward()
|
||||
|
||||
if infer_forward is not None:
|
||||
method_replacement = {"forward": partial(infer_forward)}
|
||||
self.append_or_create_method_replacement(
|
||||
description=method_replacement, policy=policy, target_key=LlamaRMSNorm
|
||||
)
|
||||
infer_forward = llama_rmsnorm_forward
|
||||
method_replacement = {"forward": partial(infer_forward)}
|
||||
self.append_or_create_method_replacement(description=method_replacement, policy=policy, target_key=LlamaRMSNorm)
|
||||
|
||||
return policy
|
||||
|
||||
|
|
|
@ -1,14 +1,14 @@
|
|||
import torch
|
||||
import triton
|
||||
|
||||
from colossalai.kernel.kernel_loader import InferenceOpsLoader
|
||||
from colossalai.kernel.triton import rms_layernorm
|
||||
|
||||
try:
|
||||
import triton # noqa
|
||||
|
||||
except ImportError:
|
||||
print("please install triton from https://github.com/openai/triton")
|
||||
|
||||
inference_ops = InferenceOpsLoader().load()
|
||||
|
||||
# Triton benchmark plot attributions
|
||||
configs = [
|
||||
|
@ -19,16 +19,20 @@ configs = [
|
|||
line_vals=[
|
||||
"vllm_rms_layernorm",
|
||||
"triton_rms_layernorm",
|
||||
"triton_rms_layernorm_with_residual",
|
||||
"cuda_rms_layernorm",
|
||||
"vllm_rms_layernorm_with_residual",
|
||||
"triton_rms_layernorm_with_residual",
|
||||
"cuda_rms_layernorm_with_residual",
|
||||
],
|
||||
line_names=[
|
||||
"vllm_rms_layernorm",
|
||||
"triton_rms_layernorm",
|
||||
"triton_rms_layernorm_with_residual",
|
||||
"cuda_rms_layernorm",
|
||||
"vllm_rms_layernorm_with_residual",
|
||||
"triton_rms_layernorm_with_residual",
|
||||
"cuda_rms_layernorm_with_residual",
|
||||
],
|
||||
styles=[("red", "-"), ("blue", "-"), ("yellow", "-"), ("green", "-")],
|
||||
styles=[("red", "-"), ("blue", "-"), ("yellow", "-"), ("red", "--"), ("blue", "--"), ("yellow", "--")],
|
||||
ylabel="ms",
|
||||
plot_name=f"RMSNorm benchmarking results",
|
||||
args={"HIDDEN_SIZE": 1024},
|
||||
|
@ -62,10 +66,15 @@ def benchmark_rms_layernorm(
|
|||
fn = lambda: vllm_norm(x)
|
||||
elif provider == "triton_rms_layernorm":
|
||||
fn = lambda: rms_layernorm(x, weight, eps=eps)
|
||||
elif provider == "cuda_rms_layernorm":
|
||||
out = torch.empty_like(x)
|
||||
fn = lambda: inference_ops.rms_layernorm(out, x, weight, eps)
|
||||
elif provider == "vllm_rms_layernorm_with_residual":
|
||||
fn = lambda: vllm_norm(x, residual=residual)
|
||||
elif provider == "triton_rms_layernorm_with_residual":
|
||||
fn = lambda: rms_layernorm(x, weight, eps=eps, residual=residual)
|
||||
elif provider == "cuda_rms_layernorm_with_residual":
|
||||
fn = lambda: inference_ops.fused_add_rms_layernorm(x, residual, weight, eps)
|
||||
else:
|
||||
raise ValueError("Undefined provider.")
|
||||
|
|
@ -11,8 +11,25 @@ void decode_kv_cache_memcpy(
|
|||
|
||||
torch::Tensor silu_and_mul(const torch::Tensor& ins);
|
||||
|
||||
void rms_layernorm(torch::Tensor& out, // [..., hidden_size]
|
||||
torch::Tensor& input, // [..., hidden_size]
|
||||
torch::Tensor& weight, // [hidden_size]
|
||||
float epsilon);
|
||||
|
||||
void fused_add_rms_layernorm(torch::Tensor& input, // [..., hidden_size]
|
||||
torch::Tensor& residual, // [..., hidden_size]
|
||||
torch::Tensor& weight, // [hidden_size]
|
||||
float epsilon);
|
||||
|
||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||
m.def("decode_kv_cache_memcpy", &decode_kv_cache_memcpy,
|
||||
"Copy the GPU memory of kvcache during the decode stage.");
|
||||
|
||||
m.def("silu_and_mul", &silu_and_mul, "Silu with a following multiply");
|
||||
|
||||
m.def("rms_layernorm", &rms_layernorm,
|
||||
"Apply Root Mean Square (RMS) Normalization to the input tensor.");
|
||||
|
||||
m.def("fused_add_rms_layernorm", &fused_add_rms_layernorm,
|
||||
"In-place fused Add and RMS Normalization.");
|
||||
}
|
||||
|
|
|
@ -0,0 +1,126 @@
|
|||
/*This code from VLLM:
|
||||
* https://github.com/vllm-project/vllm/
|
||||
* with minor changes. */
|
||||
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <torch/extension.h>
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
#include <stdio.h>
|
||||
|
||||
|
||||
#include "block_reduce.h"
|
||||
#include "type_shim.h"
|
||||
|
||||
template<typename scalar_t>
|
||||
__global__ void rms_layernorm_kernel(
|
||||
scalar_t* __restrict__ out, // [..., hidden_size]
|
||||
const scalar_t* __restrict__ input, // [..., hidden_size]
|
||||
const scalar_t* __restrict__ weight, // [hidden_size]
|
||||
const float epsilon,
|
||||
const int num_tokens,
|
||||
const int hidden_size) {
|
||||
__shared__ float s_variance;
|
||||
float variance = 0.0f;
|
||||
/*
|
||||
* since the open-sourced LLM's hidden dimensions mainly range from
|
||||
* 4096 (LLAMA-7B) to 8192 (LLAMA-65B), we thus set the supported
|
||||
* hidden dimension limit to 8192, and each thread's capacity
|
||||
* for caching input tensors to 8 (8192 = 8 * 1024) which
|
||||
* will cause problems for extremely large models, such as
|
||||
* Megatron-Turing NLG 530B with hidden dimensions up to 20480
|
||||
*/
|
||||
float x_local[8];
|
||||
|
||||
for (int idx = threadIdx.x, cnt = 0; idx < hidden_size; idx += blockDim.x, cnt++) {
|
||||
x_local[cnt] = (float) input[blockIdx.x * hidden_size + idx];
|
||||
variance += x_local[cnt] * x_local[cnt];
|
||||
}
|
||||
variance = blockReduceSum<float>(variance);
|
||||
if (threadIdx.x == 0) {
|
||||
s_variance = rsqrtf(variance / hidden_size + epsilon);
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
for (int idx = threadIdx.x, cnt = 0; idx < hidden_size; idx += blockDim.x, cnt++) {
|
||||
out[blockIdx.x * hidden_size + idx] = ((scalar_t) (x_local[cnt] * s_variance)) * weight[idx];
|
||||
}
|
||||
}
|
||||
|
||||
template<typename scalar_t>
|
||||
__global__ void fused_add_rms_layernorm_kernel(
|
||||
scalar_t* __restrict__ input, // [..., hidden_size]
|
||||
scalar_t* __restrict__ residual, // [..., hidden_size]
|
||||
const scalar_t* __restrict__ weight, // [hidden_size]
|
||||
const float epsilon,
|
||||
const int num_tokens,
|
||||
const int hidden_size) {
|
||||
__shared__ float s_variance;
|
||||
float variance = 0.0f;
|
||||
float x_local[8];
|
||||
|
||||
for (int idx = threadIdx.x, cnt = 0; idx < hidden_size; idx += blockDim.x, cnt++) {
|
||||
x_local[cnt] = (float) input[blockIdx.x * hidden_size + idx];
|
||||
x_local[cnt] += (float) residual[blockIdx.x * hidden_size + idx];
|
||||
variance += x_local[cnt] * x_local[cnt];
|
||||
residual[blockIdx.x * hidden_size + idx] = (scalar_t) x_local[cnt];
|
||||
}
|
||||
variance = blockReduceSum<float>(variance);
|
||||
if (threadIdx.x == 0) {
|
||||
s_variance = rsqrtf(variance / hidden_size + epsilon);
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
for (int idx = threadIdx.x, cnt = 0; idx < hidden_size; idx += blockDim.x, cnt++) {
|
||||
input[blockIdx.x * hidden_size + idx] = ((scalar_t) (x_local[cnt] * s_variance)) * weight[idx];
|
||||
}
|
||||
}
|
||||
|
||||
void rms_layernorm(
|
||||
torch::Tensor& out, // [..., hidden_size]
|
||||
torch::Tensor& input, // [..., hidden_size]
|
||||
torch::Tensor& weight, // [hidden_size]
|
||||
float epsilon) {
|
||||
int hidden_size = input.size(-1);
|
||||
int num_tokens = input.numel() / hidden_size;
|
||||
|
||||
dim3 grid(num_tokens);
|
||||
dim3 block(std::min(hidden_size, 1024));
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
|
||||
DISPATCH_FLOAT_HALF_AND_BFLOAT(
|
||||
input.scalar_type(),
|
||||
"rms_layernorm_kernel",
|
||||
rms_layernorm_kernel<scalar_t><<<grid, block, 0, stream>>>(
|
||||
out.data_ptr<scalar_t>(),
|
||||
input.data_ptr<scalar_t>(),
|
||||
weight.data_ptr<scalar_t>(),
|
||||
epsilon,
|
||||
num_tokens,
|
||||
hidden_size);)
|
||||
}
|
||||
|
||||
void fused_add_rms_layernorm(
|
||||
torch::Tensor& input, // [..., hidden_size]
|
||||
torch::Tensor& residual, // [..., hidden_size]
|
||||
torch::Tensor& weight, // [hidden_size]
|
||||
float epsilon) {
|
||||
int hidden_size = input.size(-1);
|
||||
int num_tokens = input.numel() / hidden_size;
|
||||
|
||||
dim3 grid(num_tokens);
|
||||
dim3 block(std::min(hidden_size, 1024));
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
|
||||
DISPATCH_FLOAT_HALF_AND_BFLOAT(
|
||||
input.scalar_type(),
|
||||
"fused_add_rms_layernorm_kernel",
|
||||
fused_add_rms_layernorm_kernel<scalar_t><<<grid, block, 0, stream>>>(
|
||||
input.data_ptr<scalar_t>(),
|
||||
residual.data_ptr<scalar_t>(),
|
||||
weight.data_ptr<scalar_t>(),
|
||||
epsilon,
|
||||
num_tokens,
|
||||
hidden_size);)
|
||||
}
|
|
@ -13,12 +13,13 @@ class InferenceOpsCudaExtension(_CudaExtension):
|
|||
"cuda/colossal_inference_C_frontend.cpp",
|
||||
"cuda/decode_kv_cache_memcpy_kernel.cu",
|
||||
"cuda/activation_kernel.cu",
|
||||
"cuda/rms_layernorm_kernel.cu",
|
||||
]
|
||||
]
|
||||
return ret
|
||||
|
||||
def include_dirs(self):
|
||||
ret = [self.get_cuda_home_include()]
|
||||
ret = [self.csrc_abs_path("cuda/include"), self.get_cuda_home_include()]
|
||||
return ret
|
||||
|
||||
def cxx_flags(self):
|
||||
|
|
|
@ -22,15 +22,11 @@ def setup_seed(seed):
|
|||
def check_inference_engine(use_engine=False, prompt_template=None):
|
||||
setup_seed(20)
|
||||
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/llama-tokenizer")
|
||||
model = (
|
||||
LlamaForCausalLM(
|
||||
LlamaConfig(
|
||||
vocab_size=50000, hidden_size=512, intermediate_size=1536, num_attention_heads=4, num_hidden_layers=16
|
||||
)
|
||||
model = LlamaForCausalLM(
|
||||
LlamaConfig(
|
||||
vocab_size=50000, hidden_size=512, intermediate_size=1536, num_attention_heads=4, num_hidden_layers=16
|
||||
)
|
||||
.cuda()
|
||||
.half()
|
||||
)
|
||||
).cuda()
|
||||
model = model.eval()
|
||||
|
||||
inputs = [
|
||||
|
@ -44,7 +40,7 @@ def check_inference_engine(use_engine=False, prompt_template=None):
|
|||
top_k = 50
|
||||
|
||||
if use_engine:
|
||||
inference_config = InferenceConfig(max_output_len=output_len, prompt_template=prompt_template)
|
||||
inference_config = InferenceConfig(max_output_len=output_len, prompt_template=prompt_template, dtype="fp32")
|
||||
inference_engine = InferenceEngine(model, tokenizer, inference_config, verbose=True)
|
||||
assert inference_engine.generation_config.max_new_tokens == output_len
|
||||
inference_engine.add_request(prompts=inputs)
|
||||
|
|
|
@ -0,0 +1,51 @@
|
|||
import pytest
|
||||
import torch
|
||||
from transformers.models.llama.modeling_llama import LlamaRMSNorm
|
||||
|
||||
from colossalai.kernel.kernel_loader import InferenceOpsLoader
|
||||
from colossalai.utils import get_current_device
|
||||
|
||||
inference_ops = InferenceOpsLoader().load()
|
||||
|
||||
|
||||
@pytest.mark.parametrize("M", [2, 4, 8, 16])
|
||||
@pytest.mark.parametrize("N", [64, 128, 512])
|
||||
def test_rms_layernorm(M: int, N: int):
|
||||
torch.manual_seed(123)
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.synchronize()
|
||||
torch.cuda.reset_peak_memory_stats()
|
||||
|
||||
device = get_current_device()
|
||||
|
||||
dtype = torch.float16
|
||||
eps = 1e-5
|
||||
x_shape = (M, N)
|
||||
w_shape = (x_shape[-1],)
|
||||
weight = torch.ones(w_shape, dtype=dtype, device=device)
|
||||
residual = torch.rand(x_shape, dtype=dtype, device=device)
|
||||
residual_copy = residual.clone()
|
||||
rms_norm = LlamaRMSNorm(hidden_size=N, eps=eps).cuda()
|
||||
x = -2.3 + 0.5 * torch.randn(x_shape, dtype=dtype, device="cuda")
|
||||
x_copy = x.clone()
|
||||
|
||||
y_cuda = torch.empty_like(x)
|
||||
inference_ops.rms_layernorm(y_cuda, x, weight, eps)
|
||||
y_llama = rms_norm.forward(x).to(dtype)
|
||||
|
||||
assert y_cuda.shape == y_llama.shape
|
||||
assert torch.allclose(y_cuda, y_llama, atol=1e-5, rtol=1e-3)
|
||||
|
||||
inference_ops.fused_add_rms_layernorm(x, residual, weight, eps)
|
||||
y_cuda = x
|
||||
|
||||
x = x_copy + residual_copy
|
||||
y_llama = rms_norm.forward(x).to(dtype)
|
||||
|
||||
assert y_cuda.shape == y_llama.shape
|
||||
assert torch.allclose(y_cuda, y_llama, atol=1e-5, rtol=1e-3)
|
||||
assert torch.allclose(x, residual, atol=1e-5, rtol=1e-3)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_rms_layernorm(16, 512)
|
Loading…
Reference in New Issue