mirror of https://github.com/hpcaitech/ColossalAI
[Inference/opt]Optimize the mid tensor of RMS Norm (#5350)
* opt rms_norm * fix bugs in rms_layernormpull/5356/head
parent
027aa1043f
commit
21ad4a27f9
|
@ -95,6 +95,8 @@ def llama_model_forward(
|
|||
)
|
||||
sm_scale = 1.0 / (batch.head_dim**0.5)
|
||||
|
||||
norm_output = torch.empty_like(hidden_states)
|
||||
|
||||
for layer_id, decoder_layer in enumerate(self.layers):
|
||||
hidden_states = decoder_layer(
|
||||
hidden_states,
|
||||
|
@ -107,13 +109,15 @@ def llama_model_forward(
|
|||
cos_sin=cos_sin,
|
||||
fd_inter_tensor=batch.fd_inter_tensor,
|
||||
output_tensor=output_tensor,
|
||||
norm_output=norm_output,
|
||||
sm_scale=sm_scale,
|
||||
)
|
||||
|
||||
if batch.is_prompts:
|
||||
last_token_indexs = sequence_lengths.cumsum(dim=-1)
|
||||
hidden_states = hidden_states[last_token_indexs - 1].contiguous()
|
||||
hidden_states = self.norm(hidden_states)
|
||||
norm_output = torch.empty_like(hidden_states)
|
||||
hidden_states = self.norm(hidden_states, norm_output)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
@ -131,6 +135,7 @@ def llama_decoder_layer_forward(
|
|||
cos_sin: Tuple[torch.Tensor] = None,
|
||||
fd_inter_tensor: FDIntermTensors = None,
|
||||
output_tensor: torch.Tensor = None,
|
||||
norm_output: torch.Tensor = None,
|
||||
sm_scale: int = None,
|
||||
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
||||
"""This function will replace the forward function of LlamaDecoderLayer.
|
||||
|
@ -148,11 +153,12 @@ def llama_decoder_layer_forward(
|
|||
fd_inter_tensor (FDIntermTensors, optional): Holding tensors used for
|
||||
storing intermediate values in flash-decoding. Defaults to None.
|
||||
output_tensor (torch.Tensor, optional): The mid tensor holds the output of attention. Defaults to None.
|
||||
norm_output (torch.Tensor, optional): The mid tensor holds the output of layernorm. Defaults to None.
|
||||
sm_scale (int, optional): Used for flash attention. Defaults to None.
|
||||
"""
|
||||
residual = hidden_states
|
||||
|
||||
hidden_states = self.input_layernorm(hidden_states)
|
||||
hidden_states = self.input_layernorm(hidden_states, norm_output)
|
||||
# Self Attention
|
||||
hidden_states = self.self_attn(
|
||||
hidden_states=hidden_states,
|
||||
|
@ -171,7 +177,7 @@ def llama_decoder_layer_forward(
|
|||
|
||||
# Fully Connected
|
||||
residual = hidden_states
|
||||
hidden_states = self.post_attention_layernorm(hidden_states)
|
||||
hidden_states = self.post_attention_layernorm(hidden_states, norm_output)
|
||||
hidden_states = self.mlp(hidden_states, residual)
|
||||
|
||||
return hidden_states
|
||||
|
|
|
@ -135,6 +135,8 @@ def llama_model_forward(
|
|||
)
|
||||
sm_scale = 1.0 / (batch.head_dim**0.5)
|
||||
|
||||
norm_output = torch.empty_like(hidden_states)
|
||||
|
||||
for layer_id, decoder_layer in enumerate(self.layers):
|
||||
hidden_states = decoder_layer(
|
||||
hidden_states,
|
||||
|
@ -149,12 +151,14 @@ def llama_model_forward(
|
|||
cos_sin=cos_sin,
|
||||
fd_inter_tensor=batch.fd_inter_tensor,
|
||||
output_tensor=output_tensor,
|
||||
norm_output=norm_output,
|
||||
sm_scale=sm_scale,
|
||||
)
|
||||
|
||||
if batch.is_prompts:
|
||||
hidden_states = hidden_states[:, -1, :].unsqueeze(dim=1).contiguous()
|
||||
hidden_states = self.norm(hidden_states)
|
||||
norm_output = torch.empty_like(hidden_states)
|
||||
hidden_states = self.norm(hidden_states.reshape(-1, hidden_states.shape[-1]), norm_output)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
@ -174,6 +178,7 @@ def llama_decoder_layer_forward(
|
|||
cos_sin: Tuple[torch.Tensor] = None,
|
||||
fd_inter_tensor: FDIntermTensors = None,
|
||||
output_tensor: torch.Tensor = None,
|
||||
norm_output: torch.Tensor = None,
|
||||
sm_scale: int = None,
|
||||
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
||||
"""This function will replace the forward function of LlamaDecoderLayer.
|
||||
|
@ -191,11 +196,12 @@ def llama_decoder_layer_forward(
|
|||
cos_sin (Tuple[torch.Tensor], optional): Holding cos and sin. Defaults to None.
|
||||
fd_inter_tensor (FDIntermTensors, optional): Holding tensors used for storing intermediate values in flash-decoding. Defaults to None.
|
||||
output_tensor (torch.Tensor, optional): The mid tensor holds the output of attention. Defaults to None.
|
||||
norm_output (torch.Tensor, optional): The mid tensor holds the output of layernorm. Defaults to None.
|
||||
sm_scale (int, optional): Used for flash attention. Defaults to None.
|
||||
"""
|
||||
residual = hidden_states
|
||||
|
||||
hidden_states = self.input_layernorm(hidden_states)
|
||||
hidden_states = self.input_layernorm(hidden_states.reshape(-1, hidden_states.shape[-1]), norm_output)
|
||||
# Self Attention
|
||||
hidden_states = self.self_attn(
|
||||
hidden_states=hidden_states,
|
||||
|
@ -217,7 +223,7 @@ def llama_decoder_layer_forward(
|
|||
|
||||
# Fully Connected
|
||||
residual = hidden_states
|
||||
hidden_states = self.post_attention_layernorm(hidden_states)
|
||||
hidden_states = self.post_attention_layernorm(hidden_states.reshape(-1, hidden_states.shape[-1]), norm_output)
|
||||
hidden_states = self.mlp(hidden_states)
|
||||
hidden_states = residual + hidden_states
|
||||
|
||||
|
|
|
@ -29,8 +29,8 @@ except:
|
|||
def get_triton_rmsnorm_forward():
|
||||
if HAS_TRITON_RMSNORM:
|
||||
|
||||
def _triton_rmsnorm_forward(self: LlamaRMSNorm, hidden_states: torch.Tensor):
|
||||
return rms_layernorm(hidden_states, self.weight.data, self.variance_epsilon)
|
||||
def _triton_rmsnorm_forward(self: LlamaRMSNorm, hidden_states: torch.Tensor, norm_output: torch.Tensor):
|
||||
return rms_layernorm(hidden_states, self.weight.data, self.variance_epsilon, norm_output)
|
||||
|
||||
return _triton_rmsnorm_forward
|
||||
else:
|
||||
|
|
|
@ -27,8 +27,8 @@ except:
|
|||
def get_triton_rmsnorm_forward():
|
||||
if HAS_TRITON_RMSNORM:
|
||||
|
||||
def _triton_rmsnorm_forward(self: LlamaRMSNorm, hidden_states: torch.Tensor):
|
||||
return rms_layernorm(hidden_states, self.weight.data, self.variance_epsilon)
|
||||
def _triton_rmsnorm_forward(self: LlamaRMSNorm, hidden_states: torch.Tensor, norm_outpu: torch.Tensor):
|
||||
return rms_layernorm(hidden_states, self.weight.data, self.variance_epsilon, norm_outpu)
|
||||
|
||||
return _triton_rmsnorm_forward
|
||||
else:
|
||||
|
|
|
@ -50,12 +50,10 @@ if HAS_TRITON:
|
|||
tl.store(Y + cols, y.to(tl.float16), mask=mask)
|
||||
|
||||
@torch.no_grad()
|
||||
def rms_layernorm(x, weight, eps):
|
||||
def rms_layernorm(x, weight, eps, norm_output=None):
|
||||
# allocate output
|
||||
y = torch.empty_like(x)
|
||||
# reshape input data into 2D tensor, (total token, hidden_size)
|
||||
x_arg = x.reshape(-1, x.shape[-1])
|
||||
M, N = x_arg.shape
|
||||
y = torch.empty_like(x) if norm_output is None else norm_output
|
||||
M, N = x.shape
|
||||
# Less than 64KB per feature: enqueue fused kernel
|
||||
MAX_FUSED_SIZE = 65536 // x.element_size()
|
||||
|
||||
|
@ -67,5 +65,5 @@ if HAS_TRITON:
|
|||
num_warps = min(max(triton.next_power_of_2(N) // 256, 8), 32)
|
||||
|
||||
# enqueue kernel
|
||||
_rmsnorm_kernel[(M,)](x_arg, y, weight, x_arg.stride(0), N, eps, BLOCK_SIZE=BLOCK_SIZE, num_warps=num_warps)
|
||||
_rmsnorm_kernel[(M,)](x, y, weight, x.stride(0), N, eps, BLOCK_SIZE=BLOCK_SIZE, num_warps=num_warps)
|
||||
return y
|
||||
|
|
|
@ -9,7 +9,8 @@ from transformers import AutoTokenizer, GenerationConfig
|
|||
|
||||
import colossalai
|
||||
from colossalai.accelerator import get_accelerator
|
||||
from colossalai.inference import InferenceEngine
|
||||
from colossalai.inference.config import InferenceConfig
|
||||
from colossalai.inference.core.engine import InferenceEngine
|
||||
from colossalai.testing import clear_cache_before_run, rerun_if_address_is_in_use, spawn
|
||||
|
||||
GIGABYTE = 1024**3
|
||||
|
|
|
@ -23,22 +23,10 @@ CUDA_VISIBLE_DEVICES_set_n_least_memory_usage() {
|
|||
CUDA_VISIBLE_DEVICES_set_n_least_memory_usage 1
|
||||
|
||||
# benchmark llama2-7b one single GPU
|
||||
|
||||
for bsz in 16 32 64; do
|
||||
python3 ${PY_SCRIPT} -m llama2-7b --tp_size 1 --pp_size 1 -b $bsz -s 512 --output_len 256 --mode ${mode} | tee logs/${mode}_${GPU}_${bsz}_512_256.txt
|
||||
done
|
||||
|
||||
|
||||
for bsz in 16 32 64; do
|
||||
python3 ${PY_SCRIPT} -m llama2-7b --tp_size 1 --pp_size 1 -b $bsz -s 1024 --output_len 256 --mode ${mode} | tee logs/${mode}_${GPU}_${bsz}_1024_256.txt
|
||||
done
|
||||
|
||||
|
||||
for bsz in 16 32 64; do
|
||||
python3 ${PY_SCRIPT} -m llama2-7b --tp_size 1 --pp_size 1 -b $bsz -s 256 --output_len 128 --mode ${mode} | tee logs/${mode}_${GPU}_${bsz}_256_128.txt
|
||||
done
|
||||
|
||||
|
||||
for bsz in 16 32 64; do
|
||||
python3 ${PY_SCRIPT} -m llama2-7b --tp_size 1 --pp_size 1 -b $bsz -s 1024 --output_len 128 --mode ${mode} | tee logs/${mode}_${GPU}_${bsz}_1024_128.txt
|
||||
for input_len in 128 512 1024; do
|
||||
for output_len in 128 256; do
|
||||
for bsz in 16 32 64; do
|
||||
python3 ${PY_SCRIPT} -m llama2-7b --tp_size 1 --pp_size 1 -b ${bsz} -s ${input_len} --output_len ${output_len} --mode ${mode} | tee logs/${input_len}_${output_len}_${mode}_${GPU}_${bsz}.txt
|
||||
done
|
||||
done
|
||||
done
|
||||
|
|
Loading…
Reference in New Issue