From d69cd2eb89e06e8cf165873e0e61c2dfacec9cb9 Mon Sep 17 00:00:00 2001 From: Frank Lee Date: Tue, 16 Jan 2024 18:55:13 +0800 Subject: [PATCH] [workflow] fixed oom tests (#5275) * [workflow] fixed oom tests * polish * polish * polish --- .github/workflows/build_on_pr.yml | 2 - tests/kit/model_zoo/registry.py | 7 +- tests/kit/model_zoo/transformers/gptj.py | 3 + .../test_plugin/test_gemini_plugin.py | 9 +- .../test_plugin/test_low_level_zero_plugin.py | 5 +- .../test_plugin/test_torch_ddp_plugin.py | 5 +- .../test_plugin/test_torch_fsdp_plugin.py | 18 ++- ...st_hybrid_parallel_plugin_checkpoint_io.py | 6 +- tests/test_infer_ops/triton/kernel_utils.py | 27 ---- .../triton/test_bloom_context_attention.py | 52 ------- .../triton/test_copy_kv_dest.py | 39 ----- .../triton/test_layernorm_triton.py | 43 ------ .../triton/test_llama_act_combine.py | 56 ------- .../triton/test_llama_context_attention.py | 50 ------ .../triton/test_self_attention_nonfusion.py | 143 ------------------ tests/test_infer_ops/triton/test_softmax.py | 36 ----- .../triton/test_token_attn_fwd.py | 72 --------- .../triton/test_token_softmax.py | 48 ------ tests/test_lazy/test_models.py | 11 +- 19 files changed, 50 insertions(+), 582 deletions(-) delete mode 100644 tests/test_infer_ops/triton/kernel_utils.py delete mode 100644 tests/test_infer_ops/triton/test_bloom_context_attention.py delete mode 100644 tests/test_infer_ops/triton/test_copy_kv_dest.py delete mode 100644 tests/test_infer_ops/triton/test_layernorm_triton.py delete mode 100644 tests/test_infer_ops/triton/test_llama_act_combine.py delete mode 100644 tests/test_infer_ops/triton/test_llama_context_attention.py delete mode 100644 tests/test_infer_ops/triton/test_self_attention_nonfusion.py delete mode 100644 tests/test_infer_ops/triton/test_softmax.py delete mode 100644 tests/test_infer_ops/triton/test_token_attn_fwd.py delete mode 100644 tests/test_infer_ops/triton/test_token_softmax.py diff --git a/.github/workflows/build_on_pr.yml b/.github/workflows/build_on_pr.yml index 54e8a6d93..a34a60669 100644 --- a/.github/workflows/build_on_pr.yml +++ b/.github/workflows/build_on_pr.yml @@ -160,9 +160,7 @@ jobs: --ignore tests/test_gptq \ --ignore tests/test_infer_ops \ --ignore tests/test_legacy \ - --ignore tests/test_moe \ --ignore tests/test_smoothquant \ - --ignore tests/test_checkpoint_io \ tests/ env: LD_LIBRARY_PATH: /github/home/.tensornvme/lib:/usr/local/nvidia/lib:/usr/local/nvidia/lib64 diff --git a/tests/kit/model_zoo/registry.py b/tests/kit/model_zoo/registry.py index 5e8e0b382..a16b16ad6 100644 --- a/tests/kit/model_zoo/registry.py +++ b/tests/kit/model_zoo/registry.py @@ -61,7 +61,9 @@ class ModelZooRegistry(dict): """ self[name] = (model_fn, data_gen_fn, output_transform_fn, loss_fn, model_attribute) - def get_sub_registry(self, keyword: Union[str, List[str]], exclude: Union[str, List[str]] = None): + def get_sub_registry( + self, keyword: Union[str, List[str]], exclude: Union[str, List[str]] = None, allow_empty: bool = False + ): """ Get a sub registry with models that contain the keyword. @@ -95,7 +97,8 @@ class ModelZooRegistry(dict): if not should_exclude: new_dict[k] = v - assert len(new_dict) > 0, f"No model found with keyword {keyword}" + if not allow_empty: + assert len(new_dict) > 0, f"No model found with keyword {keyword}" return new_dict diff --git a/tests/kit/model_zoo/transformers/gptj.py b/tests/kit/model_zoo/transformers/gptj.py index 9eefbb43d..c89124f01 100644 --- a/tests/kit/model_zoo/transformers/gptj.py +++ b/tests/kit/model_zoo/transformers/gptj.py @@ -63,6 +63,9 @@ config = transformers.GPTJConfig( n_layer=2, n_head=4, vocab_size=50258, + n_embd=256, + hidden_size=256, + n_positions=512, attn_pdrop=0, embd_pdrop=0, resid_pdrop=0, diff --git a/tests/test_booster/test_plugin/test_gemini_plugin.py b/tests/test_booster/test_plugin/test_gemini_plugin.py index 9952e41e5..17dfa3a18 100644 --- a/tests/test_booster/test_plugin/test_gemini_plugin.py +++ b/tests/test_booster/test_plugin/test_gemini_plugin.py @@ -12,7 +12,13 @@ from colossalai.fx import is_compatible_with_meta from colossalai.lazy.lazy_init import LazyInitContext from colossalai.nn.optimizer import HybridAdam from colossalai.tensor.colo_parameter import ColoParameter -from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn +from colossalai.testing import ( + clear_cache_before_run, + parameterize, + rerun_if_address_is_in_use, + skip_if_not_enough_gpus, + spawn, +) from tests.kit.model_zoo import COMMON_MODELS, IS_FAST_TEST, model_zoo @@ -172,6 +178,7 @@ def test_gemini_plugin(early_stop: bool = True): @pytest.mark.largedist +@skip_if_not_enough_gpus(8) @rerun_if_address_is_in_use() def test_gemini_plugin_3d(early_stop: bool = True): spawn(run_dist, 8, early_stop=early_stop) diff --git a/tests/test_booster/test_plugin/test_low_level_zero_plugin.py b/tests/test_booster/test_plugin/test_low_level_zero_plugin.py index bcdcc1470..286f431d5 100644 --- a/tests/test_booster/test_plugin/test_low_level_zero_plugin.py +++ b/tests/test_booster/test_plugin/test_low_level_zero_plugin.py @@ -10,8 +10,8 @@ from colossalai.booster import Booster from colossalai.booster.plugin import LowLevelZeroPlugin # from colossalai.nn.optimizer import HybridAdam -from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn -from tests.kit.model_zoo import model_zoo, IS_FAST_TEST, COMMON_MODELS +from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn +from tests.kit.model_zoo import COMMON_MODELS, IS_FAST_TEST, model_zoo # These models are not compatible with AMP _AMP_ERR_MODELS = ["timm_convit", "deepfm_interactionarch"] @@ -21,6 +21,7 @@ _LOW_LEVEL_ZERO_ERR_MODELS = ["dlrm_interactionarch"] _STUCK_MODELS = ["transformers_albert_for_multiple_choice"] +@clear_cache_before_run() def run_fn(stage, model_fn, data_gen_fn, output_transform_fn) -> Optional[str]: device = device_utils.get_current_device() try: diff --git a/tests/test_booster/test_plugin/test_torch_ddp_plugin.py b/tests/test_booster/test_plugin/test_torch_ddp_plugin.py index fa32feb2f..e785843fb 100644 --- a/tests/test_booster/test_plugin/test_torch_ddp_plugin.py +++ b/tests/test_booster/test_plugin/test_torch_ddp_plugin.py @@ -10,10 +10,11 @@ import colossalai from colossalai.booster import Booster from colossalai.booster.plugin import TorchDDPPlugin from colossalai.interface import OptimizerWrapper -from colossalai.testing import rerun_if_address_is_in_use, spawn -from tests.kit.model_zoo import model_zoo, IS_FAST_TEST, COMMON_MODELS +from colossalai.testing import clear_cache_before_run, rerun_if_address_is_in_use, spawn +from tests.kit.model_zoo import COMMON_MODELS, IS_FAST_TEST, model_zoo +@clear_cache_before_run() def run_fn(model_fn, data_gen_fn, output_transform_fn): plugin = TorchDDPPlugin() booster = Booster(plugin=plugin) diff --git a/tests/test_booster/test_plugin/test_torch_fsdp_plugin.py b/tests/test_booster/test_plugin/test_torch_fsdp_plugin.py index 8a14d7cf8..f69807046 100644 --- a/tests/test_booster/test_plugin/test_torch_fsdp_plugin.py +++ b/tests/test_booster/test_plugin/test_torch_fsdp_plugin.py @@ -11,11 +11,12 @@ if version.parse(torch.__version__) >= version.parse("1.12.0"): from colossalai.booster.plugin import TorchFSDPPlugin from colossalai.interface import OptimizerWrapper -from colossalai.testing import rerun_if_address_is_in_use, spawn -from tests.kit.model_zoo import model_zoo, IS_FAST_TEST, COMMON_MODELS +from colossalai.testing import clear_cache_before_run, rerun_if_address_is_in_use, spawn +from tests.kit.model_zoo import COMMON_MODELS, IS_FAST_TEST, model_zoo # test basic fsdp function +@clear_cache_before_run() def run_fn(model_fn, data_gen_fn, output_transform_fn): plugin = TorchFSDPPlugin() booster = Booster(plugin=plugin) @@ -40,12 +41,18 @@ def run_fn(model_fn, data_gen_fn, output_transform_fn): optimizer.clip_grad_by_norm(1.0) optimizer.step() + del model + del optimizer + del criterion + del booster + del plugin + def check_torch_fsdp_plugin(): if IS_FAST_TEST: registry = model_zoo.get_sub_registry(COMMON_MODELS) else: - registry = model_zoo + registry = model_zoo.get_sub_registry("transformers_gptj") for name, (model_fn, data_gen_fn, output_transform_fn, _, _) in registry.items(): if any( @@ -59,6 +66,7 @@ def check_torch_fsdp_plugin(): ] ): continue + print(name) run_fn(model_fn, data_gen_fn, output_transform_fn) torch.cuda.empty_cache() @@ -73,3 +81,7 @@ def run_dist(rank, world_size, port): @rerun_if_address_is_in_use() def test_torch_fsdp_plugin(): spawn(run_dist, 2) + + +if __name__ == "__main__": + test_torch_fsdp_plugin() diff --git a/tests/test_checkpoint_io/test_hybrid_parallel_plugin_checkpoint_io.py b/tests/test_checkpoint_io/test_hybrid_parallel_plugin_checkpoint_io.py index db3c56da8..865262cae 100644 --- a/tests/test_checkpoint_io/test_hybrid_parallel_plugin_checkpoint_io.py +++ b/tests/test_checkpoint_io/test_hybrid_parallel_plugin_checkpoint_io.py @@ -38,11 +38,11 @@ else: ] -@clear_cache_before_run() @parameterize("shard", [True, False]) @parameterize("model_name", ["transformers_llama_for_casual_lm"]) @parameterize("size_per_shard", [32]) @parameterize("test_config", TEST_CONFIGS) +@clear_cache_before_run() def exam_state_dict(shard: bool, model_name: str, size_per_shard: int, test_config: dict): (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) = next( iter(model_zoo.get_sub_registry(model_name).values()) @@ -145,3 +145,7 @@ def run_dist(rank, world_size, port): @rerun_if_address_is_in_use() def test_hybrid_ckpIO(world_size): spawn(run_dist, world_size) + + +if __name__ == "__main__": + test_hybrid_ckpIO(4) diff --git a/tests/test_infer_ops/triton/kernel_utils.py b/tests/test_infer_ops/triton/kernel_utils.py deleted file mode 100644 index 0732ace1e..000000000 --- a/tests/test_infer_ops/triton/kernel_utils.py +++ /dev/null @@ -1,27 +0,0 @@ -import math - -import torch -from torch.nn import functional as F - - -def torch_context_attention(xq, xk, xv, bs, seqlen, num_head, head_dim): - """ - adepted from https://github.com/ModelTC/lightllm/blob/main/lightllm/models/bloom/triton_kernel/context_flashattention_nopad.py#L253 - """ - xq = xq.view(bs, seqlen, num_head, head_dim) - xk = xk.view(bs, seqlen, num_head, head_dim) - xv = xv.view(bs, seqlen, num_head, head_dim) - mask = torch.tril(torch.ones(seqlen, seqlen), diagonal=0).unsqueeze(0).unsqueeze(0).cuda() - mask[mask == 0.0] = -100000000.0 - mask = mask.repeat(bs, num_head, 1, 1) - keys = xk - values = xv - xq = xq.transpose(1, 2) - keys = keys.transpose(1, 2) - values = values.transpose(1, 2) - sm_scale = 1 / math.sqrt(head_dim) - scores = torch.matmul(xq, keys.transpose(2, 3)) * sm_scale - scores = F.softmax(scores.float() + mask, dim=-1).to(dtype=torch.float16) - - output = torch.matmul(scores, values).transpose(1, 2).contiguous().reshape(-1, num_head, head_dim) - return output diff --git a/tests/test_infer_ops/triton/test_bloom_context_attention.py b/tests/test_infer_ops/triton/test_bloom_context_attention.py deleted file mode 100644 index 7a6c218a6..000000000 --- a/tests/test_infer_ops/triton/test_bloom_context_attention.py +++ /dev/null @@ -1,52 +0,0 @@ -import pytest -import torch -from packaging import version - -try: - pass - - from colossalai.kernel.triton import bloom_context_attn_fwd - from tests.test_infer_ops.triton.kernel_utils import torch_context_attention - - HAS_TRITON = True -except ImportError: - HAS_TRITON = False - print("please install triton from https://github.com/openai/triton") - -TRITON_CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse("11.4") - - -@pytest.mark.skipif( - not TRITON_CUDA_SUPPORT or not HAS_TRITON, reason="triton requires cuda version to be higher than 11.4" -) -def test_bloom_context_attention(): - bs = 4 - head_num = 8 - seq_len = 1024 - head_dim = 64 - - query = torch.randn((bs * seq_len, head_num, head_dim), dtype=torch.float16, device="cuda") - k = torch.randn((bs * seq_len, head_num, head_dim), dtype=torch.float16, device="cuda") - v = torch.randn((bs * seq_len, head_num, head_dim), dtype=torch.float16, device="cuda") - - max_input_len = seq_len - b_start = torch.zeros((bs,), device="cuda", dtype=torch.int32) - b_len = torch.zeros((bs,), device="cuda", dtype=torch.int32) - - for i in range(bs): - b_start[i] = i * seq_len - b_len[i] = seq_len - - o = torch.randn((bs * seq_len, head_num, head_dim), dtype=torch.float16, device="cuda") - alibi = torch.zeros((head_num,), dtype=torch.float32, device="cuda") - bloom_context_attn_fwd(query.clone(), k.clone(), v.clone(), o, b_start, b_len, max_input_len, alibi) - - torch_out = torch_context_attention(query.clone(), k.clone(), v.clone(), bs, seq_len, head_num, head_dim) - - assert torch.allclose( - torch_out.cpu(), o.cpu(), rtol=1e-3, atol=1e-2 - ), "outputs from triton and torch are not matched" - - -if __name__ == "__main__": - test_bloom_context_attention() diff --git a/tests/test_infer_ops/triton/test_copy_kv_dest.py b/tests/test_infer_ops/triton/test_copy_kv_dest.py deleted file mode 100644 index 34e453f78..000000000 --- a/tests/test_infer_ops/triton/test_copy_kv_dest.py +++ /dev/null @@ -1,39 +0,0 @@ -import pytest -import torch -from packaging import version - -try: - pass - - from colossalai.kernel.triton.copy_kv_cache_dest import copy_kv_cache_to_dest - - HAS_TRITON = True -except ImportError: - HAS_TRITON = False - print("please install triton from https://github.com/openai/triton") - -TRITON_CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse("11.4") - - -@pytest.mark.skipif( - not TRITON_CUDA_SUPPORT or not HAS_TRITON, reason="triton requires cuda version to be higher than 11.4" -) -def test_kv_cache_copy_op(): - B_NTX = 32 * 2048 - head_num = 8 - head_dim = 64 - - cache = torch.randn((B_NTX, head_num, head_dim), device="cuda", dtype=torch.float16) - dest_index = torch.arange(0, B_NTX, device="cuda", dtype=torch.int32) - - dest_data = torch.ones((B_NTX, head_num, head_dim), device="cuda", dtype=torch.float16) - - copy_kv_cache_to_dest(cache, dest_index, dest_data) - - assert torch.allclose( - cache.cpu(), dest_data.cpu(), rtol=1e-3, atol=1e-3 - ), "copy_kv_cache_to_dest outputs from triton and torch are not matched" - - -if __name__ == "__main__": - test_kv_cache_copy_op() diff --git a/tests/test_infer_ops/triton/test_layernorm_triton.py b/tests/test_infer_ops/triton/test_layernorm_triton.py deleted file mode 100644 index 7f814e8c9..000000000 --- a/tests/test_infer_ops/triton/test_layernorm_triton.py +++ /dev/null @@ -1,43 +0,0 @@ -import pytest -import torch -from packaging import version - -from colossalai.kernel.triton import layer_norm -from colossalai.testing.utils import parameterize - -try: - pass - - HAS_TRITON = True -except ImportError: - HAS_TRITON = False - print("please install triton from https://github.com/openai/triton") - -TRITON_CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse("11.4") - - -@pytest.mark.skipif( - not TRITON_CUDA_SUPPORT or not HAS_TRITON, reason="triton requires cuda version to be higher than 11.4" -) -@parameterize("M", [2, 4, 8, 16]) -@parameterize("N", [64, 128]) -def test_layer_norm(M, N): - dtype = torch.float16 - eps = 1e-5 - x_shape = (M, N) - w_shape = (x_shape[-1],) - weight = torch.rand(w_shape, dtype=dtype, device="cuda") - bias = torch.rand(w_shape, dtype=dtype, device="cuda") - x = -2.3 + 0.5 * torch.randn(x_shape, dtype=dtype, device="cuda") - - y_triton = layer_norm(x, weight, bias, eps) - y_torch = torch.nn.functional.layer_norm(x, w_shape, weight, bias, eps).to(dtype) - - assert y_triton.shape == y_torch.shape - assert y_triton.dtype == y_torch.dtype - print("max delta: ", torch.max(torch.abs(y_triton - y_torch))) - assert torch.allclose(y_triton, y_torch, atol=1e-2, rtol=0) - - -if __name__ == "__main__": - test_layer_norm() diff --git a/tests/test_infer_ops/triton/test_llama_act_combine.py b/tests/test_infer_ops/triton/test_llama_act_combine.py deleted file mode 100644 index 5341aa35a..000000000 --- a/tests/test_infer_ops/triton/test_llama_act_combine.py +++ /dev/null @@ -1,56 +0,0 @@ -import pytest -import torch -from packaging import version -from torch import nn - -from colossalai.kernel.triton.llama_act_combine_kernel import LlamaActCombine - -try: - import triton - HAS_TRITON = True -except ImportError: - HAS_TRITON = False - print("please install triton from https://github.com/openai/triton") -TRITON_CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse('11.4') - -BATCH_SIZE = 4 -SEQ_LEN = 16 -HIDDEN_SIZE = 32 - - -def SwiGLU(x): - """Gated linear unit activation function. - Args: - x : input array - axis: the axis along which the split should be computed (default: -1) - """ - size = x.shape[-1] - assert size % 2 == 0, "axis size must be divisible by 2" - x1, x2 = torch.split(x, size // 2, -1) - return x1 * (x2 * torch.sigmoid(x2.to(torch.float32)).to(x.dtype)) - - -@pytest.mark.skipif(not (HAS_TRITON and TRITON_CUDA_SUPPORT), reason="requires triton") -@pytest.mark.parametrize("dtype", [torch.float32, torch.float16]) -def test_llama_act_combine(dtype: str): - x_gate = torch.randn(BATCH_SIZE, SEQ_LEN, HIDDEN_SIZE * 2, dtype=dtype).cuda() - x_gate_torch = nn.Parameter(x_gate.detach().clone()) - x_gate_kernel = nn.Parameter(x_gate.detach().clone()) - x_up = torch.randn(BATCH_SIZE, SEQ_LEN, HIDDEN_SIZE, dtype=dtype).cuda() - x_up_torch = nn.Parameter(x_up.detach().clone()) - x_up_kernel = nn.Parameter(x_up.detach().clone()) - - torch_out = SwiGLU(x_gate_torch) * x_up_torch - kernel_out = LlamaActCombine.apply(x_gate_kernel, x_up_kernel) - atol = 1e-5 if dtype == torch.float32 else 5e-2 - assert torch.allclose(torch_out, kernel_out, atol=atol) - - torch_out.mean().backward() - kernel_out.mean().backward() - assert all(grad is not None for grad in [x_gate_torch.grad, x_up_torch.grad, x_gate_kernel.grad, x_up_kernel.grad]) - assert torch.allclose(x_gate_torch.grad, x_gate_kernel.grad, atol=atol) - assert torch.allclose(x_up_torch.grad, x_up_kernel.grad, atol=atol) - - -if __name__ == '__main__': - test_llama_act_combine(torch.float16) diff --git a/tests/test_infer_ops/triton/test_llama_context_attention.py b/tests/test_infer_ops/triton/test_llama_context_attention.py deleted file mode 100644 index 95fe50cf1..000000000 --- a/tests/test_infer_ops/triton/test_llama_context_attention.py +++ /dev/null @@ -1,50 +0,0 @@ -import pytest -import torch -from packaging import version - -try: - pass - - from colossalai.kernel.triton import llama_context_attn_fwd - from tests.test_infer_ops.triton.kernel_utils import torch_context_attention - - HAS_TRITON = True -except ImportError: - HAS_TRITON = False - print("please install triton from https://github.com/openai/triton") - -TRITON_CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse("11.4") - - -@pytest.mark.skipif( - not TRITON_CUDA_SUPPORT or not HAS_TRITON, reason="triton requires cuda version to be higher than 11.4" -) -def test_llama_context_attention(): - bs = 4 - head_num = 8 - seq_len = 1024 - head_dim = 64 - - query = torch.randn((bs * seq_len, head_num, head_dim), dtype=torch.float16, device="cuda") - k = torch.randn((bs * seq_len, head_num, head_dim), dtype=torch.float16, device="cuda") - v = torch.randn((bs * seq_len, head_num, head_dim), dtype=torch.float16, device="cuda") - - max_input_len = seq_len - b_start = torch.zeros((bs,), device="cuda", dtype=torch.int32) - b_len = torch.zeros((bs,), device="cuda", dtype=torch.int32) - - for i in range(bs): - b_start[i] = i * seq_len - b_len[i] = seq_len - - o = torch.randn((bs * seq_len, head_num, head_dim), dtype=torch.float16, device="cuda") - llama_context_attn_fwd(query.clone(), k.clone(), v.clone(), o, b_start, b_len, max_input_len) - - torch_out = torch_context_attention(query.clone(), k.clone(), v.clone(), bs, seq_len, head_num, head_dim) - assert torch.allclose( - torch_out.cpu(), o.cpu(), rtol=1e-3, atol=1e-3 - ), "outputs from triton and torch are not matched" - - -if __name__ == "__main__": - test_llama_context_attention() diff --git a/tests/test_infer_ops/triton/test_self_attention_nonfusion.py b/tests/test_infer_ops/triton/test_self_attention_nonfusion.py deleted file mode 100644 index 9bdec8664..000000000 --- a/tests/test_infer_ops/triton/test_self_attention_nonfusion.py +++ /dev/null @@ -1,143 +0,0 @@ -import pytest -import torch -import torch.nn.functional as F -from packaging import version - -try: - import triton - - from colossalai.kernel.triton.qkv_matmul_kernel import qkv_gemm_4d_kernel - from colossalai.kernel.triton.self_attention_nofusion import self_attention_compute_using_triton - - HAS_TRITON = True -except ImportError: - HAS_TRITON = False - print("please install triton from https://github.com/openai/triton") - -TRITON_CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse("11.4") - - -@pytest.mark.skipif( - not TRITON_CUDA_SUPPORT or not HAS_TRITON, reason="triton requires cuda version to be higher than 11.4" -) -def test_qkv_matmul(): - qkv = torch.randn((4, 24, 64 * 3), device="cuda", dtype=torch.float16) - scale = 1.2 - head_size = 32 - batches = qkv.shape[0] - d_model = qkv.shape[-1] // 3 - num_of_heads = d_model // head_size - - q = qkv[:, :, :d_model] - k = qkv[:, :, d_model : d_model * 2] - - q = q.view(batches, -1, num_of_heads, head_size) - k = k.view(batches, -1, num_of_heads, head_size) - q_copy = q.clone() - k_copy = k.clone() - q = torch.transpose(q, 1, 2).contiguous() - k = torch.transpose(k, 1, 2).contiguous() - k = torch.transpose(k, 2, 3).contiguous() - - torch_ouput = torch.einsum("bnij,bnjk->bnik", q, k) - torch_ouput *= 1.2 - - q, k = q_copy, k_copy - batches, M, H, K = q.shape - N = k.shape[1] - score_output = torch.empty((batches, H, M, N), device=q.device, dtype=q.dtype) - - grid = lambda meta: ( - batches, - H, - triton.cdiv(M, meta["BLOCK_SIZE_M"]) * triton.cdiv(N, meta["BLOCK_SIZE_N"]), - ) - - K = q.shape[3] - qkv_gemm_4d_kernel[grid]( - q, - k, - score_output, - M, - N, - K, - q.stride(0), - q.stride(2), - q.stride(1), - q.stride(3), - k.stride(0), - k.stride(2), - k.stride(3), - k.stride(1), - score_output.stride(0), - score_output.stride(1), - score_output.stride(2), - score_output.stride(3), - scale=scale, - # currently manually setting, later on we can use auto-tune config to match best setting - BLOCK_SIZE_M=64, - BLOCK_SIZE_N=32, - BLOCK_SIZE_K=32, - GROUP_SIZE_M=8, - ) - - check = torch.allclose(torch_ouput.cpu(), score_output.cpu(), rtol=1e-3, atol=1e-5) - assert check is True, "the outputs of triton and torch are not matched" - - -def self_attention_compute_using_torch(qkv, input_mask, scale, head_size): - batches = qkv.shape[0] - d_model = qkv.shape[-1] // 3 - num_of_heads = d_model // head_size - - q = qkv[:, :, :d_model] - k = qkv[:, :, d_model : d_model * 2] - v = qkv[:, :, d_model * 2 :] - q = q.view(batches, -1, num_of_heads, head_size) - k = k.view(batches, -1, num_of_heads, head_size) - v = v.view(batches, -1, num_of_heads, head_size) - - q = torch.transpose(q, 1, 2).contiguous() - k = torch.transpose(k, 1, 2).contiguous() - v = torch.transpose(v, 1, 2).contiguous() - - k = torch.transpose(k, -1, -2).contiguous() - - score_output = torch.einsum("bnij,bnjk->bnik", q, k) - score_output *= scale - - softmax_output = F.softmax(score_output, dim=-1) - res = torch.einsum("bnij,bnjk->bnik", softmax_output, v) - res = torch.transpose(res, 1, 2) - res = res.contiguous() - - return res.view(batches, -1, d_model), score_output, softmax_output - - -@pytest.mark.skipif( - not TRITON_CUDA_SUPPORT or not HAS_TRITON, reason="triton requires cuda version to be higher than 11.4" -) -def test_self_atttention_test(): - qkv = torch.randn((4, 24, 64 * 3), device="cuda", dtype=torch.float16) - data_output_torch, score_output_torch, softmax_output_torch = self_attention_compute_using_torch( - qkv.clone(), input_mask=None, scale=1.2, head_size=32 - ) - - data_output_triton = self_attention_compute_using_triton( - qkv.clone(), - alibi=None, - head_size=32, - scale=1.2, - input_mask=None, - layer_past=None, - use_flash=False, - triangular=True, - ) - - check = torch.allclose(data_output_triton.cpu(), data_output_torch.cpu(), rtol=1e-4, atol=1e-2) - assert check is True, "the triton output is not matched with torch output" - - -if __name__ == "__main__": - test_qkv_matmul() - test_self_atttention_test() diff --git a/tests/test_infer_ops/triton/test_softmax.py b/tests/test_infer_ops/triton/test_softmax.py deleted file mode 100644 index 43b9c0929..000000000 --- a/tests/test_infer_ops/triton/test_softmax.py +++ /dev/null @@ -1,36 +0,0 @@ -import pytest -import torch -from packaging import version -from torch import nn - -try: - from colossalai.kernel.triton.softmax import softmax - - HAS_TRITON = True -except ImportError: - HAS_TRITON = False - print("please install triton from https://github.com/openai/triton") - -TRITON_CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse("11.4") - - -@pytest.mark.skipif( - not TRITON_CUDA_SUPPORT or not HAS_TRITON, reason="triton requires cuda version to be higher than 11.4" -) -def test_softmax_op(): - data_samples = [ - torch.randn((3, 4, 5, 32), device="cuda", dtype=torch.float32), - torch.randn((320, 320, 78), device="cuda", dtype=torch.float32), - torch.randn((2345, 4, 5, 64), device="cuda", dtype=torch.float16), - ] - - for data in data_samples: - module = nn.Softmax(dim=-1) - data_torch_out = module(data) - data_triton_out = softmax(data) - check = torch.allclose(data_torch_out.cpu(), data_triton_out.cpu(), rtol=1e-3, atol=1e-3) - assert check is True, "softmax outputs from triton and torch are not matched" - - -if __name__ == "__main__": - test_softmax_op() diff --git a/tests/test_infer_ops/triton/test_token_attn_fwd.py b/tests/test_infer_ops/triton/test_token_attn_fwd.py deleted file mode 100644 index 4ee1a5fb1..000000000 --- a/tests/test_infer_ops/triton/test_token_attn_fwd.py +++ /dev/null @@ -1,72 +0,0 @@ -import pytest -import torch -from packaging import version - -try: - from colossalai.kernel.triton.token_attention_kernel import token_attention_fwd - - HAS_TRITON = True -except ImportError: - HAS_TRITON = False - print("please install triton from https://github.com/openai/triton") - - -import importlib.util - -HAS_LIGHTLLM_KERNEL = True - -if importlib.util.find_spec("lightllm") is None: - HAS_LIGHTLLM_KERNEL = False - -TRITON_CUDA_SUPPORT = version.parse(torch.version.cuda) >= version.parse("11.6") - - -def torch_att(xq, xk, xv, bs, seqlen, num_head, head_dim): - xq = xq.view(bs, 1, num_head, head_dim) - xk = xk.view(bs, seqlen, num_head, head_dim) - xv = xv.view(bs, seqlen, num_head, head_dim) - - logics = torch.sum(xq * xk, dim=3, keepdim=False) * 1 / (head_dim**0.5) - prob = torch.softmax(logics, dim=1) - prob = prob.view(bs, seqlen, num_head, 1) - - return torch.sum(prob * xv, dim=1, keepdim=False) - - -@pytest.mark.skipif( - not TRITON_CUDA_SUPPORT or not HAS_TRITON or not HAS_LIGHTLLM_KERNEL, - reason="triton requires cuda version to be higher than 11.4 or not install lightllm", -) -def test(): - Z, head_num, seq_len, head_dim = 22, 112 // 8, 2048, 128 - dtype = torch.float16 - q = torch.empty((Z, head_num, head_dim), dtype=dtype, device="cuda").normal_(mean=0.1, std=0.2) - k = torch.empty((Z * seq_len, head_num, head_dim), dtype=dtype, device="cuda").normal_(mean=0.4, std=0.2) - v = torch.empty((Z * seq_len, head_num, head_dim), dtype=dtype, device="cuda").normal_(mean=0.3, std=0.2) - o = torch.empty((Z, head_num, head_dim), dtype=dtype, device="cuda").normal_(mean=0.3, std=0.2) - alibi = torch.zeros((head_num,), dtype=torch.float32, device="cuda") - - max_kv_cache_len = seq_len - kv_cache_start_loc = torch.zeros((Z,), dtype=torch.int32, device="cuda") - kv_cache_loc = torch.zeros((Z, seq_len), dtype=torch.int32, device="cuda") - kv_cache_seq_len = torch.ones((Z,), dtype=torch.int32, device="cuda") - - kv_cache_seq_len[:] = seq_len - kv_cache_start_loc[0] = 0 - kv_cache_start_loc[1] = seq_len - kv_cache_start_loc[2] = 2 * seq_len - kv_cache_start_loc[3] = 3 * seq_len - - for i in range(Z): - kv_cache_loc[i, :] = torch.arange(i * seq_len, (i + 1) * seq_len, dtype=torch.int32, device="cuda") - - token_attention_fwd(q, k, v, o, kv_cache_loc, kv_cache_start_loc, kv_cache_seq_len, max_kv_cache_len, alibi=alibi) - torch_out = torch_att(q, k, v, Z, seq_len, head_num, head_dim) - - print("max ", torch.max(torch.abs(torch_out - o))) - print("mean ", torch.mean(torch.abs(torch_out - o))) - assert torch.allclose(torch_out, o, atol=1e-2, rtol=0) - - -if __name__ == "__main__": - test() diff --git a/tests/test_infer_ops/triton/test_token_softmax.py b/tests/test_infer_ops/triton/test_token_softmax.py deleted file mode 100644 index 1f97f1674..000000000 --- a/tests/test_infer_ops/triton/test_token_softmax.py +++ /dev/null @@ -1,48 +0,0 @@ -import pytest -import torch -from packaging import version - -try: - pass - - from colossalai.kernel.triton.token_attention_kernel import token_attn_softmax_fwd - - HAS_TRITON = True -except ImportError: - HAS_TRITON = False - print("please install triton from https://github.com/openai/triton") - -TRITON_CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse("11.4") - - -@pytest.mark.skipif( - not TRITON_CUDA_SUPPORT or not HAS_TRITON, reason="triton requires cuda version to be higher than 11.4" -) -def test_softmax(): - import torch - - batch_size, seq_len, head_num, head_dim = 4, 1025, 12, 128 - - dtype = torch.float16 - - Logics = torch.empty((head_num, batch_size * seq_len), dtype=dtype, device="cuda").normal_(mean=0.1, std=10) - ProbOut = torch.empty((head_num, batch_size * seq_len), dtype=dtype, device="cuda").normal_(mean=0.4, std=0.2) - - kv_cache_start_loc = torch.zeros((batch_size,), dtype=torch.int32, device="cuda") - kv_cache_seq_len = torch.zeros((batch_size,), dtype=torch.int32, device="cuda") - - for i in range(batch_size): - kv_cache_start_loc[i] = i * seq_len - kv_cache_seq_len[i] = seq_len - - token_attn_softmax_fwd(Logics, kv_cache_start_loc, kv_cache_seq_len, ProbOut, seq_len) - - torch_out = Logics.reshape(head_num * batch_size, -1).softmax(-1).reshape(head_num, batch_size * seq_len) - o = ProbOut - print("max ", torch.max(torch.abs(torch_out - o))) - print("mean ", torch.mean(torch.abs(torch_out - o))) - assert torch.allclose(torch_out, o, atol=1e-2, rtol=0) - - -if __name__ == "__main__": - test_softmax() diff --git a/tests/test_lazy/test_models.py b/tests/test_lazy/test_models.py index ee50e5b61..d0c4cd0a7 100644 --- a/tests/test_lazy/test_models.py +++ b/tests/test_lazy/test_models.py @@ -1,14 +1,19 @@ import pytest from lazy_init_utils import SUPPORT_LAZY, check_lazy_init -from tests.kit.model_zoo import model_zoo, IS_FAST_TEST, COMMON_MODELS +from tests.kit.model_zoo import COMMON_MODELS, IS_FAST_TEST, model_zoo @pytest.mark.skipif(not SUPPORT_LAZY, reason="requires torch >= 1.12.0") -@pytest.mark.parametrize("subset", [COMMON_MODELS] if IS_FAST_TEST else ["torchvision", "diffusers", "timm", "transformers", "torchaudio", "deepfm", "dlrm"]) +@pytest.mark.parametrize( + "subset", + [COMMON_MODELS] + if IS_FAST_TEST + else ["torchvision", "diffusers", "timm", "transformers", "torchaudio", "deepfm", "dlrm"], +) @pytest.mark.parametrize("default_device", ["cpu", "cuda"]) def test_torchvision_models_lazy_init(subset, default_device): - sub_model_zoo = model_zoo.get_sub_registry(subset) + sub_model_zoo = model_zoo.get_sub_registry(subset, allow_empty=True) for name, entry in sub_model_zoo.items(): # TODO(ver217): lazy init does not support weight norm, skip these models if name in ("torchaudio_wav2vec2_base", "torchaudio_hubert_base") or name.startswith(