diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..c3dd476 --- /dev/null +++ b/.gitignore @@ -0,0 +1,133 @@ +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +pip-wheel-metadata/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST +history/ + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +.python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + +# Mac system file +model/ \ No newline at end of file diff --git a/README.md b/README.md index c63913c..5f9689c 100644 --- a/README.md +++ b/README.md @@ -16,6 +16,10 @@ ChatGLM-6B 使用了和 ChatGPT 相似的技术,针对中文问答和对话进 *Read this in [English](README_en.md).* ## 更新信息 +**如果你遇到了任何问题并且是从本地加载模型的,请先尝试从 [HF Repo](https://huggingface.co/THUDM/chatglm-6b) 或 [Tsinghua Cloud](https://cloud.tsinghua.edu.cn/d/fb9f16d6dc8f482596c2/) 重新下载模型文件**。 + +**[2023/04/06]** 优化web demo的界面。移除embedding中的image token以减小显存占用(需要更新模型文件`pytorch_model-00001-of-00008.bin`和`pytorch_model-00008-of-00008.bin`)。去掉了对 `icetk` 的依赖(需要更新模型文件`ice_text.model`)。 + **[2023/03/31]** 增加基于 [P-Tuning-v2](https://github.com/THUDM/P-tuning-v2) 的高效参数微调实现,INT4 量化级别下最低只需 7GB 显存即可进行模型微调。详见[高效参数微调方法](ptuning/README.md)。 **[2023/03/23]** 增加 API 部署(感谢 [@LemonQu-GIT](https://github.com/LemonQu-GIT))。增加 Embedding 量化模型 [ChatGLM-6B-INT4-QE](https://huggingface.co/THUDM/chatglm-6b-int4-qe)。增加配备 Apple Silicon 芯片的 Mac 上 GPU 加速的支持。 @@ -31,6 +35,7 @@ ChatGLM-6B 使用了和 ChatGPT 相似的技术,针对中文问答和对话进 * [bibliothecarius](https://github.com/coderabbit214/bibliothecarius):快速构建服务以集成您的本地数据和AI模型,支持ChatGLM等本地化模型接入。 * [闻达](https://github.com/l15y/wenda):大型语言模型调用平台,基于 ChatGLM-6B 实现了类 ChatPDF 功能 * [JittorLLMs](https://github.com/Jittor/JittorLLMs):最低3G显存或者没有显卡都可运行 ChatGLM-6B FP16, 支持Linux、windows、Mac部署 +* [ChatGLM-Finetuning](https://github.com/liucongg/ChatGLM-Finetuning):基于ChatGLM-6B模型,进行下游具体任务微调,涉及Freeze、Lora、P-tuning等,并进行实验效果对比。 以下是部分针对本项目的教程/文档: * [Windows部署文档](https://github.com/ZhangErling/ChatGLM-6B/blob/main/deployment_windows.md) diff --git a/README_en.md b/README_en.md index d5c05bb..da2b8dc 100644 --- a/README_en.md +++ b/README_en.md @@ -9,6 +9,8 @@ ChatGLM-6B uses technology similar to ChatGPT, optimized for Chinese QA and dial Try the [online demo](https://huggingface.co/spaces/ysharma/ChatGLM-6b_Gradio_Streaming) on Huggingface Spaces. ## Update +**[2023/03/31]** Added a parameter-efficient tuning implementation based on [P-Tuning-v2](https://github.com/THUDM/P-tuning-v2). The minimum INT4 quantization level only needs 7GB GPU memory is enough for model tuning. See [Parameter-efficient tuning method](ptuning/README.md) for details. + **[2023/03/23]** Add API deployment, thanks to [@LemonQu-GIT](https://github.com/LemonQu-GIT). Add embedding-quantized model [ChatGLM-6B-INT4-QE](https://huggingface.co/THUDM/chatglm-6b-int4-qe). Add support for GPU inference on Mac with Apple Silicon. **[2023/03/19]** Add streaming output function `stream_chat`, already applied in web and CLI demo. Fix Chinese punctuations in output. Add quantized model [ChatGLM-6B-INT4](https://huggingface.co/THUDM/chatglm-6b-int4). @@ -168,6 +170,8 @@ model = AutoModel.from_pretrained("your local path", trust_remote_code=True).hal ``` Then you can use GPU-accelerated model inference on Mac. +## Parameter-efficient Tuning +Parameter-efficient tuning based on [P-tuning v2](https://github.com/THUDM/P-tuning-v2). See [ptuning/README.md](ptuning/README.md) for details on how to use it. ## ChatGLM-6B Examples diff --git a/ptuning/README.md b/ptuning/README.md index 4978dc4..753d895 100644 --- a/ptuning/README.md +++ b/ptuning/README.md @@ -3,6 +3,8 @@ 下面以 [ADGEN](https://aclanthology.org/D19-1321.pdf) (广告生成) 数据集为例介绍代码的使用方法。 +*Read this in [English](README_en.md).* + ## 软件依赖 运行微调需要4.27.1版本的`transformers`。除 ChatGLM-6B 的依赖之外,还需要按照以下依赖 ``` @@ -52,18 +54,110 @@ bash evaluate.sh * Input: 类型#裙\*版型#显瘦\*风格#文艺\*风格#简约\*图案#印花\*图案#撞色\*裙下摆#压褶\*裙长#连衣裙\*裙领型#圆领 * Label: 文艺个性的印花连衣裙,藏青色底蕴,低调又大气,撞色太阳花分布整个裙身,绚丽而美好,带来时尚减龄的气质。基础款的舒适圆领,简约不失大方,勾勒精致脸庞。领后是一粒包布扣固定,穿脱十分方便。前片立体的打褶设计,搭配后片压褶的做工,增添层次和空间感,显瘦又有型。 - * Output[微调前]: 类型#裙*版型#显瘦*风格#文艺*风格#简约*图案#印花*图案#撞色*裙下摆#压褶*裙长#连衣裙*裙领型#圆领 1\. 连衣裙:简约风格,裙长为膝盖以上,裙领型为圆领。2\. 裙下摆:压褶设计,使裙摆呈现出流畅的褶皱效果。3\. 裙领型:裙领型为圆领,使穿上连衣裙后更加有型。4\. 版型:采用显瘦设计,让连衣裙看起来更加苗条。5\. 风格:文艺风格,让连衣裙更加有内涵和品味。6\. 图案:印花设计,在连衣裙上印有独特的图案。7\. 撞色:采用撞色设计,让连衣裙在色彩上更加鲜明、富有层次感。 * Output[微调后]: 这是一款文艺范的连衣裙,以印花为元素,采用简约的印花,既能够突出文艺气质,又能够展现简约风。在印花的同时又有领子和裙摆的压褶设计,更加凸显文艺气质。简约而不会过于单调,搭配出街,穿着十分舒适。 +### 评估结果 + +| | P-tuning v2 | LoRA | +| ------------- | ----------- | ----- | +| BLEU-4 | 7.78 | 6.25 | +| Rouge-1 | 31.34 | 28.58 | +| Rouge-2 | 7.34 | 4.42 | +| Rouge-l | 25.26 | 17.56 | +| Training Loss | 3.80 | 3.36 | + + + +#### 实验设置 + + ``` +max_source_length=64 +max_target_length=64 +per_device_train_batch_size=1 +gradient_accumulation_steps=16 +max_steps=3000 + ``` + +##### P-tuning v2 + +``` +pre_seq_len=128 +learning_rate=2e-2 +quantization_bit=4 +``` + +##### LoRA + +``` +learning_rate=5e-4 +``` + +实现采用的是 [simple_thu_chatglm6b](https://github.com/yuanzhoulvpi2017/zero_nlp/tree/main/simple_thu_chatglm6b) + + + ## 模型部署 将对应的demo或代码中的`THUDM/chatglm-6b`换成经过 P-Tuning 微调之后 checkpoint 的地址(在示例中为 `./output/adgen-chatglm-6b-pt-8-1e-2/checkpoint-3000`)。注意,目前的微调还不支持多轮数据,所以只有对话第一轮的回复是经过微调的。 ## 使用自己的数据集 修改 `train.sh` 和 `evaluate.sh` 中的 `train_file`、`validation_file`和`test_file`为你自己的 JSON 格式数据集路径,并将 `prompt_column` 和 `response_column` 改为 JSON 文件中输入文本和输出文本对应的 KEY。 +## 对话数据集 + +如需要使用多轮对话数据对模型进行微调,可以提供聊天历史,例如 + +```json +{ + "prompt": "是的。上下水管都好的", + "response": "那就要检查线路了,一般风扇继电器是由电脑控制吸合的,如果电路存在断路,或者电脑坏了的话会出现继电器不吸合的情况!", + "history": [ + [ + "长城h3风扇不转。继电器好的。保险丝好的传感器新的风扇也新的这是为什么。就是继电器缺一个信号线", + "用电脑能读数据流吗?水温多少" + ], + [ + "95", + "上下水管温差怎么样啊?空气是不是都排干净了呢?" + ] + ] +} +``` + +训练时需要指定 `--history_column` 为数据中聊天历史的 key(在此例子中是 `history`),将自动把聊天历史拼接,例如: + +- Input + + ``` + [Round 0] + 问:长城h3风扇不转。继电器好的。保险丝好的传感器新的风扇也新的这是为什么。就是继电器缺一个信号线 + 答:用电脑能读数据流吗?水温多少 + [Round 1] + 问:95 + 答:上下水管温差怎么样啊?空气是不是都排干净了呢? + [Round 2] + 问:是的。上下水管都好的 + 答: + ``` + +- Label + + ``` + 那就要检查线路了,一般风扇继电器是由电脑控制吸合的,如果电路存在断路,或者电脑坏了的话会出现继电器不吸合的情况! + ``` + +要注意超过输入长度 `max_source_length` 的内容会被截。 + +可以参考以下指令: + +```shell +bash train_chat.sh +``` + + + ## TODO -* [ ] Support for chat data +* [x] Support for chat data * [ ] Support for full finetuning ## 引用 @@ -77,3 +171,4 @@ bash evaluate.sh year={2022} } ``` + diff --git a/ptuning/README_en.md b/ptuning/README_en.md new file mode 100644 index 0000000..9282da3 --- /dev/null +++ b/ptuning/README_en.md @@ -0,0 +1,115 @@ +# ChatGLM-6B-PT +This repository implements tuning of the ChatGLM-6B model based on [P-Tuning v2](https://github.com/THUDM/P-tuning-v2). P-Tuning v2 reduces the amount of parameters that need to be optimized to 0.1% of the full fine-tuning, and then through model quantization, Gradient Checkpoint and other methods, it only needs a minimum of 7GB of video memory to run. + +The following uses the [ADGEN](https://aclanthology.org/D19-1321.pdf) (advertising generation) dataset as an example to introduce how to use the code. + +## Software dependencies +Running p-tuning requires version 4.27.1 of `transformers`. In addition to the dependencies of ChatGLM-6B, the following dependencies are required +``` +pip install rouge_chinese nltk jieba datasets +``` +## Instructions + +### Download the dataset +The task of the ADGEN dataset is to generate an advertisement word (summary) based on the input (content). + +```json +{ + "content": "类型#上衣*版型#宽松*版型#显瘦*图案#线条*衣样式#衬衫*衣袖型#泡泡袖*衣款式#抽绳", + "summary": "这件衬衫的款式非常的宽松,利落的线条可以很好的隐藏身材上的小缺点,穿在身上有着很好的显瘦效果。领口装饰了一个可爱的抽绳,漂亮的绳结展现出了十足的个性,配合时尚的泡泡袖型,尽显女性甜美可爱的气息。" +} +``` + +From [Google Drive](https://drive.google.com/file/d/13_vf0xRTQsyneRKdD1bZIr93vBGOczrk/view?usp=sharing) or [Tsinghua Cloud](https://cloud.tsinghua.edu.cn/f/b3f119a008264b1cabd1/?dl=1) Download the processed ADGEN dataset, and put the decompressed `AdvertiseGen` directory into this directory. + +### Training +Run the following commands for training: +```shell +bash train.sh +``` +`PRE_SEQ_LEN` and `LR` in `train.sh` are soft prompt length and training learning rate respectively, which can be adjusted to achieve the best results. The P-Tuning-v2 method will freeze all model parameters, and the quantization level of the original model can be adjusted by adjusting `quantization_bit`. If this option is not added, it will be loaded with FP16 precision. + +Under the default configuration of `per_device_train_batch_size=1`, `gradient_accumulation_steps=16`, the model parameters of INT4 are frozen, and a training iteration will perform 16 cumulative forward and backward propagations with a batch size of 1, which is equivalent to the total batch size of 16, and only 6.7G GPU memory is required at this time with `quantization_bit=4`. If you want to improve the training efficiency under the same batch size, you can increase the value of `per_device_train_batch_size` while keeping the product of the two unchanged, but it will also bring more GPU memory consumption, please adjust it according to the actual situation. + +### Inference + +Change `CHECKPOINT` in `evaluate.sh` to the checkpoint name saved during training, and run the following commands for model inference and evaluation: +```shell +bash evaluate.sh +``` + +The evaluation indicators are Chinese Rouge score and BLEU-4. The generated results are saved in +`./output/adgen-chatglm-6b-pt-8-1e-2/generated_predictions.txt`. + +### Example +#### Example 1 +* Input: 类型#上衣\*材质#牛仔布\*颜色#白色\*风格#简约\*图案#刺绣\*衣样式#外套\*衣款式#破洞 +* Label: 简约而不简单的牛仔外套,白色的衣身十分百搭。衣身多处有做旧破洞设计,打破单调乏味,增加一丝造型看点。衣身后背处有趣味刺绣装饰,丰富层次感,彰显别样时尚。 +* Output[微调前]: 这件上衣的材质是牛仔布,颜色是白色,风格是简约,图案是刺绣,衣样式是外套,衣款式是破洞。 +* Output[微调后]: 这是一款简约的牛仔外套,破洞设计,将牛仔布破洞,带来一种随意与个性。破洞的牛仔外套,展现出时尚气息,带来一种休闲感。同时,刺绣图案,让整件外套更加立体。 + +#### Example 2 + +* Input: 类型#裙\*版型#显瘦\*风格#文艺\*风格#简约\*图案#印花\*图案#撞色\*裙下摆#压褶\*裙长#连衣裙\*裙领型#圆领 +* Label: 文艺个性的印花连衣裙,藏青色底蕴,低调又大气,撞色太阳花分布整个裙身,绚丽而美好,带来时尚减龄的气质。基础款的舒适圆领,简约不失大方,勾勒精致脸庞。领后是一粒包布扣固定,穿脱十分方便。前片立体的打褶设计,搭配后片压褶的做工,增添层次和空间感,显瘦又有型。 +* Output[微调前]: 类型#裙*版型#显瘦*风格#文艺*风格#简约*图案#印花*图案#撞色*裙下摆#压褶*裙长#连衣裙*裙领型#圆领 1\. 连衣裙:简约风格,裙长为膝盖以上,裙领型为圆领。2\. 裙下摆:压褶设计,使裙摆呈现出流畅的褶皱效果。3\. 裙领型:裙领型为圆领,使穿上连衣裙后更加有型。4\. 版型:采用显瘦设计,让连衣裙看起来更加苗条。5\. 风格:文艺风格,让连衣裙更加有内涵和品味。6\. 图案:印花设计,在连衣裙上印有独特的图案。7\. 撞色:采用撞色设计,让连衣裙在色彩上更加鲜明、富有层次感。 +* Output[微调后]: 这是一款文艺范的连衣裙,以印花为元素,采用简约的印花,既能够突出文艺气质,又能够展现简约风。在印花的同时又有领子和裙摆的压褶设计,更加凸显文艺气质。简约而不会过于单调,搭配出街,穿着十分舒适。 + +### evaluation result + +| | P-tuning v2 | LoRA | +| ------- | ----------- | ----- | +| BLEU-4 | 7.71 | 6.13 | +| Rouge-1 | 31.35 | 28.36 | +| Rouge-2 | 7.19 | 4.38 | +| Rouge-l | 25.17 | 17.54 | + +#### Experiment Settings + + ``` +max_source_length=64 +max_target_length=64 +per_device_train_batch_size=1 +gradient_accumulation_steps=16 +max_steps=3000 + ``` + +##### P-tuning v2 + +``` +pre_seq_len=128 +learning_rate=2e-2 +quantization_bit=4 +``` + +##### LoRA + +``` +learning_rate=5e-4 +``` + +The implementation uses [simple_thu_chatglm6b](https://github.com/yuanzhoulvpi2017/zero_nlp/tree/main/simple_thu_chatglm6b) + + + +## Model Deployment +Replace `THUDM/chatglm-6b` in the corresponding demo or code with the path of the checkpoint after P-Tuning(in the example, `./output/adgen-chatglm-6b-pt-8-1e-2/ checkpoint-3000`). Note that the current fine-tuning does not support multiple rounds of data, so only the responses from the first round of the conversation are fine-tuned. + +## Use your own dataset +Modify `train_file`, `validation_file` and `test_file` in `train.sh` and `evaluate.sh` to your own JSON format dataset paths, and change `prompt_column` and `response_column` to the keys in the JSON file corresponding to input text and output text. + +## TODO +* [ ] Support for chat data +* [ ] Support for full finetuning + +## quoting + +``` +@inproceedings{liu2022p, + title={P-tuning: Prompt tuning can be comparable to fine-tuning across scales and tasks}, + author={Liu, Xiao and Ji, Kaixuan and Fu, Yicheng and Tam, Weng and Du, Zhengxiao and Yang, Zhilin and Tang, Jie}, + booktitle={Proceedings of the 60th Annual Meeting of the Association for Computational Linguistics (Volume 2: Short Papers)}, + pages={61--68}, + year={2022} +} +``` \ No newline at end of file diff --git a/ptuning/arguments.py b/ptuning/arguments.py index 95d766f..f9310da 100644 --- a/ptuning/arguments.py +++ b/ptuning/arguments.py @@ -80,6 +80,10 @@ class DataTrainingArguments: default=None, metadata={"help": "The name of the column in the datasets containing the summaries (for summarization)."}, ) + history_column: Optional[str] = field( + default=None, + metadata={"help": "The name of the column in the datasets containing the history of chat."}, + ) train_file: Optional[str] = field( default=None, metadata={"help": "The input training data file (a jsonlines or csv file)."} ) diff --git a/ptuning/evaluate.sh b/ptuning/evaluate.sh index 1217ceb..120a8c2 100644 --- a/ptuning/evaluate.sh +++ b/ptuning/evaluate.sh @@ -1,5 +1,5 @@ -PRE_SEQ_LEN=8 -CHECKPOINT=adgen-chatglm-6b-pt-8-1e-2 +PRE_SEQ_LEN=128 +CHECKPOINT=adgen-chatglm-6b-pt-128-2e-2 STEP=3000 CUDA_VISIBLE_DEVICES=0 python3 main.py \ diff --git a/ptuning/main.py b/ptuning/main.py index 112c9ca..e34e95e 100644 --- a/ptuning/main.py +++ b/ptuning/main.py @@ -27,7 +27,7 @@ import numpy as np from datasets import load_dataset import jieba from rouge_chinese import Rouge -from nltk.translate.bleu_score import sentence_bleu +from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction import transformers from transformers import ( @@ -135,6 +135,7 @@ def main(): # Get the column names for input/target. prompt_column = data_args.prompt_column response_column = data_args.response_column + history_column = data_args.history_column # Temporarily set max_target_length for training. max_target_length = data_args.max_target_length @@ -143,7 +144,16 @@ def main(): inputs, targets = [], [] for i in range(len(examples[prompt_column])): if examples[prompt_column][i] and examples[response_column][i]: - inputs.append(examples[prompt_column][i]) + query = examples[prompt_column][i] + if history_column is None or len(examples[history_column][i]) == 0: + prompt = query + else: + prompt = "" + history = examples[history_column][i] + for i, (old_query, response) in enumerate(history): + prompt += "[Round {}]\n问:{}\n答:{}\n".format(i, old_query, response) + prompt += "[Round {}]\n问:{}\n答:".format(len(history), query) + inputs.append(prompt) targets.append(examples[response_column][i]) inputs = [prefix + inp for inp in inputs] @@ -157,7 +167,7 @@ def main(): model_inputs["labels"] = labels["input_ids"] return model_inputs - + def preprocess_function_train(examples): max_seq_length = data_args.max_source_length + data_args.max_target_length @@ -167,7 +177,17 @@ def main(): } for i in range(len(examples[prompt_column])): if examples[prompt_column][i] and examples[response_column][i]: - prompt, answer = examples[prompt_column][i], examples[response_column][i] + query, answer = examples[prompt_column][i], examples[response_column][i] + + if history_column is None: + prompt = query + else: + prompt = "" + history = examples[history_column][i] + for i, (old_query, response) in enumerate(history): + prompt += "[Round {}]\n问:{}\n答:{}\n".format(i, old_query, response) + prompt += "[Round {}]\n问:{}\n答:".format(len(history), query) + prompt = prefix + prompt a_ids = tokenizer.encode(text=prompt, add_special_tokens=False) b_ids = tokenizer.encode(text=answer, add_special_tokens=False) @@ -178,9 +198,9 @@ def main(): if len(b_ids) > data_args.max_target_length - 2: b_ids = b_ids[: data_args.max_target_length - 2] - input_ids = a_ids + [150001, 150004] + b_ids + [150005] + input_ids = tokenizer.build_inputs_with_special_tokens(a_ids, b_ids) - context_length = input_ids.index(150004) + context_length = input_ids.index(tokenizer.bos_token_id) mask_position = context_length - 1 labels = [-100] * context_length + input_ids[mask_position+1:] @@ -293,7 +313,7 @@ def main(): for k, v in result.items(): score_dict[k].append(round(v["f"] * 100, 4)) - bleu_score = sentence_bleu([list(label)], list(pred)) + bleu_score = sentence_bleu([list(label)], list(pred), smoothing_function=SmoothingFunction().method3) score_dict["bleu-4"].append(round(bleu_score * 100, 4)) for k, v in score_dict.items(): diff --git a/ptuning/train.sh b/ptuning/train.sh index 3189829..efc9a16 100644 --- a/ptuning/train.sh +++ b/ptuning/train.sh @@ -1,5 +1,5 @@ -PRE_SEQ_LEN=8 -LR=1e-2 +PRE_SEQ_LEN=128 +LR=2e-2 CUDA_VISIBLE_DEVICES=0 python3 main.py \ --do_train \ diff --git a/ptuning/train_chat.sh b/ptuning/train_chat.sh new file mode 100644 index 0000000..b0f5cdc --- /dev/null +++ b/ptuning/train_chat.sh @@ -0,0 +1,27 @@ +PRE_SEQ_LEN=8 +LR=1e-2 + +CUDA_VISIBLE_DEVICES=0 python3 main.py \ + --do_train \ + --train_file $CHAT_TRAIN_DATA \ + --validation_file $CHAT_VAL_DATA \ + --prompt_column prompt \ + --response_column response \ + --history_column history \ + --overwrite_cache \ + --model_name_or_path THUDM/chatglm-6b \ + --output_dir $CHECKPOINT_NAME \ + --overwrite_output_dir \ + --max_source_length 256 \ + --max_target_length 256 \ + --per_device_train_batch_size 1 \ + --per_device_eval_batch_size 1 \ + --gradient_accumulation_steps 16 \ + --predict_with_generate \ + --max_steps 3000 \ + --logging_steps 10 \ + --save_steps 1000 \ + --learning_rate $LR \ + --pre_seq_len $PRE_SEQ_LEN \ + --quantization_bit 4 + diff --git a/requirements.txt b/requirements.txt index 00707fe..214d68f 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,6 +1,7 @@ protobuf>=3.19.5,<3.20.1 transformers==4.27.1 -icetk cpm_kernels torch>=1.10 gradio +mdtex2html +sentencepiece \ No newline at end of file diff --git a/web_demo.py b/web_demo.py index 522a4bd..df7f983 100644 --- a/web_demo.py +++ b/web_demo.py @@ -1,45 +1,101 @@ from transformers import AutoModel, AutoTokenizer import gradio as gr +import mdtex2html tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True) model = AutoModel.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True).half().cuda() model = model.eval() -MAX_TURNS = 20 -MAX_BOXES = MAX_TURNS * 2 +"""Override Chatbot.postprocess""" -def predict(input, max_length, top_p, temperature, history=None): - if history is None: - history = [] +def postprocess(self, y): + if y is None: + return [] + for i, (message, response) in enumerate(y): + y[i] = ( + None if message is None else mdtex2html.convert((message)), + None if response is None else mdtex2html.convert(response), + ) + return y + + +gr.Chatbot.postprocess = postprocess + + +def parse_text(text): + """copy from https://github.com/GaiZhenbiao/ChuanhuChatGPT/""" + lines = text.split("\n") + lines = [line for line in lines if line != ""] + count = 0 + for i, line in enumerate(lines): + if "```" in line: + count += 1 + items = line.split('`') + if count % 2 == 1: + lines[i] = f'
'
+ else:
+ lines[i] = f'
'
+ else:
+ if i > 0:
+ if count % 2 == 1:
+ line = line.replace("`", "\`")
+ line = line.replace("<", "<")
+ line = line.replace(">", ">")
+ line = line.replace(" ", " ")
+ line = line.replace("*", "*")
+ line = line.replace("_", "_")
+ line = line.replace("-", "-")
+ line = line.replace(".", ".")
+ line = line.replace("!", "!")
+ line = line.replace("(", "(")
+ line = line.replace(")", ")")
+ line = line.replace("$", "$")
+ lines[i] = "