mirror of https://github.com/THUDM/ChatGLM-6B
parent
a1ecafd91f
commit
ea682a6f51
|
@ -1,5 +1,5 @@
|
|||
PRE_SEQ_LEN=8
|
||||
CHECKPOINT=adgen-chatglm-6b-pt-8-1e-2
|
||||
PRE_SEQ_LEN=128
|
||||
CHECKPOINT=adgen-chatglm-6b-pt-128-2e-2
|
||||
STEP=3000
|
||||
|
||||
CUDA_VISIBLE_DEVICES=0 python3 main.py \
|
||||
|
|
|
@ -167,7 +167,7 @@ def main():
|
|||
model_inputs["labels"] = labels["input_ids"]
|
||||
|
||||
return model_inputs
|
||||
|
||||
|
||||
def preprocess_function_train(examples):
|
||||
max_seq_length = data_args.max_source_length + data_args.max_target_length
|
||||
|
||||
|
@ -198,9 +198,9 @@ def main():
|
|||
if len(b_ids) > data_args.max_target_length - 2:
|
||||
b_ids = b_ids[: data_args.max_target_length - 2]
|
||||
|
||||
input_ids = a_ids + [150001, 150004] + b_ids + [150005]
|
||||
input_ids = tokenizer.build_inputs_with_special_tokens(a_ids, b_ids)
|
||||
|
||||
context_length = input_ids.index(150004)
|
||||
context_length = input_ids.index(tokenizer.bos_token_id)
|
||||
mask_position = context_length - 1
|
||||
labels = [-100] * context_length + input_ids[mask_position+1:]
|
||||
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
PRE_SEQ_LEN=8
|
||||
LR=1e-2
|
||||
PRE_SEQ_LEN=128
|
||||
LR=2e-2
|
||||
|
||||
CUDA_VISIBLE_DEVICES=0 python3 main.py \
|
||||
--do_train \
|
||||
|
|
Loading…
Reference in New Issue