ColossalAI/applications/ColossalQA/colossalqa/retrieval_conversation_univ...

139 lines
5.9 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

"""
Multilingual retrieval based conversation system
"""
from typing import List
from colossalqa.data_loader.document_loader import DocumentLoader
from colossalqa.mylogging import get_logger
from colossalqa.retrieval_conversation_en import EnglishRetrievalConversation
from colossalqa.retrieval_conversation_zh import ChineseRetrievalConversation
from colossalqa.retriever import CustomRetriever
from colossalqa.text_splitter import ChineseTextSplitter
from colossalqa.utils import detect_lang_naive
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.text_splitter import RecursiveCharacterTextSplitter, TextSplitter
logger = get_logger()
class UniversalRetrievalConversation:
"""
Wrapper class for bilingual retrieval conversation system
"""
def __init__(
self,
embedding_model_path: str = "moka-ai/m3e-base",
embedding_model_device: str = "cpu",
zh_model_path: str = None,
zh_model_name: str = None,
en_model_path: str = None,
en_model_name: str = None,
sql_file_path: str = None,
files_zh: List[List[str]] = None,
files_en: List[List[str]] = None,
text_splitter_chunk_size=100,
text_splitter_chunk_overlap=10,
) -> None:
"""
Warpper for multilingual retrieval qa class (Chinese + English)
Args:
embedding_model_path: local or huggingface embedding model
embedding_model_device:
files_zh: [[file_path, name_of_file, separator],...] defines the files used as supporting documents for Chinese retrieval QA
files_en: [[file_path, name_of_file, separator],...] defines the files used as supporting documents for English retrieval QA
"""
self.embedding = HuggingFaceEmbeddings(
model_name=embedding_model_path,
model_kwargs={"device": embedding_model_device},
encode_kwargs={"normalize_embeddings": False},
)
print("Select files for constructing Chinese retriever")
docs_zh = self.load_supporting_docs(
files=files_zh,
text_splitter=ChineseTextSplitter(
chunk_size=text_splitter_chunk_size, chunk_overlap=text_splitter_chunk_overlap
),
)
# Create retriever
self.information_retriever_zh = CustomRetriever(
k=3, sql_file_path=sql_file_path.replace(".db", "_zh.db"), verbose=True
)
self.information_retriever_zh.add_documents(
docs=docs_zh, cleanup="incremental", mode="by_source", embedding=self.embedding
)
print("Select files for constructing English retriever")
docs_en = self.load_supporting_docs(
files=files_en,
text_splitter=RecursiveCharacterTextSplitter(
chunk_size=text_splitter_chunk_size, chunk_overlap=text_splitter_chunk_overlap
),
)
# Create retriever
self.information_retriever_en = CustomRetriever(
k=3, sql_file_path=sql_file_path.replace(".db", "_en.db"), verbose=True
)
self.information_retriever_en.add_documents(
docs=docs_en, cleanup="incremental", mode="by_source", embedding=self.embedding
)
self.chinese_retrieval_conversation = ChineseRetrievalConversation.from_retriever(
self.information_retriever_zh, model_path=zh_model_path, model_name=zh_model_name
)
self.english_retrieval_conversation = EnglishRetrievalConversation.from_retriever(
self.information_retriever_en, model_path=en_model_path, model_name=en_model_name
)
self.memory = None
def load_supporting_docs(self, files: List[List[str]] = None, text_splitter: TextSplitter = None):
"""
Load supporting documents, currently, all documents will be stored in one vector store
"""
documents = []
if files:
for file in files:
retriever_data = DocumentLoader([[file["data_path"], file["name"]]]).all_data
splits = text_splitter.split_documents(retriever_data)
documents.extend(splits)
else:
while True:
file = input("Select a file to load or press Enter to exit:")
if file == "":
break
data_name = input("Enter a short description of the data:")
separator = input(
"Enter a separator to force separating text into chunks, if no separator is given, the defaut separator is '\\n\\n', press ENTER directly to skip:"
)
separator = separator if separator != "" else "\n\n"
retriever_data = DocumentLoader([[file, data_name.replace(" ", "_")]]).all_data
# Split
splits = text_splitter.split_documents(retriever_data)
documents.extend(splits)
return documents
def start_test_session(self):
"""
Simple multilingual session for testing purpose, with naive language selection mechanism
"""
while True:
user_input = input("User: ")
lang = detect_lang_naive(user_input)
if "END" == user_input:
print("Agent: Happy to chat with you )")
break
agent_response = self.run(user_input, which_language=lang)
print(f"Agent: {agent_response}")
def run(self, user_input: str, which_language=str):
"""
Generate the response given the user input and a str indicates the language requirement of the output string
"""
assert which_language in ["zh", "en"]
if which_language == "zh":
agent_response, self.memory = self.chinese_retrieval_conversation.run(user_input, self.memory)
else:
agent_response, self.memory = self.english_retrieval_conversation.run(user_input, self.memory)
return agent_response.split("\n")[0]