mirror of https://github.com/THUDM/ChatGLM-6B
Update main.py
handle corner case when input+target > max_seq_length. we truncate left target tokens.pull/293/head
parent
fdc2c7f70d
commit
67bc05e852
|
@ -185,6 +185,10 @@ def main():
|
|||
labels = [-100] * context_length + input_ids[mask_position+1:]
|
||||
|
||||
pad_len = max_seq_length - len(input_ids)
|
||||
if pad_len < 0:
|
||||
input_ids = input_ids[:pad_len-1]+ [150005]
|
||||
labels = labels[:pad_len-1] + [150005]
|
||||
else:
|
||||
input_ids = input_ids + [tokenizer.pad_token_id] * pad_len
|
||||
labels = labels + [tokenizer.pad_token_id] * pad_len
|
||||
|
||||
|
|
Loading…
Reference in New Issue