mirror of https://github.com/hpcaitech/ColossalAI
144 lines
4.2 KiB
Python
144 lines
4.2 KiB
Python
import pytest
|
|
import torch
|
|
import torch.nn.functional as F
|
|
from packaging import version
|
|
|
|
try:
|
|
import triton
|
|
|
|
from colossalai.kernel.triton.qkv_matmul_kernel import qkv_gemm_4d_kernel
|
|
from colossalai.kernel.triton.self_attention_nofusion import self_attention_compute_using_triton
|
|
|
|
HAS_TRITON = True
|
|
except ImportError:
|
|
HAS_TRITON = False
|
|
print("please install triton from https://github.com/openai/triton")
|
|
|
|
TRITON_CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse("11.4")
|
|
|
|
|
|
@pytest.mark.skipif(
|
|
not TRITON_CUDA_SUPPORT or not HAS_TRITON, reason="triton requires cuda version to be higher than 11.4"
|
|
)
|
|
def test_qkv_matmul():
|
|
qkv = torch.randn((4, 24, 64 * 3), device="cuda", dtype=torch.float16)
|
|
scale = 1.2
|
|
head_size = 32
|
|
batches = qkv.shape[0]
|
|
d_model = qkv.shape[-1] // 3
|
|
num_of_heads = d_model // head_size
|
|
|
|
q = qkv[:, :, :d_model]
|
|
k = qkv[:, :, d_model : d_model * 2]
|
|
|
|
q = q.view(batches, -1, num_of_heads, head_size)
|
|
k = k.view(batches, -1, num_of_heads, head_size)
|
|
q_copy = q.clone()
|
|
k_copy = k.clone()
|
|
q = torch.transpose(q, 1, 2).contiguous()
|
|
k = torch.transpose(k, 1, 2).contiguous()
|
|
k = torch.transpose(k, 2, 3).contiguous()
|
|
|
|
torch_ouput = torch.einsum("bnij,bnjk->bnik", q, k)
|
|
torch_ouput *= 1.2
|
|
|
|
q, k = q_copy, k_copy
|
|
batches, M, H, K = q.shape
|
|
N = k.shape[1]
|
|
score_output = torch.empty((batches, H, M, N), device=q.device, dtype=q.dtype)
|
|
|
|
grid = lambda meta: (
|
|
batches,
|
|
H,
|
|
triton.cdiv(M, meta["BLOCK_SIZE_M"]) * triton.cdiv(N, meta["BLOCK_SIZE_N"]),
|
|
)
|
|
|
|
K = q.shape[3]
|
|
qkv_gemm_4d_kernel[grid](
|
|
q,
|
|
k,
|
|
score_output,
|
|
M,
|
|
N,
|
|
K,
|
|
q.stride(0),
|
|
q.stride(2),
|
|
q.stride(1),
|
|
q.stride(3),
|
|
k.stride(0),
|
|
k.stride(2),
|
|
k.stride(3),
|
|
k.stride(1),
|
|
score_output.stride(0),
|
|
score_output.stride(1),
|
|
score_output.stride(2),
|
|
score_output.stride(3),
|
|
scale=scale,
|
|
# currently manually setting, later on we can use auto-tune config to match best setting
|
|
BLOCK_SIZE_M=64,
|
|
BLOCK_SIZE_N=32,
|
|
BLOCK_SIZE_K=32,
|
|
GROUP_SIZE_M=8,
|
|
)
|
|
|
|
check = torch.allclose(torch_ouput.cpu(), score_output.cpu(), rtol=1e-3, atol=1e-5)
|
|
assert check is True, "the outputs of triton and torch are not matched"
|
|
|
|
|
|
def self_attention_compute_using_torch(qkv, input_mask, scale, head_size):
|
|
batches = qkv.shape[0]
|
|
d_model = qkv.shape[-1] // 3
|
|
num_of_heads = d_model // head_size
|
|
|
|
q = qkv[:, :, :d_model]
|
|
k = qkv[:, :, d_model : d_model * 2]
|
|
v = qkv[:, :, d_model * 2 :]
|
|
q = q.view(batches, -1, num_of_heads, head_size)
|
|
k = k.view(batches, -1, num_of_heads, head_size)
|
|
v = v.view(batches, -1, num_of_heads, head_size)
|
|
|
|
q = torch.transpose(q, 1, 2).contiguous()
|
|
k = torch.transpose(k, 1, 2).contiguous()
|
|
v = torch.transpose(v, 1, 2).contiguous()
|
|
|
|
k = torch.transpose(k, -1, -2).contiguous()
|
|
|
|
score_output = torch.einsum("bnij,bnjk->bnik", q, k)
|
|
score_output *= scale
|
|
|
|
softmax_output = F.softmax(score_output, dim=-1)
|
|
res = torch.einsum("bnij,bnjk->bnik", softmax_output, v)
|
|
res = torch.transpose(res, 1, 2)
|
|
res = res.contiguous()
|
|
|
|
return res.view(batches, -1, d_model), score_output, softmax_output
|
|
|
|
|
|
@pytest.mark.skipif(
|
|
not TRITON_CUDA_SUPPORT or not HAS_TRITON, reason="triton requires cuda version to be higher than 11.4"
|
|
)
|
|
def test_self_atttention_test():
|
|
qkv = torch.randn((4, 24, 64 * 3), device="cuda", dtype=torch.float16)
|
|
data_output_torch, score_output_torch, softmax_output_torch = self_attention_compute_using_torch(
|
|
qkv.clone(), input_mask=None, scale=1.2, head_size=32
|
|
)
|
|
|
|
data_output_triton = self_attention_compute_using_triton(
|
|
qkv.clone(),
|
|
alibi=None,
|
|
head_size=32,
|
|
scale=1.2,
|
|
input_mask=None,
|
|
layer_past=None,
|
|
use_flash=False,
|
|
triangular=True,
|
|
)
|
|
|
|
check = torch.allclose(data_output_triton.cpu(), data_output_torch.cpu(), rtol=1e-4, atol=1e-2)
|
|
assert check is True, "the triton output is not matched with torch output"
|
|
|
|
|
|
if __name__ == "__main__":
|
|
test_qkv_matmul()
|
|
test_self_atttention_test()
|