mirror of https://github.com/THUDM/ChatGLM-6B
Add Stream API deployment
parent
ab6bcb4968
commit
02947052ee
|
@ -0,0 +1,95 @@
|
|||
from fastapi import FastAPI
|
||||
from pydantic import BaseModel
|
||||
from typing import Optional
|
||||
from sse_starlette.sse import EventSourceResponse
|
||||
from transformers import AutoTokenizer, AutoModel
|
||||
import uvicorn
|
||||
import torch
|
||||
|
||||
'''
|
||||
此脚本实现模型响应结果的流式传输,让用户无需等待完整内容的响应。
|
||||
This script implements the streaming transmission of model response results, eliminating the need for users to wait for a complete response of the content.
|
||||
访问接口时它将返回event-stream流,你需要在客户端接收并处理它。
|
||||
When accessing the interface, it will return an 'event-stream' stream, which you need to receive and process on the client.
|
||||
|
||||
POST http://127.0.0.1:8010
|
||||
{ "input": "你好ChatGLM" }
|
||||
|
||||
input: 输入内容
|
||||
max_length: 最大长度
|
||||
top_p: 采样阈值
|
||||
temperature: 抽样随机性
|
||||
history: 二维历史消息数组,eg: [["你好ChatGLM","你好,我是ChatGLM,一个基于语言模型的人工智能助手。很高兴见到你,欢迎问我任何问题。"]]
|
||||
html_entities: 开启HTML字符实体转换
|
||||
'''
|
||||
|
||||
DEVICE = "cuda"
|
||||
DEVICE_ID = "0"
|
||||
CUDA_DEVICE = f"{DEVICE}:{DEVICE_ID}" if DEVICE_ID else DEVICE
|
||||
|
||||
def torch_gc():
|
||||
if torch.cuda.is_available():
|
||||
with torch.cuda.device(CUDA_DEVICE):
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.ipc_collect()
|
||||
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
def parse_text(text):
|
||||
lines = text.split("\n")
|
||||
lines = [line for line in lines if line != ""]
|
||||
count = 0
|
||||
for i, line in enumerate(lines):
|
||||
if "```" in line:
|
||||
count += 1
|
||||
items = line.split('`')
|
||||
if count % 2 == 1:
|
||||
lines[i] = f'<pre><code class="language-{items[-1]}">'
|
||||
else:
|
||||
lines[i] = f'<br></code></pre>'
|
||||
else:
|
||||
if i > 0:
|
||||
if count % 2 == 1:
|
||||
line = line.replace("`", "\`")
|
||||
line = line.replace("<", "<")
|
||||
line = line.replace(">", ">")
|
||||
line = line.replace(" ", " ")
|
||||
line = line.replace("*", "*")
|
||||
line = line.replace("_", "_")
|
||||
line = line.replace("-", "-")
|
||||
line = line.replace(".", ".")
|
||||
line = line.replace("!", "!")
|
||||
line = line.replace("(", "(")
|
||||
line = line.replace(")", ")")
|
||||
line = line.replace("$", "$")
|
||||
lines[i] = "<br>"+line
|
||||
text = "".join(lines)
|
||||
return text
|
||||
|
||||
async def predict(input, max_length, top_p, temperature, history, html_entities):
|
||||
global model, tokenizer
|
||||
for response, history in model.stream_chat(tokenizer, input, history, max_length=max_length, top_p=top_p,
|
||||
temperature=temperature):
|
||||
yield parse_text(response) if html_entities else response
|
||||
torch_gc()
|
||||
|
||||
class ConversationsParams(BaseModel):
|
||||
input: str
|
||||
max_length: Optional[int] = 2048
|
||||
top_p: Optional[float] = 0.7
|
||||
temperature: Optional[float] = 0.95
|
||||
history: Optional[list] = []
|
||||
html_entities: Optional[bool] = True
|
||||
|
||||
@app.post('/')
|
||||
async def conversations(params: ConversationsParams):
|
||||
history = list(map(tuple, params.history))
|
||||
predictGenerator = predict(params.input, params.max_length, params.top_p, params.temperature, history, params.html_entities)
|
||||
return EventSourceResponse(predictGenerator)
|
||||
|
||||
if __name__ == '__main__':
|
||||
tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True)
|
||||
model = AutoModel.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True).half().cuda()
|
||||
model.eval()
|
||||
uvicorn.run(app, host='0.0.0.0', port=8010, workers=1)
|
Loading…
Reference in New Issue