mirror of https://github.com/THUDM/ChatGLM-6B
Fix position_ids in prediction
parent
11b5a54484
commit
c508f62b70
|
@ -185,6 +185,8 @@ class Seq2SeqTrainer(Trainer):
|
|||
|
||||
if "attention_mask" in inputs:
|
||||
gen_kwargs["attention_mask"] = inputs.get("attention_mask", None)
|
||||
if "position_ids" in inputs:
|
||||
gen_kwargs["position_ids"] = inputs.get("position_ids", None)
|
||||
if "global_attention_mask" in inputs:
|
||||
gen_kwargs["global_attention_mask"] = inputs.get("global_attention_mask", None)
|
||||
|
||||
|
|
Loading…
Reference in New Issue