mirror of https://github.com/hpcaitech/ColossalAI
126 lines
4.9 KiB
Python
126 lines
4.9 KiB
Python
"""
|
||
Script for the multilingual conversation based experimental AI agent
|
||
We used ChatGPT as the language model
|
||
You need openai api key to run this script
|
||
"""
|
||
|
||
import argparse
|
||
import os
|
||
|
||
from colossalqa.data_loader.document_loader import DocumentLoader
|
||
from colossalqa.data_loader.table_dataloader import TableLoader
|
||
from langchain import LLMChain, OpenAI
|
||
from langchain.agents import Tool, ZeroShotAgent
|
||
from langchain.agents.agent import AgentExecutor
|
||
from langchain.agents.agent_toolkits import create_retriever_tool
|
||
from langchain.embeddings.openai import OpenAIEmbeddings
|
||
from langchain.llms import OpenAI
|
||
from langchain.memory import ChatMessageHistory, ConversationBufferMemory
|
||
from langchain.memory.chat_memory import ChatMessageHistory
|
||
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
||
from langchain.utilities import SQLDatabase
|
||
from langchain.vectorstores import Chroma
|
||
from langchain_experimental.sql import SQLDatabaseChain
|
||
|
||
if __name__ == "__main__":
|
||
parser = argparse.ArgumentParser(description="Experimental AI agent powered by ChatGPT")
|
||
parser.add_argument("--open_ai_key_path", type=str, default=None, help="path to the plain text open_ai_key file")
|
||
|
||
args = parser.parse_args()
|
||
|
||
# Setup openai key
|
||
# Set env var OPENAI_API_KEY or load from a file
|
||
openai_key = open(args.open_ai_key_path).read()
|
||
os.environ["OPENAI_API_KEY"] = openai_key
|
||
|
||
# Load data served on sql
|
||
print("Select files for constructing sql database")
|
||
tools = []
|
||
|
||
llm = OpenAI(temperature=0.0)
|
||
|
||
while True:
|
||
file = input("Select a file to load or press Enter to exit:")
|
||
if file == "":
|
||
break
|
||
data_name = input("Enter a short description of the data:")
|
||
|
||
table_loader = TableLoader(
|
||
[[file, data_name.replace(" ", "_")]], sql_path=f"sqlite:///{data_name.replace(' ', '_')}.db"
|
||
)
|
||
sql_path = table_loader.get_sql_path()
|
||
|
||
# Create sql database
|
||
db = SQLDatabase.from_uri(sql_path)
|
||
print(db.get_table_info())
|
||
|
||
db_chain = SQLDatabaseChain.from_llm(llm, db, verbose=True)
|
||
name = f"Query the SQL database regarding {data_name}"
|
||
description = (
|
||
f"useful for when you need to answer questions based on data stored on a SQL database regarding {data_name}"
|
||
)
|
||
tools.append(
|
||
Tool(
|
||
name=name,
|
||
func=db_chain.run,
|
||
description=description,
|
||
)
|
||
)
|
||
print(f"Added sql dataset\n\tname={name}\n\tdescription:{description}")
|
||
|
||
# VectorDB
|
||
embedding = OpenAIEmbeddings()
|
||
|
||
# Load data serve on sql
|
||
print("Select files for constructing retriever")
|
||
while True:
|
||
file = input("Select a file to load or press Enter to exit:")
|
||
if file == "":
|
||
break
|
||
data_name = input("Enter a short description of the data:")
|
||
retriever_data = DocumentLoader([[file, data_name.replace(" ", "_")]]).all_data
|
||
|
||
# Split
|
||
text_splitter = RecursiveCharacterTextSplitter(chunk_size=200, chunk_overlap=0)
|
||
splits = text_splitter.split_documents(retriever_data)
|
||
|
||
# Create vector store
|
||
vectordb = Chroma.from_documents(documents=splits, embedding=embedding)
|
||
# Create retriever
|
||
retriever = vectordb.as_retriever(
|
||
search_type="similarity_score_threshold", search_kwargs={"score_threshold": 0.5, "k": 5}
|
||
)
|
||
# Add to tool chain
|
||
name = f"Searches and returns documents regarding {data_name}."
|
||
tools.append(create_retriever_tool(retriever, data_name, name))
|
||
|
||
prefix = """Have a conversation with a human, answering the following questions as best you can. You have access to the following tools. If none of the tools can be used to answer the question. Do not share uncertain answer unless you think answering the question doesn't need any background information. In that case, try to answer the question directly."""
|
||
suffix = """You are provided with the following background knowledge:
|
||
Begin!"
|
||
|
||
{chat_history}
|
||
Question: {input}
|
||
{agent_scratchpad}"""
|
||
|
||
prompt = ZeroShotAgent.create_prompt(
|
||
tools,
|
||
prefix=prefix,
|
||
suffix=suffix,
|
||
input_variables=["input", "chat_history", "agent_scratchpad"],
|
||
)
|
||
|
||
memory = ConversationBufferMemory(memory_key="chat_history", chat_memory=ChatMessageHistory())
|
||
|
||
llm_chain = LLMChain(llm=OpenAI(temperature=0.7), prompt=prompt)
|
||
agent = ZeroShotAgent(llm_chain=llm_chain, tools=tools, verbose=True)
|
||
agent_chain = AgentExecutor.from_agent_and_tools(agent=agent, tools=tools, verbose=True, memory=memory)
|
||
|
||
while True:
|
||
user_input = input("User: ")
|
||
if " end " in user_input:
|
||
print("Agent: Happy to chat with you :)")
|
||
break
|
||
agent_response = agent_chain.run(user_input)
|
||
print(f"Agent: {agent_response}")
|
||
table_loader.sql_engine.dispose()
|