[fx] fixed huggingface OPT and T5 results misalignment (#1227)

pull/1222/head^2
Frank Lee 2022-07-07 16:29:58 +08:00 committed by GitHub
parent 2b7dca44b5
commit 5581170890
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 19 additions and 9 deletions

View File

@ -23,9 +23,20 @@ def test_t5():
kwargs = dict(input_ids=input_ids, decoder_input_ids=decoder_input_ids)
return kwargs
def data_gen_for_encoder_only():
input_ids = torch.zeros((BATCH_SIZE, SEQ_LENGHT), dtype=torch.int64)
kwargs = dict(input_ids=input_ids)
return kwargs
for model_cls in MODEL_LIST:
model = model_cls(config=config)
trace_model_and_compare_output(model, data_gen)
if isinstance(model, transformers.T5EncoderModel):
data_gen_func = data_gen_for_encoder_only
else:
data_gen_func = data_gen
trace_model_and_compare_output(model, data_gen_func)
if __name__ == '__main__':

View File

@ -6,8 +6,12 @@ from torch.utils._pytree import tree_flatten
def trace_model_and_compare_output(model, data_gen):
tracer = ColoTracer()
# must turn on eval mode to ensure the output is consistent
model.eval()
# make sure that the model is traceable
tracer = ColoTracer()
try:
kwargs = data_gen()
meta_args = {k: v.to('meta') for k, v in kwargs.items()}
@ -17,17 +21,12 @@ def trace_model_and_compare_output(model, data_gen):
gm = GraphModule(model, graph, model.__class__.__name__)
gm.recompile()
# check output
inputs = data_gen()
# must turn on eval mode to ensure the output is consistent
gm.eval()
model.eval()
# run forward
inputs = data_gen()
non_fx_out = model(**inputs)
fx_out = gm(**inputs)
# check output
for k in non_fx_out.keys():
if torch.is_tensor(fx_out[k]):
assert torch.equal(