diff --git a/ptuning/main.py b/ptuning/main.py index 1776055..e93b689 100644 --- a/ptuning/main.py +++ b/ptuning/main.py @@ -185,8 +185,8 @@ def main(): labels = [-100] * context_length + input_ids[mask_position+1:] pad_len = max_seq_length - len(input_ids) - input_ids = input_ids + [tokenizer.pad_token_id] * pad_len - labels = labels + [tokenizer.pad_token_id] * pad_len + input_ids = [tokenizer.pad_token_id] * pad_len + input_ids + labels = [tokenizer.pad_token_id] * pad_len + labels model_inputs["input_ids"].append(input_ids) model_inputs["labels"].append(labels)