From b2475d8c5ce1ecfe50c70ba5dc40fb4287d16939 Mon Sep 17 00:00:00 2001 From: Frank Lee Date: Fri, 15 Jul 2022 18:22:15 +0800 Subject: [PATCH] [fx] fixed unit tests for torch 1.12 (#1327) --- colossalai/fx/tracer/_tracer_utils.py | 1 - requirements/requirements-test.txt | 1 + tests/test_fx/test_pipeline/test_hf_model/test_albert.py | 1 - tests/test_fx/test_pipeline/test_hf_model/test_bert.py | 1 - tests/test_fx/test_pipeline/test_hf_model/test_gpt.py | 1 - tests/test_fx/test_pipeline/test_hf_model/test_opt.py | 1 - tests/test_fx/test_pipeline/test_hf_model/test_t5.py | 1 - tests/test_fx/test_pipeline/test_timm_model/test_timm.py | 2 -- .../test_fx/test_pipeline/test_torchvision/test_torchvision.py | 1 - tests/test_fx/test_tracer/test_hf_model/test_hf_albert.py | 1 - tests/test_fx/test_tracer/test_hf_model/test_hf_bert.py | 1 - tests/test_fx/test_tracer/test_hf_model/test_hf_gpt.py | 1 - tests/test_fx/test_tracer/test_hf_model/test_hf_opt.py | 1 - tests/test_fx/test_tracer/test_hf_model/test_hf_t5.py | 1 - tests/test_fx/test_tracer/test_patched_module.py | 2 +- tests/test_fx/test_tracer/test_timm_model/test_timm_model.py | 2 -- .../test_torchvision_model/test_torchvision_model.py | 1 - 17 files changed, 2 insertions(+), 18 deletions(-) diff --git a/colossalai/fx/tracer/_tracer_utils.py b/colossalai/fx/tracer/_tracer_utils.py index 528c4a8e9..c1d21e67e 100644 --- a/colossalai/fx/tracer/_tracer_utils.py +++ b/colossalai/fx/tracer/_tracer_utils.py @@ -22,7 +22,6 @@ def extract_meta(*args, **kwargs): if isinstance(val, MetaDeviceAttribute): return 'meta' elif isinstance(val, ColoProxy): - assert val.meta_data is not None return val.meta_data return val diff --git a/requirements/requirements-test.txt b/requirements/requirements-test.txt index 03101d69f..221c82ef7 100644 --- a/requirements/requirements-test.txt +++ b/requirements/requirements-test.txt @@ -1,4 +1,5 @@ pytest torchvision transformers +timm titans diff --git a/tests/test_fx/test_pipeline/test_hf_model/test_albert.py b/tests/test_fx/test_pipeline/test_hf_model/test_albert.py index 0bdc9a1aa..08d20c894 100644 --- a/tests/test_fx/test_pipeline/test_hf_model/test_albert.py +++ b/tests/test_fx/test_pipeline/test_hf_model/test_albert.py @@ -7,7 +7,6 @@ BATCH_SIZE = 2 SEQ_LENGHT = 16 -@pytest.mark.skip("error with pytorch 1.10") def test_single_sentence_albert(): MODEL_LIST = [ transformers.AlbertModel, diff --git a/tests/test_fx/test_pipeline/test_hf_model/test_bert.py b/tests/test_fx/test_pipeline/test_hf_model/test_bert.py index c7af6e4d0..a3699b660 100644 --- a/tests/test_fx/test_pipeline/test_hf_model/test_bert.py +++ b/tests/test_fx/test_pipeline/test_hf_model/test_bert.py @@ -7,7 +7,6 @@ BATCH_SIZE = 2 SEQ_LENGHT = 16 -@pytest.mark.skip("error with pytorch 1.10") def test_single_sentence_bert(): MODEL_LIST = [ transformers.BertModel, diff --git a/tests/test_fx/test_pipeline/test_hf_model/test_gpt.py b/tests/test_fx/test_pipeline/test_hf_model/test_gpt.py index 6b982dda4..b973ac854 100644 --- a/tests/test_fx/test_pipeline/test_hf_model/test_gpt.py +++ b/tests/test_fx/test_pipeline/test_hf_model/test_gpt.py @@ -9,7 +9,6 @@ NUM_EPOCHS = 2 NUM_CHUNKS = 1 -@pytest.mark.skip("error with pytorch 1.10") def test_gpt(): MODEL_LIST = [ transformers.GPT2Model, diff --git a/tests/test_fx/test_pipeline/test_hf_model/test_opt.py b/tests/test_fx/test_pipeline/test_hf_model/test_opt.py index 00c16d201..a55ea54fe 100644 --- a/tests/test_fx/test_pipeline/test_hf_model/test_opt.py +++ b/tests/test_fx/test_pipeline/test_hf_model/test_opt.py @@ -7,7 +7,6 @@ BATCH_SIZE = 1 SEQ_LENGHT = 16 -@pytest.mark.skip("error with pytorch 1.10") def test_opt(): MODEL_LIST = [ transformers.OPTModel, diff --git a/tests/test_fx/test_pipeline/test_hf_model/test_t5.py b/tests/test_fx/test_pipeline/test_hf_model/test_t5.py index f24dd705c..d20d18842 100644 --- a/tests/test_fx/test_pipeline/test_hf_model/test_t5.py +++ b/tests/test_fx/test_pipeline/test_hf_model/test_t5.py @@ -7,7 +7,6 @@ BATCH_SIZE = 1 SEQ_LENGHT = 16 -@pytest.mark.skip("error with pytorch 1.10") def test_t5(): MODEL_LIST = [ transformers.T5Model, diff --git a/tests/test_fx/test_pipeline/test_timm_model/test_timm.py b/tests/test_fx/test_pipeline/test_timm_model/test_timm.py index bf11cb30a..da3843a27 100644 --- a/tests/test_fx/test_pipeline/test_timm_model/test_timm.py +++ b/tests/test_fx/test_pipeline/test_timm_model/test_timm.py @@ -7,7 +7,6 @@ except: from timm_utils import split_model_and_compare_output -@pytest.mark.skip('skip as timm is required') def test_timm_models_without_control_flow(): MODEL_LIST = [ @@ -28,7 +27,6 @@ def test_timm_models_without_control_flow(): split_model_and_compare_output(model, data) -@pytest.mark.skip('skip as timm is required') def test_timm_models_with_control_flow(): torch.backends.cudnn.deterministic = True diff --git a/tests/test_fx/test_pipeline/test_torchvision/test_torchvision.py b/tests/test_fx/test_pipeline/test_torchvision/test_torchvision.py index e52889e3b..c03121063 100644 --- a/tests/test_fx/test_pipeline/test_torchvision/test_torchvision.py +++ b/tests/test_fx/test_pipeline/test_torchvision/test_torchvision.py @@ -19,7 +19,6 @@ torch.manual_seed(MANUAL_SEED) torch.backends.cudnn.deterministic = True -@pytest.mark.skip('skip as torchvision is required') def test_torchvision_models(): MODEL_LIST = [ tm.vgg11, tm.resnet18, tm.densenet121, tm.mobilenet_v3_small, tm.resnext50_32x4d, tm.wide_resnet50_2, diff --git a/tests/test_fx/test_tracer/test_hf_model/test_hf_albert.py b/tests/test_fx/test_tracer/test_hf_model/test_hf_albert.py index 2b01eabd3..cf809e13a 100644 --- a/tests/test_fx/test_tracer/test_hf_model/test_hf_albert.py +++ b/tests/test_fx/test_tracer/test_hf_model/test_hf_albert.py @@ -34,7 +34,6 @@ def test_single_sentence_albert(): trace_model_and_compare_output(model, data_gen) -@pytest.mark.skip("error with pytorch 1.10") def test_multi_sentence_albert(): config = transformers.AlbertConfig(hidden_size=128, num_hidden_layers=2, diff --git a/tests/test_fx/test_tracer/test_hf_model/test_hf_bert.py b/tests/test_fx/test_tracer/test_hf_model/test_hf_bert.py index e60e4aa7c..63ad4badc 100644 --- a/tests/test_fx/test_tracer/test_hf_model/test_hf_bert.py +++ b/tests/test_fx/test_tracer/test_hf_model/test_hf_bert.py @@ -31,7 +31,6 @@ def test_single_sentence_bert(): trace_model_and_compare_output(model, data_gen) -@pytest.mark.skip("error with pytorch 1.10") def test_multi_sentence_bert(): config = transformers.BertConfig(hidden_size=128, num_hidden_layers=2, num_attention_heads=4, intermediate_size=256) tokenizer = transformers.BertTokenizer.from_pretrained("bert-base-uncased") diff --git a/tests/test_fx/test_tracer/test_hf_model/test_hf_gpt.py b/tests/test_fx/test_tracer/test_hf_model/test_hf_gpt.py index 9c8971a75..1c20e9bfd 100644 --- a/tests/test_fx/test_tracer/test_hf_model/test_hf_gpt.py +++ b/tests/test_fx/test_tracer/test_hf_model/test_hf_gpt.py @@ -7,7 +7,6 @@ BATCH_SIZE = 1 SEQ_LENGHT = 16 -@pytest.mark.skip("error with pytorch 1.10") def test_gpt(): MODEL_LIST = [ transformers.GPT2Model, diff --git a/tests/test_fx/test_tracer/test_hf_model/test_hf_opt.py b/tests/test_fx/test_tracer/test_hf_model/test_hf_opt.py index 0075d1f2b..5ac051887 100644 --- a/tests/test_fx/test_tracer/test_hf_model/test_hf_opt.py +++ b/tests/test_fx/test_tracer/test_hf_model/test_hf_opt.py @@ -7,7 +7,6 @@ BATCH_SIZE = 1 SEQ_LENGHT = 16 -@pytest.mark.skip("error with pytorch 1.10") def test_opt(): MODEL_LIST = [ transformers.OPTModel, diff --git a/tests/test_fx/test_tracer/test_hf_model/test_hf_t5.py b/tests/test_fx/test_tracer/test_hf_model/test_hf_t5.py index 4e2614056..645951de9 100644 --- a/tests/test_fx/test_tracer/test_hf_model/test_hf_t5.py +++ b/tests/test_fx/test_tracer/test_hf_model/test_hf_t5.py @@ -7,7 +7,6 @@ BATCH_SIZE = 1 SEQ_LENGHT = 16 -@pytest.mark.skip("error with pytorch 1.10") def test_t5(): MODEL_LIST = [ transformers.T5Model, diff --git a/tests/test_fx/test_tracer/test_patched_module.py b/tests/test_fx/test_tracer/test_patched_module.py index d7ceba1a5..9b4f7c516 100644 --- a/tests/test_fx/test_tracer/test_patched_module.py +++ b/tests/test_fx/test_tracer/test_patched_module.py @@ -40,7 +40,7 @@ def test_embedding(): _assert_output_shape(data, ln, patched_module.torch_nn_normalize, False, data.shape) # test group norm - gn = torch.nn.GroupNorm(4, num_channels=2) + gn = torch.nn.GroupNorm(4, num_channels=8) _assert_output_shape(data, gn, patched_module.torch_nn_normalize, False, data.shape) # test batch norm 1d diff --git a/tests/test_fx/test_tracer/test_timm_model/test_timm_model.py b/tests/test_fx/test_tracer/test_timm_model/test_timm_model.py index 7df0b2e6c..5e2c40cac 100644 --- a/tests/test_fx/test_tracer/test_timm_model/test_timm_model.py +++ b/tests/test_fx/test_tracer/test_timm_model/test_timm_model.py @@ -36,7 +36,6 @@ def trace_and_compare(model_cls, tracer, data, meta_args=None): fx_out, non_fx_out), f'{model.__class__.__name__} has inconsistent outputs, {fx_out} vs {non_fx_out}' -@pytest.mark.skip('skip as timm is required') def test_timm_models_without_control_flow(): torch.backends.cudnn.deterministic = True @@ -58,7 +57,6 @@ def test_timm_models_without_control_flow(): trace_and_compare(model_cls, tracer, data) -@pytest.mark.skip('skip as timm is required') def test_timm_models_with_control_flow(): torch.backends.cudnn.deterministic = True diff --git a/tests/test_fx/test_tracer/test_torchvision_model/test_torchvision_model.py b/tests/test_fx/test_tracer/test_torchvision_model/test_torchvision_model.py index 11c3d7ea5..7360bd885 100644 --- a/tests/test_fx/test_tracer/test_torchvision_model/test_torchvision_model.py +++ b/tests/test_fx/test_tracer/test_torchvision_model/test_torchvision_model.py @@ -8,7 +8,6 @@ from colossalai.fx import ColoTracer from torch.fx import GraphModule -@pytest.mark.skip('skip as torchvision is required') def test_torchvision_models(): MODEL_LIST = [ tm.vgg11, tm.resnet18, tm.densenet121, tm.mobilenet_v3_small, tm.resnext50_32x4d, tm.wide_resnet50_2,