mirror of https://github.com/THUDM/ChatGLM-6B
Merge branch 'dev'
commit
69a4c3193f
|
@ -0,0 +1,133 @@
|
||||||
|
# Byte-compiled / optimized / DLL files
|
||||||
|
__pycache__/
|
||||||
|
*.py[cod]
|
||||||
|
*$py.class
|
||||||
|
|
||||||
|
# C extensions
|
||||||
|
*.so
|
||||||
|
|
||||||
|
# Distribution / packaging
|
||||||
|
.Python
|
||||||
|
build/
|
||||||
|
develop-eggs/
|
||||||
|
dist/
|
||||||
|
downloads/
|
||||||
|
eggs/
|
||||||
|
.eggs/
|
||||||
|
lib/
|
||||||
|
lib64/
|
||||||
|
parts/
|
||||||
|
sdist/
|
||||||
|
var/
|
||||||
|
wheels/
|
||||||
|
pip-wheel-metadata/
|
||||||
|
share/python-wheels/
|
||||||
|
*.egg-info/
|
||||||
|
.installed.cfg
|
||||||
|
*.egg
|
||||||
|
MANIFEST
|
||||||
|
history/
|
||||||
|
|
||||||
|
# PyInstaller
|
||||||
|
# Usually these files are written by a python script from a template
|
||||||
|
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
||||||
|
*.manifest
|
||||||
|
*.spec
|
||||||
|
|
||||||
|
# Installer logs
|
||||||
|
pip-log.txt
|
||||||
|
pip-delete-this-directory.txt
|
||||||
|
|
||||||
|
# Unit test / coverage reports
|
||||||
|
htmlcov/
|
||||||
|
.tox/
|
||||||
|
.nox/
|
||||||
|
.coverage
|
||||||
|
.coverage.*
|
||||||
|
.cache
|
||||||
|
nosetests.xml
|
||||||
|
coverage.xml
|
||||||
|
*.cover
|
||||||
|
*.py,cover
|
||||||
|
.hypothesis/
|
||||||
|
.pytest_cache/
|
||||||
|
|
||||||
|
# Translations
|
||||||
|
*.mo
|
||||||
|
*.pot
|
||||||
|
|
||||||
|
# Django stuff:
|
||||||
|
*.log
|
||||||
|
local_settings.py
|
||||||
|
db.sqlite3
|
||||||
|
db.sqlite3-journal
|
||||||
|
|
||||||
|
# Flask stuff:
|
||||||
|
instance/
|
||||||
|
.webassets-cache
|
||||||
|
|
||||||
|
# Scrapy stuff:
|
||||||
|
.scrapy
|
||||||
|
|
||||||
|
# Sphinx documentation
|
||||||
|
docs/_build/
|
||||||
|
|
||||||
|
# PyBuilder
|
||||||
|
target/
|
||||||
|
|
||||||
|
# Jupyter Notebook
|
||||||
|
.ipynb_checkpoints
|
||||||
|
|
||||||
|
# IPython
|
||||||
|
profile_default/
|
||||||
|
ipython_config.py
|
||||||
|
|
||||||
|
# pyenv
|
||||||
|
.python-version
|
||||||
|
|
||||||
|
# pipenv
|
||||||
|
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
||||||
|
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
||||||
|
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
||||||
|
# install all needed dependencies.
|
||||||
|
#Pipfile.lock
|
||||||
|
|
||||||
|
# PEP 582; used by e.g. github.com/David-OConnor/pyflow
|
||||||
|
__pypackages__/
|
||||||
|
|
||||||
|
# Celery stuff
|
||||||
|
celerybeat-schedule
|
||||||
|
celerybeat.pid
|
||||||
|
|
||||||
|
# SageMath parsed files
|
||||||
|
*.sage.py
|
||||||
|
|
||||||
|
# Environments
|
||||||
|
.env
|
||||||
|
.venv
|
||||||
|
env/
|
||||||
|
venv/
|
||||||
|
ENV/
|
||||||
|
env.bak/
|
||||||
|
venv.bak/
|
||||||
|
|
||||||
|
# Spyder project settings
|
||||||
|
.spyderproject
|
||||||
|
.spyproject
|
||||||
|
|
||||||
|
# Rope project settings
|
||||||
|
.ropeproject
|
||||||
|
|
||||||
|
# mkdocs documentation
|
||||||
|
/site
|
||||||
|
|
||||||
|
# mypy
|
||||||
|
.mypy_cache/
|
||||||
|
.dmypy.json
|
||||||
|
dmypy.json
|
||||||
|
|
||||||
|
# Pyre type checker
|
||||||
|
.pyre/
|
||||||
|
|
||||||
|
# Mac system file
|
||||||
|
model/
|
|
@ -4,3 +4,4 @@ icetk
|
||||||
cpm_kernels
|
cpm_kernels
|
||||||
torch>=1.10
|
torch>=1.10
|
||||||
gradio
|
gradio
|
||||||
|
mdtex2html
|
104
web_demo.py
104
web_demo.py
|
@ -1,45 +1,101 @@
|
||||||
from transformers import AutoModel, AutoTokenizer
|
from transformers import AutoModel, AutoTokenizer
|
||||||
import gradio as gr
|
import gradio as gr
|
||||||
|
import mdtex2html
|
||||||
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True)
|
tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True)
|
||||||
model = AutoModel.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True).half().cuda()
|
model = AutoModel.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True).half().cuda()
|
||||||
model = model.eval()
|
model = model.eval()
|
||||||
|
|
||||||
MAX_TURNS = 20
|
"""Override Chatbot.postprocess"""
|
||||||
MAX_BOXES = MAX_TURNS * 2
|
|
||||||
|
|
||||||
|
|
||||||
def predict(input, max_length, top_p, temperature, history=None):
|
def postprocess(self, y):
|
||||||
if history is None:
|
if y is None:
|
||||||
history = []
|
return []
|
||||||
|
for i, (message, response) in enumerate(y):
|
||||||
|
y[i] = (
|
||||||
|
None if message is None else mdtex2html.convert((message)),
|
||||||
|
None if response is None else mdtex2html.convert(response),
|
||||||
|
)
|
||||||
|
return y
|
||||||
|
|
||||||
|
|
||||||
|
gr.Chatbot.postprocess = postprocess
|
||||||
|
|
||||||
|
|
||||||
|
def parse_text(text):
|
||||||
|
"""copy from https://github.com/GaiZhenbiao/ChuanhuChatGPT/"""
|
||||||
|
lines = text.split("\n")
|
||||||
|
lines = [line for line in lines if line != ""]
|
||||||
|
count = 0
|
||||||
|
for i, line in enumerate(lines):
|
||||||
|
if "```" in line:
|
||||||
|
count += 1
|
||||||
|
items = line.split('`')
|
||||||
|
if count % 2 == 1:
|
||||||
|
lines[i] = f'<pre><code class="language-{items[-1]}">'
|
||||||
|
else:
|
||||||
|
lines[i] = f'<br></code></pre>'
|
||||||
|
else:
|
||||||
|
if i > 0:
|
||||||
|
if count % 2 == 1:
|
||||||
|
line = line.replace("`", "\`")
|
||||||
|
line = line.replace("<", "<")
|
||||||
|
line = line.replace(">", ">")
|
||||||
|
line = line.replace(" ", " ")
|
||||||
|
line = line.replace("*", "*")
|
||||||
|
line = line.replace("_", "_")
|
||||||
|
line = line.replace("-", "-")
|
||||||
|
line = line.replace(".", ".")
|
||||||
|
line = line.replace("!", "!")
|
||||||
|
line = line.replace("(", "(")
|
||||||
|
line = line.replace(")", ")")
|
||||||
|
line = line.replace("$", "$")
|
||||||
|
lines[i] = "<br>"+line
|
||||||
|
text = "".join(lines)
|
||||||
|
return text
|
||||||
|
|
||||||
|
|
||||||
|
def predict(input, chatbot, max_length, top_p, temperature, history):
|
||||||
|
chatbot.append((parse_text(input), ""))
|
||||||
for response, history in model.stream_chat(tokenizer, input, history, max_length=max_length, top_p=top_p,
|
for response, history in model.stream_chat(tokenizer, input, history, max_length=max_length, top_p=top_p,
|
||||||
temperature=temperature):
|
temperature=temperature):
|
||||||
updates = []
|
chatbot[-1] = (parse_text(input), parse_text(response))
|
||||||
for query, response in history:
|
|
||||||
updates.append(gr.update(visible=True, value="用户:" + query))
|
yield chatbot, history
|
||||||
updates.append(gr.update(visible=True, value="ChatGLM-6B:" + response))
|
|
||||||
if len(updates) < MAX_BOXES:
|
|
||||||
updates = updates + [gr.Textbox.update(visible=False)] * (MAX_BOXES - len(updates))
|
def reset_user_input():
|
||||||
yield [history] + updates
|
return gr.update(value='')
|
||||||
|
|
||||||
|
|
||||||
|
def reset_state():
|
||||||
|
return [], []
|
||||||
|
|
||||||
|
|
||||||
with gr.Blocks() as demo:
|
with gr.Blocks() as demo:
|
||||||
state = gr.State([])
|
gr.HTML("""<h1 align="center">ChatGLM</h1>""")
|
||||||
text_boxes = []
|
|
||||||
for i in range(MAX_BOXES):
|
|
||||||
if i % 2 == 0:
|
|
||||||
text_boxes.append(gr.Markdown(visible=False, label="提问:"))
|
|
||||||
else:
|
|
||||||
text_boxes.append(gr.Markdown(visible=False, label="回复:"))
|
|
||||||
|
|
||||||
|
chatbot = gr.Chatbot()
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
with gr.Column(scale=4):
|
with gr.Column(scale=4):
|
||||||
txt = gr.Textbox(show_label=False, placeholder="Enter text and press enter", lines=11).style(
|
with gr.Column(scale=12):
|
||||||
container=False)
|
user_input = gr.Textbox(show_label=False, placeholder="Input...", lines=10).style(
|
||||||
|
container=False)
|
||||||
|
with gr.Column(min_width=32, scale=1):
|
||||||
|
submitBtn = gr.Button("Submit", variant="primary")
|
||||||
with gr.Column(scale=1):
|
with gr.Column(scale=1):
|
||||||
|
emptyBtn = gr.Button("Clear History")
|
||||||
max_length = gr.Slider(0, 4096, value=2048, step=1.0, label="Maximum length", interactive=True)
|
max_length = gr.Slider(0, 4096, value=2048, step=1.0, label="Maximum length", interactive=True)
|
||||||
top_p = gr.Slider(0, 1, value=0.7, step=0.01, label="Top P", interactive=True)
|
top_p = gr.Slider(0, 1, value=0.7, step=0.01, label="Top P", interactive=True)
|
||||||
temperature = gr.Slider(0, 1, value=0.95, step=0.01, label="Temperature", interactive=True)
|
temperature = gr.Slider(0, 1, value=0.95, step=0.01, label="Temperature", interactive=True)
|
||||||
button = gr.Button("Generate")
|
|
||||||
button.click(predict, [txt, max_length, top_p, temperature, state], [state] + text_boxes)
|
history = gr.State([])
|
||||||
demo.queue().launch(share=False, inbrowser=True)
|
|
||||||
|
submitBtn.click(predict, [user_input, chatbot, max_length, top_p, temperature, history], [chatbot, history],
|
||||||
|
show_progress=True)
|
||||||
|
submitBtn.click(reset_user_input, [], [user_input])
|
||||||
|
|
||||||
|
emptyBtn.click(reset_state, outputs=[chatbot, history], show_progress=True)
|
||||||
|
|
||||||
|
demo.queue().launch(share=True, inbrowser=True)
|
||||||
|
|
|
@ -0,0 +1,45 @@
|
||||||
|
from transformers import AutoModel, AutoTokenizer
|
||||||
|
import gradio as gr
|
||||||
|
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True)
|
||||||
|
model = AutoModel.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True).half().cuda()
|
||||||
|
model = model.eval()
|
||||||
|
|
||||||
|
MAX_TURNS = 20
|
||||||
|
MAX_BOXES = MAX_TURNS * 2
|
||||||
|
|
||||||
|
|
||||||
|
def predict(input, max_length, top_p, temperature, history=None):
|
||||||
|
if history is None:
|
||||||
|
history = []
|
||||||
|
for response, history in model.stream_chat(tokenizer, input, history, max_length=max_length, top_p=top_p,
|
||||||
|
temperature=temperature):
|
||||||
|
updates = []
|
||||||
|
for query, response in history:
|
||||||
|
updates.append(gr.update(visible=True, value="用户:" + query))
|
||||||
|
updates.append(gr.update(visible=True, value="ChatGLM-6B:" + response))
|
||||||
|
if len(updates) < MAX_BOXES:
|
||||||
|
updates = updates + [gr.Textbox.update(visible=False)] * (MAX_BOXES - len(updates))
|
||||||
|
yield [history] + updates
|
||||||
|
|
||||||
|
|
||||||
|
with gr.Blocks() as demo:
|
||||||
|
state = gr.State([])
|
||||||
|
text_boxes = []
|
||||||
|
for i in range(MAX_BOXES):
|
||||||
|
if i % 2 == 0:
|
||||||
|
text_boxes.append(gr.Markdown(visible=False, label="提问:"))
|
||||||
|
else:
|
||||||
|
text_boxes.append(gr.Markdown(visible=False, label="回复:"))
|
||||||
|
|
||||||
|
with gr.Row():
|
||||||
|
with gr.Column(scale=4):
|
||||||
|
txt = gr.Textbox(show_label=False, placeholder="Enter text and press enter", lines=11).style(
|
||||||
|
container=False)
|
||||||
|
with gr.Column(scale=1):
|
||||||
|
max_length = gr.Slider(0, 4096, value=2048, step=1.0, label="Maximum length", interactive=True)
|
||||||
|
top_p = gr.Slider(0, 1, value=0.7, step=0.01, label="Top P", interactive=True)
|
||||||
|
temperature = gr.Slider(0, 1, value=0.95, step=0.01, label="Temperature", interactive=True)
|
||||||
|
button = gr.Button("Generate")
|
||||||
|
button.click(predict, [txt, max_length, top_p, temperature, state], [state] + text_boxes)
|
||||||
|
demo.queue().launch(share=False, inbrowser=True)
|
Loading…
Reference in New Issue