mirror of https://github.com/hpcaitech/ColossalAI
[fx] fixed unit tests for torch 1.12 (#1327)
parent
d49708ae43
commit
b2475d8c5c
|
@ -22,7 +22,6 @@ def extract_meta(*args, **kwargs):
|
||||||
if isinstance(val, MetaDeviceAttribute):
|
if isinstance(val, MetaDeviceAttribute):
|
||||||
return 'meta'
|
return 'meta'
|
||||||
elif isinstance(val, ColoProxy):
|
elif isinstance(val, ColoProxy):
|
||||||
assert val.meta_data is not None
|
|
||||||
return val.meta_data
|
return val.meta_data
|
||||||
return val
|
return val
|
||||||
|
|
||||||
|
|
|
@ -1,4 +1,5 @@
|
||||||
pytest
|
pytest
|
||||||
torchvision
|
torchvision
|
||||||
transformers
|
transformers
|
||||||
|
timm
|
||||||
titans
|
titans
|
||||||
|
|
|
@ -7,7 +7,6 @@ BATCH_SIZE = 2
|
||||||
SEQ_LENGHT = 16
|
SEQ_LENGHT = 16
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skip("error with pytorch 1.10")
|
|
||||||
def test_single_sentence_albert():
|
def test_single_sentence_albert():
|
||||||
MODEL_LIST = [
|
MODEL_LIST = [
|
||||||
transformers.AlbertModel,
|
transformers.AlbertModel,
|
||||||
|
|
|
@ -7,7 +7,6 @@ BATCH_SIZE = 2
|
||||||
SEQ_LENGHT = 16
|
SEQ_LENGHT = 16
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skip("error with pytorch 1.10")
|
|
||||||
def test_single_sentence_bert():
|
def test_single_sentence_bert():
|
||||||
MODEL_LIST = [
|
MODEL_LIST = [
|
||||||
transformers.BertModel,
|
transformers.BertModel,
|
||||||
|
|
|
@ -9,7 +9,6 @@ NUM_EPOCHS = 2
|
||||||
NUM_CHUNKS = 1
|
NUM_CHUNKS = 1
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skip("error with pytorch 1.10")
|
|
||||||
def test_gpt():
|
def test_gpt():
|
||||||
MODEL_LIST = [
|
MODEL_LIST = [
|
||||||
transformers.GPT2Model,
|
transformers.GPT2Model,
|
||||||
|
|
|
@ -7,7 +7,6 @@ BATCH_SIZE = 1
|
||||||
SEQ_LENGHT = 16
|
SEQ_LENGHT = 16
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skip("error with pytorch 1.10")
|
|
||||||
def test_opt():
|
def test_opt():
|
||||||
MODEL_LIST = [
|
MODEL_LIST = [
|
||||||
transformers.OPTModel,
|
transformers.OPTModel,
|
||||||
|
|
|
@ -7,7 +7,6 @@ BATCH_SIZE = 1
|
||||||
SEQ_LENGHT = 16
|
SEQ_LENGHT = 16
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skip("error with pytorch 1.10")
|
|
||||||
def test_t5():
|
def test_t5():
|
||||||
MODEL_LIST = [
|
MODEL_LIST = [
|
||||||
transformers.T5Model,
|
transformers.T5Model,
|
||||||
|
|
|
@ -7,7 +7,6 @@ except:
|
||||||
from timm_utils import split_model_and_compare_output
|
from timm_utils import split_model_and_compare_output
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skip('skip as timm is required')
|
|
||||||
def test_timm_models_without_control_flow():
|
def test_timm_models_without_control_flow():
|
||||||
|
|
||||||
MODEL_LIST = [
|
MODEL_LIST = [
|
||||||
|
@ -28,7 +27,6 @@ def test_timm_models_without_control_flow():
|
||||||
split_model_and_compare_output(model, data)
|
split_model_and_compare_output(model, data)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skip('skip as timm is required')
|
|
||||||
def test_timm_models_with_control_flow():
|
def test_timm_models_with_control_flow():
|
||||||
torch.backends.cudnn.deterministic = True
|
torch.backends.cudnn.deterministic = True
|
||||||
|
|
||||||
|
|
|
@ -19,7 +19,6 @@ torch.manual_seed(MANUAL_SEED)
|
||||||
torch.backends.cudnn.deterministic = True
|
torch.backends.cudnn.deterministic = True
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skip('skip as torchvision is required')
|
|
||||||
def test_torchvision_models():
|
def test_torchvision_models():
|
||||||
MODEL_LIST = [
|
MODEL_LIST = [
|
||||||
tm.vgg11, tm.resnet18, tm.densenet121, tm.mobilenet_v3_small, tm.resnext50_32x4d, tm.wide_resnet50_2,
|
tm.vgg11, tm.resnet18, tm.densenet121, tm.mobilenet_v3_small, tm.resnext50_32x4d, tm.wide_resnet50_2,
|
||||||
|
|
|
@ -34,7 +34,6 @@ def test_single_sentence_albert():
|
||||||
trace_model_and_compare_output(model, data_gen)
|
trace_model_and_compare_output(model, data_gen)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skip("error with pytorch 1.10")
|
|
||||||
def test_multi_sentence_albert():
|
def test_multi_sentence_albert():
|
||||||
config = transformers.AlbertConfig(hidden_size=128,
|
config = transformers.AlbertConfig(hidden_size=128,
|
||||||
num_hidden_layers=2,
|
num_hidden_layers=2,
|
||||||
|
|
|
@ -31,7 +31,6 @@ def test_single_sentence_bert():
|
||||||
trace_model_and_compare_output(model, data_gen)
|
trace_model_and_compare_output(model, data_gen)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skip("error with pytorch 1.10")
|
|
||||||
def test_multi_sentence_bert():
|
def test_multi_sentence_bert():
|
||||||
config = transformers.BertConfig(hidden_size=128, num_hidden_layers=2, num_attention_heads=4, intermediate_size=256)
|
config = transformers.BertConfig(hidden_size=128, num_hidden_layers=2, num_attention_heads=4, intermediate_size=256)
|
||||||
tokenizer = transformers.BertTokenizer.from_pretrained("bert-base-uncased")
|
tokenizer = transformers.BertTokenizer.from_pretrained("bert-base-uncased")
|
||||||
|
|
|
@ -7,7 +7,6 @@ BATCH_SIZE = 1
|
||||||
SEQ_LENGHT = 16
|
SEQ_LENGHT = 16
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skip("error with pytorch 1.10")
|
|
||||||
def test_gpt():
|
def test_gpt():
|
||||||
MODEL_LIST = [
|
MODEL_LIST = [
|
||||||
transformers.GPT2Model,
|
transformers.GPT2Model,
|
||||||
|
|
|
@ -7,7 +7,6 @@ BATCH_SIZE = 1
|
||||||
SEQ_LENGHT = 16
|
SEQ_LENGHT = 16
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skip("error with pytorch 1.10")
|
|
||||||
def test_opt():
|
def test_opt():
|
||||||
MODEL_LIST = [
|
MODEL_LIST = [
|
||||||
transformers.OPTModel,
|
transformers.OPTModel,
|
||||||
|
|
|
@ -7,7 +7,6 @@ BATCH_SIZE = 1
|
||||||
SEQ_LENGHT = 16
|
SEQ_LENGHT = 16
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skip("error with pytorch 1.10")
|
|
||||||
def test_t5():
|
def test_t5():
|
||||||
MODEL_LIST = [
|
MODEL_LIST = [
|
||||||
transformers.T5Model,
|
transformers.T5Model,
|
||||||
|
|
|
@ -40,7 +40,7 @@ def test_embedding():
|
||||||
_assert_output_shape(data, ln, patched_module.torch_nn_normalize, False, data.shape)
|
_assert_output_shape(data, ln, patched_module.torch_nn_normalize, False, data.shape)
|
||||||
|
|
||||||
# test group norm
|
# test group norm
|
||||||
gn = torch.nn.GroupNorm(4, num_channels=2)
|
gn = torch.nn.GroupNorm(4, num_channels=8)
|
||||||
_assert_output_shape(data, gn, patched_module.torch_nn_normalize, False, data.shape)
|
_assert_output_shape(data, gn, patched_module.torch_nn_normalize, False, data.shape)
|
||||||
|
|
||||||
# test batch norm 1d
|
# test batch norm 1d
|
||||||
|
|
|
@ -36,7 +36,6 @@ def trace_and_compare(model_cls, tracer, data, meta_args=None):
|
||||||
fx_out, non_fx_out), f'{model.__class__.__name__} has inconsistent outputs, {fx_out} vs {non_fx_out}'
|
fx_out, non_fx_out), f'{model.__class__.__name__} has inconsistent outputs, {fx_out} vs {non_fx_out}'
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skip('skip as timm is required')
|
|
||||||
def test_timm_models_without_control_flow():
|
def test_timm_models_without_control_flow():
|
||||||
torch.backends.cudnn.deterministic = True
|
torch.backends.cudnn.deterministic = True
|
||||||
|
|
||||||
|
@ -58,7 +57,6 @@ def test_timm_models_without_control_flow():
|
||||||
trace_and_compare(model_cls, tracer, data)
|
trace_and_compare(model_cls, tracer, data)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skip('skip as timm is required')
|
|
||||||
def test_timm_models_with_control_flow():
|
def test_timm_models_with_control_flow():
|
||||||
torch.backends.cudnn.deterministic = True
|
torch.backends.cudnn.deterministic = True
|
||||||
|
|
||||||
|
|
|
@ -8,7 +8,6 @@ from colossalai.fx import ColoTracer
|
||||||
from torch.fx import GraphModule
|
from torch.fx import GraphModule
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skip('skip as torchvision is required')
|
|
||||||
def test_torchvision_models():
|
def test_torchvision_models():
|
||||||
MODEL_LIST = [
|
MODEL_LIST = [
|
||||||
tm.vgg11, tm.resnet18, tm.densenet121, tm.mobilenet_v3_small, tm.resnext50_32x4d, tm.wide_resnet50_2,
|
tm.vgg11, tm.resnet18, tm.densenet121, tm.mobilenet_v3_small, tm.resnext50_32x4d, tm.wide_resnet50_2,
|
||||||
|
|
Loading…
Reference in New Issue