mirror of https://github.com/hpcaitech/ColossalAI
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
219 lines
8.6 KiB
219 lines
8.6 KiB
"""
|
|
Chain for question-answering against a vector database.
|
|
|
|
Modified from Original Source
|
|
|
|
This code is based on LangChain Ai's langchain, which can be found at
|
|
https://github.com/langchain-ai/langchain
|
|
The original code is licensed under the MIT license.
|
|
"""
|
|
from __future__ import annotations
|
|
|
|
import copy
|
|
import inspect
|
|
from typing import Any, Dict, List, Optional
|
|
|
|
from colossalqa.chain.retrieval_qa.load_chain import load_qa_chain
|
|
from colossalqa.chain.retrieval_qa.stuff import CustomStuffDocumentsChain
|
|
from langchain.callbacks.manager import AsyncCallbackManagerForChainRun, CallbackManagerForChainRun, Callbacks
|
|
from langchain.chains.llm import LLMChain
|
|
from langchain.chains.question_answering.stuff_prompt import PROMPT_SELECTOR
|
|
from langchain.chains.retrieval_qa.base import BaseRetrievalQA
|
|
from langchain.prompts import PromptTemplate
|
|
from langchain.pydantic_v1 import Field
|
|
from langchain.schema import BaseRetriever, Document
|
|
from langchain.schema.language_model import BaseLanguageModel
|
|
|
|
|
|
class CustomBaseRetrievalQA(BaseRetrievalQA):
|
|
"""Base class for question-answering chains."""
|
|
|
|
@classmethod
|
|
def from_llm(
|
|
cls,
|
|
llm: BaseLanguageModel,
|
|
prompt: Optional[PromptTemplate] = None,
|
|
callbacks: Callbacks = None,
|
|
**kwargs: Any,
|
|
) -> BaseRetrievalQA:
|
|
"""Initialize from LLM."""
|
|
llm_kwargs = kwargs.pop("llm_kwargs", {})
|
|
_prompt = prompt or PROMPT_SELECTOR.get_prompt(llm)
|
|
llm_chain = LLMChain(llm=llm, prompt=_prompt, callbacks=callbacks, llm_kwargs=llm_kwargs)
|
|
document_prompt = kwargs.get(
|
|
"document_prompt", PromptTemplate(input_variables=["page_content"], template="Context:\n{page_content}")
|
|
)
|
|
combine_documents_chain = CustomStuffDocumentsChain(
|
|
llm_chain=llm_chain,
|
|
document_variable_name="context",
|
|
document_prompt=document_prompt,
|
|
callbacks=callbacks,
|
|
)
|
|
|
|
return cls(
|
|
combine_documents_chain=combine_documents_chain,
|
|
callbacks=callbacks,
|
|
**kwargs,
|
|
)
|
|
|
|
@classmethod
|
|
def from_chain_type(
|
|
cls,
|
|
llm: BaseLanguageModel,
|
|
chain_type: str = "stuff",
|
|
chain_type_kwargs: Optional[dict] = None,
|
|
**kwargs: Any,
|
|
) -> BaseRetrievalQA:
|
|
"""Load chain from chain type."""
|
|
llm_kwargs = kwargs.pop("llm_kwargs", {})
|
|
_chain_type_kwargs = chain_type_kwargs or {}
|
|
combine_documents_chain = load_qa_chain(llm, chain_type=chain_type, **_chain_type_kwargs, llm_kwargs=llm_kwargs)
|
|
return cls(combine_documents_chain=combine_documents_chain, **kwargs)
|
|
|
|
def _call(
|
|
self,
|
|
inputs: Dict[str, Any],
|
|
run_manager: Optional[CallbackManagerForChainRun] = None,
|
|
) -> Dict[str, Any]:
|
|
"""Run get_relevant_text and llm on input query.
|
|
|
|
If chain has 'return_source_documents' as 'True', returns
|
|
the retrieved documents as well under the key 'source_documents'.
|
|
|
|
Example:
|
|
.. code-block:: python
|
|
|
|
res = indexqa({'query': 'This is my query'})
|
|
answer, docs = res['result'], res['source_documents']
|
|
"""
|
|
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
|
|
question = inputs[self.input_key]
|
|
accepts_run_manager = "run_manager" in inspect.signature(self._get_docs).parameters
|
|
if accepts_run_manager:
|
|
docs = self._get_docs(question, run_manager=_run_manager)
|
|
else:
|
|
docs = self._get_docs(question) # type: ignore[call-arg]
|
|
|
|
kwargs = {
|
|
k: v
|
|
for k, v in inputs.items()
|
|
if k in ["stop", "temperature", "top_k", "top_p", "max_new_tokens", "doc_prefix"]
|
|
}
|
|
if self.combine_documents_chain.memory is not None:
|
|
buffered_history_backup, summarized_history_temp_backup = copy.deepcopy(
|
|
self.combine_documents_chain.memory.buffered_history
|
|
), copy.deepcopy(self.combine_documents_chain.memory.summarized_history_temp)
|
|
else:
|
|
buffered_history_backup = None
|
|
summarized_history_temp_backup = None
|
|
|
|
answer = self.combine_documents_chain.run(
|
|
input_documents=docs, question=question, callbacks=_run_manager.get_child(), **kwargs
|
|
)
|
|
if summarized_history_temp_backup is not None and buffered_history_backup is not None:
|
|
(
|
|
self.combine_documents_chain.memory.buffered_history,
|
|
self.combine_documents_chain.memory.summarized_history_temp,
|
|
) = copy.deepcopy(buffered_history_backup), copy.deepcopy(summarized_history_temp_backup)
|
|
|
|
# if rejection_trigger_keywords is not given, return the response from LLM directly
|
|
rejection_trigger_keywords = inputs.get("rejection_trigger_keywords", [])
|
|
answer = answer if all([rej not in answer for rej in rejection_trigger_keywords]) else None
|
|
if answer is None:
|
|
answer = inputs.get("rejection_answer", "抱歉,根据提供的信息无法回答该问题。")
|
|
if self.combine_documents_chain.memory is not None:
|
|
self.combine_documents_chain.memory.save_context({"question": question}, {"output": answer})
|
|
|
|
if self.return_source_documents:
|
|
return {self.output_key: answer, "source_documents": docs}
|
|
else:
|
|
return {self.output_key: answer}
|
|
|
|
async def _acall(
|
|
self,
|
|
inputs: Dict[str, Any],
|
|
run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
|
|
) -> Dict[str, Any]:
|
|
"""Run get_relevant_text and llm on input query.
|
|
|
|
If chain has 'return_source_documents' as 'True', returns
|
|
the retrieved documents as well under the key 'source_documents'.
|
|
|
|
Example:
|
|
.. code-block:: python
|
|
|
|
res = indexqa({'query': 'This is my query'})
|
|
answer, docs = res['result'], res['source_documents']
|
|
"""
|
|
_run_manager = run_manager or AsyncCallbackManagerForChainRun.get_noop_manager()
|
|
question = inputs[self.input_key]
|
|
accepts_run_manager = "run_manager" in inspect.signature(self._aget_docs).parameters
|
|
if accepts_run_manager:
|
|
docs = await self._aget_docs(question, run_manager=_run_manager)
|
|
else:
|
|
docs = await self._aget_docs(question) # type: ignore[call-arg]
|
|
kwargs = {
|
|
k: v
|
|
for k, v in inputs.items()
|
|
if k in ["stop", "temperature", "top_k", "top_p", "max_new_tokens", "doc_prefix"]
|
|
}
|
|
answer = await self.combine_documents_chain.arun(
|
|
input_documents=docs, question=question, callbacks=_run_manager.get_child(), **kwargs
|
|
)
|
|
# if rejection_trigger_keywords is not given, return the response from LLM directly
|
|
rejection_trigger_keywords = inputs.get("rejection_trigger_keywords", [])
|
|
answer = (
|
|
answer
|
|
if all([rej not in answer for rej in rejection_trigger_keywords]) or len(rejection_trigger_keywords) == 0
|
|
else None
|
|
)
|
|
if answer is None:
|
|
answer = inputs.get("rejection_answer", "抱歉,根据提供的信息无法回答该问题。")
|
|
self.combine_documents_chain.memory.save_context({"question": question}, {"output": answer})
|
|
|
|
if self.return_source_documents:
|
|
return {self.output_key: answer, "source_documents": docs}
|
|
else:
|
|
return {self.output_key: answer}
|
|
|
|
|
|
class RetrievalQA(CustomBaseRetrievalQA):
|
|
"""Chain for question-answering against an index.
|
|
|
|
Example:
|
|
.. code-block:: python
|
|
|
|
from langchain.llms import OpenAI
|
|
from langchain.chains import RetrievalQA
|
|
from langchain.faiss import FAISS
|
|
from langchain.vectorstores.base import VectorStoreRetriever
|
|
retriever = VectorStoreRetriever(vectorstore=FAISS(...))
|
|
retrievalQA = RetrievalQA.from_llm(llm=OpenAI(), retriever=retriever)
|
|
|
|
"""
|
|
|
|
retriever: BaseRetriever = Field(exclude=True)
|
|
|
|
def _get_docs(
|
|
self,
|
|
question: str,
|
|
*,
|
|
run_manager: CallbackManagerForChainRun,
|
|
) -> List[Document]:
|
|
"""Get docs."""
|
|
return self.retriever.get_relevant_documents(question, callbacks=run_manager.get_child())
|
|
|
|
async def _aget_docs(
|
|
self,
|
|
question: str,
|
|
*,
|
|
run_manager: AsyncCallbackManagerForChainRun,
|
|
) -> List[Document]:
|
|
"""Get docs."""
|
|
return await self.retriever.aget_relevant_documents(question, callbacks=run_manager.get_child())
|
|
|
|
@property
|
|
def _chain_type(self) -> str:
|
|
"""Return the chain type."""
|
|
return "retrieval_qa"
|