mirror of https://github.com/THUDM/ChatGLM-6B
83 lines
1.8 KiB
Python
83 lines
1.8 KiB
Python
from fastapi import FastAPI
|
|
from pydantic import BaseModel
|
|
import uvicorn
|
|
import json
|
|
from transformers import AutoModel, AutoTokenizer
|
|
from typing import List,Tuple
|
|
|
|
|
|
max_length = 4096
|
|
# 根据id获取上下文信息
|
|
def get_history(id: str) -> List[Tuple[str,str]] or None:
|
|
if id in sessions.keys():
|
|
length = len(json.dumps(sessions[id],indent=2))
|
|
if length>max_length:
|
|
sessions[id] = []
|
|
return None
|
|
if sessions[id] == []:
|
|
return None
|
|
return sessions[id]
|
|
else:
|
|
sessions[id] = []
|
|
return None
|
|
|
|
# 根据id清空上下文
|
|
|
|
|
|
def clear(id: str) -> str:
|
|
sessions[id] = []
|
|
return '已重置'
|
|
|
|
|
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained(
|
|
"THUDM/chatglm-6b", trust_remote_code=True)
|
|
model = AutoModel.from_pretrained(
|
|
"THUDM/chatglm-6b", trust_remote_code=True).half().cuda()
|
|
model = model.eval()
|
|
|
|
MAX_TURNS = 20
|
|
MAX_BOXES = MAX_TURNS * 2
|
|
|
|
sessions = {}
|
|
|
|
|
|
def predict(prompt: str, uid: str, max_length: int = 2048, top_p: float = 0.7, temperature: float = 0.95) -> str:
|
|
history = get_history(uid)
|
|
print(history)
|
|
response, history = model.chat(tokenizer, prompt, history=history, max_length=max_length, top_p=top_p,
|
|
temperature=temperature)
|
|
sessions[uid].append((prompt, response))
|
|
print(get_history(uid))
|
|
return response
|
|
|
|
# while 1:
|
|
# uid = input("uid:")
|
|
# prompt = input('msg:')
|
|
# msg = predict(prompt=prompt,uid = uid)
|
|
# print(msg)
|
|
|
|
|
|
app = FastAPI()
|
|
|
|
|
|
class Item_chat(BaseModel):
|
|
msg: str
|
|
uid: str
|
|
@app.post("/chat")
|
|
def chat(item:Item_chat):
|
|
msg = predict(prompt=item.msg, uid=item.uid)
|
|
print(msg)
|
|
return msg
|
|
|
|
|
|
class Item_claer(BaseModel):
|
|
uid: str
|
|
@app.post("/clear")
|
|
def clear_session(item:Item_claer):
|
|
return clear(item.uid)
|
|
|
|
|
|
uvicorn.run(app, host="0.0.0.0", port=10269)
|