mirror of https://github.com/hpcaitech/ColossalAI
Optimized the execution interval time between cuda kernels caused by view and memcopy (#5390)
* opt_view_and_memcopy * fix bugs in ci * fix ci bugs * update benchmark scripts * fix ci bugspull/5399/head
parent
730103819d
commit
2a718c8be8
|
@ -2,7 +2,6 @@
|
||||||
from typing import List, Optional, Tuple
|
from typing import List, Optional, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch.nn import Parameter
|
|
||||||
from transformers.models.llama.modeling_llama import (
|
from transformers.models.llama.modeling_llama import (
|
||||||
LlamaAttention,
|
LlamaAttention,
|
||||||
LlamaConfig,
|
LlamaConfig,
|
||||||
|
@ -82,19 +81,21 @@ def llama_model_forward(
|
||||||
|
|
||||||
if batch.is_prompts:
|
if batch.is_prompts:
|
||||||
output_tensor = torch.zeros(
|
output_tensor = torch.zeros(
|
||||||
(sequence_lengths.sum().item(), batch.num_heads, batch.head_dim), dtype=batch.dtype, device=batch.device
|
(sequence_lengths.sum().item(), batch.num_heads * batch.head_dim), dtype=batch.dtype, device=batch.device
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
output_tensor = torch.zeros(
|
output_tensor = torch.zeros(
|
||||||
(batch_size, batch.num_heads, batch.head_dim), dtype=batch.dtype, device=batch.device
|
(batch_size, batch.num_heads * batch.head_dim), dtype=batch.dtype, device=batch.device
|
||||||
)
|
)
|
||||||
sm_scale = 1.0 / (batch.head_dim**0.5)
|
sm_scale = 1.0 / (batch.head_dim**0.5)
|
||||||
|
|
||||||
norm_output = torch.empty_like(hidden_states)
|
norm_output = torch.empty_like(hidden_states)
|
||||||
|
residual = None
|
||||||
|
|
||||||
for layer_id, decoder_layer in enumerate(self.layers):
|
for layer_id, decoder_layer in enumerate(self.layers):
|
||||||
hidden_states = decoder_layer(
|
hidden_states, residual = decoder_layer(
|
||||||
hidden_states,
|
hidden_states,
|
||||||
|
residual=residual,
|
||||||
block_tables=block_tables,
|
block_tables=block_tables,
|
||||||
k_cache=k_caches[layer_id],
|
k_cache=k_caches[layer_id],
|
||||||
v_cache=v_caches[layer_id],
|
v_cache=v_caches[layer_id],
|
||||||
|
@ -111,8 +112,9 @@ def llama_model_forward(
|
||||||
if batch.is_prompts:
|
if batch.is_prompts:
|
||||||
last_token_indexs = sequence_lengths.cumsum(dim=-1)
|
last_token_indexs = sequence_lengths.cumsum(dim=-1)
|
||||||
hidden_states = hidden_states[last_token_indexs - 1].contiguous()
|
hidden_states = hidden_states[last_token_indexs - 1].contiguous()
|
||||||
|
residual = residual[last_token_indexs - 1].contiguous()
|
||||||
norm_output = torch.empty_like(hidden_states)
|
norm_output = torch.empty_like(hidden_states)
|
||||||
hidden_states = self.norm(hidden_states, norm_output)
|
hidden_states, _ = self.norm(hidden_states, norm_output, residual)
|
||||||
|
|
||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
|
@ -120,6 +122,7 @@ def llama_model_forward(
|
||||||
def llama_decoder_layer_forward(
|
def llama_decoder_layer_forward(
|
||||||
self: LlamaDecoderLayer,
|
self: LlamaDecoderLayer,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
|
residual: torch.Tensor,
|
||||||
block_tables: torch.Tensor = None,
|
block_tables: torch.Tensor = None,
|
||||||
k_cache: torch.Tensor = None,
|
k_cache: torch.Tensor = None,
|
||||||
v_cache: torch.Tensor = None,
|
v_cache: torch.Tensor = None,
|
||||||
|
@ -136,6 +139,7 @@ def llama_decoder_layer_forward(
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
hidden_states (torch.Tensor): input to the layer of shape [token_num, embed_dim].
|
hidden_states (torch.Tensor): input to the layer of shape [token_num, embed_dim].
|
||||||
|
residual (torch.Tensor): shape [token_num, embed_dim], used to be added to hidden_states in out_proj.
|
||||||
block_tables (torch.Tensor, optional): A 2D tensor of shape [batch_size, max_blocks_per_sequence],
|
block_tables (torch.Tensor, optional): A 2D tensor of shape [batch_size, max_blocks_per_sequence],
|
||||||
storing mapping of token_position_id -> block_id. Defaults to None.
|
storing mapping of token_position_id -> block_id. Defaults to None.
|
||||||
k_cache (torch.Tensor, optional): It holds the GPU memory for the key cache. Defaults to None.
|
k_cache (torch.Tensor, optional): It holds the GPU memory for the key cache. Defaults to None.
|
||||||
|
@ -151,12 +155,10 @@ def llama_decoder_layer_forward(
|
||||||
sm_scale (int, optional): Used for flash attention. Defaults to None.
|
sm_scale (int, optional): Used for flash attention. Defaults to None.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
residual = hidden_states
|
hidden_states, residual = self.input_layernorm(hidden_states, norm_output, residual)
|
||||||
hidden_states = self.input_layernorm(hidden_states, norm_output)
|
|
||||||
# Self Attention
|
# Self Attention
|
||||||
hidden_states = self.self_attn(
|
hidden_states = self.self_attn(
|
||||||
hidden_states=hidden_states,
|
hidden_states=hidden_states,
|
||||||
residual=residual,
|
|
||||||
block_tables=block_tables,
|
block_tables=block_tables,
|
||||||
k_cache=k_cache,
|
k_cache=k_cache,
|
||||||
v_cache=v_cache,
|
v_cache=v_cache,
|
||||||
|
@ -170,11 +172,10 @@ def llama_decoder_layer_forward(
|
||||||
)
|
)
|
||||||
|
|
||||||
# Fully Connected
|
# Fully Connected
|
||||||
residual = hidden_states
|
hidden_states, residual = self.post_attention_layernorm(hidden_states, norm_output, residual)
|
||||||
hidden_states = self.post_attention_layernorm(hidden_states, norm_output)
|
hidden_states = self.mlp(hidden_states)
|
||||||
hidden_states = self.mlp(hidden_states, residual)
|
|
||||||
|
|
||||||
return hidden_states
|
return hidden_states, residual
|
||||||
|
|
||||||
|
|
||||||
class NopadLlamaAttention(LlamaAttention):
|
class NopadLlamaAttention(LlamaAttention):
|
||||||
|
@ -198,16 +199,18 @@ class NopadLlamaAttention(LlamaAttention):
|
||||||
attn_oproj_w (torch.Tensor, optional): The transposed o_proj weight. Defaults to None.
|
attn_oproj_w (torch.Tensor, optional): The transposed o_proj weight. Defaults to None.
|
||||||
"""
|
"""
|
||||||
super().__init__(config, layer_idx)
|
super().__init__(config, layer_idx)
|
||||||
self.q_proj.weight = Parameter(attn_qproj_w, requires_grad=False)
|
self.q_proj_weight = attn_qproj_w
|
||||||
self.k_proj.weight = Parameter(attn_kproj_w, requires_grad=False)
|
self.k_proj_weight = attn_kproj_w
|
||||||
self.v_proj.weight = Parameter(attn_vproj_w, requires_grad=False)
|
self.v_proj_weight = attn_vproj_w
|
||||||
self.o_proj.weight = Parameter(attn_oproj_w, requires_grad=False)
|
self.o_proj_weight = attn_oproj_w
|
||||||
|
|
||||||
if self.num_heads == self.num_key_value_heads:
|
if self.num_heads == self.num_key_value_heads:
|
||||||
qkv_weight_list = [self.q_proj.weight, self.k_proj.weight, self.v_proj.weight]
|
qkv_weight_list = [self.q_proj_weight, self.k_proj_weight, self.v_proj_weight]
|
||||||
self.qkv_weight = torch.stack(qkv_weight_list, dim=0)
|
self.qkv_weight = torch.stack(qkv_weight_list, dim=0)
|
||||||
self.q_proj = None
|
|
||||||
self.k_proj = None
|
self.q_proj = None
|
||||||
self.v_proj = None
|
self.k_proj = None
|
||||||
|
self.v_proj = None
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def from_native_module(module: LlamaAttention, *args, **kwargs) -> LlamaAttention:
|
def from_native_module(module: LlamaAttention, *args, **kwargs) -> LlamaAttention:
|
||||||
|
@ -239,7 +242,6 @@ class NopadLlamaAttention(LlamaAttention):
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
residual: torch.Tensor,
|
|
||||||
block_tables: torch.Tensor = None,
|
block_tables: torch.Tensor = None,
|
||||||
k_cache: torch.Tensor = None,
|
k_cache: torch.Tensor = None,
|
||||||
v_cache: torch.Tensor = None,
|
v_cache: torch.Tensor = None,
|
||||||
|
@ -254,7 +256,6 @@ class NopadLlamaAttention(LlamaAttention):
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
hidden_states (torch.Tensor): input to the layer of shape [token_num, embed_dim].
|
hidden_states (torch.Tensor): input to the layer of shape [token_num, embed_dim].
|
||||||
residual (torch.Tensor): shape [token_num, embed_dim], used to be added to hidden_states in out_proj.
|
|
||||||
block_tables (torch.Tensor, optional): A 2D tensor of shape [batch_size, max_blocks_per_sequence],
|
block_tables (torch.Tensor, optional): A 2D tensor of shape [batch_size, max_blocks_per_sequence],
|
||||||
storing mapping of token_position_id -> block_id. Defaults to None.
|
storing mapping of token_position_id -> block_id. Defaults to None.
|
||||||
k_cache (torch.Tensor, optional): It holds the GPU memory for the key cache. Defaults to None.
|
k_cache (torch.Tensor, optional): It holds the GPU memory for the key cache. Defaults to None.
|
||||||
|
@ -270,9 +271,9 @@ class NopadLlamaAttention(LlamaAttention):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if self.num_heads != self.num_key_value_heads:
|
if self.num_heads != self.num_key_value_heads:
|
||||||
query_states = torch.mm(hidden_states, self.q_proj.weight).view(-1, self.num_heads, self.head_dim)
|
query_states = torch.mm(hidden_states, self.q_proj_weight).view(-1, self.num_heads, self.head_dim)
|
||||||
key_states = torch.mm(hidden_states, self.k_proj.weight).view(-1, self.num_key_value_heads, self.head_dim)
|
key_states = torch.mm(hidden_states, self.k_proj_weight).view(-1, self.num_key_value_heads, self.head_dim)
|
||||||
value_states = torch.mm(hidden_states, self.v_proj.weight).view(-1, self.num_key_value_heads, self.head_dim)
|
value_states = torch.mm(hidden_states, self.v_proj_weight).view(-1, self.num_key_value_heads, self.head_dim)
|
||||||
else:
|
else:
|
||||||
# fused qkv
|
# fused qkv
|
||||||
token_nums = hidden_states.size(0)
|
token_nums = hidden_states.size(0)
|
||||||
|
@ -324,8 +325,7 @@ class NopadLlamaAttention(LlamaAttention):
|
||||||
sm_scale=sm_scale,
|
sm_scale=sm_scale,
|
||||||
)
|
)
|
||||||
|
|
||||||
attn_output = attn_output.view(-1, self.hidden_size)
|
attn_output = torch.mm(attn_output, self.o_proj_weight)
|
||||||
attn_output = torch.addmm(residual, attn_output, self.o_proj.weight)
|
|
||||||
|
|
||||||
return attn_output
|
return attn_output
|
||||||
|
|
||||||
|
@ -348,10 +348,11 @@ class NopadLlamaMLP(LlamaMLP):
|
||||||
mlp_dproj_w (torch.Tensor, optional): The transposed down_proj weight. Defaults to None.
|
mlp_dproj_w (torch.Tensor, optional): The transposed down_proj weight. Defaults to None.
|
||||||
"""
|
"""
|
||||||
super().__init__(config)
|
super().__init__(config)
|
||||||
self.gate_up_weight = Parameter(torch.stack([mlp_gproj_w, mlp_uproj_w], dim=0), requires_grad=False)
|
self.gate_up_weight = torch.stack([mlp_gproj_w, mlp_uproj_w], dim=0)
|
||||||
self.down_proj.weight = Parameter(mlp_dproj_w, requires_grad=False)
|
self.down_proj_weight = mlp_dproj_w
|
||||||
self.gate_proj = None
|
self.gate_proj = None
|
||||||
self.up_proj = None
|
self.up_proj = None
|
||||||
|
self.down_proj = None
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def from_native_module(module: LlamaMLP, *args, **kwargs) -> LlamaMLP:
|
def from_native_module(module: LlamaMLP, *args, **kwargs) -> LlamaMLP:
|
||||||
|
@ -375,14 +376,13 @@ class NopadLlamaMLP(LlamaMLP):
|
||||||
|
|
||||||
return mlp_layer
|
return mlp_layer
|
||||||
|
|
||||||
def forward(self, hidden_states: torch.Tensor, residual: torch.Tensor) -> torch.Tensor:
|
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
hidden_states (torch.Tensor): input to the layer of shape [token_num, embed_dim].
|
hidden_states (torch.Tensor): input to the layer of shape [token_num, embed_dim].
|
||||||
residual (torch.Tensor): shape [token_num, embed_dim], used to be added to hidden_states in down_proj.
|
|
||||||
"""
|
"""
|
||||||
hidden_states = hidden_states.expand(2, -1, -1)
|
hidden_states = hidden_states.expand(2, -1, -1)
|
||||||
gate_up_proj_out = torch.bmm(hidden_states, self.gate_up_weight)
|
gate_up_proj_out = torch.bmm(hidden_states, self.gate_up_weight)
|
||||||
act_out = torch.nn.functional.silu(gate_up_proj_out[0], inplace=True)
|
act_out = torch.nn.functional.silu(gate_up_proj_out[0], inplace=True)
|
||||||
tmp_out = act_out * gate_up_proj_out[1]
|
tmp_out = act_out * gate_up_proj_out[1]
|
||||||
return torch.addmm(residual, tmp_out, self.down_proj.weight)
|
return torch.mm(tmp_out, self.down_proj_weight)
|
||||||
|
|
|
@ -29,8 +29,10 @@ except:
|
||||||
def get_triton_rmsnorm_forward():
|
def get_triton_rmsnorm_forward():
|
||||||
if HAS_TRITON_RMSNORM:
|
if HAS_TRITON_RMSNORM:
|
||||||
|
|
||||||
def _triton_rmsnorm_forward(self: LlamaRMSNorm, hidden_states: torch.Tensor, norm_output: torch.Tensor):
|
def _triton_rmsnorm_forward(
|
||||||
return rms_layernorm(hidden_states, self.weight.data, self.variance_epsilon, norm_output)
|
self: LlamaRMSNorm, hidden_states: torch.Tensor, norm_output: torch.Tensor, residual: torch.Tensor = None
|
||||||
|
):
|
||||||
|
return rms_layernorm(hidden_states, self.weight.data, self.variance_epsilon, norm_output, residual)
|
||||||
|
|
||||||
return _triton_rmsnorm_forward
|
return _triton_rmsnorm_forward
|
||||||
else:
|
else:
|
||||||
|
|
|
@ -205,7 +205,7 @@ def context_attention_unpadded(
|
||||||
assert k_cache.shape == v_cache.shape
|
assert k_cache.shape == v_cache.shape
|
||||||
assert context_lengths.shape[0] == block_tables.shape[0]
|
assert context_lengths.shape[0] == block_tables.shape[0]
|
||||||
|
|
||||||
num_tokens, num_heads, _ = q.shape
|
num_tokens, num_heads, head_dim = q.shape
|
||||||
num_kv_heads = k.shape[-2]
|
num_kv_heads = k.shape[-2]
|
||||||
assert num_kv_heads > 0 and num_heads % num_kv_heads == 0
|
assert num_kv_heads > 0 and num_heads % num_kv_heads == 0
|
||||||
num_kv_group = num_heads // num_kv_heads
|
num_kv_group = num_heads // num_kv_heads
|
||||||
|
@ -213,7 +213,9 @@ def context_attention_unpadded(
|
||||||
num_seqs, max_blocks_per_seq = block_tables.shape
|
num_seqs, max_blocks_per_seq = block_tables.shape
|
||||||
max_seq_len = context_lengths.max().item() if max_seq_len is None else max_seq_len
|
max_seq_len = context_lengths.max().item() if max_seq_len is None else max_seq_len
|
||||||
sm_scale = 1.0 / (Lq**0.5) if sm_scale is None else sm_scale
|
sm_scale = 1.0 / (Lq**0.5) if sm_scale is None else sm_scale
|
||||||
output = torch.zeros_like(q) if output is None else output
|
output = (
|
||||||
|
torch.empty((num_tokens, num_heads * head_dim), dtype=q.dtype, device=q.device) if output is None else output
|
||||||
|
)
|
||||||
|
|
||||||
# NOTE For now, BLOCK_M and BLOCK_N are supposed to be equivalent with
|
# NOTE For now, BLOCK_M and BLOCK_N are supposed to be equivalent with
|
||||||
# the size of physical cache block (i.e. `block_size`)
|
# the size of physical cache block (i.e. `block_size`)
|
||||||
|
@ -243,8 +245,8 @@ def context_attention_unpadded(
|
||||||
v.stride(1),
|
v.stride(1),
|
||||||
v.stride(2),
|
v.stride(2),
|
||||||
output.stride(0),
|
output.stride(0),
|
||||||
output.stride(1),
|
head_dim,
|
||||||
output.stride(2),
|
1,
|
||||||
k_cache.stride(0),
|
k_cache.stride(0),
|
||||||
k_cache.stride(1),
|
k_cache.stride(1),
|
||||||
k_cache.stride(2),
|
k_cache.stride(2),
|
||||||
|
|
|
@ -211,7 +211,7 @@ def flash_decoding_attention(
|
||||||
records the (kv) sequence lengths incorporating past kv sequence lengths.
|
records the (kv) sequence lengths incorporating past kv sequence lengths.
|
||||||
block_tables (torch.Tensor): [batch_size, max_blocks_per_sequence]
|
block_tables (torch.Tensor): [batch_size, max_blocks_per_sequence]
|
||||||
max_seq_len_in_batch (int): Maximum sequence length in the batch.
|
max_seq_len_in_batch (int): Maximum sequence length in the batch.
|
||||||
output (torch.Tensor): [bsz, num_heads, head_dim]
|
output (torch.Tensor): [bsz, num_heads * head_dim]
|
||||||
mid_output (torch.Tensor): [ max_bsz , num_heads, kv_max_split_num, head_dim]
|
mid_output (torch.Tensor): [ max_bsz , num_heads, kv_max_split_num, head_dim]
|
||||||
Intermediate output tensor. `max_bsz` should be greater than or equal to `bsz`.
|
Intermediate output tensor. `max_bsz` should be greater than or equal to `bsz`.
|
||||||
mid_output_lse (torch.Tensor): [ max_bsz , num_heads, kv_max_split_num]
|
mid_output_lse (torch.Tensor): [ max_bsz , num_heads, kv_max_split_num]
|
||||||
|
@ -220,7 +220,7 @@ def flash_decoding_attention(
|
||||||
num_kv_group (int, optional): Number of key/value groups. Defaults to 1.
|
num_kv_group (int, optional): Number of key/value groups. Defaults to 1.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Output tensor with shape [bsz, num_heads, head_dim]
|
Output tensor with shape [bsz, num_heads * head_dim]
|
||||||
"""
|
"""
|
||||||
q = q.squeeze() if q.dim() == 4 else q
|
q = q.squeeze() if q.dim() == 4 else q
|
||||||
assert q.dim() == 3, f"Incompatible q dim: {q.dim()}"
|
assert q.dim() == 3, f"Incompatible q dim: {q.dim()}"
|
||||||
|
@ -261,7 +261,7 @@ def flash_decoding_attention(
|
||||||
# NOTE use `triton.next_power_of_2` here to utilize the cache mechanism of triton
|
# NOTE use `triton.next_power_of_2` here to utilize the cache mechanism of triton
|
||||||
# To optimize, revise batching/scheduling to batch 2^n sequences in a batch (preferred)
|
# To optimize, revise batching/scheduling to batch 2^n sequences in a batch (preferred)
|
||||||
grid = (triton.next_power_of_2(bsz), num_heads, triton.cdiv(triton.next_power_of_2(max_seq_len_in_batch), BLOCK_KV))
|
grid = (triton.next_power_of_2(bsz), num_heads, triton.cdiv(triton.next_power_of_2(max_seq_len_in_batch), BLOCK_KV))
|
||||||
output = torch.empty((bsz, num_heads, head_dim), dtype=q.dtype, device=q.device) if output is None else output
|
output = torch.empty((bsz, num_heads * head_dim), dtype=q.dtype, device=q.device) if output is None else output
|
||||||
|
|
||||||
_flash_decoding_fwd_kernel[grid](
|
_flash_decoding_fwd_kernel[grid](
|
||||||
q,
|
q,
|
||||||
|
@ -311,8 +311,8 @@ def flash_decoding_attention(
|
||||||
mid_output_lse.stride(1),
|
mid_output_lse.stride(1),
|
||||||
mid_output_lse.stride(2),
|
mid_output_lse.stride(2),
|
||||||
output.stride(0),
|
output.stride(0),
|
||||||
output.stride(1),
|
head_dim,
|
||||||
output.stride(2),
|
1,
|
||||||
BLOCK_KV=block_size,
|
BLOCK_KV=block_size,
|
||||||
HEAD_DIM=head_dim,
|
HEAD_DIM=head_dim,
|
||||||
)
|
)
|
||||||
|
|
|
@ -49,7 +49,50 @@ if HAS_TRITON:
|
||||||
# Write output
|
# Write output
|
||||||
tl.store(Y + cols, y.to(tl.float16), mask=mask)
|
tl.store(Y + cols, y.to(tl.float16), mask=mask)
|
||||||
|
|
||||||
def rms_layernorm(x, weight, eps, norm_output=None):
|
@triton.jit
|
||||||
|
def _rmsnorm_with_residual_kernel(
|
||||||
|
X, # pointer to the input
|
||||||
|
Y, # pointer to the output
|
||||||
|
R, # pointer to the residual
|
||||||
|
W, # pointer to the weights
|
||||||
|
stride, # how much to increase the pointer when moving by 1 row
|
||||||
|
N, # number of columns in X
|
||||||
|
eps, # epsilon to avoid division by zero
|
||||||
|
BLOCK_SIZE: tl.constexpr,
|
||||||
|
):
|
||||||
|
# This triton kernel implements Root Mean Square Layer Norm (RMSNorm).
|
||||||
|
|
||||||
|
# Map the program id to the row of X and Y it should compute.
|
||||||
|
row = tl.program_id(0)
|
||||||
|
Y += row * stride
|
||||||
|
X += row * stride
|
||||||
|
R += row * stride
|
||||||
|
# Compute variance
|
||||||
|
_var = tl.zeros([BLOCK_SIZE], dtype=tl.float32)
|
||||||
|
for off in range(0, N, BLOCK_SIZE):
|
||||||
|
cols = off + tl.arange(0, BLOCK_SIZE)
|
||||||
|
x = tl.load(X + cols, mask=cols < N, other=0.0).to(tl.float32)
|
||||||
|
x = tl.where(cols < N, x, 0.0)
|
||||||
|
r = tl.load(R + cols, mask=cols < N, other=0.0).to(tl.float32)
|
||||||
|
r = tl.where(cols < N, r, 0.0)
|
||||||
|
x = x + r
|
||||||
|
_var += x * x
|
||||||
|
mask = cols < N
|
||||||
|
tl.store(X + cols, x.to(tl.float16), mask=mask)
|
||||||
|
var = tl.sum(_var, axis=0) / N
|
||||||
|
rstd = 1 / tl.sqrt(var + eps)
|
||||||
|
# Normalize and apply linear transformation
|
||||||
|
for off in range(0, N, BLOCK_SIZE):
|
||||||
|
cols = off + tl.arange(0, BLOCK_SIZE)
|
||||||
|
mask = cols < N
|
||||||
|
w = tl.load(W + cols, mask=mask)
|
||||||
|
x = tl.load(X + cols, mask=mask, other=0.0).to(tl.float32)
|
||||||
|
x_hat = x * rstd
|
||||||
|
y = x_hat * w
|
||||||
|
# Write output
|
||||||
|
tl.store(Y + cols, y.to(tl.float16), mask=mask)
|
||||||
|
|
||||||
|
def rms_layernorm(x, weight, eps, norm_output=None, residual=None):
|
||||||
# allocate output
|
# allocate output
|
||||||
y = torch.empty_like(x) if norm_output is None else norm_output
|
y = torch.empty_like(x) if norm_output is None else norm_output
|
||||||
M, N = x.shape
|
M, N = x.shape
|
||||||
|
@ -64,5 +107,10 @@ if HAS_TRITON:
|
||||||
num_warps = min(max(triton.next_power_of_2(N) // 256, 8), 32)
|
num_warps = min(max(triton.next_power_of_2(N) // 256, 8), 32)
|
||||||
|
|
||||||
# enqueue kernel
|
# enqueue kernel
|
||||||
_rmsnorm_kernel[(M,)](x, y, weight, x.stride(0), N, eps, BLOCK_SIZE=BLOCK_SIZE, num_warps=num_warps)
|
if residual is None:
|
||||||
return y
|
_rmsnorm_kernel[(M,)](x, y, weight, x.stride(0), N, eps, BLOCK_SIZE=BLOCK_SIZE, num_warps=num_warps)
|
||||||
|
else:
|
||||||
|
_rmsnorm_with_residual_kernel[(M,)](
|
||||||
|
x, y, residual, weight, x.stride(0), N, eps, BLOCK_SIZE=BLOCK_SIZE, num_warps=num_warps
|
||||||
|
)
|
||||||
|
return y, x
|
||||||
|
|
|
@ -95,7 +95,7 @@ def benchmark_inference(args):
|
||||||
else:
|
else:
|
||||||
assert args.model_path, "When testing pretrained weights, the model path must be provided.'"
|
assert args.model_path, "When testing pretrained weights, the model path must be provided.'"
|
||||||
model = transformers.LlamaForCausalLM.from_pretrained(args.model_path).cuda()
|
model = transformers.LlamaForCausalLM.from_pretrained(args.model_path).cuda()
|
||||||
tokenizer = AutoTokenizer.from_pretrained(args.model_path)
|
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/llama-tokenizer")
|
||||||
|
|
||||||
model = model.eval()
|
model = model.eval()
|
||||||
|
|
||||||
|
@ -122,6 +122,7 @@ def benchmark_inference(args):
|
||||||
elif args.mode == "vllm":
|
elif args.mode == "vllm":
|
||||||
engine = LLM(
|
engine = LLM(
|
||||||
model=args.model_path,
|
model=args.model_path,
|
||||||
|
tokenizer="hf-internal-testing/llama-tokenizer",
|
||||||
max_num_seqs=mbsz,
|
max_num_seqs=mbsz,
|
||||||
dtype="float16",
|
dtype="float16",
|
||||||
enforce_eager=True,
|
enforce_eager=True,
|
||||||
|
|
|
@ -100,10 +100,14 @@ def test_context_attention(
|
||||||
k_cache_triton = torch.zeros_like(k_cache_ref)
|
k_cache_triton = torch.zeros_like(k_cache_ref)
|
||||||
v_cache_triton = torch.zeros_like(v_cache_ref)
|
v_cache_triton = torch.zeros_like(v_cache_ref)
|
||||||
|
|
||||||
|
_, num_heads, head_dim = q_unpad.shape
|
||||||
|
|
||||||
out_triton = context_attention_unpadded(
|
out_triton = context_attention_unpadded(
|
||||||
q_unpad, k_unpad, v_unpad, k_cache_triton, v_cache_triton, context_lengths, block_tables, block_size
|
q_unpad, k_unpad, v_unpad, k_cache_triton, v_cache_triton, context_lengths, block_tables, block_size
|
||||||
)
|
)
|
||||||
|
|
||||||
|
out_triton = out_triton.view(-1, num_heads, head_dim)
|
||||||
|
|
||||||
out_torch = torch_attn_unpad(q_unpad, k_unpad, v_unpad, context_lengths, num_attn_heads, num_kv_heads)
|
out_torch = torch_attn_unpad(q_unpad, k_unpad, v_unpad, context_lengths, num_attn_heads, num_kv_heads)
|
||||||
|
|
||||||
assert out_torch.shape == out_triton.shape
|
assert out_torch.shape == out_triton.shape
|
||||||
|
|
|
@ -3,6 +3,7 @@ import torch
|
||||||
import triton
|
import triton
|
||||||
from packaging import version
|
from packaging import version
|
||||||
from transformers.models.llama.modeling_llama import LlamaRMSNorm
|
from transformers.models.llama.modeling_llama import LlamaRMSNorm
|
||||||
|
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||||
|
|
||||||
from colossalai.kernel.triton import rms_layernorm
|
from colossalai.kernel.triton import rms_layernorm
|
||||||
from colossalai.testing.utils import parameterize
|
from colossalai.testing.utils import parameterize
|
||||||
|
@ -29,15 +30,28 @@ def test_layer_norm(M, N):
|
||||||
x_shape = (M, N)
|
x_shape = (M, N)
|
||||||
w_shape = (x_shape[-1],)
|
w_shape = (x_shape[-1],)
|
||||||
weight = torch.ones(w_shape, dtype=dtype, device="cuda")
|
weight = torch.ones(w_shape, dtype=dtype, device="cuda")
|
||||||
|
residual = torch.rand(x_shape, dtype=dtype, device="cuda")
|
||||||
|
residual_copy = residual.clone()
|
||||||
rms_norm = LlamaRMSNorm(hidden_size=N, eps=eps).cuda()
|
rms_norm = LlamaRMSNorm(hidden_size=N, eps=eps).cuda()
|
||||||
x = -2.3 + 0.5 * torch.randn(x_shape, dtype=dtype, device="cuda")
|
x = -2.3 + 0.5 * torch.randn(x_shape, dtype=dtype, device="cuda")
|
||||||
|
x_copy = x.clone()
|
||||||
|
|
||||||
y_triton = rms_layernorm(x, weight, eps=eps)
|
y_triton, _ = rms_layernorm(x, weight, eps=eps)
|
||||||
y_llama = rms_norm.forward(x).to(dtype)
|
y_llama = rms_norm.forward(x).to(dtype)
|
||||||
|
|
||||||
assert y_triton.shape == y_llama.shape
|
assert y_triton.shape == y_llama.shape
|
||||||
assert torch.allclose(y_triton, y_llama, atol=1e-5, rtol=1e-3)
|
assert torch.allclose(y_triton, y_llama, atol=1e-5, rtol=1e-3)
|
||||||
|
|
||||||
|
y_triton, residual = rms_layernorm(x, weight, eps=eps, residual=residual)
|
||||||
|
|
||||||
|
x = x_copy + residual_copy
|
||||||
|
|
||||||
|
y_llama = rms_norm.forward(x).to(dtype)
|
||||||
|
|
||||||
|
assert y_triton.shape == y_llama.shape
|
||||||
|
assert torch.allclose(y_triton, y_llama, atol=1e-5, rtol=1e-3)
|
||||||
|
assert torch.allclose(x, residual, atol=1e-5, rtol=1e-3)
|
||||||
|
|
||||||
|
|
||||||
# Triton benchmark plot attributions
|
# Triton benchmark plot attributions
|
||||||
configs = [
|
configs = [
|
||||||
|
@ -45,9 +59,19 @@ configs = [
|
||||||
x_names=["SEQUENCE_TOTAL"],
|
x_names=["SEQUENCE_TOTAL"],
|
||||||
x_vals=[i for i in range(128, 1025, 128)],
|
x_vals=[i for i in range(128, 1025, 128)],
|
||||||
line_arg="provider",
|
line_arg="provider",
|
||||||
line_vals=["torch_rms_layernorm", "triton_rms_layernorm"],
|
line_vals=[
|
||||||
line_names=["torch_rms_layernorm", "triton_rms_layernorm"],
|
"vllm_rms_layernorm",
|
||||||
styles=[("red", "-"), ("blue", "-")],
|
"triton_rms_layernorm",
|
||||||
|
"triton_rms_layernorm_with_residual",
|
||||||
|
"vllm_rms_layernorm_with_residual",
|
||||||
|
],
|
||||||
|
line_names=[
|
||||||
|
"vllm_rms_layernorm",
|
||||||
|
"triton_rms_layernorm",
|
||||||
|
"triton_rms_layernorm_with_residual",
|
||||||
|
"vllm_rms_layernorm_with_residual",
|
||||||
|
],
|
||||||
|
styles=[("red", "-"), ("blue", "-"), ("yellow", "-"), ("green", "-")],
|
||||||
ylabel="ms",
|
ylabel="ms",
|
||||||
plot_name=f"RMSNorm benchmarking results",
|
plot_name=f"RMSNorm benchmarking results",
|
||||||
args={"HIDDEN_SIZE": 1024},
|
args={"HIDDEN_SIZE": 1024},
|
||||||
|
@ -68,13 +92,18 @@ def benchmark_rms_layernorm(
|
||||||
eps = 1e-5
|
eps = 1e-5
|
||||||
x_shape = (SEQUENCE_TOTAL, HIDDEN_SIZE)
|
x_shape = (SEQUENCE_TOTAL, HIDDEN_SIZE)
|
||||||
w_shape = (x_shape[-1],)
|
w_shape = (x_shape[-1],)
|
||||||
|
residual = torch.rand(x_shape, dtype=dtype, device="cuda")
|
||||||
weight = torch.ones(w_shape, dtype=dtype, device="cuda")
|
weight = torch.ones(w_shape, dtype=dtype, device="cuda")
|
||||||
torch_norm = LlamaRMSNorm(hidden_size=HIDDEN_SIZE).to(dtype=dtype, device="cuda")
|
vllm_norm = RMSNorm(hidden_size=HIDDEN_SIZE, eps=eps).to(dtype=dtype, device="cuda")
|
||||||
x = -2.3 + 0.5 * torch.randn(x_shape, dtype=dtype, device="cuda")
|
x = -2.3 + 0.5 * torch.randn(x_shape, dtype=dtype, device="cuda")
|
||||||
if provider == "torch_rms_layernorm":
|
if provider == "vllm_rms_layernorm":
|
||||||
fn = lambda: torch_norm(x)
|
fn = lambda: vllm_norm(x)
|
||||||
elif provider == "triton_rms_layernorm":
|
elif provider == "triton_rms_layernorm":
|
||||||
fn = lambda: rms_layernorm(x, weight, eps=eps)
|
fn = lambda: rms_layernorm(x, weight, eps=eps)
|
||||||
|
elif provider == "vllm_rms_layernorm_with_residual":
|
||||||
|
fn = lambda: vllm_norm(x, residual=residual)
|
||||||
|
elif provider == "triton_rms_layernorm_with_residual":
|
||||||
|
fn = lambda: rms_layernorm(x, weight, eps=eps, residual=residual)
|
||||||
else:
|
else:
|
||||||
raise ValueError("Undefined provider.")
|
raise ValueError("Undefined provider.")
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue