From 21ad4a27f91659220bec6c4d4f2d0f62f7093a45 Mon Sep 17 00:00:00 2001 From: yuehuayingxueluo <867460659@qq.com> Date: Fri, 2 Feb 2024 15:06:01 +0800 Subject: [PATCH] [Inference/opt]Optimize the mid tensor of RMS Norm (#5350) * opt rms_norm * fix bugs in rms_layernorm --- .../modeling/models/nopadding_llama.py | 12 +++++++--- .../modeling/models/padding_llama.py | 12 +++++++--- .../modeling/policy/nopadding_llama.py | 4 ++-- .../modeling/policy/padding_llama.py | 4 ++-- colossalai/kernel/triton/rms_layernorm.py | 10 ++++---- examples/inference/benchmark_llama.py | 3 ++- examples/inference/run_benchmark.sh | 24 +++++-------------- 7 files changed, 34 insertions(+), 35 deletions(-) diff --git a/colossalai/inference/modeling/models/nopadding_llama.py b/colossalai/inference/modeling/models/nopadding_llama.py index 6b108cd4d..5d0397ee8 100644 --- a/colossalai/inference/modeling/models/nopadding_llama.py +++ b/colossalai/inference/modeling/models/nopadding_llama.py @@ -95,6 +95,8 @@ def llama_model_forward( ) sm_scale = 1.0 / (batch.head_dim**0.5) + norm_output = torch.empty_like(hidden_states) + for layer_id, decoder_layer in enumerate(self.layers): hidden_states = decoder_layer( hidden_states, @@ -107,13 +109,15 @@ def llama_model_forward( cos_sin=cos_sin, fd_inter_tensor=batch.fd_inter_tensor, output_tensor=output_tensor, + norm_output=norm_output, sm_scale=sm_scale, ) if batch.is_prompts: last_token_indexs = sequence_lengths.cumsum(dim=-1) hidden_states = hidden_states[last_token_indexs - 1].contiguous() - hidden_states = self.norm(hidden_states) + norm_output = torch.empty_like(hidden_states) + hidden_states = self.norm(hidden_states, norm_output) return hidden_states @@ -131,6 +135,7 @@ def llama_decoder_layer_forward( cos_sin: Tuple[torch.Tensor] = None, fd_inter_tensor: FDIntermTensors = None, output_tensor: torch.Tensor = None, + norm_output: torch.Tensor = None, sm_scale: int = None, ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: """This function will replace the forward function of LlamaDecoderLayer. @@ -148,11 +153,12 @@ def llama_decoder_layer_forward( fd_inter_tensor (FDIntermTensors, optional): Holding tensors used for storing intermediate values in flash-decoding. Defaults to None. output_tensor (torch.Tensor, optional): The mid tensor holds the output of attention. Defaults to None. + norm_output (torch.Tensor, optional): The mid tensor holds the output of layernorm. Defaults to None. sm_scale (int, optional): Used for flash attention. Defaults to None. """ residual = hidden_states - hidden_states = self.input_layernorm(hidden_states) + hidden_states = self.input_layernorm(hidden_states, norm_output) # Self Attention hidden_states = self.self_attn( hidden_states=hidden_states, @@ -171,7 +177,7 @@ def llama_decoder_layer_forward( # Fully Connected residual = hidden_states - hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.post_attention_layernorm(hidden_states, norm_output) hidden_states = self.mlp(hidden_states, residual) return hidden_states diff --git a/colossalai/inference/modeling/models/padding_llama.py b/colossalai/inference/modeling/models/padding_llama.py index 51d718a53..c53ff652c 100644 --- a/colossalai/inference/modeling/models/padding_llama.py +++ b/colossalai/inference/modeling/models/padding_llama.py @@ -135,6 +135,8 @@ def llama_model_forward( ) sm_scale = 1.0 / (batch.head_dim**0.5) + norm_output = torch.empty_like(hidden_states) + for layer_id, decoder_layer in enumerate(self.layers): hidden_states = decoder_layer( hidden_states, @@ -149,12 +151,14 @@ def llama_model_forward( cos_sin=cos_sin, fd_inter_tensor=batch.fd_inter_tensor, output_tensor=output_tensor, + norm_output=norm_output, sm_scale=sm_scale, ) if batch.is_prompts: hidden_states = hidden_states[:, -1, :].unsqueeze(dim=1).contiguous() - hidden_states = self.norm(hidden_states) + norm_output = torch.empty_like(hidden_states) + hidden_states = self.norm(hidden_states.reshape(-1, hidden_states.shape[-1]), norm_output) return hidden_states @@ -174,6 +178,7 @@ def llama_decoder_layer_forward( cos_sin: Tuple[torch.Tensor] = None, fd_inter_tensor: FDIntermTensors = None, output_tensor: torch.Tensor = None, + norm_output: torch.Tensor = None, sm_scale: int = None, ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: """This function will replace the forward function of LlamaDecoderLayer. @@ -191,11 +196,12 @@ def llama_decoder_layer_forward( cos_sin (Tuple[torch.Tensor], optional): Holding cos and sin. Defaults to None. fd_inter_tensor (FDIntermTensors, optional): Holding tensors used for storing intermediate values in flash-decoding. Defaults to None. output_tensor (torch.Tensor, optional): The mid tensor holds the output of attention. Defaults to None. + norm_output (torch.Tensor, optional): The mid tensor holds the output of layernorm. Defaults to None. sm_scale (int, optional): Used for flash attention. Defaults to None. """ residual = hidden_states - hidden_states = self.input_layernorm(hidden_states) + hidden_states = self.input_layernorm(hidden_states.reshape(-1, hidden_states.shape[-1]), norm_output) # Self Attention hidden_states = self.self_attn( hidden_states=hidden_states, @@ -217,7 +223,7 @@ def llama_decoder_layer_forward( # Fully Connected residual = hidden_states - hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.post_attention_layernorm(hidden_states.reshape(-1, hidden_states.shape[-1]), norm_output) hidden_states = self.mlp(hidden_states) hidden_states = residual + hidden_states diff --git a/colossalai/inference/modeling/policy/nopadding_llama.py b/colossalai/inference/modeling/policy/nopadding_llama.py index aed72ef73..c8bb7dae3 100644 --- a/colossalai/inference/modeling/policy/nopadding_llama.py +++ b/colossalai/inference/modeling/policy/nopadding_llama.py @@ -29,8 +29,8 @@ except: def get_triton_rmsnorm_forward(): if HAS_TRITON_RMSNORM: - def _triton_rmsnorm_forward(self: LlamaRMSNorm, hidden_states: torch.Tensor): - return rms_layernorm(hidden_states, self.weight.data, self.variance_epsilon) + def _triton_rmsnorm_forward(self: LlamaRMSNorm, hidden_states: torch.Tensor, norm_output: torch.Tensor): + return rms_layernorm(hidden_states, self.weight.data, self.variance_epsilon, norm_output) return _triton_rmsnorm_forward else: diff --git a/colossalai/inference/modeling/policy/padding_llama.py b/colossalai/inference/modeling/policy/padding_llama.py index 9aa64f55b..fb009417b 100644 --- a/colossalai/inference/modeling/policy/padding_llama.py +++ b/colossalai/inference/modeling/policy/padding_llama.py @@ -27,8 +27,8 @@ except: def get_triton_rmsnorm_forward(): if HAS_TRITON_RMSNORM: - def _triton_rmsnorm_forward(self: LlamaRMSNorm, hidden_states: torch.Tensor): - return rms_layernorm(hidden_states, self.weight.data, self.variance_epsilon) + def _triton_rmsnorm_forward(self: LlamaRMSNorm, hidden_states: torch.Tensor, norm_outpu: torch.Tensor): + return rms_layernorm(hidden_states, self.weight.data, self.variance_epsilon, norm_outpu) return _triton_rmsnorm_forward else: diff --git a/colossalai/kernel/triton/rms_layernorm.py b/colossalai/kernel/triton/rms_layernorm.py index 71a724008..e4424eb33 100644 --- a/colossalai/kernel/triton/rms_layernorm.py +++ b/colossalai/kernel/triton/rms_layernorm.py @@ -50,12 +50,10 @@ if HAS_TRITON: tl.store(Y + cols, y.to(tl.float16), mask=mask) @torch.no_grad() - def rms_layernorm(x, weight, eps): + def rms_layernorm(x, weight, eps, norm_output=None): # allocate output - y = torch.empty_like(x) - # reshape input data into 2D tensor, (total token, hidden_size) - x_arg = x.reshape(-1, x.shape[-1]) - M, N = x_arg.shape + y = torch.empty_like(x) if norm_output is None else norm_output + M, N = x.shape # Less than 64KB per feature: enqueue fused kernel MAX_FUSED_SIZE = 65536 // x.element_size() @@ -67,5 +65,5 @@ if HAS_TRITON: num_warps = min(max(triton.next_power_of_2(N) // 256, 8), 32) # enqueue kernel - _rmsnorm_kernel[(M,)](x_arg, y, weight, x_arg.stride(0), N, eps, BLOCK_SIZE=BLOCK_SIZE, num_warps=num_warps) + _rmsnorm_kernel[(M,)](x, y, weight, x.stride(0), N, eps, BLOCK_SIZE=BLOCK_SIZE, num_warps=num_warps) return y diff --git a/examples/inference/benchmark_llama.py b/examples/inference/benchmark_llama.py index c49d98982..267e56231 100644 --- a/examples/inference/benchmark_llama.py +++ b/examples/inference/benchmark_llama.py @@ -9,7 +9,8 @@ from transformers import AutoTokenizer, GenerationConfig import colossalai from colossalai.accelerator import get_accelerator -from colossalai.inference import InferenceEngine +from colossalai.inference.config import InferenceConfig +from colossalai.inference.core.engine import InferenceEngine from colossalai.testing import clear_cache_before_run, rerun_if_address_is_in_use, spawn GIGABYTE = 1024**3 diff --git a/examples/inference/run_benchmark.sh b/examples/inference/run_benchmark.sh index 6870ed384..2a6e5a5d7 100755 --- a/examples/inference/run_benchmark.sh +++ b/examples/inference/run_benchmark.sh @@ -23,22 +23,10 @@ CUDA_VISIBLE_DEVICES_set_n_least_memory_usage() { CUDA_VISIBLE_DEVICES_set_n_least_memory_usage 1 # benchmark llama2-7b one single GPU - -for bsz in 16 32 64; do - python3 ${PY_SCRIPT} -m llama2-7b --tp_size 1 --pp_size 1 -b $bsz -s 512 --output_len 256 --mode ${mode} | tee logs/${mode}_${GPU}_${bsz}_512_256.txt -done - - -for bsz in 16 32 64; do - python3 ${PY_SCRIPT} -m llama2-7b --tp_size 1 --pp_size 1 -b $bsz -s 1024 --output_len 256 --mode ${mode} | tee logs/${mode}_${GPU}_${bsz}_1024_256.txt -done - - -for bsz in 16 32 64; do - python3 ${PY_SCRIPT} -m llama2-7b --tp_size 1 --pp_size 1 -b $bsz -s 256 --output_len 128 --mode ${mode} | tee logs/${mode}_${GPU}_${bsz}_256_128.txt -done - - -for bsz in 16 32 64; do - python3 ${PY_SCRIPT} -m llama2-7b --tp_size 1 --pp_size 1 -b $bsz -s 1024 --output_len 128 --mode ${mode} | tee logs/${mode}_${GPU}_${bsz}_1024_128.txt +for input_len in 128 512 1024; do + for output_len in 128 256; do + for bsz in 16 32 64; do + python3 ${PY_SCRIPT} -m llama2-7b --tp_size 1 --pp_size 1 -b ${bsz} -s ${input_len} --output_len ${output_len} --mode ${mode} | tee logs/${input_len}_${output_len}_${mode}_${GPU}_${bsz}.txt + done + done done