From 67bc05e852ac3790741781a439d5aa57f7b6a501 Mon Sep 17 00:00:00 2001 From: dumpmemory <64742282+dumpmemory@users.noreply.github.com> Date: Fri, 31 Mar 2023 15:05:18 +0800 Subject: [PATCH] Update main.py handle corner case when input+target > max_seq_length. we truncate left target tokens. --- ptuning/main.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/ptuning/main.py b/ptuning/main.py index d82fccc..abfd9ef 100644 --- a/ptuning/main.py +++ b/ptuning/main.py @@ -185,8 +185,12 @@ def main(): labels = [-100] * context_length + input_ids[mask_position+1:] pad_len = max_seq_length - len(input_ids) - input_ids = input_ids + [tokenizer.pad_token_id] * pad_len - labels = labels + [tokenizer.pad_token_id] * pad_len + if pad_len < 0: + input_ids = input_ids[:pad_len-1]+ [150005] + labels = labels[:pad_len-1] + [150005] + else: + input_ids = input_ids + [tokenizer.pad_token_id] * pad_len + labels = labels + [tokenizer.pad_token_id] * pad_len model_inputs["input_ids"].append(input_ids) model_inputs["labels"].append(labels) @@ -386,4 +390,4 @@ def _mp_fn(index): if __name__ == "__main__": - main() \ No newline at end of file + main()