diff --git a/tests/kit/model_zoo/__init__.py b/tests/kit/model_zoo/__init__.py index abe18ebfa..7f14d04c0 100644 --- a/tests/kit/model_zoo/__init__.py +++ b/tests/kit/model_zoo/__init__.py @@ -1,4 +1,4 @@ -from . import diffusers, timm, torchvision +from . import diffusers, timm, torchvision, transformers from .registry import model_zoo __all__ = ['model_zoo'] diff --git a/tests/kit/model_zoo/transformers/__init__.py b/tests/kit/model_zoo/transformers/__init__.py new file mode 100644 index 000000000..f56ff7ad8 --- /dev/null +++ b/tests/kit/model_zoo/transformers/__init__.py @@ -0,0 +1,5 @@ +from .albert import * +from .bert import * +from .gpt import * +from .opt import * +from .t5 import * diff --git a/tests/kit/model_zoo/transformers/albert.py b/tests/kit/model_zoo/transformers/albert.py new file mode 100644 index 000000000..e85f564e3 --- /dev/null +++ b/tests/kit/model_zoo/transformers/albert.py @@ -0,0 +1,85 @@ +import torch +import transformers + +from ..registry import ModelAttribute, model_zoo + +# =============================== +# Register single-sentence ALBERT +# =============================== +BATCH_SIZE = 2 +SEQ_LENGTH = 16 + + +def data_gen_fn(): + input_ids = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64) + token_type_ids = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64) + attention_mask = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64) + return dict(input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask) + + +output_transform_fn = lambda x: x + +config = transformers.AlbertConfig(embedding_size=128, + hidden_size=128, + num_hidden_layers=2, + num_attention_heads=4, + intermediate_size=256) + +model_zoo.register(name='transformers_albert', + model_fn=lambda: transformers.AlbertModel(config), + data_gen_fn=data_gen_fn, + output_transform_fn=output_transform_fn, + model_attribute=ModelAttribute(has_control_flow=True)) +model_zoo.register(name='transformers_albert_for_pretraining', + model_fn=lambda: transformers.AlbertForPreTraining(config), + data_gen_fn=data_gen_fn, + output_transform_fn=output_transform_fn, + model_attribute=ModelAttribute(has_control_flow=True)) +model_zoo.register(name='transformers_albert_for_masked_lm', + model_fn=lambda: transformers.AlbertForMaskedLM(config), + data_gen_fn=data_gen_fn, + output_transform_fn=output_transform_fn, + model_attribute=ModelAttribute(has_control_flow=True)) +model_zoo.register(name='transformers_albert_for_sequence_classification', + model_fn=lambda: transformers.AlbertForSequenceClassification(config), + data_gen_fn=data_gen_fn, + output_transform_fn=output_transform_fn, + model_attribute=ModelAttribute(has_control_flow=True)) +model_zoo.register(name='transformers_albert_for_token_classification', + model_fn=lambda: transformers.AlbertForTokenClassification(config), + data_gen_fn=data_gen_fn, + output_transform_fn=output_transform_fn, + model_attribute=ModelAttribute(has_control_flow=True)) + +# =============================== +# Register multi-sentence ALBERT +# =============================== + + +def data_gen_for_qa(): + question, text = "Who was Jim Henson?", "Jim Henson was a nice puppet" + tokenizer = transformers.BertTokenizer.from_pretrained("bert-base-uncased") + inputs = tokenizer(question, text, return_tensors="pt") + return inputs + + +def data_gen_for_mcq(): + prompt = "In Italy, pizza served in formal settings, such as at a restaurant, is presented unsliced." + choice0 = "It is eaten with a fork and a knife." + choice1 = "It is eaten while held in the hand." + tokenizer = transformers.BertTokenizer.from_pretrained("bert-base-uncased") + encoding = tokenizer([prompt, prompt], [choice0, choice1], return_tensors="pt", padding=True) + encoding = {k: v.unsqueeze(0) for k, v in encoding.items()} + return encoding + + +model_zoo.register(name='transformers_albert_for_question_answering', + model_fn=lambda: transformers.AlbertForQuestionAnswering(config), + data_gen_fn=data_gen_for_qa, + output_transform_fn=output_transform_fn, + model_attribute=ModelAttribute(has_control_flow=True)) +model_zoo.register(name='transformers_albert_for_multiple_choice', + model_fn=lambda: transformers.AlbertForMultipleChoice(config), + data_gen_fn=data_gen_for_mcq, + output_transform_fn=output_transform_fn, + model_attribute=ModelAttribute(has_control_flow=True)) diff --git a/tests/kit/model_zoo/transformers/bert.py b/tests/kit/model_zoo/transformers/bert.py new file mode 100644 index 000000000..99135704d --- /dev/null +++ b/tests/kit/model_zoo/transformers/bert.py @@ -0,0 +1,88 @@ +import torch +import transformers + +from ..registry import ModelAttribute, model_zoo + +# =============================== +# Register single-sentence BERT +# =============================== +BATCH_SIZE = 2 +SEQ_LENGTH = 16 + + +def data_gen_fn(): + input_ids = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64) + token_type_ids = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64) + attention_mask = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64) + return dict(input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask) + + +output_transform_fn = lambda x: x + +config = transformers.BertConfig(hidden_size=128, num_hidden_layers=2, num_attention_heads=4, intermediate_size=256) + +# register the BERT variants +model_zoo.register(name='transformers_bert', + model_fn=lambda: transformers.BertModel(config), + data_gen_fn=data_gen_fn, + output_transform_fn=output_transform_fn, + model_attribute=ModelAttribute(has_control_flow=True)) +model_zoo.register(name='transformers_bert_for_pretraining', + model_fn=lambda: transformers.BertForPreTraining(config), + data_gen_fn=data_gen_fn, + output_transform_fn=output_transform_fn, + model_attribute=ModelAttribute(has_control_flow=True)) +model_zoo.register(name='transformers_bert_lm_head_model', + model_fn=lambda: transformers.BertLMHeadModel(config), + data_gen_fn=data_gen_fn, + output_transform_fn=output_transform_fn, + model_attribute=ModelAttribute(has_control_flow=True)) +model_zoo.register(name='transformers_bert_for_masked_lm', + model_fn=lambda: transformers.BertForMaskedLM(config), + data_gen_fn=data_gen_fn, + output_transform_fn=output_transform_fn, + model_attribute=ModelAttribute(has_control_flow=True)) +model_zoo.register(name='transformers_bert_for_sequence_classification', + model_fn=lambda: transformers.BertForSequenceClassification(config), + data_gen_fn=data_gen_fn, + output_transform_fn=output_transform_fn, + model_attribute=ModelAttribute(has_control_flow=True)) +model_zoo.register(name='transformers_bert_for_token_classification', + model_fn=lambda: transformers.BertForTokenClassification(config), + data_gen_fn=data_gen_fn, + output_transform_fn=output_transform_fn, + model_attribute=ModelAttribute(has_control_flow=True)) + + +# =============================== +# Register multi-sentence BERT +# =============================== +def data_gen_for_next_sentence(): + tokenizer = transformers.BertTokenizer.from_pretrained("bert-base-uncased") + prompt = "In Italy, pizza served in formal settings, such as at a restaurant, is presented unsliced." + next_sentence = "The sky is blue due to the shorter wavelength of blue light." + encoding = tokenizer(prompt, next_sentence, return_tensors="pt") + return encoding + + +def data_gen_for_mcq(): + tokenizer = transformers.BertTokenizer.from_pretrained("bert-base-uncased") + prompt = "In Italy, pizza served in formal settings, such as at a restaurant, is presented unsliced." + choice0 = "It is eaten with a fork and a knife." + choice1 = "It is eaten while held in the hand." + encoding = tokenizer([prompt, prompt], [choice0, choice1], return_tensors="pt", padding=True) + encoding = {k: v.unsqueeze(0) for k, v in encoding.items()} + return encoding + + +# register the following models +model_zoo.register(name='transformers_bert_for_next_sentence', + model_fn=lambda: transformers.BertForNextSentencePrediction(config), + data_gen_fn=data_gen_for_next_sentence, + output_transform_fn=output_transform_fn, + model_attribute=ModelAttribute(has_control_flow=True)) +model_zoo.register(name='transformers_bert_for_mcq', + model_fn=lambda: transformers.BertForMultipleChoice(config), + data_gen_fn=data_gen_for_mcq, + output_transform_fn=output_transform_fn, + model_attribute=ModelAttribute(has_control_flow=True)) diff --git a/tests/kit/model_zoo/transformers/gpt.py b/tests/kit/model_zoo/transformers/gpt.py new file mode 100644 index 000000000..a92a46e36 --- /dev/null +++ b/tests/kit/model_zoo/transformers/gpt.py @@ -0,0 +1,49 @@ +import torch +import transformers + +from ..registry import ModelAttribute, model_zoo + +# =============================== +# Register single-sentence GPT +# =============================== +BATCH_SIZE = 2 +SEQ_LENGTH = 16 + + +def data_gen(): + input_ids = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64) + token_type_ids = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64) + attention_mask = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64) + return dict(input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask) + + +output_transform_fn = lambda x: x + +config = transformers.GPT2Config(n_position=64, n_layer=2, n_head=4) + +# register the following models +model_zoo.register(name='transformers_gpt', + model_fn=lambda: transformers.GPT2Model(config), + data_gen_fn=data_gen, + output_transform_fn=output_transform_fn, + model_attribute=ModelAttribute(has_control_flow=True)) +model_zoo.register(name='transformers_gpt_lm', + model_fn=lambda: transformers.GPT2LMHeadModel(config), + data_gen_fn=data_gen, + output_transform_fn=output_transform_fn, + model_attribute=ModelAttribute(has_control_flow=True)) +model_zoo.register(name='transformers_gpt_double_heads', + model_fn=lambda: transformers.GPT2DoubleHeadsModel(config), + data_gen_fn=data_gen, + output_transform_fn=output_transform_fn, + model_attribute=ModelAttribute(has_control_flow=True)) +model_zoo.register(name='transformers_gpt_for_token_classification', + model_fn=lambda: transformers.GPT2ForTokenClassification(config), + data_gen_fn=data_gen, + output_transform_fn=output_transform_fn, + model_attribute=ModelAttribute(has_control_flow=True)) +model_zoo.register(name='transformers_gpt_for_sequence_classification', + model_fn=lambda: transformers.GPT2ForSequenceClassification(config), + data_gen_fn=data_gen, + output_transform_fn=output_transform_fn, + model_attribute=ModelAttribute(has_control_flow=True)) diff --git a/tests/kit/model_zoo/transformers/opt.py b/tests/kit/model_zoo/transformers/opt.py new file mode 100644 index 000000000..d9c4a0b3c --- /dev/null +++ b/tests/kit/model_zoo/transformers/opt.py @@ -0,0 +1,35 @@ +import torch +import transformers + +from ..registry import ModelAttribute, model_zoo + +# =============================== +# Register single-sentence OPT +# =============================== +BATCH_SIZE = 2 +SEQ_LENGTH = 16 + + +def data_gen(): + input_ids = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64) + attention_mask = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64) + return dict(input_ids=input_ids, attention_mask=attention_mask) + + +output_transform_fn = lambda x: x + +config = transformers.OPTConfig(hidden_size=128, num_hidden_layers=2, num_attention_heads=4) + +# register the following models +# transformers.OPTModel, +# transformers.OPTForCausalLM, +model_zoo.register(name='transformers_opt', + model_fn=lambda: transformers.OPTModel(config), + data_gen_fn=data_gen, + output_transform_fn=output_transform_fn, + model_attribute=ModelAttribute(has_control_flow=True)) +model_zoo.register(name='transformers_opt_for_causal_lm', + model_fn=lambda: transformers.OPTForCausalLM(config), + data_gen_fn=data_gen, + output_transform_fn=output_transform_fn, + model_attribute=ModelAttribute(has_control_flow=True)) diff --git a/tests/kit/model_zoo/transformers/t5.py b/tests/kit/model_zoo/transformers/t5.py new file mode 100644 index 000000000..b81bcad90 --- /dev/null +++ b/tests/kit/model_zoo/transformers/t5.py @@ -0,0 +1,46 @@ +import torch +import transformers + +from ..registry import ModelAttribute, model_zoo + +# =============================== +# Register single-sentence T5 +# =============================== +BATCH_SIZE = 2 +SEQ_LENGTH = 16 + + +def data_gen(): + input_ids = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64) + decoder_input_ids = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64) + return dict(input_ids=input_ids, decoder_input_ids=decoder_input_ids) + + +def data_gen_for_encoder_only(): + input_ids = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64) + return dict(input_ids=input_ids) + + +output_transform_fn = lambda x: x + +config = transformers.T5Config(d_model=128, num_layers=2) + +# register the following models +# transformers.T5Model, +# transformers.T5ForConditionalGeneration, +# transformers.T5EncoderModel, +model_zoo.register(name='transformers_t5', + model_fn=lambda: transformers.T5Model(config), + data_gen_fn=data_gen, + output_transform_fn=output_transform_fn, + model_attribute=ModelAttribute(has_control_flow=True)) +model_zoo.register(name='transformers_t5_for_conditional_generation', + model_fn=lambda: transformers.T5ForConditionalGeneration(config), + data_gen_fn=data_gen, + output_transform_fn=output_transform_fn, + model_attribute=ModelAttribute(has_control_flow=True)) +model_zoo.register(name='transformers_t5_encoder_model', + model_fn=lambda: transformers.T5EncoderModel(config), + data_gen_fn=data_gen_for_encoder_only, + output_transform_fn=output_transform_fn, + model_attribute=ModelAttribute(has_control_flow=True)) diff --git a/tests/test_fx/test_tracer/test_hf_model/test_hf_albert.py b/tests/test_fx/test_tracer/test_hf_model/test_hf_albert.py index 9c36b0c9c..b1c9c211a 100644 --- a/tests/test_fx/test_tracer/test_hf_model/test_hf_albert.py +++ b/tests/test_fx/test_tracer/test_hf_model/test_hf_albert.py @@ -1,66 +1,18 @@ -import pytest -import torch -import transformers from hf_tracer_utils import trace_model_and_compare_output +from tests.kit.model_zoo import model_zoo + BATCH_SIZE = 2 SEQ_LENGTH = 16 -def test_single_sentence_albert(): - MODEL_LIST = [ - transformers.AlbertModel, - transformers.AlbertForPreTraining, - transformers.AlbertForMaskedLM, - transformers.AlbertForSequenceClassification, - transformers.AlbertForTokenClassification, - ] +def test_albert(): + sub_registry = model_zoo.get_sub_registry('transformers_albert') - config = transformers.AlbertConfig(embedding_size=128, - hidden_size=128, - num_hidden_layers=2, - num_attention_heads=4, - intermediate_size=256) - - def data_gen(): - input_ids = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64) - token_type_ids = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64) - attention_mask = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64) - meta_args = dict(input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask) - return meta_args - - for model_cls in MODEL_LIST: - model = model_cls(config=config) - trace_model_and_compare_output(model, data_gen) - - -def test_multi_sentence_albert(): - config = transformers.AlbertConfig(hidden_size=128, - num_hidden_layers=2, - num_attention_heads=4, - intermediate_size=256) - tokenizer = transformers.BertTokenizer.from_pretrained("bert-base-uncased") - - def data_gen_for_qa(): - question, text = "Who was Jim Henson?", "Jim Henson was a nice puppet" - inputs = tokenizer(question, text, return_tensors="pt") - return inputs - - model = transformers.AlbertForQuestionAnswering(config) - trace_model_and_compare_output(model, data_gen_for_qa) - - def data_gen_for_mcq(): - prompt = "In Italy, pizza served in formal settings, such as at a restaurant, is presented unsliced." - choice0 = "It is eaten with a fork and a knife." - choice1 = "It is eaten while held in the hand." - encoding = tokenizer([prompt, prompt], [choice0, choice1], return_tensors="pt", padding=True) - encoding = {k: v.unsqueeze(0) for k, v in encoding.items()} - return encoding - - model = transformers.AlbertForMultipleChoice(config) - trace_model_and_compare_output(model, data_gen_for_mcq) + for name, (model_fn, data_gen_fn, _, _) in sub_registry.items(): + model = model_fn() + trace_model_and_compare_output(model, data_gen_fn) if __name__ == '__main__': - test_single_sentence_albert() - test_multi_sentence_albert() + test_albert() diff --git a/tests/test_fx/test_tracer/test_hf_model/test_hf_bert.py b/tests/test_fx/test_tracer/test_hf_model/test_hf_bert.py index 62273e2d5..1bf4947c3 100644 --- a/tests/test_fx/test_tracer/test_hf_model/test_hf_bert.py +++ b/tests/test_fx/test_tracer/test_hf_model/test_hf_bert.py @@ -1,69 +1,15 @@ -import pytest -import torch -import transformers from hf_tracer_utils import trace_model_and_compare_output -BATCH_SIZE = 2 -SEQ_LENGTH = 16 +from tests.kit.model_zoo import model_zoo -def test_single_sentence_bert(): - MODEL_LIST = [ - transformers.BertModel, - transformers.BertForPreTraining, - transformers.BertLMHeadModel, - transformers.BertForMaskedLM, - transformers.BertForSequenceClassification, - transformers.BertForTokenClassification, - ] +def test_bert(): + sub_registry = model_zoo.get_sub_registry('transformers_bert') - config = transformers.BertConfig(hidden_size=128, num_hidden_layers=2, num_attention_heads=4, intermediate_size=256) - - def data_gen(): - input_ids = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64) - token_type_ids = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64) - attention_mask = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64) - meta_args = dict(input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask) - return meta_args - - for model_cls in MODEL_LIST: - model = model_cls(config=config) - trace_model_and_compare_output(model, data_gen) - - -def test_multi_sentence_bert(): - config = transformers.BertConfig(hidden_size=128, num_hidden_layers=2, num_attention_heads=4, intermediate_size=256) - tokenizer = transformers.BertTokenizer.from_pretrained("bert-base-uncased") - - def data_gen_for_next_sentence(): - prompt = "In Italy, pizza served in formal settings, such as at a restaurant, is presented unsliced." - next_sentence = "The sky is blue due to the shorter wavelength of blue light." - encoding = tokenizer(prompt, next_sentence, return_tensors="pt") - return encoding - - model = transformers.BertForNextSentencePrediction(config) - trace_model_and_compare_output(model, data_gen_for_next_sentence) - - def data_gen_for_qa(): - question, text = "Who was Jim Henson?", "Jim Henson was a nice puppet" - inputs = tokenizer(question, text, return_tensors="pt") - return inputs - - model = transformers.BertForQuestionAnswering(config) - trace_model_and_compare_output(model, data_gen_for_qa) - - def data_gen_for_mcq(): - prompt = "In Italy, pizza served in formal settings, such as at a restaurant, is presented unsliced." - choice0 = "It is eaten with a fork and a knife." - choice1 = "It is eaten while held in the hand." - encoding = tokenizer([prompt, prompt], [choice0, choice1], return_tensors="pt", padding=True) - encoding = {k: v.unsqueeze(0) for k, v in encoding.items()} - return encoding - - model = transformers.BertForMultipleChoice(config) - trace_model_and_compare_output(model, data_gen_for_mcq) + for name, (model_fn, data_gen_fn, _, _) in sub_registry.items(): + model = model_fn() + trace_model_and_compare_output(model, data_gen_fn) if __name__ == '__main__': - test_single_sentence_bert() - test_multi_sentence_bert() + test_bert() diff --git a/tests/test_fx/test_tracer/test_hf_model/test_hf_gpt.py b/tests/test_fx/test_tracer/test_hf_model/test_hf_gpt.py index ad4c9684d..67a3178fa 100644 --- a/tests/test_fx/test_tracer/test_hf_model/test_hf_gpt.py +++ b/tests/test_fx/test_tracer/test_hf_model/test_hf_gpt.py @@ -1,35 +1,17 @@ import pytest -import torch -import transformers from hf_tracer_utils import trace_model_and_compare_output -BATCH_SIZE = 1 -SEQ_LENGTH = 16 +from tests.kit.model_zoo import model_zoo # TODO: remove this skip once we handle the latest gpt model @pytest.mark.skip def test_gpt(): - MODEL_LIST = [ - transformers.GPT2Model, - transformers.GPT2LMHeadModel, - transformers.GPT2DoubleHeadsModel, - transformers.GPT2ForTokenClassification, - # transformers.GPT2ForSequenceClassification, # not supported yet - ] + sub_registry = model_zoo.get_sub_registry('transformers_gpt') - config = transformers.GPT2Config(n_position=64, n_layer=2, n_head=4) - - def data_gen(): - input_ids = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64) - token_type_ids = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64) - attention_mask = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64) - kwargs = dict(input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask) - return kwargs - - for model_cls in MODEL_LIST: - model = model_cls(config=config) - trace_model_and_compare_output(model, data_gen) + for name, (model_fn, data_gen_fn, _, _) in sub_registry.items(): + model = model_fn() + trace_model_and_compare_output(model, data_gen_fn) if __name__ == '__main__': diff --git a/tests/test_fx/test_tracer/test_hf_model/test_hf_opt.py b/tests/test_fx/test_tracer/test_hf_model/test_hf_opt.py index 06260176e..740f5a9f0 100644 --- a/tests/test_fx/test_tracer/test_hf_model/test_hf_opt.py +++ b/tests/test_fx/test_tracer/test_hf_model/test_hf_opt.py @@ -1,29 +1,14 @@ -import pytest -import torch -import transformers from hf_tracer_utils import trace_model_and_compare_output -BATCH_SIZE = 1 -SEQ_LENGTH = 16 +from tests.kit.model_zoo import model_zoo def test_opt(): - MODEL_LIST = [ - transformers.OPTModel, - transformers.OPTForCausalLM, - ] + sub_registry = model_zoo.get_sub_registry('transformers_opt') - config = transformers.OPTConfig(hidden_size=128, num_hidden_layers=2, num_attention_heads=4) - - def data_gen(): - input_ids = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64) - attention_mask = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64) - kwargs = dict(input_ids=input_ids, attention_mask=attention_mask) - return kwargs - - for model_cls in MODEL_LIST: - model = model_cls(config=config) - trace_model_and_compare_output(model, data_gen) + for name, (model_fn, data_gen_fn, _, _) in sub_registry.items(): + model = model_fn() + trace_model_and_compare_output(model, data_gen_fn) if __name__ == '__main__': diff --git a/tests/test_fx/test_tracer/test_hf_model/test_hf_t5.py b/tests/test_fx/test_tracer/test_hf_model/test_hf_t5.py index 71e782fdd..7073fd634 100644 --- a/tests/test_fx/test_tracer/test_hf_model/test_hf_t5.py +++ b/tests/test_fx/test_tracer/test_hf_model/test_hf_t5.py @@ -1,41 +1,14 @@ -import pytest -import torch -import transformers from hf_tracer_utils import trace_model_and_compare_output -BATCH_SIZE = 1 -SEQ_LENGTH = 16 +from tests.kit.model_zoo import model_zoo def test_t5(): - MODEL_LIST = [ - transformers.T5Model, - transformers.T5ForConditionalGeneration, - transformers.T5EncoderModel, - ] + sub_registry = model_zoo.get_sub_registry('transformers_t5') - config = transformers.T5Config(d_model=128, num_layers=2) - - def data_gen(): - input_ids = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64) - decoder_input_ids = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64) - kwargs = dict(input_ids=input_ids, decoder_input_ids=decoder_input_ids) - return kwargs - - def data_gen_for_encoder_only(): - input_ids = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64) - kwargs = dict(input_ids=input_ids) - return kwargs - - for model_cls in MODEL_LIST: - model = model_cls(config=config) - - if isinstance(model, transformers.T5EncoderModel): - data_gen_func = data_gen_for_encoder_only - else: - data_gen_func = data_gen - - trace_model_and_compare_output(model, data_gen_func) + for name, (model_fn, data_gen_fn, _, _) in sub_registry.items(): + model = model_fn() + trace_model_and_compare_output(model, data_gen_fn) if __name__ == '__main__':