mirror of https://github.com/THUDM/ChatGLM2-6B
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
60 lines
2.1 KiB
60 lines
2.1 KiB
from fastapi import FastAPI, Request |
|
from transformers import AutoTokenizer, AutoModel |
|
import uvicorn, json, datetime |
|
import torch |
|
|
|
DEVICE = "cuda" |
|
DEVICE_ID = "0" |
|
CUDA_DEVICE = f"{DEVICE}:{DEVICE_ID}" if DEVICE_ID else DEVICE |
|
|
|
|
|
def torch_gc(): |
|
if torch.cuda.is_available(): |
|
with torch.cuda.device(CUDA_DEVICE): |
|
torch.cuda.empty_cache() |
|
torch.cuda.ipc_collect() |
|
|
|
|
|
app = FastAPI() |
|
|
|
|
|
@app.post("/") |
|
async def create_item(request: Request): |
|
global model, tokenizer |
|
json_post_raw = await request.json() |
|
json_post = json.dumps(json_post_raw) |
|
json_post_list = json.loads(json_post) |
|
prompt = json_post_list.get('prompt') |
|
history = json_post_list.get('history') |
|
max_length = json_post_list.get('max_length') |
|
top_p = json_post_list.get('top_p') |
|
temperature = json_post_list.get('temperature') |
|
response, history = model.chat(tokenizer, |
|
prompt, |
|
history=history, |
|
max_length=max_length if max_length else 2048, |
|
top_p=top_p if top_p else 0.7, |
|
temperature=temperature if temperature else 0.95) |
|
now = datetime.datetime.now() |
|
time = now.strftime("%Y-%m-%d %H:%M:%S") |
|
answer = { |
|
"response": response, |
|
"history": history, |
|
"status": 200, |
|
"time": time |
|
} |
|
log = "[" + time + "] " + '", prompt:"' + prompt + '", response:"' + repr(response) + '"' |
|
print(log) |
|
torch_gc() |
|
return answer |
|
|
|
|
|
if __name__ == '__main__': |
|
tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm2-6b", trust_remote_code=True) |
|
model = AutoModel.from_pretrained("THUDM/chatglm2-6b", trust_remote_code=True).cuda() |
|
# 多显卡支持,使用下面三行代替上面两行,将num_gpus改为你实际的显卡数量 |
|
# model_path = "THUDM/chatglm2-6b" |
|
# tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) |
|
# model = load_model_on_gpus(model_path, num_gpus=2) |
|
model.eval() |
|
uvicorn.run(app, host='0.0.0.0', port=8000, workers=1)
|
|
|