From 5865924cc61b780d4fdc6c8ec499b38ce87f9280 Mon Sep 17 00:00:00 2001 From: rainatam Date: Thu, 6 Apr 2023 20:21:29 +0800 Subject: [PATCH] Add training for chat data --- ptuning/README.md | 69 ++++++++++++++++++++++++++++++++++++++----- ptuning/arguments.py | 4 +++ ptuning/main.py | 26 ++++++++++++++-- ptuning/train_chat.sh | 27 +++++++++++++++++ 4 files changed, 117 insertions(+), 9 deletions(-) create mode 100644 ptuning/train_chat.sh diff --git a/ptuning/README.md b/ptuning/README.md index ca1fc73..c9497ad 100644 --- a/ptuning/README.md +++ b/ptuning/README.md @@ -57,12 +57,15 @@ bash evaluate.sh ### 评估结果 -| | P-tuning v2 | LoRA | -| ------- | ----------- | ----- | -| BLEU-4 | 7.71 | 6.13 | -| Rouge-1 | 31.35 | 28.36 | -| Rouge-2 | 7.19 | 4.38 | -| Rouge-l | 25.17 | 17.54 | +| | P-tuning v2 | LoRA | +| ------------- | ----------- | ----- | +| BLEU-4 | 7.78 | 6.13 | +| Rouge-1 | 31.34 | 28.36 | +| Rouge-2 | 7.34 | 4.38 | +| Rouge-l | 25.26 | 17.54 | +| Training Loss | 3.8016 | 3.36 | + + #### 实验设置 @@ -98,8 +101,60 @@ learning_rate=5e-4 ## 使用自己的数据集 修改 `train.sh` 和 `evaluate.sh` 中的 `train_file`、`validation_file`和`test_file`为你自己的 JSON 格式数据集路径,并将 `prompt_column` 和 `response_column` 改为 JSON 文件中输入文本和输出文本对应的 KEY。 +## 对话数据集 + +如需要使用多轮对话数据对模型进行微调,可以提供聊天历史,例如 + +```json +{ + "prompt": "是的。上下水管都好的", + "response": "那就要检查线路了,一般风扇继电器是由电脑控制吸合的,如果电路存在断路,或者电脑坏了的话会出现继电器不吸合的情况!", + "history": [ + [ + "长城h3风扇不转。继电器好的。保险丝好的传感器新的风扇也新的这是为什么。就是继电器缺一个信号线", + "用电脑能读数据流吗?水温多少" + ], + [ + "95", + "上下水管温差怎么样啊?空气是不是都排干净了呢?" + ] + ] +} +``` + +训练时需要指定 `--history_column` 为数据中聊天历史的 key(在此例子中是 `history`),将自动把聊天历史拼接,例如: + +- Input + + ``` + [Round 0] + 问:长城h3风扇不转。继电器好的。保险丝好的传感器新的风扇也新的这是为什么。就是继电器缺一个信号线 + 答:用电脑能读数据流吗?水温多少 + [Round 1] + 问:95 + 答:上下水管温差怎么样啊?空气是不是都排干净了呢? + [Round 2] + 问:是的。上下水管都好的 + ``` + +- Label + + ``` + 那就要检查线路了,一般风扇继电器是由电脑控制吸合的,如果电路存在断路,或者电脑坏了的话会出现继电器不吸合的情况! + ``` + +要注意超过输入长度 `max_source_length` 的内容会被截。 + +可以参考以下指令: + +```shell +bash train_chat.sh +``` + + + ## TODO -* [ ] Support for chat data +* [x] Support for chat data * [ ] Support for full finetuning ## 引用 diff --git a/ptuning/arguments.py b/ptuning/arguments.py index 95d766f..f9310da 100644 --- a/ptuning/arguments.py +++ b/ptuning/arguments.py @@ -80,6 +80,10 @@ class DataTrainingArguments: default=None, metadata={"help": "The name of the column in the datasets containing the summaries (for summarization)."}, ) + history_column: Optional[str] = field( + default=None, + metadata={"help": "The name of the column in the datasets containing the history of chat."}, + ) train_file: Optional[str] = field( default=None, metadata={"help": "The input training data file (a jsonlines or csv file)."} ) diff --git a/ptuning/main.py b/ptuning/main.py index fbf3924..a7837ac 100644 --- a/ptuning/main.py +++ b/ptuning/main.py @@ -135,6 +135,7 @@ def main(): # Get the column names for input/target. prompt_column = data_args.prompt_column response_column = data_args.response_column + history_column = data_args.history_column # Temporarily set max_target_length for training. max_target_length = data_args.max_target_length @@ -143,7 +144,16 @@ def main(): inputs, targets = [], [] for i in range(len(examples[prompt_column])): if examples[prompt_column][i] and examples[response_column][i]: - inputs.append(examples[prompt_column][i]) + query = examples[prompt_column][i] + if history_column is None or len(examples[history_column][i]) == 0: + prompt = query + else: + prompt = "" + history = examples[history_column][i] + for i, (old_query, response) in enumerate(history): + prompt += "[Round {}]\n问:{}\n答:{}\n".format(i, old_query, response) + prompt += "[Round {}]\n问:{}\n答:".format(len(history), query) + inputs.append(prompt) targets.append(examples[response_column][i]) inputs = [prefix + inp for inp in inputs] @@ -167,7 +177,17 @@ def main(): } for i in range(len(examples[prompt_column])): if examples[prompt_column][i] and examples[response_column][i]: - prompt, answer = examples[prompt_column][i], examples[response_column][i] + query, answer = examples[prompt_column][i], examples[response_column][i] + + if history_column is None: + prompt = query + else: + prompt = "" + history = examples[history_column][i] + for i, (old_query, response) in enumerate(history): + prompt += "[Round {}]\n问:{}\n答:{}\n".format(i, old_query, response) + prompt += "[Round {}]\n问:{}\n答:".format(len(history), query) + prompt = prefix + prompt a_ids = tokenizer.encode(text=prompt, add_special_tokens=False) b_ids = tokenizer.encode(text=answer, add_special_tokens=False) @@ -218,6 +238,8 @@ def main(): desc="Running tokenizer on train dataset", ) print_dataset_example(train_dataset[0]) + print_dataset_example(train_dataset[2]) + exit() if training_args.do_eval: max_target_length = data_args.val_max_target_length diff --git a/ptuning/train_chat.sh b/ptuning/train_chat.sh new file mode 100644 index 0000000..b0f5cdc --- /dev/null +++ b/ptuning/train_chat.sh @@ -0,0 +1,27 @@ +PRE_SEQ_LEN=8 +LR=1e-2 + +CUDA_VISIBLE_DEVICES=0 python3 main.py \ + --do_train \ + --train_file $CHAT_TRAIN_DATA \ + --validation_file $CHAT_VAL_DATA \ + --prompt_column prompt \ + --response_column response \ + --history_column history \ + --overwrite_cache \ + --model_name_or_path THUDM/chatglm-6b \ + --output_dir $CHECKPOINT_NAME \ + --overwrite_output_dir \ + --max_source_length 256 \ + --max_target_length 256 \ + --per_device_train_batch_size 1 \ + --per_device_eval_batch_size 1 \ + --gradient_accumulation_steps 16 \ + --predict_with_generate \ + --max_steps 3000 \ + --logging_steps 10 \ + --save_steps 1000 \ + --learning_rate $LR \ + --pre_seq_len $PRE_SEQ_LEN \ + --quantization_bit 4 +