mirror of https://github.com/THUDM/ChatGLM-6B
parent
a1ecafd91f
commit
ea682a6f51
|
@ -1,5 +1,5 @@
|
||||||
PRE_SEQ_LEN=8
|
PRE_SEQ_LEN=128
|
||||||
CHECKPOINT=adgen-chatglm-6b-pt-8-1e-2
|
CHECKPOINT=adgen-chatglm-6b-pt-128-2e-2
|
||||||
STEP=3000
|
STEP=3000
|
||||||
|
|
||||||
CUDA_VISIBLE_DEVICES=0 python3 main.py \
|
CUDA_VISIBLE_DEVICES=0 python3 main.py \
|
||||||
|
|
|
@ -167,7 +167,7 @@ def main():
|
||||||
model_inputs["labels"] = labels["input_ids"]
|
model_inputs["labels"] = labels["input_ids"]
|
||||||
|
|
||||||
return model_inputs
|
return model_inputs
|
||||||
|
|
||||||
def preprocess_function_train(examples):
|
def preprocess_function_train(examples):
|
||||||
max_seq_length = data_args.max_source_length + data_args.max_target_length
|
max_seq_length = data_args.max_source_length + data_args.max_target_length
|
||||||
|
|
||||||
|
@ -198,9 +198,9 @@ def main():
|
||||||
if len(b_ids) > data_args.max_target_length - 2:
|
if len(b_ids) > data_args.max_target_length - 2:
|
||||||
b_ids = b_ids[: data_args.max_target_length - 2]
|
b_ids = b_ids[: data_args.max_target_length - 2]
|
||||||
|
|
||||||
input_ids = a_ids + [150001, 150004] + b_ids + [150005]
|
input_ids = tokenizer.build_inputs_with_special_tokens(a_ids, b_ids)
|
||||||
|
|
||||||
context_length = input_ids.index(150004)
|
context_length = input_ids.index(tokenizer.bos_token_id)
|
||||||
mask_position = context_length - 1
|
mask_position = context_length - 1
|
||||||
labels = [-100] * context_length + input_ids[mask_position+1:]
|
labels = [-100] * context_length + input_ids[mask_position+1:]
|
||||||
|
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
PRE_SEQ_LEN=8
|
PRE_SEQ_LEN=128
|
||||||
LR=1e-2
|
LR=2e-2
|
||||||
|
|
||||||
CUDA_VISIBLE_DEVICES=0 python3 main.py \
|
CUDA_VISIBLE_DEVICES=0 python3 main.py \
|
||||||
--do_train \
|
--do_train \
|
||||||
|
|
Loading…
Reference in New Issue