ChatGLM-6B/glm-bot/fastapi.py

83 lines
1.8 KiB
Python

from fastapi import FastAPI
from pydantic import BaseModel
import uvicorn
import json
from transformers import AutoModel, AutoTokenizer
from typing import List,Tuple
max_length = 4096
# 根据id获取上下文信息
def get_history(id: str) -> List[Tuple[str,str]] or None:
if id in sessions.keys():
length = len(json.dumps(sessions[id],indent=2))
if length>max_length:
sessions[id] = []
return None
if sessions[id] == []:
return None
return sessions[id]
else:
sessions[id] = []
return None
# 根据id清空上下文
def clear(id: str) -> str:
sessions[id] = []
return '已重置'
tokenizer = AutoTokenizer.from_pretrained(
"THUDM/chatglm-6b", trust_remote_code=True)
model = AutoModel.from_pretrained(
"THUDM/chatglm-6b", trust_remote_code=True).half().cuda()
model = model.eval()
MAX_TURNS = 20
MAX_BOXES = MAX_TURNS * 2
sessions = {}
def predict(prompt: str, uid: str, max_length: int = 2048, top_p: float = 0.7, temperature: float = 0.95) -> str:
history = get_history(uid)
print(history)
response, history = model.chat(tokenizer, prompt, history=history, max_length=max_length, top_p=top_p,
temperature=temperature)
sessions[uid].append((prompt, response))
print(get_history(uid))
return response
# while 1:
# uid = input("uid:")
# prompt = input('msg:')
# msg = predict(prompt=prompt,uid = uid)
# print(msg)
app = FastAPI()
class Item_chat(BaseModel):
msg: str
uid: str
@app.post("/chat")
def chat(item:Item_chat):
msg = predict(prompt=item.msg, uid=item.uid)
print(msg)
return msg
class Item_claer(BaseModel):
uid: str
@app.post("/clear")
def clear_session(item:Item_claer):
return clear(item.uid)
uvicorn.run(app, host="0.0.0.0", port=10269)