feat rmsnorm cuda kernel and add unittest, benchmark script (#5417)

pull/5445/head
Steve Luo 2024-03-08 16:21:12 +08:00 committed by GitHub
parent 2b28b54ac6
commit f7aecc0c6b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 244 additions and 49 deletions

View File

@ -9,6 +9,7 @@ from transformers.models.llama.modeling_llama import (
LlamaForCausalLM,
LlamaMLP,
LlamaModel,
LlamaRMSNorm,
)
from colossalai.inference.batch_bucket import BatchBucket
@ -19,6 +20,7 @@ from colossalai.kernel.triton import (
decoding_fused_rotary_embedding,
flash_decoding_attention,
get_xine_cache,
rms_layernorm,
rotary_embedding,
)
from colossalai.logging import get_dist_logger
@ -124,7 +126,7 @@ def llama_model_forward(
hidden_states = hidden_states[last_token_indexs - 1].contiguous()
residual = residual[last_token_indexs - 1].contiguous()
norm_output = torch.empty_like(hidden_states)
hidden_states, _ = self.norm(hidden_states, norm_output, residual)
hidden_states, _ = self.norm(hidden_states, norm_output, residual, use_cuda_kernel)
return hidden_states
@ -167,7 +169,7 @@ def llama_decoder_layer_forward(
use_cuda_kernel: (bool, optional): Whether to use cuda kernel. Defaults to True.
"""
hidden_states, residual = self.input_layernorm(hidden_states, norm_output, residual)
hidden_states, residual = self.input_layernorm(hidden_states, norm_output, residual, use_cuda_kernel)
# Self Attention
hidden_states = self.self_attn(
hidden_states=hidden_states,
@ -185,12 +187,32 @@ def llama_decoder_layer_forward(
)
# Fully Connected
hidden_states, residual = self.post_attention_layernorm(hidden_states, norm_output, residual)
hidden_states, residual = self.post_attention_layernorm(hidden_states, norm_output, residual, use_cuda_kernel)
hidden_states = self.mlp(hidden_states)
return hidden_states, residual
def llama_rmsnorm_forward(
self: LlamaRMSNorm,
hidden_states: torch.Tensor,
norm_output: torch.Tensor,
residual: torch.Tensor = None,
use_cuda_kernel: bool = True,
):
if use_cuda_kernel:
if residual is not None:
inference_ops.fused_add_rms_layernorm(hidden_states, residual, self.weight.data, self.variance_epsilon)
return hidden_states, residual
if norm_output is None:
norm_output = torch.empty_like(hidden_states)
inference_ops.rms_layernorm(norm_output, hidden_states, self.weight.data, self.variance_epsilon)
return norm_output, hidden_states
else:
return rms_layernorm(hidden_states, self.weight.data, self.variance_epsilon, norm_output, residual)
class NopadLlamaAttention(LlamaAttention):
def __init__(
self,

View File

@ -1,6 +1,5 @@
from functools import partial
import torch
from torch.nn import Parameter
from transformers.models.llama.modeling_llama import LlamaDecoderLayer, LlamaForCausalLM, LlamaModel, LlamaRMSNorm
@ -10,6 +9,7 @@ from colossalai.inference.modeling.models.nopadding_llama import (
llama_causal_lm_forward,
llama_decoder_layer_forward,
llama_model_forward,
llama_rmsnorm_forward,
)
from colossalai.inference.utils import init_to_get_rotary
from colossalai.shardformer.policies.base_policy import ModulePolicyDescription, SubModuleReplacementDescription
@ -17,27 +17,6 @@ from colossalai.shardformer.policies.base_policy import ModulePolicyDescription,
# import colossalai
from colossalai.shardformer.policies.llama import LlamaForCausalLMPolicy
try:
from colossalai.kernel.triton import rms_layernorm
HAS_TRITON_RMSNORM = True
except:
print("you should install triton from https://github.com/openai/triton")
HAS_TRITON_RMSNORM = False
def get_triton_rmsnorm_forward():
if HAS_TRITON_RMSNORM:
def _triton_rmsnorm_forward(
self: LlamaRMSNorm, hidden_states: torch.Tensor, norm_output: torch.Tensor, residual: torch.Tensor = None
):
return rms_layernorm(hidden_states, self.weight.data, self.variance_epsilon, norm_output, residual)
return _triton_rmsnorm_forward
else:
return None
class NoPaddingLlamaModelInferPolicy(LlamaForCausalLMPolicy):
def __init__(self) -> None:
@ -84,15 +63,9 @@ class NoPaddingLlamaModelInferPolicy(LlamaForCausalLMPolicy):
description=method_replacement, policy=policy, target_key=LlamaDecoderLayer
)
infer_forward = None
if HAS_TRITON_RMSNORM:
infer_forward = get_triton_rmsnorm_forward()
if infer_forward is not None:
method_replacement = {"forward": partial(infer_forward)}
self.append_or_create_method_replacement(
description=method_replacement, policy=policy, target_key=LlamaRMSNorm
)
infer_forward = llama_rmsnorm_forward
method_replacement = {"forward": partial(infer_forward)}
self.append_or_create_method_replacement(description=method_replacement, policy=policy, target_key=LlamaRMSNorm)
return policy

View File

@ -1,14 +1,14 @@
import torch
import triton
from colossalai.kernel.kernel_loader import InferenceOpsLoader
from colossalai.kernel.triton import rms_layernorm
try:
import triton # noqa
except ImportError:
print("please install triton from https://github.com/openai/triton")
inference_ops = InferenceOpsLoader().load()
# Triton benchmark plot attributions
configs = [
@ -19,16 +19,20 @@ configs = [
line_vals=[
"vllm_rms_layernorm",
"triton_rms_layernorm",
"triton_rms_layernorm_with_residual",
"cuda_rms_layernorm",
"vllm_rms_layernorm_with_residual",
"triton_rms_layernorm_with_residual",
"cuda_rms_layernorm_with_residual",
],
line_names=[
"vllm_rms_layernorm",
"triton_rms_layernorm",
"triton_rms_layernorm_with_residual",
"cuda_rms_layernorm",
"vllm_rms_layernorm_with_residual",
"triton_rms_layernorm_with_residual",
"cuda_rms_layernorm_with_residual",
],
styles=[("red", "-"), ("blue", "-"), ("yellow", "-"), ("green", "-")],
styles=[("red", "-"), ("blue", "-"), ("yellow", "-"), ("red", "--"), ("blue", "--"), ("yellow", "--")],
ylabel="ms",
plot_name=f"RMSNorm benchmarking results",
args={"HIDDEN_SIZE": 1024},
@ -62,10 +66,15 @@ def benchmark_rms_layernorm(
fn = lambda: vllm_norm(x)
elif provider == "triton_rms_layernorm":
fn = lambda: rms_layernorm(x, weight, eps=eps)
elif provider == "cuda_rms_layernorm":
out = torch.empty_like(x)
fn = lambda: inference_ops.rms_layernorm(out, x, weight, eps)
elif provider == "vllm_rms_layernorm_with_residual":
fn = lambda: vllm_norm(x, residual=residual)
elif provider == "triton_rms_layernorm_with_residual":
fn = lambda: rms_layernorm(x, weight, eps=eps, residual=residual)
elif provider == "cuda_rms_layernorm_with_residual":
fn = lambda: inference_ops.fused_add_rms_layernorm(x, residual, weight, eps)
else:
raise ValueError("Undefined provider.")

View File

@ -11,8 +11,25 @@ void decode_kv_cache_memcpy(
torch::Tensor silu_and_mul(const torch::Tensor& ins);
void rms_layernorm(torch::Tensor& out, // [..., hidden_size]
torch::Tensor& input, // [..., hidden_size]
torch::Tensor& weight, // [hidden_size]
float epsilon);
void fused_add_rms_layernorm(torch::Tensor& input, // [..., hidden_size]
torch::Tensor& residual, // [..., hidden_size]
torch::Tensor& weight, // [hidden_size]
float epsilon);
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("decode_kv_cache_memcpy", &decode_kv_cache_memcpy,
"Copy the GPU memory of kvcache during the decode stage.");
m.def("silu_and_mul", &silu_and_mul, "Silu with a following multiply");
m.def("rms_layernorm", &rms_layernorm,
"Apply Root Mean Square (RMS) Normalization to the input tensor.");
m.def("fused_add_rms_layernorm", &fused_add_rms_layernorm,
"In-place fused Add and RMS Normalization.");
}

View File

@ -0,0 +1,126 @@
/*This code from VLLM:
* https://github.com/vllm-project/vllm/
* with minor changes. */
#include <ATen/cuda/CUDAContext.h>
#include <torch/extension.h>
#include <c10/cuda/CUDAGuard.h>
#include <stdio.h>
#include "block_reduce.h"
#include "type_shim.h"
template<typename scalar_t>
__global__ void rms_layernorm_kernel(
scalar_t* __restrict__ out, // [..., hidden_size]
const scalar_t* __restrict__ input, // [..., hidden_size]
const scalar_t* __restrict__ weight, // [hidden_size]
const float epsilon,
const int num_tokens,
const int hidden_size) {
__shared__ float s_variance;
float variance = 0.0f;
/*
* since the open-sourced LLM's hidden dimensions mainly range from
* 4096 (LLAMA-7B) to 8192 (LLAMA-65B), we thus set the supported
* hidden dimension limit to 8192, and each thread's capacity
* for caching input tensors to 8 (8192 = 8 * 1024) which
* will cause problems for extremely large models, such as
* Megatron-Turing NLG 530B with hidden dimensions up to 20480
*/
float x_local[8];
for (int idx = threadIdx.x, cnt = 0; idx < hidden_size; idx += blockDim.x, cnt++) {
x_local[cnt] = (float) input[blockIdx.x * hidden_size + idx];
variance += x_local[cnt] * x_local[cnt];
}
variance = blockReduceSum<float>(variance);
if (threadIdx.x == 0) {
s_variance = rsqrtf(variance / hidden_size + epsilon);
}
__syncthreads();
for (int idx = threadIdx.x, cnt = 0; idx < hidden_size; idx += blockDim.x, cnt++) {
out[blockIdx.x * hidden_size + idx] = ((scalar_t) (x_local[cnt] * s_variance)) * weight[idx];
}
}
template<typename scalar_t>
__global__ void fused_add_rms_layernorm_kernel(
scalar_t* __restrict__ input, // [..., hidden_size]
scalar_t* __restrict__ residual, // [..., hidden_size]
const scalar_t* __restrict__ weight, // [hidden_size]
const float epsilon,
const int num_tokens,
const int hidden_size) {
__shared__ float s_variance;
float variance = 0.0f;
float x_local[8];
for (int idx = threadIdx.x, cnt = 0; idx < hidden_size; idx += blockDim.x, cnt++) {
x_local[cnt] = (float) input[blockIdx.x * hidden_size + idx];
x_local[cnt] += (float) residual[blockIdx.x * hidden_size + idx];
variance += x_local[cnt] * x_local[cnt];
residual[blockIdx.x * hidden_size + idx] = (scalar_t) x_local[cnt];
}
variance = blockReduceSum<float>(variance);
if (threadIdx.x == 0) {
s_variance = rsqrtf(variance / hidden_size + epsilon);
}
__syncthreads();
for (int idx = threadIdx.x, cnt = 0; idx < hidden_size; idx += blockDim.x, cnt++) {
input[blockIdx.x * hidden_size + idx] = ((scalar_t) (x_local[cnt] * s_variance)) * weight[idx];
}
}
void rms_layernorm(
torch::Tensor& out, // [..., hidden_size]
torch::Tensor& input, // [..., hidden_size]
torch::Tensor& weight, // [hidden_size]
float epsilon) {
int hidden_size = input.size(-1);
int num_tokens = input.numel() / hidden_size;
dim3 grid(num_tokens);
dim3 block(std::min(hidden_size, 1024));
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
DISPATCH_FLOAT_HALF_AND_BFLOAT(
input.scalar_type(),
"rms_layernorm_kernel",
rms_layernorm_kernel<scalar_t><<<grid, block, 0, stream>>>(
out.data_ptr<scalar_t>(),
input.data_ptr<scalar_t>(),
weight.data_ptr<scalar_t>(),
epsilon,
num_tokens,
hidden_size);)
}
void fused_add_rms_layernorm(
torch::Tensor& input, // [..., hidden_size]
torch::Tensor& residual, // [..., hidden_size]
torch::Tensor& weight, // [hidden_size]
float epsilon) {
int hidden_size = input.size(-1);
int num_tokens = input.numel() / hidden_size;
dim3 grid(num_tokens);
dim3 block(std::min(hidden_size, 1024));
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
DISPATCH_FLOAT_HALF_AND_BFLOAT(
input.scalar_type(),
"fused_add_rms_layernorm_kernel",
fused_add_rms_layernorm_kernel<scalar_t><<<grid, block, 0, stream>>>(
input.data_ptr<scalar_t>(),
residual.data_ptr<scalar_t>(),
weight.data_ptr<scalar_t>(),
epsilon,
num_tokens,
hidden_size);)
}

View File

@ -13,12 +13,13 @@ class InferenceOpsCudaExtension(_CudaExtension):
"cuda/colossal_inference_C_frontend.cpp",
"cuda/decode_kv_cache_memcpy_kernel.cu",
"cuda/activation_kernel.cu",
"cuda/rms_layernorm_kernel.cu",
]
]
return ret
def include_dirs(self):
ret = [self.get_cuda_home_include()]
ret = [self.csrc_abs_path("cuda/include"), self.get_cuda_home_include()]
return ret
def cxx_flags(self):

View File

@ -22,15 +22,11 @@ def setup_seed(seed):
def check_inference_engine(use_engine=False, prompt_template=None):
setup_seed(20)
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/llama-tokenizer")
model = (
LlamaForCausalLM(
LlamaConfig(
vocab_size=50000, hidden_size=512, intermediate_size=1536, num_attention_heads=4, num_hidden_layers=16
)
model = LlamaForCausalLM(
LlamaConfig(
vocab_size=50000, hidden_size=512, intermediate_size=1536, num_attention_heads=4, num_hidden_layers=16
)
.cuda()
.half()
)
).cuda()
model = model.eval()
inputs = [
@ -44,7 +40,7 @@ def check_inference_engine(use_engine=False, prompt_template=None):
top_k = 50
if use_engine:
inference_config = InferenceConfig(max_output_len=output_len, prompt_template=prompt_template)
inference_config = InferenceConfig(max_output_len=output_len, prompt_template=prompt_template, dtype="fp32")
inference_engine = InferenceEngine(model, tokenizer, inference_config, verbose=True)
assert inference_engine.generation_config.max_new_tokens == output_len
inference_engine.add_request(prompts=inputs)

View File

@ -0,0 +1,51 @@
import pytest
import torch
from transformers.models.llama.modeling_llama import LlamaRMSNorm
from colossalai.kernel.kernel_loader import InferenceOpsLoader
from colossalai.utils import get_current_device
inference_ops = InferenceOpsLoader().load()
@pytest.mark.parametrize("M", [2, 4, 8, 16])
@pytest.mark.parametrize("N", [64, 128, 512])
def test_rms_layernorm(M: int, N: int):
torch.manual_seed(123)
torch.cuda.empty_cache()
torch.cuda.synchronize()
torch.cuda.reset_peak_memory_stats()
device = get_current_device()
dtype = torch.float16
eps = 1e-5
x_shape = (M, N)
w_shape = (x_shape[-1],)
weight = torch.ones(w_shape, dtype=dtype, device=device)
residual = torch.rand(x_shape, dtype=dtype, device=device)
residual_copy = residual.clone()
rms_norm = LlamaRMSNorm(hidden_size=N, eps=eps).cuda()
x = -2.3 + 0.5 * torch.randn(x_shape, dtype=dtype, device="cuda")
x_copy = x.clone()
y_cuda = torch.empty_like(x)
inference_ops.rms_layernorm(y_cuda, x, weight, eps)
y_llama = rms_norm.forward(x).to(dtype)
assert y_cuda.shape == y_llama.shape
assert torch.allclose(y_cuda, y_llama, atol=1e-5, rtol=1e-3)
inference_ops.fused_add_rms_layernorm(x, residual, weight, eps)
y_cuda = x
x = x_copy + residual_copy
y_llama = rms_norm.forward(x).to(dtype)
assert y_cuda.shape == y_llama.shape
assert torch.allclose(y_cuda, y_llama, atol=1e-5, rtol=1e-3)
assert torch.allclose(x, residual, atol=1e-5, rtol=1e-3)
if __name__ == "__main__":
test_rms_layernorm(16, 512)