[NFC] polish applications/ColossalQA/colossalqa/chain/retrieval_qa/base.py code style (#5232)

fix/format
Cunxiao Du 2024-01-25 13:13:57 +08:00 committed by GitHub
parent b0b53a171c
commit 65f21a2556
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 11 additions and 7 deletions

View File

@ -24,6 +24,7 @@ from langchain.pydantic_v1 import Field
from langchain.schema import BaseRetriever, Document from langchain.schema import BaseRetriever, Document
from langchain.schema.language_model import BaseLanguageModel from langchain.schema.language_model import BaseLanguageModel
class CustomBaseRetrievalQA(BaseRetrievalQA): class CustomBaseRetrievalQA(BaseRetrievalQA):
"""Base class for question-answering chains.""" """Base class for question-answering chains."""
@ -98,7 +99,6 @@ class CustomBaseRetrievalQA(BaseRetrievalQA):
for k, v in inputs.items() for k, v in inputs.items()
if k in ["stop", "temperature", "top_k", "top_p", "max_new_tokens", "doc_prefix"] if k in ["stop", "temperature", "top_k", "top_p", "max_new_tokens", "doc_prefix"]
} }
answers = []
if self.combine_documents_chain.memory is not None: if self.combine_documents_chain.memory is not None:
buffered_history_backup, summarized_history_temp_backup = copy.deepcopy( buffered_history_backup, summarized_history_temp_backup = copy.deepcopy(
self.combine_documents_chain.memory.buffered_history self.combine_documents_chain.memory.buffered_history
@ -117,10 +117,10 @@ class CustomBaseRetrievalQA(BaseRetrievalQA):
) = copy.deepcopy(buffered_history_backup), copy.deepcopy(summarized_history_temp_backup) ) = copy.deepcopy(buffered_history_backup), copy.deepcopy(summarized_history_temp_backup)
# if rejection_trigger_keywords is not given, return the response from LLM directly # if rejection_trigger_keywords is not given, return the response from LLM directly
rejection_trigger_keywrods = inputs.get('rejection_trigger_keywrods', []) rejection_trigger_keywrods = inputs.get("rejection_trigger_keywrods", [])
answer = answer if all([rej not in answer for rej in rejection_trigger_keywrods]) else None answer = answer if all([rej not in answer for rej in rejection_trigger_keywrods]) else None
if answer is None: if answer is None:
answer = inputs.get('rejection_answer', "抱歉,根据提供的信息无法回答该问题。") answer = inputs.get("rejection_answer", "抱歉,根据提供的信息无法回答该问题。")
if self.combine_documents_chain.memory is not None: if self.combine_documents_chain.memory is not None:
self.combine_documents_chain.memory.save_context({"question": question}, {"output": answer}) self.combine_documents_chain.memory.save_context({"question": question}, {"output": answer})
@ -161,10 +161,14 @@ class CustomBaseRetrievalQA(BaseRetrievalQA):
input_documents=docs, question=question, callbacks=_run_manager.get_child(), **kwargs input_documents=docs, question=question, callbacks=_run_manager.get_child(), **kwargs
) )
# if rejection_trigger_keywords is not given, return the response from LLM directly # if rejection_trigger_keywords is not given, return the response from LLM directly
rejection_trigger_keywrods = inputs.get('rejection_trigger_keywrods', []) rejection_trigger_keywrods = inputs.get("rejection_trigger_keywrods", [])
answer = answer if all([rej not in answer for rej in rejection_trigger_keywrods]) or len(rejection_trigger_keywrods)==0 else None answer = (
answer
if all([rej not in answer for rej in rejection_trigger_keywrods]) or len(rejection_trigger_keywrods) == 0
else None
)
if answer is None: if answer is None:
answer = inputs.get('rejection_answer', "抱歉,根据提供的信息无法回答该问题。") answer = inputs.get("rejection_answer", "抱歉,根据提供的信息无法回答该问题。")
self.combine_documents_chain.memory.save_context({"question": question}, {"output": answer}) self.combine_documents_chain.memory.save_context({"question": question}, {"output": answer})
if self.return_source_documents: if self.return_source_documents: