From bc6695b7f22320c51d991d9b6592580d95136315 Mon Sep 17 00:00:00 2001 From: DealiAxy Date: Fri, 19 May 2023 17:33:20 +0800 Subject: [PATCH 1/4] =?UTF-8?q?=E2=98=80=20feat:=20update=20.gitignore?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .gitignore | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/.gitignore b/.gitignore index da568ec..c3fdd87 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,7 @@ +.vscode +ptuning/data +ptuning/output + # Byte-compiled / optimized / DLL files __pycache__/ *.py[cod] From 770676fdd51a4c3f25cfb0181019a42e69d6bb61 Mon Sep 17 00:00:00 2001 From: DealiAxy Date: Fri, 19 May 2023 17:33:40 +0800 Subject: [PATCH 2/4] =?UTF-8?q?=E2=98=80=20feat:=20=E9=87=8D=E5=86=99?= =?UTF-8?q?=E6=8E=A5=E5=8F=A3?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- api.py | 78 +++++++++++++++++++++++++++++++++------------------------- 1 file changed, 44 insertions(+), 34 deletions(-) diff --git a/api.py b/api.py index 693c70a..4ecaa27 100644 --- a/api.py +++ b/api.py @@ -1,40 +1,49 @@ +import json +import datetime +import torch +import uvicorn +from typing import List from fastapi import FastAPI, Request from transformers import AutoTokenizer, AutoModel -import uvicorn, json, datetime -import torch - -DEVICE = "cuda" -DEVICE_ID = "0" -CUDA_DEVICE = f"{DEVICE}:{DEVICE_ID}" if DEVICE_ID else DEVICE +from pydantic import BaseModel +from utils import load_model_on_gpus -def torch_gc(): +devices_list = [ + 'cuda:0', + 'cuda:1' +] + + +def _torch_gc(): if torch.cuda.is_available(): - with torch.cuda.device(CUDA_DEVICE): - torch.cuda.empty_cache() - torch.cuda.ipc_collect() + for item in devices_list: + with torch.cuda.device(item): + torch.cuda.empty_cache() + torch.cuda.ipc_collect() + + +class Question(BaseModel): + prompt: str + history: List[str] = [] + max_length: int = 2048 + top_p: float = 0.7 + temperature: float = 0.95 app = FastAPI() -@app.post("/") -async def create_item(request: Request): - global model, tokenizer - json_post_raw = await request.json() - json_post = json.dumps(json_post_raw) - json_post_list = json.loads(json_post) - prompt = json_post_list.get('prompt') - history = json_post_list.get('history') - max_length = json_post_list.get('max_length') - top_p = json_post_list.get('top_p') - temperature = json_post_list.get('temperature') - response, history = model.chat(tokenizer, - prompt, - history=history, - max_length=max_length if max_length else 2048, - top_p=top_p if top_p else 0.7, - temperature=temperature if temperature else 0.95) +@app.post('/chat/') +async def chat(question: Question): + response, history = model.chat( + tokenizer, + question.prompt, + history=question.history, + max_length=question.max_length, + top_p=question.top_p, + temperature=question.temperature + ) now = datetime.datetime.now() time = now.strftime("%Y-%m-%d %H:%M:%S") answer = { @@ -43,14 +52,15 @@ async def create_item(request: Request): "status": 200, "time": time } - log = "[" + time + "] " + '", prompt:"' + prompt + '", response:"' + repr(response) + '"' - print(log) - torch_gc() + _torch_gc() return answer -if __name__ == '__main__': - tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True) - model = AutoModel.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True).half().cuda() +if __name__ == "__main__": + tokenizer = AutoTokenizer.from_pretrained( + "THUDM/chatglm-6b", trust_remote_code=True + ) + model = load_model_on_gpus("THUDM/chatglm-6b", num_gpus=2) + # model = AutoModel.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True).half().cuda() model.eval() - uvicorn.run(app, host='0.0.0.0', port=8000, workers=1) + uvicorn.run(app, host="127.0.0.1", port=11001, workers=1) From 5b5546789588364feadbd01655282a09e9014f7e Mon Sep 17 00:00:00 2001 From: DealiAxy Date: Fri, 19 May 2023 17:34:58 +0800 Subject: [PATCH 3/4] =?UTF-8?q?=E2=98=80=20feat:=20=E4=BF=AE=E6=94=B9?= =?UTF-8?q?=E6=94=AF=E6=8C=81=E5=A4=9A=E5=8D=A1=E9=83=A8=E7=BD=B2?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- cli_demo_gpus.py | 59 ++++++++++++++++++++++++++++++++++++++++++++++++ web_demo.py | 28 ++++++++++++++--------- 2 files changed, 76 insertions(+), 11 deletions(-) create mode 100644 cli_demo_gpus.py diff --git a/cli_demo_gpus.py b/cli_demo_gpus.py new file mode 100644 index 0000000..56c8d83 --- /dev/null +++ b/cli_demo_gpus.py @@ -0,0 +1,59 @@ +import os +import platform +import signal +from transformers import AutoTokenizer, AutoModel +from utils import load_model_on_gpus + +tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True) +model = load_model_on_gpus("THUDM/chatglm-6b", num_gpus=2) +# model = AutoModel.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True).half().cuda() +model = model.eval() + +os_name = platform.system() +clear_command = 'cls' if os_name == 'Windows' else 'clear' +stop_stream = False + + +def build_prompt(history): + prompt = "欢迎使用 ChatGLM-6B 模型,输入内容即可进行对话,clear 清空对话历史,stop 终止程序" + for query, response in history: + prompt += f"\n\n用户:{query}" + prompt += f"\n\nChatGLM-6B:{response}" + return prompt + + +def signal_handler(signal, frame): + global stop_stream + stop_stream = True + + +def main(): + history = [] + global stop_stream + print("欢迎使用 ChatGLM-6B 模型,输入内容即可进行对话,clear 清空对话历史,stop 终止程序") + while True: + query = input("\n用户:") + if query.strip() == "stop": + break + if query.strip() == "clear": + history = [] + os.system(clear_command) + print("欢迎使用 ChatGLM-6B 模型,输入内容即可进行对话,clear 清空对话历史,stop 终止程序") + continue + count = 0 + for response, history in model.stream_chat(tokenizer, query, history=history): + if stop_stream: + stop_stream = False + break + else: + count += 1 + if count % 8 == 0: + os.system(clear_command) + print(build_prompt(history), flush=True) + signal.signal(signal.SIGINT, signal_handler) + os.system(clear_command) + print(build_prompt(history), flush=True) + + +if __name__ == "__main__": + main() diff --git a/web_demo.py b/web_demo.py index 97ea622..8f221bf 100644 --- a/web_demo.py +++ b/web_demo.py @@ -1,9 +1,12 @@ -from transformers import AutoModel, AutoTokenizer import gradio as gr import mdtex2html +from transformers import AutoModel, AutoTokenizer +from utils import load_model_on_gpus + tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True) -model = AutoModel.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True).half().cuda() +# model = AutoModel.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True).half().cuda() +model = load_model_on_gpus("THUDM/chatglm-6b", num_gpus=2) model = model.eval() """Override Chatbot.postprocess""" @@ -60,7 +63,7 @@ def predict(input, chatbot, max_length, top_p, temperature, history): chatbot.append((parse_text(input), "")) for response, history in model.stream_chat(tokenizer, input, history, max_length=max_length, top_p=top_p, temperature=temperature): - chatbot[-1] = (parse_text(input), parse_text(response)) + chatbot[-1] = (parse_text(input), parse_text(response)) yield chatbot, history @@ -74,21 +77,24 @@ def reset_state(): with gr.Blocks() as demo: - gr.HTML("""

ChatGLM

""") + gr.HTML("""

CodeLab

""") chatbot = gr.Chatbot() with gr.Row(): with gr.Column(scale=4): with gr.Column(scale=12): - user_input = gr.Textbox(show_label=False, placeholder="Input...", lines=10).style( + user_input = gr.Textbox(show_label=False, placeholder="输入聊天内容", lines=10).style( container=False) with gr.Column(min_width=32, scale=1): - submitBtn = gr.Button("Submit", variant="primary") + submitBtn = gr.Button("发送", variant="primary") with gr.Column(scale=1): - emptyBtn = gr.Button("Clear History") - max_length = gr.Slider(0, 4096, value=2048, step=1.0, label="Maximum length", interactive=True) - top_p = gr.Slider(0, 1, value=0.7, step=0.01, label="Top P", interactive=True) - temperature = gr.Slider(0, 1, value=0.95, step=0.01, label="Temperature", interactive=True) + emptyBtn = gr.Button("清除历史记录") + max_length = gr.Slider( + 0, 4096, value=2048, step=1.0, label="Maximum length", interactive=True) + top_p = gr.Slider(0, 1, value=0.7, step=0.01, + label="Top P", interactive=True) + temperature = gr.Slider( + 0, 1, value=0.95, step=0.01, label="Temperature", interactive=True) history = gr.State([]) @@ -98,4 +104,4 @@ with gr.Blocks() as demo: emptyBtn.click(reset_state, outputs=[chatbot, history], show_progress=True) -demo.queue().launch(share=False, inbrowser=True) +demo.queue().launch(share=False, inbrowser=False, server_port=11001) From 6a43132a393e44e07d2df546603b684cefe6f098 Mon Sep 17 00:00:00 2001 From: DealiAxy Date: Fri, 19 May 2023 17:35:06 +0800 Subject: [PATCH 4/4] =?UTF-8?q?=E2=98=80=20feat:=20=E8=AE=AD=E7=BB=83?= =?UTF-8?q?=E7=9B=B8=E5=85=B3=E4=BF=AE=E6=94=B9?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- ptuning/evaluate.sh | 6 +++--- ptuning/train.sh | 8 +++++--- ptuning/web_demo.py | 5 ++--- ptuning/web_demo.sh | 5 +++-- 4 files changed, 13 insertions(+), 11 deletions(-) diff --git a/ptuning/evaluate.sh b/ptuning/evaluate.sh index ab85536..f9debe5 100644 --- a/ptuning/evaluate.sh +++ b/ptuning/evaluate.sh @@ -2,10 +2,10 @@ PRE_SEQ_LEN=128 CHECKPOINT=adgen-chatglm-6b-pt-128-2e-2 STEP=3000 -CUDA_VISIBLE_DEVICES=0 python3 main.py \ +CUDA_VISIBLE_DEVICES=0,1,2,3 python3 main.py \ --do_predict \ - --validation_file AdvertiseGen/dev.json \ - --test_file AdvertiseGen/dev.json \ + --validation_file data/AdvertiseGen/dev.json \ + --test_file data/AdvertiseGen/dev.json \ --overwrite_cache \ --prompt_column content \ --response_column summary \ diff --git a/ptuning/train.sh b/ptuning/train.sh index efc9a16..5e85284 100644 --- a/ptuning/train.sh +++ b/ptuning/train.sh @@ -1,10 +1,12 @@ PRE_SEQ_LEN=128 LR=2e-2 -CUDA_VISIBLE_DEVICES=0 python3 main.py \ +export 'PYTORCH_CUDA_ALLOC_CONF=max_split_size_mb:32' + +CUDA_VISIBLE_DEVICES=0,1,2,3 python3 main.py \ --do_train \ - --train_file AdvertiseGen/train.json \ - --validation_file AdvertiseGen/dev.json \ + --train_file data/AdvertiseGen/train.json \ + --validation_file data/AdvertiseGen/dev.json \ --prompt_column content \ --response_column summary \ --overwrite_cache \ diff --git a/ptuning/web_demo.py b/ptuning/web_demo.py index 43d0c82..40fbe99 100644 --- a/ptuning/web_demo.py +++ b/ptuning/web_demo.py @@ -119,8 +119,7 @@ with gr.Blocks() as demo: def main(): global model, tokenizer - parser = HfArgumentParser(( - ModelArguments)) + parser = HfArgumentParser((ModelArguments)) if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): # If we pass only one argument to the script and it's the path to a json file, # let's parse it to get our arguments. @@ -158,7 +157,7 @@ def main(): model.transformer.prefix_encoder.float().cuda() model = model.eval() - demo.queue().launch(share=False, inbrowser=True) + demo.queue().launch(share=False, inbrowser=True, server_port=11001) diff --git a/ptuning/web_demo.sh b/ptuning/web_demo.sh index 87bf9e9..ebe7219 100644 --- a/ptuning/web_demo.sh +++ b/ptuning/web_demo.sh @@ -1,7 +1,8 @@ PRE_SEQ_LEN=128 -CUDA_VISIBLE_DEVICES=0 python3 web_demo.py \ +CUDA_VISIBLE_DEVICES=0,1 python3 web_demo.py \ --model_name_or_path THUDM/chatglm-6b \ --ptuning_checkpoint output/adgen-chatglm-6b-pt-128-2e-2/checkpoint-3000 \ - --pre_seq_len $PRE_SEQ_LEN + --pre_seq_len $PRE_SEQ_LEN \ + --quantization_bit 4