From 01ea68b2e6ba7b0489207b8a8233e040e21f8411 Mon Sep 17 00:00:00 2001 From: YuliangLiu0306 <72588413+YuliangLiu0306@users.noreply.github.com> Date: Tue, 12 Jul 2022 23:25:30 +0800 Subject: [PATCH] [tests] remove T5 test skip decorator (#1271) --- tests/test_fx/test_pipeline/test_hf_model/test_t5.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/tests/test_fx/test_pipeline/test_hf_model/test_t5.py b/tests/test_fx/test_pipeline/test_hf_model/test_t5.py index d78883c3d..ea32b87cf 100644 --- a/tests/test_fx/test_pipeline/test_hf_model/test_t5.py +++ b/tests/test_fx/test_pipeline/test_hf_model/test_t5.py @@ -2,12 +2,20 @@ import pytest import transformers import torch from hf_utils import split_model_and_compare_output +from colossalai.fx.tracer.meta_patch import meta_patched_module +try: + import apex + + @meta_patched_module.register(apex.normalization.FusedRMSNorm) + def apex_fused_layernorm(self, input): + return torch.empty(input.shape, device='meta') +except ImportError: + pass BATCH_SIZE = 1 SEQ_LENGHT = 16 -@pytest.mark.skip('tracing failed') def test_t5(): MODEL_LIST = [ transformers.T5Model, @@ -15,7 +23,7 @@ def test_t5(): transformers.T5EncoderModel, ] - config = transformers.T5Config(d_model=128, num_layers=2) + config = transformers.T5Config(vocab_size=100, d_model=128, num_layers=2) def data_gen(): input_ids = torch.zeros((BATCH_SIZE, SEQ_LENGHT), dtype=torch.int64)