diff --git a/example_with_langchain_and_vectorstore/chat_backend.py b/example_with_langchain_and_vectorstore/chat_backend.py new file mode 100644 index 0000000..747d352 --- /dev/null +++ b/example_with_langchain_and_vectorstore/chat_backend.py @@ -0,0 +1,92 @@ +import os +from typing import List, Dict, Tuple, Any +import streamlit as st +import pandas as pd +import os +from langchain.embeddings.openai import OpenAIEmbeddings +from langchain.vectorstores import Chroma +from langchain.text_splitter import CharacterTextSplitter +from langchain.chains import ( + ChatVectorDBChain, + QAWithSourcesChain, + VectorDBQAWithSourcesChain, +) +from langchain.docstore.document import Document +from langchain.vectorstores.faiss import FAISS +from langchain.chat_models import ChatOpenAI +from langchain.prompts.chat import ( + ChatPromptTemplate, + SystemMessagePromptTemplate, + AIMessagePromptTemplate, + HumanMessagePromptTemplate, +) +from transformers import AutoTokenizer, AutoModel + +# Set up OpenAI API key +# This is solely for the purpose of semantic search part of langchain vector search. +# Completion is still purely done using ChatGLM model. +os.environ["OPENAI_API_KEY"] = "" + + +@st.cache_resource() +def get_chat_glm(): + tokenizer = AutoTokenizer.from_pretrained( + "THUDM/chatglm-6b-int4", trust_remote_code=True + ) + model = ( + AutoModel.from_pretrained("THUDM/chatglm-6b-int4", trust_remote_code=True) + .half() + .cuda() + ) + model = model.eval() + return model, tokenizer + + +def chat_with_agent(user_input, temperature=0.2, max_tokens=800, chat_history=[]): + model, tokenizer = get_chat_glm() + response, updated_history = model.chat( + tokenizer, + user_input, + history=chat_history, + temperature=temperature, + max_length=max_tokens, + ) + return response, updated_history + + +# Langchian related features +def init_wiki_agent( + index_dir, + max_token=800, + temperature=0.3, +): + + embeddings = OpenAIEmbeddings() + if index_dir: + vectorstore = FAISS.load_local(index_dir, embeddings=embeddings) + else: + raise ValueError("Need saved vector store location") + system_template = """使用以下wikipedia的片段来回答用户的问题。 +如果无法从中得到答案,请说 "不知道" 或 "没有足够的相关信息". 不要试图编造答案。 +---------------- +{context}""" + messages = [ + SystemMessagePromptTemplate.from_template(system_template), + HumanMessagePromptTemplate.from_template("{question}"), + ] + prompt = ChatPromptTemplate.from_messages(messages) + # qa = ChatVectorDBChain.from_llm(llm=ChatOpenAI(temperature=temperature, max_tokens=max_token), + # vectorstore=vectorstore, + # qa_prompt=prompt) + from chatglm_llm import ChatGLM_G + qa = ChatVectorDBChain.from_llm( + llm=ChatGLM_G(), vectorstore=vectorstore, qa_prompt=prompt + ) + qa.return_source_documents = True + qa.top_k_docs_for_context = 2 + return qa + + +def get_wiki_agent_answer(query, qa, chat_history=[]): + result = qa({"question": query, "chat_history": chat_history}) + return result diff --git a/example_with_langchain_and_vectorstore/chat_style.css b/example_with_langchain_and_vectorstore/chat_style.css new file mode 100644 index 0000000..109208c --- /dev/null +++ b/example_with_langchain_and_vectorstore/chat_style.css @@ -0,0 +1,146 @@ +.image-left { + display: inline-block; + vertical-align: middle; + margin-right: 1em; +} +.conversation-container { + padding: 1em; + border-radius: 10px; + /* margin-bottom: 1em; */ + margin-top: 1em; +} + +.conversation-container.user { + background-color: rgba(217, 217, 227, 0.4); +} + +.conversation-container.bot { + /* background-color: rgba(247,247,248,0.3); */ + margin-bottom: 0em; +} +.text-area-input { + height: 5em; + margin-bottom: 1em; + font-size: 1.2rem; +} +.conversation-scroll { + height: 70vh; + overflow-y: scroll; +} + +[data-testid="stForm"] { + /* width: 55vw; +max-width: 80wh; +margin-left: -15vw; */ + /* max-height: 70vh; + overflow-y: scroll; */ + width: 70%; + margin-left: 15%; +} + +[data-testid="stForm"] .stTextArea { + box-shadow: 0 5px 6px -4px #c7cdce; + width: 60% !important; + margin-left: 20% !important; +} + +footer { + /* margin-left: -20vw; */ +} + +/* [data-testid="stMarkdownContainer"]:has(.bot){ + margin-bottom: 0px; + margin-top: 0px; +} */ + +[data-testid="stForm"] + [data-testid="stVerticalBlock"] + [data-testid="stVerticalBlock"]:has(div.conversation-container) { + max-height: 55vh !important; + overflow-y: auto !important; + overflow-x: hidden; + /* font-size: 1.2rem; */ + margin-right: 5px; +} + +[data-testid="stForm"] + [data-testid="stVerticalBlock"] + [data-testid="stVerticalBlock"] + p { + /* font-size: 1.2rem; */ +} + +[data-testid="stForm"] .stImage { + /* width: 6rem !important; + */ + width: 5rem !important; +} + +[data-testid="stForm"] .img { + /* width: 6rem !important; + */ + max-width: 85px; +} + +@font-face { + font-family: "Josefin Slab", serif; + font-style: normal; + font-weight: 300; +} +header, +h1, +h2, +h3 [class*="css"] { + font-family: "Josefin Slab", serif; + font-style: normal; + font-weight: 300; +} + +h1 { + font-weight: bold; +} + +body { + font-size: medium; +} + +/* .hhh{ + color: black; + background-color:#fff; +} + +.show:hover .hhh{ + color: white; +} */ + +[data-testid="stForm"] .stButton { + box-shadow: 0 5px 6px -4px #c7cdce; + /* width: auto; + font-size: 20pt; + float: right; + box-shadow: rgba(44, 43, 43, 0.5) 0px 0px 0px 0.2rem; */ + /* width: 60%; */ +} + +.stButton button { + width: 100%; +} + +::-webkit-scrollbar:vertical { + width: 10px; +} + +/* Track */ +::-webkit-scrollbar-track:vertical { + background: #f1f1f1; +} + +/* Handle */ +::-webkit-scrollbar-thumb:vertical { + background: #888; +} + +/* Handle on hover */ +::-webkit-scrollbar-thumb:vertical:hover { + background: #555; +} diff --git a/example_with_langchain_and_vectorstore/chatglm_llm.py b/example_with_langchain_and_vectorstore/chatglm_llm.py new file mode 100644 index 0000000..d7943f6 --- /dev/null +++ b/example_with_langchain_and_vectorstore/chatglm_llm.py @@ -0,0 +1,35 @@ +from langchain.llms.base import LLM +from typing import Optional, List, Mapping, Any +from langchain.llms.utils import enforce_stop_tokens +from transformers import AutoTokenizer, AutoModel + +"""ChatGLM_G is a wrapper around the ChatGLM model to fit LangChain framework. May not be an optimal implementation""" + +class ChatGLM_G(LLM): + + tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm-6b-int4", trust_remote_code=True) + model = AutoModel.from_pretrained("THUDM/chatglm-6b-int4", trust_remote_code=True).half().cuda() + history = [] + + @property + def _llm_type(self) -> str: + return "ChatGLM_G" + + def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str: + response, updated_history = self.model.chat(self.tokenizer, prompt, history=self.history) + print("ChatGLM: prompt: ", prompt) + print("ChatGLM: response: ", response) + if stop is not None: + response = enforce_stop_tokens(response, stop) + self.history = updated_history + return response + + def __call__(self, prompt: str, stop: Optional[List[str]] = None) -> str: + response, updated_history = self.model.chat(self.tokenizer, prompt, history=self.history) + print("ChatGLM: prompt: ", prompt) + print("ChatGLM: response: ", response) + if stop is not None: + response = enforce_stop_tokens(response, stop) + self.history = updated_history + + return response \ No newline at end of file diff --git a/example_with_langchain_and_vectorstore/index/wiki_faiss_2023_03_06/index.faiss b/example_with_langchain_and_vectorstore/index/wiki_faiss_2023_03_06/index.faiss new file mode 100644 index 0000000..6558696 Binary files /dev/null and b/example_with_langchain_and_vectorstore/index/wiki_faiss_2023_03_06/index.faiss differ diff --git a/example_with_langchain_and_vectorstore/index/wiki_faiss_2023_03_06/index.pkl b/example_with_langchain_and_vectorstore/index/wiki_faiss_2023_03_06/index.pkl new file mode 100644 index 0000000..3191ae8 Binary files /dev/null and b/example_with_langchain_and_vectorstore/index/wiki_faiss_2023_03_06/index.pkl differ diff --git a/example_with_langchain_and_vectorstore/logo.png b/example_with_langchain_and_vectorstore/logo.png new file mode 100644 index 0000000..63af50a Binary files /dev/null and b/example_with_langchain_and_vectorstore/logo.png differ diff --git a/example_with_langchain_and_vectorstore/webapp_with_vectorstore.py b/example_with_langchain_and_vectorstore/webapp_with_vectorstore.py new file mode 100644 index 0000000..f70fc43 --- /dev/null +++ b/example_with_langchain_and_vectorstore/webapp_with_vectorstore.py @@ -0,0 +1,305 @@ +import streamlit as st +from chat_backend import chat_with_agent, init_wiki_agent, get_wiki_agent_answer +from streamlit.components.v1 import html + +import os +import streamlit as st +from PIL import Image +import html +import uuid + +path = os.path.dirname(__file__) + + +icon_img = Image.open(os.path.join(path, "logo.png")) + +USER_NAME = "Me" +AGENT_NAME = "Helpbot" + + +st.set_page_config( + page_title="ChatGLM", + page_icon=icon_img, + layout="wide", + # initial_sidebar_state="collapsed", +) + +st.write( + "", unsafe_allow_html=True +) + + +def local_css(file_name): + with open(file_name) as f: + st.markdown(f"", unsafe_allow_html=True) + + +def remote_css(url): + st.markdown(f'', unsafe_allow_html=True) + + +def icon(icon_name): + st.markdown(f'{icon_name}', unsafe_allow_html=True) + + +def javascript(source: str) -> None: + """loading javascript correctly""" + div_id = uuid.uuid4() + + st.markdown( + f""" +
+ """, + unsafe_allow_html=True, + ) + + +local_css("chat_style.css") + + +st.markdown( + """ + + """, + unsafe_allow_html=True, +) + +# User Input and Send button +user_profile_image = "https://img1.baidu.com/it/u=3150659458,3834452201&fm=253&fmt=auto&app=138&f=JPEG?w=369&h=378" + +chatgpt_profile_image = "https://avatars.githubusercontent.com/u/44095251?v=4" + + +def display_chat_log(cur_container): + for cur_conversation in st.session_state["chat_log"]: + for msg in cur_conversation: + if msg["role"] == USER_NAME: + cur_container.markdown( + "