[workflow] fixed oom tests (#5275)

* [workflow] fixed oom tests

* polish

* polish

* polish
pull/5276/head^2^2
Frank Lee 10 months ago committed by GitHub
parent 04244aaaf1
commit d69cd2eb89
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -160,9 +160,7 @@ jobs:
--ignore tests/test_gptq \
--ignore tests/test_infer_ops \
--ignore tests/test_legacy \
--ignore tests/test_moe \
--ignore tests/test_smoothquant \
--ignore tests/test_checkpoint_io \
tests/
env:
LD_LIBRARY_PATH: /github/home/.tensornvme/lib:/usr/local/nvidia/lib:/usr/local/nvidia/lib64

@ -61,7 +61,9 @@ class ModelZooRegistry(dict):
"""
self[name] = (model_fn, data_gen_fn, output_transform_fn, loss_fn, model_attribute)
def get_sub_registry(self, keyword: Union[str, List[str]], exclude: Union[str, List[str]] = None):
def get_sub_registry(
self, keyword: Union[str, List[str]], exclude: Union[str, List[str]] = None, allow_empty: bool = False
):
"""
Get a sub registry with models that contain the keyword.
@ -95,7 +97,8 @@ class ModelZooRegistry(dict):
if not should_exclude:
new_dict[k] = v
assert len(new_dict) > 0, f"No model found with keyword {keyword}"
if not allow_empty:
assert len(new_dict) > 0, f"No model found with keyword {keyword}"
return new_dict

@ -63,6 +63,9 @@ config = transformers.GPTJConfig(
n_layer=2,
n_head=4,
vocab_size=50258,
n_embd=256,
hidden_size=256,
n_positions=512,
attn_pdrop=0,
embd_pdrop=0,
resid_pdrop=0,

@ -12,7 +12,13 @@ from colossalai.fx import is_compatible_with_meta
from colossalai.lazy.lazy_init import LazyInitContext
from colossalai.nn.optimizer import HybridAdam
from colossalai.tensor.colo_parameter import ColoParameter
from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn
from colossalai.testing import (
clear_cache_before_run,
parameterize,
rerun_if_address_is_in_use,
skip_if_not_enough_gpus,
spawn,
)
from tests.kit.model_zoo import COMMON_MODELS, IS_FAST_TEST, model_zoo
@ -172,6 +178,7 @@ def test_gemini_plugin(early_stop: bool = True):
@pytest.mark.largedist
@skip_if_not_enough_gpus(8)
@rerun_if_address_is_in_use()
def test_gemini_plugin_3d(early_stop: bool = True):
spawn(run_dist, 8, early_stop=early_stop)

@ -10,8 +10,8 @@ from colossalai.booster import Booster
from colossalai.booster.plugin import LowLevelZeroPlugin
# from colossalai.nn.optimizer import HybridAdam
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
from tests.kit.model_zoo import model_zoo, IS_FAST_TEST, COMMON_MODELS
from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn
from tests.kit.model_zoo import COMMON_MODELS, IS_FAST_TEST, model_zoo
# These models are not compatible with AMP
_AMP_ERR_MODELS = ["timm_convit", "deepfm_interactionarch"]
@ -21,6 +21,7 @@ _LOW_LEVEL_ZERO_ERR_MODELS = ["dlrm_interactionarch"]
_STUCK_MODELS = ["transformers_albert_for_multiple_choice"]
@clear_cache_before_run()
def run_fn(stage, model_fn, data_gen_fn, output_transform_fn) -> Optional[str]:
device = device_utils.get_current_device()
try:

@ -10,10 +10,11 @@ import colossalai
from colossalai.booster import Booster
from colossalai.booster.plugin import TorchDDPPlugin
from colossalai.interface import OptimizerWrapper
from colossalai.testing import rerun_if_address_is_in_use, spawn
from tests.kit.model_zoo import model_zoo, IS_FAST_TEST, COMMON_MODELS
from colossalai.testing import clear_cache_before_run, rerun_if_address_is_in_use, spawn
from tests.kit.model_zoo import COMMON_MODELS, IS_FAST_TEST, model_zoo
@clear_cache_before_run()
def run_fn(model_fn, data_gen_fn, output_transform_fn):
plugin = TorchDDPPlugin()
booster = Booster(plugin=plugin)

@ -11,11 +11,12 @@ if version.parse(torch.__version__) >= version.parse("1.12.0"):
from colossalai.booster.plugin import TorchFSDPPlugin
from colossalai.interface import OptimizerWrapper
from colossalai.testing import rerun_if_address_is_in_use, spawn
from tests.kit.model_zoo import model_zoo, IS_FAST_TEST, COMMON_MODELS
from colossalai.testing import clear_cache_before_run, rerun_if_address_is_in_use, spawn
from tests.kit.model_zoo import COMMON_MODELS, IS_FAST_TEST, model_zoo
# test basic fsdp function
@clear_cache_before_run()
def run_fn(model_fn, data_gen_fn, output_transform_fn):
plugin = TorchFSDPPlugin()
booster = Booster(plugin=plugin)
@ -40,12 +41,18 @@ def run_fn(model_fn, data_gen_fn, output_transform_fn):
optimizer.clip_grad_by_norm(1.0)
optimizer.step()
del model
del optimizer
del criterion
del booster
del plugin
def check_torch_fsdp_plugin():
if IS_FAST_TEST:
registry = model_zoo.get_sub_registry(COMMON_MODELS)
else:
registry = model_zoo
registry = model_zoo.get_sub_registry("transformers_gptj")
for name, (model_fn, data_gen_fn, output_transform_fn, _, _) in registry.items():
if any(
@ -59,6 +66,7 @@ def check_torch_fsdp_plugin():
]
):
continue
print(name)
run_fn(model_fn, data_gen_fn, output_transform_fn)
torch.cuda.empty_cache()
@ -73,3 +81,7 @@ def run_dist(rank, world_size, port):
@rerun_if_address_is_in_use()
def test_torch_fsdp_plugin():
spawn(run_dist, 2)
if __name__ == "__main__":
test_torch_fsdp_plugin()

@ -38,11 +38,11 @@ else:
]
@clear_cache_before_run()
@parameterize("shard", [True, False])
@parameterize("model_name", ["transformers_llama_for_casual_lm"])
@parameterize("size_per_shard", [32])
@parameterize("test_config", TEST_CONFIGS)
@clear_cache_before_run()
def exam_state_dict(shard: bool, model_name: str, size_per_shard: int, test_config: dict):
(model_fn, data_gen_fn, output_transform_fn, loss_fn, _) = next(
iter(model_zoo.get_sub_registry(model_name).values())
@ -145,3 +145,7 @@ def run_dist(rank, world_size, port):
@rerun_if_address_is_in_use()
def test_hybrid_ckpIO(world_size):
spawn(run_dist, world_size)
if __name__ == "__main__":
test_hybrid_ckpIO(4)

@ -1,27 +0,0 @@
import math
import torch
from torch.nn import functional as F
def torch_context_attention(xq, xk, xv, bs, seqlen, num_head, head_dim):
"""
adepted from https://github.com/ModelTC/lightllm/blob/main/lightllm/models/bloom/triton_kernel/context_flashattention_nopad.py#L253
"""
xq = xq.view(bs, seqlen, num_head, head_dim)
xk = xk.view(bs, seqlen, num_head, head_dim)
xv = xv.view(bs, seqlen, num_head, head_dim)
mask = torch.tril(torch.ones(seqlen, seqlen), diagonal=0).unsqueeze(0).unsqueeze(0).cuda()
mask[mask == 0.0] = -100000000.0
mask = mask.repeat(bs, num_head, 1, 1)
keys = xk
values = xv
xq = xq.transpose(1, 2)
keys = keys.transpose(1, 2)
values = values.transpose(1, 2)
sm_scale = 1 / math.sqrt(head_dim)
scores = torch.matmul(xq, keys.transpose(2, 3)) * sm_scale
scores = F.softmax(scores.float() + mask, dim=-1).to(dtype=torch.float16)
output = torch.matmul(scores, values).transpose(1, 2).contiguous().reshape(-1, num_head, head_dim)
return output

@ -1,52 +0,0 @@
import pytest
import torch
from packaging import version
try:
pass
from colossalai.kernel.triton import bloom_context_attn_fwd
from tests.test_infer_ops.triton.kernel_utils import torch_context_attention
HAS_TRITON = True
except ImportError:
HAS_TRITON = False
print("please install triton from https://github.com/openai/triton")
TRITON_CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse("11.4")
@pytest.mark.skipif(
not TRITON_CUDA_SUPPORT or not HAS_TRITON, reason="triton requires cuda version to be higher than 11.4"
)
def test_bloom_context_attention():
bs = 4
head_num = 8
seq_len = 1024
head_dim = 64
query = torch.randn((bs * seq_len, head_num, head_dim), dtype=torch.float16, device="cuda")
k = torch.randn((bs * seq_len, head_num, head_dim), dtype=torch.float16, device="cuda")
v = torch.randn((bs * seq_len, head_num, head_dim), dtype=torch.float16, device="cuda")
max_input_len = seq_len
b_start = torch.zeros((bs,), device="cuda", dtype=torch.int32)
b_len = torch.zeros((bs,), device="cuda", dtype=torch.int32)
for i in range(bs):
b_start[i] = i * seq_len
b_len[i] = seq_len
o = torch.randn((bs * seq_len, head_num, head_dim), dtype=torch.float16, device="cuda")
alibi = torch.zeros((head_num,), dtype=torch.float32, device="cuda")
bloom_context_attn_fwd(query.clone(), k.clone(), v.clone(), o, b_start, b_len, max_input_len, alibi)
torch_out = torch_context_attention(query.clone(), k.clone(), v.clone(), bs, seq_len, head_num, head_dim)
assert torch.allclose(
torch_out.cpu(), o.cpu(), rtol=1e-3, atol=1e-2
), "outputs from triton and torch are not matched"
if __name__ == "__main__":
test_bloom_context_attention()

@ -1,39 +0,0 @@
import pytest
import torch
from packaging import version
try:
pass
from colossalai.kernel.triton.copy_kv_cache_dest import copy_kv_cache_to_dest
HAS_TRITON = True
except ImportError:
HAS_TRITON = False
print("please install triton from https://github.com/openai/triton")
TRITON_CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse("11.4")
@pytest.mark.skipif(
not TRITON_CUDA_SUPPORT or not HAS_TRITON, reason="triton requires cuda version to be higher than 11.4"
)
def test_kv_cache_copy_op():
B_NTX = 32 * 2048
head_num = 8
head_dim = 64
cache = torch.randn((B_NTX, head_num, head_dim), device="cuda", dtype=torch.float16)
dest_index = torch.arange(0, B_NTX, device="cuda", dtype=torch.int32)
dest_data = torch.ones((B_NTX, head_num, head_dim), device="cuda", dtype=torch.float16)
copy_kv_cache_to_dest(cache, dest_index, dest_data)
assert torch.allclose(
cache.cpu(), dest_data.cpu(), rtol=1e-3, atol=1e-3
), "copy_kv_cache_to_dest outputs from triton and torch are not matched"
if __name__ == "__main__":
test_kv_cache_copy_op()

@ -1,43 +0,0 @@
import pytest
import torch
from packaging import version
from colossalai.kernel.triton import layer_norm
from colossalai.testing.utils import parameterize
try:
pass
HAS_TRITON = True
except ImportError:
HAS_TRITON = False
print("please install triton from https://github.com/openai/triton")
TRITON_CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse("11.4")
@pytest.mark.skipif(
not TRITON_CUDA_SUPPORT or not HAS_TRITON, reason="triton requires cuda version to be higher than 11.4"
)
@parameterize("M", [2, 4, 8, 16])
@parameterize("N", [64, 128])
def test_layer_norm(M, N):
dtype = torch.float16
eps = 1e-5
x_shape = (M, N)
w_shape = (x_shape[-1],)
weight = torch.rand(w_shape, dtype=dtype, device="cuda")
bias = torch.rand(w_shape, dtype=dtype, device="cuda")
x = -2.3 + 0.5 * torch.randn(x_shape, dtype=dtype, device="cuda")
y_triton = layer_norm(x, weight, bias, eps)
y_torch = torch.nn.functional.layer_norm(x, w_shape, weight, bias, eps).to(dtype)
assert y_triton.shape == y_torch.shape
assert y_triton.dtype == y_torch.dtype
print("max delta: ", torch.max(torch.abs(y_triton - y_torch)))
assert torch.allclose(y_triton, y_torch, atol=1e-2, rtol=0)
if __name__ == "__main__":
test_layer_norm()

@ -1,56 +0,0 @@
import pytest
import torch
from packaging import version
from torch import nn
from colossalai.kernel.triton.llama_act_combine_kernel import LlamaActCombine
try:
import triton
HAS_TRITON = True
except ImportError:
HAS_TRITON = False
print("please install triton from https://github.com/openai/triton")
TRITON_CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse('11.4')
BATCH_SIZE = 4
SEQ_LEN = 16
HIDDEN_SIZE = 32
def SwiGLU(x):
"""Gated linear unit activation function.
Args:
x : input array
axis: the axis along which the split should be computed (default: -1)
"""
size = x.shape[-1]
assert size % 2 == 0, "axis size must be divisible by 2"
x1, x2 = torch.split(x, size // 2, -1)
return x1 * (x2 * torch.sigmoid(x2.to(torch.float32)).to(x.dtype))
@pytest.mark.skipif(not (HAS_TRITON and TRITON_CUDA_SUPPORT), reason="requires triton")
@pytest.mark.parametrize("dtype", [torch.float32, torch.float16])
def test_llama_act_combine(dtype: str):
x_gate = torch.randn(BATCH_SIZE, SEQ_LEN, HIDDEN_SIZE * 2, dtype=dtype).cuda()
x_gate_torch = nn.Parameter(x_gate.detach().clone())
x_gate_kernel = nn.Parameter(x_gate.detach().clone())
x_up = torch.randn(BATCH_SIZE, SEQ_LEN, HIDDEN_SIZE, dtype=dtype).cuda()
x_up_torch = nn.Parameter(x_up.detach().clone())
x_up_kernel = nn.Parameter(x_up.detach().clone())
torch_out = SwiGLU(x_gate_torch) * x_up_torch
kernel_out = LlamaActCombine.apply(x_gate_kernel, x_up_kernel)
atol = 1e-5 if dtype == torch.float32 else 5e-2
assert torch.allclose(torch_out, kernel_out, atol=atol)
torch_out.mean().backward()
kernel_out.mean().backward()
assert all(grad is not None for grad in [x_gate_torch.grad, x_up_torch.grad, x_gate_kernel.grad, x_up_kernel.grad])
assert torch.allclose(x_gate_torch.grad, x_gate_kernel.grad, atol=atol)
assert torch.allclose(x_up_torch.grad, x_up_kernel.grad, atol=atol)
if __name__ == '__main__':
test_llama_act_combine(torch.float16)

@ -1,50 +0,0 @@
import pytest
import torch
from packaging import version
try:
pass
from colossalai.kernel.triton import llama_context_attn_fwd
from tests.test_infer_ops.triton.kernel_utils import torch_context_attention
HAS_TRITON = True
except ImportError:
HAS_TRITON = False
print("please install triton from https://github.com/openai/triton")
TRITON_CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse("11.4")
@pytest.mark.skipif(
not TRITON_CUDA_SUPPORT or not HAS_TRITON, reason="triton requires cuda version to be higher than 11.4"
)
def test_llama_context_attention():
bs = 4
head_num = 8
seq_len = 1024
head_dim = 64
query = torch.randn((bs * seq_len, head_num, head_dim), dtype=torch.float16, device="cuda")
k = torch.randn((bs * seq_len, head_num, head_dim), dtype=torch.float16, device="cuda")
v = torch.randn((bs * seq_len, head_num, head_dim), dtype=torch.float16, device="cuda")
max_input_len = seq_len
b_start = torch.zeros((bs,), device="cuda", dtype=torch.int32)
b_len = torch.zeros((bs,), device="cuda", dtype=torch.int32)
for i in range(bs):
b_start[i] = i * seq_len
b_len[i] = seq_len
o = torch.randn((bs * seq_len, head_num, head_dim), dtype=torch.float16, device="cuda")
llama_context_attn_fwd(query.clone(), k.clone(), v.clone(), o, b_start, b_len, max_input_len)
torch_out = torch_context_attention(query.clone(), k.clone(), v.clone(), bs, seq_len, head_num, head_dim)
assert torch.allclose(
torch_out.cpu(), o.cpu(), rtol=1e-3, atol=1e-3
), "outputs from triton and torch are not matched"
if __name__ == "__main__":
test_llama_context_attention()

@ -1,143 +0,0 @@
import pytest
import torch
import torch.nn.functional as F
from packaging import version
try:
import triton
from colossalai.kernel.triton.qkv_matmul_kernel import qkv_gemm_4d_kernel
from colossalai.kernel.triton.self_attention_nofusion import self_attention_compute_using_triton
HAS_TRITON = True
except ImportError:
HAS_TRITON = False
print("please install triton from https://github.com/openai/triton")
TRITON_CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse("11.4")
@pytest.mark.skipif(
not TRITON_CUDA_SUPPORT or not HAS_TRITON, reason="triton requires cuda version to be higher than 11.4"
)
def test_qkv_matmul():
qkv = torch.randn((4, 24, 64 * 3), device="cuda", dtype=torch.float16)
scale = 1.2
head_size = 32
batches = qkv.shape[0]
d_model = qkv.shape[-1] // 3
num_of_heads = d_model // head_size
q = qkv[:, :, :d_model]
k = qkv[:, :, d_model : d_model * 2]
q = q.view(batches, -1, num_of_heads, head_size)
k = k.view(batches, -1, num_of_heads, head_size)
q_copy = q.clone()
k_copy = k.clone()
q = torch.transpose(q, 1, 2).contiguous()
k = torch.transpose(k, 1, 2).contiguous()
k = torch.transpose(k, 2, 3).contiguous()
torch_ouput = torch.einsum("bnij,bnjk->bnik", q, k)
torch_ouput *= 1.2
q, k = q_copy, k_copy
batches, M, H, K = q.shape
N = k.shape[1]
score_output = torch.empty((batches, H, M, N), device=q.device, dtype=q.dtype)
grid = lambda meta: (
batches,
H,
triton.cdiv(M, meta["BLOCK_SIZE_M"]) * triton.cdiv(N, meta["BLOCK_SIZE_N"]),
)
K = q.shape[3]
qkv_gemm_4d_kernel[grid](
q,
k,
score_output,
M,
N,
K,
q.stride(0),
q.stride(2),
q.stride(1),
q.stride(3),
k.stride(0),
k.stride(2),
k.stride(3),
k.stride(1),
score_output.stride(0),
score_output.stride(1),
score_output.stride(2),
score_output.stride(3),
scale=scale,
# currently manually setting, later on we can use auto-tune config to match best setting
BLOCK_SIZE_M=64,
BLOCK_SIZE_N=32,
BLOCK_SIZE_K=32,
GROUP_SIZE_M=8,
)
check = torch.allclose(torch_ouput.cpu(), score_output.cpu(), rtol=1e-3, atol=1e-5)
assert check is True, "the outputs of triton and torch are not matched"
def self_attention_compute_using_torch(qkv, input_mask, scale, head_size):
batches = qkv.shape[0]
d_model = qkv.shape[-1] // 3
num_of_heads = d_model // head_size
q = qkv[:, :, :d_model]
k = qkv[:, :, d_model : d_model * 2]
v = qkv[:, :, d_model * 2 :]
q = q.view(batches, -1, num_of_heads, head_size)
k = k.view(batches, -1, num_of_heads, head_size)
v = v.view(batches, -1, num_of_heads, head_size)
q = torch.transpose(q, 1, 2).contiguous()
k = torch.transpose(k, 1, 2).contiguous()
v = torch.transpose(v, 1, 2).contiguous()
k = torch.transpose(k, -1, -2).contiguous()
score_output = torch.einsum("bnij,bnjk->bnik", q, k)
score_output *= scale
softmax_output = F.softmax(score_output, dim=-1)
res = torch.einsum("bnij,bnjk->bnik", softmax_output, v)
res = torch.transpose(res, 1, 2)
res = res.contiguous()
return res.view(batches, -1, d_model), score_output, softmax_output
@pytest.mark.skipif(
not TRITON_CUDA_SUPPORT or not HAS_TRITON, reason="triton requires cuda version to be higher than 11.4"
)
def test_self_atttention_test():
qkv = torch.randn((4, 24, 64 * 3), device="cuda", dtype=torch.float16)
data_output_torch, score_output_torch, softmax_output_torch = self_attention_compute_using_torch(
qkv.clone(), input_mask=None, scale=1.2, head_size=32
)
data_output_triton = self_attention_compute_using_triton(
qkv.clone(),
alibi=None,
head_size=32,
scale=1.2,
input_mask=None,
layer_past=None,
use_flash=False,
triangular=True,
)
check = torch.allclose(data_output_triton.cpu(), data_output_torch.cpu(), rtol=1e-4, atol=1e-2)
assert check is True, "the triton output is not matched with torch output"
if __name__ == "__main__":
test_qkv_matmul()
test_self_atttention_test()

@ -1,36 +0,0 @@
import pytest
import torch
from packaging import version
from torch import nn
try:
from colossalai.kernel.triton.softmax import softmax
HAS_TRITON = True
except ImportError:
HAS_TRITON = False
print("please install triton from https://github.com/openai/triton")
TRITON_CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse("11.4")
@pytest.mark.skipif(
not TRITON_CUDA_SUPPORT or not HAS_TRITON, reason="triton requires cuda version to be higher than 11.4"
)
def test_softmax_op():
data_samples = [
torch.randn((3, 4, 5, 32), device="cuda", dtype=torch.float32),
torch.randn((320, 320, 78), device="cuda", dtype=torch.float32),
torch.randn((2345, 4, 5, 64), device="cuda", dtype=torch.float16),
]
for data in data_samples:
module = nn.Softmax(dim=-1)
data_torch_out = module(data)
data_triton_out = softmax(data)
check = torch.allclose(data_torch_out.cpu(), data_triton_out.cpu(), rtol=1e-3, atol=1e-3)
assert check is True, "softmax outputs from triton and torch are not matched"
if __name__ == "__main__":
test_softmax_op()

@ -1,72 +0,0 @@
import pytest
import torch
from packaging import version
try:
from colossalai.kernel.triton.token_attention_kernel import token_attention_fwd
HAS_TRITON = True
except ImportError:
HAS_TRITON = False
print("please install triton from https://github.com/openai/triton")
import importlib.util
HAS_LIGHTLLM_KERNEL = True
if importlib.util.find_spec("lightllm") is None:
HAS_LIGHTLLM_KERNEL = False
TRITON_CUDA_SUPPORT = version.parse(torch.version.cuda) >= version.parse("11.6")
def torch_att(xq, xk, xv, bs, seqlen, num_head, head_dim):
xq = xq.view(bs, 1, num_head, head_dim)
xk = xk.view(bs, seqlen, num_head, head_dim)
xv = xv.view(bs, seqlen, num_head, head_dim)
logics = torch.sum(xq * xk, dim=3, keepdim=False) * 1 / (head_dim**0.5)
prob = torch.softmax(logics, dim=1)
prob = prob.view(bs, seqlen, num_head, 1)
return torch.sum(prob * xv, dim=1, keepdim=False)
@pytest.mark.skipif(
not TRITON_CUDA_SUPPORT or not HAS_TRITON or not HAS_LIGHTLLM_KERNEL,
reason="triton requires cuda version to be higher than 11.4 or not install lightllm",
)
def test():
Z, head_num, seq_len, head_dim = 22, 112 // 8, 2048, 128
dtype = torch.float16
q = torch.empty((Z, head_num, head_dim), dtype=dtype, device="cuda").normal_(mean=0.1, std=0.2)
k = torch.empty((Z * seq_len, head_num, head_dim), dtype=dtype, device="cuda").normal_(mean=0.4, std=0.2)
v = torch.empty((Z * seq_len, head_num, head_dim), dtype=dtype, device="cuda").normal_(mean=0.3, std=0.2)
o = torch.empty((Z, head_num, head_dim), dtype=dtype, device="cuda").normal_(mean=0.3, std=0.2)
alibi = torch.zeros((head_num,), dtype=torch.float32, device="cuda")
max_kv_cache_len = seq_len
kv_cache_start_loc = torch.zeros((Z,), dtype=torch.int32, device="cuda")
kv_cache_loc = torch.zeros((Z, seq_len), dtype=torch.int32, device="cuda")
kv_cache_seq_len = torch.ones((Z,), dtype=torch.int32, device="cuda")
kv_cache_seq_len[:] = seq_len
kv_cache_start_loc[0] = 0
kv_cache_start_loc[1] = seq_len
kv_cache_start_loc[2] = 2 * seq_len
kv_cache_start_loc[3] = 3 * seq_len
for i in range(Z):
kv_cache_loc[i, :] = torch.arange(i * seq_len, (i + 1) * seq_len, dtype=torch.int32, device="cuda")
token_attention_fwd(q, k, v, o, kv_cache_loc, kv_cache_start_loc, kv_cache_seq_len, max_kv_cache_len, alibi=alibi)
torch_out = torch_att(q, k, v, Z, seq_len, head_num, head_dim)
print("max ", torch.max(torch.abs(torch_out - o)))
print("mean ", torch.mean(torch.abs(torch_out - o)))
assert torch.allclose(torch_out, o, atol=1e-2, rtol=0)
if __name__ == "__main__":
test()

@ -1,48 +0,0 @@
import pytest
import torch
from packaging import version
try:
pass
from colossalai.kernel.triton.token_attention_kernel import token_attn_softmax_fwd
HAS_TRITON = True
except ImportError:
HAS_TRITON = False
print("please install triton from https://github.com/openai/triton")
TRITON_CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse("11.4")
@pytest.mark.skipif(
not TRITON_CUDA_SUPPORT or not HAS_TRITON, reason="triton requires cuda version to be higher than 11.4"
)
def test_softmax():
import torch
batch_size, seq_len, head_num, head_dim = 4, 1025, 12, 128
dtype = torch.float16
Logics = torch.empty((head_num, batch_size * seq_len), dtype=dtype, device="cuda").normal_(mean=0.1, std=10)
ProbOut = torch.empty((head_num, batch_size * seq_len), dtype=dtype, device="cuda").normal_(mean=0.4, std=0.2)
kv_cache_start_loc = torch.zeros((batch_size,), dtype=torch.int32, device="cuda")
kv_cache_seq_len = torch.zeros((batch_size,), dtype=torch.int32, device="cuda")
for i in range(batch_size):
kv_cache_start_loc[i] = i * seq_len
kv_cache_seq_len[i] = seq_len
token_attn_softmax_fwd(Logics, kv_cache_start_loc, kv_cache_seq_len, ProbOut, seq_len)
torch_out = Logics.reshape(head_num * batch_size, -1).softmax(-1).reshape(head_num, batch_size * seq_len)
o = ProbOut
print("max ", torch.max(torch.abs(torch_out - o)))
print("mean ", torch.mean(torch.abs(torch_out - o)))
assert torch.allclose(torch_out, o, atol=1e-2, rtol=0)
if __name__ == "__main__":
test_softmax()

@ -1,14 +1,19 @@
import pytest
from lazy_init_utils import SUPPORT_LAZY, check_lazy_init
from tests.kit.model_zoo import model_zoo, IS_FAST_TEST, COMMON_MODELS
from tests.kit.model_zoo import COMMON_MODELS, IS_FAST_TEST, model_zoo
@pytest.mark.skipif(not SUPPORT_LAZY, reason="requires torch >= 1.12.0")
@pytest.mark.parametrize("subset", [COMMON_MODELS] if IS_FAST_TEST else ["torchvision", "diffusers", "timm", "transformers", "torchaudio", "deepfm", "dlrm"])
@pytest.mark.parametrize(
"subset",
[COMMON_MODELS]
if IS_FAST_TEST
else ["torchvision", "diffusers", "timm", "transformers", "torchaudio", "deepfm", "dlrm"],
)
@pytest.mark.parametrize("default_device", ["cpu", "cuda"])
def test_torchvision_models_lazy_init(subset, default_device):
sub_model_zoo = model_zoo.get_sub_registry(subset)
sub_model_zoo = model_zoo.get_sub_registry(subset, allow_empty=True)
for name, entry in sub_model_zoo.items():
# TODO(ver217): lazy init does not support weight norm, skip these models
if name in ("torchaudio_wav2vec2_base", "torchaudio_hubert_base") or name.startswith(

Loading…
Cancel
Save