Update main.py

handle corner case when input+target > max_seq_length. we truncate left target tokens.
pull/293/head
dumpmemory 2023-03-31 15:05:18 +08:00 committed by GitHub
parent fdc2c7f70d
commit 67bc05e852
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 7 additions and 3 deletions

View File

@ -185,8 +185,12 @@ def main():
labels = [-100] * context_length + input_ids[mask_position+1:]
pad_len = max_seq_length - len(input_ids)
input_ids = input_ids + [tokenizer.pad_token_id] * pad_len
labels = labels + [tokenizer.pad_token_id] * pad_len
if pad_len < 0:
input_ids = input_ids[:pad_len-1]+ [150005]
labels = labels[:pad_len-1] + [150005]
else:
input_ids = input_ids + [tokenizer.pad_token_id] * pad_len
labels = labels + [tokenizer.pad_token_id] * pad_len
model_inputs["input_ids"].append(input_ids)
model_inputs["labels"].append(labels)
@ -386,4 +390,4 @@ def _mp_fn(index):
if __name__ == "__main__":
main()
main()