mirror of https://github.com/THUDM/ChatGLM-6B
PoC with LangChain wrapper and a webapp to chat with vector store
Added a proof of concept LangChain wrapper for ChatGLM model. Added a streamlit based interface to chat with a vector store. ( in this case, OpenAI related wiki pages within two degree of separation, stored in FAISS)pull/216/head
parent
28665ade15
commit
e90276f340
|
@ -0,0 +1,92 @@
|
|||
import os
|
||||
from typing import List, Dict, Tuple, Any
|
||||
import streamlit as st
|
||||
import pandas as pd
|
||||
import os
|
||||
from langchain.embeddings.openai import OpenAIEmbeddings
|
||||
from langchain.vectorstores import Chroma
|
||||
from langchain.text_splitter import CharacterTextSplitter
|
||||
from langchain.chains import (
|
||||
ChatVectorDBChain,
|
||||
QAWithSourcesChain,
|
||||
VectorDBQAWithSourcesChain,
|
||||
)
|
||||
from langchain.docstore.document import Document
|
||||
from langchain.vectorstores.faiss import FAISS
|
||||
from langchain.chat_models import ChatOpenAI
|
||||
from langchain.prompts.chat import (
|
||||
ChatPromptTemplate,
|
||||
SystemMessagePromptTemplate,
|
||||
AIMessagePromptTemplate,
|
||||
HumanMessagePromptTemplate,
|
||||
)
|
||||
from transformers import AutoTokenizer, AutoModel
|
||||
|
||||
# Set up OpenAI API key
|
||||
# This is solely for the purpose of semantic search part of langchain vector search.
|
||||
# Completion is still purely done using ChatGLM model.
|
||||
os.environ["OPENAI_API_KEY"] = ""
|
||||
|
||||
|
||||
@st.cache_resource()
|
||||
def get_chat_glm():
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
"THUDM/chatglm-6b-int4", trust_remote_code=True
|
||||
)
|
||||
model = (
|
||||
AutoModel.from_pretrained("THUDM/chatglm-6b-int4", trust_remote_code=True)
|
||||
.half()
|
||||
.cuda()
|
||||
)
|
||||
model = model.eval()
|
||||
return model, tokenizer
|
||||
|
||||
|
||||
def chat_with_agent(user_input, temperature=0.2, max_tokens=800, chat_history=[]):
|
||||
model, tokenizer = get_chat_glm()
|
||||
response, updated_history = model.chat(
|
||||
tokenizer,
|
||||
user_input,
|
||||
history=chat_history,
|
||||
temperature=temperature,
|
||||
max_length=max_tokens,
|
||||
)
|
||||
return response, updated_history
|
||||
|
||||
|
||||
# Langchian related features
|
||||
def init_wiki_agent(
|
||||
index_dir,
|
||||
max_token=800,
|
||||
temperature=0.3,
|
||||
):
|
||||
|
||||
embeddings = OpenAIEmbeddings()
|
||||
if index_dir:
|
||||
vectorstore = FAISS.load_local(index_dir, embeddings=embeddings)
|
||||
else:
|
||||
raise ValueError("Need saved vector store location")
|
||||
system_template = """使用以下wikipedia的片段来回答用户的问题。
|
||||
如果无法从中得到答案,请说 "不知道" 或 "没有足够的相关信息". 不要试图编造答案。
|
||||
----------------
|
||||
{context}"""
|
||||
messages = [
|
||||
SystemMessagePromptTemplate.from_template(system_template),
|
||||
HumanMessagePromptTemplate.from_template("{question}"),
|
||||
]
|
||||
prompt = ChatPromptTemplate.from_messages(messages)
|
||||
# qa = ChatVectorDBChain.from_llm(llm=ChatOpenAI(temperature=temperature, max_tokens=max_token),
|
||||
# vectorstore=vectorstore,
|
||||
# qa_prompt=prompt)
|
||||
from chatglm_llm import ChatGLM_G
|
||||
qa = ChatVectorDBChain.from_llm(
|
||||
llm=ChatGLM_G(), vectorstore=vectorstore, qa_prompt=prompt
|
||||
)
|
||||
qa.return_source_documents = True
|
||||
qa.top_k_docs_for_context = 2
|
||||
return qa
|
||||
|
||||
|
||||
def get_wiki_agent_answer(query, qa, chat_history=[]):
|
||||
result = qa({"question": query, "chat_history": chat_history})
|
||||
return result
|
|
@ -0,0 +1,146 @@
|
|||
.image-left {
|
||||
display: inline-block;
|
||||
vertical-align: middle;
|
||||
margin-right: 1em;
|
||||
}
|
||||
.conversation-container {
|
||||
padding: 1em;
|
||||
border-radius: 10px;
|
||||
/* margin-bottom: 1em; */
|
||||
margin-top: 1em;
|
||||
}
|
||||
|
||||
.conversation-container.user {
|
||||
background-color: rgba(217, 217, 227, 0.4);
|
||||
}
|
||||
|
||||
.conversation-container.bot {
|
||||
/* background-color: rgba(247,247,248,0.3); */
|
||||
margin-bottom: 0em;
|
||||
}
|
||||
.text-area-input {
|
||||
height: 5em;
|
||||
margin-bottom: 1em;
|
||||
font-size: 1.2rem;
|
||||
}
|
||||
.conversation-scroll {
|
||||
height: 70vh;
|
||||
overflow-y: scroll;
|
||||
}
|
||||
|
||||
[data-testid="stForm"] {
|
||||
/* width: 55vw;
|
||||
max-width: 80wh;
|
||||
margin-left: -15vw; */
|
||||
/* max-height: 70vh;
|
||||
overflow-y: scroll; */
|
||||
width: 70%;
|
||||
margin-left: 15%;
|
||||
}
|
||||
|
||||
[data-testid="stForm"] .stTextArea {
|
||||
box-shadow: 0 5px 6px -4px #c7cdce;
|
||||
width: 60% !important;
|
||||
margin-left: 20% !important;
|
||||
}
|
||||
|
||||
footer {
|
||||
/* margin-left: -20vw; */
|
||||
}
|
||||
|
||||
/* [data-testid="stMarkdownContainer"]:has(.bot){
|
||||
margin-bottom: 0px;
|
||||
margin-top: 0px;
|
||||
} */
|
||||
|
||||
[data-testid="stForm"]
|
||||
[data-testid="stVerticalBlock"]
|
||||
[data-testid="stVerticalBlock"]:has(div.conversation-container) {
|
||||
max-height: 55vh !important;
|
||||
overflow-y: auto !important;
|
||||
overflow-x: hidden;
|
||||
/* font-size: 1.2rem; */
|
||||
margin-right: 5px;
|
||||
}
|
||||
|
||||
[data-testid="stForm"]
|
||||
[data-testid="stVerticalBlock"]
|
||||
[data-testid="stVerticalBlock"]
|
||||
p {
|
||||
/* font-size: 1.2rem; */
|
||||
}
|
||||
|
||||
[data-testid="stForm"] .stImage {
|
||||
/* width: 6rem !important;
|
||||
*/
|
||||
width: 5rem !important;
|
||||
}
|
||||
|
||||
[data-testid="stForm"] .img {
|
||||
/* width: 6rem !important;
|
||||
*/
|
||||
max-width: 85px;
|
||||
}
|
||||
|
||||
@font-face {
|
||||
font-family: "Josefin Slab", serif;
|
||||
font-style: normal;
|
||||
font-weight: 300;
|
||||
}
|
||||
header,
|
||||
h1,
|
||||
h2,
|
||||
h3 [class*="css"] {
|
||||
font-family: "Josefin Slab", serif;
|
||||
font-style: normal;
|
||||
font-weight: 300;
|
||||
}
|
||||
|
||||
h1 {
|
||||
font-weight: bold;
|
||||
}
|
||||
|
||||
body {
|
||||
font-size: medium;
|
||||
}
|
||||
|
||||
/* .hhh{
|
||||
color: black;
|
||||
background-color:#fff;
|
||||
}
|
||||
|
||||
.show:hover .hhh{
|
||||
color: white;
|
||||
} */
|
||||
|
||||
[data-testid="stForm"] .stButton {
|
||||
box-shadow: 0 5px 6px -4px #c7cdce;
|
||||
/* width: auto;
|
||||
font-size: 20pt;
|
||||
float: right;
|
||||
box-shadow: rgba(44, 43, 43, 0.5) 0px 0px 0px 0.2rem; */
|
||||
/* width: 60%; */
|
||||
}
|
||||
|
||||
.stButton button {
|
||||
width: 100%;
|
||||
}
|
||||
|
||||
::-webkit-scrollbar:vertical {
|
||||
width: 10px;
|
||||
}
|
||||
|
||||
/* Track */
|
||||
::-webkit-scrollbar-track:vertical {
|
||||
background: #f1f1f1;
|
||||
}
|
||||
|
||||
/* Handle */
|
||||
::-webkit-scrollbar-thumb:vertical {
|
||||
background: #888;
|
||||
}
|
||||
|
||||
/* Handle on hover */
|
||||
::-webkit-scrollbar-thumb:vertical:hover {
|
||||
background: #555;
|
||||
}
|
|
@ -0,0 +1,35 @@
|
|||
from langchain.llms.base import LLM
|
||||
from typing import Optional, List, Mapping, Any
|
||||
from langchain.llms.utils import enforce_stop_tokens
|
||||
from transformers import AutoTokenizer, AutoModel
|
||||
|
||||
"""ChatGLM_G is a wrapper around the ChatGLM model to fit LangChain framework. May not be an optimal implementation"""
|
||||
|
||||
class ChatGLM_G(LLM):
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm-6b-int4", trust_remote_code=True)
|
||||
model = AutoModel.from_pretrained("THUDM/chatglm-6b-int4", trust_remote_code=True).half().cuda()
|
||||
history = []
|
||||
|
||||
@property
|
||||
def _llm_type(self) -> str:
|
||||
return "ChatGLM_G"
|
||||
|
||||
def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str:
|
||||
response, updated_history = self.model.chat(self.tokenizer, prompt, history=self.history)
|
||||
print("ChatGLM: prompt: ", prompt)
|
||||
print("ChatGLM: response: ", response)
|
||||
if stop is not None:
|
||||
response = enforce_stop_tokens(response, stop)
|
||||
self.history = updated_history
|
||||
return response
|
||||
|
||||
def __call__(self, prompt: str, stop: Optional[List[str]] = None) -> str:
|
||||
response, updated_history = self.model.chat(self.tokenizer, prompt, history=self.history)
|
||||
print("ChatGLM: prompt: ", prompt)
|
||||
print("ChatGLM: response: ", response)
|
||||
if stop is not None:
|
||||
response = enforce_stop_tokens(response, stop)
|
||||
self.history = updated_history
|
||||
|
||||
return response
|
Binary file not shown.
Binary file not shown.
Binary file not shown.
After Width: | Height: | Size: 42 KiB |
|
@ -0,0 +1,305 @@
|
|||
import streamlit as st
|
||||
from chat_backend import chat_with_agent, init_wiki_agent, get_wiki_agent_answer
|
||||
from streamlit.components.v1 import html
|
||||
|
||||
import os
|
||||
import streamlit as st
|
||||
from PIL import Image
|
||||
import html
|
||||
import uuid
|
||||
|
||||
path = os.path.dirname(__file__)
|
||||
|
||||
|
||||
icon_img = Image.open(os.path.join(path, "logo.png"))
|
||||
|
||||
USER_NAME = "Me"
|
||||
AGENT_NAME = "Helpbot"
|
||||
|
||||
|
||||
st.set_page_config(
|
||||
page_title="ChatGLM",
|
||||
page_icon=icon_img,
|
||||
layout="wide",
|
||||
# initial_sidebar_state="collapsed",
|
||||
)
|
||||
|
||||
st.write(
|
||||
"<style>div.block-container{padding-top:1rem;}</style>", unsafe_allow_html=True
|
||||
)
|
||||
|
||||
|
||||
def local_css(file_name):
|
||||
with open(file_name) as f:
|
||||
st.markdown(f"<style>{f.read()}</style>", unsafe_allow_html=True)
|
||||
|
||||
|
||||
def remote_css(url):
|
||||
st.markdown(f'<link href="{url}" rel="stylesheet">', unsafe_allow_html=True)
|
||||
|
||||
|
||||
def icon(icon_name):
|
||||
st.markdown(f'<i class="material-icons">{icon_name}</i>', unsafe_allow_html=True)
|
||||
|
||||
|
||||
def javascript(source: str) -> None:
|
||||
"""loading javascript correctly"""
|
||||
div_id = uuid.uuid4()
|
||||
|
||||
st.markdown(
|
||||
f"""
|
||||
<div style="display:none" id="{div_id}">
|
||||
<iframe src="javascript: \
|
||||
var script = document.createElement('script'); \
|
||||
script.type = 'text/javascript'; \
|
||||
script.text = {html.escape(repr(source))}; \
|
||||
var div = window.parent.document.getElementById('{div_id}'); \
|
||||
div.appendChild(script); \
|
||||
div.parentElement.parentElement.parentElement.style.display = 'none'; \
|
||||
"/>
|
||||
</div>
|
||||
""",
|
||||
unsafe_allow_html=True,
|
||||
)
|
||||
|
||||
|
||||
local_css("chat_style.css")
|
||||
|
||||
|
||||
st.markdown(
|
||||
"""
|
||||
<link href="https://fonts.googleapis.com/css2?family=Josefin+Slab&display=swap" rel="stylesheet">
|
||||
""",
|
||||
unsafe_allow_html=True,
|
||||
)
|
||||
|
||||
# User Input and Send button
|
||||
user_profile_image = "https://img1.baidu.com/it/u=3150659458,3834452201&fm=253&fmt=auto&app=138&f=JPEG?w=369&h=378"
|
||||
|
||||
chatgpt_profile_image = "https://avatars.githubusercontent.com/u/44095251?v=4"
|
||||
|
||||
|
||||
def display_chat_log(cur_container):
|
||||
for cur_conversation in st.session_state["chat_log"]:
|
||||
for msg in cur_conversation:
|
||||
if msg["role"] == USER_NAME:
|
||||
cur_container.markdown(
|
||||
"<div class=' conversation-container user'><img src='{}' class='image-left' width='50'><br> {} </div>".format(
|
||||
user_profile_image, html.escape(msg["content"])
|
||||
),
|
||||
unsafe_allow_html=True,
|
||||
)
|
||||
else:
|
||||
cur_container.markdown(
|
||||
"<div class='conversation-container bot'><img src='{}' class='image-left' width='50'></div>".format(
|
||||
chatgpt_profile_image
|
||||
),
|
||||
unsafe_allow_html=True,
|
||||
)
|
||||
cur_container.markdown(
|
||||
f"{' '+msg['content']}", unsafe_allow_html=True
|
||||
)
|
||||
|
||||
|
||||
def dict_to_github_markdown(data, has_section=False):
|
||||
wiki_logo = "https://upload.wikimedia.org/wikipedia/en/thumb/8/80/Wikipedia-logo-v2.svg/1200px-Wikipedia-logo-v2.svg.png"
|
||||
slack_logo = (
|
||||
"https://cdn.freebiesupply.com/logos/large/2x/slack-1-logo-png-transparent.png"
|
||||
)
|
||||
markdown = ""
|
||||
for item in data:
|
||||
title = item["title"]
|
||||
url = item["url"]
|
||||
if has_section:
|
||||
section = item["section"]
|
||||
title_text_and_section = f"{title} - {section}"
|
||||
else:
|
||||
title_text_and_section = title
|
||||
if "wikipedia" in url:
|
||||
logo = wiki_logo
|
||||
elif "slack" in url:
|
||||
logo = slack_logo
|
||||
else:
|
||||
logo = ""
|
||||
if len(title_text_and_section) > 50:
|
||||
title_text_and_section = title_text_and_section[:50] + "..."
|
||||
hyperlink = f"[{title_text_and_section}]({url})"
|
||||
if logo:
|
||||
markdown += (
|
||||
f" <img src='{logo}' width='20' height='20'> {hyperlink} "
|
||||
)
|
||||
else:
|
||||
markdown += f" {hyperlink}"
|
||||
return markdown
|
||||
|
||||
|
||||
def clean_agent():
|
||||
st.session_state["chat_log"] = [[]]
|
||||
st.session_state["messages"] = None
|
||||
st.session_state["agent"] = None
|
||||
st.session_state["agent_chat_history"] = []
|
||||
|
||||
|
||||
if "bot_desc" not in st.session_state:
|
||||
st.session_state["bot_desc"] = "ChatGLM with Vectorstore."
|
||||
|
||||
|
||||
# Sidebar
|
||||
st.sidebar.subheader("Model Settings")
|
||||
agent_selected = st.sidebar.selectbox(
|
||||
label="Agent",
|
||||
options=["Chat", "AI Wikipedia Agent"],
|
||||
index=0,
|
||||
on_change=clean_agent,
|
||||
help="Select the agent to chat with.\n\nChat: General conversational Chatbot based on ChatGLM.\n\AI Wikipedia Agent: Chat with knowlegebase. In this case, wikipedia articles within 2 degree of separation to OpenAI.",
|
||||
)
|
||||
max_token_selected = st.sidebar.slider(
|
||||
label="Model Max Output Length",
|
||||
min_value=50,
|
||||
max_value=4500,
|
||||
value=500,
|
||||
step=50,
|
||||
help="The maximum number of tokens to generate. Requests can use up to 2,048 or 4,000 tokens shared between prompt and completion. The exact limit varies by model. (One token is roughly 4 characters for normal English text)",
|
||||
)
|
||||
tempature_selected = st.sidebar.number_input(
|
||||
label="Model Tempature",
|
||||
min_value=0.0,
|
||||
max_value=1.0,
|
||||
value=0.2,
|
||||
step=0.1,
|
||||
help="Controls randomness: Lowering results in less random completions. As the temperature approaches zero, the model will become deterministic and repetitive.",
|
||||
)
|
||||
|
||||
# Dynamic conversation display
|
||||
if "chat_log" not in st.session_state:
|
||||
st.session_state["chat_log"] = [[]]
|
||||
|
||||
if "messages" not in st.session_state:
|
||||
st.session_state["messages"] = None
|
||||
if "agent_chat_history" not in st.session_state:
|
||||
st.session_state["agent_chat_history"] = []
|
||||
|
||||
if "agent" not in st.session_state:
|
||||
st.session_state["agent"] = None
|
||||
|
||||
|
||||
with st.form(key="user_question", clear_on_submit=True):
|
||||
|
||||
# Title and Image in same line
|
||||
# Use user chatgpt profile image
|
||||
c1, c2 = st.columns((9, 1))
|
||||
c1.write("# ChatGLM")
|
||||
c1.write(f"### Chatbot for general conversation.")
|
||||
|
||||
help_bot_icon = (
|
||||
f'<img src="{chatgpt_profile_image}" width="50" style="vertical-align:middle">'
|
||||
)
|
||||
app_log_image = Image.open("logo.png")
|
||||
c2.image(app_log_image)
|
||||
|
||||
conversation_main_container = st.container()
|
||||
|
||||
user_input = st.text_area(
|
||||
"", key="user_input", height=20, placeholder="Ask me anything!"
|
||||
)
|
||||
# set button on the right
|
||||
_, c_clean_btn, c_btn, _ = st.columns([5.2, 1, 1.8, 2])
|
||||
send_button = c_btn.form_submit_button(label="Send")
|
||||
clean_button = c_clean_btn.form_submit_button(label="Clear")
|
||||
if clean_button:
|
||||
clean_agent()
|
||||
|
||||
conversation = []
|
||||
if send_button:
|
||||
if user_input:
|
||||
with st.spinner("Thinking..."):
|
||||
# Determin which agent to call:
|
||||
if agent_selected == "Chat":
|
||||
|
||||
output, cur_chat_history = chat_with_agent(
|
||||
user_input,
|
||||
temperature=tempature_selected,
|
||||
max_tokens=max_token_selected,
|
||||
chat_history=st.session_state["messages"],
|
||||
)
|
||||
|
||||
# Update chat history
|
||||
st.session_state["messages"] = cur_chat_history
|
||||
# Update overall displayed conversations
|
||||
conversation.append({"role": USER_NAME, "content": user_input})
|
||||
conversation.append({"role": AGENT_NAME, "content": output})
|
||||
elif agent_selected == "AI Wikipedia Agent":
|
||||
if (
|
||||
"agent" not in st.session_state
|
||||
or st.session_state.agent is None
|
||||
):
|
||||
st.session_state.agent = init_wiki_agent(
|
||||
index_dir="index/wiki_faiss_2023_03_06",
|
||||
max_token=max_token_selected,
|
||||
temperature=tempature_selected,
|
||||
)
|
||||
output_dict = get_wiki_agent_answer(
|
||||
user_input,
|
||||
st.session_state.agent,
|
||||
chat_history=st.session_state["agent_chat_history"],
|
||||
)
|
||||
output = output_dict["answer"]
|
||||
|
||||
output_sources = [
|
||||
c.metadata for c in list(output_dict["source_documents"])
|
||||
]
|
||||
|
||||
st.session_state["agent_chat_history"].append((user_input, output))
|
||||
|
||||
conversation.append({"role": USER_NAME, "content": user_input})
|
||||
conversation.append(
|
||||
{
|
||||
"role": AGENT_NAME,
|
||||
"content": output
|
||||
+ "\n\n **Sources:** "
|
||||
+ dict_to_github_markdown(output_sources, has_section=True),
|
||||
}
|
||||
)
|
||||
|
||||
st.session_state["chat_log"].append(conversation)
|
||||
col99, col1 = st.columns([999, 1])
|
||||
with col99:
|
||||
display_chat_log(conversation_main_container)
|
||||
with col1:
|
||||
|
||||
# Scroll to bottom of conversation
|
||||
scroll_to_element = """
|
||||
var element = document.getElementsByClassName('conversation-container')[
|
||||
document.getElementsByClassName('conversation-container').length - 1
|
||||
];
|
||||
element.scrollIntoView({behavior: 'smooth', block: 'start'});
|
||||
"""
|
||||
javascript(scroll_to_element)
|
||||
|
||||
|
||||
def footer():
|
||||
style = """
|
||||
<style>
|
||||
# MainMenu {visibility: hidden;}
|
||||
footer {visibility: hidden;}
|
||||
</style>
|
||||
"""
|
||||
|
||||
myargs = [
|
||||
"Made with ChatGLM models, check out the models on ",
|
||||
'<img src="https://cdn-icons-png.flaticon.com/512/25/25231.png" width="18" height="18" margin="0em">',
|
||||
' <a href="https://github.com/THUDM/ChatGLM-6B" target="_blank">official repo</a>',
|
||||
"!",
|
||||
"<br>",
|
||||
]
|
||||
|
||||
st.markdown(style, unsafe_allow_html=True)
|
||||
st.markdown(
|
||||
'<div style="left: 0; bottom: 0; margin: 0px 0px 0px 0px; width: 100%; text-align: center; height: 30px; opacity: 0.8;">'
|
||||
+ "".join(myargs)
|
||||
+ "</div>",
|
||||
unsafe_allow_html=True,
|
||||
)
|
||||
|
||||
|
||||
footer()
|
Loading…
Reference in New Issue