[Kernels] added triton-implemented of self attention for colossal-ai (#4241)

* added softmax kernel

* added qkv_kernel

* added ops

* adding tests

* upload tets

* fix tests

* debugging

* debugging tests

* debugging

* added

* fixed errors

* added softmax kernel

* clean codes

* added tests

* update tests

* update tests

* added attention

* add

* fixed pytest checking

* add cuda check

* fix cuda version

* fix typo
pull/4302/head
Cuiqing Li 2023-07-18 23:53:38 +08:00 committed by GitHub
parent 7ff11b5537
commit 4b977541a8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 525 additions and 0 deletions

View File

@ -0,0 +1,209 @@
import torch
from torch import nn
try:
import triton
import triton.language as tl
HAS_TRITON = True
except ImportError:
HAS_TRITON = False
print("please install triton from https://github.com/openai/triton")
if HAS_TRITON:
from .qkv_matmul_kernel import qkv_gemm_4d_kernel
from .softmax_kernel import softmax_kernel
def self_attention_forward_without_fusion(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, input_mask: torch.Tensor, scale: float):
r""" A function to do QKV Attention calculation by calling GEMM and softmax triton kernels
Args:
q (torch.Tensor): Q embedding in attention layer, shape should be (batch, seq_len, num_heads, head_size)
k (torch.Tensor): K embedding in attention layer, shape should be (batch, seq_len, num_heads, head_size)
v (torch.Tensor): V embedding in attention layer, shape should be (batch, seq_len, num_heads, head_size)
input_mask (torch.Tensor): mask for softmax layer, shape should be (batch, num_heads, seq_lem, seq_len)
scale: the float scale value which is used to multiply with Q*K^T before doing softmax
Return:
output (Torch.Tensor): The output shape is (batch, seq_len, num_heads, head_size)
"""
assert len(q.shape) == 4, "the shape of q val must be 4"
batches, M, H, K = q.shape
assert q.shape == k.shape, "the shape of q and the shape of k must be equal"
assert q.shape == v.shape, "the shape of q and the shape of v must be equal"
assert q.shape[-1] == k.shape[-1], "the last dimension of q and k must be equal"
N = k.shape[1]
# head_size * num_of_head
d_model = q.shape[-1] * q.shape[-2]
score_output = torch.empty(
(batches, H, M, N), device=q.device, dtype=q.dtype)
grid = lambda meta: (
batches,
H,
triton.cdiv(M, meta["BLOCK_SIZE_M"]) *
triton.cdiv(N, meta["BLOCK_SIZE_N"]),
)
qkv_gemm_4d_kernel[grid](
q, k, score_output,
M, N, K,
q.stride(0), q.stride(2), q.stride(1), q.stride(3),
k.stride(0), k.stride(2), k.stride(3), k.stride(1),
score_output.stride(0), score_output.stride(1), score_output.stride(2), score_output.stride(3),
scale=scale,
# currently manually setting, later on we can use auto-tune config to match best setting
BLOCK_SIZE_M=64,
BLOCK_SIZE_N=32,
BLOCK_SIZE_K=32,
GROUP_SIZE_M=8,
)
softmax_output = torch.empty(
score_output.shape, device=score_output.device, dtype=score_output.dtype)
score_output_shape = score_output.shape
score_output = score_output.view(-1, score_output.shape[-1])
n_rows, n_cols = score_output.shape
if n_rows <= 350000:
block_size = max(triton.next_power_of_2(n_cols), 2)
num_warps = 4
if block_size >= 4096:
num_warps = 16
elif block_size >= 2048:
num_warps = 8
else:
num_warps = 4
softmax_kernel[(n_rows, )](
softmax_output,
score_output,
score_output.stride(0),
n_cols,
mask_ptr = input_mask,
num_warps=num_warps,
BLOCK_SIZE=block_size,
)
else:
#TODO: change softmax kernel functions to make it suitable for large size dimension
softmax_output = torch.nn.functional.softmax(score_output, dim=-1)
softmax_output = softmax_output.view(*score_output_shape)
batches, H, M, K = softmax_output.shape
N = v.shape[-1]
output = torch.empty(
(batches, M, H, N), device=softmax_output.device, dtype=softmax_output.dtype)
grid = lambda meta: (
batches,
H,
triton.cdiv(M, meta["BLOCK_SIZE_M"]) *
triton.cdiv(N, meta["BLOCK_SIZE_N"]),
)
qkv_gemm_4d_kernel[grid](
softmax_output, v, output,
M, N, K,
softmax_output.stride(0),
softmax_output.stride(1),
softmax_output.stride(2),
softmax_output.stride(3),
v.stride(0),
v.stride(2),
v.stride(1),
v.stride(3),
output.stride(0),
output.stride(2),
output.stride(1),
output.stride(3),
BLOCK_SIZE_M=128,
BLOCK_SIZE_N=64,
BLOCK_SIZE_K=64,
GROUP_SIZE_M=8,
scale=-1,
)
return output.view(batches, -1, d_model)
def self_attention_compute_using_triton(qkv,
input_mask,
layer_past,
alibi,
scale,
head_size,
triangular=False,
use_flash=False):
assert qkv.is_contiguous()
assert alibi is None, "current triton self-attention does not support alibi"
batches = qkv.shape[0]
d_model = qkv.shape[-1] // 3
num_of_heads = d_model // head_size
q = qkv[:, :, :d_model]
k = qkv[:, :, d_model:d_model * 2]
v = qkv[:, :, d_model * 2:]
q = q.view(batches, -1, num_of_heads, head_size)
k = k.view(batches, -1, num_of_heads, head_size)
v = v.view(batches, -1, num_of_heads, head_size)
data_output_triton = self_attention_forward_without_fusion(
q, k, v, input_mask, scale)
return data_output_triton
def softmax(input: torch.Tensor, mask: torch.Tensor = None, dim=-1) -> torch.Tensor:
if mask is not None:
assert input[-1] == mask[-1], "the last dimentions should be the same for input and mask"
assert dim == -1 or dim == len(input.shape)-1, "currently softmax layer only support last dimention"
hidden_dim = input.shape[-1]
output = torch.empty_like(input)
input = input.view(-1, hidden_dim)
if mask is not None:
mask = mask.view(-1, hidden_dim)
assert input.shape[0] == mask.shape[0], "the fist dimention of mask and input should be the same"
num_rows, num_cols = input.shape
block_size = max(triton.next_power_of_2(num_cols), 2)
num_warps = 16
if block_size >= 4096:
num_warps = 16
elif block_size >= 2048:
num_warps = 8
else:
num_warps = 4
if num_rows <= 350000:
grid = (num_rows,)
softmax_kernel[grid](output, input, input.stride(0), num_cols, mask, BLOCK_SIZE = block_size, num_warps=num_warps)
else:
grid = lambda meta: ()
grid = lambda meta: (
triton.cdiv(num_rows, meta["BLOCK_M"]),
)
BLOCK_M = 32
if block_size >= 4096:
BLOCK_M = 4
elif block_size >= 2048:
BLOCK_M = 8
softmax_kernel_2[grid](output_ptr = output,
input_ptr = input,
row_stride = input.stride(0),
n_rows = num_rows,
n_cols = num_cols,
mask_ptr = mask,
# currently manually setting up size
BLOCK_M = 32,
BLOCK_SIZE = block_size)
return output

View File

@ -0,0 +1,109 @@
import torch
try:
import triton
import triton.language as tl
HAS_TRITON = True
except ImportError:
HAS_TRITON = False
print("please install triton from https://github.com/openai/triton")
if HAS_TRITON:
'''
this kernel function is modified from https://triton-lang.org/main/getting-started/tutorials/03-matrix-multiplication.html
'''
@triton.jit
def qkv_gemm_4d_kernel(
a_ptr,
b_ptr,
c_ptr,
M,
N,
K,
stride_ab,
stride_ah,
stride_am,
stride_ak,
stride_bb,
stride_bh,
stride_bk,
stride_bn,
stride_cb,
stride_ch,
stride_cm,
stride_cn,
scale,
# Meta-parameters
BLOCK_SIZE_M : tl.constexpr = 64,
BLOCK_SIZE_N : tl.constexpr = 32,
BLOCK_SIZE_K : tl.constexpr = 32,
GROUP_SIZE_M : tl.constexpr = 8,
):
r""" A kernel function which is used to do batch-matmul for Q*K^T or score_matrix * V for attention layer,
where score_matrix is softmax(Q*V^T/sqrt(hidden_size))
Args:
a_ptr(torch.Tensor): pointer to input tensor array (bs, M, h, K) or (bs, h, M, K)
b_ptr(torch.Tensor): pointer to input tensor array (bs, N, h, K) or (bs, h, N, K)
c_ptr(torch.Tensor): pointer to output tensor array (bs, M, h, N) or (bs, h, M, N)
stride_ab(tl.constexpr): stride for bs-dimention for tensor array A
stride_ah(tl.constexpr): stride for h-dimention for tensor array A
stride_am(tl.constexpr): stride for m-dimention for tensor array A
stride_ak(tl.constexpr): stride for k-dimention for tensor array A
stride_bb(tl.constexpr): stride for bs-dimention for tensor array B
stride_bh(tl.constexpr): stride for h-dimention for tensor array B
stride_bk(tl.constexpr): stride for k-dimention for tensor array B
stride_bn(tl.constexpr): stride for n-dimention for tensor array B
stride_cb(tl.constexpr): stride for bs-dimention for tensor array output
stride_ch(tl.constexpr): stride for h-dimention for tensor array output
stride_cm(tl.constexpr): stride for m-dimention for tensor array output
stride_cn(tl.constexpr): stride for n-dimention for tensor array output
BLOCK_SIZE_M : tiling size for M-dimension of tensor Array a
BLOCK_SIZE_N : tiling size for N-dimension of tensor Array b
BLOCK_SIZE_K : tiling size for K-dimension of a and b
GROUP_SIZE_M : group size for reducing cache miss, more details:
"""
num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
batch = tl.program_id(axis = 0)
head = tl.program_id(axis = 1)
pid = tl.program_id(axis = 2)
# the following is from tutorial: https://triton-lang.org/main/getting-started/tutorials/03-matrix-multiplication.html
num_pid_in_group = GROUP_SIZE_M * num_pid_n
group_id = pid // num_pid_in_group
first_pid_m = group_id * GROUP_SIZE_M
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
pid_m = first_pid_m + (pid % group_size_m)
pid_n = (pid % num_pid_in_group) // group_size_m
offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
offs_k = tl.arange(0, BLOCK_SIZE_K)
a_ptrs = (a_ptr + batch * stride_ab + head * stride_ah +
(offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak))
b_ptrs = (b_ptr + batch * stride_bb + head * stride_bh +
(offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn))
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
for k in range(0, K, BLOCK_SIZE_K):
a_mask = (offs_am[:, None] < M) & (offs_k[None, :] + k < K)
b_mask = (offs_k[:, None] + k < K) & (offs_bn[None, :] < N)
a = tl.load(a_ptrs, mask=a_mask, other=0.)
b = tl.load(b_ptrs, mask=b_mask, other=0.)
accumulator += tl.dot(a, b)
a_ptrs += BLOCK_SIZE_K * stride_ak
b_ptrs += BLOCK_SIZE_K * stride_bk
accumulator = accumulator.to(c_ptr.dtype.element_ty)
if scale > 0:
accumulator = accumulator * scale.to(c_ptr.dtype.element_ty)
offs_accumu_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
offs_accumu_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
c_ptrs = (c_ptr + batch * stride_cb + head * stride_ch + stride_cm * offs_accumu_m[:, None] +
stride_cn * offs_accumu_n[None, :])
accumulator_mask = (offs_accumu_m[:, None] < M) & (offs_accumu_n[None, :] < N)
tl.store(c_ptrs, accumulator, mask=accumulator_mask)

View File

@ -0,0 +1,44 @@
try:
import triton
import triton.language as tl
HAS_TRITON = True
except ImportError:
HAS_TRITON = False
print("please install triton from https://github.com/openai/triton")
if HAS_TRITON:
'''
softmax kernel is modified based on
https://github.com/openai/triton/blob/34817ecc954a6f4ca7b4dfb352fdde1f8bd49ca5/python/tutorials/02-fused-softmax.py
'''
@triton.jit
def softmax_kernel(output_ptr, input_ptr, row_stride, n_cols, mask_ptr, BLOCK_SIZE: tl.constexpr):
r""" the kernel function for implementing softmax operator
Args:
output_ptr: the output after finishing softmax operation, (N, hidden_dim)
input_ptr: the tensor of input, shape should be (N, hidden_dim)
n_cols(tl.constexpr): the number of cols of input
BLOCK_SIZE(tl.constexpr): the block_size of your hidden_dim dimension, typically BLOCK_SIZE >= hidden_dim
"""
row_idx = tl.program_id(0)
row_start_ptr = input_ptr + row_idx * row_stride
col_offsets = tl.arange(0, BLOCK_SIZE)
input_ptrs = row_start_ptr + col_offsets
row = tl.load(input_ptrs, mask=col_offsets < n_cols, other=-float('inf')).to(tl.float32)
row_minus_max = row - tl.max(row, axis=0)
if mask_ptr is not None:
# load mask into SRAM
mask_ptrs = (mask_ptr + (row_indx * row_stride)) + col_offsets
mask = tl.load(mask_ptrs, mask=col_offsets < n_cols, other=0).to(tl.float32)
# update
row_minus_max = row_minus_max + mask
numerator = tl.exp(row_minus_max)
denominator = tl.sum(numerator, axis=0)
softmax_output = numerator / denominator
output_row_start_ptr = output_ptr + row_idx * row_stride
output_ptrs = output_row_start_ptr + col_offsets
# Write back output to DRAM
tl.store(output_ptrs, softmax_output, mask=col_offsets < n_cols)

View File

@ -0,0 +1,136 @@
import pytest
from packaging import version
import torch
from torch import nn
import torch.nn.functional as F
from colossalai.kernel.triton.ops import self_attention_compute_using_triton
from colossalai.kernel.triton.qkv_matmul_kernel import qkv_gemm_4d_kernel
try:
import triton
import triton.language as tl
HAS_TRITON = True
except ImportError:
HAS_TRITON = False
print("please install triton from https://github.com/openai/triton")
TRITON_CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse('11.4')
@pytest.mark.skipif(not TRITON_CUDA_SUPPORT, reason="triton requires cuda version to be higher than 11.4")
def test_qkv_matmul():
qkv = torch.randn((4, 24, 64*3), device="cuda", dtype=torch.float16)
scale = 1.2
head_size = 32
batches = qkv.shape[0]
d_model = qkv.shape[-1] // 3
num_of_heads = d_model // head_size
q = qkv[:, :, :d_model]
k = qkv[:, :, d_model:d_model * 2]
q = q.view(batches, -1, num_of_heads, head_size)
k = k.view(batches, -1, num_of_heads, head_size)
q_copy = q.clone()
k_copy = k.clone()
q = torch.transpose(q, 1, 2).contiguous()
k = torch.transpose(k, 1, 2).contiguous()
k = torch.transpose(k, 2, 3).contiguous()
torch_ouput = torch.einsum('bnij,bnjk->bnik', q, k)
torch_ouput *= 1.2
q, k = q_copy, k_copy
batches, M, H, K = q.shape
N = k.shape[1]
score_output = torch.empty(
(batches, H, M, N), device=q.device, dtype=q.dtype)
grid = lambda meta: (
batches,
H,
triton.cdiv(M, meta["BLOCK_SIZE_M"]) *
triton.cdiv(N, meta["BLOCK_SIZE_N"]),
)
K = q.shape[3]
qkv_gemm_4d_kernel[grid](
q, k, score_output,
M, N, K,
q.stride(0), q.stride(2), q.stride(1), q.stride(3),
k.stride(0), k.stride(2), k.stride(3), k.stride(1),
score_output.stride(0), score_output.stride(1), score_output.stride(2), score_output.stride(3),
scale=scale,
# currently manually setting, later on we can use auto-tune config to match best setting
BLOCK_SIZE_M=64,
BLOCK_SIZE_N=32,
BLOCK_SIZE_K=32,
GROUP_SIZE_M=8,
)
check = torch.allclose(torch_ouput.cpu(), score_output.cpu(), rtol=1e-3, atol=1e-5)
assert check is True, "the outputs of triton and torch are not matched"
def self_attention_compute_using_torch(qkv,
input_mask,
scale,
head_size
):
batches = qkv.shape[0]
d_model = qkv.shape[-1] // 3
num_of_heads = d_model // head_size
q = qkv[:, :, :d_model]
k = qkv[:, :, d_model:d_model * 2]
v = qkv[:, :, d_model * 2:]
q = q.view(batches, -1, num_of_heads, head_size)
k = k.view(batches, -1, num_of_heads, head_size)
v = v.view(batches, -1, num_of_heads, head_size)
q = torch.transpose(q, 1, 2).contiguous()
k = torch.transpose(k, 1, 2).contiguous()
v = torch.transpose(v, 1, 2).contiguous()
k = torch.transpose(k, -1, -2).contiguous()
score_output = torch.einsum('bnij,bnjk->bnik', q, k)
score_output *= scale
softmax_output = F.softmax(score_output, dim = -1)
res = torch.einsum('bnij,bnjk->bnik', softmax_output, v)
res = torch.transpose(res, 1, 2)
res = res.contiguous()
return res.view(batches, -1, d_model), score_output, softmax_output
@pytest.mark.skipif(not TRITON_CUDA_SUPPORT, reason="triton requires cuda version to be higher than 11.4")
def test_self_atttention_test():
qkv = torch.randn((4, 24, 64*3), device="cuda", dtype=torch.float16)
data_output_torch, score_output_torch, softmax_output_torch = self_attention_compute_using_torch(
qkv.clone(),
input_mask = None,
scale = 1.2,
head_size = 32
)
data_output_triton = self_attention_compute_using_triton(
qkv.clone(),
alibi=None,
head_size=32,
scale=1.2,
input_mask=None,
layer_past=None,
use_flash=False,
triangular=True)
check = torch.allclose(data_output_triton.cpu(), data_output_torch.cpu(), rtol=1e-4, atol=1e-2)
assert check is True, "the triton output is not matched with torch output"
if __name__ == "__main__":
test_qkv_matmul()
test_self_atttention_test()

View File

@ -0,0 +1,27 @@
import pytest
from packaging import version
import torch
from torch import nn
from colossalai.kernel.triton.ops import softmax
TRITON_CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse('11.4')
@pytest.mark.skipif(not TRITON_CUDA_SUPPORT, reason="triton requires cuda version to be higher than 11.4")
def test_softmax_op():
data_samples = [
torch.randn((3, 4, 5, 32), device = "cuda", dtype = torch.float32),
torch.randn((320, 320, 78), device = "cuda", dtype = torch.float32),
torch.randn((2345, 4, 5, 64), device = "cuda", dtype = torch.float16)
]
for data in data_samples:
module = nn.Softmax(dim = -1)
data_torch_out = module(data)
data_triton_out = softmax(data)
check = torch.allclose(data_torch_out.cpu(), data_triton_out.cpu(), rtol=1e-3, atol=1e-3)
assert check is True, "softmax outputs from triton and torch are not matched"
if __name__ == "__main__":
test_softmax_op()