From 2e8c7092512a0f3c2407681f1a90298f828cf344 Mon Sep 17 00:00:00 2001 From: linteng <10499295+wanglinteng@users.noreply.github.com> Date: Fri, 5 May 2023 15:08:40 +0800 Subject: [PATCH 1/2] Create api.py --- ptuning/api.py | 100 +++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 100 insertions(+) create mode 100644 ptuning/api.py diff --git a/ptuning/api.py b/ptuning/api.py new file mode 100644 index 0000000..249aa67 --- /dev/null +++ b/ptuning/api.py @@ -0,0 +1,100 @@ +import os +import sys +import torch +import uvicorn, json, datetime + +from fastapi import FastAPI, Request +from transformers import AutoTokenizer, AutoModel +from transformers import ( + HfArgumentParser, + AutoConfig +) +from arguments import ModelArguments + +DEVICE = "cuda" +DEVICE_ID = "0" +CUDA_DEVICE = f"{DEVICE}:{DEVICE_ID}" if DEVICE_ID else DEVICE + + +def torch_gc(): + if torch.cuda.is_available(): + with torch.cuda.device(CUDA_DEVICE): + torch.cuda.empty_cache() + torch.cuda.ipc_collect() + + +app = FastAPI() + + +@app.post("/") +async def create_item(request: Request): + global model, tokenizer + json_post_raw = await request.json() + json_post = json.dumps(json_post_raw) + json_post_list = json.loads(json_post) + prompt = json_post_list.get('prompt') + history = json_post_list.get('history') + max_length = json_post_list.get('max_length') + top_p = json_post_list.get('top_p') + temperature = json_post_list.get('temperature') + response, history = model.chat(tokenizer, + prompt, + history=history, + max_length=max_length if max_length else 2048, + top_p=top_p if top_p else 0.7, + temperature=temperature if temperature else 0.95) + now = datetime.datetime.now() + time = now.strftime("%Y-%m-%d %H:%M:%S") + answer = { + "response": response, + "history": history, + "status": 200, + "time": time + } + log = "[" + time + "] " + '", prompt:"' + prompt + '", response:"' + repr(response) + '"' + print(log) + torch_gc() + return answer + + +if __name__ == '__main__': + parser = HfArgumentParser( + (ModelArguments)) + if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): + # If we pass only one argument to the script and it's the path to a json file, + # let's parse it to get our arguments. + model_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))[0] + else: + model_args = parser.parse_args_into_dataclasses()[0] + + tokenizer = AutoTokenizer.from_pretrained( + model_args.model_name_or_path, trust_remote_code=True) + config = AutoConfig.from_pretrained( + model_args.model_name_or_path, trust_remote_code=True) + + config.pre_seq_len = model_args.pre_seq_len + config.prefix_projection = model_args.prefix_projection + + if model_args.ptuning_checkpoint is not None: + print(f"Loading prefix_encoder weight from {model_args.ptuning_checkpoint}") + model = AutoModel.from_pretrained(model_args.model_name_or_path, config=config, trust_remote_code=True) + prefix_state_dict = torch.load(os.path.join(model_args.ptuning_checkpoint, "pytorch_model.bin")) + new_prefix_state_dict = {} + for k, v in prefix_state_dict.items(): + if k.startswith("transformer.prefix_encoder."): + new_prefix_state_dict[k[len("transformer.prefix_encoder."):]] = v + model.transformer.prefix_encoder.load_state_dict(new_prefix_state_dict) + else: + model = AutoModel.from_pretrained(model_args.model_name_or_path, config=config, trust_remote_code=True) + + if model_args.quantization_bit is not None: + print(f"Quantized to {model_args.quantization_bit} bit") + model = model.quantize(model_args.quantization_bit) + + if model_args.pre_seq_len is not None: + # P-tuning v2 + model = model.half().cuda() + model.transformer.prefix_encoder.float().cuda() + + model.eval() + uvicorn.run(app, host='0.0.0.0', port=8000, workers=1) From 091f0a4c4ddbb6f2325b7c84d4c30c6be9924e45 Mon Sep 17 00:00:00 2001 From: linteng <10499295+wanglinteng@users.noreply.github.com> Date: Fri, 5 May 2023 15:10:18 +0800 Subject: [PATCH 2/2] Create api.sh --- ptuning/api.sh | 6 ++++++ 1 file changed, 6 insertions(+) create mode 100644 ptuning/api.sh diff --git a/ptuning/api.sh b/ptuning/api.sh new file mode 100644 index 0000000..2fc1663 --- /dev/null +++ b/ptuning/api.sh @@ -0,0 +1,6 @@ +PRE_SEQ_LEN=128 + +CUDA_VISIBLE_DEVICES=0 python3 api.py \ + --model_name_or_path THUDM/chatglm-6b \ + --ptuning_checkpoint output/adgen-chatglm-6b-pt-128-2e-2/checkpoint-3000 \ + --pre_seq_len $PRE_SEQ_LEN