Fix position_ids in prediction

pull/350/head
duzx16 2 years ago
parent 11b5a54484
commit c508f62b70

@ -185,6 +185,8 @@ class Seq2SeqTrainer(Trainer):
if "attention_mask" in inputs:
gen_kwargs["attention_mask"] = inputs.get("attention_mask", None)
if "position_ids" in inputs:
gen_kwargs["position_ids"] = inputs.get("position_ids", None)
if "global_attention_mask" in inputs:
gen_kwargs["global_attention_mask"] = inputs.get("global_attention_mask", None)

Loading…
Cancel
Save