You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
ColossalAI/applications/ColossalQA/colossalqa/chain/retrieval_qa/base.py

219 lines
8.6 KiB

"""
Chain for question-answering against a vector database.
Modified from Original Source
This code is based on LangChain Ai's langchain, which can be found at
https://github.com/langchain-ai/langchain
The original code is licensed under the MIT license.
"""
from __future__ import annotations
import copy
import inspect
from typing import Any, Dict, List, Optional
from colossalqa.chain.retrieval_qa.load_chain import load_qa_chain
from colossalqa.chain.retrieval_qa.stuff import CustomStuffDocumentsChain
from langchain.callbacks.manager import AsyncCallbackManagerForChainRun, CallbackManagerForChainRun, Callbacks
from langchain.chains.llm import LLMChain
from langchain.chains.question_answering.stuff_prompt import PROMPT_SELECTOR
from langchain.chains.retrieval_qa.base import BaseRetrievalQA
from langchain.prompts import PromptTemplate
from langchain.pydantic_v1 import Field
from langchain.schema import BaseRetriever, Document
from langchain.schema.language_model import BaseLanguageModel
class CustomBaseRetrievalQA(BaseRetrievalQA):
"""Base class for question-answering chains."""
@classmethod
def from_llm(
cls,
llm: BaseLanguageModel,
prompt: Optional[PromptTemplate] = None,
callbacks: Callbacks = None,
**kwargs: Any,
) -> BaseRetrievalQA:
"""Initialize from LLM."""
llm_kwargs = kwargs.pop("llm_kwargs", {})
_prompt = prompt or PROMPT_SELECTOR.get_prompt(llm)
llm_chain = LLMChain(llm=llm, prompt=_prompt, callbacks=callbacks, llm_kwargs=llm_kwargs)
document_prompt = kwargs.get(
"document_prompt", PromptTemplate(input_variables=["page_content"], template="Context:\n{page_content}")
)
combine_documents_chain = CustomStuffDocumentsChain(
llm_chain=llm_chain,
document_variable_name="context",
document_prompt=document_prompt,
callbacks=callbacks,
)
return cls(
combine_documents_chain=combine_documents_chain,
callbacks=callbacks,
**kwargs,
)
@classmethod
def from_chain_type(
cls,
llm: BaseLanguageModel,
chain_type: str = "stuff",
chain_type_kwargs: Optional[dict] = None,
**kwargs: Any,
) -> BaseRetrievalQA:
"""Load chain from chain type."""
llm_kwargs = kwargs.pop("llm_kwargs", {})
_chain_type_kwargs = chain_type_kwargs or {}
combine_documents_chain = load_qa_chain(llm, chain_type=chain_type, **_chain_type_kwargs, llm_kwargs=llm_kwargs)
return cls(combine_documents_chain=combine_documents_chain, **kwargs)
def _call(
self,
inputs: Dict[str, Any],
run_manager: Optional[CallbackManagerForChainRun] = None,
) -> Dict[str, Any]:
"""Run get_relevant_text and llm on input query.
If chain has 'return_source_documents' as 'True', returns
the retrieved documents as well under the key 'source_documents'.
Example:
.. code-block:: python
res = indexqa({'query': 'This is my query'})
answer, docs = res['result'], res['source_documents']
"""
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
question = inputs[self.input_key]
accepts_run_manager = "run_manager" in inspect.signature(self._get_docs).parameters
if accepts_run_manager:
docs = self._get_docs(question, run_manager=_run_manager)
else:
docs = self._get_docs(question) # type: ignore[call-arg]
kwargs = {
k: v
for k, v in inputs.items()
if k in ["stop", "temperature", "top_k", "top_p", "max_new_tokens", "doc_prefix"]
}
if self.combine_documents_chain.memory is not None:
buffered_history_backup, summarized_history_temp_backup = copy.deepcopy(
self.combine_documents_chain.memory.buffered_history
), copy.deepcopy(self.combine_documents_chain.memory.summarized_history_temp)
else:
buffered_history_backup = None
summarized_history_temp_backup = None
answer = self.combine_documents_chain.run(
input_documents=docs, question=question, callbacks=_run_manager.get_child(), **kwargs
)
if summarized_history_temp_backup is not None and buffered_history_backup is not None:
(
self.combine_documents_chain.memory.buffered_history,
self.combine_documents_chain.memory.summarized_history_temp,
) = copy.deepcopy(buffered_history_backup), copy.deepcopy(summarized_history_temp_backup)
# if rejection_trigger_keywords is not given, return the response from LLM directly
rejection_trigger_keywords = inputs.get("rejection_trigger_keywords", [])
answer = answer if all([rej not in answer for rej in rejection_trigger_keywords]) else None
if answer is None:
answer = inputs.get("rejection_answer", "抱歉,根据提供的信息无法回答该问题。")
if self.combine_documents_chain.memory is not None:
self.combine_documents_chain.memory.save_context({"question": question}, {"output": answer})
if self.return_source_documents:
return {self.output_key: answer, "source_documents": docs}
else:
return {self.output_key: answer}
async def _acall(
self,
inputs: Dict[str, Any],
run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
) -> Dict[str, Any]:
"""Run get_relevant_text and llm on input query.
If chain has 'return_source_documents' as 'True', returns
the retrieved documents as well under the key 'source_documents'.
Example:
.. code-block:: python
res = indexqa({'query': 'This is my query'})
answer, docs = res['result'], res['source_documents']
"""
_run_manager = run_manager or AsyncCallbackManagerForChainRun.get_noop_manager()
question = inputs[self.input_key]
accepts_run_manager = "run_manager" in inspect.signature(self._aget_docs).parameters
if accepts_run_manager:
docs = await self._aget_docs(question, run_manager=_run_manager)
else:
docs = await self._aget_docs(question) # type: ignore[call-arg]
kwargs = {
k: v
for k, v in inputs.items()
if k in ["stop", "temperature", "top_k", "top_p", "max_new_tokens", "doc_prefix"]
}
answer = await self.combine_documents_chain.arun(
input_documents=docs, question=question, callbacks=_run_manager.get_child(), **kwargs
)
# if rejection_trigger_keywords is not given, return the response from LLM directly
rejection_trigger_keywords = inputs.get("rejection_trigger_keywords", [])
answer = (
answer
if all([rej not in answer for rej in rejection_trigger_keywords]) or len(rejection_trigger_keywords) == 0
else None
)
if answer is None:
answer = inputs.get("rejection_answer", "抱歉,根据提供的信息无法回答该问题。")
self.combine_documents_chain.memory.save_context({"question": question}, {"output": answer})
if self.return_source_documents:
return {self.output_key: answer, "source_documents": docs}
else:
return {self.output_key: answer}
class RetrievalQA(CustomBaseRetrievalQA):
"""Chain for question-answering against an index.
Example:
.. code-block:: python
from langchain.llms import OpenAI
from langchain.chains import RetrievalQA
from langchain.faiss import FAISS
from langchain.vectorstores.base import VectorStoreRetriever
retriever = VectorStoreRetriever(vectorstore=FAISS(...))
retrievalQA = RetrievalQA.from_llm(llm=OpenAI(), retriever=retriever)
"""
retriever: BaseRetriever = Field(exclude=True)
def _get_docs(
self,
question: str,
*,
run_manager: CallbackManagerForChainRun,
) -> List[Document]:
"""Get docs."""
return self.retriever.get_relevant_documents(question, callbacks=run_manager.get_child())
async def _aget_docs(
self,
question: str,
*,
run_manager: AsyncCallbackManagerForChainRun,
) -> List[Document]:
"""Get docs."""
return await self.retriever.aget_relevant_documents(question, callbacks=run_manager.get_child())
@property
def _chain_type(self) -> str:
"""Return the chain type."""
return "retrieval_qa"