[fx] fixed unit tests for torch 1.12 (#1327)

pull/1329/head
Frank Lee 2022-07-15 18:22:15 +08:00 committed by GitHub
parent d49708ae43
commit b2475d8c5c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
17 changed files with 2 additions and 18 deletions

View File

@ -22,7 +22,6 @@ def extract_meta(*args, **kwargs):
if isinstance(val, MetaDeviceAttribute): if isinstance(val, MetaDeviceAttribute):
return 'meta' return 'meta'
elif isinstance(val, ColoProxy): elif isinstance(val, ColoProxy):
assert val.meta_data is not None
return val.meta_data return val.meta_data
return val return val

View File

@ -1,4 +1,5 @@
pytest pytest
torchvision torchvision
transformers transformers
timm
titans titans

View File

@ -7,7 +7,6 @@ BATCH_SIZE = 2
SEQ_LENGHT = 16 SEQ_LENGHT = 16
@pytest.mark.skip("error with pytorch 1.10")
def test_single_sentence_albert(): def test_single_sentence_albert():
MODEL_LIST = [ MODEL_LIST = [
transformers.AlbertModel, transformers.AlbertModel,

View File

@ -7,7 +7,6 @@ BATCH_SIZE = 2
SEQ_LENGHT = 16 SEQ_LENGHT = 16
@pytest.mark.skip("error with pytorch 1.10")
def test_single_sentence_bert(): def test_single_sentence_bert():
MODEL_LIST = [ MODEL_LIST = [
transformers.BertModel, transformers.BertModel,

View File

@ -9,7 +9,6 @@ NUM_EPOCHS = 2
NUM_CHUNKS = 1 NUM_CHUNKS = 1
@pytest.mark.skip("error with pytorch 1.10")
def test_gpt(): def test_gpt():
MODEL_LIST = [ MODEL_LIST = [
transformers.GPT2Model, transformers.GPT2Model,

View File

@ -7,7 +7,6 @@ BATCH_SIZE = 1
SEQ_LENGHT = 16 SEQ_LENGHT = 16
@pytest.mark.skip("error with pytorch 1.10")
def test_opt(): def test_opt():
MODEL_LIST = [ MODEL_LIST = [
transformers.OPTModel, transformers.OPTModel,

View File

@ -7,7 +7,6 @@ BATCH_SIZE = 1
SEQ_LENGHT = 16 SEQ_LENGHT = 16
@pytest.mark.skip("error with pytorch 1.10")
def test_t5(): def test_t5():
MODEL_LIST = [ MODEL_LIST = [
transformers.T5Model, transformers.T5Model,

View File

@ -7,7 +7,6 @@ except:
from timm_utils import split_model_and_compare_output from timm_utils import split_model_and_compare_output
@pytest.mark.skip('skip as timm is required')
def test_timm_models_without_control_flow(): def test_timm_models_without_control_flow():
MODEL_LIST = [ MODEL_LIST = [
@ -28,7 +27,6 @@ def test_timm_models_without_control_flow():
split_model_and_compare_output(model, data) split_model_and_compare_output(model, data)
@pytest.mark.skip('skip as timm is required')
def test_timm_models_with_control_flow(): def test_timm_models_with_control_flow():
torch.backends.cudnn.deterministic = True torch.backends.cudnn.deterministic = True

View File

@ -19,7 +19,6 @@ torch.manual_seed(MANUAL_SEED)
torch.backends.cudnn.deterministic = True torch.backends.cudnn.deterministic = True
@pytest.mark.skip('skip as torchvision is required')
def test_torchvision_models(): def test_torchvision_models():
MODEL_LIST = [ MODEL_LIST = [
tm.vgg11, tm.resnet18, tm.densenet121, tm.mobilenet_v3_small, tm.resnext50_32x4d, tm.wide_resnet50_2, tm.vgg11, tm.resnet18, tm.densenet121, tm.mobilenet_v3_small, tm.resnext50_32x4d, tm.wide_resnet50_2,

View File

@ -34,7 +34,6 @@ def test_single_sentence_albert():
trace_model_and_compare_output(model, data_gen) trace_model_and_compare_output(model, data_gen)
@pytest.mark.skip("error with pytorch 1.10")
def test_multi_sentence_albert(): def test_multi_sentence_albert():
config = transformers.AlbertConfig(hidden_size=128, config = transformers.AlbertConfig(hidden_size=128,
num_hidden_layers=2, num_hidden_layers=2,

View File

@ -31,7 +31,6 @@ def test_single_sentence_bert():
trace_model_and_compare_output(model, data_gen) trace_model_and_compare_output(model, data_gen)
@pytest.mark.skip("error with pytorch 1.10")
def test_multi_sentence_bert(): def test_multi_sentence_bert():
config = transformers.BertConfig(hidden_size=128, num_hidden_layers=2, num_attention_heads=4, intermediate_size=256) config = transformers.BertConfig(hidden_size=128, num_hidden_layers=2, num_attention_heads=4, intermediate_size=256)
tokenizer = transformers.BertTokenizer.from_pretrained("bert-base-uncased") tokenizer = transformers.BertTokenizer.from_pretrained("bert-base-uncased")

View File

@ -7,7 +7,6 @@ BATCH_SIZE = 1
SEQ_LENGHT = 16 SEQ_LENGHT = 16
@pytest.mark.skip("error with pytorch 1.10")
def test_gpt(): def test_gpt():
MODEL_LIST = [ MODEL_LIST = [
transformers.GPT2Model, transformers.GPT2Model,

View File

@ -7,7 +7,6 @@ BATCH_SIZE = 1
SEQ_LENGHT = 16 SEQ_LENGHT = 16
@pytest.mark.skip("error with pytorch 1.10")
def test_opt(): def test_opt():
MODEL_LIST = [ MODEL_LIST = [
transformers.OPTModel, transformers.OPTModel,

View File

@ -7,7 +7,6 @@ BATCH_SIZE = 1
SEQ_LENGHT = 16 SEQ_LENGHT = 16
@pytest.mark.skip("error with pytorch 1.10")
def test_t5(): def test_t5():
MODEL_LIST = [ MODEL_LIST = [
transformers.T5Model, transformers.T5Model,

View File

@ -40,7 +40,7 @@ def test_embedding():
_assert_output_shape(data, ln, patched_module.torch_nn_normalize, False, data.shape) _assert_output_shape(data, ln, patched_module.torch_nn_normalize, False, data.shape)
# test group norm # test group norm
gn = torch.nn.GroupNorm(4, num_channels=2) gn = torch.nn.GroupNorm(4, num_channels=8)
_assert_output_shape(data, gn, patched_module.torch_nn_normalize, False, data.shape) _assert_output_shape(data, gn, patched_module.torch_nn_normalize, False, data.shape)
# test batch norm 1d # test batch norm 1d

View File

@ -36,7 +36,6 @@ def trace_and_compare(model_cls, tracer, data, meta_args=None):
fx_out, non_fx_out), f'{model.__class__.__name__} has inconsistent outputs, {fx_out} vs {non_fx_out}' fx_out, non_fx_out), f'{model.__class__.__name__} has inconsistent outputs, {fx_out} vs {non_fx_out}'
@pytest.mark.skip('skip as timm is required')
def test_timm_models_without_control_flow(): def test_timm_models_without_control_flow():
torch.backends.cudnn.deterministic = True torch.backends.cudnn.deterministic = True
@ -58,7 +57,6 @@ def test_timm_models_without_control_flow():
trace_and_compare(model_cls, tracer, data) trace_and_compare(model_cls, tracer, data)
@pytest.mark.skip('skip as timm is required')
def test_timm_models_with_control_flow(): def test_timm_models_with_control_flow():
torch.backends.cudnn.deterministic = True torch.backends.cudnn.deterministic = True

View File

@ -8,7 +8,6 @@ from colossalai.fx import ColoTracer
from torch.fx import GraphModule from torch.fx import GraphModule
@pytest.mark.skip('skip as torchvision is required')
def test_torchvision_models(): def test_torchvision_models():
MODEL_LIST = [ MODEL_LIST = [
tm.vgg11, tm.resnet18, tm.densenet121, tm.mobilenet_v3_small, tm.resnext50_32x4d, tm.wide_resnet50_2, tm.vgg11, tm.resnet18, tm.densenet121, tm.mobilenet_v3_small, tm.resnext50_32x4d, tm.wide_resnet50_2,