mirror of https://github.com/hpcaitech/ColossalAI
Optimized the execution interval time between cuda kernels caused by view and memcopy (#5390)
* opt_view_and_memcopy * fix bugs in ci * fix ci bugs * update benchmark scripts * fix ci bugspull/5399/head
parent
730103819d
commit
2a718c8be8
|
@ -2,7 +2,6 @@
|
|||
from typing import List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
from torch.nn import Parameter
|
||||
from transformers.models.llama.modeling_llama import (
|
||||
LlamaAttention,
|
||||
LlamaConfig,
|
||||
|
@ -82,19 +81,21 @@ def llama_model_forward(
|
|||
|
||||
if batch.is_prompts:
|
||||
output_tensor = torch.zeros(
|
||||
(sequence_lengths.sum().item(), batch.num_heads, batch.head_dim), dtype=batch.dtype, device=batch.device
|
||||
(sequence_lengths.sum().item(), batch.num_heads * batch.head_dim), dtype=batch.dtype, device=batch.device
|
||||
)
|
||||
else:
|
||||
output_tensor = torch.zeros(
|
||||
(batch_size, batch.num_heads, batch.head_dim), dtype=batch.dtype, device=batch.device
|
||||
(batch_size, batch.num_heads * batch.head_dim), dtype=batch.dtype, device=batch.device
|
||||
)
|
||||
sm_scale = 1.0 / (batch.head_dim**0.5)
|
||||
|
||||
norm_output = torch.empty_like(hidden_states)
|
||||
residual = None
|
||||
|
||||
for layer_id, decoder_layer in enumerate(self.layers):
|
||||
hidden_states = decoder_layer(
|
||||
hidden_states, residual = decoder_layer(
|
||||
hidden_states,
|
||||
residual=residual,
|
||||
block_tables=block_tables,
|
||||
k_cache=k_caches[layer_id],
|
||||
v_cache=v_caches[layer_id],
|
||||
|
@ -111,8 +112,9 @@ def llama_model_forward(
|
|||
if batch.is_prompts:
|
||||
last_token_indexs = sequence_lengths.cumsum(dim=-1)
|
||||
hidden_states = hidden_states[last_token_indexs - 1].contiguous()
|
||||
residual = residual[last_token_indexs - 1].contiguous()
|
||||
norm_output = torch.empty_like(hidden_states)
|
||||
hidden_states = self.norm(hidden_states, norm_output)
|
||||
hidden_states, _ = self.norm(hidden_states, norm_output, residual)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
@ -120,6 +122,7 @@ def llama_model_forward(
|
|||
def llama_decoder_layer_forward(
|
||||
self: LlamaDecoderLayer,
|
||||
hidden_states: torch.Tensor,
|
||||
residual: torch.Tensor,
|
||||
block_tables: torch.Tensor = None,
|
||||
k_cache: torch.Tensor = None,
|
||||
v_cache: torch.Tensor = None,
|
||||
|
@ -136,6 +139,7 @@ def llama_decoder_layer_forward(
|
|||
|
||||
Args:
|
||||
hidden_states (torch.Tensor): input to the layer of shape [token_num, embed_dim].
|
||||
residual (torch.Tensor): shape [token_num, embed_dim], used to be added to hidden_states in out_proj.
|
||||
block_tables (torch.Tensor, optional): A 2D tensor of shape [batch_size, max_blocks_per_sequence],
|
||||
storing mapping of token_position_id -> block_id. Defaults to None.
|
||||
k_cache (torch.Tensor, optional): It holds the GPU memory for the key cache. Defaults to None.
|
||||
|
@ -151,12 +155,10 @@ def llama_decoder_layer_forward(
|
|||
sm_scale (int, optional): Used for flash attention. Defaults to None.
|
||||
"""
|
||||
|
||||
residual = hidden_states
|
||||
hidden_states = self.input_layernorm(hidden_states, norm_output)
|
||||
hidden_states, residual = self.input_layernorm(hidden_states, norm_output, residual)
|
||||
# Self Attention
|
||||
hidden_states = self.self_attn(
|
||||
hidden_states=hidden_states,
|
||||
residual=residual,
|
||||
block_tables=block_tables,
|
||||
k_cache=k_cache,
|
||||
v_cache=v_cache,
|
||||
|
@ -170,11 +172,10 @@ def llama_decoder_layer_forward(
|
|||
)
|
||||
|
||||
# Fully Connected
|
||||
residual = hidden_states
|
||||
hidden_states = self.post_attention_layernorm(hidden_states, norm_output)
|
||||
hidden_states = self.mlp(hidden_states, residual)
|
||||
hidden_states, residual = self.post_attention_layernorm(hidden_states, norm_output, residual)
|
||||
hidden_states = self.mlp(hidden_states)
|
||||
|
||||
return hidden_states
|
||||
return hidden_states, residual
|
||||
|
||||
|
||||
class NopadLlamaAttention(LlamaAttention):
|
||||
|
@ -198,16 +199,18 @@ class NopadLlamaAttention(LlamaAttention):
|
|||
attn_oproj_w (torch.Tensor, optional): The transposed o_proj weight. Defaults to None.
|
||||
"""
|
||||
super().__init__(config, layer_idx)
|
||||
self.q_proj.weight = Parameter(attn_qproj_w, requires_grad=False)
|
||||
self.k_proj.weight = Parameter(attn_kproj_w, requires_grad=False)
|
||||
self.v_proj.weight = Parameter(attn_vproj_w, requires_grad=False)
|
||||
self.o_proj.weight = Parameter(attn_oproj_w, requires_grad=False)
|
||||
self.q_proj_weight = attn_qproj_w
|
||||
self.k_proj_weight = attn_kproj_w
|
||||
self.v_proj_weight = attn_vproj_w
|
||||
self.o_proj_weight = attn_oproj_w
|
||||
|
||||
if self.num_heads == self.num_key_value_heads:
|
||||
qkv_weight_list = [self.q_proj.weight, self.k_proj.weight, self.v_proj.weight]
|
||||
qkv_weight_list = [self.q_proj_weight, self.k_proj_weight, self.v_proj_weight]
|
||||
self.qkv_weight = torch.stack(qkv_weight_list, dim=0)
|
||||
self.q_proj = None
|
||||
self.k_proj = None
|
||||
self.v_proj = None
|
||||
|
||||
self.q_proj = None
|
||||
self.k_proj = None
|
||||
self.v_proj = None
|
||||
|
||||
@staticmethod
|
||||
def from_native_module(module: LlamaAttention, *args, **kwargs) -> LlamaAttention:
|
||||
|
@ -239,7 +242,6 @@ class NopadLlamaAttention(LlamaAttention):
|
|||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
residual: torch.Tensor,
|
||||
block_tables: torch.Tensor = None,
|
||||
k_cache: torch.Tensor = None,
|
||||
v_cache: torch.Tensor = None,
|
||||
|
@ -254,7 +256,6 @@ class NopadLlamaAttention(LlamaAttention):
|
|||
"""
|
||||
Args:
|
||||
hidden_states (torch.Tensor): input to the layer of shape [token_num, embed_dim].
|
||||
residual (torch.Tensor): shape [token_num, embed_dim], used to be added to hidden_states in out_proj.
|
||||
block_tables (torch.Tensor, optional): A 2D tensor of shape [batch_size, max_blocks_per_sequence],
|
||||
storing mapping of token_position_id -> block_id. Defaults to None.
|
||||
k_cache (torch.Tensor, optional): It holds the GPU memory for the key cache. Defaults to None.
|
||||
|
@ -270,9 +271,9 @@ class NopadLlamaAttention(LlamaAttention):
|
|||
"""
|
||||
|
||||
if self.num_heads != self.num_key_value_heads:
|
||||
query_states = torch.mm(hidden_states, self.q_proj.weight).view(-1, self.num_heads, self.head_dim)
|
||||
key_states = torch.mm(hidden_states, self.k_proj.weight).view(-1, self.num_key_value_heads, self.head_dim)
|
||||
value_states = torch.mm(hidden_states, self.v_proj.weight).view(-1, self.num_key_value_heads, self.head_dim)
|
||||
query_states = torch.mm(hidden_states, self.q_proj_weight).view(-1, self.num_heads, self.head_dim)
|
||||
key_states = torch.mm(hidden_states, self.k_proj_weight).view(-1, self.num_key_value_heads, self.head_dim)
|
||||
value_states = torch.mm(hidden_states, self.v_proj_weight).view(-1, self.num_key_value_heads, self.head_dim)
|
||||
else:
|
||||
# fused qkv
|
||||
token_nums = hidden_states.size(0)
|
||||
|
@ -324,8 +325,7 @@ class NopadLlamaAttention(LlamaAttention):
|
|||
sm_scale=sm_scale,
|
||||
)
|
||||
|
||||
attn_output = attn_output.view(-1, self.hidden_size)
|
||||
attn_output = torch.addmm(residual, attn_output, self.o_proj.weight)
|
||||
attn_output = torch.mm(attn_output, self.o_proj_weight)
|
||||
|
||||
return attn_output
|
||||
|
||||
|
@ -348,10 +348,11 @@ class NopadLlamaMLP(LlamaMLP):
|
|||
mlp_dproj_w (torch.Tensor, optional): The transposed down_proj weight. Defaults to None.
|
||||
"""
|
||||
super().__init__(config)
|
||||
self.gate_up_weight = Parameter(torch.stack([mlp_gproj_w, mlp_uproj_w], dim=0), requires_grad=False)
|
||||
self.down_proj.weight = Parameter(mlp_dproj_w, requires_grad=False)
|
||||
self.gate_up_weight = torch.stack([mlp_gproj_w, mlp_uproj_w], dim=0)
|
||||
self.down_proj_weight = mlp_dproj_w
|
||||
self.gate_proj = None
|
||||
self.up_proj = None
|
||||
self.down_proj = None
|
||||
|
||||
@staticmethod
|
||||
def from_native_module(module: LlamaMLP, *args, **kwargs) -> LlamaMLP:
|
||||
|
@ -375,14 +376,13 @@ class NopadLlamaMLP(LlamaMLP):
|
|||
|
||||
return mlp_layer
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor, residual: torch.Tensor) -> torch.Tensor:
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
hidden_states (torch.Tensor): input to the layer of shape [token_num, embed_dim].
|
||||
residual (torch.Tensor): shape [token_num, embed_dim], used to be added to hidden_states in down_proj.
|
||||
"""
|
||||
hidden_states = hidden_states.expand(2, -1, -1)
|
||||
gate_up_proj_out = torch.bmm(hidden_states, self.gate_up_weight)
|
||||
act_out = torch.nn.functional.silu(gate_up_proj_out[0], inplace=True)
|
||||
tmp_out = act_out * gate_up_proj_out[1]
|
||||
return torch.addmm(residual, tmp_out, self.down_proj.weight)
|
||||
return torch.mm(tmp_out, self.down_proj_weight)
|
||||
|
|
|
@ -29,8 +29,10 @@ except:
|
|||
def get_triton_rmsnorm_forward():
|
||||
if HAS_TRITON_RMSNORM:
|
||||
|
||||
def _triton_rmsnorm_forward(self: LlamaRMSNorm, hidden_states: torch.Tensor, norm_output: torch.Tensor):
|
||||
return rms_layernorm(hidden_states, self.weight.data, self.variance_epsilon, norm_output)
|
||||
def _triton_rmsnorm_forward(
|
||||
self: LlamaRMSNorm, hidden_states: torch.Tensor, norm_output: torch.Tensor, residual: torch.Tensor = None
|
||||
):
|
||||
return rms_layernorm(hidden_states, self.weight.data, self.variance_epsilon, norm_output, residual)
|
||||
|
||||
return _triton_rmsnorm_forward
|
||||
else:
|
||||
|
|
|
@ -205,7 +205,7 @@ def context_attention_unpadded(
|
|||
assert k_cache.shape == v_cache.shape
|
||||
assert context_lengths.shape[0] == block_tables.shape[0]
|
||||
|
||||
num_tokens, num_heads, _ = q.shape
|
||||
num_tokens, num_heads, head_dim = q.shape
|
||||
num_kv_heads = k.shape[-2]
|
||||
assert num_kv_heads > 0 and num_heads % num_kv_heads == 0
|
||||
num_kv_group = num_heads // num_kv_heads
|
||||
|
@ -213,7 +213,9 @@ def context_attention_unpadded(
|
|||
num_seqs, max_blocks_per_seq = block_tables.shape
|
||||
max_seq_len = context_lengths.max().item() if max_seq_len is None else max_seq_len
|
||||
sm_scale = 1.0 / (Lq**0.5) if sm_scale is None else sm_scale
|
||||
output = torch.zeros_like(q) if output is None else output
|
||||
output = (
|
||||
torch.empty((num_tokens, num_heads * head_dim), dtype=q.dtype, device=q.device) if output is None else output
|
||||
)
|
||||
|
||||
# NOTE For now, BLOCK_M and BLOCK_N are supposed to be equivalent with
|
||||
# the size of physical cache block (i.e. `block_size`)
|
||||
|
@ -243,8 +245,8 @@ def context_attention_unpadded(
|
|||
v.stride(1),
|
||||
v.stride(2),
|
||||
output.stride(0),
|
||||
output.stride(1),
|
||||
output.stride(2),
|
||||
head_dim,
|
||||
1,
|
||||
k_cache.stride(0),
|
||||
k_cache.stride(1),
|
||||
k_cache.stride(2),
|
||||
|
|
|
@ -211,7 +211,7 @@ def flash_decoding_attention(
|
|||
records the (kv) sequence lengths incorporating past kv sequence lengths.
|
||||
block_tables (torch.Tensor): [batch_size, max_blocks_per_sequence]
|
||||
max_seq_len_in_batch (int): Maximum sequence length in the batch.
|
||||
output (torch.Tensor): [bsz, num_heads, head_dim]
|
||||
output (torch.Tensor): [bsz, num_heads * head_dim]
|
||||
mid_output (torch.Tensor): [ max_bsz , num_heads, kv_max_split_num, head_dim]
|
||||
Intermediate output tensor. `max_bsz` should be greater than or equal to `bsz`.
|
||||
mid_output_lse (torch.Tensor): [ max_bsz , num_heads, kv_max_split_num]
|
||||
|
@ -220,7 +220,7 @@ def flash_decoding_attention(
|
|||
num_kv_group (int, optional): Number of key/value groups. Defaults to 1.
|
||||
|
||||
Returns:
|
||||
Output tensor with shape [bsz, num_heads, head_dim]
|
||||
Output tensor with shape [bsz, num_heads * head_dim]
|
||||
"""
|
||||
q = q.squeeze() if q.dim() == 4 else q
|
||||
assert q.dim() == 3, f"Incompatible q dim: {q.dim()}"
|
||||
|
@ -261,7 +261,7 @@ def flash_decoding_attention(
|
|||
# NOTE use `triton.next_power_of_2` here to utilize the cache mechanism of triton
|
||||
# To optimize, revise batching/scheduling to batch 2^n sequences in a batch (preferred)
|
||||
grid = (triton.next_power_of_2(bsz), num_heads, triton.cdiv(triton.next_power_of_2(max_seq_len_in_batch), BLOCK_KV))
|
||||
output = torch.empty((bsz, num_heads, head_dim), dtype=q.dtype, device=q.device) if output is None else output
|
||||
output = torch.empty((bsz, num_heads * head_dim), dtype=q.dtype, device=q.device) if output is None else output
|
||||
|
||||
_flash_decoding_fwd_kernel[grid](
|
||||
q,
|
||||
|
@ -294,7 +294,7 @@ def flash_decoding_attention(
|
|||
BLOCK_SIZE=block_size,
|
||||
HEAD_DIM=head_dim,
|
||||
)
|
||||
|
||||
|
||||
grid = (triton.next_power_of_2(bsz), num_heads)
|
||||
|
||||
_flash_decoding_fwd_reduce_kernel[grid](
|
||||
|
@ -311,8 +311,8 @@ def flash_decoding_attention(
|
|||
mid_output_lse.stride(1),
|
||||
mid_output_lse.stride(2),
|
||||
output.stride(0),
|
||||
output.stride(1),
|
||||
output.stride(2),
|
||||
head_dim,
|
||||
1,
|
||||
BLOCK_KV=block_size,
|
||||
HEAD_DIM=head_dim,
|
||||
)
|
||||
|
|
|
@ -49,7 +49,50 @@ if HAS_TRITON:
|
|||
# Write output
|
||||
tl.store(Y + cols, y.to(tl.float16), mask=mask)
|
||||
|
||||
def rms_layernorm(x, weight, eps, norm_output=None):
|
||||
@triton.jit
|
||||
def _rmsnorm_with_residual_kernel(
|
||||
X, # pointer to the input
|
||||
Y, # pointer to the output
|
||||
R, # pointer to the residual
|
||||
W, # pointer to the weights
|
||||
stride, # how much to increase the pointer when moving by 1 row
|
||||
N, # number of columns in X
|
||||
eps, # epsilon to avoid division by zero
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
):
|
||||
# This triton kernel implements Root Mean Square Layer Norm (RMSNorm).
|
||||
|
||||
# Map the program id to the row of X and Y it should compute.
|
||||
row = tl.program_id(0)
|
||||
Y += row * stride
|
||||
X += row * stride
|
||||
R += row * stride
|
||||
# Compute variance
|
||||
_var = tl.zeros([BLOCK_SIZE], dtype=tl.float32)
|
||||
for off in range(0, N, BLOCK_SIZE):
|
||||
cols = off + tl.arange(0, BLOCK_SIZE)
|
||||
x = tl.load(X + cols, mask=cols < N, other=0.0).to(tl.float32)
|
||||
x = tl.where(cols < N, x, 0.0)
|
||||
r = tl.load(R + cols, mask=cols < N, other=0.0).to(tl.float32)
|
||||
r = tl.where(cols < N, r, 0.0)
|
||||
x = x + r
|
||||
_var += x * x
|
||||
mask = cols < N
|
||||
tl.store(X + cols, x.to(tl.float16), mask=mask)
|
||||
var = tl.sum(_var, axis=0) / N
|
||||
rstd = 1 / tl.sqrt(var + eps)
|
||||
# Normalize and apply linear transformation
|
||||
for off in range(0, N, BLOCK_SIZE):
|
||||
cols = off + tl.arange(0, BLOCK_SIZE)
|
||||
mask = cols < N
|
||||
w = tl.load(W + cols, mask=mask)
|
||||
x = tl.load(X + cols, mask=mask, other=0.0).to(tl.float32)
|
||||
x_hat = x * rstd
|
||||
y = x_hat * w
|
||||
# Write output
|
||||
tl.store(Y + cols, y.to(tl.float16), mask=mask)
|
||||
|
||||
def rms_layernorm(x, weight, eps, norm_output=None, residual=None):
|
||||
# allocate output
|
||||
y = torch.empty_like(x) if norm_output is None else norm_output
|
||||
M, N = x.shape
|
||||
|
@ -64,5 +107,10 @@ if HAS_TRITON:
|
|||
num_warps = min(max(triton.next_power_of_2(N) // 256, 8), 32)
|
||||
|
||||
# enqueue kernel
|
||||
_rmsnorm_kernel[(M,)](x, y, weight, x.stride(0), N, eps, BLOCK_SIZE=BLOCK_SIZE, num_warps=num_warps)
|
||||
return y
|
||||
if residual is None:
|
||||
_rmsnorm_kernel[(M,)](x, y, weight, x.stride(0), N, eps, BLOCK_SIZE=BLOCK_SIZE, num_warps=num_warps)
|
||||
else:
|
||||
_rmsnorm_with_residual_kernel[(M,)](
|
||||
x, y, residual, weight, x.stride(0), N, eps, BLOCK_SIZE=BLOCK_SIZE, num_warps=num_warps
|
||||
)
|
||||
return y, x
|
||||
|
|
|
@ -95,7 +95,7 @@ def benchmark_inference(args):
|
|||
else:
|
||||
assert args.model_path, "When testing pretrained weights, the model path must be provided.'"
|
||||
model = transformers.LlamaForCausalLM.from_pretrained(args.model_path).cuda()
|
||||
tokenizer = AutoTokenizer.from_pretrained(args.model_path)
|
||||
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/llama-tokenizer")
|
||||
|
||||
model = model.eval()
|
||||
|
||||
|
@ -122,6 +122,7 @@ def benchmark_inference(args):
|
|||
elif args.mode == "vllm":
|
||||
engine = LLM(
|
||||
model=args.model_path,
|
||||
tokenizer="hf-internal-testing/llama-tokenizer",
|
||||
max_num_seqs=mbsz,
|
||||
dtype="float16",
|
||||
enforce_eager=True,
|
||||
|
|
|
@ -100,10 +100,14 @@ def test_context_attention(
|
|||
k_cache_triton = torch.zeros_like(k_cache_ref)
|
||||
v_cache_triton = torch.zeros_like(v_cache_ref)
|
||||
|
||||
_, num_heads, head_dim = q_unpad.shape
|
||||
|
||||
out_triton = context_attention_unpadded(
|
||||
q_unpad, k_unpad, v_unpad, k_cache_triton, v_cache_triton, context_lengths, block_tables, block_size
|
||||
)
|
||||
|
||||
out_triton = out_triton.view(-1, num_heads, head_dim)
|
||||
|
||||
out_torch = torch_attn_unpad(q_unpad, k_unpad, v_unpad, context_lengths, num_attn_heads, num_kv_heads)
|
||||
|
||||
assert out_torch.shape == out_triton.shape
|
||||
|
|
|
@ -3,6 +3,7 @@ import torch
|
|||
import triton
|
||||
from packaging import version
|
||||
from transformers.models.llama.modeling_llama import LlamaRMSNorm
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
|
||||
from colossalai.kernel.triton import rms_layernorm
|
||||
from colossalai.testing.utils import parameterize
|
||||
|
@ -29,15 +30,28 @@ def test_layer_norm(M, N):
|
|||
x_shape = (M, N)
|
||||
w_shape = (x_shape[-1],)
|
||||
weight = torch.ones(w_shape, dtype=dtype, device="cuda")
|
||||
residual = torch.rand(x_shape, dtype=dtype, device="cuda")
|
||||
residual_copy = residual.clone()
|
||||
rms_norm = LlamaRMSNorm(hidden_size=N, eps=eps).cuda()
|
||||
x = -2.3 + 0.5 * torch.randn(x_shape, dtype=dtype, device="cuda")
|
||||
x_copy = x.clone()
|
||||
|
||||
y_triton = rms_layernorm(x, weight, eps=eps)
|
||||
y_triton, _ = rms_layernorm(x, weight, eps=eps)
|
||||
y_llama = rms_norm.forward(x).to(dtype)
|
||||
|
||||
assert y_triton.shape == y_llama.shape
|
||||
assert torch.allclose(y_triton, y_llama, atol=1e-5, rtol=1e-3)
|
||||
|
||||
y_triton, residual = rms_layernorm(x, weight, eps=eps, residual=residual)
|
||||
|
||||
x = x_copy + residual_copy
|
||||
|
||||
y_llama = rms_norm.forward(x).to(dtype)
|
||||
|
||||
assert y_triton.shape == y_llama.shape
|
||||
assert torch.allclose(y_triton, y_llama, atol=1e-5, rtol=1e-3)
|
||||
assert torch.allclose(x, residual, atol=1e-5, rtol=1e-3)
|
||||
|
||||
|
||||
# Triton benchmark plot attributions
|
||||
configs = [
|
||||
|
@ -45,9 +59,19 @@ configs = [
|
|||
x_names=["SEQUENCE_TOTAL"],
|
||||
x_vals=[i for i in range(128, 1025, 128)],
|
||||
line_arg="provider",
|
||||
line_vals=["torch_rms_layernorm", "triton_rms_layernorm"],
|
||||
line_names=["torch_rms_layernorm", "triton_rms_layernorm"],
|
||||
styles=[("red", "-"), ("blue", "-")],
|
||||
line_vals=[
|
||||
"vllm_rms_layernorm",
|
||||
"triton_rms_layernorm",
|
||||
"triton_rms_layernorm_with_residual",
|
||||
"vllm_rms_layernorm_with_residual",
|
||||
],
|
||||
line_names=[
|
||||
"vllm_rms_layernorm",
|
||||
"triton_rms_layernorm",
|
||||
"triton_rms_layernorm_with_residual",
|
||||
"vllm_rms_layernorm_with_residual",
|
||||
],
|
||||
styles=[("red", "-"), ("blue", "-"), ("yellow", "-"), ("green", "-")],
|
||||
ylabel="ms",
|
||||
plot_name=f"RMSNorm benchmarking results",
|
||||
args={"HIDDEN_SIZE": 1024},
|
||||
|
@ -68,13 +92,18 @@ def benchmark_rms_layernorm(
|
|||
eps = 1e-5
|
||||
x_shape = (SEQUENCE_TOTAL, HIDDEN_SIZE)
|
||||
w_shape = (x_shape[-1],)
|
||||
residual = torch.rand(x_shape, dtype=dtype, device="cuda")
|
||||
weight = torch.ones(w_shape, dtype=dtype, device="cuda")
|
||||
torch_norm = LlamaRMSNorm(hidden_size=HIDDEN_SIZE).to(dtype=dtype, device="cuda")
|
||||
vllm_norm = RMSNorm(hidden_size=HIDDEN_SIZE, eps=eps).to(dtype=dtype, device="cuda")
|
||||
x = -2.3 + 0.5 * torch.randn(x_shape, dtype=dtype, device="cuda")
|
||||
if provider == "torch_rms_layernorm":
|
||||
fn = lambda: torch_norm(x)
|
||||
if provider == "vllm_rms_layernorm":
|
||||
fn = lambda: vllm_norm(x)
|
||||
elif provider == "triton_rms_layernorm":
|
||||
fn = lambda: rms_layernorm(x, weight, eps=eps)
|
||||
elif provider == "vllm_rms_layernorm_with_residual":
|
||||
fn = lambda: vllm_norm(x, residual=residual)
|
||||
elif provider == "triton_rms_layernorm_with_residual":
|
||||
fn = lambda: rms_layernorm(x, weight, eps=eps, residual=residual)
|
||||
else:
|
||||
raise ValueError("Undefined provider.")
|
||||
|
||||
|
|
Loading…
Reference in New Issue