mirror of https://github.com/THUDM/ChatGLM-6B
Add training for chat data
parent
a9fc018444
commit
5865924cc6
|
@ -58,11 +58,14 @@ bash evaluate.sh
|
|||
### 评估结果
|
||||
|
||||
| | P-tuning v2 | LoRA |
|
||||
| ------- | ----------- | ----- |
|
||||
| BLEU-4 | 7.71 | 6.13 |
|
||||
| Rouge-1 | 31.35 | 28.36 |
|
||||
| Rouge-2 | 7.19 | 4.38 |
|
||||
| Rouge-l | 25.17 | 17.54 |
|
||||
| ------------- | ----------- | ----- |
|
||||
| BLEU-4 | 7.78 | 6.13 |
|
||||
| Rouge-1 | 31.34 | 28.36 |
|
||||
| Rouge-2 | 7.34 | 4.38 |
|
||||
| Rouge-l | 25.26 | 17.54 |
|
||||
| Training Loss | 3.8016 | 3.36 |
|
||||
|
||||
|
||||
|
||||
#### 实验设置
|
||||
|
||||
|
@ -98,8 +101,60 @@ learning_rate=5e-4
|
|||
## 使用自己的数据集
|
||||
修改 `train.sh` 和 `evaluate.sh` 中的 `train_file`、`validation_file`和`test_file`为你自己的 JSON 格式数据集路径,并将 `prompt_column` 和 `response_column` 改为 JSON 文件中输入文本和输出文本对应的 KEY。
|
||||
|
||||
## 对话数据集
|
||||
|
||||
如需要使用多轮对话数据对模型进行微调,可以提供聊天历史,例如
|
||||
|
||||
```json
|
||||
{
|
||||
"prompt": "是的。上下水管都好的",
|
||||
"response": "那就要检查线路了,一般风扇继电器是由电脑控制吸合的,如果电路存在断路,或者电脑坏了的话会出现继电器不吸合的情况!",
|
||||
"history": [
|
||||
[
|
||||
"长城h3风扇不转。继电器好的。保险丝好的传感器新的风扇也新的这是为什么。就是继电器缺一个信号线",
|
||||
"用电脑能读数据流吗?水温多少"
|
||||
],
|
||||
[
|
||||
"95",
|
||||
"上下水管温差怎么样啊?空气是不是都排干净了呢?"
|
||||
]
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
训练时需要指定 `--history_column` 为数据中聊天历史的 key(在此例子中是 `history`),将自动把聊天历史拼接,例如:
|
||||
|
||||
- Input
|
||||
|
||||
```
|
||||
[Round 0]
|
||||
问:长城h3风扇不转。继电器好的。保险丝好的传感器新的风扇也新的这是为什么。就是继电器缺一个信号线
|
||||
答:用电脑能读数据流吗?水温多少
|
||||
[Round 1]
|
||||
问:95
|
||||
答:上下水管温差怎么样啊?空气是不是都排干净了呢?
|
||||
[Round 2]
|
||||
问:是的。上下水管都好的
|
||||
```
|
||||
|
||||
- Label
|
||||
|
||||
```
|
||||
那就要检查线路了,一般风扇继电器是由电脑控制吸合的,如果电路存在断路,或者电脑坏了的话会出现继电器不吸合的情况!
|
||||
```
|
||||
|
||||
要注意超过输入长度 `max_source_length` 的内容会被截。
|
||||
|
||||
可以参考以下指令:
|
||||
|
||||
```shell
|
||||
bash train_chat.sh
|
||||
```
|
||||
|
||||
|
||||
|
||||
## TODO
|
||||
* [ ] Support for chat data
|
||||
* [x] Support for chat data
|
||||
* [ ] Support for full finetuning
|
||||
|
||||
## 引用
|
||||
|
|
|
@ -80,6 +80,10 @@ class DataTrainingArguments:
|
|||
default=None,
|
||||
metadata={"help": "The name of the column in the datasets containing the summaries (for summarization)."},
|
||||
)
|
||||
history_column: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={"help": "The name of the column in the datasets containing the history of chat."},
|
||||
)
|
||||
train_file: Optional[str] = field(
|
||||
default=None, metadata={"help": "The input training data file (a jsonlines or csv file)."}
|
||||
)
|
||||
|
|
|
@ -135,6 +135,7 @@ def main():
|
|||
# Get the column names for input/target.
|
||||
prompt_column = data_args.prompt_column
|
||||
response_column = data_args.response_column
|
||||
history_column = data_args.history_column
|
||||
|
||||
# Temporarily set max_target_length for training.
|
||||
max_target_length = data_args.max_target_length
|
||||
|
@ -143,7 +144,16 @@ def main():
|
|||
inputs, targets = [], []
|
||||
for i in range(len(examples[prompt_column])):
|
||||
if examples[prompt_column][i] and examples[response_column][i]:
|
||||
inputs.append(examples[prompt_column][i])
|
||||
query = examples[prompt_column][i]
|
||||
if history_column is None or len(examples[history_column][i]) == 0:
|
||||
prompt = query
|
||||
else:
|
||||
prompt = ""
|
||||
history = examples[history_column][i]
|
||||
for i, (old_query, response) in enumerate(history):
|
||||
prompt += "[Round {}]\n问:{}\n答:{}\n".format(i, old_query, response)
|
||||
prompt += "[Round {}]\n问:{}\n答:".format(len(history), query)
|
||||
inputs.append(prompt)
|
||||
targets.append(examples[response_column][i])
|
||||
|
||||
inputs = [prefix + inp for inp in inputs]
|
||||
|
@ -167,7 +177,17 @@ def main():
|
|||
}
|
||||
for i in range(len(examples[prompt_column])):
|
||||
if examples[prompt_column][i] and examples[response_column][i]:
|
||||
prompt, answer = examples[prompt_column][i], examples[response_column][i]
|
||||
query, answer = examples[prompt_column][i], examples[response_column][i]
|
||||
|
||||
if history_column is None:
|
||||
prompt = query
|
||||
else:
|
||||
prompt = ""
|
||||
history = examples[history_column][i]
|
||||
for i, (old_query, response) in enumerate(history):
|
||||
prompt += "[Round {}]\n问:{}\n答:{}\n".format(i, old_query, response)
|
||||
prompt += "[Round {}]\n问:{}\n答:".format(len(history), query)
|
||||
|
||||
prompt = prefix + prompt
|
||||
a_ids = tokenizer.encode(text=prompt, add_special_tokens=False)
|
||||
b_ids = tokenizer.encode(text=answer, add_special_tokens=False)
|
||||
|
@ -218,6 +238,8 @@ def main():
|
|||
desc="Running tokenizer on train dataset",
|
||||
)
|
||||
print_dataset_example(train_dataset[0])
|
||||
print_dataset_example(train_dataset[2])
|
||||
exit()
|
||||
|
||||
if training_args.do_eval:
|
||||
max_target_length = data_args.val_max_target_length
|
||||
|
|
|
@ -0,0 +1,27 @@
|
|||
PRE_SEQ_LEN=8
|
||||
LR=1e-2
|
||||
|
||||
CUDA_VISIBLE_DEVICES=0 python3 main.py \
|
||||
--do_train \
|
||||
--train_file $CHAT_TRAIN_DATA \
|
||||
--validation_file $CHAT_VAL_DATA \
|
||||
--prompt_column prompt \
|
||||
--response_column response \
|
||||
--history_column history \
|
||||
--overwrite_cache \
|
||||
--model_name_or_path THUDM/chatglm-6b \
|
||||
--output_dir $CHECKPOINT_NAME \
|
||||
--overwrite_output_dir \
|
||||
--max_source_length 256 \
|
||||
--max_target_length 256 \
|
||||
--per_device_train_batch_size 1 \
|
||||
--per_device_eval_batch_size 1 \
|
||||
--gradient_accumulation_steps 16 \
|
||||
--predict_with_generate \
|
||||
--max_steps 3000 \
|
||||
--logging_steps 10 \
|
||||
--save_steps 1000 \
|
||||
--learning_rate $LR \
|
||||
--pre_seq_len $PRE_SEQ_LEN \
|
||||
--quantization_bit 4
|
||||
|
Loading…
Reference in New Issue