mirror of https://github.com/hpcaitech/ColossalAI
[Kernels] added triton-implemented of self attention for colossal-ai (#4241)
* added softmax kernel * added qkv_kernel * added ops * adding tests * upload tets * fix tests * debugging * debugging tests * debugging * added * fixed errors * added softmax kernel * clean codes * added tests * update tests * update tests * added attention * add * fixed pytest checking * add cuda check * fix cuda version * fix typopull/4302/head
parent
7ff11b5537
commit
4b977541a8
|
@ -0,0 +1,209 @@
|
|||
import torch
|
||||
from torch import nn
|
||||
|
||||
try:
|
||||
import triton
|
||||
import triton.language as tl
|
||||
HAS_TRITON = True
|
||||
except ImportError:
|
||||
HAS_TRITON = False
|
||||
print("please install triton from https://github.com/openai/triton")
|
||||
|
||||
if HAS_TRITON:
|
||||
from .qkv_matmul_kernel import qkv_gemm_4d_kernel
|
||||
from .softmax_kernel import softmax_kernel
|
||||
|
||||
def self_attention_forward_without_fusion(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, input_mask: torch.Tensor, scale: float):
|
||||
r""" A function to do QKV Attention calculation by calling GEMM and softmax triton kernels
|
||||
Args:
|
||||
q (torch.Tensor): Q embedding in attention layer, shape should be (batch, seq_len, num_heads, head_size)
|
||||
k (torch.Tensor): K embedding in attention layer, shape should be (batch, seq_len, num_heads, head_size)
|
||||
v (torch.Tensor): V embedding in attention layer, shape should be (batch, seq_len, num_heads, head_size)
|
||||
input_mask (torch.Tensor): mask for softmax layer, shape should be (batch, num_heads, seq_lem, seq_len)
|
||||
scale: the float scale value which is used to multiply with Q*K^T before doing softmax
|
||||
|
||||
Return:
|
||||
output (Torch.Tensor): The output shape is (batch, seq_len, num_heads, head_size)
|
||||
"""
|
||||
assert len(q.shape) == 4, "the shape of q val must be 4"
|
||||
batches, M, H, K = q.shape
|
||||
assert q.shape == k.shape, "the shape of q and the shape of k must be equal"
|
||||
assert q.shape == v.shape, "the shape of q and the shape of v must be equal"
|
||||
assert q.shape[-1] == k.shape[-1], "the last dimension of q and k must be equal"
|
||||
|
||||
N = k.shape[1]
|
||||
|
||||
# head_size * num_of_head
|
||||
d_model = q.shape[-1] * q.shape[-2]
|
||||
|
||||
score_output = torch.empty(
|
||||
(batches, H, M, N), device=q.device, dtype=q.dtype)
|
||||
|
||||
grid = lambda meta: (
|
||||
batches,
|
||||
H,
|
||||
triton.cdiv(M, meta["BLOCK_SIZE_M"]) *
|
||||
triton.cdiv(N, meta["BLOCK_SIZE_N"]),
|
||||
)
|
||||
|
||||
qkv_gemm_4d_kernel[grid](
|
||||
q, k, score_output,
|
||||
M, N, K,
|
||||
q.stride(0), q.stride(2), q.stride(1), q.stride(3),
|
||||
k.stride(0), k.stride(2), k.stride(3), k.stride(1),
|
||||
score_output.stride(0), score_output.stride(1), score_output.stride(2), score_output.stride(3),
|
||||
scale=scale,
|
||||
# currently manually setting, later on we can use auto-tune config to match best setting
|
||||
BLOCK_SIZE_M=64,
|
||||
BLOCK_SIZE_N=32,
|
||||
BLOCK_SIZE_K=32,
|
||||
GROUP_SIZE_M=8,
|
||||
)
|
||||
|
||||
softmax_output = torch.empty(
|
||||
score_output.shape, device=score_output.device, dtype=score_output.dtype)
|
||||
score_output_shape = score_output.shape
|
||||
|
||||
score_output = score_output.view(-1, score_output.shape[-1])
|
||||
n_rows, n_cols = score_output.shape
|
||||
|
||||
if n_rows <= 350000:
|
||||
|
||||
block_size = max(triton.next_power_of_2(n_cols), 2)
|
||||
num_warps = 4
|
||||
if block_size >= 4096:
|
||||
num_warps = 16
|
||||
elif block_size >= 2048:
|
||||
num_warps = 8
|
||||
else:
|
||||
num_warps = 4
|
||||
|
||||
softmax_kernel[(n_rows, )](
|
||||
softmax_output,
|
||||
score_output,
|
||||
score_output.stride(0),
|
||||
n_cols,
|
||||
mask_ptr = input_mask,
|
||||
num_warps=num_warps,
|
||||
BLOCK_SIZE=block_size,
|
||||
)
|
||||
|
||||
else:
|
||||
#TODO: change softmax kernel functions to make it suitable for large size dimension
|
||||
softmax_output = torch.nn.functional.softmax(score_output, dim=-1)
|
||||
softmax_output = softmax_output.view(*score_output_shape)
|
||||
|
||||
batches, H, M, K = softmax_output.shape
|
||||
N = v.shape[-1]
|
||||
|
||||
output = torch.empty(
|
||||
(batches, M, H, N), device=softmax_output.device, dtype=softmax_output.dtype)
|
||||
|
||||
grid = lambda meta: (
|
||||
batches,
|
||||
H,
|
||||
triton.cdiv(M, meta["BLOCK_SIZE_M"]) *
|
||||
triton.cdiv(N, meta["BLOCK_SIZE_N"]),
|
||||
)
|
||||
|
||||
qkv_gemm_4d_kernel[grid](
|
||||
softmax_output, v, output,
|
||||
M, N, K,
|
||||
softmax_output.stride(0),
|
||||
softmax_output.stride(1),
|
||||
softmax_output.stride(2),
|
||||
softmax_output.stride(3),
|
||||
v.stride(0),
|
||||
v.stride(2),
|
||||
v.stride(1),
|
||||
v.stride(3),
|
||||
output.stride(0),
|
||||
output.stride(2),
|
||||
output.stride(1),
|
||||
output.stride(3),
|
||||
BLOCK_SIZE_M=128,
|
||||
BLOCK_SIZE_N=64,
|
||||
BLOCK_SIZE_K=64,
|
||||
GROUP_SIZE_M=8,
|
||||
scale=-1,
|
||||
)
|
||||
return output.view(batches, -1, d_model)
|
||||
|
||||
|
||||
def self_attention_compute_using_triton(qkv,
|
||||
input_mask,
|
||||
layer_past,
|
||||
alibi,
|
||||
scale,
|
||||
head_size,
|
||||
triangular=False,
|
||||
use_flash=False):
|
||||
|
||||
assert qkv.is_contiguous()
|
||||
assert alibi is None, "current triton self-attention does not support alibi"
|
||||
batches = qkv.shape[0]
|
||||
d_model = qkv.shape[-1] // 3
|
||||
num_of_heads = d_model // head_size
|
||||
|
||||
q = qkv[:, :, :d_model]
|
||||
k = qkv[:, :, d_model:d_model * 2]
|
||||
v = qkv[:, :, d_model * 2:]
|
||||
q = q.view(batches, -1, num_of_heads, head_size)
|
||||
k = k.view(batches, -1, num_of_heads, head_size)
|
||||
v = v.view(batches, -1, num_of_heads, head_size)
|
||||
|
||||
data_output_triton = self_attention_forward_without_fusion(
|
||||
q, k, v, input_mask, scale)
|
||||
|
||||
return data_output_triton
|
||||
|
||||
|
||||
def softmax(input: torch.Tensor, mask: torch.Tensor = None, dim=-1) -> torch.Tensor:
|
||||
if mask is not None:
|
||||
assert input[-1] == mask[-1], "the last dimentions should be the same for input and mask"
|
||||
assert dim == -1 or dim == len(input.shape)-1, "currently softmax layer only support last dimention"
|
||||
|
||||
hidden_dim = input.shape[-1]
|
||||
output = torch.empty_like(input)
|
||||
input = input.view(-1, hidden_dim)
|
||||
if mask is not None:
|
||||
mask = mask.view(-1, hidden_dim)
|
||||
assert input.shape[0] == mask.shape[0], "the fist dimention of mask and input should be the same"
|
||||
|
||||
num_rows, num_cols = input.shape
|
||||
block_size = max(triton.next_power_of_2(num_cols), 2)
|
||||
num_warps = 16
|
||||
if block_size >= 4096:
|
||||
num_warps = 16
|
||||
elif block_size >= 2048:
|
||||
num_warps = 8
|
||||
else:
|
||||
num_warps = 4
|
||||
|
||||
if num_rows <= 350000:
|
||||
grid = (num_rows,)
|
||||
softmax_kernel[grid](output, input, input.stride(0), num_cols, mask, BLOCK_SIZE = block_size, num_warps=num_warps)
|
||||
else:
|
||||
grid = lambda meta: ()
|
||||
|
||||
grid = lambda meta: (
|
||||
triton.cdiv(num_rows, meta["BLOCK_M"]),
|
||||
)
|
||||
|
||||
BLOCK_M = 32
|
||||
if block_size >= 4096:
|
||||
BLOCK_M = 4
|
||||
elif block_size >= 2048:
|
||||
BLOCK_M = 8
|
||||
|
||||
softmax_kernel_2[grid](output_ptr = output,
|
||||
input_ptr = input,
|
||||
row_stride = input.stride(0),
|
||||
n_rows = num_rows,
|
||||
n_cols = num_cols,
|
||||
mask_ptr = mask,
|
||||
# currently manually setting up size
|
||||
BLOCK_M = 32,
|
||||
BLOCK_SIZE = block_size)
|
||||
|
||||
return output
|
|
@ -0,0 +1,109 @@
|
|||
import torch
|
||||
try:
|
||||
import triton
|
||||
import triton.language as tl
|
||||
HAS_TRITON = True
|
||||
except ImportError:
|
||||
HAS_TRITON = False
|
||||
print("please install triton from https://github.com/openai/triton")
|
||||
|
||||
|
||||
if HAS_TRITON:
|
||||
'''
|
||||
this kernel function is modified from https://triton-lang.org/main/getting-started/tutorials/03-matrix-multiplication.html
|
||||
'''
|
||||
@triton.jit
|
||||
def qkv_gemm_4d_kernel(
|
||||
a_ptr,
|
||||
b_ptr,
|
||||
c_ptr,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
stride_ab,
|
||||
stride_ah,
|
||||
stride_am,
|
||||
stride_ak,
|
||||
stride_bb,
|
||||
stride_bh,
|
||||
stride_bk,
|
||||
stride_bn,
|
||||
stride_cb,
|
||||
stride_ch,
|
||||
stride_cm,
|
||||
stride_cn,
|
||||
scale,
|
||||
# Meta-parameters
|
||||
BLOCK_SIZE_M : tl.constexpr = 64,
|
||||
BLOCK_SIZE_N : tl.constexpr = 32,
|
||||
BLOCK_SIZE_K : tl.constexpr = 32,
|
||||
GROUP_SIZE_M : tl.constexpr = 8,
|
||||
):
|
||||
r""" A kernel function which is used to do batch-matmul for Q*K^T or score_matrix * V for attention layer,
|
||||
where score_matrix is softmax(Q*V^T/sqrt(hidden_size))
|
||||
Args:
|
||||
a_ptr(torch.Tensor): pointer to input tensor array (bs, M, h, K) or (bs, h, M, K)
|
||||
b_ptr(torch.Tensor): pointer to input tensor array (bs, N, h, K) or (bs, h, N, K)
|
||||
c_ptr(torch.Tensor): pointer to output tensor array (bs, M, h, N) or (bs, h, M, N)
|
||||
stride_ab(tl.constexpr): stride for bs-dimention for tensor array A
|
||||
stride_ah(tl.constexpr): stride for h-dimention for tensor array A
|
||||
stride_am(tl.constexpr): stride for m-dimention for tensor array A
|
||||
stride_ak(tl.constexpr): stride for k-dimention for tensor array A
|
||||
stride_bb(tl.constexpr): stride for bs-dimention for tensor array B
|
||||
stride_bh(tl.constexpr): stride for h-dimention for tensor array B
|
||||
stride_bk(tl.constexpr): stride for k-dimention for tensor array B
|
||||
stride_bn(tl.constexpr): stride for n-dimention for tensor array B
|
||||
stride_cb(tl.constexpr): stride for bs-dimention for tensor array output
|
||||
stride_ch(tl.constexpr): stride for h-dimention for tensor array output
|
||||
stride_cm(tl.constexpr): stride for m-dimention for tensor array output
|
||||
stride_cn(tl.constexpr): stride for n-dimention for tensor array output
|
||||
BLOCK_SIZE_M : tiling size for M-dimension of tensor Array a
|
||||
BLOCK_SIZE_N : tiling size for N-dimension of tensor Array b
|
||||
BLOCK_SIZE_K : tiling size for K-dimension of a and b
|
||||
GROUP_SIZE_M : group size for reducing cache miss, more details:
|
||||
"""
|
||||
|
||||
num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
|
||||
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
|
||||
batch = tl.program_id(axis = 0)
|
||||
head = tl.program_id(axis = 1)
|
||||
pid = tl.program_id(axis = 2)
|
||||
|
||||
# the following is from tutorial: https://triton-lang.org/main/getting-started/tutorials/03-matrix-multiplication.html
|
||||
num_pid_in_group = GROUP_SIZE_M * num_pid_n
|
||||
group_id = pid // num_pid_in_group
|
||||
first_pid_m = group_id * GROUP_SIZE_M
|
||||
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
|
||||
pid_m = first_pid_m + (pid % group_size_m)
|
||||
pid_n = (pid % num_pid_in_group) // group_size_m
|
||||
|
||||
|
||||
offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
||||
offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
||||
offs_k = tl.arange(0, BLOCK_SIZE_K)
|
||||
a_ptrs = (a_ptr + batch * stride_ab + head * stride_ah +
|
||||
(offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak))
|
||||
b_ptrs = (b_ptr + batch * stride_bb + head * stride_bh +
|
||||
(offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn))
|
||||
|
||||
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
|
||||
for k in range(0, K, BLOCK_SIZE_K):
|
||||
a_mask = (offs_am[:, None] < M) & (offs_k[None, :] + k < K)
|
||||
b_mask = (offs_k[:, None] + k < K) & (offs_bn[None, :] < N)
|
||||
a = tl.load(a_ptrs, mask=a_mask, other=0.)
|
||||
b = tl.load(b_ptrs, mask=b_mask, other=0.)
|
||||
accumulator += tl.dot(a, b)
|
||||
a_ptrs += BLOCK_SIZE_K * stride_ak
|
||||
b_ptrs += BLOCK_SIZE_K * stride_bk
|
||||
|
||||
accumulator = accumulator.to(c_ptr.dtype.element_ty)
|
||||
if scale > 0:
|
||||
accumulator = accumulator * scale.to(c_ptr.dtype.element_ty)
|
||||
|
||||
|
||||
offs_accumu_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
||||
offs_accumu_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
||||
c_ptrs = (c_ptr + batch * stride_cb + head * stride_ch + stride_cm * offs_accumu_m[:, None] +
|
||||
stride_cn * offs_accumu_n[None, :])
|
||||
accumulator_mask = (offs_accumu_m[:, None] < M) & (offs_accumu_n[None, :] < N)
|
||||
tl.store(c_ptrs, accumulator, mask=accumulator_mask)
|
|
@ -0,0 +1,44 @@
|
|||
try:
|
||||
import triton
|
||||
import triton.language as tl
|
||||
HAS_TRITON = True
|
||||
except ImportError:
|
||||
HAS_TRITON = False
|
||||
print("please install triton from https://github.com/openai/triton")
|
||||
|
||||
if HAS_TRITON:
|
||||
'''
|
||||
softmax kernel is modified based on
|
||||
https://github.com/openai/triton/blob/34817ecc954a6f4ca7b4dfb352fdde1f8bd49ca5/python/tutorials/02-fused-softmax.py
|
||||
'''
|
||||
@triton.jit
|
||||
def softmax_kernel(output_ptr, input_ptr, row_stride, n_cols, mask_ptr, BLOCK_SIZE: tl.constexpr):
|
||||
r""" the kernel function for implementing softmax operator
|
||||
Args:
|
||||
output_ptr: the output after finishing softmax operation, (N, hidden_dim)
|
||||
input_ptr: the tensor of input, shape should be (N, hidden_dim)
|
||||
n_cols(tl.constexpr): the number of cols of input
|
||||
BLOCK_SIZE(tl.constexpr): the block_size of your hidden_dim dimension, typically BLOCK_SIZE >= hidden_dim
|
||||
"""
|
||||
row_idx = tl.program_id(0)
|
||||
row_start_ptr = input_ptr + row_idx * row_stride
|
||||
col_offsets = tl.arange(0, BLOCK_SIZE)
|
||||
input_ptrs = row_start_ptr + col_offsets
|
||||
row = tl.load(input_ptrs, mask=col_offsets < n_cols, other=-float('inf')).to(tl.float32)
|
||||
row_minus_max = row - tl.max(row, axis=0)
|
||||
|
||||
if mask_ptr is not None:
|
||||
# load mask into SRAM
|
||||
mask_ptrs = (mask_ptr + (row_indx * row_stride)) + col_offsets
|
||||
mask = tl.load(mask_ptrs, mask=col_offsets < n_cols, other=0).to(tl.float32)
|
||||
|
||||
# update
|
||||
row_minus_max = row_minus_max + mask
|
||||
|
||||
numerator = tl.exp(row_minus_max)
|
||||
denominator = tl.sum(numerator, axis=0)
|
||||
softmax_output = numerator / denominator
|
||||
output_row_start_ptr = output_ptr + row_idx * row_stride
|
||||
output_ptrs = output_row_start_ptr + col_offsets
|
||||
# Write back output to DRAM
|
||||
tl.store(output_ptrs, softmax_output, mask=col_offsets < n_cols)
|
|
@ -0,0 +1,136 @@
|
|||
import pytest
|
||||
from packaging import version
|
||||
import torch
|
||||
from torch import nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from colossalai.kernel.triton.ops import self_attention_compute_using_triton
|
||||
from colossalai.kernel.triton.qkv_matmul_kernel import qkv_gemm_4d_kernel
|
||||
|
||||
try:
|
||||
import triton
|
||||
import triton.language as tl
|
||||
HAS_TRITON = True
|
||||
except ImportError:
|
||||
HAS_TRITON = False
|
||||
print("please install triton from https://github.com/openai/triton")
|
||||
|
||||
TRITON_CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse('11.4')
|
||||
|
||||
@pytest.mark.skipif(not TRITON_CUDA_SUPPORT, reason="triton requires cuda version to be higher than 11.4")
|
||||
def test_qkv_matmul():
|
||||
qkv = torch.randn((4, 24, 64*3), device="cuda", dtype=torch.float16)
|
||||
scale = 1.2
|
||||
head_size = 32
|
||||
batches = qkv.shape[0]
|
||||
d_model = qkv.shape[-1] // 3
|
||||
num_of_heads = d_model // head_size
|
||||
|
||||
q = qkv[:, :, :d_model]
|
||||
k = qkv[:, :, d_model:d_model * 2]
|
||||
|
||||
q = q.view(batches, -1, num_of_heads, head_size)
|
||||
k = k.view(batches, -1, num_of_heads, head_size)
|
||||
q_copy = q.clone()
|
||||
k_copy = k.clone()
|
||||
q = torch.transpose(q, 1, 2).contiguous()
|
||||
k = torch.transpose(k, 1, 2).contiguous()
|
||||
k = torch.transpose(k, 2, 3).contiguous()
|
||||
|
||||
torch_ouput = torch.einsum('bnij,bnjk->bnik', q, k)
|
||||
torch_ouput *= 1.2
|
||||
|
||||
q, k = q_copy, k_copy
|
||||
batches, M, H, K = q.shape
|
||||
N = k.shape[1]
|
||||
score_output = torch.empty(
|
||||
(batches, H, M, N), device=q.device, dtype=q.dtype)
|
||||
|
||||
grid = lambda meta: (
|
||||
batches,
|
||||
H,
|
||||
triton.cdiv(M, meta["BLOCK_SIZE_M"]) *
|
||||
triton.cdiv(N, meta["BLOCK_SIZE_N"]),
|
||||
)
|
||||
|
||||
K = q.shape[3]
|
||||
qkv_gemm_4d_kernel[grid](
|
||||
q, k, score_output,
|
||||
M, N, K,
|
||||
q.stride(0), q.stride(2), q.stride(1), q.stride(3),
|
||||
k.stride(0), k.stride(2), k.stride(3), k.stride(1),
|
||||
score_output.stride(0), score_output.stride(1), score_output.stride(2), score_output.stride(3),
|
||||
scale=scale,
|
||||
# currently manually setting, later on we can use auto-tune config to match best setting
|
||||
BLOCK_SIZE_M=64,
|
||||
BLOCK_SIZE_N=32,
|
||||
BLOCK_SIZE_K=32,
|
||||
GROUP_SIZE_M=8,
|
||||
)
|
||||
|
||||
check = torch.allclose(torch_ouput.cpu(), score_output.cpu(), rtol=1e-3, atol=1e-5)
|
||||
assert check is True, "the outputs of triton and torch are not matched"
|
||||
|
||||
|
||||
def self_attention_compute_using_torch(qkv,
|
||||
input_mask,
|
||||
scale,
|
||||
head_size
|
||||
):
|
||||
|
||||
batches = qkv.shape[0]
|
||||
d_model = qkv.shape[-1] // 3
|
||||
num_of_heads = d_model // head_size
|
||||
|
||||
q = qkv[:, :, :d_model]
|
||||
k = qkv[:, :, d_model:d_model * 2]
|
||||
v = qkv[:, :, d_model * 2:]
|
||||
q = q.view(batches, -1, num_of_heads, head_size)
|
||||
k = k.view(batches, -1, num_of_heads, head_size)
|
||||
v = v.view(batches, -1, num_of_heads, head_size)
|
||||
|
||||
q = torch.transpose(q, 1, 2).contiguous()
|
||||
k = torch.transpose(k, 1, 2).contiguous()
|
||||
v = torch.transpose(v, 1, 2).contiguous()
|
||||
|
||||
k = torch.transpose(k, -1, -2).contiguous()
|
||||
|
||||
score_output = torch.einsum('bnij,bnjk->bnik', q, k)
|
||||
score_output *= scale
|
||||
|
||||
softmax_output = F.softmax(score_output, dim = -1)
|
||||
res = torch.einsum('bnij,bnjk->bnik', softmax_output, v)
|
||||
res = torch.transpose(res, 1, 2)
|
||||
res = res.contiguous()
|
||||
|
||||
|
||||
return res.view(batches, -1, d_model), score_output, softmax_output
|
||||
|
||||
@pytest.mark.skipif(not TRITON_CUDA_SUPPORT, reason="triton requires cuda version to be higher than 11.4")
|
||||
def test_self_atttention_test():
|
||||
|
||||
qkv = torch.randn((4, 24, 64*3), device="cuda", dtype=torch.float16)
|
||||
data_output_torch, score_output_torch, softmax_output_torch = self_attention_compute_using_torch(
|
||||
qkv.clone(),
|
||||
input_mask = None,
|
||||
scale = 1.2,
|
||||
head_size = 32
|
||||
)
|
||||
|
||||
data_output_triton = self_attention_compute_using_triton(
|
||||
qkv.clone(),
|
||||
alibi=None,
|
||||
head_size=32,
|
||||
scale=1.2,
|
||||
input_mask=None,
|
||||
layer_past=None,
|
||||
use_flash=False,
|
||||
triangular=True)
|
||||
|
||||
check = torch.allclose(data_output_triton.cpu(), data_output_torch.cpu(), rtol=1e-4, atol=1e-2)
|
||||
assert check is True, "the triton output is not matched with torch output"
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_qkv_matmul()
|
||||
test_self_atttention_test()
|
|
@ -0,0 +1,27 @@
|
|||
import pytest
|
||||
from packaging import version
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from colossalai.kernel.triton.ops import softmax
|
||||
|
||||
TRITON_CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse('11.4')
|
||||
|
||||
@pytest.mark.skipif(not TRITON_CUDA_SUPPORT, reason="triton requires cuda version to be higher than 11.4")
|
||||
def test_softmax_op():
|
||||
data_samples = [
|
||||
torch.randn((3, 4, 5, 32), device = "cuda", dtype = torch.float32),
|
||||
torch.randn((320, 320, 78), device = "cuda", dtype = torch.float32),
|
||||
torch.randn((2345, 4, 5, 64), device = "cuda", dtype = torch.float16)
|
||||
]
|
||||
|
||||
for data in data_samples:
|
||||
module = nn.Softmax(dim = -1)
|
||||
data_torch_out = module(data)
|
||||
data_triton_out = softmax(data)
|
||||
check = torch.allclose(data_torch_out.cpu(), data_triton_out.cpu(), rtol=1e-3, atol=1e-3)
|
||||
assert check is True, "softmax outputs from triton and torch are not matched"
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_softmax_op()
|
Loading…
Reference in New Issue