mirror of https://github.com/hpcaitech/ColossalAI
[fx] fixed huggingface OPT and T5 results misalignment (#1227)
parent
2b7dca44b5
commit
5581170890
|
@ -23,9 +23,20 @@ def test_t5():
|
|||
kwargs = dict(input_ids=input_ids, decoder_input_ids=decoder_input_ids)
|
||||
return kwargs
|
||||
|
||||
def data_gen_for_encoder_only():
|
||||
input_ids = torch.zeros((BATCH_SIZE, SEQ_LENGHT), dtype=torch.int64)
|
||||
kwargs = dict(input_ids=input_ids)
|
||||
return kwargs
|
||||
|
||||
for model_cls in MODEL_LIST:
|
||||
model = model_cls(config=config)
|
||||
trace_model_and_compare_output(model, data_gen)
|
||||
|
||||
if isinstance(model, transformers.T5EncoderModel):
|
||||
data_gen_func = data_gen_for_encoder_only
|
||||
else:
|
||||
data_gen_func = data_gen
|
||||
|
||||
trace_model_and_compare_output(model, data_gen_func)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
|
|
@ -6,8 +6,12 @@ from torch.utils._pytree import tree_flatten
|
|||
|
||||
|
||||
def trace_model_and_compare_output(model, data_gen):
|
||||
tracer = ColoTracer()
|
||||
# must turn on eval mode to ensure the output is consistent
|
||||
model.eval()
|
||||
|
||||
# make sure that the model is traceable
|
||||
tracer = ColoTracer()
|
||||
|
||||
try:
|
||||
kwargs = data_gen()
|
||||
meta_args = {k: v.to('meta') for k, v in kwargs.items()}
|
||||
|
@ -17,17 +21,12 @@ def trace_model_and_compare_output(model, data_gen):
|
|||
gm = GraphModule(model, graph, model.__class__.__name__)
|
||||
gm.recompile()
|
||||
|
||||
# check output
|
||||
inputs = data_gen()
|
||||
|
||||
# must turn on eval mode to ensure the output is consistent
|
||||
gm.eval()
|
||||
model.eval()
|
||||
|
||||
# run forward
|
||||
inputs = data_gen()
|
||||
non_fx_out = model(**inputs)
|
||||
fx_out = gm(**inputs)
|
||||
|
||||
# check output
|
||||
for k in non_fx_out.keys():
|
||||
if torch.is_tensor(fx_out[k]):
|
||||
assert torch.equal(
|
||||
|
|
Loading…
Reference in New Issue