diff --git a/ptuning/main.py b/ptuning/main.py index d1c4977..1776055 100644 --- a/ptuning/main.py +++ b/ptuning/main.py @@ -147,7 +147,7 @@ def main(): targets.append(examples[response_column][i]) inputs = [prefix + inp for inp in inputs] - model_inputs = tokenizer(inputs, max_length=data_args.max_source_length, truncation=True) + model_inputs = tokenizer(inputs, max_length=data_args.max_source_length, truncation=True, padding=True) labels = tokenizer(text_target=targets, max_length=max_target_length, truncation=True) if data_args.ignore_pad_token_for_loss: