From 55811708902df7d1c3e2557bb7c1117fe1ba07e5 Mon Sep 17 00:00:00 2001 From: Frank Lee Date: Thu, 7 Jul 2022 16:29:58 +0800 Subject: [PATCH] [fx] fixed huggingface OPT and T5 results misalignment (#1227) --- .../test_tracer/test_hf_model/test_hf_t5.py | 13 ++++++++++++- tests/test_fx/test_tracer/test_hf_model/utils.py | 15 +++++++-------- 2 files changed, 19 insertions(+), 9 deletions(-) diff --git a/tests/test_fx/test_tracer/test_hf_model/test_hf_t5.py b/tests/test_fx/test_tracer/test_hf_model/test_hf_t5.py index 001ada2db..0bf765174 100644 --- a/tests/test_fx/test_tracer/test_hf_model/test_hf_t5.py +++ b/tests/test_fx/test_tracer/test_hf_model/test_hf_t5.py @@ -23,9 +23,20 @@ def test_t5(): kwargs = dict(input_ids=input_ids, decoder_input_ids=decoder_input_ids) return kwargs + def data_gen_for_encoder_only(): + input_ids = torch.zeros((BATCH_SIZE, SEQ_LENGHT), dtype=torch.int64) + kwargs = dict(input_ids=input_ids) + return kwargs + for model_cls in MODEL_LIST: model = model_cls(config=config) - trace_model_and_compare_output(model, data_gen) + + if isinstance(model, transformers.T5EncoderModel): + data_gen_func = data_gen_for_encoder_only + else: + data_gen_func = data_gen + + trace_model_and_compare_output(model, data_gen_func) if __name__ == '__main__': diff --git a/tests/test_fx/test_tracer/test_hf_model/utils.py b/tests/test_fx/test_tracer/test_hf_model/utils.py index 382c87ad5..038548209 100644 --- a/tests/test_fx/test_tracer/test_hf_model/utils.py +++ b/tests/test_fx/test_tracer/test_hf_model/utils.py @@ -6,8 +6,12 @@ from torch.utils._pytree import tree_flatten def trace_model_and_compare_output(model, data_gen): - tracer = ColoTracer() + # must turn on eval mode to ensure the output is consistent + model.eval() + # make sure that the model is traceable + tracer = ColoTracer() + try: kwargs = data_gen() meta_args = {k: v.to('meta') for k, v in kwargs.items()} @@ -17,17 +21,12 @@ def trace_model_and_compare_output(model, data_gen): gm = GraphModule(model, graph, model.__class__.__name__) gm.recompile() - # check output - inputs = data_gen() - - # must turn on eval mode to ensure the output is consistent - gm.eval() - model.eval() - # run forward + inputs = data_gen() non_fx_out = model(**inputs) fx_out = gm(**inputs) + # check output for k in non_fx_out.keys(): if torch.is_tensor(fx_out[k]): assert torch.equal(