mirror of https://github.com/THUDM/ChatGLM-6B
				
				
				
			add web_demo3
							parent
							
								
									a1fcd52182
								
							
						
					
					
						commit
						6148d6d6ac
					
				| 
						 | 
				
			
			@ -0,0 +1,133 @@
 | 
			
		|||
# Byte-compiled / optimized / DLL files
 | 
			
		||||
__pycache__/
 | 
			
		||||
*.py[cod]
 | 
			
		||||
*$py.class
 | 
			
		||||
 | 
			
		||||
# C extensions
 | 
			
		||||
*.so
 | 
			
		||||
 | 
			
		||||
# Distribution / packaging
 | 
			
		||||
.Python
 | 
			
		||||
build/
 | 
			
		||||
develop-eggs/
 | 
			
		||||
dist/
 | 
			
		||||
downloads/
 | 
			
		||||
eggs/
 | 
			
		||||
.eggs/
 | 
			
		||||
lib/
 | 
			
		||||
lib64/
 | 
			
		||||
parts/
 | 
			
		||||
sdist/
 | 
			
		||||
var/
 | 
			
		||||
wheels/
 | 
			
		||||
pip-wheel-metadata/
 | 
			
		||||
share/python-wheels/
 | 
			
		||||
*.egg-info/
 | 
			
		||||
.installed.cfg
 | 
			
		||||
*.egg
 | 
			
		||||
MANIFEST
 | 
			
		||||
history/
 | 
			
		||||
 | 
			
		||||
# PyInstaller
 | 
			
		||||
#  Usually these files are written by a python script from a template
 | 
			
		||||
#  before PyInstaller builds the exe, so as to inject date/other infos into it.
 | 
			
		||||
*.manifest
 | 
			
		||||
*.spec
 | 
			
		||||
 | 
			
		||||
# Installer logs
 | 
			
		||||
pip-log.txt
 | 
			
		||||
pip-delete-this-directory.txt
 | 
			
		||||
 | 
			
		||||
# Unit test / coverage reports
 | 
			
		||||
htmlcov/
 | 
			
		||||
.tox/
 | 
			
		||||
.nox/
 | 
			
		||||
.coverage
 | 
			
		||||
.coverage.*
 | 
			
		||||
.cache
 | 
			
		||||
nosetests.xml
 | 
			
		||||
coverage.xml
 | 
			
		||||
*.cover
 | 
			
		||||
*.py,cover
 | 
			
		||||
.hypothesis/
 | 
			
		||||
.pytest_cache/
 | 
			
		||||
 | 
			
		||||
# Translations
 | 
			
		||||
*.mo
 | 
			
		||||
*.pot
 | 
			
		||||
 | 
			
		||||
# Django stuff:
 | 
			
		||||
*.log
 | 
			
		||||
local_settings.py
 | 
			
		||||
db.sqlite3
 | 
			
		||||
db.sqlite3-journal
 | 
			
		||||
 | 
			
		||||
# Flask stuff:
 | 
			
		||||
instance/
 | 
			
		||||
.webassets-cache
 | 
			
		||||
 | 
			
		||||
# Scrapy stuff:
 | 
			
		||||
.scrapy
 | 
			
		||||
 | 
			
		||||
# Sphinx documentation
 | 
			
		||||
docs/_build/
 | 
			
		||||
 | 
			
		||||
# PyBuilder
 | 
			
		||||
target/
 | 
			
		||||
 | 
			
		||||
# Jupyter Notebook
 | 
			
		||||
.ipynb_checkpoints
 | 
			
		||||
 | 
			
		||||
# IPython
 | 
			
		||||
profile_default/
 | 
			
		||||
ipython_config.py
 | 
			
		||||
 | 
			
		||||
# pyenv
 | 
			
		||||
.python-version
 | 
			
		||||
 | 
			
		||||
# pipenv
 | 
			
		||||
#   According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
 | 
			
		||||
#   However, in case of collaboration, if having platform-specific dependencies or dependencies
 | 
			
		||||
#   having no cross-platform support, pipenv may install dependencies that don't work, or not
 | 
			
		||||
#   install all needed dependencies.
 | 
			
		||||
#Pipfile.lock
 | 
			
		||||
 | 
			
		||||
# PEP 582; used by e.g. github.com/David-OConnor/pyflow
 | 
			
		||||
__pypackages__/
 | 
			
		||||
 | 
			
		||||
# Celery stuff
 | 
			
		||||
celerybeat-schedule
 | 
			
		||||
celerybeat.pid
 | 
			
		||||
 | 
			
		||||
# SageMath parsed files
 | 
			
		||||
*.sage.py
 | 
			
		||||
 | 
			
		||||
# Environments
 | 
			
		||||
.env
 | 
			
		||||
.venv
 | 
			
		||||
env/
 | 
			
		||||
venv/
 | 
			
		||||
ENV/
 | 
			
		||||
env.bak/
 | 
			
		||||
venv.bak/
 | 
			
		||||
 | 
			
		||||
# Spyder project settings
 | 
			
		||||
.spyderproject
 | 
			
		||||
.spyproject
 | 
			
		||||
 | 
			
		||||
# Rope project settings
 | 
			
		||||
.ropeproject
 | 
			
		||||
 | 
			
		||||
# mkdocs documentation
 | 
			
		||||
/site
 | 
			
		||||
 | 
			
		||||
# mypy
 | 
			
		||||
.mypy_cache/
 | 
			
		||||
.dmypy.json
 | 
			
		||||
dmypy.json
 | 
			
		||||
 | 
			
		||||
# Pyre type checker
 | 
			
		||||
.pyre/
 | 
			
		||||
 | 
			
		||||
# Mac system file
 | 
			
		||||
model/
 | 
			
		||||
| 
						 | 
				
			
			@ -4,3 +4,4 @@ icetk
 | 
			
		|||
cpm_kernels
 | 
			
		||||
torch>=1.10
 | 
			
		||||
gradio
 | 
			
		||||
mdtex2html
 | 
			
		||||
| 
						 | 
				
			
			@ -0,0 +1,69 @@
 | 
			
		|||
from transformers import AutoModel, AutoTokenizer
 | 
			
		||||
import gradio as gr
 | 
			
		||||
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Tuple, Type
 | 
			
		||||
import mdtex2html
 | 
			
		||||
 | 
			
		||||
tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True)
 | 
			
		||||
model = AutoModel.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True).quantize(8).half().cuda()
 | 
			
		||||
model = model.eval()
 | 
			
		||||
 | 
			
		||||
# MAX_TURNS = 20
 | 
			
		||||
# MAX_BOXES = MAX_TURNS * 2
 | 
			
		||||
 | 
			
		||||
"""Override Chatbot.postprocess"""
 | 
			
		||||
def postprocess(self, y):
 | 
			
		||||
        if y is None:
 | 
			
		||||
            return []
 | 
			
		||||
        for i, (message, response) in enumerate(y):
 | 
			
		||||
            y[i] = (
 | 
			
		||||
                None if message is None else mdtex2html.convert((message)),
 | 
			
		||||
                None if response is None else mdtex2html.convert(response),
 | 
			
		||||
            )
 | 
			
		||||
        return y
 | 
			
		||||
gr.Chatbot.postprocess = postprocess
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def predict(input, chatbot, max_length, top_p, temperature, history):
 | 
			
		||||
    chatbot.append((input, ""))
 | 
			
		||||
    for response, history in model.stream_chat(tokenizer, input, history, max_length=max_length, top_p=top_p,
 | 
			
		||||
                                               temperature=temperature):
 | 
			
		||||
        chatbot[-1] = (input, response)       
 | 
			
		||||
        yield chatbot, history
 | 
			
		||||
 | 
			
		||||
def reset_user_input():
 | 
			
		||||
    return gr.update(value='')
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def reset_state():
 | 
			
		||||
    return [], []
 | 
			
		||||
 | 
			
		||||
with gr.Blocks() as demo:
 | 
			
		||||
    gr.HTML("""<h1 align="center">ChatGLM</h1>""")
 | 
			
		||||
 | 
			
		||||
    with gr.Row():
 | 
			
		||||
        with gr.Column(scale=4):
 | 
			
		||||
            chatbot = gr.Chatbot()
 | 
			
		||||
            with gr.Row():
 | 
			
		||||
                with gr.Column(scale=12):
 | 
			
		||||
                    user_input = gr.Textbox(show_label=False, placeholder="Input...").style(
 | 
			
		||||
                        container=False)
 | 
			
		||||
                with gr.Column(min_width=32, scale=1):
 | 
			
		||||
                    submitBtn = gr.Button("Submit", variant="primary")
 | 
			
		||||
        with gr.Column(scale=1):
 | 
			
		||||
            emptyBtn = gr.Button("Clear History")
 | 
			
		||||
            max_length = gr.Slider(0, 4096, value=2048, step=1.0, label="Maximum length", interactive=True)
 | 
			
		||||
            top_p = gr.Slider(0, 1, value=0.7, step=0.01, label="Top P", interactive=True)
 | 
			
		||||
            temperature = gr.Slider(0, 1, value=0.95, step=0.01, label="Temperature", interactive=True)
 | 
			
		||||
    
 | 
			
		||||
    history = gr.State([])
 | 
			
		||||
 | 
			
		||||
    user_input.submit(predict, [user_input, chatbot, max_length, top_p, temperature, history], [chatbot, history], show_progress=True)
 | 
			
		||||
    user_input.submit(reset_user_input, [], [user_input])
 | 
			
		||||
 | 
			
		||||
    submitBtn.click(predict, [user_input, chatbot, max_length, top_p, temperature, history], [chatbot, history], show_progress=True)
 | 
			
		||||
    submitBtn.click(reset_user_input, [], [user_input])
 | 
			
		||||
 | 
			
		||||
    emptyBtn.click(reset_state, outputs=[chatbot, history], show_progress=True)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
demo.queue().launch(share=False, inbrowser=True)
 | 
			
		||||
		Loading…
	
		Reference in New Issue