pull/1066/merge
DealiAxy 2024-07-14 17:54:50 +08:00 committed by GitHub
commit 49a782700b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 137 additions and 56 deletions

4
.gitignore vendored
View File

@ -1,3 +1,7 @@
.vscode
ptuning/data
ptuning/output
# Byte-compiled / optimized / DLL files # Byte-compiled / optimized / DLL files
__pycache__/ __pycache__/
*.py[cod] *.py[cod]

78
api.py
View File

@ -1,40 +1,49 @@
import json
import datetime
import torch
import uvicorn
from typing import List
from fastapi import FastAPI, Request from fastapi import FastAPI, Request
from transformers import AutoTokenizer, AutoModel from transformers import AutoTokenizer, AutoModel
import uvicorn, json, datetime from pydantic import BaseModel
import torch from utils import load_model_on_gpus
DEVICE = "cuda"
DEVICE_ID = "0"
CUDA_DEVICE = f"{DEVICE}:{DEVICE_ID}" if DEVICE_ID else DEVICE
def torch_gc(): devices_list = [
'cuda:0',
'cuda:1'
]
def _torch_gc():
if torch.cuda.is_available(): if torch.cuda.is_available():
with torch.cuda.device(CUDA_DEVICE): for item in devices_list:
torch.cuda.empty_cache() with torch.cuda.device(item):
torch.cuda.ipc_collect() torch.cuda.empty_cache()
torch.cuda.ipc_collect()
class Question(BaseModel):
prompt: str
history: List[str] = []
max_length: int = 2048
top_p: float = 0.7
temperature: float = 0.95
app = FastAPI() app = FastAPI()
@app.post("/") @app.post('/chat/')
async def create_item(request: Request): async def chat(question: Question):
global model, tokenizer response, history = model.chat(
json_post_raw = await request.json() tokenizer,
json_post = json.dumps(json_post_raw) question.prompt,
json_post_list = json.loads(json_post) history=question.history,
prompt = json_post_list.get('prompt') max_length=question.max_length,
history = json_post_list.get('history') top_p=question.top_p,
max_length = json_post_list.get('max_length') temperature=question.temperature
top_p = json_post_list.get('top_p') )
temperature = json_post_list.get('temperature')
response, history = model.chat(tokenizer,
prompt,
history=history,
max_length=max_length if max_length else 2048,
top_p=top_p if top_p else 0.7,
temperature=temperature if temperature else 0.95)
now = datetime.datetime.now() now = datetime.datetime.now()
time = now.strftime("%Y-%m-%d %H:%M:%S") time = now.strftime("%Y-%m-%d %H:%M:%S")
answer = { answer = {
@ -43,14 +52,15 @@ async def create_item(request: Request):
"status": 200, "status": 200,
"time": time "time": time
} }
log = "[" + time + "] " + '", prompt:"' + prompt + '", response:"' + repr(response) + '"' _torch_gc()
print(log)
torch_gc()
return answer return answer
if __name__ == '__main__': if __name__ == "__main__":
tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True) tokenizer = AutoTokenizer.from_pretrained(
model = AutoModel.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True).half().cuda() "THUDM/chatglm-6b", trust_remote_code=True
)
model = load_model_on_gpus("THUDM/chatglm-6b", num_gpus=2)
# model = AutoModel.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True).half().cuda()
model.eval() model.eval()
uvicorn.run(app, host='0.0.0.0', port=8000, workers=1) uvicorn.run(app, host="127.0.0.1", port=11001, workers=1)

59
cli_demo_gpus.py Normal file
View File

@ -0,0 +1,59 @@
import os
import platform
import signal
from transformers import AutoTokenizer, AutoModel
from utils import load_model_on_gpus
tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True)
model = load_model_on_gpus("THUDM/chatglm-6b", num_gpus=2)
# model = AutoModel.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True).half().cuda()
model = model.eval()
os_name = platform.system()
clear_command = 'cls' if os_name == 'Windows' else 'clear'
stop_stream = False
def build_prompt(history):
prompt = "欢迎使用 ChatGLM-6B 模型输入内容即可进行对话clear 清空对话历史stop 终止程序"
for query, response in history:
prompt += f"\n\n用户:{query}"
prompt += f"\n\nChatGLM-6B{response}"
return prompt
def signal_handler(signal, frame):
global stop_stream
stop_stream = True
def main():
history = []
global stop_stream
print("欢迎使用 ChatGLM-6B 模型输入内容即可进行对话clear 清空对话历史stop 终止程序")
while True:
query = input("\n用户:")
if query.strip() == "stop":
break
if query.strip() == "clear":
history = []
os.system(clear_command)
print("欢迎使用 ChatGLM-6B 模型输入内容即可进行对话clear 清空对话历史stop 终止程序")
continue
count = 0
for response, history in model.stream_chat(tokenizer, query, history=history):
if stop_stream:
stop_stream = False
break
else:
count += 1
if count % 8 == 0:
os.system(clear_command)
print(build_prompt(history), flush=True)
signal.signal(signal.SIGINT, signal_handler)
os.system(clear_command)
print(build_prompt(history), flush=True)
if __name__ == "__main__":
main()

View File

@ -2,10 +2,10 @@ PRE_SEQ_LEN=128
CHECKPOINT=adgen-chatglm-6b-pt-128-2e-2 CHECKPOINT=adgen-chatglm-6b-pt-128-2e-2
STEP=3000 STEP=3000
CUDA_VISIBLE_DEVICES=0 python3 main.py \ CUDA_VISIBLE_DEVICES=0,1,2,3 python3 main.py \
--do_predict \ --do_predict \
--validation_file AdvertiseGen/dev.json \ --validation_file data/AdvertiseGen/dev.json \
--test_file AdvertiseGen/dev.json \ --test_file data/AdvertiseGen/dev.json \
--overwrite_cache \ --overwrite_cache \
--prompt_column content \ --prompt_column content \
--response_column summary \ --response_column summary \

View File

@ -1,10 +1,12 @@
PRE_SEQ_LEN=128 PRE_SEQ_LEN=128
LR=2e-2 LR=2e-2
CUDA_VISIBLE_DEVICES=0 python3 main.py \ export 'PYTORCH_CUDA_ALLOC_CONF=max_split_size_mb:32'
CUDA_VISIBLE_DEVICES=0,1,2,3 python3 main.py \
--do_train \ --do_train \
--train_file AdvertiseGen/train.json \ --train_file data/AdvertiseGen/train.json \
--validation_file AdvertiseGen/dev.json \ --validation_file data/AdvertiseGen/dev.json \
--prompt_column content \ --prompt_column content \
--response_column summary \ --response_column summary \
--overwrite_cache \ --overwrite_cache \

View File

@ -119,8 +119,7 @@ with gr.Blocks() as demo:
def main(): def main():
global model, tokenizer global model, tokenizer
parser = HfArgumentParser(( parser = HfArgumentParser((ModelArguments))
ModelArguments))
if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
# If we pass only one argument to the script and it's the path to a json file, # If we pass only one argument to the script and it's the path to a json file,
# let's parse it to get our arguments. # let's parse it to get our arguments.
@ -158,7 +157,7 @@ def main():
model.transformer.prefix_encoder.float().cuda() model.transformer.prefix_encoder.float().cuda()
model = model.eval() model = model.eval()
demo.queue().launch(share=False, inbrowser=True) demo.queue().launch(share=False, inbrowser=True, server_port=11001)

View File

@ -1,7 +1,8 @@
PRE_SEQ_LEN=128 PRE_SEQ_LEN=128
CUDA_VISIBLE_DEVICES=0 python3 web_demo.py \ CUDA_VISIBLE_DEVICES=0,1 python3 web_demo.py \
--model_name_or_path THUDM/chatglm-6b \ --model_name_or_path THUDM/chatglm-6b \
--ptuning_checkpoint output/adgen-chatglm-6b-pt-128-2e-2/checkpoint-3000 \ --ptuning_checkpoint output/adgen-chatglm-6b-pt-128-2e-2/checkpoint-3000 \
--pre_seq_len $PRE_SEQ_LEN --pre_seq_len $PRE_SEQ_LEN \
--quantization_bit 4

View File

@ -1,9 +1,12 @@
from transformers import AutoModel, AutoTokenizer
import gradio as gr import gradio as gr
import mdtex2html import mdtex2html
from transformers import AutoModel, AutoTokenizer
from utils import load_model_on_gpus
tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True) tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True)
model = AutoModel.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True).half().cuda() # model = AutoModel.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True).half().cuda()
model = load_model_on_gpus("THUDM/chatglm-6b", num_gpus=2)
model = model.eval() model = model.eval()
"""Override Chatbot.postprocess""" """Override Chatbot.postprocess"""
@ -60,7 +63,7 @@ def predict(input, chatbot, max_length, top_p, temperature, history):
chatbot.append((parse_text(input), "")) chatbot.append((parse_text(input), ""))
for response, history in model.stream_chat(tokenizer, input, history, max_length=max_length, top_p=top_p, for response, history in model.stream_chat(tokenizer, input, history, max_length=max_length, top_p=top_p,
temperature=temperature): temperature=temperature):
chatbot[-1] = (parse_text(input), parse_text(response)) chatbot[-1] = (parse_text(input), parse_text(response))
yield chatbot, history yield chatbot, history
@ -74,21 +77,24 @@ def reset_state():
with gr.Blocks() as demo: with gr.Blocks() as demo:
gr.HTML("""<h1 align="center">ChatGLM</h1>""") gr.HTML("""<h1 align="center">CodeLab</h1>""")
chatbot = gr.Chatbot() chatbot = gr.Chatbot()
with gr.Row(): with gr.Row():
with gr.Column(scale=4): with gr.Column(scale=4):
with gr.Column(scale=12): with gr.Column(scale=12):
user_input = gr.Textbox(show_label=False, placeholder="Input...", lines=10).style( user_input = gr.Textbox(show_label=False, placeholder="输入聊天内容", lines=10).style(
container=False) container=False)
with gr.Column(min_width=32, scale=1): with gr.Column(min_width=32, scale=1):
submitBtn = gr.Button("Submit", variant="primary") submitBtn = gr.Button("发送", variant="primary")
with gr.Column(scale=1): with gr.Column(scale=1):
emptyBtn = gr.Button("Clear History") emptyBtn = gr.Button("清除历史记录")
max_length = gr.Slider(0, 4096, value=2048, step=1.0, label="Maximum length", interactive=True) max_length = gr.Slider(
top_p = gr.Slider(0, 1, value=0.7, step=0.01, label="Top P", interactive=True) 0, 4096, value=2048, step=1.0, label="Maximum length", interactive=True)
temperature = gr.Slider(0, 1, value=0.95, step=0.01, label="Temperature", interactive=True) top_p = gr.Slider(0, 1, value=0.7, step=0.01,
label="Top P", interactive=True)
temperature = gr.Slider(
0, 1, value=0.95, step=0.01, label="Temperature", interactive=True)
history = gr.State([]) history = gr.State([])
@ -98,4 +104,4 @@ with gr.Blocks() as demo:
emptyBtn.click(reset_state, outputs=[chatbot, history], show_progress=True) emptyBtn.click(reset_state, outputs=[chatbot, history], show_progress=True)
demo.queue().launch(share=False, inbrowser=True) demo.queue().launch(share=False, inbrowser=False, server_port=11001)