Merge commit '58bfd476e6c66625dc607ee4bcfd478d0a051a87' into develop

* commit '58bfd476e6c66625dc607ee4bcfd478d0a051a87':
  Update README.md
  Update web_demo.py
  Update typewriter-effect examples
  Update README
  Update README
  Add newline in cli output
  Update README
  Add support for streaming output

# Conflicts:
#	cli_demo.py
pull/151/head
kingzeus 2023-03-20 18:34:08 +08:00
commit 23bfb8c139
5 changed files with 90 additions and 90 deletions

View File

@ -9,7 +9,12 @@ ChatGLM-6B 使用了和 ChatGPT 相似的技术,针对中文问答和对话进
*Read this in [English](README_en.md).*
## 硬件需求
## 更新信息
**[2023/03/19]** 增加流式输出接口 `stream_chat`,已更新到网页版和命令行 Demo。修复输出中的中文标点。增加量化后的模型 [ChatGLM-6B-INT4](https://huggingface.co/THUDM/chatglm-6b-int4)
## 使用方式
### 硬件需求
| **量化等级** | **最低 GPU 显存** |
| -------------- | ----------------- |
@ -17,8 +22,6 @@ ChatGLM-6B 使用了和 ChatGPT 相似的技术,针对中文问答和对话进
| INT8 | 10 GB |
| INT4 | 6 GB |
## 使用方式
### 环境安装
使用 pip 安装依赖:`pip install -r requirements.txt`,其中 `transformers` 库版本推荐为 `4.26.1`,但理论上不低于 `4.23.1` 即可。
@ -31,6 +34,7 @@ ChatGLM-6B 使用了和 ChatGPT 相似的技术,针对中文问答和对话进
>>> from transformers import AutoTokenizer, AutoModel
>>> tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True)
>>> model = AutoModel.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True).half().cuda()
>>> model = model.eval()
>>> response, history = model.chat(tokenizer, "你好", history=[])
>>> print(response)
你好👋!我是人工智能助手 ChatGLM-6B,很高兴见到你,欢迎问我任何问题。
@ -60,7 +64,7 @@ cd ChatGLM-6B
#### 网页版 Demo
![web-demo](resources/web-demo.png)
![web-demo](resources/web-demo.gif)
首先安装 Gradio`pip install gradio`,然后运行仓库中的 [web_demo.py](web_demo.py)
@ -68,9 +72,9 @@ cd ChatGLM-6B
python web_demo.py
```
程序会运行一个 Web Server并输出地址。在浏览器中打开输出的地址即可使用。
程序会运行一个 Web Server并输出地址。在浏览器中打开输出的地址即可使用。最新版 Demo 实现了打字机效果,速度体验大大提升。注意,由于国内 Gradio 的网络访问较为缓慢,启用 `demo.queue().launch(share=True, inbrowser=True)` 时所有网络会经过 Gradio 服务器转发,导致打字机体验大幅下降,现在默认启动方式已经改为 `share=False`,如有需要公网访问的需求,可以重新修改为 `share=True` 启动。
感谢[@AdamBear](https://github.com/AdamBear) 实现了基于Streamlit的网页版demo运行方式见[#117](https://github.com/THUDM/ChatGLM-6B/pull/117).
感谢[@AdamBear](https://github.com/AdamBear) 实现了基于 Streamlit 的网页版 Demo运行方式见[#117](https://github.com/THUDM/ChatGLM-6B/pull/117).
#### 命令行 Demo
@ -97,18 +101,21 @@ model = AutoModel.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True).ha
模型量化会带来一定的性能损失经过测试ChatGLM-6B 在 4-bit 量化下仍然能够进行自然流畅的生成。使用 [GPT-Q](https://arxiv.org/abs/2210.17323) 等量化方案可以进一步压缩量化精度/提升相同量化精度下的模型性能,欢迎大家提出对应的 Pull Request。
**[2023/03/19]** 量化过程需要在内存中首先加载 FP16 格式的模型,消耗大概 13GB 的内存。如果你的内存不足的话,可以直接加载量化后的模型,仅需大概 5.2GB 的内存:
```python
model = AutoModel.from_pretrained("THUDM/chatglm-6b-int4", trust_remote_code=True).half().cuda()
```
### CPU 部署
如果你没有GPU硬件的话也可以在CPU上进行推理。使用方法如下
如果你没有 GPU 硬件的话,也可以在 CPU 上进行推理,但是推理速度会更慢。使用方法如下(需要大概 32GB 内存)
```python
model = AutoModel.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True).float()
```
CPU上推理速度可能会比较慢。
以上方法需要32G内存。如果你只有16G内存可以尝试
**[2023/03/19]** 如果你的内存不足,可以直接加载量化后的模型:
```python
model = AutoModel.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True).bfloat16()
model = AutoModel.from_pretrained("THUDM/chatglm-6b-int4",trust_remote_code=True).float()
```
需保证空闲内存接近16G并且推理速度会很慢。
如果遇到了报错 `Could not find module 'nvcuda.dll'` 或者 `RuntimeError: Unknown platform: darwin` (MacOS) 的话请参考这个[Issue](https://github.com/THUDM/ChatGLM-6B/issues/6#issuecomment-1470060041).

View File

@ -6,16 +6,19 @@ ChatGLM-6B is an open bilingual language model based on [General Language Model
ChatGLM-6B uses technology similar to ChatGPT, optimized for Chinese QA and dialogue. The model is trained for about 1T tokens of Chinese and English corpus, supplemented by supervised fine-tuning, feedback bootstrap, and reinforcement learning wit human feedback. With only about 6.2 billion parameters, the model is able to generate answers that are in line with human preference.
## Hardware Requirements
## Update
**[2023/03/19]** Add streaming output function `stream_chat`, already applied in web and CLI demo. Fix Chinese punctuations in output. Add quantized model [ChatGLM-6B-INT4](https://huggingface.co/THUDM/chatglm-6b-int4).
## Getting Started
### Hardware Requirements
| **Quantization Level** | **GPU Memory** |
| ---------------------------- | -------------------- |
|------------------------|----------------|
| FP16no quantization | 13 GB |
| INT8 | 10 GB |
| INT4 | 6 GB |
## Getting Started
### Environment Setup
Install the requirements with pip: `pip install -r requirements.txt`. `transformers` library version is recommended to be `4.26.1`, but theoretically any version no lower than `4.23.1` is acceptable.
@ -28,6 +31,7 @@ Generate dialogue with the following code
>>> from transformers import AutoTokenizer, AutoModel
>>> tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True)
>>> model = AutoModel.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True).half().cuda()
>>> model = model.eval()
>>> response, history = model.chat(tokenizer, "你好", history=[])
>>> print(response)
你好👋!我是人工智能助手 ChatGLM-6B,很高兴见到你,欢迎问我任何问题。
@ -95,24 +99,24 @@ After 2 to 3 rounds of dialogue, the GPU memory usage is about 10GB under 8-bit
Model quantization brings a certain performance decline. After testing, ChatGLM-6B can still perform natural and smooth generation under 4-bit quantization. using [GPT-Q](https://arxiv.org/abs/2210.17323) etc. The quantization scheme can further compress the quantization accuracy/improve the model performance under the same quantization accuracy. You are welcome to submit corresponding Pull Requests.
**[2023/03/19]** The quantization costs about 13GB of CPU memory to load the FP16 model. If your CPU memory is limited, you can directly load the quantized model, which costs only 5.2GB CPU memory:
```python
model = AutoModel.from_pretrained("THUDM/chatglm-6b-int4", trust_remote_code=True).half().cuda()
```
### CPU Deployment
If your computer is not equipped with GPU, you can also conduct inference on CPU:
If your computer is not equipped with GPU, you can also conduct inference on CPU, but the inference speed is slow (and taking about 32GB of memory):
```python
model = AutoModel.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True).float()
```
The inference speed will be relatively slow on CPU.
The above method requires 32GB of memory. If you only have 16GB of memory, you can try:
**[2023/03/19]** If your CPU memory is limited, you can directly load the quantized model:
```python
model = AutoModel.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True).bfloat16()
model = AutoModel.from_pretrained("THUDM/chatglm-6b-int4", trust_remote_code=True).float()
```
It is necessary to ensure that there is nearly 16GB of free memory, and the inference speed will be very slow.
**For Mac users**: if your encounter the error `RuntimeError: Unknown platform: darwin`, please refer to this [Issue](https://github.com/THUDM/ChatGLM-6B/issues/6#issuecomment-1470060041).
## ChatGLM-6B Examples

View File

@ -1,40 +1,24 @@
import os
import platform
import argparse
import time
from transformers import AutoTokenizer, AutoModel
parser = argparse.ArgumentParser(description='cli demo')
parser.add_argument('--cpu', action='store_true', help='cpu mode')
parser.add_argument('--showTime', action='store_true', help='show time consuming')
parser.add_argument('--local', action='store_true',help='using local models,default path:/models/chatglm-6b')
args = parser.parse_args()
os_name = platform.system()
# mac: force use cpu
if os_name == 'Darwin':
args.cpu = True
model_name = "THUDM/chatglm-6b"
if args.local:
model_name = "./models/chatglm-6b"
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
model = AutoModel.from_pretrained(model_name, trust_remote_code=True)
if(args.cpu):
model = model.float()
else:
model = model.half().cuda()
tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True)
model = AutoModel.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True).half().cuda()
model = model.eval()
os_name = platform.system()
clear_command = 'cls' if os_name == 'Windows' else 'clear'
def build_prompt(history):
prompt = "欢迎使用 ChatGLM-6B 模型输入内容即可进行对话clear 清空对话历史stop 终止程序"
for query, response in history:
prompt += f"\n\n用户:{query}"
prompt += f"\n\nChatGLM-6B{response}"
return prompt
def main():
history = []
print("欢迎使用 ChatGLM-6B 模型输入内容即可进行对话clear 清空对话历史stop 终止程序")
while True:
@ -43,13 +27,18 @@ while True:
break
if query == "clear":
history = []
command = 'cls' if os_name == 'Windows' else 'clear'
os.system(command)
os.system(clear_command)
print("欢迎使用 ChatGLM-6B 模型输入内容即可进行对话clear 清空对话历史stop 终止程序")
continue
timeStart = time.perf_counter()
response, history = model.chat(tokenizer, query, history=history)
timeEnd = time.perf_counter()
showTime="({timeEnd - timeStart:0.4f}s)" if args.showTime else ""
count = 0
for response, history in model.stream_chat(tokenizer, query, history=history):
count += 1
if count % 8 == 0:
os.system(clear_command)
print(build_prompt(history), flush=True)
os.system(clear_command)
print(build_prompt(history), flush=True)
print(f"ChatGLM-6B {showTime}{response}")
if __name__ == "__main__":
main()

BIN
resources/web-demo.gif Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 1.6 MiB

View File

@ -12,15 +12,15 @@ MAX_BOXES = MAX_TURNS * 2
def predict(input, max_length, top_p, temperature, history=None):
if history is None:
history = []
response, history = model.chat(tokenizer, input, history, max_length=max_length, top_p=top_p,
temperature=temperature)
for response, history in model.stream_chat(tokenizer, input, history, max_length=max_length, top_p=top_p,
temperature=temperature):
updates = []
for query, response in history:
updates.append(gr.update(visible=True, value="用户:" + query))
updates.append(gr.update(visible=True, value="ChatGLM-6B" + response))
if len(updates) < MAX_BOXES:
updates = updates + [gr.Textbox.update(visible=False)] * (MAX_BOXES - len(updates))
return [history] + updates
yield [history] + updates
with gr.Blocks() as demo:
@ -42,4 +42,4 @@ with gr.Blocks() as demo:
temperature = gr.Slider(0, 1, value=0.95, step=0.01, label="Temperature", interactive=True)
button = gr.Button("Generate")
button.click(predict, [txt, max_length, top_p, temperature, state], [state] + text_boxes)
demo.queue().launch(share=True, inbrowser=True)
demo.queue().launch(share=False, inbrowser=True)