mirror of https://github.com/hpcaitech/ColossalAI
[workflow] fixed oom tests (#5275)
* [workflow] fixed oom tests * polish * polish * polishpull/5276/head^2^2
parent
04244aaaf1
commit
d69cd2eb89
|
@ -160,9 +160,7 @@ jobs:
|
|||
--ignore tests/test_gptq \
|
||||
--ignore tests/test_infer_ops \
|
||||
--ignore tests/test_legacy \
|
||||
--ignore tests/test_moe \
|
||||
--ignore tests/test_smoothquant \
|
||||
--ignore tests/test_checkpoint_io \
|
||||
tests/
|
||||
env:
|
||||
LD_LIBRARY_PATH: /github/home/.tensornvme/lib:/usr/local/nvidia/lib:/usr/local/nvidia/lib64
|
||||
|
|
|
@ -61,7 +61,9 @@ class ModelZooRegistry(dict):
|
|||
"""
|
||||
self[name] = (model_fn, data_gen_fn, output_transform_fn, loss_fn, model_attribute)
|
||||
|
||||
def get_sub_registry(self, keyword: Union[str, List[str]], exclude: Union[str, List[str]] = None):
|
||||
def get_sub_registry(
|
||||
self, keyword: Union[str, List[str]], exclude: Union[str, List[str]] = None, allow_empty: bool = False
|
||||
):
|
||||
"""
|
||||
Get a sub registry with models that contain the keyword.
|
||||
|
||||
|
@ -95,7 +97,8 @@ class ModelZooRegistry(dict):
|
|||
if not should_exclude:
|
||||
new_dict[k] = v
|
||||
|
||||
assert len(new_dict) > 0, f"No model found with keyword {keyword}"
|
||||
if not allow_empty:
|
||||
assert len(new_dict) > 0, f"No model found with keyword {keyword}"
|
||||
return new_dict
|
||||
|
||||
|
||||
|
|
|
@ -63,6 +63,9 @@ config = transformers.GPTJConfig(
|
|||
n_layer=2,
|
||||
n_head=4,
|
||||
vocab_size=50258,
|
||||
n_embd=256,
|
||||
hidden_size=256,
|
||||
n_positions=512,
|
||||
attn_pdrop=0,
|
||||
embd_pdrop=0,
|
||||
resid_pdrop=0,
|
||||
|
|
|
@ -12,7 +12,13 @@ from colossalai.fx import is_compatible_with_meta
|
|||
from colossalai.lazy.lazy_init import LazyInitContext
|
||||
from colossalai.nn.optimizer import HybridAdam
|
||||
from colossalai.tensor.colo_parameter import ColoParameter
|
||||
from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn
|
||||
from colossalai.testing import (
|
||||
clear_cache_before_run,
|
||||
parameterize,
|
||||
rerun_if_address_is_in_use,
|
||||
skip_if_not_enough_gpus,
|
||||
spawn,
|
||||
)
|
||||
from tests.kit.model_zoo import COMMON_MODELS, IS_FAST_TEST, model_zoo
|
||||
|
||||
|
||||
|
@ -172,6 +178,7 @@ def test_gemini_plugin(early_stop: bool = True):
|
|||
|
||||
|
||||
@pytest.mark.largedist
|
||||
@skip_if_not_enough_gpus(8)
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_gemini_plugin_3d(early_stop: bool = True):
|
||||
spawn(run_dist, 8, early_stop=early_stop)
|
||||
|
|
|
@ -10,8 +10,8 @@ from colossalai.booster import Booster
|
|||
from colossalai.booster.plugin import LowLevelZeroPlugin
|
||||
|
||||
# from colossalai.nn.optimizer import HybridAdam
|
||||
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
|
||||
from tests.kit.model_zoo import model_zoo, IS_FAST_TEST, COMMON_MODELS
|
||||
from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn
|
||||
from tests.kit.model_zoo import COMMON_MODELS, IS_FAST_TEST, model_zoo
|
||||
|
||||
# These models are not compatible with AMP
|
||||
_AMP_ERR_MODELS = ["timm_convit", "deepfm_interactionarch"]
|
||||
|
@ -21,6 +21,7 @@ _LOW_LEVEL_ZERO_ERR_MODELS = ["dlrm_interactionarch"]
|
|||
_STUCK_MODELS = ["transformers_albert_for_multiple_choice"]
|
||||
|
||||
|
||||
@clear_cache_before_run()
|
||||
def run_fn(stage, model_fn, data_gen_fn, output_transform_fn) -> Optional[str]:
|
||||
device = device_utils.get_current_device()
|
||||
try:
|
||||
|
|
|
@ -10,10 +10,11 @@ import colossalai
|
|||
from colossalai.booster import Booster
|
||||
from colossalai.booster.plugin import TorchDDPPlugin
|
||||
from colossalai.interface import OptimizerWrapper
|
||||
from colossalai.testing import rerun_if_address_is_in_use, spawn
|
||||
from tests.kit.model_zoo import model_zoo, IS_FAST_TEST, COMMON_MODELS
|
||||
from colossalai.testing import clear_cache_before_run, rerun_if_address_is_in_use, spawn
|
||||
from tests.kit.model_zoo import COMMON_MODELS, IS_FAST_TEST, model_zoo
|
||||
|
||||
|
||||
@clear_cache_before_run()
|
||||
def run_fn(model_fn, data_gen_fn, output_transform_fn):
|
||||
plugin = TorchDDPPlugin()
|
||||
booster = Booster(plugin=plugin)
|
||||
|
|
|
@ -11,11 +11,12 @@ if version.parse(torch.__version__) >= version.parse("1.12.0"):
|
|||
from colossalai.booster.plugin import TorchFSDPPlugin
|
||||
|
||||
from colossalai.interface import OptimizerWrapper
|
||||
from colossalai.testing import rerun_if_address_is_in_use, spawn
|
||||
from tests.kit.model_zoo import model_zoo, IS_FAST_TEST, COMMON_MODELS
|
||||
from colossalai.testing import clear_cache_before_run, rerun_if_address_is_in_use, spawn
|
||||
from tests.kit.model_zoo import COMMON_MODELS, IS_FAST_TEST, model_zoo
|
||||
|
||||
|
||||
# test basic fsdp function
|
||||
@clear_cache_before_run()
|
||||
def run_fn(model_fn, data_gen_fn, output_transform_fn):
|
||||
plugin = TorchFSDPPlugin()
|
||||
booster = Booster(plugin=plugin)
|
||||
|
@ -40,12 +41,18 @@ def run_fn(model_fn, data_gen_fn, output_transform_fn):
|
|||
optimizer.clip_grad_by_norm(1.0)
|
||||
optimizer.step()
|
||||
|
||||
del model
|
||||
del optimizer
|
||||
del criterion
|
||||
del booster
|
||||
del plugin
|
||||
|
||||
|
||||
def check_torch_fsdp_plugin():
|
||||
if IS_FAST_TEST:
|
||||
registry = model_zoo.get_sub_registry(COMMON_MODELS)
|
||||
else:
|
||||
registry = model_zoo
|
||||
registry = model_zoo.get_sub_registry("transformers_gptj")
|
||||
|
||||
for name, (model_fn, data_gen_fn, output_transform_fn, _, _) in registry.items():
|
||||
if any(
|
||||
|
@ -59,6 +66,7 @@ def check_torch_fsdp_plugin():
|
|||
]
|
||||
):
|
||||
continue
|
||||
print(name)
|
||||
run_fn(model_fn, data_gen_fn, output_transform_fn)
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
|
@ -73,3 +81,7 @@ def run_dist(rank, world_size, port):
|
|||
@rerun_if_address_is_in_use()
|
||||
def test_torch_fsdp_plugin():
|
||||
spawn(run_dist, 2)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_torch_fsdp_plugin()
|
||||
|
|
|
@ -38,11 +38,11 @@ else:
|
|||
]
|
||||
|
||||
|
||||
@clear_cache_before_run()
|
||||
@parameterize("shard", [True, False])
|
||||
@parameterize("model_name", ["transformers_llama_for_casual_lm"])
|
||||
@parameterize("size_per_shard", [32])
|
||||
@parameterize("test_config", TEST_CONFIGS)
|
||||
@clear_cache_before_run()
|
||||
def exam_state_dict(shard: bool, model_name: str, size_per_shard: int, test_config: dict):
|
||||
(model_fn, data_gen_fn, output_transform_fn, loss_fn, _) = next(
|
||||
iter(model_zoo.get_sub_registry(model_name).values())
|
||||
|
@ -145,3 +145,7 @@ def run_dist(rank, world_size, port):
|
|||
@rerun_if_address_is_in_use()
|
||||
def test_hybrid_ckpIO(world_size):
|
||||
spawn(run_dist, world_size)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_hybrid_ckpIO(4)
|
||||
|
|
|
@ -1,27 +0,0 @@
|
|||
import math
|
||||
|
||||
import torch
|
||||
from torch.nn import functional as F
|
||||
|
||||
|
||||
def torch_context_attention(xq, xk, xv, bs, seqlen, num_head, head_dim):
|
||||
"""
|
||||
adepted from https://github.com/ModelTC/lightllm/blob/main/lightllm/models/bloom/triton_kernel/context_flashattention_nopad.py#L253
|
||||
"""
|
||||
xq = xq.view(bs, seqlen, num_head, head_dim)
|
||||
xk = xk.view(bs, seqlen, num_head, head_dim)
|
||||
xv = xv.view(bs, seqlen, num_head, head_dim)
|
||||
mask = torch.tril(torch.ones(seqlen, seqlen), diagonal=0).unsqueeze(0).unsqueeze(0).cuda()
|
||||
mask[mask == 0.0] = -100000000.0
|
||||
mask = mask.repeat(bs, num_head, 1, 1)
|
||||
keys = xk
|
||||
values = xv
|
||||
xq = xq.transpose(1, 2)
|
||||
keys = keys.transpose(1, 2)
|
||||
values = values.transpose(1, 2)
|
||||
sm_scale = 1 / math.sqrt(head_dim)
|
||||
scores = torch.matmul(xq, keys.transpose(2, 3)) * sm_scale
|
||||
scores = F.softmax(scores.float() + mask, dim=-1).to(dtype=torch.float16)
|
||||
|
||||
output = torch.matmul(scores, values).transpose(1, 2).contiguous().reshape(-1, num_head, head_dim)
|
||||
return output
|
|
@ -1,52 +0,0 @@
|
|||
import pytest
|
||||
import torch
|
||||
from packaging import version
|
||||
|
||||
try:
|
||||
pass
|
||||
|
||||
from colossalai.kernel.triton import bloom_context_attn_fwd
|
||||
from tests.test_infer_ops.triton.kernel_utils import torch_context_attention
|
||||
|
||||
HAS_TRITON = True
|
||||
except ImportError:
|
||||
HAS_TRITON = False
|
||||
print("please install triton from https://github.com/openai/triton")
|
||||
|
||||
TRITON_CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse("11.4")
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not TRITON_CUDA_SUPPORT or not HAS_TRITON, reason="triton requires cuda version to be higher than 11.4"
|
||||
)
|
||||
def test_bloom_context_attention():
|
||||
bs = 4
|
||||
head_num = 8
|
||||
seq_len = 1024
|
||||
head_dim = 64
|
||||
|
||||
query = torch.randn((bs * seq_len, head_num, head_dim), dtype=torch.float16, device="cuda")
|
||||
k = torch.randn((bs * seq_len, head_num, head_dim), dtype=torch.float16, device="cuda")
|
||||
v = torch.randn((bs * seq_len, head_num, head_dim), dtype=torch.float16, device="cuda")
|
||||
|
||||
max_input_len = seq_len
|
||||
b_start = torch.zeros((bs,), device="cuda", dtype=torch.int32)
|
||||
b_len = torch.zeros((bs,), device="cuda", dtype=torch.int32)
|
||||
|
||||
for i in range(bs):
|
||||
b_start[i] = i * seq_len
|
||||
b_len[i] = seq_len
|
||||
|
||||
o = torch.randn((bs * seq_len, head_num, head_dim), dtype=torch.float16, device="cuda")
|
||||
alibi = torch.zeros((head_num,), dtype=torch.float32, device="cuda")
|
||||
bloom_context_attn_fwd(query.clone(), k.clone(), v.clone(), o, b_start, b_len, max_input_len, alibi)
|
||||
|
||||
torch_out = torch_context_attention(query.clone(), k.clone(), v.clone(), bs, seq_len, head_num, head_dim)
|
||||
|
||||
assert torch.allclose(
|
||||
torch_out.cpu(), o.cpu(), rtol=1e-3, atol=1e-2
|
||||
), "outputs from triton and torch are not matched"
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_bloom_context_attention()
|
|
@ -1,39 +0,0 @@
|
|||
import pytest
|
||||
import torch
|
||||
from packaging import version
|
||||
|
||||
try:
|
||||
pass
|
||||
|
||||
from colossalai.kernel.triton.copy_kv_cache_dest import copy_kv_cache_to_dest
|
||||
|
||||
HAS_TRITON = True
|
||||
except ImportError:
|
||||
HAS_TRITON = False
|
||||
print("please install triton from https://github.com/openai/triton")
|
||||
|
||||
TRITON_CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse("11.4")
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not TRITON_CUDA_SUPPORT or not HAS_TRITON, reason="triton requires cuda version to be higher than 11.4"
|
||||
)
|
||||
def test_kv_cache_copy_op():
|
||||
B_NTX = 32 * 2048
|
||||
head_num = 8
|
||||
head_dim = 64
|
||||
|
||||
cache = torch.randn((B_NTX, head_num, head_dim), device="cuda", dtype=torch.float16)
|
||||
dest_index = torch.arange(0, B_NTX, device="cuda", dtype=torch.int32)
|
||||
|
||||
dest_data = torch.ones((B_NTX, head_num, head_dim), device="cuda", dtype=torch.float16)
|
||||
|
||||
copy_kv_cache_to_dest(cache, dest_index, dest_data)
|
||||
|
||||
assert torch.allclose(
|
||||
cache.cpu(), dest_data.cpu(), rtol=1e-3, atol=1e-3
|
||||
), "copy_kv_cache_to_dest outputs from triton and torch are not matched"
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_kv_cache_copy_op()
|
|
@ -1,43 +0,0 @@
|
|||
import pytest
|
||||
import torch
|
||||
from packaging import version
|
||||
|
||||
from colossalai.kernel.triton import layer_norm
|
||||
from colossalai.testing.utils import parameterize
|
||||
|
||||
try:
|
||||
pass
|
||||
|
||||
HAS_TRITON = True
|
||||
except ImportError:
|
||||
HAS_TRITON = False
|
||||
print("please install triton from https://github.com/openai/triton")
|
||||
|
||||
TRITON_CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse("11.4")
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not TRITON_CUDA_SUPPORT or not HAS_TRITON, reason="triton requires cuda version to be higher than 11.4"
|
||||
)
|
||||
@parameterize("M", [2, 4, 8, 16])
|
||||
@parameterize("N", [64, 128])
|
||||
def test_layer_norm(M, N):
|
||||
dtype = torch.float16
|
||||
eps = 1e-5
|
||||
x_shape = (M, N)
|
||||
w_shape = (x_shape[-1],)
|
||||
weight = torch.rand(w_shape, dtype=dtype, device="cuda")
|
||||
bias = torch.rand(w_shape, dtype=dtype, device="cuda")
|
||||
x = -2.3 + 0.5 * torch.randn(x_shape, dtype=dtype, device="cuda")
|
||||
|
||||
y_triton = layer_norm(x, weight, bias, eps)
|
||||
y_torch = torch.nn.functional.layer_norm(x, w_shape, weight, bias, eps).to(dtype)
|
||||
|
||||
assert y_triton.shape == y_torch.shape
|
||||
assert y_triton.dtype == y_torch.dtype
|
||||
print("max delta: ", torch.max(torch.abs(y_triton - y_torch)))
|
||||
assert torch.allclose(y_triton, y_torch, atol=1e-2, rtol=0)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_layer_norm()
|
|
@ -1,56 +0,0 @@
|
|||
import pytest
|
||||
import torch
|
||||
from packaging import version
|
||||
from torch import nn
|
||||
|
||||
from colossalai.kernel.triton.llama_act_combine_kernel import LlamaActCombine
|
||||
|
||||
try:
|
||||
import triton
|
||||
HAS_TRITON = True
|
||||
except ImportError:
|
||||
HAS_TRITON = False
|
||||
print("please install triton from https://github.com/openai/triton")
|
||||
TRITON_CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse('11.4')
|
||||
|
||||
BATCH_SIZE = 4
|
||||
SEQ_LEN = 16
|
||||
HIDDEN_SIZE = 32
|
||||
|
||||
|
||||
def SwiGLU(x):
|
||||
"""Gated linear unit activation function.
|
||||
Args:
|
||||
x : input array
|
||||
axis: the axis along which the split should be computed (default: -1)
|
||||
"""
|
||||
size = x.shape[-1]
|
||||
assert size % 2 == 0, "axis size must be divisible by 2"
|
||||
x1, x2 = torch.split(x, size // 2, -1)
|
||||
return x1 * (x2 * torch.sigmoid(x2.to(torch.float32)).to(x.dtype))
|
||||
|
||||
|
||||
@pytest.mark.skipif(not (HAS_TRITON and TRITON_CUDA_SUPPORT), reason="requires triton")
|
||||
@pytest.mark.parametrize("dtype", [torch.float32, torch.float16])
|
||||
def test_llama_act_combine(dtype: str):
|
||||
x_gate = torch.randn(BATCH_SIZE, SEQ_LEN, HIDDEN_SIZE * 2, dtype=dtype).cuda()
|
||||
x_gate_torch = nn.Parameter(x_gate.detach().clone())
|
||||
x_gate_kernel = nn.Parameter(x_gate.detach().clone())
|
||||
x_up = torch.randn(BATCH_SIZE, SEQ_LEN, HIDDEN_SIZE, dtype=dtype).cuda()
|
||||
x_up_torch = nn.Parameter(x_up.detach().clone())
|
||||
x_up_kernel = nn.Parameter(x_up.detach().clone())
|
||||
|
||||
torch_out = SwiGLU(x_gate_torch) * x_up_torch
|
||||
kernel_out = LlamaActCombine.apply(x_gate_kernel, x_up_kernel)
|
||||
atol = 1e-5 if dtype == torch.float32 else 5e-2
|
||||
assert torch.allclose(torch_out, kernel_out, atol=atol)
|
||||
|
||||
torch_out.mean().backward()
|
||||
kernel_out.mean().backward()
|
||||
assert all(grad is not None for grad in [x_gate_torch.grad, x_up_torch.grad, x_gate_kernel.grad, x_up_kernel.grad])
|
||||
assert torch.allclose(x_gate_torch.grad, x_gate_kernel.grad, atol=atol)
|
||||
assert torch.allclose(x_up_torch.grad, x_up_kernel.grad, atol=atol)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_llama_act_combine(torch.float16)
|
|
@ -1,50 +0,0 @@
|
|||
import pytest
|
||||
import torch
|
||||
from packaging import version
|
||||
|
||||
try:
|
||||
pass
|
||||
|
||||
from colossalai.kernel.triton import llama_context_attn_fwd
|
||||
from tests.test_infer_ops.triton.kernel_utils import torch_context_attention
|
||||
|
||||
HAS_TRITON = True
|
||||
except ImportError:
|
||||
HAS_TRITON = False
|
||||
print("please install triton from https://github.com/openai/triton")
|
||||
|
||||
TRITON_CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse("11.4")
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not TRITON_CUDA_SUPPORT or not HAS_TRITON, reason="triton requires cuda version to be higher than 11.4"
|
||||
)
|
||||
def test_llama_context_attention():
|
||||
bs = 4
|
||||
head_num = 8
|
||||
seq_len = 1024
|
||||
head_dim = 64
|
||||
|
||||
query = torch.randn((bs * seq_len, head_num, head_dim), dtype=torch.float16, device="cuda")
|
||||
k = torch.randn((bs * seq_len, head_num, head_dim), dtype=torch.float16, device="cuda")
|
||||
v = torch.randn((bs * seq_len, head_num, head_dim), dtype=torch.float16, device="cuda")
|
||||
|
||||
max_input_len = seq_len
|
||||
b_start = torch.zeros((bs,), device="cuda", dtype=torch.int32)
|
||||
b_len = torch.zeros((bs,), device="cuda", dtype=torch.int32)
|
||||
|
||||
for i in range(bs):
|
||||
b_start[i] = i * seq_len
|
||||
b_len[i] = seq_len
|
||||
|
||||
o = torch.randn((bs * seq_len, head_num, head_dim), dtype=torch.float16, device="cuda")
|
||||
llama_context_attn_fwd(query.clone(), k.clone(), v.clone(), o, b_start, b_len, max_input_len)
|
||||
|
||||
torch_out = torch_context_attention(query.clone(), k.clone(), v.clone(), bs, seq_len, head_num, head_dim)
|
||||
assert torch.allclose(
|
||||
torch_out.cpu(), o.cpu(), rtol=1e-3, atol=1e-3
|
||||
), "outputs from triton and torch are not matched"
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_llama_context_attention()
|
|
@ -1,143 +0,0 @@
|
|||
import pytest
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from packaging import version
|
||||
|
||||
try:
|
||||
import triton
|
||||
|
||||
from colossalai.kernel.triton.qkv_matmul_kernel import qkv_gemm_4d_kernel
|
||||
from colossalai.kernel.triton.self_attention_nofusion import self_attention_compute_using_triton
|
||||
|
||||
HAS_TRITON = True
|
||||
except ImportError:
|
||||
HAS_TRITON = False
|
||||
print("please install triton from https://github.com/openai/triton")
|
||||
|
||||
TRITON_CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse("11.4")
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not TRITON_CUDA_SUPPORT or not HAS_TRITON, reason="triton requires cuda version to be higher than 11.4"
|
||||
)
|
||||
def test_qkv_matmul():
|
||||
qkv = torch.randn((4, 24, 64 * 3), device="cuda", dtype=torch.float16)
|
||||
scale = 1.2
|
||||
head_size = 32
|
||||
batches = qkv.shape[0]
|
||||
d_model = qkv.shape[-1] // 3
|
||||
num_of_heads = d_model // head_size
|
||||
|
||||
q = qkv[:, :, :d_model]
|
||||
k = qkv[:, :, d_model : d_model * 2]
|
||||
|
||||
q = q.view(batches, -1, num_of_heads, head_size)
|
||||
k = k.view(batches, -1, num_of_heads, head_size)
|
||||
q_copy = q.clone()
|
||||
k_copy = k.clone()
|
||||
q = torch.transpose(q, 1, 2).contiguous()
|
||||
k = torch.transpose(k, 1, 2).contiguous()
|
||||
k = torch.transpose(k, 2, 3).contiguous()
|
||||
|
||||
torch_ouput = torch.einsum("bnij,bnjk->bnik", q, k)
|
||||
torch_ouput *= 1.2
|
||||
|
||||
q, k = q_copy, k_copy
|
||||
batches, M, H, K = q.shape
|
||||
N = k.shape[1]
|
||||
score_output = torch.empty((batches, H, M, N), device=q.device, dtype=q.dtype)
|
||||
|
||||
grid = lambda meta: (
|
||||
batches,
|
||||
H,
|
||||
triton.cdiv(M, meta["BLOCK_SIZE_M"]) * triton.cdiv(N, meta["BLOCK_SIZE_N"]),
|
||||
)
|
||||
|
||||
K = q.shape[3]
|
||||
qkv_gemm_4d_kernel[grid](
|
||||
q,
|
||||
k,
|
||||
score_output,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
q.stride(0),
|
||||
q.stride(2),
|
||||
q.stride(1),
|
||||
q.stride(3),
|
||||
k.stride(0),
|
||||
k.stride(2),
|
||||
k.stride(3),
|
||||
k.stride(1),
|
||||
score_output.stride(0),
|
||||
score_output.stride(1),
|
||||
score_output.stride(2),
|
||||
score_output.stride(3),
|
||||
scale=scale,
|
||||
# currently manually setting, later on we can use auto-tune config to match best setting
|
||||
BLOCK_SIZE_M=64,
|
||||
BLOCK_SIZE_N=32,
|
||||
BLOCK_SIZE_K=32,
|
||||
GROUP_SIZE_M=8,
|
||||
)
|
||||
|
||||
check = torch.allclose(torch_ouput.cpu(), score_output.cpu(), rtol=1e-3, atol=1e-5)
|
||||
assert check is True, "the outputs of triton and torch are not matched"
|
||||
|
||||
|
||||
def self_attention_compute_using_torch(qkv, input_mask, scale, head_size):
|
||||
batches = qkv.shape[0]
|
||||
d_model = qkv.shape[-1] // 3
|
||||
num_of_heads = d_model // head_size
|
||||
|
||||
q = qkv[:, :, :d_model]
|
||||
k = qkv[:, :, d_model : d_model * 2]
|
||||
v = qkv[:, :, d_model * 2 :]
|
||||
q = q.view(batches, -1, num_of_heads, head_size)
|
||||
k = k.view(batches, -1, num_of_heads, head_size)
|
||||
v = v.view(batches, -1, num_of_heads, head_size)
|
||||
|
||||
q = torch.transpose(q, 1, 2).contiguous()
|
||||
k = torch.transpose(k, 1, 2).contiguous()
|
||||
v = torch.transpose(v, 1, 2).contiguous()
|
||||
|
||||
k = torch.transpose(k, -1, -2).contiguous()
|
||||
|
||||
score_output = torch.einsum("bnij,bnjk->bnik", q, k)
|
||||
score_output *= scale
|
||||
|
||||
softmax_output = F.softmax(score_output, dim=-1)
|
||||
res = torch.einsum("bnij,bnjk->bnik", softmax_output, v)
|
||||
res = torch.transpose(res, 1, 2)
|
||||
res = res.contiguous()
|
||||
|
||||
return res.view(batches, -1, d_model), score_output, softmax_output
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not TRITON_CUDA_SUPPORT or not HAS_TRITON, reason="triton requires cuda version to be higher than 11.4"
|
||||
)
|
||||
def test_self_atttention_test():
|
||||
qkv = torch.randn((4, 24, 64 * 3), device="cuda", dtype=torch.float16)
|
||||
data_output_torch, score_output_torch, softmax_output_torch = self_attention_compute_using_torch(
|
||||
qkv.clone(), input_mask=None, scale=1.2, head_size=32
|
||||
)
|
||||
|
||||
data_output_triton = self_attention_compute_using_triton(
|
||||
qkv.clone(),
|
||||
alibi=None,
|
||||
head_size=32,
|
||||
scale=1.2,
|
||||
input_mask=None,
|
||||
layer_past=None,
|
||||
use_flash=False,
|
||||
triangular=True,
|
||||
)
|
||||
|
||||
check = torch.allclose(data_output_triton.cpu(), data_output_torch.cpu(), rtol=1e-4, atol=1e-2)
|
||||
assert check is True, "the triton output is not matched with torch output"
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_qkv_matmul()
|
||||
test_self_atttention_test()
|
|
@ -1,36 +0,0 @@
|
|||
import pytest
|
||||
import torch
|
||||
from packaging import version
|
||||
from torch import nn
|
||||
|
||||
try:
|
||||
from colossalai.kernel.triton.softmax import softmax
|
||||
|
||||
HAS_TRITON = True
|
||||
except ImportError:
|
||||
HAS_TRITON = False
|
||||
print("please install triton from https://github.com/openai/triton")
|
||||
|
||||
TRITON_CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse("11.4")
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not TRITON_CUDA_SUPPORT or not HAS_TRITON, reason="triton requires cuda version to be higher than 11.4"
|
||||
)
|
||||
def test_softmax_op():
|
||||
data_samples = [
|
||||
torch.randn((3, 4, 5, 32), device="cuda", dtype=torch.float32),
|
||||
torch.randn((320, 320, 78), device="cuda", dtype=torch.float32),
|
||||
torch.randn((2345, 4, 5, 64), device="cuda", dtype=torch.float16),
|
||||
]
|
||||
|
||||
for data in data_samples:
|
||||
module = nn.Softmax(dim=-1)
|
||||
data_torch_out = module(data)
|
||||
data_triton_out = softmax(data)
|
||||
check = torch.allclose(data_torch_out.cpu(), data_triton_out.cpu(), rtol=1e-3, atol=1e-3)
|
||||
assert check is True, "softmax outputs from triton and torch are not matched"
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_softmax_op()
|
|
@ -1,72 +0,0 @@
|
|||
import pytest
|
||||
import torch
|
||||
from packaging import version
|
||||
|
||||
try:
|
||||
from colossalai.kernel.triton.token_attention_kernel import token_attention_fwd
|
||||
|
||||
HAS_TRITON = True
|
||||
except ImportError:
|
||||
HAS_TRITON = False
|
||||
print("please install triton from https://github.com/openai/triton")
|
||||
|
||||
|
||||
import importlib.util
|
||||
|
||||
HAS_LIGHTLLM_KERNEL = True
|
||||
|
||||
if importlib.util.find_spec("lightllm") is None:
|
||||
HAS_LIGHTLLM_KERNEL = False
|
||||
|
||||
TRITON_CUDA_SUPPORT = version.parse(torch.version.cuda) >= version.parse("11.6")
|
||||
|
||||
|
||||
def torch_att(xq, xk, xv, bs, seqlen, num_head, head_dim):
|
||||
xq = xq.view(bs, 1, num_head, head_dim)
|
||||
xk = xk.view(bs, seqlen, num_head, head_dim)
|
||||
xv = xv.view(bs, seqlen, num_head, head_dim)
|
||||
|
||||
logics = torch.sum(xq * xk, dim=3, keepdim=False) * 1 / (head_dim**0.5)
|
||||
prob = torch.softmax(logics, dim=1)
|
||||
prob = prob.view(bs, seqlen, num_head, 1)
|
||||
|
||||
return torch.sum(prob * xv, dim=1, keepdim=False)
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not TRITON_CUDA_SUPPORT or not HAS_TRITON or not HAS_LIGHTLLM_KERNEL,
|
||||
reason="triton requires cuda version to be higher than 11.4 or not install lightllm",
|
||||
)
|
||||
def test():
|
||||
Z, head_num, seq_len, head_dim = 22, 112 // 8, 2048, 128
|
||||
dtype = torch.float16
|
||||
q = torch.empty((Z, head_num, head_dim), dtype=dtype, device="cuda").normal_(mean=0.1, std=0.2)
|
||||
k = torch.empty((Z * seq_len, head_num, head_dim), dtype=dtype, device="cuda").normal_(mean=0.4, std=0.2)
|
||||
v = torch.empty((Z * seq_len, head_num, head_dim), dtype=dtype, device="cuda").normal_(mean=0.3, std=0.2)
|
||||
o = torch.empty((Z, head_num, head_dim), dtype=dtype, device="cuda").normal_(mean=0.3, std=0.2)
|
||||
alibi = torch.zeros((head_num,), dtype=torch.float32, device="cuda")
|
||||
|
||||
max_kv_cache_len = seq_len
|
||||
kv_cache_start_loc = torch.zeros((Z,), dtype=torch.int32, device="cuda")
|
||||
kv_cache_loc = torch.zeros((Z, seq_len), dtype=torch.int32, device="cuda")
|
||||
kv_cache_seq_len = torch.ones((Z,), dtype=torch.int32, device="cuda")
|
||||
|
||||
kv_cache_seq_len[:] = seq_len
|
||||
kv_cache_start_loc[0] = 0
|
||||
kv_cache_start_loc[1] = seq_len
|
||||
kv_cache_start_loc[2] = 2 * seq_len
|
||||
kv_cache_start_loc[3] = 3 * seq_len
|
||||
|
||||
for i in range(Z):
|
||||
kv_cache_loc[i, :] = torch.arange(i * seq_len, (i + 1) * seq_len, dtype=torch.int32, device="cuda")
|
||||
|
||||
token_attention_fwd(q, k, v, o, kv_cache_loc, kv_cache_start_loc, kv_cache_seq_len, max_kv_cache_len, alibi=alibi)
|
||||
torch_out = torch_att(q, k, v, Z, seq_len, head_num, head_dim)
|
||||
|
||||
print("max ", torch.max(torch.abs(torch_out - o)))
|
||||
print("mean ", torch.mean(torch.abs(torch_out - o)))
|
||||
assert torch.allclose(torch_out, o, atol=1e-2, rtol=0)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test()
|
|
@ -1,48 +0,0 @@
|
|||
import pytest
|
||||
import torch
|
||||
from packaging import version
|
||||
|
||||
try:
|
||||
pass
|
||||
|
||||
from colossalai.kernel.triton.token_attention_kernel import token_attn_softmax_fwd
|
||||
|
||||
HAS_TRITON = True
|
||||
except ImportError:
|
||||
HAS_TRITON = False
|
||||
print("please install triton from https://github.com/openai/triton")
|
||||
|
||||
TRITON_CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse("11.4")
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not TRITON_CUDA_SUPPORT or not HAS_TRITON, reason="triton requires cuda version to be higher than 11.4"
|
||||
)
|
||||
def test_softmax():
|
||||
import torch
|
||||
|
||||
batch_size, seq_len, head_num, head_dim = 4, 1025, 12, 128
|
||||
|
||||
dtype = torch.float16
|
||||
|
||||
Logics = torch.empty((head_num, batch_size * seq_len), dtype=dtype, device="cuda").normal_(mean=0.1, std=10)
|
||||
ProbOut = torch.empty((head_num, batch_size * seq_len), dtype=dtype, device="cuda").normal_(mean=0.4, std=0.2)
|
||||
|
||||
kv_cache_start_loc = torch.zeros((batch_size,), dtype=torch.int32, device="cuda")
|
||||
kv_cache_seq_len = torch.zeros((batch_size,), dtype=torch.int32, device="cuda")
|
||||
|
||||
for i in range(batch_size):
|
||||
kv_cache_start_loc[i] = i * seq_len
|
||||
kv_cache_seq_len[i] = seq_len
|
||||
|
||||
token_attn_softmax_fwd(Logics, kv_cache_start_loc, kv_cache_seq_len, ProbOut, seq_len)
|
||||
|
||||
torch_out = Logics.reshape(head_num * batch_size, -1).softmax(-1).reshape(head_num, batch_size * seq_len)
|
||||
o = ProbOut
|
||||
print("max ", torch.max(torch.abs(torch_out - o)))
|
||||
print("mean ", torch.mean(torch.abs(torch_out - o)))
|
||||
assert torch.allclose(torch_out, o, atol=1e-2, rtol=0)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_softmax()
|
|
@ -1,14 +1,19 @@
|
|||
import pytest
|
||||
from lazy_init_utils import SUPPORT_LAZY, check_lazy_init
|
||||
|
||||
from tests.kit.model_zoo import model_zoo, IS_FAST_TEST, COMMON_MODELS
|
||||
from tests.kit.model_zoo import COMMON_MODELS, IS_FAST_TEST, model_zoo
|
||||
|
||||
|
||||
@pytest.mark.skipif(not SUPPORT_LAZY, reason="requires torch >= 1.12.0")
|
||||
@pytest.mark.parametrize("subset", [COMMON_MODELS] if IS_FAST_TEST else ["torchvision", "diffusers", "timm", "transformers", "torchaudio", "deepfm", "dlrm"])
|
||||
@pytest.mark.parametrize(
|
||||
"subset",
|
||||
[COMMON_MODELS]
|
||||
if IS_FAST_TEST
|
||||
else ["torchvision", "diffusers", "timm", "transformers", "torchaudio", "deepfm", "dlrm"],
|
||||
)
|
||||
@pytest.mark.parametrize("default_device", ["cpu", "cuda"])
|
||||
def test_torchvision_models_lazy_init(subset, default_device):
|
||||
sub_model_zoo = model_zoo.get_sub_registry(subset)
|
||||
sub_model_zoo = model_zoo.get_sub_registry(subset, allow_empty=True)
|
||||
for name, entry in sub_model_zoo.items():
|
||||
# TODO(ver217): lazy init does not support weight norm, skip these models
|
||||
if name in ("torchaudio_wav2vec2_base", "torchaudio_hubert_base") or name.startswith(
|
||||
|
|
Loading…
Reference in New Issue