""" Multilingual retrieval based conversation system """ from typing import List from colossalqa.data_loader.document_loader import DocumentLoader from colossalqa.mylogging import get_logger from colossalqa.retrieval_conversation_en import EnglishRetrievalConversation from colossalqa.retrieval_conversation_zh import ChineseRetrievalConversation from colossalqa.retriever import CustomRetriever from colossalqa.text_splitter import ChineseTextSplitter from colossalqa.utils import detect_lang_naive from langchain.embeddings import HuggingFaceEmbeddings from langchain.text_splitter import RecursiveCharacterTextSplitter, TextSplitter logger = get_logger() class UniversalRetrievalConversation: """ Wrapper class for bilingual retrieval conversation system """ def __init__( self, embedding_model_path: str = "moka-ai/m3e-base", embedding_model_device: str = "cpu", zh_model_path: str = None, zh_model_name: str = None, en_model_path: str = None, en_model_name: str = None, sql_file_path: str = None, files_zh: List[List[str]] = None, files_en: List[List[str]] = None, text_splitter_chunk_size=100, text_splitter_chunk_overlap=10, ) -> None: """ Wrapper for multilingual retrieval qa class (Chinese + English) Args: embedding_model_path: local or huggingface embedding model embedding_model_device: files_zh: [[file_path, name_of_file, separator],...] defines the files used as supporting documents for Chinese retrieval QA files_en: [[file_path, name_of_file, separator],...] defines the files used as supporting documents for English retrieval QA """ self.embedding = HuggingFaceEmbeddings( model_name=embedding_model_path, model_kwargs={"device": embedding_model_device}, encode_kwargs={"normalize_embeddings": False}, ) print("Select files for constructing Chinese retriever") docs_zh = self.load_supporting_docs( files=files_zh, text_splitter=ChineseTextSplitter( chunk_size=text_splitter_chunk_size, chunk_overlap=text_splitter_chunk_overlap ), ) # Create retriever self.information_retriever_zh = CustomRetriever( k=3, sql_file_path=sql_file_path.replace(".db", "_zh.db"), verbose=True ) self.information_retriever_zh.add_documents( docs=docs_zh, cleanup="incremental", mode="by_source", embedding=self.embedding ) print("Select files for constructing English retriever") docs_en = self.load_supporting_docs( files=files_en, text_splitter=RecursiveCharacterTextSplitter( chunk_size=text_splitter_chunk_size, chunk_overlap=text_splitter_chunk_overlap ), ) # Create retriever self.information_retriever_en = CustomRetriever( k=3, sql_file_path=sql_file_path.replace(".db", "_en.db"), verbose=True ) self.information_retriever_en.add_documents( docs=docs_en, cleanup="incremental", mode="by_source", embedding=self.embedding ) self.chinese_retrieval_conversation = ChineseRetrievalConversation.from_retriever( self.information_retriever_zh, model_path=zh_model_path, model_name=zh_model_name ) self.english_retrieval_conversation = EnglishRetrievalConversation.from_retriever( self.information_retriever_en, model_path=en_model_path, model_name=en_model_name ) self.memory = None def load_supporting_docs(self, files: List[List[str]] = None, text_splitter: TextSplitter = None): """ Load supporting documents, currently, all documents will be stored in one vector store """ documents = [] if files: for file in files: retriever_data = DocumentLoader([[file["data_path"], file["name"]]]).all_data splits = text_splitter.split_documents(retriever_data) documents.extend(splits) else: while True: file = input("Select a file to load or press Enter to exit:") if file == "": break data_name = input("Enter a short description of the data:") separator = input( "Enter a separator to force separating text into chunks, if no separator is given, the default separator is '\\n\\n', press ENTER directly to skip:" ) separator = separator if separator != "" else "\n\n" retriever_data = DocumentLoader([[file, data_name.replace(" ", "_")]]).all_data # Split splits = text_splitter.split_documents(retriever_data) documents.extend(splits) return documents def start_test_session(self): """ Simple multilingual session for testing purpose, with naive language selection mechanism """ while True: user_input = input("User: ") lang = detect_lang_naive(user_input) if "END" == user_input: print("Agent: Happy to chat with you :)") break agent_response = self.run(user_input, which_language=lang) print(f"Agent: {agent_response}") def run(self, user_input: str, which_language=str): """ Generate the response given the user input and a str indicates the language requirement of the output string """ assert which_language in ["zh", "en"] if which_language == "zh": agent_response, self.memory = self.chinese_retrieval_conversation.run(user_input, self.memory) else: agent_response, self.memory = self.english_retrieval_conversation.run(user_input, self.memory) return agent_response.split("\n")[0]