import pytest import torch import torch.nn.functional as F from packaging import version try: import triton from colossalai.kernel.triton.qkv_matmul_kernel import qkv_gemm_4d_kernel from colossalai.kernel.triton.self_attention_nofusion import self_attention_compute_using_triton HAS_TRITON = True except ImportError: HAS_TRITON = False print("please install triton from https://github.com/openai/triton") TRITON_CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse("11.4") @pytest.mark.skipif( not TRITON_CUDA_SUPPORT or not HAS_TRITON, reason="triton requires cuda version to be higher than 11.4" ) def test_qkv_matmul(): qkv = torch.randn((4, 24, 64 * 3), device="cuda", dtype=torch.float16) scale = 1.2 head_size = 32 batches = qkv.shape[0] d_model = qkv.shape[-1] // 3 num_of_heads = d_model // head_size q = qkv[:, :, :d_model] k = qkv[:, :, d_model : d_model * 2] q = q.view(batches, -1, num_of_heads, head_size) k = k.view(batches, -1, num_of_heads, head_size) q_copy = q.clone() k_copy = k.clone() q = torch.transpose(q, 1, 2).contiguous() k = torch.transpose(k, 1, 2).contiguous() k = torch.transpose(k, 2, 3).contiguous() torch_ouput = torch.einsum("bnij,bnjk->bnik", q, k) torch_ouput *= 1.2 q, k = q_copy, k_copy batches, M, H, K = q.shape N = k.shape[1] score_output = torch.empty((batches, H, M, N), device=q.device, dtype=q.dtype) grid = lambda meta: ( batches, H, triton.cdiv(M, meta["BLOCK_SIZE_M"]) * triton.cdiv(N, meta["BLOCK_SIZE_N"]), ) K = q.shape[3] qkv_gemm_4d_kernel[grid]( q, k, score_output, M, N, K, q.stride(0), q.stride(2), q.stride(1), q.stride(3), k.stride(0), k.stride(2), k.stride(3), k.stride(1), score_output.stride(0), score_output.stride(1), score_output.stride(2), score_output.stride(3), scale=scale, # currently manually setting, later on we can use auto-tune config to match best setting BLOCK_SIZE_M=64, BLOCK_SIZE_N=32, BLOCK_SIZE_K=32, GROUP_SIZE_M=8, ) check = torch.allclose(torch_ouput.cpu(), score_output.cpu(), rtol=1e-3, atol=1e-5) assert check is True, "the outputs of triton and torch are not matched" def self_attention_compute_using_torch(qkv, input_mask, scale, head_size): batches = qkv.shape[0] d_model = qkv.shape[-1] // 3 num_of_heads = d_model // head_size q = qkv[:, :, :d_model] k = qkv[:, :, d_model : d_model * 2] v = qkv[:, :, d_model * 2 :] q = q.view(batches, -1, num_of_heads, head_size) k = k.view(batches, -1, num_of_heads, head_size) v = v.view(batches, -1, num_of_heads, head_size) q = torch.transpose(q, 1, 2).contiguous() k = torch.transpose(k, 1, 2).contiguous() v = torch.transpose(v, 1, 2).contiguous() k = torch.transpose(k, -1, -2).contiguous() score_output = torch.einsum("bnij,bnjk->bnik", q, k) score_output *= scale softmax_output = F.softmax(score_output, dim=-1) res = torch.einsum("bnij,bnjk->bnik", softmax_output, v) res = torch.transpose(res, 1, 2) res = res.contiguous() return res.view(batches, -1, d_model), score_output, softmax_output @pytest.mark.skipif( not TRITON_CUDA_SUPPORT or not HAS_TRITON, reason="triton requires cuda version to be higher than 11.4" ) def test_self_atttention_test(): qkv = torch.randn((4, 24, 64 * 3), device="cuda", dtype=torch.float16) data_output_torch, score_output_torch, softmax_output_torch = self_attention_compute_using_torch( qkv.clone(), input_mask=None, scale=1.2, head_size=32 ) data_output_triton = self_attention_compute_using_triton( qkv.clone(), alibi=None, head_size=32, scale=1.2, input_mask=None, layer_past=None, use_flash=False, triangular=True, ) check = torch.allclose(data_output_triton.cpu(), data_output_torch.cpu(), rtol=1e-4, atol=1e-2) assert check is True, "the triton output is not matched with torch output" if __name__ == "__main__": test_qkv_matmul() test_self_atttention_test()