mirror of https://github.com/hpcaitech/ColossalAI
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
67 lines
2.5 KiB
67 lines
2.5 KiB
import pytest
|
|
import torch
|
|
from packaging import version
|
|
|
|
from colossalai.kernel.triton import get_xine_cache
|
|
|
|
try:
|
|
import triton # noqa
|
|
|
|
HAS_TRITON = True
|
|
except ImportError:
|
|
HAS_TRITON = False
|
|
print("please install triton from https://github.com/openai/triton")
|
|
|
|
TRITON_CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse("11.4")
|
|
|
|
|
|
@torch.no_grad()
|
|
def get_cos_sin(lengths, cos_cache, sin_cache, is_prompts, dtype):
|
|
"""
|
|
Get cos and sin for the cache, and return nopad format.
|
|
Args:
|
|
lengths: shape(num_seqs,), stores lenghth of each sequence.
|
|
cos_cache: shape(max_rotary_position(e.g.2048), head_dim), cos cache constrcuted in model.
|
|
sin_cache: shape(max_rotary_position(e.g.2048), head_dim), sin cache constrcuted in model.
|
|
is_prompts: bool, mark if in prefill mode.
|
|
dtype: The data type of this inference process.
|
|
"""
|
|
|
|
if is_prompts:
|
|
index_arrays = [torch.arange(length) for length in lengths]
|
|
else:
|
|
index_arrays = [(length - 1).view(-1) for length in lengths]
|
|
indices = torch.cat(index_arrays, dim=-1)
|
|
cos_output = cos_cache[indices].to(dtype=dtype)
|
|
sin_output = sin_cache[indices].to(dtype=dtype)
|
|
|
|
return (cos_output, sin_output)
|
|
|
|
|
|
@pytest.mark.skipif(
|
|
not TRITON_CUDA_SUPPORT or not HAS_TRITON, reason="triton requires cuda version to be higher than 11.4"
|
|
)
|
|
@pytest.mark.parametrize("BATCH_SIZE", [4])
|
|
@pytest.mark.parametrize("MAX_SEQ_LEN", [64])
|
|
@pytest.mark.parametrize("HEAD_DIM", [64])
|
|
@pytest.mark.parametrize("dtype", [torch.float32])
|
|
def test_get_xine_cache(BATCH_SIZE, MAX_SEQ_LEN, HEAD_DIM, dtype):
|
|
MAX_TOTAL_TOKENS = BATCH_SIZE * MAX_SEQ_LEN
|
|
cos_cache = torch.randn((MAX_TOTAL_TOKENS, HEAD_DIM), dtype=dtype, device="cuda")
|
|
sin_cache = torch.randn((MAX_TOTAL_TOKENS, HEAD_DIM), dtype=dtype, device="cuda")
|
|
lengths = torch.randint(2, MAX_SEQ_LEN, (BATCH_SIZE,), device="cuda")
|
|
# prefill
|
|
cos_ref, sin_ref = get_cos_sin(lengths, cos_cache, sin_cache, is_prompts=True, dtype=dtype)
|
|
cos, sin = get_xine_cache(lengths, cos_cache, sin_cache, is_prompts=True)
|
|
assert torch.allclose(cos, cos_ref)
|
|
assert torch.allclose(sin, sin_ref)
|
|
# decoding
|
|
ncos_ref, nsin_ref = get_cos_sin(lengths, cos_cache, sin_cache, is_prompts=False, dtype=dtype)
|
|
cos, sin = get_xine_cache(lengths, cos_cache, sin_cache, is_prompts=False)
|
|
assert torch.allclose(cos, ncos_ref)
|
|
assert torch.allclose(sin, nsin_ref)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
test_get_xine_cache(4, 64, 256, torch.float32)
|