Fix position_ids in prediction

pull/350/head
duzx16 2023-04-02 01:59:07 +08:00
parent 11b5a54484
commit c508f62b70
1 changed files with 2 additions and 0 deletions

View File

@ -185,6 +185,8 @@ class Seq2SeqTrainer(Trainer):
if "attention_mask" in inputs:
gen_kwargs["attention_mask"] = inputs.get("attention_mask", None)
if "position_ids" in inputs:
gen_kwargs["position_ids"] = inputs.get("position_ids", None)
if "global_attention_mask" in inputs:
gen_kwargs["global_attention_mask"] = inputs.get("global_attention_mask", None)