mirror of https://github.com/hpcaitech/ColossalAI
fix (#5311)
parent
4f28cb43c0
commit
7ddd8b37f0
|
@ -136,7 +136,7 @@ def fused_rotary_embedding(
|
|||
q_total_tokens, q_head_num, head_dim = q.shape
|
||||
assert q.size(0) == k.size(0)
|
||||
BLOCK_HEAD = 4
|
||||
BLOCK_SIZE = 16
|
||||
BLOCK_SIZE = 8
|
||||
cumsum_lens = torch.cumsum(lengths, dim=0)
|
||||
|
||||
grid = (triton.cdiv(q_head_num, BLOCK_HEAD), triton.cdiv(q_total_tokens, BLOCK_SIZE), BLOCK_SIZE)
|
||||
|
|
|
@ -2,6 +2,22 @@ import torch
|
|||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
"""
|
||||
# Base autotune if needed
|
||||
@triton.autotune(
|
||||
configs=[
|
||||
triton.Config({'BLOCK_HEAD':4,"BLOCK_TOKENS":4,},num_warps=4),
|
||||
triton.Config({'BLOCK_HEAD':4,"BLOCK_TOKENS":8,},num_warps=8),
|
||||
triton.Config({'BLOCK_HEAD':8,"BLOCK_TOKENS":8,},num_warps=8),
|
||||
triton.Config({'BLOCK_HEAD':4,"BLOCK_TOKENS":4,},num_warps=16),
|
||||
triton.Config({'BLOCK_HEAD':4,"BLOCK_TOKENS":4,},num_warps=32),
|
||||
triton.Config({'BLOCK_HEAD':16,"BLOCK_TOKENS":16,},num_warps=4),
|
||||
triton.Config({'BLOCK_HEAD':8,"BLOCK_TOKENS":16,},num_warps=8),
|
||||
],
|
||||
key=['HEAD_DIM','q_total_tokens','Q_HEAD_NUM']
|
||||
)
|
||||
"""
|
||||
|
||||
|
||||
@triton.jit
|
||||
def rotary_embedding_kernel(
|
||||
|
@ -26,43 +42,53 @@ def rotary_embedding_kernel(
|
|||
block_head_index = tl.program_id(0)
|
||||
block_token_index = tl.program_id(1)
|
||||
|
||||
rotary_data = q
|
||||
HEAD_NUM = Q_HEAD_NUM
|
||||
head_stride = q_head_stride
|
||||
token_stride = q_token_stride
|
||||
|
||||
if block_token_index * BLOCK_TOKENS >= q_total_tokens:
|
||||
block_token_index = block_token_index - tl.cdiv(q_total_tokens, BLOCK_TOKENS)
|
||||
rotary_data = k
|
||||
HEAD_NUM = K_HEAD_NUM
|
||||
head_stride = k_head_stride
|
||||
token_stride = k_token_stride
|
||||
|
||||
tokens_range = block_token_index * BLOCK_TOKENS + tl.arange(0, BLOCK_TOKENS)
|
||||
head_range = block_head_index * BLOCK_HEAD + tl.arange(0, BLOCK_HEAD)
|
||||
|
||||
dim_range0 = tl.arange(0, HEAD_DIM // 2)
|
||||
dim_range1 = tl.arange(HEAD_DIM // 2, HEAD_DIM)
|
||||
|
||||
off_data0 = (
|
||||
tokens_range[:, None, None] * token_stride
|
||||
+ head_range[None, :, None] * head_stride
|
||||
off_q0 = (
|
||||
tokens_range[:, None, None] * q_token_stride
|
||||
+ head_range[None, :, None] * q_head_stride
|
||||
+ dim_range0[None, None, :] * head_dim_stride
|
||||
)
|
||||
off_data1 = (
|
||||
tokens_range[:, None, None] * token_stride
|
||||
+ head_range[None, :, None] * head_stride
|
||||
off_q1 = (
|
||||
tokens_range[:, None, None] * q_token_stride
|
||||
+ head_range[None, :, None] * q_head_stride
|
||||
+ dim_range1[None, None, :] * head_dim_stride
|
||||
)
|
||||
off_k0 = (
|
||||
tokens_range[:, None, None] * k_token_stride
|
||||
+ head_range[None, :, None] * k_head_stride
|
||||
+ dim_range0[None, None, :] * head_dim_stride
|
||||
)
|
||||
off_k1 = (
|
||||
tokens_range[:, None, None] * k_token_stride
|
||||
+ head_range[None, :, None] * k_head_stride
|
||||
+ dim_range1[None, None, :] * head_dim_stride
|
||||
)
|
||||
|
||||
loaded_data0 = tl.load(
|
||||
rotary_data + off_data0,
|
||||
mask=((head_range[None, :, None] < HEAD_NUM) & (tokens_range[:, None, None] < q_total_tokens)),
|
||||
loaded_q0 = tl.load(
|
||||
q + off_q0,
|
||||
mask=((head_range[None, :, None] < Q_HEAD_NUM) & (tokens_range[:, None, None] < q_total_tokens)),
|
||||
other=0.0,
|
||||
)
|
||||
loaded_data1 = tl.load(
|
||||
rotary_data + off_data1,
|
||||
mask=((head_range[None, :, None] < HEAD_NUM) & (tokens_range[:, None, None] < q_total_tokens)),
|
||||
loaded_q1 = tl.load(
|
||||
q + off_q1,
|
||||
mask=((head_range[None, :, None] < Q_HEAD_NUM) & (tokens_range[:, None, None] < q_total_tokens)),
|
||||
other=0.0,
|
||||
)
|
||||
|
||||
loaded_k0 = tl.load(
|
||||
k + off_k0,
|
||||
mask=((head_range[None, :, None] < K_HEAD_NUM) & (tokens_range[:, None, None] < q_total_tokens)),
|
||||
other=0.0,
|
||||
)
|
||||
|
||||
loaded_k1 = tl.load(
|
||||
k + off_k1,
|
||||
mask=((head_range[None, :, None] < K_HEAD_NUM) & (tokens_range[:, None, None] < q_total_tokens)),
|
||||
other=0.0,
|
||||
)
|
||||
|
||||
|
@ -71,19 +97,32 @@ def rotary_embedding_kernel(
|
|||
loaded_cos = tl.load(cos + off_cos_sin, mask=(tokens_range[:, None] < q_total_tokens), other=0.0)
|
||||
loaded_sin = tl.load(sin + off_cos_sin, mask=(tokens_range[:, None] < q_total_tokens), other=0.0)
|
||||
|
||||
out0 = loaded_data0 * loaded_cos[:, None, :] - loaded_data1 * loaded_sin[:, None, :]
|
||||
out1 = loaded_data0 * loaded_sin[:, None, :] + loaded_data1 * loaded_cos[:, None, :]
|
||||
out_q0 = loaded_q0 * loaded_cos[:, None, :] - loaded_q1 * loaded_sin[:, None, :]
|
||||
out_q1 = loaded_q0 * loaded_sin[:, None, :] + loaded_q1 * loaded_cos[:, None, :]
|
||||
|
||||
out_k0 = loaded_k0 * loaded_cos[:, None, :] - loaded_k1 * loaded_sin[:, None, :]
|
||||
out_k1 = loaded_k0 * loaded_sin[:, None, :] + loaded_k1 * loaded_cos[:, None, :]
|
||||
|
||||
# concat
|
||||
tl.store(
|
||||
rotary_data + off_data0,
|
||||
out0,
|
||||
mask=((head_range[None, :, None] < HEAD_NUM) & (tokens_range[:, None, None] < q_total_tokens)),
|
||||
q + off_q0,
|
||||
out_q0,
|
||||
mask=((head_range[None, :, None] < Q_HEAD_NUM) & (tokens_range[:, None, None] < q_total_tokens)),
|
||||
)
|
||||
tl.store(
|
||||
rotary_data + off_data1,
|
||||
out1,
|
||||
mask=((head_range[None, :, None] < HEAD_NUM) & (tokens_range[:, None, None] < q_total_tokens)),
|
||||
q + off_q1,
|
||||
out_q1,
|
||||
mask=((head_range[None, :, None] < Q_HEAD_NUM) & (tokens_range[:, None, None] < q_total_tokens)),
|
||||
)
|
||||
tl.store(
|
||||
k + off_k0,
|
||||
out_k0,
|
||||
mask=((head_range[None, :, None] < K_HEAD_NUM) & (tokens_range[:, None, None] < q_total_tokens)),
|
||||
)
|
||||
tl.store(
|
||||
k + off_k1,
|
||||
out_k1,
|
||||
mask=((head_range[None, :, None] < K_HEAD_NUM) & (tokens_range[:, None, None] < q_total_tokens)),
|
||||
)
|
||||
|
||||
|
||||
|
@ -105,11 +144,13 @@ def rotary_embedding(
|
|||
q_total_tokens, q_head_num, head_dim = q.shape
|
||||
assert q.size(0) == k.size(0)
|
||||
BLOCK_HEAD = 4
|
||||
BLOCK_TOKENS = 8
|
||||
grid = (triton.cdiv(q_head_num, BLOCK_HEAD), 2 * triton.cdiv(q_total_tokens, BLOCK_TOKENS))
|
||||
BLOCK_TOKENS = 4
|
||||
grid = lambda META: (triton.cdiv(q_head_num, META["BLOCK_HEAD"]), triton.cdiv(q_total_tokens, META["BLOCK_TOKENS"]))
|
||||
|
||||
if head_dim >= 128:
|
||||
num_warps = 8
|
||||
if head_dim >= 256:
|
||||
num_warps = 32
|
||||
elif head_dim >= 128:
|
||||
num_warps = 16
|
||||
else:
|
||||
num_warps = 4
|
||||
|
||||
|
@ -144,7 +185,6 @@ def rotary_embedding(
|
|||
BLOCK_HEAD=BLOCK_HEAD,
|
||||
BLOCK_TOKENS=BLOCK_TOKENS,
|
||||
num_warps=num_warps,
|
||||
num_stages=1,
|
||||
)
|
||||
|
||||
return
|
||||
|
|
|
@ -5,9 +5,11 @@ import triton.language as tl
|
|||
|
||||
@triton.jit
|
||||
def prefill_cache_kernel(
|
||||
CaChe,
|
||||
cos_cache,
|
||||
sin_cache,
|
||||
cumsum_lengths,
|
||||
output,
|
||||
cos_output,
|
||||
sin_output,
|
||||
cache_stride,
|
||||
hidden_stride,
|
||||
total_length,
|
||||
|
@ -22,15 +24,31 @@ def prefill_cache_kernel(
|
|||
# original seq_idx and pos
|
||||
cumsum_lens = tl.load(cumsum_lengths + tl.arange(0, N_ELEMENTS))
|
||||
ori_seq_idx = idx - tl.max(tl.where(cumsum_lens <= idx, cumsum_lens, 0))
|
||||
_cache = tl.load(CaChe + ori_seq_idx * cache_stride + tl.arange(0, HIDDEN_DIM) * hidden_stride)
|
||||
tl.store(output + idx * cache_stride + tl.arange(0, HIDDEN_DIM) * hidden_stride, _cache, mask=idx < total_length)
|
||||
cos_cache_part = tl.load(
|
||||
cos_cache + ori_seq_idx * cache_stride + tl.arange(0, HIDDEN_DIM) * hidden_stride, mask=idx < total_length
|
||||
)
|
||||
sin_cache_part = tl.load(
|
||||
sin_cache + ori_seq_idx * cache_stride + tl.arange(0, HIDDEN_DIM) * hidden_stride, mask=idx < total_length
|
||||
)
|
||||
tl.store(
|
||||
cos_output + idx * cache_stride + tl.arange(0, HIDDEN_DIM) * hidden_stride,
|
||||
cos_cache_part,
|
||||
mask=idx < total_length,
|
||||
)
|
||||
tl.store(
|
||||
sin_output + idx * cache_stride + tl.arange(0, HIDDEN_DIM) * hidden_stride,
|
||||
sin_cache_part,
|
||||
mask=idx < total_length,
|
||||
)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def decoding_cache_kernel(
|
||||
CaChe,
|
||||
cos_cache,
|
||||
sin_cache,
|
||||
lengths,
|
||||
output,
|
||||
cos_output,
|
||||
sin_output,
|
||||
cache_stride,
|
||||
hidden_stride,
|
||||
HIDDEN_DIM: tl.constexpr,
|
||||
|
@ -39,16 +57,28 @@ def decoding_cache_kernel(
|
|||
):
|
||||
idx = tl.program_id(0) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
|
||||
ori_seq_idx = tl.load(lengths + idx, mask=(idx < NUM_SEQS), other=None) # [BLOCK_SIZE,]
|
||||
_cache = tl.load(CaChe + ori_seq_idx[:, None] * cache_stride + tl.arange(0, HIDDEN_DIM)[None, :] * hidden_stride)
|
||||
cos_cache_part = tl.load(
|
||||
cos_cache + ori_seq_idx[:, None] * cache_stride + tl.arange(0, HIDDEN_DIM)[None, :] * hidden_stride,
|
||||
mask=idx[:, None] < NUM_SEQS,
|
||||
)
|
||||
sin_cache_part = tl.load(
|
||||
sin_cache + ori_seq_idx[:, None] * cache_stride + tl.arange(0, HIDDEN_DIM)[None, :] * hidden_stride,
|
||||
mask=idx[:, None] < NUM_SEQS,
|
||||
)
|
||||
tl.store(
|
||||
output + (idx[:, None] * cache_stride + tl.arange(0, HIDDEN_DIM)[None, :] * hidden_stride),
|
||||
_cache,
|
||||
cos_output + (idx[:, None] * cache_stride + tl.arange(0, HIDDEN_DIM)[None, :] * hidden_stride),
|
||||
cos_cache_part,
|
||||
mask=idx[:, None] < NUM_SEQS,
|
||||
)
|
||||
tl.store(
|
||||
sin_output + (idx[:, None] * cache_stride + tl.arange(0, HIDDEN_DIM)[None, :] * hidden_stride),
|
||||
sin_cache_part,
|
||||
mask=idx[:, None] < NUM_SEQS,
|
||||
)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def get_xine_cache(lengths: torch.Tensor, cache: torch.Tensor, is_prompts: bool = False):
|
||||
def get_xine_cache(lengths: torch.Tensor, cos_cache: torch.Tensor, sin_cache: torch.Tensor, is_prompts: bool = False):
|
||||
"""
|
||||
Transform cos/sin cache into no pad sequence, with two different modes.
|
||||
Args:
|
||||
|
@ -60,28 +90,33 @@ def get_xine_cache(lengths: torch.Tensor, cache: torch.Tensor, is_prompts: bool
|
|||
For decoding mode:
|
||||
cos/sin cache is only needed for the last token.
|
||||
"""
|
||||
|
||||
_, hidden_dim = cache.shape
|
||||
assert cos_cache.shape[1] == sin_cache.shape[1]
|
||||
_, hidden_dim = cos_cache.shape
|
||||
num_seqs = lengths.numel()
|
||||
|
||||
BLOCK_SIZE = 16
|
||||
if hidden_dim >= 128:
|
||||
if hidden_dim >= 256:
|
||||
num_warps = 16
|
||||
elif hidden_dim >= 128:
|
||||
num_warps = 8
|
||||
else:
|
||||
num_warps = 4
|
||||
|
||||
cache_stride = cache.stride(0)
|
||||
hidden_stride = cache.stride(1)
|
||||
cache_stride = cos_cache.stride(0)
|
||||
hidden_stride = cos_cache.stride(1)
|
||||
|
||||
if is_prompts:
|
||||
BLOCK_SIZE = 16
|
||||
total_length = lengths.sum().item()
|
||||
cumsum_lens = torch.cumsum(lengths, dim=0)
|
||||
output = torch.empty((total_length, hidden_dim), dtype=cache.dtype, device=cache.device)
|
||||
cos_output = torch.empty((total_length, hidden_dim), dtype=cos_cache.dtype, device=cos_cache.device)
|
||||
sin_output = torch.empty((total_length, hidden_dim), dtype=sin_cache.dtype, device=sin_cache.device)
|
||||
grid = (triton.cdiv(total_length, BLOCK_SIZE), BLOCK_SIZE)
|
||||
prefill_cache_kernel[grid](
|
||||
cache,
|
||||
cos_cache,
|
||||
sin_cache,
|
||||
cumsum_lens,
|
||||
output,
|
||||
cos_output,
|
||||
sin_output,
|
||||
cache_stride,
|
||||
hidden_stride,
|
||||
total_length,
|
||||
|
@ -91,14 +126,17 @@ def get_xine_cache(lengths: torch.Tensor, cache: torch.Tensor, is_prompts: bool
|
|||
num_warps=num_warps,
|
||||
)
|
||||
else:
|
||||
# BUG: get memory access error whe using a deepcopy lengths to replace lengths
|
||||
BLOCK_SIZE = 4
|
||||
nlengths = torch.as_tensor(lengths) - 1
|
||||
output = torch.empty((num_seqs, hidden_dim), dtype=cache.dtype, device=cache.device)
|
||||
cos_output = torch.empty((num_seqs, hidden_dim), dtype=cos_cache.dtype, device=cos_cache.device)
|
||||
sin_output = torch.empty((num_seqs, hidden_dim), dtype=sin_cache.dtype, device=sin_cache.device)
|
||||
grid = (triton.cdiv(num_seqs, BLOCK_SIZE),)
|
||||
decoding_cache_kernel[grid](
|
||||
cache,
|
||||
cos_cache,
|
||||
sin_cache,
|
||||
nlengths,
|
||||
output,
|
||||
cos_output,
|
||||
sin_output,
|
||||
cache_stride,
|
||||
hidden_stride,
|
||||
HIDDEN_DIM=hidden_dim,
|
||||
|
@ -107,4 +145,4 @@ def get_xine_cache(lengths: torch.Tensor, cache: torch.Tensor, is_prompts: bool
|
|||
num_warps=num_warps,
|
||||
)
|
||||
|
||||
return output
|
||||
return cos_output, sin_output
|
||||
|
|
|
@ -39,8 +39,8 @@ configs = [
|
|||
x_names=["max_num_tokens"],
|
||||
x_vals=[2**i for i in range(6, 12)],
|
||||
line_arg="provider",
|
||||
line_vals=["torch_get_cos_sin_func", "triton_get_xine_func"],
|
||||
line_names=["torch_get_cos_sin_func", "triton_get_xine_func"],
|
||||
line_vals=["torch_get_cos_sin", "triton_get_cos_sin"],
|
||||
line_names=["torch_get_cos_sin", "triton_get_cos_sin"],
|
||||
styles=[("red", "-"), ("blue", "-")],
|
||||
ylabel="ms",
|
||||
plot_name="Get_cos-sin_func",
|
||||
|
@ -58,19 +58,15 @@ def benchmark_get_xine_cache(
|
|||
):
|
||||
warmup = 10
|
||||
rep = 1000
|
||||
max_token_per_seq = max_num_tokens // batch_size
|
||||
dtype = torch.float16
|
||||
cos_cache = torch.randn((max_num_tokens, head_dim), dtype=dtype, device="cuda")
|
||||
sin_cache = torch.randn((max_num_tokens, head_dim), dtype=dtype, device="cuda")
|
||||
lengths = torch.randint(2, max_token_per_seq, (batch_size,), device="cuda")
|
||||
cos_cache = torch.randn((8912, head_dim), dtype=dtype, device="cuda")
|
||||
sin_cache = torch.randn((8912, head_dim), dtype=dtype, device="cuda")
|
||||
lengths = torch.randint(2, max_num_tokens, (batch_size,), device="cuda")
|
||||
|
||||
if provider == "torch_get_cos_sin_func":
|
||||
fn = lambda: get_cos_sin(lengths, cos_cache, sin_cache, is_prompts=False, dtype=dtype)
|
||||
elif provider == "triton_get_xine_func":
|
||||
fn = lambda: [
|
||||
get_xine_cache(lengths, cos_cache, is_prompts=False),
|
||||
get_xine_cache(lengths, sin_cache, is_prompts=False),
|
||||
]
|
||||
if provider == "torch_get_cos_sin":
|
||||
fn = lambda: get_cos_sin(lengths, cos_cache, sin_cache, is_prompts=True, dtype=dtype)
|
||||
elif provider == "triton_get_cos_sin":
|
||||
fn = lambda: get_xine_cache(lengths, cos_cache, sin_cache, is_prompts=True)
|
||||
else:
|
||||
raise ValueError("Undefined provider")
|
||||
|
||||
|
|
Loading…
Reference in New Issue