Browse Source

Optimized the execution interval time between cuda kernels caused by view and memcopy (#5390)

* opt_view_and_memcopy

* fix bugs in ci

* fix ci bugs

* update benchmark scripts

* fix ci bugs
pull/5399/head
yuehuayingxueluo 9 months ago committed by GitHub
parent
commit
2a718c8be8
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
  1. 58
      colossalai/inference/modeling/models/nopadding_llama.py
  2. 6
      colossalai/inference/modeling/policy/nopadding_llama.py
  3. 10
      colossalai/kernel/triton/context_attn_unpad.py
  4. 10
      colossalai/kernel/triton/flash_decoding.py
  5. 52
      colossalai/kernel/triton/rms_layernorm.py
  6. 3
      examples/inference/benchmark_llama.py
  7. 4
      tests/test_infer/test_ops/triton/test_context_attn_unpad.py
  8. 43
      tests/test_infer/test_ops/triton/test_rmsnorm_triton.py

58
colossalai/inference/modeling/models/nopadding_llama.py

@ -2,7 +2,6 @@
from typing import List, Optional, Tuple from typing import List, Optional, Tuple
import torch import torch
from torch.nn import Parameter
from transformers.models.llama.modeling_llama import ( from transformers.models.llama.modeling_llama import (
LlamaAttention, LlamaAttention,
LlamaConfig, LlamaConfig,
@ -82,19 +81,21 @@ def llama_model_forward(
if batch.is_prompts: if batch.is_prompts:
output_tensor = torch.zeros( output_tensor = torch.zeros(
(sequence_lengths.sum().item(), batch.num_heads, batch.head_dim), dtype=batch.dtype, device=batch.device (sequence_lengths.sum().item(), batch.num_heads * batch.head_dim), dtype=batch.dtype, device=batch.device
) )
else: else:
output_tensor = torch.zeros( output_tensor = torch.zeros(
(batch_size, batch.num_heads, batch.head_dim), dtype=batch.dtype, device=batch.device (batch_size, batch.num_heads * batch.head_dim), dtype=batch.dtype, device=batch.device
) )
sm_scale = 1.0 / (batch.head_dim**0.5) sm_scale = 1.0 / (batch.head_dim**0.5)
norm_output = torch.empty_like(hidden_states) norm_output = torch.empty_like(hidden_states)
residual = None
for layer_id, decoder_layer in enumerate(self.layers): for layer_id, decoder_layer in enumerate(self.layers):
hidden_states = decoder_layer( hidden_states, residual = decoder_layer(
hidden_states, hidden_states,
residual=residual,
block_tables=block_tables, block_tables=block_tables,
k_cache=k_caches[layer_id], k_cache=k_caches[layer_id],
v_cache=v_caches[layer_id], v_cache=v_caches[layer_id],
@ -111,8 +112,9 @@ def llama_model_forward(
if batch.is_prompts: if batch.is_prompts:
last_token_indexs = sequence_lengths.cumsum(dim=-1) last_token_indexs = sequence_lengths.cumsum(dim=-1)
hidden_states = hidden_states[last_token_indexs - 1].contiguous() hidden_states = hidden_states[last_token_indexs - 1].contiguous()
residual = residual[last_token_indexs - 1].contiguous()
norm_output = torch.empty_like(hidden_states) norm_output = torch.empty_like(hidden_states)
hidden_states = self.norm(hidden_states, norm_output) hidden_states, _ = self.norm(hidden_states, norm_output, residual)
return hidden_states return hidden_states
@ -120,6 +122,7 @@ def llama_model_forward(
def llama_decoder_layer_forward( def llama_decoder_layer_forward(
self: LlamaDecoderLayer, self: LlamaDecoderLayer,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
residual: torch.Tensor,
block_tables: torch.Tensor = None, block_tables: torch.Tensor = None,
k_cache: torch.Tensor = None, k_cache: torch.Tensor = None,
v_cache: torch.Tensor = None, v_cache: torch.Tensor = None,
@ -136,6 +139,7 @@ def llama_decoder_layer_forward(
Args: Args:
hidden_states (torch.Tensor): input to the layer of shape [token_num, embed_dim]. hidden_states (torch.Tensor): input to the layer of shape [token_num, embed_dim].
residual (torch.Tensor): shape [token_num, embed_dim], used to be added to hidden_states in out_proj.
block_tables (torch.Tensor, optional): A 2D tensor of shape [batch_size, max_blocks_per_sequence], block_tables (torch.Tensor, optional): A 2D tensor of shape [batch_size, max_blocks_per_sequence],
storing mapping of token_position_id -> block_id. Defaults to None. storing mapping of token_position_id -> block_id. Defaults to None.
k_cache (torch.Tensor, optional): It holds the GPU memory for the key cache. Defaults to None. k_cache (torch.Tensor, optional): It holds the GPU memory for the key cache. Defaults to None.
@ -151,12 +155,10 @@ def llama_decoder_layer_forward(
sm_scale (int, optional): Used for flash attention. Defaults to None. sm_scale (int, optional): Used for flash attention. Defaults to None.
""" """
residual = hidden_states hidden_states, residual = self.input_layernorm(hidden_states, norm_output, residual)
hidden_states = self.input_layernorm(hidden_states, norm_output)
# Self Attention # Self Attention
hidden_states = self.self_attn( hidden_states = self.self_attn(
hidden_states=hidden_states, hidden_states=hidden_states,
residual=residual,
block_tables=block_tables, block_tables=block_tables,
k_cache=k_cache, k_cache=k_cache,
v_cache=v_cache, v_cache=v_cache,
@ -170,11 +172,10 @@ def llama_decoder_layer_forward(
) )
# Fully Connected # Fully Connected
residual = hidden_states hidden_states, residual = self.post_attention_layernorm(hidden_states, norm_output, residual)
hidden_states = self.post_attention_layernorm(hidden_states, norm_output) hidden_states = self.mlp(hidden_states)
hidden_states = self.mlp(hidden_states, residual)
return hidden_states return hidden_states, residual
class NopadLlamaAttention(LlamaAttention): class NopadLlamaAttention(LlamaAttention):
@ -198,13 +199,15 @@ class NopadLlamaAttention(LlamaAttention):
attn_oproj_w (torch.Tensor, optional): The transposed o_proj weight. Defaults to None. attn_oproj_w (torch.Tensor, optional): The transposed o_proj weight. Defaults to None.
""" """
super().__init__(config, layer_idx) super().__init__(config, layer_idx)
self.q_proj.weight = Parameter(attn_qproj_w, requires_grad=False) self.q_proj_weight = attn_qproj_w
self.k_proj.weight = Parameter(attn_kproj_w, requires_grad=False) self.k_proj_weight = attn_kproj_w
self.v_proj.weight = Parameter(attn_vproj_w, requires_grad=False) self.v_proj_weight = attn_vproj_w
self.o_proj.weight = Parameter(attn_oproj_w, requires_grad=False) self.o_proj_weight = attn_oproj_w
if self.num_heads == self.num_key_value_heads: if self.num_heads == self.num_key_value_heads:
qkv_weight_list = [self.q_proj.weight, self.k_proj.weight, self.v_proj.weight] qkv_weight_list = [self.q_proj_weight, self.k_proj_weight, self.v_proj_weight]
self.qkv_weight = torch.stack(qkv_weight_list, dim=0) self.qkv_weight = torch.stack(qkv_weight_list, dim=0)
self.q_proj = None self.q_proj = None
self.k_proj = None self.k_proj = None
self.v_proj = None self.v_proj = None
@ -239,7 +242,6 @@ class NopadLlamaAttention(LlamaAttention):
def forward( def forward(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
residual: torch.Tensor,
block_tables: torch.Tensor = None, block_tables: torch.Tensor = None,
k_cache: torch.Tensor = None, k_cache: torch.Tensor = None,
v_cache: torch.Tensor = None, v_cache: torch.Tensor = None,
@ -254,7 +256,6 @@ class NopadLlamaAttention(LlamaAttention):
""" """
Args: Args:
hidden_states (torch.Tensor): input to the layer of shape [token_num, embed_dim]. hidden_states (torch.Tensor): input to the layer of shape [token_num, embed_dim].
residual (torch.Tensor): shape [token_num, embed_dim], used to be added to hidden_states in out_proj.
block_tables (torch.Tensor, optional): A 2D tensor of shape [batch_size, max_blocks_per_sequence], block_tables (torch.Tensor, optional): A 2D tensor of shape [batch_size, max_blocks_per_sequence],
storing mapping of token_position_id -> block_id. Defaults to None. storing mapping of token_position_id -> block_id. Defaults to None.
k_cache (torch.Tensor, optional): It holds the GPU memory for the key cache. Defaults to None. k_cache (torch.Tensor, optional): It holds the GPU memory for the key cache. Defaults to None.
@ -270,9 +271,9 @@ class NopadLlamaAttention(LlamaAttention):
""" """
if self.num_heads != self.num_key_value_heads: if self.num_heads != self.num_key_value_heads:
query_states = torch.mm(hidden_states, self.q_proj.weight).view(-1, self.num_heads, self.head_dim) query_states = torch.mm(hidden_states, self.q_proj_weight).view(-1, self.num_heads, self.head_dim)
key_states = torch.mm(hidden_states, self.k_proj.weight).view(-1, self.num_key_value_heads, self.head_dim) key_states = torch.mm(hidden_states, self.k_proj_weight).view(-1, self.num_key_value_heads, self.head_dim)
value_states = torch.mm(hidden_states, self.v_proj.weight).view(-1, self.num_key_value_heads, self.head_dim) value_states = torch.mm(hidden_states, self.v_proj_weight).view(-1, self.num_key_value_heads, self.head_dim)
else: else:
# fused qkv # fused qkv
token_nums = hidden_states.size(0) token_nums = hidden_states.size(0)
@ -324,8 +325,7 @@ class NopadLlamaAttention(LlamaAttention):
sm_scale=sm_scale, sm_scale=sm_scale,
) )
attn_output = attn_output.view(-1, self.hidden_size) attn_output = torch.mm(attn_output, self.o_proj_weight)
attn_output = torch.addmm(residual, attn_output, self.o_proj.weight)
return attn_output return attn_output
@ -348,10 +348,11 @@ class NopadLlamaMLP(LlamaMLP):
mlp_dproj_w (torch.Tensor, optional): The transposed down_proj weight. Defaults to None. mlp_dproj_w (torch.Tensor, optional): The transposed down_proj weight. Defaults to None.
""" """
super().__init__(config) super().__init__(config)
self.gate_up_weight = Parameter(torch.stack([mlp_gproj_w, mlp_uproj_w], dim=0), requires_grad=False) self.gate_up_weight = torch.stack([mlp_gproj_w, mlp_uproj_w], dim=0)
self.down_proj.weight = Parameter(mlp_dproj_w, requires_grad=False) self.down_proj_weight = mlp_dproj_w
self.gate_proj = None self.gate_proj = None
self.up_proj = None self.up_proj = None
self.down_proj = None
@staticmethod @staticmethod
def from_native_module(module: LlamaMLP, *args, **kwargs) -> LlamaMLP: def from_native_module(module: LlamaMLP, *args, **kwargs) -> LlamaMLP:
@ -375,14 +376,13 @@ class NopadLlamaMLP(LlamaMLP):
return mlp_layer return mlp_layer
def forward(self, hidden_states: torch.Tensor, residual: torch.Tensor) -> torch.Tensor: def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
""" """
Args: Args:
hidden_states (torch.Tensor): input to the layer of shape [token_num, embed_dim]. hidden_states (torch.Tensor): input to the layer of shape [token_num, embed_dim].
residual (torch.Tensor): shape [token_num, embed_dim], used to be added to hidden_states in down_proj.
""" """
hidden_states = hidden_states.expand(2, -1, -1) hidden_states = hidden_states.expand(2, -1, -1)
gate_up_proj_out = torch.bmm(hidden_states, self.gate_up_weight) gate_up_proj_out = torch.bmm(hidden_states, self.gate_up_weight)
act_out = torch.nn.functional.silu(gate_up_proj_out[0], inplace=True) act_out = torch.nn.functional.silu(gate_up_proj_out[0], inplace=True)
tmp_out = act_out * gate_up_proj_out[1] tmp_out = act_out * gate_up_proj_out[1]
return torch.addmm(residual, tmp_out, self.down_proj.weight) return torch.mm(tmp_out, self.down_proj_weight)

6
colossalai/inference/modeling/policy/nopadding_llama.py

@ -29,8 +29,10 @@ except:
def get_triton_rmsnorm_forward(): def get_triton_rmsnorm_forward():
if HAS_TRITON_RMSNORM: if HAS_TRITON_RMSNORM:
def _triton_rmsnorm_forward(self: LlamaRMSNorm, hidden_states: torch.Tensor, norm_output: torch.Tensor): def _triton_rmsnorm_forward(
return rms_layernorm(hidden_states, self.weight.data, self.variance_epsilon, norm_output) self: LlamaRMSNorm, hidden_states: torch.Tensor, norm_output: torch.Tensor, residual: torch.Tensor = None
):
return rms_layernorm(hidden_states, self.weight.data, self.variance_epsilon, norm_output, residual)
return _triton_rmsnorm_forward return _triton_rmsnorm_forward
else: else:

10
colossalai/kernel/triton/context_attn_unpad.py

@ -205,7 +205,7 @@ def context_attention_unpadded(
assert k_cache.shape == v_cache.shape assert k_cache.shape == v_cache.shape
assert context_lengths.shape[0] == block_tables.shape[0] assert context_lengths.shape[0] == block_tables.shape[0]
num_tokens, num_heads, _ = q.shape num_tokens, num_heads, head_dim = q.shape
num_kv_heads = k.shape[-2] num_kv_heads = k.shape[-2]
assert num_kv_heads > 0 and num_heads % num_kv_heads == 0 assert num_kv_heads > 0 and num_heads % num_kv_heads == 0
num_kv_group = num_heads // num_kv_heads num_kv_group = num_heads // num_kv_heads
@ -213,7 +213,9 @@ def context_attention_unpadded(
num_seqs, max_blocks_per_seq = block_tables.shape num_seqs, max_blocks_per_seq = block_tables.shape
max_seq_len = context_lengths.max().item() if max_seq_len is None else max_seq_len max_seq_len = context_lengths.max().item() if max_seq_len is None else max_seq_len
sm_scale = 1.0 / (Lq**0.5) if sm_scale is None else sm_scale sm_scale = 1.0 / (Lq**0.5) if sm_scale is None else sm_scale
output = torch.zeros_like(q) if output is None else output output = (
torch.empty((num_tokens, num_heads * head_dim), dtype=q.dtype, device=q.device) if output is None else output
)
# NOTE For now, BLOCK_M and BLOCK_N are supposed to be equivalent with # NOTE For now, BLOCK_M and BLOCK_N are supposed to be equivalent with
# the size of physical cache block (i.e. `block_size`) # the size of physical cache block (i.e. `block_size`)
@ -243,8 +245,8 @@ def context_attention_unpadded(
v.stride(1), v.stride(1),
v.stride(2), v.stride(2),
output.stride(0), output.stride(0),
output.stride(1), head_dim,
output.stride(2), 1,
k_cache.stride(0), k_cache.stride(0),
k_cache.stride(1), k_cache.stride(1),
k_cache.stride(2), k_cache.stride(2),

10
colossalai/kernel/triton/flash_decoding.py

@ -211,7 +211,7 @@ def flash_decoding_attention(
records the (kv) sequence lengths incorporating past kv sequence lengths. records the (kv) sequence lengths incorporating past kv sequence lengths.
block_tables (torch.Tensor): [batch_size, max_blocks_per_sequence] block_tables (torch.Tensor): [batch_size, max_blocks_per_sequence]
max_seq_len_in_batch (int): Maximum sequence length in the batch. max_seq_len_in_batch (int): Maximum sequence length in the batch.
output (torch.Tensor): [bsz, num_heads, head_dim] output (torch.Tensor): [bsz, num_heads * head_dim]
mid_output (torch.Tensor): [ max_bsz , num_heads, kv_max_split_num, head_dim] mid_output (torch.Tensor): [ max_bsz , num_heads, kv_max_split_num, head_dim]
Intermediate output tensor. `max_bsz` should be greater than or equal to `bsz`. Intermediate output tensor. `max_bsz` should be greater than or equal to `bsz`.
mid_output_lse (torch.Tensor): [ max_bsz , num_heads, kv_max_split_num] mid_output_lse (torch.Tensor): [ max_bsz , num_heads, kv_max_split_num]
@ -220,7 +220,7 @@ def flash_decoding_attention(
num_kv_group (int, optional): Number of key/value groups. Defaults to 1. num_kv_group (int, optional): Number of key/value groups. Defaults to 1.
Returns: Returns:
Output tensor with shape [bsz, num_heads, head_dim] Output tensor with shape [bsz, num_heads * head_dim]
""" """
q = q.squeeze() if q.dim() == 4 else q q = q.squeeze() if q.dim() == 4 else q
assert q.dim() == 3, f"Incompatible q dim: {q.dim()}" assert q.dim() == 3, f"Incompatible q dim: {q.dim()}"
@ -261,7 +261,7 @@ def flash_decoding_attention(
# NOTE use `triton.next_power_of_2` here to utilize the cache mechanism of triton # NOTE use `triton.next_power_of_2` here to utilize the cache mechanism of triton
# To optimize, revise batching/scheduling to batch 2^n sequences in a batch (preferred) # To optimize, revise batching/scheduling to batch 2^n sequences in a batch (preferred)
grid = (triton.next_power_of_2(bsz), num_heads, triton.cdiv(triton.next_power_of_2(max_seq_len_in_batch), BLOCK_KV)) grid = (triton.next_power_of_2(bsz), num_heads, triton.cdiv(triton.next_power_of_2(max_seq_len_in_batch), BLOCK_KV))
output = torch.empty((bsz, num_heads, head_dim), dtype=q.dtype, device=q.device) if output is None else output output = torch.empty((bsz, num_heads * head_dim), dtype=q.dtype, device=q.device) if output is None else output
_flash_decoding_fwd_kernel[grid]( _flash_decoding_fwd_kernel[grid](
q, q,
@ -311,8 +311,8 @@ def flash_decoding_attention(
mid_output_lse.stride(1), mid_output_lse.stride(1),
mid_output_lse.stride(2), mid_output_lse.stride(2),
output.stride(0), output.stride(0),
output.stride(1), head_dim,
output.stride(2), 1,
BLOCK_KV=block_size, BLOCK_KV=block_size,
HEAD_DIM=head_dim, HEAD_DIM=head_dim,
) )

52
colossalai/kernel/triton/rms_layernorm.py

@ -49,7 +49,50 @@ if HAS_TRITON:
# Write output # Write output
tl.store(Y + cols, y.to(tl.float16), mask=mask) tl.store(Y + cols, y.to(tl.float16), mask=mask)
def rms_layernorm(x, weight, eps, norm_output=None): @triton.jit
def _rmsnorm_with_residual_kernel(
X, # pointer to the input
Y, # pointer to the output
R, # pointer to the residual
W, # pointer to the weights
stride, # how much to increase the pointer when moving by 1 row
N, # number of columns in X
eps, # epsilon to avoid division by zero
BLOCK_SIZE: tl.constexpr,
):
# This triton kernel implements Root Mean Square Layer Norm (RMSNorm).
# Map the program id to the row of X and Y it should compute.
row = tl.program_id(0)
Y += row * stride
X += row * stride
R += row * stride
# Compute variance
_var = tl.zeros([BLOCK_SIZE], dtype=tl.float32)
for off in range(0, N, BLOCK_SIZE):
cols = off + tl.arange(0, BLOCK_SIZE)
x = tl.load(X + cols, mask=cols < N, other=0.0).to(tl.float32)
x = tl.where(cols < N, x, 0.0)
r = tl.load(R + cols, mask=cols < N, other=0.0).to(tl.float32)
r = tl.where(cols < N, r, 0.0)
x = x + r
_var += x * x
mask = cols < N
tl.store(X + cols, x.to(tl.float16), mask=mask)
var = tl.sum(_var, axis=0) / N
rstd = 1 / tl.sqrt(var + eps)
# Normalize and apply linear transformation
for off in range(0, N, BLOCK_SIZE):
cols = off + tl.arange(0, BLOCK_SIZE)
mask = cols < N
w = tl.load(W + cols, mask=mask)
x = tl.load(X + cols, mask=mask, other=0.0).to(tl.float32)
x_hat = x * rstd
y = x_hat * w
# Write output
tl.store(Y + cols, y.to(tl.float16), mask=mask)
def rms_layernorm(x, weight, eps, norm_output=None, residual=None):
# allocate output # allocate output
y = torch.empty_like(x) if norm_output is None else norm_output y = torch.empty_like(x) if norm_output is None else norm_output
M, N = x.shape M, N = x.shape
@ -64,5 +107,10 @@ if HAS_TRITON:
num_warps = min(max(triton.next_power_of_2(N) // 256, 8), 32) num_warps = min(max(triton.next_power_of_2(N) // 256, 8), 32)
# enqueue kernel # enqueue kernel
if residual is None:
_rmsnorm_kernel[(M,)](x, y, weight, x.stride(0), N, eps, BLOCK_SIZE=BLOCK_SIZE, num_warps=num_warps) _rmsnorm_kernel[(M,)](x, y, weight, x.stride(0), N, eps, BLOCK_SIZE=BLOCK_SIZE, num_warps=num_warps)
return y else:
_rmsnorm_with_residual_kernel[(M,)](
x, y, residual, weight, x.stride(0), N, eps, BLOCK_SIZE=BLOCK_SIZE, num_warps=num_warps
)
return y, x

3
examples/inference/benchmark_llama.py

@ -95,7 +95,7 @@ def benchmark_inference(args):
else: else:
assert args.model_path, "When testing pretrained weights, the model path must be provided.'" assert args.model_path, "When testing pretrained weights, the model path must be provided.'"
model = transformers.LlamaForCausalLM.from_pretrained(args.model_path).cuda() model = transformers.LlamaForCausalLM.from_pretrained(args.model_path).cuda()
tokenizer = AutoTokenizer.from_pretrained(args.model_path) tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/llama-tokenizer")
model = model.eval() model = model.eval()
@ -122,6 +122,7 @@ def benchmark_inference(args):
elif args.mode == "vllm": elif args.mode == "vllm":
engine = LLM( engine = LLM(
model=args.model_path, model=args.model_path,
tokenizer="hf-internal-testing/llama-tokenizer",
max_num_seqs=mbsz, max_num_seqs=mbsz,
dtype="float16", dtype="float16",
enforce_eager=True, enforce_eager=True,

4
tests/test_infer/test_ops/triton/test_context_attn_unpad.py

@ -100,10 +100,14 @@ def test_context_attention(
k_cache_triton = torch.zeros_like(k_cache_ref) k_cache_triton = torch.zeros_like(k_cache_ref)
v_cache_triton = torch.zeros_like(v_cache_ref) v_cache_triton = torch.zeros_like(v_cache_ref)
_, num_heads, head_dim = q_unpad.shape
out_triton = context_attention_unpadded( out_triton = context_attention_unpadded(
q_unpad, k_unpad, v_unpad, k_cache_triton, v_cache_triton, context_lengths, block_tables, block_size q_unpad, k_unpad, v_unpad, k_cache_triton, v_cache_triton, context_lengths, block_tables, block_size
) )
out_triton = out_triton.view(-1, num_heads, head_dim)
out_torch = torch_attn_unpad(q_unpad, k_unpad, v_unpad, context_lengths, num_attn_heads, num_kv_heads) out_torch = torch_attn_unpad(q_unpad, k_unpad, v_unpad, context_lengths, num_attn_heads, num_kv_heads)
assert out_torch.shape == out_triton.shape assert out_torch.shape == out_triton.shape

43
tests/test_infer/test_ops/triton/test_rmsnorm_triton.py

@ -3,6 +3,7 @@ import torch
import triton import triton
from packaging import version from packaging import version
from transformers.models.llama.modeling_llama import LlamaRMSNorm from transformers.models.llama.modeling_llama import LlamaRMSNorm
from vllm.model_executor.layers.layernorm import RMSNorm
from colossalai.kernel.triton import rms_layernorm from colossalai.kernel.triton import rms_layernorm
from colossalai.testing.utils import parameterize from colossalai.testing.utils import parameterize
@ -29,15 +30,28 @@ def test_layer_norm(M, N):
x_shape = (M, N) x_shape = (M, N)
w_shape = (x_shape[-1],) w_shape = (x_shape[-1],)
weight = torch.ones(w_shape, dtype=dtype, device="cuda") weight = torch.ones(w_shape, dtype=dtype, device="cuda")
residual = torch.rand(x_shape, dtype=dtype, device="cuda")
residual_copy = residual.clone()
rms_norm = LlamaRMSNorm(hidden_size=N, eps=eps).cuda() rms_norm = LlamaRMSNorm(hidden_size=N, eps=eps).cuda()
x = -2.3 + 0.5 * torch.randn(x_shape, dtype=dtype, device="cuda") x = -2.3 + 0.5 * torch.randn(x_shape, dtype=dtype, device="cuda")
x_copy = x.clone()
y_triton = rms_layernorm(x, weight, eps=eps) y_triton, _ = rms_layernorm(x, weight, eps=eps)
y_llama = rms_norm.forward(x).to(dtype) y_llama = rms_norm.forward(x).to(dtype)
assert y_triton.shape == y_llama.shape assert y_triton.shape == y_llama.shape
assert torch.allclose(y_triton, y_llama, atol=1e-5, rtol=1e-3) assert torch.allclose(y_triton, y_llama, atol=1e-5, rtol=1e-3)
y_triton, residual = rms_layernorm(x, weight, eps=eps, residual=residual)
x = x_copy + residual_copy
y_llama = rms_norm.forward(x).to(dtype)
assert y_triton.shape == y_llama.shape
assert torch.allclose(y_triton, y_llama, atol=1e-5, rtol=1e-3)
assert torch.allclose(x, residual, atol=1e-5, rtol=1e-3)
# Triton benchmark plot attributions # Triton benchmark plot attributions
configs = [ configs = [
@ -45,9 +59,19 @@ configs = [
x_names=["SEQUENCE_TOTAL"], x_names=["SEQUENCE_TOTAL"],
x_vals=[i for i in range(128, 1025, 128)], x_vals=[i for i in range(128, 1025, 128)],
line_arg="provider", line_arg="provider",
line_vals=["torch_rms_layernorm", "triton_rms_layernorm"], line_vals=[
line_names=["torch_rms_layernorm", "triton_rms_layernorm"], "vllm_rms_layernorm",
styles=[("red", "-"), ("blue", "-")], "triton_rms_layernorm",
"triton_rms_layernorm_with_residual",
"vllm_rms_layernorm_with_residual",
],
line_names=[
"vllm_rms_layernorm",
"triton_rms_layernorm",
"triton_rms_layernorm_with_residual",
"vllm_rms_layernorm_with_residual",
],
styles=[("red", "-"), ("blue", "-"), ("yellow", "-"), ("green", "-")],
ylabel="ms", ylabel="ms",
plot_name=f"RMSNorm benchmarking results", plot_name=f"RMSNorm benchmarking results",
args={"HIDDEN_SIZE": 1024}, args={"HIDDEN_SIZE": 1024},
@ -68,13 +92,18 @@ def benchmark_rms_layernorm(
eps = 1e-5 eps = 1e-5
x_shape = (SEQUENCE_TOTAL, HIDDEN_SIZE) x_shape = (SEQUENCE_TOTAL, HIDDEN_SIZE)
w_shape = (x_shape[-1],) w_shape = (x_shape[-1],)
residual = torch.rand(x_shape, dtype=dtype, device="cuda")
weight = torch.ones(w_shape, dtype=dtype, device="cuda") weight = torch.ones(w_shape, dtype=dtype, device="cuda")
torch_norm = LlamaRMSNorm(hidden_size=HIDDEN_SIZE).to(dtype=dtype, device="cuda") vllm_norm = RMSNorm(hidden_size=HIDDEN_SIZE, eps=eps).to(dtype=dtype, device="cuda")
x = -2.3 + 0.5 * torch.randn(x_shape, dtype=dtype, device="cuda") x = -2.3 + 0.5 * torch.randn(x_shape, dtype=dtype, device="cuda")
if provider == "torch_rms_layernorm": if provider == "vllm_rms_layernorm":
fn = lambda: torch_norm(x) fn = lambda: vllm_norm(x)
elif provider == "triton_rms_layernorm": elif provider == "triton_rms_layernorm":
fn = lambda: rms_layernorm(x, weight, eps=eps) fn = lambda: rms_layernorm(x, weight, eps=eps)
elif provider == "vllm_rms_layernorm_with_residual":
fn = lambda: vllm_norm(x, residual=residual)
elif provider == "triton_rms_layernorm_with_residual":
fn = lambda: rms_layernorm(x, weight, eps=eps, residual=residual)
else: else:
raise ValueError("Undefined provider.") raise ValueError("Undefined provider.")

Loading…
Cancel
Save