diff --git a/web_demo.py b/web_demo.py
index 88a6dc8..df7f983 100644
--- a/web_demo.py
+++ b/web_demo.py
@@ -1,45 +1,101 @@
from transformers import AutoModel, AutoTokenizer
import gradio as gr
+import mdtex2html
tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True)
model = AutoModel.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True).half().cuda()
model = model.eval()
-MAX_TURNS = 20
-MAX_BOXES = MAX_TURNS * 2
+"""Override Chatbot.postprocess"""
-def predict(input, max_length, top_p, temperature, history=None):
- if history is None:
- history = []
+def postprocess(self, y):
+ if y is None:
+ return []
+ for i, (message, response) in enumerate(y):
+ y[i] = (
+ None if message is None else mdtex2html.convert((message)),
+ None if response is None else mdtex2html.convert(response),
+ )
+ return y
+
+
+gr.Chatbot.postprocess = postprocess
+
+
+def parse_text(text):
+ """copy from https://github.com/GaiZhenbiao/ChuanhuChatGPT/"""
+ lines = text.split("\n")
+ lines = [line for line in lines if line != ""]
+ count = 0
+ for i, line in enumerate(lines):
+ if "```" in line:
+ count += 1
+ items = line.split('`')
+ if count % 2 == 1:
+ lines[i] = f'
'
+ else:
+ lines[i] = f'
'
+ else:
+ if i > 0:
+ if count % 2 == 1:
+ line = line.replace("`", "\`")
+ line = line.replace("<", "<")
+ line = line.replace(">", ">")
+ line = line.replace(" ", " ")
+ line = line.replace("*", "*")
+ line = line.replace("_", "_")
+ line = line.replace("-", "-")
+ line = line.replace(".", ".")
+ line = line.replace("!", "!")
+ line = line.replace("(", "(")
+ line = line.replace(")", ")")
+ line = line.replace("$", "$")
+ lines[i] = "
"+line
+ text = "".join(lines)
+ return text
+
+
+def predict(input, chatbot, max_length, top_p, temperature, history):
+ chatbot.append((parse_text(input), ""))
for response, history in model.stream_chat(tokenizer, input, history, max_length=max_length, top_p=top_p,
temperature=temperature):
- updates = []
- for query, response in history:
- updates.append(gr.update(visible=True, value="用户:" + query))
- updates.append(gr.update(visible=True, value="ChatGLM-6B:" + response))
- if len(updates) < MAX_BOXES:
- updates = updates + [gr.Textbox.update(visible=False)] * (MAX_BOXES - len(updates))
- yield [history] + updates
+ chatbot[-1] = (parse_text(input), parse_text(response))
+
+ yield chatbot, history
+
+
+def reset_user_input():
+ return gr.update(value='')
+
+
+def reset_state():
+ return [], []
with gr.Blocks() as demo:
- state = gr.State([])
- text_boxes = []
- for i in range(MAX_BOXES):
- if i % 2 == 0:
- text_boxes.append(gr.Markdown(visible=False, label="提问:"))
- else:
- text_boxes.append(gr.Markdown(visible=False, label="回复:"))
+ gr.HTML("""ChatGLM
""")
+ chatbot = gr.Chatbot()
with gr.Row():
with gr.Column(scale=4):
- txt = gr.Textbox(show_label=False, placeholder="Enter text and press enter", lines=11).style(
- container=False)
+ with gr.Column(scale=12):
+ user_input = gr.Textbox(show_label=False, placeholder="Input...", lines=10).style(
+ container=False)
+ with gr.Column(min_width=32, scale=1):
+ submitBtn = gr.Button("Submit", variant="primary")
with gr.Column(scale=1):
+ emptyBtn = gr.Button("Clear History")
max_length = gr.Slider(0, 4096, value=2048, step=1.0, label="Maximum length", interactive=True)
top_p = gr.Slider(0, 1, value=0.7, step=0.01, label="Top P", interactive=True)
temperature = gr.Slider(0, 1, value=0.95, step=0.01, label="Temperature", interactive=True)
- button = gr.Button("Generate")
- button.click(predict, [txt, max_length, top_p, temperature, state], [state] + text_boxes)
-demo.queue().launch(share=False, inbrowser=True)
+
+ history = gr.State([])
+
+ submitBtn.click(predict, [user_input, chatbot, max_length, top_p, temperature, history], [chatbot, history],
+ show_progress=True)
+ submitBtn.click(reset_user_input, [], [user_input])
+
+ emptyBtn.click(reset_state, outputs=[chatbot, history], show_progress=True)
+
+demo.queue().launch(share=True, inbrowser=True)
diff --git a/web_demo3.py b/web_demo3.py
deleted file mode 100644
index df7f983..0000000
--- a/web_demo3.py
+++ /dev/null
@@ -1,101 +0,0 @@
-from transformers import AutoModel, AutoTokenizer
-import gradio as gr
-import mdtex2html
-
-tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True)
-model = AutoModel.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True).half().cuda()
-model = model.eval()
-
-"""Override Chatbot.postprocess"""
-
-
-def postprocess(self, y):
- if y is None:
- return []
- for i, (message, response) in enumerate(y):
- y[i] = (
- None if message is None else mdtex2html.convert((message)),
- None if response is None else mdtex2html.convert(response),
- )
- return y
-
-
-gr.Chatbot.postprocess = postprocess
-
-
-def parse_text(text):
- """copy from https://github.com/GaiZhenbiao/ChuanhuChatGPT/"""
- lines = text.split("\n")
- lines = [line for line in lines if line != ""]
- count = 0
- for i, line in enumerate(lines):
- if "```" in line:
- count += 1
- items = line.split('`')
- if count % 2 == 1:
- lines[i] = f''
- else:
- lines[i] = f'
'
- else:
- if i > 0:
- if count % 2 == 1:
- line = line.replace("`", "\`")
- line = line.replace("<", "<")
- line = line.replace(">", ">")
- line = line.replace(" ", " ")
- line = line.replace("*", "*")
- line = line.replace("_", "_")
- line = line.replace("-", "-")
- line = line.replace(".", ".")
- line = line.replace("!", "!")
- line = line.replace("(", "(")
- line = line.replace(")", ")")
- line = line.replace("$", "$")
- lines[i] = "
"+line
- text = "".join(lines)
- return text
-
-
-def predict(input, chatbot, max_length, top_p, temperature, history):
- chatbot.append((parse_text(input), ""))
- for response, history in model.stream_chat(tokenizer, input, history, max_length=max_length, top_p=top_p,
- temperature=temperature):
- chatbot[-1] = (parse_text(input), parse_text(response))
-
- yield chatbot, history
-
-
-def reset_user_input():
- return gr.update(value='')
-
-
-def reset_state():
- return [], []
-
-
-with gr.Blocks() as demo:
- gr.HTML("""ChatGLM
""")
-
- chatbot = gr.Chatbot()
- with gr.Row():
- with gr.Column(scale=4):
- with gr.Column(scale=12):
- user_input = gr.Textbox(show_label=False, placeholder="Input...", lines=10).style(
- container=False)
- with gr.Column(min_width=32, scale=1):
- submitBtn = gr.Button("Submit", variant="primary")
- with gr.Column(scale=1):
- emptyBtn = gr.Button("Clear History")
- max_length = gr.Slider(0, 4096, value=2048, step=1.0, label="Maximum length", interactive=True)
- top_p = gr.Slider(0, 1, value=0.7, step=0.01, label="Top P", interactive=True)
- temperature = gr.Slider(0, 1, value=0.95, step=0.01, label="Temperature", interactive=True)
-
- history = gr.State([])
-
- submitBtn.click(predict, [user_input, chatbot, max_length, top_p, temperature, history], [chatbot, history],
- show_progress=True)
- submitBtn.click(reset_user_input, [], [user_input])
-
- emptyBtn.click(reset_state, outputs=[chatbot, history], show_progress=True)
-
-demo.queue().launch(share=True, inbrowser=True)
diff --git a/web_demo_old.py b/web_demo_old.py
new file mode 100644
index 0000000..88a6dc8
--- /dev/null
+++ b/web_demo_old.py
@@ -0,0 +1,45 @@
+from transformers import AutoModel, AutoTokenizer
+import gradio as gr
+
+tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True)
+model = AutoModel.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True).half().cuda()
+model = model.eval()
+
+MAX_TURNS = 20
+MAX_BOXES = MAX_TURNS * 2
+
+
+def predict(input, max_length, top_p, temperature, history=None):
+ if history is None:
+ history = []
+ for response, history in model.stream_chat(tokenizer, input, history, max_length=max_length, top_p=top_p,
+ temperature=temperature):
+ updates = []
+ for query, response in history:
+ updates.append(gr.update(visible=True, value="用户:" + query))
+ updates.append(gr.update(visible=True, value="ChatGLM-6B:" + response))
+ if len(updates) < MAX_BOXES:
+ updates = updates + [gr.Textbox.update(visible=False)] * (MAX_BOXES - len(updates))
+ yield [history] + updates
+
+
+with gr.Blocks() as demo:
+ state = gr.State([])
+ text_boxes = []
+ for i in range(MAX_BOXES):
+ if i % 2 == 0:
+ text_boxes.append(gr.Markdown(visible=False, label="提问:"))
+ else:
+ text_boxes.append(gr.Markdown(visible=False, label="回复:"))
+
+ with gr.Row():
+ with gr.Column(scale=4):
+ txt = gr.Textbox(show_label=False, placeholder="Enter text and press enter", lines=11).style(
+ container=False)
+ with gr.Column(scale=1):
+ max_length = gr.Slider(0, 4096, value=2048, step=1.0, label="Maximum length", interactive=True)
+ top_p = gr.Slider(0, 1, value=0.7, step=0.01, label="Top P", interactive=True)
+ temperature = gr.Slider(0, 1, value=0.95, step=0.01, label="Temperature", interactive=True)
+ button = gr.Button("Generate")
+ button.click(predict, [txt, max_length, top_p, temperature, state], [state] + text_boxes)
+demo.queue().launch(share=False, inbrowser=True)