mirror of https://github.com/hpcaitech/ColossalAI
[fx] fixed huggingface OPT and T5 results misalignment (#1227)
parent
2b7dca44b5
commit
5581170890
|
@ -23,9 +23,20 @@ def test_t5():
|
||||||
kwargs = dict(input_ids=input_ids, decoder_input_ids=decoder_input_ids)
|
kwargs = dict(input_ids=input_ids, decoder_input_ids=decoder_input_ids)
|
||||||
return kwargs
|
return kwargs
|
||||||
|
|
||||||
|
def data_gen_for_encoder_only():
|
||||||
|
input_ids = torch.zeros((BATCH_SIZE, SEQ_LENGHT), dtype=torch.int64)
|
||||||
|
kwargs = dict(input_ids=input_ids)
|
||||||
|
return kwargs
|
||||||
|
|
||||||
for model_cls in MODEL_LIST:
|
for model_cls in MODEL_LIST:
|
||||||
model = model_cls(config=config)
|
model = model_cls(config=config)
|
||||||
trace_model_and_compare_output(model, data_gen)
|
|
||||||
|
if isinstance(model, transformers.T5EncoderModel):
|
||||||
|
data_gen_func = data_gen_for_encoder_only
|
||||||
|
else:
|
||||||
|
data_gen_func = data_gen
|
||||||
|
|
||||||
|
trace_model_and_compare_output(model, data_gen_func)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
|
|
@ -6,8 +6,12 @@ from torch.utils._pytree import tree_flatten
|
||||||
|
|
||||||
|
|
||||||
def trace_model_and_compare_output(model, data_gen):
|
def trace_model_and_compare_output(model, data_gen):
|
||||||
tracer = ColoTracer()
|
# must turn on eval mode to ensure the output is consistent
|
||||||
|
model.eval()
|
||||||
|
|
||||||
# make sure that the model is traceable
|
# make sure that the model is traceable
|
||||||
|
tracer = ColoTracer()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
kwargs = data_gen()
|
kwargs = data_gen()
|
||||||
meta_args = {k: v.to('meta') for k, v in kwargs.items()}
|
meta_args = {k: v.to('meta') for k, v in kwargs.items()}
|
||||||
|
@ -17,17 +21,12 @@ def trace_model_and_compare_output(model, data_gen):
|
||||||
gm = GraphModule(model, graph, model.__class__.__name__)
|
gm = GraphModule(model, graph, model.__class__.__name__)
|
||||||
gm.recompile()
|
gm.recompile()
|
||||||
|
|
||||||
# check output
|
|
||||||
inputs = data_gen()
|
|
||||||
|
|
||||||
# must turn on eval mode to ensure the output is consistent
|
|
||||||
gm.eval()
|
|
||||||
model.eval()
|
|
||||||
|
|
||||||
# run forward
|
# run forward
|
||||||
|
inputs = data_gen()
|
||||||
non_fx_out = model(**inputs)
|
non_fx_out = model(**inputs)
|
||||||
fx_out = gm(**inputs)
|
fx_out = gm(**inputs)
|
||||||
|
|
||||||
|
# check output
|
||||||
for k in non_fx_out.keys():
|
for k in non_fx_out.keys():
|
||||||
if torch.is_tensor(fx_out[k]):
|
if torch.is_tensor(fx_out[k]):
|
||||||
assert torch.equal(
|
assert torch.equal(
|
||||||
|
|
Loading…
Reference in New Issue