mirror of https://github.com/hpcaitech/ColossalAI
94 lines
2.9 KiB
Python
94 lines
2.9 KiB
Python
|
# Adapted from ModelTC https://github.com/ModelTC/lightllm
|
||
|
import torch
|
||
|
import triton
|
||
|
import triton.language as tl
|
||
|
|
||
|
|
||
|
@triton.jit
|
||
|
def _rotary_kernel(
|
||
|
q,
|
||
|
Cos,
|
||
|
Sin,
|
||
|
q_bs_stride,
|
||
|
q_h_stride,
|
||
|
q_d_stride,
|
||
|
cos_bs_stride,
|
||
|
cos_d_stride,
|
||
|
total_len,
|
||
|
HEAD_NUM: tl.constexpr,
|
||
|
BLOCK_HEAD: tl.constexpr,
|
||
|
BLOCK_SEQ: tl.constexpr,
|
||
|
HEAD_DIM: tl.constexpr,
|
||
|
):
|
||
|
current_head_index = tl.program_id(0)
|
||
|
current_seq_index = tl.program_id(1)
|
||
|
|
||
|
current_head_range = current_head_index * BLOCK_HEAD + tl.arange(0, BLOCK_HEAD)
|
||
|
current_seq_range = current_seq_index * BLOCK_SEQ + tl.arange(0, BLOCK_SEQ)
|
||
|
|
||
|
dim_range0 = tl.arange(0, HEAD_DIM // 2)
|
||
|
dim_range1 = tl.arange(HEAD_DIM // 2, HEAD_DIM)
|
||
|
|
||
|
off_q0 = current_seq_range[:, None, None] * q_bs_stride + current_head_range[
|
||
|
None, :, None] * q_h_stride + dim_range0[None, None, :] * q_d_stride
|
||
|
off_q1 = current_seq_range[:, None, None] * q_bs_stride + current_head_range[
|
||
|
None, :, None] * q_h_stride + dim_range1[None, None, :] * q_d_stride
|
||
|
|
||
|
off_dimcos_sin = current_seq_range[:, None, None] * cos_bs_stride + dim_range0[None, None, :] * cos_d_stride
|
||
|
|
||
|
q0 = tl.load(q + off_q0,
|
||
|
mask=(current_seq_range[:, None, None] < total_len) & (current_head_range[None, :, None] < HEAD_NUM),
|
||
|
other=0.0)
|
||
|
q1 = tl.load(q + off_q1,
|
||
|
mask=(current_seq_range[:, None, None] < total_len) & (current_head_range[None, :, None] < HEAD_NUM),
|
||
|
other=0.0)
|
||
|
|
||
|
cos = tl.load(Cos + off_dimcos_sin, mask=current_seq_range[:, None, None] < total_len, other=0.0)
|
||
|
sin = tl.load(Sin + off_dimcos_sin, mask=current_seq_range[:, None, None] < total_len, other=0.0)
|
||
|
|
||
|
out0 = q0 * cos - q1 * sin
|
||
|
out1 = q0 * sin + q1 * cos
|
||
|
|
||
|
tl.store(q + off_q0,
|
||
|
out0,
|
||
|
mask=(current_seq_range[:, None, None] < total_len) & (current_head_range[None, :, None] < HEAD_NUM))
|
||
|
tl.store(q + off_q1,
|
||
|
out1,
|
||
|
mask=(current_seq_range[:, None, None] < total_len) & (current_head_range[None, :, None] < HEAD_NUM))
|
||
|
|
||
|
return
|
||
|
|
||
|
|
||
|
@torch.no_grad()
|
||
|
def rotary_embedding_fwd(q, cos, sin):
|
||
|
total_len = q.shape[0]
|
||
|
head_num = q.shape[1]
|
||
|
head_dim = q.shape[2]
|
||
|
assert q.shape[0] == cos.shape[0] and q.shape[0] == sin.shape[0], f"q shape {q.shape} cos shape {cos.shape}"
|
||
|
BLOCK_HEAD = 4
|
||
|
BLOCK_SEQ = 32
|
||
|
grid = (triton.cdiv(head_num, BLOCK_HEAD), triton.cdiv(total_len, BLOCK_SEQ))
|
||
|
if head_dim >= 128:
|
||
|
num_warps = 8
|
||
|
else:
|
||
|
num_warps = 4
|
||
|
|
||
|
_rotary_kernel[grid](
|
||
|
q,
|
||
|
cos,
|
||
|
sin,
|
||
|
q.stride(0),
|
||
|
q.stride(1),
|
||
|
q.stride(2),
|
||
|
cos.stride(0),
|
||
|
cos.stride(1),
|
||
|
total_len,
|
||
|
HEAD_NUM=head_num,
|
||
|
BLOCK_HEAD=BLOCK_HEAD,
|
||
|
BLOCK_SEQ=BLOCK_SEQ,
|
||
|
HEAD_DIM=head_dim,
|
||
|
num_warps=num_warps,
|
||
|
num_stages=1,
|
||
|
)
|
||
|
return
|