diff --git a/ptuning/trainer_seq2seq.py b/ptuning/trainer_seq2seq.py index 0087786..518daa0 100644 --- a/ptuning/trainer_seq2seq.py +++ b/ptuning/trainer_seq2seq.py @@ -185,6 +185,8 @@ class Seq2SeqTrainer(Trainer): if "attention_mask" in inputs: gen_kwargs["attention_mask"] = inputs.get("attention_mask", None) + if "position_ids" in inputs: + gen_kwargs["position_ids"] = inputs.get("position_ids", None) if "global_attention_mask" in inputs: gen_kwargs["global_attention_mask"] = inputs.get("global_attention_mask", None)