mirror of https://github.com/hpcaitech/ColossalAI
57 lines
1.6 KiB
Python
57 lines
1.6 KiB
Python
|
# Adapted from ModelTC https://github.com/ModelTC/lightllm
|
||
|
|
||
|
import time
|
||
|
|
||
|
import pytest
|
||
|
import torch
|
||
|
from packaging import version
|
||
|
|
||
|
try:
|
||
|
import triton
|
||
|
import triton.language as tl
|
||
|
|
||
|
from colossalai.kernel.triton.rotary_embedding_kernel import rotary_embedding_fwd
|
||
|
|
||
|
HAS_TRITON = True
|
||
|
except ImportError:
|
||
|
HAS_TRITON = False
|
||
|
print("please install triton from https://github.com/openai/triton")
|
||
|
|
||
|
TRITON_CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse('11.4')
|
||
|
|
||
|
|
||
|
def torch_rotary_emb(x, cos, sin):
|
||
|
seq_len, h, dim = x.shape
|
||
|
x0 = x[:, :, 0:dim // 2]
|
||
|
x1 = x[:, :, dim // 2:dim]
|
||
|
cos = cos.view((seq_len, 1, dim // 2))
|
||
|
sin = sin.view((seq_len, 1, dim // 2))
|
||
|
o0 = x0 * cos - x1 * sin
|
||
|
o1 = x0 * sin + x1 * cos
|
||
|
return torch.cat((o0, o1), dim=-1)
|
||
|
|
||
|
|
||
|
@pytest.mark.skipif(not TRITON_CUDA_SUPPORT or not HAS_TRITON,
|
||
|
reason="triton requires cuda version to be higher than 11.4")
|
||
|
def test_rotary_emb():
|
||
|
SEQ_LEN = 1
|
||
|
HEAD_NUM = 32
|
||
|
HEAD_DIM = 128
|
||
|
dtype = torch.half
|
||
|
# create data
|
||
|
x_shape = (SEQ_LEN, HEAD_NUM, HEAD_DIM)
|
||
|
x = -2.3 + 0.5 * torch.randn(x_shape, dtype=dtype, device='cuda')
|
||
|
cos_shape = (SEQ_LEN, HEAD_DIM // 2)
|
||
|
cos = -1.2 + 0.5 * torch.randn(cos_shape, dtype=dtype, device='cuda')
|
||
|
sin = -2.0 + 0.5 * torch.randn(cos_shape, dtype=dtype, device='cuda')
|
||
|
# forward pass
|
||
|
y_torch = torch_rotary_emb(x, cos, sin)
|
||
|
rotary_embedding_fwd(x, cos, sin)
|
||
|
y_triton = x
|
||
|
# compare
|
||
|
assert torch.allclose(y_torch, y_triton, atol=1e-2, rtol=0)
|
||
|
|
||
|
|
||
|
if __name__ == "__main__":
|
||
|
test_rotary_emb()
|