mirror of https://github.com/hpcaitech/ColossalAI
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
172 lines
7.0 KiB
172 lines
7.0 KiB
import os
|
|
from typing import Dict, Tuple
|
|
|
|
from colossalqa.chain.retrieval_qa.base import RetrievalQA
|
|
from colossalqa.data_loader.document_loader import DocumentLoader
|
|
from colossalqa.memory import ConversationBufferWithSummary
|
|
from colossalqa.mylogging import get_logger
|
|
from colossalqa.prompt.prompt import ZH_RETRIEVAL_QA_REJECTION_ANSWER, ZH_RETRIEVAL_QA_TRIGGER_KEYWORDS
|
|
from colossalqa.retriever import CustomRetriever
|
|
from langchain import LLMChain
|
|
from langchain.embeddings import HuggingFaceEmbeddings
|
|
|
|
logger = get_logger()
|
|
|
|
|
|
class RAG_ChatBot:
|
|
def __init__(
|
|
self,
|
|
llm,
|
|
rag_config,
|
|
) -> None:
|
|
self.llm = llm
|
|
self.rag_config = rag_config
|
|
self.set_embed_model(**self.rag_config["embed"])
|
|
self.set_text_splitter(**self.rag_config["splitter"])
|
|
self.set_memory(**self.rag_config["chain"])
|
|
self.set_info_retriever(**self.rag_config["retrieval"])
|
|
self.set_rag_chain(**self.rag_config["chain"])
|
|
if self.rag_config["chain"].get("disambig_prompt", None):
|
|
self.set_disambig_retriv(**self.rag_config["chain"])
|
|
|
|
self.documents = []
|
|
self.docs_names = []
|
|
|
|
def set_embed_model(self, **kwargs):
|
|
self.embed_model = HuggingFaceEmbeddings(
|
|
model_name=kwargs["embed_model_name_or_path"],
|
|
model_kwargs=kwargs["embed_model_device"],
|
|
encode_kwargs={"normalize_embeddings": False},
|
|
)
|
|
|
|
def set_text_splitter(self, **kwargs):
|
|
# Initialize text_splitter
|
|
self.text_splitter = kwargs["name"]()
|
|
|
|
def set_memory(self, **kwargs):
|
|
params = {"llm_kwargs": kwargs["mem_llm_kwargs"]} if kwargs.get("mem_llm_kwargs", None) else {}
|
|
# Initialize memory with summarization ability
|
|
self.memory = ConversationBufferWithSummary(
|
|
llm=self.llm,
|
|
prompt=kwargs["mem_summary_prompt"],
|
|
human_prefix=kwargs["mem_human_prefix"],
|
|
ai_prefix=kwargs["mem_ai_prefix"],
|
|
max_tokens=kwargs["mem_max_tokens"],
|
|
**params,
|
|
)
|
|
|
|
def set_info_retriever(self, **kwargs):
|
|
self.info_retriever = CustomRetriever(
|
|
k=kwargs["retri_top_k"], sql_file_path=kwargs["retri_kb_file_path"], verbose=kwargs["verbose"]
|
|
)
|
|
|
|
def set_rag_chain(self, **kwargs):
|
|
params = {"llm_kwargs": kwargs["gen_llm_kwargs"]} if kwargs.get("gen_llm_kwargs", None) else {}
|
|
self.rag_chain = RetrievalQA.from_chain_type(
|
|
llm=self.llm,
|
|
verbose=kwargs["verbose"],
|
|
chain_type="stuff",
|
|
retriever=self.info_retriever,
|
|
chain_type_kwargs={"prompt": kwargs["gen_qa_prompt"], "memory": self.memory},
|
|
**params,
|
|
)
|
|
|
|
def set_disambig_retriv(self, **kwargs):
|
|
params = {"llm_kwargs": kwargs["disambig_llm_kwargs"]} if kwargs.get("disambig_llm_kwargs", None) else {}
|
|
self.llm_chain_disambiguate = LLMChain(llm=self.llm, prompt=kwargs["disambig_prompt"], **params)
|
|
|
|
def disambiguity(input: str):
|
|
out = self.llm_chain_disambiguate.run(input=input, chat_history=self.memory.buffer, stop=["\n"])
|
|
return out.split("\n")[0]
|
|
|
|
self.info_retriever.set_rephrase_handler(disambiguity)
|
|
|
|
def load_doc_from_console(self, json_parse_args: Dict = {}):
|
|
print("Select files for constructing the retriever")
|
|
while True:
|
|
file = input("Enter a file path or press Enter directly without input to exit:").strip()
|
|
if file == "":
|
|
break
|
|
data_name = input("Enter a short description of the data:")
|
|
docs = DocumentLoader([[file, data_name.replace(" ", "_")]], **json_parse_args).all_data
|
|
self.documents.extend(docs)
|
|
self.docs_names.append(data_name)
|
|
self.split_docs_and_add_to_mem(**self.rag_config["chain"])
|
|
|
|
def load_doc_from_files(self, files, data_name="default_kb", json_parse_args: Dict = {}):
|
|
for file in files:
|
|
docs = DocumentLoader([[file, data_name.replace(" ", "_")]], **json_parse_args).all_data
|
|
self.documents.extend(docs)
|
|
self.docs_names.append(os.path.basename(file))
|
|
self.split_docs_and_add_to_mem(**self.rag_config["chain"])
|
|
|
|
def split_docs_and_add_to_mem(self, **kwargs):
|
|
doc_splits = self.split_docs(self.documents)
|
|
self.info_retriever.add_documents(
|
|
docs=doc_splits, cleanup="incremental", mode="by_source", embedding=self.embed_model
|
|
)
|
|
self.memory.initiate_document_retrieval_chain(self.llm, kwargs["gen_qa_prompt"], self.info_retriever)
|
|
|
|
def split_docs(self, documents):
|
|
doc_splits = self.text_splitter.split_documents(documents)
|
|
return doc_splits
|
|
|
|
def clear_docs(self, **kwargs):
|
|
self.documents = []
|
|
self.docs_names = []
|
|
self.info_retriever.clear_documents()
|
|
self.memory.initiate_document_retrieval_chain(self.llm, kwargs["gen_qa_prompt"], self.info_retriever)
|
|
|
|
def reset_config(self, rag_config):
|
|
self.rag_config = rag_config
|
|
self.set_embed_model(**self.rag_config["embed"])
|
|
self.set_text_splitter(**self.rag_config["splitter"])
|
|
self.set_memory(**self.rag_config["chain"])
|
|
self.set_info_retriever(**self.rag_config["retrieval"])
|
|
self.set_rag_chain(**self.rag_config["chain"])
|
|
if self.rag_config["chain"].get("disambig_prompt", None):
|
|
self.set_disambig_retriv(**self.rag_config["chain"])
|
|
|
|
def run(self, user_input: str, memory: ConversationBufferWithSummary) -> Tuple[str, ConversationBufferWithSummary]:
|
|
if memory:
|
|
memory.buffered_history.messages = memory.buffered_history.messages
|
|
memory.summarized_history_temp.messages = memory.summarized_history_temp.messages
|
|
result = self.rag_chain.run(
|
|
query=user_input,
|
|
stop=[memory.human_prefix + ": "],
|
|
rejection_trigger_keywords=ZH_RETRIEVAL_QA_TRIGGER_KEYWORDS,
|
|
rejection_answer=ZH_RETRIEVAL_QA_REJECTION_ANSWER,
|
|
)
|
|
return result, memory
|
|
|
|
def start_test_session(self):
|
|
"""
|
|
Simple session for testing purpose
|
|
"""
|
|
while True:
|
|
user_input = input("User: ")
|
|
if "END" == user_input:
|
|
print("Agent: Happy to chat with you :)")
|
|
break
|
|
agent_response, self.memory = self.run(user_input, self.memory)
|
|
print(f"Agent: {agent_response}")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
# Initialize an Langchain LLM(here we use ChatGPT as an example)
|
|
import config
|
|
from langchain.llms import OpenAI
|
|
|
|
# you need to: export OPENAI_API_KEY="YOUR_OPENAI_API_KEY"
|
|
llm = OpenAI(openai_api_key=os.getenv("OPENAI_API_KEY"))
|
|
|
|
# chatgpt cannot control temperature, do_sample, etc.
|
|
all_config = config.ALL_CONFIG
|
|
all_config["chain"]["mem_llm_kwargs"] = None
|
|
all_config["chain"]["disambig_llm_kwargs"] = None
|
|
all_config["chain"]["gen_llm_kwargs"] = None
|
|
|
|
rag = RAG_ChatBot(llm, all_config)
|
|
rag.load_doc_from_console()
|
|
rag.start_test_session()
|