diff --git a/ptuning/main.py b/ptuning/main.py index d82fccc..abfd9ef 100644 --- a/ptuning/main.py +++ b/ptuning/main.py @@ -185,8 +185,12 @@ def main(): labels = [-100] * context_length + input_ids[mask_position+1:] pad_len = max_seq_length - len(input_ids) - input_ids = input_ids + [tokenizer.pad_token_id] * pad_len - labels = labels + [tokenizer.pad_token_id] * pad_len + if pad_len < 0: + input_ids = input_ids[:pad_len-1]+ [150005] + labels = labels[:pad_len-1] + [150005] + else: + input_ids = input_ids + [tokenizer.pad_token_id] * pad_len + labels = labels + [tokenizer.pad_token_id] * pad_len model_inputs["input_ids"].append(input_ids) model_inputs["labels"].append(labels) @@ -386,4 +390,4 @@ def _mp_fn(index): if __name__ == "__main__": - main() \ No newline at end of file + main()