""" Implement a memory class for storing conversation history Support long term and short term memory """ from typing import Any, Dict, List from colossalqa.chain.memory.summary import ConversationSummaryMemory from colossalqa.chain.retrieval_qa.load_chain import load_qa_chain from langchain.chains.combine_documents.base import BaseCombineDocumentsChain from langchain.memory.chat_message_histories.in_memory import ChatMessageHistory from langchain.schema import BaseChatMessageHistory from langchain.schema.messages import BaseMessage from langchain.schema.retriever import BaseRetriever from pydantic import Field class ConversationBufferWithSummary(ConversationSummaryMemory): """Memory class for storing information about entities.""" # Define dictionary to store information about entities. # Store the most recent conversation history buffered_history: BaseChatMessageHistory = Field(default_factory=ChatMessageHistory) # Temp buffer summarized_history_temp: BaseChatMessageHistory = Field(default_factory=ChatMessageHistory) human_prefix: str = "Human" ai_prefix: str = "Assistant" buffer: str = "" # Formated conversation in str existing_summary: str = "" # Summarization of stale converstion in str # Define key to pass information about entities into prompt. memory_key: str = "chat_history" input_key: str = "question" retriever: BaseRetriever = None max_tokens: int = 2000 chain: BaseCombineDocumentsChain = None input_chain_type_kwargs: List = {} @property def buffer(self) -> Any: """String buffer of memory.""" return self.buffer_as_messages if self.return_messages else self.buffer_as_str @property def buffer_as_str(self) -> str: """Exposes the buffer as a string in case return_messages is True.""" self.buffer = self.format_dialogue() return self.buffer @property def buffer_as_messages(self) -> List[BaseMessage]: """Exposes the buffer as a list of messages in case return_messages is False.""" return self.buffered_history.messages def clear(self): """Clear all the memory""" self.buffered_history.clear() self.summarized_history_temp.clear() def initiate_document_retrieval_chain( self, llm: Any, prompt_template: Any, retriever: Any, chain_type_kwargs: Dict[str, Any] = {} ) -> None: """ Since we need to calculate the length of the prompt, we need to initiate a retrieval chain to calculate the length of the prompt. Args: llm: the language model for the retrieval chain (we won't actually return the output) prompt_template: the prompt template for constructing the retrieval chain retriever: the retriever for the retrieval chain max_tokens: the max length of the prompt (not include the output) chain_type_kwargs: the kwargs for the retrieval chain memory_key: the key for the chat history input_key: the key for the input query """ self.retriever = retriever input_chain_type_kwargs = {k: v for k, v in chain_type_kwargs.items() if k not in [self.memory_key]} self.input_chain_type_kwargs = input_chain_type_kwargs self.chain = load_qa_chain(llm, chain_type="stuff", prompt=prompt_template, **self.input_chain_type_kwargs) @property def memory_variables(self) -> List[str]: """Define the variables we are providing to the prompt.""" return [self.memory_key] def format_dialogue(self, lang: str = "en") -> str: """Format memory into two parts--- summarization of historical conversation and most recent conversation""" if len(self.summarized_history_temp.messages) != 0: for i in range(int(len(self.summarized_history_temp.messages) / 2)): self.existing_summary = ( self.predict_new_summary( self.summarized_history_temp.messages[i * 2 : i * 2 + 2], self.existing_summary, stop=["\n\n"] ) .strip() .split("\n")[0] .strip() ) for i in range(int(len(self.summarized_history_temp.messages) / 2)): self.summarized_history_temp.messages.pop(0) self.summarized_history_temp.messages.pop(0) conversation_buffer = [] for t in self.buffered_history.messages: if t.type == "human": prefix = self.human_prefix else: prefix = self.ai_prefix conversation_buffer.append(prefix + ": " + t.content) conversation_buffer = "\n".join(conversation_buffer) if len(self.existing_summary) > 0: if lang == "en": message = f"A summarization of historical conversation:\n{self.existing_summary}\nMost recent conversation:\n{conversation_buffer}" elif lang == "zh": message = f"历史对话概要:\n{self.existing_summary}\n最近的对话:\n{conversation_buffer}" else: raise ValueError("Unsupported language") return message else: message = conversation_buffer return message def get_conversation_length(self): """Get the length of the formatted conversation""" prompt = self.format_dialogue() length = self.llm.get_num_tokens(prompt) return length def load_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, str]: """Load the memory variables. Summarize oversize conversation to fit into the length constraint defined by max_tokene Args: inputs: the kwargs of the chain of your definition Returns: a dict that maps from memory key to the formated dialogue the formated dialogue has the following format if conversation is too long: A summarization of historical conversation: {summarization} Most recent conversation: Human: XXX Assistant: XXX ... otherwise Human: XXX Assistant: XXX ... """ # Calculate remain length if "input_documents" in inputs: # Run in a retrieval qa chain docs = inputs["input_documents"] else: # For test docs = self.retriever.get_relevant_documents(inputs[self.input_key]) inputs[self.memory_key] = "" inputs = {k: v for k, v in inputs.items() if k in [self.chain.input_key, self.input_key, self.memory_key]} prompt_length = self.chain.prompt_length(docs, **inputs) remain = self.max_tokens - prompt_length while self.get_conversation_length() > remain: if len(self.buffered_history.messages) <= 2: raise RuntimeError("Exceed max_tokens, trunk size of retrieved documents is too large") temp = self.buffered_history.messages.pop(0) self.summarized_history_temp.messages.append(temp) temp = self.buffered_history.messages.pop(0) self.summarized_history_temp.messages.append(temp) return {self.memory_key: self.format_dialogue()} def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, str]) -> None: """Save context from this conversation to buffer.""" input_str, output_str = self._get_input_output(inputs, outputs) self.buffered_history.add_user_message(input_str.strip()) self.buffered_history.add_ai_message(output_str.strip())