PoC with LangChain wrapper and a webapp to chat with vector store

Added a proof of concept LangChain wrapper for ChatGLM model.

Added a streamlit based interface to chat with a vector store. ( in this case, OpenAI related wiki pages within two degree of separation,  stored in FAISS)
pull/216/head
Ji Zhang 2023-03-23 22:21:52 -07:00
parent 28665ade15
commit e90276f340
7 changed files with 578 additions and 0 deletions

View File

@ -0,0 +1,92 @@
import os
from typing import List, Dict, Tuple, Any
import streamlit as st
import pandas as pd
import os
from langchain.embeddings.openai import OpenAIEmbeddings
from langchain.vectorstores import Chroma
from langchain.text_splitter import CharacterTextSplitter
from langchain.chains import (
ChatVectorDBChain,
QAWithSourcesChain,
VectorDBQAWithSourcesChain,
)
from langchain.docstore.document import Document
from langchain.vectorstores.faiss import FAISS
from langchain.chat_models import ChatOpenAI
from langchain.prompts.chat import (
ChatPromptTemplate,
SystemMessagePromptTemplate,
AIMessagePromptTemplate,
HumanMessagePromptTemplate,
)
from transformers import AutoTokenizer, AutoModel
# Set up OpenAI API key
# This is solely for the purpose of semantic search part of langchain vector search.
# Completion is still purely done using ChatGLM model.
os.environ["OPENAI_API_KEY"] = ""
@st.cache_resource()
def get_chat_glm():
tokenizer = AutoTokenizer.from_pretrained(
"THUDM/chatglm-6b-int4", trust_remote_code=True
)
model = (
AutoModel.from_pretrained("THUDM/chatglm-6b-int4", trust_remote_code=True)
.half()
.cuda()
)
model = model.eval()
return model, tokenizer
def chat_with_agent(user_input, temperature=0.2, max_tokens=800, chat_history=[]):
model, tokenizer = get_chat_glm()
response, updated_history = model.chat(
tokenizer,
user_input,
history=chat_history,
temperature=temperature,
max_length=max_tokens,
)
return response, updated_history
# Langchian related features
def init_wiki_agent(
index_dir,
max_token=800,
temperature=0.3,
):
embeddings = OpenAIEmbeddings()
if index_dir:
vectorstore = FAISS.load_local(index_dir, embeddings=embeddings)
else:
raise ValueError("Need saved vector store location")
system_template = """使用以下wikipedia的片段来回答用户的问题。
如果无法从中得到答案请说 "不知道" "没有足够的相关信息". 不要试图编造答案
----------------
{context}"""
messages = [
SystemMessagePromptTemplate.from_template(system_template),
HumanMessagePromptTemplate.from_template("{question}"),
]
prompt = ChatPromptTemplate.from_messages(messages)
# qa = ChatVectorDBChain.from_llm(llm=ChatOpenAI(temperature=temperature, max_tokens=max_token),
# vectorstore=vectorstore,
# qa_prompt=prompt)
from chatglm_llm import ChatGLM_G
qa = ChatVectorDBChain.from_llm(
llm=ChatGLM_G(), vectorstore=vectorstore, qa_prompt=prompt
)
qa.return_source_documents = True
qa.top_k_docs_for_context = 2
return qa
def get_wiki_agent_answer(query, qa, chat_history=[]):
result = qa({"question": query, "chat_history": chat_history})
return result

View File

@ -0,0 +1,146 @@
.image-left {
display: inline-block;
vertical-align: middle;
margin-right: 1em;
}
.conversation-container {
padding: 1em;
border-radius: 10px;
/* margin-bottom: 1em; */
margin-top: 1em;
}
.conversation-container.user {
background-color: rgba(217, 217, 227, 0.4);
}
.conversation-container.bot {
/* background-color: rgba(247,247,248,0.3); */
margin-bottom: 0em;
}
.text-area-input {
height: 5em;
margin-bottom: 1em;
font-size: 1.2rem;
}
.conversation-scroll {
height: 70vh;
overflow-y: scroll;
}
[data-testid="stForm"] {
/* width: 55vw;
max-width: 80wh;
margin-left: -15vw; */
/* max-height: 70vh;
overflow-y: scroll; */
width: 70%;
margin-left: 15%;
}
[data-testid="stForm"] .stTextArea {
box-shadow: 0 5px 6px -4px #c7cdce;
width: 60% !important;
margin-left: 20% !important;
}
footer {
/* margin-left: -20vw; */
}
/* [data-testid="stMarkdownContainer"]:has(.bot){
margin-bottom: 0px;
margin-top: 0px;
} */
[data-testid="stForm"]
[data-testid="stVerticalBlock"]
[data-testid="stVerticalBlock"]:has(div.conversation-container) {
max-height: 55vh !important;
overflow-y: auto !important;
overflow-x: hidden;
/* font-size: 1.2rem; */
margin-right: 5px;
}
[data-testid="stForm"]
[data-testid="stVerticalBlock"]
[data-testid="stVerticalBlock"]
p {
/* font-size: 1.2rem; */
}
[data-testid="stForm"] .stImage {
/* width: 6rem !important;
*/
width: 5rem !important;
}
[data-testid="stForm"] .img {
/* width: 6rem !important;
*/
max-width: 85px;
}
@font-face {
font-family: "Josefin Slab", serif;
font-style: normal;
font-weight: 300;
}
header,
h1,
h2,
h3 [class*="css"] {
font-family: "Josefin Slab", serif;
font-style: normal;
font-weight: 300;
}
h1 {
font-weight: bold;
}
body {
font-size: medium;
}
/* .hhh{
color: black;
background-color:#fff;
}
.show:hover .hhh{
color: white;
} */
[data-testid="stForm"] .stButton {
box-shadow: 0 5px 6px -4px #c7cdce;
/* width: auto;
font-size: 20pt;
float: right;
box-shadow: rgba(44, 43, 43, 0.5) 0px 0px 0px 0.2rem; */
/* width: 60%; */
}
.stButton button {
width: 100%;
}
::-webkit-scrollbar:vertical {
width: 10px;
}
/* Track */
::-webkit-scrollbar-track:vertical {
background: #f1f1f1;
}
/* Handle */
::-webkit-scrollbar-thumb:vertical {
background: #888;
}
/* Handle on hover */
::-webkit-scrollbar-thumb:vertical:hover {
background: #555;
}

View File

@ -0,0 +1,35 @@
from langchain.llms.base import LLM
from typing import Optional, List, Mapping, Any
from langchain.llms.utils import enforce_stop_tokens
from transformers import AutoTokenizer, AutoModel
"""ChatGLM_G is a wrapper around the ChatGLM model to fit LangChain framework. May not be an optimal implementation"""
class ChatGLM_G(LLM):
tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm-6b-int4", trust_remote_code=True)
model = AutoModel.from_pretrained("THUDM/chatglm-6b-int4", trust_remote_code=True).half().cuda()
history = []
@property
def _llm_type(self) -> str:
return "ChatGLM_G"
def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str:
response, updated_history = self.model.chat(self.tokenizer, prompt, history=self.history)
print("ChatGLM: prompt: ", prompt)
print("ChatGLM: response: ", response)
if stop is not None:
response = enforce_stop_tokens(response, stop)
self.history = updated_history
return response
def __call__(self, prompt: str, stop: Optional[List[str]] = None) -> str:
response, updated_history = self.model.chat(self.tokenizer, prompt, history=self.history)
print("ChatGLM: prompt: ", prompt)
print("ChatGLM: response: ", response)
if stop is not None:
response = enforce_stop_tokens(response, stop)
self.history = updated_history
return response

Binary file not shown.

After

Width:  |  Height:  |  Size: 42 KiB

View File

@ -0,0 +1,305 @@
import streamlit as st
from chat_backend import chat_with_agent, init_wiki_agent, get_wiki_agent_answer
from streamlit.components.v1 import html
import os
import streamlit as st
from PIL import Image
import html
import uuid
path = os.path.dirname(__file__)
icon_img = Image.open(os.path.join(path, "logo.png"))
USER_NAME = "Me"
AGENT_NAME = "Helpbot"
st.set_page_config(
page_title="ChatGLM",
page_icon=icon_img,
layout="wide",
# initial_sidebar_state="collapsed",
)
st.write(
"<style>div.block-container{padding-top:1rem;}</style>", unsafe_allow_html=True
)
def local_css(file_name):
with open(file_name) as f:
st.markdown(f"<style>{f.read()}</style>", unsafe_allow_html=True)
def remote_css(url):
st.markdown(f'<link href="{url}" rel="stylesheet">', unsafe_allow_html=True)
def icon(icon_name):
st.markdown(f'<i class="material-icons">{icon_name}</i>', unsafe_allow_html=True)
def javascript(source: str) -> None:
"""loading javascript correctly"""
div_id = uuid.uuid4()
st.markdown(
f"""
<div style="display:none" id="{div_id}">
<iframe src="javascript: \
var script = document.createElement('script'); \
script.type = 'text/javascript'; \
script.text = {html.escape(repr(source))}; \
var div = window.parent.document.getElementById('{div_id}'); \
div.appendChild(script); \
div.parentElement.parentElement.parentElement.style.display = 'none'; \
"/>
</div>
""",
unsafe_allow_html=True,
)
local_css("chat_style.css")
st.markdown(
"""
<link href="https://fonts.googleapis.com/css2?family=Josefin+Slab&display=swap" rel="stylesheet">
""",
unsafe_allow_html=True,
)
# User Input and Send button
user_profile_image = "https://img1.baidu.com/it/u=3150659458,3834452201&fm=253&fmt=auto&app=138&f=JPEG?w=369&h=378"
chatgpt_profile_image = "https://avatars.githubusercontent.com/u/44095251?v=4"
def display_chat_log(cur_container):
for cur_conversation in st.session_state["chat_log"]:
for msg in cur_conversation:
if msg["role"] == USER_NAME:
cur_container.markdown(
"<div class=' conversation-container user'><img src='{}' class='image-left' width='50'><br> {} </div>".format(
user_profile_image, html.escape(msg["content"])
),
unsafe_allow_html=True,
)
else:
cur_container.markdown(
"<div class='conversation-container bot'><img src='{}' class='image-left' width='50'></div>".format(
chatgpt_profile_image
),
unsafe_allow_html=True,
)
cur_container.markdown(
f"{'&nbsp;&nbsp;&nbsp;'+msg['content']}", unsafe_allow_html=True
)
def dict_to_github_markdown(data, has_section=False):
wiki_logo = "https://upload.wikimedia.org/wikipedia/en/thumb/8/80/Wikipedia-logo-v2.svg/1200px-Wikipedia-logo-v2.svg.png"
slack_logo = (
"https://cdn.freebiesupply.com/logos/large/2x/slack-1-logo-png-transparent.png"
)
markdown = ""
for item in data:
title = item["title"]
url = item["url"]
if has_section:
section = item["section"]
title_text_and_section = f"{title} - {section}"
else:
title_text_and_section = title
if "wikipedia" in url:
logo = wiki_logo
elif "slack" in url:
logo = slack_logo
else:
logo = ""
if len(title_text_and_section) > 50:
title_text_and_section = title_text_and_section[:50] + "..."
hyperlink = f"[{title_text_and_section}]({url})"
if logo:
markdown += (
f"&nbsp;&nbsp; <img src='{logo}' width='20' height='20'> {hyperlink} "
)
else:
markdown += f"&nbsp;&nbsp; {hyperlink}"
return markdown
def clean_agent():
st.session_state["chat_log"] = [[]]
st.session_state["messages"] = None
st.session_state["agent"] = None
st.session_state["agent_chat_history"] = []
if "bot_desc" not in st.session_state:
st.session_state["bot_desc"] = "ChatGLM with Vectorstore."
# Sidebar
st.sidebar.subheader("Model Settings")
agent_selected = st.sidebar.selectbox(
label="Agent",
options=["Chat", "AI Wikipedia Agent"],
index=0,
on_change=clean_agent,
help="Select the agent to chat with.\n\nChat: General conversational Chatbot based on ChatGLM.\n\AI Wikipedia Agent: Chat with knowlegebase. In this case, wikipedia articles within 2 degree of separation to OpenAI.",
)
max_token_selected = st.sidebar.slider(
label="Model Max Output Length",
min_value=50,
max_value=4500,
value=500,
step=50,
help="The maximum number of tokens to generate. Requests can use up to 2,048 or 4,000 tokens shared between prompt and completion. The exact limit varies by model. (One token is roughly 4 characters for normal English text)",
)
tempature_selected = st.sidebar.number_input(
label="Model Tempature",
min_value=0.0,
max_value=1.0,
value=0.2,
step=0.1,
help="Controls randomness: Lowering results in less random completions. As the temperature approaches zero, the model will become deterministic and repetitive.",
)
# Dynamic conversation display
if "chat_log" not in st.session_state:
st.session_state["chat_log"] = [[]]
if "messages" not in st.session_state:
st.session_state["messages"] = None
if "agent_chat_history" not in st.session_state:
st.session_state["agent_chat_history"] = []
if "agent" not in st.session_state:
st.session_state["agent"] = None
with st.form(key="user_question", clear_on_submit=True):
# Title and Image in same line
# Use user chatgpt profile image
c1, c2 = st.columns((9, 1))
c1.write("# ChatGLM")
c1.write(f"### Chatbot for general conversation.")
help_bot_icon = (
f'<img src="{chatgpt_profile_image}" width="50" style="vertical-align:middle">'
)
app_log_image = Image.open("logo.png")
c2.image(app_log_image)
conversation_main_container = st.container()
user_input = st.text_area(
"", key="user_input", height=20, placeholder="Ask me anything!"
)
# set button on the right
_, c_clean_btn, c_btn, _ = st.columns([5.2, 1, 1.8, 2])
send_button = c_btn.form_submit_button(label="Send")
clean_button = c_clean_btn.form_submit_button(label="Clear")
if clean_button:
clean_agent()
conversation = []
if send_button:
if user_input:
with st.spinner("Thinking..."):
# Determin which agent to call:
if agent_selected == "Chat":
output, cur_chat_history = chat_with_agent(
user_input,
temperature=tempature_selected,
max_tokens=max_token_selected,
chat_history=st.session_state["messages"],
)
# Update chat history
st.session_state["messages"] = cur_chat_history
# Update overall displayed conversations
conversation.append({"role": USER_NAME, "content": user_input})
conversation.append({"role": AGENT_NAME, "content": output})
elif agent_selected == "AI Wikipedia Agent":
if (
"agent" not in st.session_state
or st.session_state.agent is None
):
st.session_state.agent = init_wiki_agent(
index_dir="index/wiki_faiss_2023_03_06",
max_token=max_token_selected,
temperature=tempature_selected,
)
output_dict = get_wiki_agent_answer(
user_input,
st.session_state.agent,
chat_history=st.session_state["agent_chat_history"],
)
output = output_dict["answer"]
output_sources = [
c.metadata for c in list(output_dict["source_documents"])
]
st.session_state["agent_chat_history"].append((user_input, output))
conversation.append({"role": USER_NAME, "content": user_input})
conversation.append(
{
"role": AGENT_NAME,
"content": output
+ "\n\n&nbsp;&nbsp;&nbsp;**Sources:** "
+ dict_to_github_markdown(output_sources, has_section=True),
}
)
st.session_state["chat_log"].append(conversation)
col99, col1 = st.columns([999, 1])
with col99:
display_chat_log(conversation_main_container)
with col1:
# Scroll to bottom of conversation
scroll_to_element = """
var element = document.getElementsByClassName('conversation-container')[
document.getElementsByClassName('conversation-container').length - 1
];
element.scrollIntoView({behavior: 'smooth', block: 'start'});
"""
javascript(scroll_to_element)
def footer():
style = """
<style>
# MainMenu {visibility: hidden;}
footer {visibility: hidden;}
</style>
"""
myargs = [
"Made with ChatGLM models, check out the models on ",
'<img src="https://cdn-icons-png.flaticon.com/512/25/25231.png" width="18" height="18" margin="0em">',
' <a href="https://github.com/THUDM/ChatGLM-6B" target="_blank">official repo</a>',
"!",
"<br>",
]
st.markdown(style, unsafe_allow_html=True)
st.markdown(
'<div style="left: 0; bottom: 0; margin: 0px 0px 0px 0px; width: 100%; text-align: center; height: 30px; opacity: 0.8;">'
+ "".join(myargs)
+ "</div>",
unsafe_allow_html=True,
)
footer()