You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

118 lines
6.4 KiB

import argparse
import copy
import json
import os
import random
import string
from http.server import BaseHTTPRequestHandler, HTTPServer
from colossalqa.local.llm import ColossalAPI, ColossalLLM
from colossalqa.data_loader.document_loader import DocumentLoader
from colossalqa.retrieval_conversation_zh import ChineseRetrievalConversation
from colossalqa.retriever import CustomRetriever
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.text_splitter import RecursiveCharacterTextSplitter
from RAG_ChatBot import RAG_ChatBot, DEFAULT_RAG_CFG
# Define the mapping between embed_model_name(passed from Front End) and the actual path on the back end server
"m3e": os.environ.get("EMB_MODEL_PATH", DEFAULT_RAG_CFG["embed_model_name_or_path"])
# Define the mapping between LLM_name(passed from Front End) and the actual path on the back end server
"chatglm2": os.environ.get("CHAT_LLM_PATH", "THUDM/chatglm-6b"),
"pangu": "Pangu_API",
"chatgpt": "OpenAI_API"
def randomword(length):
letters = string.ascii_lowercase
return "".join(random.choice(letters) for i in range(length))
class ColossalQAServerRequestHandler(BaseHTTPRequestHandler):
chatbot = None
def _set_response(self):
set http header for response
self.send_header("Content-type", "application/json")
def do_POST(self):
content_length = int(self.headers["Content-Length"])
post_data =
received_json = json.loads(post_data.decode("utf-8"))
# conversation_ready is False(user's first request): Need to upload files and initialize the RAG chain
if received_json["conversation_ready"] is False:
self.rag_config = DEFAULT_RAG_CFG.copy()
assert received_json["embed_model_name"] in EMBED_MODEL_DICT
assert received_json["llm_name"] in LLM_DICT
self.docs_files = received_json["docs"]
embed_model_name, llm_name = received_json["embed_model_name"], received_json["llm_name"]
# Find the embed_model/llm ckpt path on the back end server.
embed_model_path, llm_path = EMBED_MODEL_DICT[embed_model_name], LLM_DICT[llm_name]
self.rag_config["embed_model_name_or_path"] = embed_model_path
# Create the storage path for knowledge base files
self.rag_config["retri_kb_file_path"] = os.path.join(os.environ["TMP"], "colossalqa_kb/"+randomword(20))
if not os.path.exists(self.rag_config["retri_kb_file_path"]):
if (embed_model_path is not None) and (llm_path is not None):
# ---- Intialize LLM, QA_chatbot here ----
print("Initializing LLM...")
if llm_path == "Pangu_API":
from colossalqa.local.pangu_llm import Pangu
self.llm = Pangu(id=1)
self.llm.set_auth_config() # verify user's auth info here
self.rag_config["mem_llm_kwargs"] = None
self.rag_config["disambig_llm_kwargs"] = None
self.rag_config["gen_llm_kwargs"] = None
elif llm_path == "OpenAI_API":
from langchain.llms import OpenAI
self.llm = OpenAI()
self.rag_config["mem_llm_kwargs"] = None
self.rag_config["disambig_llm_kwargs"] = None
self.rag_config["gen_llm_kwargs"] = None
# ** (For Testing Only) **
# In practice, all LLMs will run on the cloud platform and accessed by API, instead of running locally.
# initialize model from model_path by using ColossalLLM
self.rag_config["mem_llm_kwargs"] = {"max_new_tokens": 50, "temperature": 1, "do_sample": True}
self.rag_config["disambig_llm_kwargs"] = {"max_new_tokens": 30, "temperature": 1, "do_sample": True}
self.rag_config["gen_llm_kwargs"] = {"max_new_tokens": 100, "temperature": 1, "do_sample": True}
self.colossal_api = ColossalAPI(llm_name, llm_path)
self.llm = ColossalLLM(n=1, api=self.colossal_api)
print(f"Initializing RAG Chain...")
print("RAG_CONFIG: ", self.rag_config)
self.__class__.chatbot = RAG_ChatBot(self.llm, self.rag_config)
print("Loading Files....\n", self.docs_files)
# -----------------------------------------------------------------------------------
res = {"response": f"文件上传完成,模型初始化完成,让我们开始对话吧!(后端模型:{llm_name})", "error": "", "conversation_ready": True}
except Exception as e:
res = {"response": "文件上传或模型初始化有误,无法开始对话。",
"error": f"Error in File Uploading and/or RAG initialization. Error details: {e}",
"conversation_ready": False}
# conversation_ready is True: Chatbot and docs are all set. Ready to chat.
user_input = received_json["user_input"]
chatbot_response, self.__class__.chatbot.memory =, self.__class__.chatbot.memory)
res = {"response": chatbot_response, "error": "", "conversation_ready": True}
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Chinese retrieval based conversation system")
parser.add_argument("--port", type=int, default=13666, help="port on localhost to start the server")
args = parser.parse_args()
server_address = ("localhost", args.port)
httpd = HTTPServer(server_address, ColossalQAServerRequestHandler)
print(f"Starting server on port {args.port}...")