Browse Source

[Inference/opt]Optimize the mid tensor of RMS Norm (#5350)

* opt rms_norm

* fix bugs in rms_layernorm
pull/5356/head
yuehuayingxueluo 10 months ago committed by GitHub
parent
commit
21ad4a27f9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
  1. 12
      colossalai/inference/modeling/models/nopadding_llama.py
  2. 12
      colossalai/inference/modeling/models/padding_llama.py
  3. 4
      colossalai/inference/modeling/policy/nopadding_llama.py
  4. 4
      colossalai/inference/modeling/policy/padding_llama.py
  5. 10
      colossalai/kernel/triton/rms_layernorm.py
  6. 3
      examples/inference/benchmark_llama.py
  7. 24
      examples/inference/run_benchmark.sh

12
colossalai/inference/modeling/models/nopadding_llama.py

@ -95,6 +95,8 @@ def llama_model_forward(
)
sm_scale = 1.0 / (batch.head_dim**0.5)
norm_output = torch.empty_like(hidden_states)
for layer_id, decoder_layer in enumerate(self.layers):
hidden_states = decoder_layer(
hidden_states,
@ -107,13 +109,15 @@ def llama_model_forward(
cos_sin=cos_sin,
fd_inter_tensor=batch.fd_inter_tensor,
output_tensor=output_tensor,
norm_output=norm_output,
sm_scale=sm_scale,
)
if batch.is_prompts:
last_token_indexs = sequence_lengths.cumsum(dim=-1)
hidden_states = hidden_states[last_token_indexs - 1].contiguous()
hidden_states = self.norm(hidden_states)
norm_output = torch.empty_like(hidden_states)
hidden_states = self.norm(hidden_states, norm_output)
return hidden_states
@ -131,6 +135,7 @@ def llama_decoder_layer_forward(
cos_sin: Tuple[torch.Tensor] = None,
fd_inter_tensor: FDIntermTensors = None,
output_tensor: torch.Tensor = None,
norm_output: torch.Tensor = None,
sm_scale: int = None,
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
"""This function will replace the forward function of LlamaDecoderLayer.
@ -148,11 +153,12 @@ def llama_decoder_layer_forward(
fd_inter_tensor (FDIntermTensors, optional): Holding tensors used for
storing intermediate values in flash-decoding. Defaults to None.
output_tensor (torch.Tensor, optional): The mid tensor holds the output of attention. Defaults to None.
norm_output (torch.Tensor, optional): The mid tensor holds the output of layernorm. Defaults to None.
sm_scale (int, optional): Used for flash attention. Defaults to None.
"""
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
hidden_states = self.input_layernorm(hidden_states, norm_output)
# Self Attention
hidden_states = self.self_attn(
hidden_states=hidden_states,
@ -171,7 +177,7 @@ def llama_decoder_layer_forward(
# Fully Connected
residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states = self.post_attention_layernorm(hidden_states, norm_output)
hidden_states = self.mlp(hidden_states, residual)
return hidden_states

12
colossalai/inference/modeling/models/padding_llama.py

@ -135,6 +135,8 @@ def llama_model_forward(
)
sm_scale = 1.0 / (batch.head_dim**0.5)
norm_output = torch.empty_like(hidden_states)
for layer_id, decoder_layer in enumerate(self.layers):
hidden_states = decoder_layer(
hidden_states,
@ -149,12 +151,14 @@ def llama_model_forward(
cos_sin=cos_sin,
fd_inter_tensor=batch.fd_inter_tensor,
output_tensor=output_tensor,
norm_output=norm_output,
sm_scale=sm_scale,
)
if batch.is_prompts:
hidden_states = hidden_states[:, -1, :].unsqueeze(dim=1).contiguous()
hidden_states = self.norm(hidden_states)
norm_output = torch.empty_like(hidden_states)
hidden_states = self.norm(hidden_states.reshape(-1, hidden_states.shape[-1]), norm_output)
return hidden_states
@ -174,6 +178,7 @@ def llama_decoder_layer_forward(
cos_sin: Tuple[torch.Tensor] = None,
fd_inter_tensor: FDIntermTensors = None,
output_tensor: torch.Tensor = None,
norm_output: torch.Tensor = None,
sm_scale: int = None,
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
"""This function will replace the forward function of LlamaDecoderLayer.
@ -191,11 +196,12 @@ def llama_decoder_layer_forward(
cos_sin (Tuple[torch.Tensor], optional): Holding cos and sin. Defaults to None.
fd_inter_tensor (FDIntermTensors, optional): Holding tensors used for storing intermediate values in flash-decoding. Defaults to None.
output_tensor (torch.Tensor, optional): The mid tensor holds the output of attention. Defaults to None.
norm_output (torch.Tensor, optional): The mid tensor holds the output of layernorm. Defaults to None.
sm_scale (int, optional): Used for flash attention. Defaults to None.
"""
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
hidden_states = self.input_layernorm(hidden_states.reshape(-1, hidden_states.shape[-1]), norm_output)
# Self Attention
hidden_states = self.self_attn(
hidden_states=hidden_states,
@ -217,7 +223,7 @@ def llama_decoder_layer_forward(
# Fully Connected
residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states = self.post_attention_layernorm(hidden_states.reshape(-1, hidden_states.shape[-1]), norm_output)
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states

4
colossalai/inference/modeling/policy/nopadding_llama.py

@ -29,8 +29,8 @@ except:
def get_triton_rmsnorm_forward():
if HAS_TRITON_RMSNORM:
def _triton_rmsnorm_forward(self: LlamaRMSNorm, hidden_states: torch.Tensor):
return rms_layernorm(hidden_states, self.weight.data, self.variance_epsilon)
def _triton_rmsnorm_forward(self: LlamaRMSNorm, hidden_states: torch.Tensor, norm_output: torch.Tensor):
return rms_layernorm(hidden_states, self.weight.data, self.variance_epsilon, norm_output)
return _triton_rmsnorm_forward
else:

4
colossalai/inference/modeling/policy/padding_llama.py

@ -27,8 +27,8 @@ except:
def get_triton_rmsnorm_forward():
if HAS_TRITON_RMSNORM:
def _triton_rmsnorm_forward(self: LlamaRMSNorm, hidden_states: torch.Tensor):
return rms_layernorm(hidden_states, self.weight.data, self.variance_epsilon)
def _triton_rmsnorm_forward(self: LlamaRMSNorm, hidden_states: torch.Tensor, norm_outpu: torch.Tensor):
return rms_layernorm(hidden_states, self.weight.data, self.variance_epsilon, norm_outpu)
return _triton_rmsnorm_forward
else:

10
colossalai/kernel/triton/rms_layernorm.py

@ -50,12 +50,10 @@ if HAS_TRITON:
tl.store(Y + cols, y.to(tl.float16), mask=mask)
@torch.no_grad()
def rms_layernorm(x, weight, eps):
def rms_layernorm(x, weight, eps, norm_output=None):
# allocate output
y = torch.empty_like(x)
# reshape input data into 2D tensor, (total token, hidden_size)
x_arg = x.reshape(-1, x.shape[-1])
M, N = x_arg.shape
y = torch.empty_like(x) if norm_output is None else norm_output
M, N = x.shape
# Less than 64KB per feature: enqueue fused kernel
MAX_FUSED_SIZE = 65536 // x.element_size()
@ -67,5 +65,5 @@ if HAS_TRITON:
num_warps = min(max(triton.next_power_of_2(N) // 256, 8), 32)
# enqueue kernel
_rmsnorm_kernel[(M,)](x_arg, y, weight, x_arg.stride(0), N, eps, BLOCK_SIZE=BLOCK_SIZE, num_warps=num_warps)
_rmsnorm_kernel[(M,)](x, y, weight, x.stride(0), N, eps, BLOCK_SIZE=BLOCK_SIZE, num_warps=num_warps)
return y

3
examples/inference/benchmark_llama.py

@ -9,7 +9,8 @@ from transformers import AutoTokenizer, GenerationConfig
import colossalai
from colossalai.accelerator import get_accelerator
from colossalai.inference import InferenceEngine
from colossalai.inference.config import InferenceConfig
from colossalai.inference.core.engine import InferenceEngine
from colossalai.testing import clear_cache_before_run, rerun_if_address_is_in_use, spawn
GIGABYTE = 1024**3

24
examples/inference/run_benchmark.sh

@ -23,22 +23,10 @@ CUDA_VISIBLE_DEVICES_set_n_least_memory_usage() {
CUDA_VISIBLE_DEVICES_set_n_least_memory_usage 1
# benchmark llama2-7b one single GPU
for bsz in 16 32 64; do
python3 ${PY_SCRIPT} -m llama2-7b --tp_size 1 --pp_size 1 -b $bsz -s 512 --output_len 256 --mode ${mode} | tee logs/${mode}_${GPU}_${bsz}_512_256.txt
done
for bsz in 16 32 64; do
python3 ${PY_SCRIPT} -m llama2-7b --tp_size 1 --pp_size 1 -b $bsz -s 1024 --output_len 256 --mode ${mode} | tee logs/${mode}_${GPU}_${bsz}_1024_256.txt
done
for bsz in 16 32 64; do
python3 ${PY_SCRIPT} -m llama2-7b --tp_size 1 --pp_size 1 -b $bsz -s 256 --output_len 128 --mode ${mode} | tee logs/${mode}_${GPU}_${bsz}_256_128.txt
done
for bsz in 16 32 64; do
python3 ${PY_SCRIPT} -m llama2-7b --tp_size 1 --pp_size 1 -b $bsz -s 1024 --output_len 128 --mode ${mode} | tee logs/${mode}_${GPU}_${bsz}_1024_128.txt
for input_len in 128 512 1024; do
for output_len in 128 256; do
for bsz in 16 32 64; do
python3 ${PY_SCRIPT} -m llama2-7b --tp_size 1 --pp_size 1 -b ${bsz} -s ${input_len} --output_len ${output_len} --mode ${mode} | tee logs/${input_len}_${output_len}_${mode}_${GPU}_${bsz}.txt
done
done
done

Loading…
Cancel
Save