From f7aecc0c6bac001d10c1dd00274e0152e4c86df6 Mon Sep 17 00:00:00 2001 From: Steve Luo <36296769+SunflowerAries@users.noreply.github.com> Date: Fri, 8 Mar 2024 16:21:12 +0800 Subject: [PATCH] feat rmsnorm cuda kernel and add unittest, benchmark script (#5417) --- .../modeling/models/nopadding_llama.py | 28 +++- .../modeling/policy/nopadding_llama.py | 35 +---- ...rmsnorm_triton.py => benchmark_rmsnorm.py} | 19 ++- .../cuda/colossal_inference_C_frontend.cpp | 17 +++ extensions/csrc/cuda/rms_layernorm_kernel.cu | 126 ++++++++++++++++++ extensions/inference/inference_ops_cuda.py | 3 +- tests/test_infer/test_inference_engine.py | 14 +- .../test_ops/cuda/test_rms_layernorm.py | 51 +++++++ 8 files changed, 244 insertions(+), 49 deletions(-) rename examples/inference/benchmark_ops/{benchmark_rmsnorm_triton.py => benchmark_rmsnorm.py} (79%) create mode 100644 extensions/csrc/cuda/rms_layernorm_kernel.cu create mode 100644 tests/test_infer/test_ops/cuda/test_rms_layernorm.py diff --git a/colossalai/inference/modeling/models/nopadding_llama.py b/colossalai/inference/modeling/models/nopadding_llama.py index 876fed456..f84abab4b 100644 --- a/colossalai/inference/modeling/models/nopadding_llama.py +++ b/colossalai/inference/modeling/models/nopadding_llama.py @@ -9,6 +9,7 @@ from transformers.models.llama.modeling_llama import ( LlamaForCausalLM, LlamaMLP, LlamaModel, + LlamaRMSNorm, ) from colossalai.inference.batch_bucket import BatchBucket @@ -19,6 +20,7 @@ from colossalai.kernel.triton import ( decoding_fused_rotary_embedding, flash_decoding_attention, get_xine_cache, + rms_layernorm, rotary_embedding, ) from colossalai.logging import get_dist_logger @@ -124,7 +126,7 @@ def llama_model_forward( hidden_states = hidden_states[last_token_indexs - 1].contiguous() residual = residual[last_token_indexs - 1].contiguous() norm_output = torch.empty_like(hidden_states) - hidden_states, _ = self.norm(hidden_states, norm_output, residual) + hidden_states, _ = self.norm(hidden_states, norm_output, residual, use_cuda_kernel) return hidden_states @@ -167,7 +169,7 @@ def llama_decoder_layer_forward( use_cuda_kernel: (bool, optional): Whether to use cuda kernel. Defaults to True. """ - hidden_states, residual = self.input_layernorm(hidden_states, norm_output, residual) + hidden_states, residual = self.input_layernorm(hidden_states, norm_output, residual, use_cuda_kernel) # Self Attention hidden_states = self.self_attn( hidden_states=hidden_states, @@ -185,12 +187,32 @@ def llama_decoder_layer_forward( ) # Fully Connected - hidden_states, residual = self.post_attention_layernorm(hidden_states, norm_output, residual) + hidden_states, residual = self.post_attention_layernorm(hidden_states, norm_output, residual, use_cuda_kernel) hidden_states = self.mlp(hidden_states) return hidden_states, residual +def llama_rmsnorm_forward( + self: LlamaRMSNorm, + hidden_states: torch.Tensor, + norm_output: torch.Tensor, + residual: torch.Tensor = None, + use_cuda_kernel: bool = True, +): + if use_cuda_kernel: + if residual is not None: + inference_ops.fused_add_rms_layernorm(hidden_states, residual, self.weight.data, self.variance_epsilon) + return hidden_states, residual + + if norm_output is None: + norm_output = torch.empty_like(hidden_states) + inference_ops.rms_layernorm(norm_output, hidden_states, self.weight.data, self.variance_epsilon) + return norm_output, hidden_states + else: + return rms_layernorm(hidden_states, self.weight.data, self.variance_epsilon, norm_output, residual) + + class NopadLlamaAttention(LlamaAttention): def __init__( self, diff --git a/colossalai/inference/modeling/policy/nopadding_llama.py b/colossalai/inference/modeling/policy/nopadding_llama.py index 13695b835..bb9a22b41 100644 --- a/colossalai/inference/modeling/policy/nopadding_llama.py +++ b/colossalai/inference/modeling/policy/nopadding_llama.py @@ -1,6 +1,5 @@ from functools import partial -import torch from torch.nn import Parameter from transformers.models.llama.modeling_llama import LlamaDecoderLayer, LlamaForCausalLM, LlamaModel, LlamaRMSNorm @@ -10,6 +9,7 @@ from colossalai.inference.modeling.models.nopadding_llama import ( llama_causal_lm_forward, llama_decoder_layer_forward, llama_model_forward, + llama_rmsnorm_forward, ) from colossalai.inference.utils import init_to_get_rotary from colossalai.shardformer.policies.base_policy import ModulePolicyDescription, SubModuleReplacementDescription @@ -17,27 +17,6 @@ from colossalai.shardformer.policies.base_policy import ModulePolicyDescription, # import colossalai from colossalai.shardformer.policies.llama import LlamaForCausalLMPolicy -try: - from colossalai.kernel.triton import rms_layernorm - - HAS_TRITON_RMSNORM = True -except: - print("you should install triton from https://github.com/openai/triton") - HAS_TRITON_RMSNORM = False - - -def get_triton_rmsnorm_forward(): - if HAS_TRITON_RMSNORM: - - def _triton_rmsnorm_forward( - self: LlamaRMSNorm, hidden_states: torch.Tensor, norm_output: torch.Tensor, residual: torch.Tensor = None - ): - return rms_layernorm(hidden_states, self.weight.data, self.variance_epsilon, norm_output, residual) - - return _triton_rmsnorm_forward - else: - return None - class NoPaddingLlamaModelInferPolicy(LlamaForCausalLMPolicy): def __init__(self) -> None: @@ -84,15 +63,9 @@ class NoPaddingLlamaModelInferPolicy(LlamaForCausalLMPolicy): description=method_replacement, policy=policy, target_key=LlamaDecoderLayer ) - infer_forward = None - if HAS_TRITON_RMSNORM: - infer_forward = get_triton_rmsnorm_forward() - - if infer_forward is not None: - method_replacement = {"forward": partial(infer_forward)} - self.append_or_create_method_replacement( - description=method_replacement, policy=policy, target_key=LlamaRMSNorm - ) + infer_forward = llama_rmsnorm_forward + method_replacement = {"forward": partial(infer_forward)} + self.append_or_create_method_replacement(description=method_replacement, policy=policy, target_key=LlamaRMSNorm) return policy diff --git a/examples/inference/benchmark_ops/benchmark_rmsnorm_triton.py b/examples/inference/benchmark_ops/benchmark_rmsnorm.py similarity index 79% rename from examples/inference/benchmark_ops/benchmark_rmsnorm_triton.py rename to examples/inference/benchmark_ops/benchmark_rmsnorm.py index 9c60601b9..3b5166af0 100644 --- a/examples/inference/benchmark_ops/benchmark_rmsnorm_triton.py +++ b/examples/inference/benchmark_ops/benchmark_rmsnorm.py @@ -1,14 +1,14 @@ import torch -import triton +from colossalai.kernel.kernel_loader import InferenceOpsLoader from colossalai.kernel.triton import rms_layernorm try: import triton # noqa - except ImportError: print("please install triton from https://github.com/openai/triton") +inference_ops = InferenceOpsLoader().load() # Triton benchmark plot attributions configs = [ @@ -19,16 +19,20 @@ configs = [ line_vals=[ "vllm_rms_layernorm", "triton_rms_layernorm", - "triton_rms_layernorm_with_residual", + "cuda_rms_layernorm", "vllm_rms_layernorm_with_residual", + "triton_rms_layernorm_with_residual", + "cuda_rms_layernorm_with_residual", ], line_names=[ "vllm_rms_layernorm", "triton_rms_layernorm", - "triton_rms_layernorm_with_residual", + "cuda_rms_layernorm", "vllm_rms_layernorm_with_residual", + "triton_rms_layernorm_with_residual", + "cuda_rms_layernorm_with_residual", ], - styles=[("red", "-"), ("blue", "-"), ("yellow", "-"), ("green", "-")], + styles=[("red", "-"), ("blue", "-"), ("yellow", "-"), ("red", "--"), ("blue", "--"), ("yellow", "--")], ylabel="ms", plot_name=f"RMSNorm benchmarking results", args={"HIDDEN_SIZE": 1024}, @@ -62,10 +66,15 @@ def benchmark_rms_layernorm( fn = lambda: vllm_norm(x) elif provider == "triton_rms_layernorm": fn = lambda: rms_layernorm(x, weight, eps=eps) + elif provider == "cuda_rms_layernorm": + out = torch.empty_like(x) + fn = lambda: inference_ops.rms_layernorm(out, x, weight, eps) elif provider == "vllm_rms_layernorm_with_residual": fn = lambda: vllm_norm(x, residual=residual) elif provider == "triton_rms_layernorm_with_residual": fn = lambda: rms_layernorm(x, weight, eps=eps, residual=residual) + elif provider == "cuda_rms_layernorm_with_residual": + fn = lambda: inference_ops.fused_add_rms_layernorm(x, residual, weight, eps) else: raise ValueError("Undefined provider.") diff --git a/extensions/csrc/cuda/colossal_inference_C_frontend.cpp b/extensions/csrc/cuda/colossal_inference_C_frontend.cpp index cc53d8b88..73ed49e6c 100644 --- a/extensions/csrc/cuda/colossal_inference_C_frontend.cpp +++ b/extensions/csrc/cuda/colossal_inference_C_frontend.cpp @@ -11,8 +11,25 @@ void decode_kv_cache_memcpy( torch::Tensor silu_and_mul(const torch::Tensor& ins); +void rms_layernorm(torch::Tensor& out, // [..., hidden_size] + torch::Tensor& input, // [..., hidden_size] + torch::Tensor& weight, // [hidden_size] + float epsilon); + +void fused_add_rms_layernorm(torch::Tensor& input, // [..., hidden_size] + torch::Tensor& residual, // [..., hidden_size] + torch::Tensor& weight, // [hidden_size] + float epsilon); + PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("decode_kv_cache_memcpy", &decode_kv_cache_memcpy, "Copy the GPU memory of kvcache during the decode stage."); + m.def("silu_and_mul", &silu_and_mul, "Silu with a following multiply"); + + m.def("rms_layernorm", &rms_layernorm, + "Apply Root Mean Square (RMS) Normalization to the input tensor."); + + m.def("fused_add_rms_layernorm", &fused_add_rms_layernorm, + "In-place fused Add and RMS Normalization."); } diff --git a/extensions/csrc/cuda/rms_layernorm_kernel.cu b/extensions/csrc/cuda/rms_layernorm_kernel.cu new file mode 100644 index 000000000..99d36575d --- /dev/null +++ b/extensions/csrc/cuda/rms_layernorm_kernel.cu @@ -0,0 +1,126 @@ +/*This code from VLLM: + * https://github.com/vllm-project/vllm/ + * with minor changes. */ + +#include +#include +#include +#include + + +#include "block_reduce.h" +#include "type_shim.h" + +template +__global__ void rms_layernorm_kernel( + scalar_t* __restrict__ out, // [..., hidden_size] + const scalar_t* __restrict__ input, // [..., hidden_size] + const scalar_t* __restrict__ weight, // [hidden_size] + const float epsilon, + const int num_tokens, + const int hidden_size) { + __shared__ float s_variance; + float variance = 0.0f; + /* + * since the open-sourced LLM's hidden dimensions mainly range from + * 4096 (LLAMA-7B) to 8192 (LLAMA-65B), we thus set the supported + * hidden dimension limit to 8192, and each thread's capacity + * for caching input tensors to 8 (8192 = 8 * 1024) which + * will cause problems for extremely large models, such as + * Megatron-Turing NLG 530B with hidden dimensions up to 20480 + */ + float x_local[8]; + + for (int idx = threadIdx.x, cnt = 0; idx < hidden_size; idx += blockDim.x, cnt++) { + x_local[cnt] = (float) input[blockIdx.x * hidden_size + idx]; + variance += x_local[cnt] * x_local[cnt]; + } + variance = blockReduceSum(variance); + if (threadIdx.x == 0) { + s_variance = rsqrtf(variance / hidden_size + epsilon); + } + __syncthreads(); + + for (int idx = threadIdx.x, cnt = 0; idx < hidden_size; idx += blockDim.x, cnt++) { + out[blockIdx.x * hidden_size + idx] = ((scalar_t) (x_local[cnt] * s_variance)) * weight[idx]; + } +} + +template +__global__ void fused_add_rms_layernorm_kernel( + scalar_t* __restrict__ input, // [..., hidden_size] + scalar_t* __restrict__ residual, // [..., hidden_size] + const scalar_t* __restrict__ weight, // [hidden_size] + const float epsilon, + const int num_tokens, + const int hidden_size) { + __shared__ float s_variance; + float variance = 0.0f; + float x_local[8]; + + for (int idx = threadIdx.x, cnt = 0; idx < hidden_size; idx += blockDim.x, cnt++) { + x_local[cnt] = (float) input[blockIdx.x * hidden_size + idx]; + x_local[cnt] += (float) residual[blockIdx.x * hidden_size + idx]; + variance += x_local[cnt] * x_local[cnt]; + residual[blockIdx.x * hidden_size + idx] = (scalar_t) x_local[cnt]; + } + variance = blockReduceSum(variance); + if (threadIdx.x == 0) { + s_variance = rsqrtf(variance / hidden_size + epsilon); + } + __syncthreads(); + + for (int idx = threadIdx.x, cnt = 0; idx < hidden_size; idx += blockDim.x, cnt++) { + input[blockIdx.x * hidden_size + idx] = ((scalar_t) (x_local[cnt] * s_variance)) * weight[idx]; + } +} + +void rms_layernorm( + torch::Tensor& out, // [..., hidden_size] + torch::Tensor& input, // [..., hidden_size] + torch::Tensor& weight, // [hidden_size] + float epsilon) { + int hidden_size = input.size(-1); + int num_tokens = input.numel() / hidden_size; + + dim3 grid(num_tokens); + dim3 block(std::min(hidden_size, 1024)); + const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + DISPATCH_FLOAT_HALF_AND_BFLOAT( + input.scalar_type(), + "rms_layernorm_kernel", + rms_layernorm_kernel<<>>( + out.data_ptr(), + input.data_ptr(), + weight.data_ptr(), + epsilon, + num_tokens, + hidden_size);) +} + +void fused_add_rms_layernorm( + torch::Tensor& input, // [..., hidden_size] + torch::Tensor& residual, // [..., hidden_size] + torch::Tensor& weight, // [hidden_size] + float epsilon) { + int hidden_size = input.size(-1); + int num_tokens = input.numel() / hidden_size; + + dim3 grid(num_tokens); + dim3 block(std::min(hidden_size, 1024)); + const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + DISPATCH_FLOAT_HALF_AND_BFLOAT( + input.scalar_type(), + "fused_add_rms_layernorm_kernel", + fused_add_rms_layernorm_kernel<<>>( + input.data_ptr(), + residual.data_ptr(), + weight.data_ptr(), + epsilon, + num_tokens, + hidden_size);) +} diff --git a/extensions/inference/inference_ops_cuda.py b/extensions/inference/inference_ops_cuda.py index 2858d7160..042c598fb 100644 --- a/extensions/inference/inference_ops_cuda.py +++ b/extensions/inference/inference_ops_cuda.py @@ -13,12 +13,13 @@ class InferenceOpsCudaExtension(_CudaExtension): "cuda/colossal_inference_C_frontend.cpp", "cuda/decode_kv_cache_memcpy_kernel.cu", "cuda/activation_kernel.cu", + "cuda/rms_layernorm_kernel.cu", ] ] return ret def include_dirs(self): - ret = [self.get_cuda_home_include()] + ret = [self.csrc_abs_path("cuda/include"), self.get_cuda_home_include()] return ret def cxx_flags(self): diff --git a/tests/test_infer/test_inference_engine.py b/tests/test_infer/test_inference_engine.py index edd92bb96..25b2c2f43 100644 --- a/tests/test_infer/test_inference_engine.py +++ b/tests/test_infer/test_inference_engine.py @@ -22,15 +22,11 @@ def setup_seed(seed): def check_inference_engine(use_engine=False, prompt_template=None): setup_seed(20) tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/llama-tokenizer") - model = ( - LlamaForCausalLM( - LlamaConfig( - vocab_size=50000, hidden_size=512, intermediate_size=1536, num_attention_heads=4, num_hidden_layers=16 - ) + model = LlamaForCausalLM( + LlamaConfig( + vocab_size=50000, hidden_size=512, intermediate_size=1536, num_attention_heads=4, num_hidden_layers=16 ) - .cuda() - .half() - ) + ).cuda() model = model.eval() inputs = [ @@ -44,7 +40,7 @@ def check_inference_engine(use_engine=False, prompt_template=None): top_k = 50 if use_engine: - inference_config = InferenceConfig(max_output_len=output_len, prompt_template=prompt_template) + inference_config = InferenceConfig(max_output_len=output_len, prompt_template=prompt_template, dtype="fp32") inference_engine = InferenceEngine(model, tokenizer, inference_config, verbose=True) assert inference_engine.generation_config.max_new_tokens == output_len inference_engine.add_request(prompts=inputs) diff --git a/tests/test_infer/test_ops/cuda/test_rms_layernorm.py b/tests/test_infer/test_ops/cuda/test_rms_layernorm.py new file mode 100644 index 000000000..d14010600 --- /dev/null +++ b/tests/test_infer/test_ops/cuda/test_rms_layernorm.py @@ -0,0 +1,51 @@ +import pytest +import torch +from transformers.models.llama.modeling_llama import LlamaRMSNorm + +from colossalai.kernel.kernel_loader import InferenceOpsLoader +from colossalai.utils import get_current_device + +inference_ops = InferenceOpsLoader().load() + + +@pytest.mark.parametrize("M", [2, 4, 8, 16]) +@pytest.mark.parametrize("N", [64, 128, 512]) +def test_rms_layernorm(M: int, N: int): + torch.manual_seed(123) + torch.cuda.empty_cache() + torch.cuda.synchronize() + torch.cuda.reset_peak_memory_stats() + + device = get_current_device() + + dtype = torch.float16 + eps = 1e-5 + x_shape = (M, N) + w_shape = (x_shape[-1],) + weight = torch.ones(w_shape, dtype=dtype, device=device) + residual = torch.rand(x_shape, dtype=dtype, device=device) + residual_copy = residual.clone() + rms_norm = LlamaRMSNorm(hidden_size=N, eps=eps).cuda() + x = -2.3 + 0.5 * torch.randn(x_shape, dtype=dtype, device="cuda") + x_copy = x.clone() + + y_cuda = torch.empty_like(x) + inference_ops.rms_layernorm(y_cuda, x, weight, eps) + y_llama = rms_norm.forward(x).to(dtype) + + assert y_cuda.shape == y_llama.shape + assert torch.allclose(y_cuda, y_llama, atol=1e-5, rtol=1e-3) + + inference_ops.fused_add_rms_layernorm(x, residual, weight, eps) + y_cuda = x + + x = x_copy + residual_copy + y_llama = rms_norm.forward(x).to(dtype) + + assert y_cuda.shape == y_llama.shape + assert torch.allclose(y_cuda, y_llama, atol=1e-5, rtol=1e-3) + assert torch.allclose(x, residual, atol=1e-5, rtol=1e-3) + + +if __name__ == "__main__": + test_rms_layernorm(16, 512)