Merge branch 'main' into main

pull/216/head
Ji 2023-03-28 18:11:29 -07:00 committed by GitHub
commit 779869927a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 62 additions and 10 deletions

View File

@ -17,6 +17,17 @@ ChatGLM-6B 使用了和 ChatGPT 相似的技术,针对中文问答和对话进
**[2023/03/19]** 增加流式输出接口 `stream_chat`,已更新到网页版和命令行 Demo。修复输出中的中文标点。增加量化后的模型 [ChatGLM-6B-INT4](https://huggingface.co/THUDM/chatglm-6b-int4) **[2023/03/19]** 增加流式输出接口 `stream_chat`,已更新到网页版和命令行 Demo。修复输出中的中文标点。增加量化后的模型 [ChatGLM-6B-INT4](https://huggingface.co/THUDM/chatglm-6b-int4)
## 友情链接
以下是部分基于本仓库开发的开源项目:
* [ChatGLM-MNN](https://github.com/wangzhaode/ChatGLM-MNN): 一个基于 MNN 的 ChatGLM-6B C++ 推理实现,支持根据显存大小自动分配计算任务给 GPU 和 CPU
* [ChatGLM-Tuning](https://github.com/mymusise/ChatGLM-Tuning): 基于 LoRA 对 ChatGLM-6B 进行微调
以下是部分针对本项目的教程/文档:
* [Windows部署文档](https://github.com/ZhangErling/ChatGLM-6B/blob/main/deployment_windows.md)
如果你有其他好的项目/教程的话欢迎参照上述格式添加到README中并提出 [PR](https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/proposing-changes-to-your-work-with-pull-requests/creating-a-pull-request-from-a-fork).
## 使用方式 ## 使用方式
### 硬件需求 ### 硬件需求

View File

@ -13,6 +13,13 @@ Try the [online demo](https://huggingface.co/spaces/ysharma/ChatGLM-6b_Gradio_St
**[2023/03/19]** Add streaming output function `stream_chat`, already applied in web and CLI demo. Fix Chinese punctuations in output. Add quantized model [ChatGLM-6B-INT4](https://huggingface.co/THUDM/chatglm-6b-int4). **[2023/03/19]** Add streaming output function `stream_chat`, already applied in web and CLI demo. Fix Chinese punctuations in output. Add quantized model [ChatGLM-6B-INT4](https://huggingface.co/THUDM/chatglm-6b-int4).
## Projects
The following are some open source projects developed based on this repository:
* [ChatGLM-MNN](https://github.com/wangzhaode/ChatGLM-MNN): An [MNN](https://github.com/alibaba/MNN)-based implementation of ChatGLM-6B C++ inference, which supports automatic allocation of computing tasks to GPU and CPU according to the size of GPU memory
* [ChatGLM-Tuning](https://github.com/mymusise/ChatGLM-Tuning): Fine-tuning ChatGLM-6B based on LoRA
If you have other good projects, please refer to the above format to add to README and propose [PR](https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/proposing-changes-to-your-work-with-pull-requests/creating-a-pull-request-from-a-fork).
## Getting Started ## Getting Started
### Hardware Requirements ### Hardware Requirements

27
api.py
View File

@ -1,6 +1,19 @@
from fastapi import FastAPI, Request from fastapi import FastAPI, Request
from transformers import AutoTokenizer, AutoModel from transformers import AutoTokenizer, AutoModel
import uvicorn, json, datetime import uvicorn, json, datetime
import torch
DEVICE = "cuda"
DEVICE_ID = "0"
CUDA_DEVICE = f"{DEVICE}:{DEVICE_ID}" if DEVICE_ID else DEVICE
def torch_gc():
if torch.cuda.is_available():
with torch.cuda.device(CUDA_DEVICE):
torch.cuda.empty_cache()
torch.cuda.ipc_collect()
app = FastAPI() app = FastAPI()
@ -13,7 +26,15 @@ async def create_item(request: Request):
json_post_list = json.loads(json_post) json_post_list = json.loads(json_post)
prompt = json_post_list.get('prompt') prompt = json_post_list.get('prompt')
history = json_post_list.get('history') history = json_post_list.get('history')
response, history = model.chat(tokenizer, prompt, history=history) max_length = json_post_list.get('max_length')
top_p = json_post_list.get('top_p')
temperature = json_post_list.get('temperature')
response, history = model.chat(tokenizer,
prompt,
history=history,
max_length=max_length if max_length else 2048,
top_p=top_p if top_p else 0.7,
temperature=temperature if temperature else 0.95)
now = datetime.datetime.now() now = datetime.datetime.now()
time = now.strftime("%Y-%m-%d %H:%M:%S") time = now.strftime("%Y-%m-%d %H:%M:%S")
answer = { answer = {
@ -24,12 +45,12 @@ async def create_item(request: Request):
} }
log = "[" + time + "] " + '", prompt:"' + prompt + '", response:"' + repr(response) + '"' log = "[" + time + "] " + '", prompt:"' + prompt + '", response:"' + repr(response) + '"'
print(log) print(log)
torch_gc()
return answer return answer
if __name__ == '__main__': if __name__ == '__main__':
uvicorn.run('api:app', host='0.0.0.0', port=8000, workers=1)
tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True) tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True)
model = AutoModel.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True).half().cuda() model = AutoModel.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True).half().cuda()
model.eval() model.eval()
uvicorn.run(app, host='0.0.0.0', port=8000, workers=1)

View File

@ -1,5 +1,6 @@
import os import os
import platform import platform
import signal
from transformers import AutoTokenizer, AutoModel from transformers import AutoTokenizer, AutoModel
tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True) tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True)
@ -8,6 +9,7 @@ model = model.eval()
os_name = platform.system() os_name = platform.system()
clear_command = 'cls' if os_name == 'Windows' else 'clear' clear_command = 'cls' if os_name == 'Windows' else 'clear'
stop_stream = False
def build_prompt(history): def build_prompt(history):
@ -18,8 +20,14 @@ def build_prompt(history):
return prompt return prompt
def signal_handler(signal, frame):
global stop_stream
stop_stream = True
def main(): def main():
history = [] history = []
global stop_stream
print("欢迎使用 ChatGLM-6B 模型输入内容即可进行对话clear 清空对话历史stop 终止程序") print("欢迎使用 ChatGLM-6B 模型输入内容即可进行对话clear 清空对话历史stop 终止程序")
while True: while True:
query = input("\n用户:") query = input("\n用户:")
@ -32,10 +40,15 @@ def main():
continue continue
count = 0 count = 0
for response, history in model.stream_chat(tokenizer, query, history=history): for response, history in model.stream_chat(tokenizer, query, history=history):
if stop_stream:
stop_stream = False
break
else:
count += 1 count += 1
if count % 8 == 0: if count % 8 == 0:
os.system(clear_command) os.system(clear_command)
print(build_prompt(history), flush=True) print(build_prompt(history), flush=True)
signal.signal(signal.SIGINT, signal_handler)
os.system(clear_command) os.system(clear_command)
print(build_prompt(history), flush=True) print(build_prompt(history), flush=True)