mirror of https://github.com/hpcaitech/ColossalAI
aibig-modeldata-parallelismdeep-learningdistributed-computingfoundation-modelsheterogeneous-traininghpcinferencelarge-scalemodel-parallelismpipeline-parallelism
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
147 lines
4.8 KiB
147 lines
4.8 KiB
import torch |
|
import triton |
|
import triton.language as tl |
|
|
|
|
|
@triton.jit |
|
def prefill_cache_kernel( |
|
cos_cache, |
|
sin_cache, |
|
cumsum_lengths, |
|
cos_output, |
|
sin_output, |
|
cache_stride, |
|
hidden_stride, |
|
total_length, |
|
HIDDEN_DIM: tl.constexpr, |
|
N_ELEMENTS: tl.constexpr, |
|
BLOCK_SIZE: tl.constexpr, |
|
): |
|
idx0 = tl.program_id(axis=0) |
|
idx1 = tl.program_id(axis=1) |
|
idx = idx0 * BLOCK_SIZE + idx1 |
|
|
|
# original seq_idx and pos |
|
cumsum_lens = tl.load(cumsum_lengths + tl.arange(0, N_ELEMENTS)) |
|
ori_seq_idx = idx - tl.max(tl.where(cumsum_lens <= idx, cumsum_lens, 0)) |
|
cos_cache_part = tl.load( |
|
cos_cache + ori_seq_idx * cache_stride + tl.arange(0, HIDDEN_DIM) * hidden_stride, mask=idx < total_length |
|
) |
|
sin_cache_part = tl.load( |
|
sin_cache + ori_seq_idx * cache_stride + tl.arange(0, HIDDEN_DIM) * hidden_stride, mask=idx < total_length |
|
) |
|
tl.store( |
|
cos_output + idx * cache_stride + tl.arange(0, HIDDEN_DIM) * hidden_stride, |
|
cos_cache_part, |
|
mask=idx < total_length, |
|
) |
|
tl.store( |
|
sin_output + idx * cache_stride + tl.arange(0, HIDDEN_DIM) * hidden_stride, |
|
sin_cache_part, |
|
mask=idx < total_length, |
|
) |
|
|
|
|
|
@triton.jit |
|
def decoding_cache_kernel( |
|
cos_cache, |
|
sin_cache, |
|
lengths, |
|
cos_output, |
|
sin_output, |
|
cache_stride, |
|
hidden_stride, |
|
HIDDEN_DIM: tl.constexpr, |
|
NUM_SEQS: tl.constexpr, |
|
BLOCK_SIZE: tl.constexpr, |
|
): |
|
idx = tl.program_id(0) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) |
|
ori_seq_idx = tl.load(lengths + idx, mask=(idx < NUM_SEQS), other=None) # [BLOCK_SIZE,] |
|
cos_cache_part = tl.load( |
|
cos_cache + ori_seq_idx[:, None] * cache_stride + tl.arange(0, HIDDEN_DIM)[None, :] * hidden_stride, |
|
mask=idx[:, None] < NUM_SEQS, |
|
) |
|
sin_cache_part = tl.load( |
|
sin_cache + ori_seq_idx[:, None] * cache_stride + tl.arange(0, HIDDEN_DIM)[None, :] * hidden_stride, |
|
mask=idx[:, None] < NUM_SEQS, |
|
) |
|
tl.store( |
|
cos_output + (idx[:, None] * cache_stride + tl.arange(0, HIDDEN_DIM)[None, :] * hidden_stride), |
|
cos_cache_part, |
|
mask=idx[:, None] < NUM_SEQS, |
|
) |
|
tl.store( |
|
sin_output + (idx[:, None] * cache_stride + tl.arange(0, HIDDEN_DIM)[None, :] * hidden_stride), |
|
sin_cache_part, |
|
mask=idx[:, None] < NUM_SEQS, |
|
) |
|
|
|
|
|
def get_xine_cache(lengths: torch.Tensor, cos_cache: torch.Tensor, sin_cache: torch.Tensor, is_prompts: bool = False): |
|
""" |
|
Transform cos/sin cache into no pad sequence, with two different modes. |
|
Args: |
|
lengths: shape(num_seqs,), stores lenghth of each sequence. |
|
cache: shape(max_rotary_position(e.g.2048), head_dim), cos/sin cache constrcuted in model. |
|
is_prompts: bool, mark if in prefill mode. |
|
For prefill mode: |
|
cos/sin cache for each sequence is equal to its length. |
|
For decoding mode: |
|
cos/sin cache is only needed for the last token. |
|
""" |
|
assert cos_cache.shape[1] == sin_cache.shape[1] |
|
_, hidden_dim = cos_cache.shape |
|
num_seqs = lengths.numel() |
|
|
|
if hidden_dim >= 256: |
|
num_warps = 16 |
|
elif hidden_dim >= 128: |
|
num_warps = 8 |
|
else: |
|
num_warps = 4 |
|
|
|
cache_stride = cos_cache.stride(0) |
|
hidden_stride = cos_cache.stride(1) |
|
|
|
if is_prompts: |
|
BLOCK_SIZE = 16 |
|
total_length = lengths.sum().item() |
|
cumsum_lens = torch.cumsum(lengths, dim=0) |
|
cos_output = torch.empty((total_length, hidden_dim), dtype=cos_cache.dtype, device=cos_cache.device) |
|
sin_output = torch.empty((total_length, hidden_dim), dtype=sin_cache.dtype, device=sin_cache.device) |
|
grid = (triton.cdiv(total_length, BLOCK_SIZE), BLOCK_SIZE) |
|
prefill_cache_kernel[grid]( |
|
cos_cache, |
|
sin_cache, |
|
cumsum_lens, |
|
cos_output, |
|
sin_output, |
|
cache_stride, |
|
hidden_stride, |
|
total_length, |
|
HIDDEN_DIM=hidden_dim, |
|
N_ELEMENTS=triton.next_power_of_2(num_seqs), |
|
BLOCK_SIZE=BLOCK_SIZE, |
|
num_warps=num_warps, |
|
) |
|
else: |
|
BLOCK_SIZE = 4 |
|
nlengths = torch.as_tensor(lengths) - 1 |
|
cos_output = torch.empty((num_seqs, hidden_dim), dtype=cos_cache.dtype, device=cos_cache.device) |
|
sin_output = torch.empty((num_seqs, hidden_dim), dtype=sin_cache.dtype, device=sin_cache.device) |
|
grid = (triton.cdiv(num_seqs, BLOCK_SIZE),) |
|
decoding_cache_kernel[grid]( |
|
cos_cache, |
|
sin_cache, |
|
nlengths, |
|
cos_output, |
|
sin_output, |
|
cache_stride, |
|
hidden_stride, |
|
HIDDEN_DIM=hidden_dim, |
|
NUM_SEQS=num_seqs, |
|
BLOCK_SIZE=BLOCK_SIZE, |
|
num_warps=num_warps, |
|
) |
|
|
|
return cos_output, sin_output
|
|
|