Add training for chat data

pull/427/head
rainatam 2 years ago
parent a9fc018444
commit 5865924cc6

@ -58,11 +58,14 @@ bash evaluate.sh
### 评估结果 ### 评估结果
| | P-tuning v2 | LoRA | | | P-tuning v2 | LoRA |
| ------- | ----------- | ----- | | ------------- | ----------- | ----- |
| BLEU-4 | 7.71 | 6.13 | | BLEU-4 | 7.78 | 6.13 |
| Rouge-1 | 31.35 | 28.36 | | Rouge-1 | 31.34 | 28.36 |
| Rouge-2 | 7.19 | 4.38 | | Rouge-2 | 7.34 | 4.38 |
| Rouge-l | 25.17 | 17.54 | | Rouge-l | 25.26 | 17.54 |
| Training Loss | 3.8016 | 3.36 |
#### 实验设置 #### 实验设置
@ -98,8 +101,60 @@ learning_rate=5e-4
## 使用自己的数据集 ## 使用自己的数据集
修改 `train.sh``evaluate.sh` 中的 `train_file`、`validation_file`和`test_file`为你自己的 JSON 格式数据集路径,并将 `prompt_column``response_column` 改为 JSON 文件中输入文本和输出文本对应的 KEY。 修改 `train.sh``evaluate.sh` 中的 `train_file`、`validation_file`和`test_file`为你自己的 JSON 格式数据集路径,并将 `prompt_column``response_column` 改为 JSON 文件中输入文本和输出文本对应的 KEY。
## 对话数据集
如需要使用多轮对话数据对模型进行微调,可以提供聊天历史,例如
```json
{
"prompt": "是的。上下水管都好的",
"response": "那就要检查线路了,一般风扇继电器是由电脑控制吸合的,如果电路存在断路,或者电脑坏了的话会出现继电器不吸合的情况!",
"history": [
[
"长城h3风扇不转。继电器好的。保险丝好的传感器新的风扇也新的这是为什么。就是继电器缺一个信号线",
"用电脑能读数据流吗?水温多少"
],
[
"95",
"上下水管温差怎么样啊?空气是不是都排干净了呢?"
]
]
}
```
训练时需要指定 `--history_column` 为数据中聊天历史的 key在此例子中是 `history`),将自动把聊天历史拼接,例如:
- Input
```
[Round 0]
问:长城h3风扇不转。继电器好的。保险丝好的传感器新的风扇也新的这是为什么。就是继电器缺一个信号线
答:用电脑能读数据流吗?水温多少
[Round 1]
问:95
答:上下水管温差怎么样啊?空气是不是都排干净了呢?
[Round 2]
问:是的。上下水管都好的
```
- Label
```
那就要检查线路了,一般风扇继电器是由电脑控制吸合的,如果电路存在断路,或者电脑坏了的话会出现继电器不吸合的情况!
```
要注意超过输入长度 `max_source_length` 的内容会被截。
可以参考以下指令:
```shell
bash train_chat.sh
```
## TODO ## TODO
* [ ] Support for chat data * [x] Support for chat data
* [ ] Support for full finetuning * [ ] Support for full finetuning
## 引用 ## 引用

@ -80,6 +80,10 @@ class DataTrainingArguments:
default=None, default=None,
metadata={"help": "The name of the column in the datasets containing the summaries (for summarization)."}, metadata={"help": "The name of the column in the datasets containing the summaries (for summarization)."},
) )
history_column: Optional[str] = field(
default=None,
metadata={"help": "The name of the column in the datasets containing the history of chat."},
)
train_file: Optional[str] = field( train_file: Optional[str] = field(
default=None, metadata={"help": "The input training data file (a jsonlines or csv file)."} default=None, metadata={"help": "The input training data file (a jsonlines or csv file)."}
) )

@ -135,6 +135,7 @@ def main():
# Get the column names for input/target. # Get the column names for input/target.
prompt_column = data_args.prompt_column prompt_column = data_args.prompt_column
response_column = data_args.response_column response_column = data_args.response_column
history_column = data_args.history_column
# Temporarily set max_target_length for training. # Temporarily set max_target_length for training.
max_target_length = data_args.max_target_length max_target_length = data_args.max_target_length
@ -143,7 +144,16 @@ def main():
inputs, targets = [], [] inputs, targets = [], []
for i in range(len(examples[prompt_column])): for i in range(len(examples[prompt_column])):
if examples[prompt_column][i] and examples[response_column][i]: if examples[prompt_column][i] and examples[response_column][i]:
inputs.append(examples[prompt_column][i]) query = examples[prompt_column][i]
if history_column is None or len(examples[history_column][i]) == 0:
prompt = query
else:
prompt = ""
history = examples[history_column][i]
for i, (old_query, response) in enumerate(history):
prompt += "[Round {}]\n问:{}\n答:{}\n".format(i, old_query, response)
prompt += "[Round {}]\n问:{}\n答:".format(len(history), query)
inputs.append(prompt)
targets.append(examples[response_column][i]) targets.append(examples[response_column][i])
inputs = [prefix + inp for inp in inputs] inputs = [prefix + inp for inp in inputs]
@ -167,7 +177,17 @@ def main():
} }
for i in range(len(examples[prompt_column])): for i in range(len(examples[prompt_column])):
if examples[prompt_column][i] and examples[response_column][i]: if examples[prompt_column][i] and examples[response_column][i]:
prompt, answer = examples[prompt_column][i], examples[response_column][i] query, answer = examples[prompt_column][i], examples[response_column][i]
if history_column is None:
prompt = query
else:
prompt = ""
history = examples[history_column][i]
for i, (old_query, response) in enumerate(history):
prompt += "[Round {}]\n问:{}\n答:{}\n".format(i, old_query, response)
prompt += "[Round {}]\n问:{}\n答:".format(len(history), query)
prompt = prefix + prompt prompt = prefix + prompt
a_ids = tokenizer.encode(text=prompt, add_special_tokens=False) a_ids = tokenizer.encode(text=prompt, add_special_tokens=False)
b_ids = tokenizer.encode(text=answer, add_special_tokens=False) b_ids = tokenizer.encode(text=answer, add_special_tokens=False)
@ -218,6 +238,8 @@ def main():
desc="Running tokenizer on train dataset", desc="Running tokenizer on train dataset",
) )
print_dataset_example(train_dataset[0]) print_dataset_example(train_dataset[0])
print_dataset_example(train_dataset[2])
exit()
if training_args.do_eval: if training_args.do_eval:
max_target_length = data_args.val_max_target_length max_target_length = data_args.val_max_target_length

@ -0,0 +1,27 @@
PRE_SEQ_LEN=8
LR=1e-2
CUDA_VISIBLE_DEVICES=0 python3 main.py \
--do_train \
--train_file $CHAT_TRAIN_DATA \
--validation_file $CHAT_VAL_DATA \
--prompt_column prompt \
--response_column response \
--history_column history \
--overwrite_cache \
--model_name_or_path THUDM/chatglm-6b \
--output_dir $CHECKPOINT_NAME \
--overwrite_output_dir \
--max_source_length 256 \
--max_target_length 256 \
--per_device_train_batch_size 1 \
--per_device_eval_batch_size 1 \
--gradient_accumulation_steps 16 \
--predict_with_generate \
--max_steps 3000 \
--logging_steps 10 \
--save_steps 1000 \
--learning_rate $LR \
--pre_seq_len $PRE_SEQ_LEN \
--quantization_bit 4
Loading…
Cancel
Save