From 6148d6d6ac41a416846df1ebfa9b1341a1f69859 Mon Sep 17 00:00:00 2001 From: tuteng0915 Date: Mon, 3 Apr 2023 23:11:31 +0800 Subject: [PATCH 01/12] add web_demo3 --- .gitignore | 133 +++++++++++++++++++++++++++++++++++++++++++++++ requirements.txt | 1 + web_demo3.py | 69 ++++++++++++++++++++++++ 3 files changed, 203 insertions(+) create mode 100644 .gitignore create mode 100644 web_demo3.py diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..c3dd476 --- /dev/null +++ b/.gitignore @@ -0,0 +1,133 @@ +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +pip-wheel-metadata/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST +history/ + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +.python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + +# Mac system file +model/ \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index 00707fe..072d12c 100644 --- a/requirements.txt +++ b/requirements.txt @@ -4,3 +4,4 @@ icetk cpm_kernels torch>=1.10 gradio +mdtex2html \ No newline at end of file diff --git a/web_demo3.py b/web_demo3.py new file mode 100644 index 0000000..d6a62ec --- /dev/null +++ b/web_demo3.py @@ -0,0 +1,69 @@ +from transformers import AutoModel, AutoTokenizer +import gradio as gr +from typing import TYPE_CHECKING, Any, Callable, Dict, List, Tuple, Type +import mdtex2html + +tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True) +model = AutoModel.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True).quantize(8).half().cuda() +model = model.eval() + +# MAX_TURNS = 20 +# MAX_BOXES = MAX_TURNS * 2 + +"""Override Chatbot.postprocess""" +def postprocess(self, y): + if y is None: + return [] + for i, (message, response) in enumerate(y): + y[i] = ( + None if message is None else mdtex2html.convert((message)), + None if response is None else mdtex2html.convert(response), + ) + return y +gr.Chatbot.postprocess = postprocess + + +def predict(input, chatbot, max_length, top_p, temperature, history): + chatbot.append((input, "")) + for response, history in model.stream_chat(tokenizer, input, history, max_length=max_length, top_p=top_p, + temperature=temperature): + chatbot[-1] = (input, response) + yield chatbot, history + +def reset_user_input(): + return gr.update(value='') + + +def reset_state(): + return [], [] + +with gr.Blocks() as demo: + gr.HTML("""

ChatGLM

""") + + with gr.Row(): + with gr.Column(scale=4): + chatbot = gr.Chatbot() + with gr.Row(): + with gr.Column(scale=12): + user_input = gr.Textbox(show_label=False, placeholder="Input...").style( + container=False) + with gr.Column(min_width=32, scale=1): + submitBtn = gr.Button("Submit", variant="primary") + with gr.Column(scale=1): + emptyBtn = gr.Button("Clear History") + max_length = gr.Slider(0, 4096, value=2048, step=1.0, label="Maximum length", interactive=True) + top_p = gr.Slider(0, 1, value=0.7, step=0.01, label="Top P", interactive=True) + temperature = gr.Slider(0, 1, value=0.95, step=0.01, label="Temperature", interactive=True) + + history = gr.State([]) + + user_input.submit(predict, [user_input, chatbot, max_length, top_p, temperature, history], [chatbot, history], show_progress=True) + user_input.submit(reset_user_input, [], [user_input]) + + submitBtn.click(predict, [user_input, chatbot, max_length, top_p, temperature, history], [chatbot, history], show_progress=True) + submitBtn.click(reset_user_input, [], [user_input]) + + emptyBtn.click(reset_state, outputs=[chatbot, history], show_progress=True) + + +demo.queue().launch(share=False, inbrowser=True) From ec069419becceac2d69a07149c07b9fc564e19db Mon Sep 17 00:00:00 2001 From: duzx16 Date: Mon, 3 Apr 2023 23:29:04 +0800 Subject: [PATCH 02/12] Add another web demo with Gradio --- web_demo3.py | 40 ++++++++++++++++++++++------------------ 1 file changed, 22 insertions(+), 18 deletions(-) diff --git a/web_demo3.py b/web_demo3.py index d6a62ec..203ba1f 100644 --- a/web_demo3.py +++ b/web_demo3.py @@ -4,22 +4,23 @@ from typing import TYPE_CHECKING, Any, Callable, Dict, List, Tuple, Type import mdtex2html tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True) -model = AutoModel.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True).quantize(8).half().cuda() +model = AutoModel.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True).half().cuda() model = model.eval() -# MAX_TURNS = 20 -# MAX_BOXES = MAX_TURNS * 2 - """Override Chatbot.postprocess""" + + def postprocess(self, y): - if y is None: - return [] - for i, (message, response) in enumerate(y): - y[i] = ( - None if message is None else mdtex2html.convert((message)), - None if response is None else mdtex2html.convert(response), - ) - return y + if y is None: + return [] + for i, (message, response) in enumerate(y): + y[i] = ( + None if message is None else mdtex2html.convert((message)), + None if response is None else mdtex2html.convert(response), + ) + return y + + gr.Chatbot.postprocess = postprocess @@ -27,9 +28,10 @@ def predict(input, chatbot, max_length, top_p, temperature, history): chatbot.append((input, "")) for response, history in model.stream_chat(tokenizer, input, history, max_length=max_length, top_p=top_p, temperature=temperature): - chatbot[-1] = (input, response) + chatbot[-1] = (input, response) yield chatbot, history + def reset_user_input(): return gr.update(value='') @@ -37,6 +39,7 @@ def reset_user_input(): def reset_state(): return [], [] + with gr.Blocks() as demo: gr.HTML("""

ChatGLM

""") @@ -54,16 +57,17 @@ with gr.Blocks() as demo: max_length = gr.Slider(0, 4096, value=2048, step=1.0, label="Maximum length", interactive=True) top_p = gr.Slider(0, 1, value=0.7, step=0.01, label="Top P", interactive=True) temperature = gr.Slider(0, 1, value=0.95, step=0.01, label="Temperature", interactive=True) - + history = gr.State([]) - user_input.submit(predict, [user_input, chatbot, max_length, top_p, temperature, history], [chatbot, history], show_progress=True) + user_input.submit(predict, [user_input, chatbot, max_length, top_p, temperature, history], [chatbot, history], + show_progress=True) user_input.submit(reset_user_input, [], [user_input]) - submitBtn.click(predict, [user_input, chatbot, max_length, top_p, temperature, history], [chatbot, history], show_progress=True) + submitBtn.click(predict, [user_input, chatbot, max_length, top_p, temperature, history], [chatbot, history], + show_progress=True) submitBtn.click(reset_user_input, [], [user_input]) emptyBtn.click(reset_state, outputs=[chatbot, history], show_progress=True) - -demo.queue().launch(share=False, inbrowser=True) +demo.queue().launch(share=True, inbrowser=True) From 119caa15ef98de6faf3c66e82fa900f9b21b505c Mon Sep 17 00:00:00 2001 From: tuteng0915 Date: Mon, 3 Apr 2023 23:31:30 +0800 Subject: [PATCH 03/12] add parse_text --- web_demo3.py | 37 +++++++++++++++++++++++++++++++++++-- 1 file changed, 35 insertions(+), 2 deletions(-) diff --git a/web_demo3.py b/web_demo3.py index d6a62ec..80ffce9 100644 --- a/web_demo3.py +++ b/web_demo3.py @@ -23,11 +23,44 @@ def postprocess(self, y): gr.Chatbot.postprocess = postprocess +def parse_text(text): + """revise from https://github.com/GaiZhenbiao/ChuanhuChatGPT/""" + lines = text.split("\n") + lines = [line for line in lines if line != ""] + count = 0 + for i, line in enumerate(lines): + if "```" in line: + count += 1 + items = line.split('`') + if count % 2 == 1: + lines[i] = f'
'
+            else:
+                lines[i] = f'
' + else: + if i > 0: + if count % 2 == 1: + line = line.replace("`", "\`") + line = line.replace("<", "<") + line = line.replace(">", ">") + line = line.replace(" ", " ") + line = line.replace("*", "*") + line = line.replace("_", "_") + line = line.replace("-", "-") + line = line.replace(".", ".") + line = line.replace("!", "!") + line = line.replace("(", "(") + line = line.replace(")", ")") + line = line.replace("$", "$") + lines[i] = "
"+line + text = "".join(lines) + return text + + def predict(input, chatbot, max_length, top_p, temperature, history): - chatbot.append((input, "")) + chatbot.append((parse_text(input), "")) for response, history in model.stream_chat(tokenizer, input, history, max_length=max_length, top_p=top_p, temperature=temperature): - chatbot[-1] = (input, response) + chatbot[-1] = (parse_text(input), parse_text(response)) yield chatbot, history def reset_user_input(): From d21f891a76e9df2da7d3e0f6e5c5d28ef1dde337 Mon Sep 17 00:00:00 2001 From: tuteng0915 Date: Mon, 3 Apr 2023 23:36:18 +0800 Subject: [PATCH 04/12] add parse_text --- web_demo3.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/web_demo3.py b/web_demo3.py index 80ffce9..ad5ba11 100644 --- a/web_demo3.py +++ b/web_demo3.py @@ -4,7 +4,7 @@ from typing import TYPE_CHECKING, Any, Callable, Dict, List, Tuple, Type import mdtex2html tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True) -model = AutoModel.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True).quantize(8).half().cuda() +model = AutoModel.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True).half().cuda() model = model.eval() # MAX_TURNS = 20 @@ -24,7 +24,7 @@ gr.Chatbot.postprocess = postprocess def parse_text(text): - """revise from https://github.com/GaiZhenbiao/ChuanhuChatGPT/""" + """copy from https://github.com/GaiZhenbiao/ChuanhuChatGPT/""" lines = text.split("\n") lines = [line for line in lines if line != ""] count = 0 From cc4be399ff1a88f5459b3b8793b83c4372409517 Mon Sep 17 00:00:00 2001 From: duzx16 Date: Thu, 6 Apr 2023 16:58:40 +0800 Subject: [PATCH 05/12] Update web demo3 --- web_demo3.py | 22 ++++++++-------------- 1 file changed, 8 insertions(+), 14 deletions(-) diff --git a/web_demo3.py b/web_demo3.py index 7c0777c..0e39968 100644 --- a/web_demo3.py +++ b/web_demo3.py @@ -1,10 +1,9 @@ from transformers import AutoModel, AutoTokenizer import gradio as gr -from typing import TYPE_CHECKING, Any, Callable, Dict, List, Tuple, Type import mdtex2html -tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True) -model = AutoModel.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True).half().cuda() +tokenizer = AutoTokenizer.from_pretrained("/mnt/vepfs/workspace/zxdu/chatglm_6b", trust_remote_code=True) +model = AutoModel.from_pretrained("/mnt/vepfs/workspace/zxdu/chatglm_6b", trust_remote_code=True).half().cuda() model = model.eval() """Override Chatbot.postprocess""" @@ -77,15 +76,14 @@ def reset_state(): with gr.Blocks() as demo: gr.HTML("""

ChatGLM

""") + chatbot = gr.Chatbot() with gr.Row(): with gr.Column(scale=4): - chatbot = gr.Chatbot() - with gr.Row(): - with gr.Column(scale=12): - user_input = gr.Textbox(show_label=False, placeholder="Input...").style( - container=False) - with gr.Column(min_width=32, scale=1): - submitBtn = gr.Button("Submit", variant="primary") + with gr.Column(scale=12): + user_input = gr.Textbox(show_label=False, placeholder="Input...", lines=10).style( + container=False) + with gr.Column(min_width=32, scale=1): + submitBtn = gr.Button("Submit", variant="primary") with gr.Column(scale=1): emptyBtn = gr.Button("Clear History") max_length = gr.Slider(0, 4096, value=2048, step=1.0, label="Maximum length", interactive=True) @@ -94,10 +92,6 @@ with gr.Blocks() as demo: history = gr.State([]) - user_input.submit(predict, [user_input, chatbot, max_length, top_p, temperature, history], [chatbot, history], - show_progress=True) - user_input.submit(reset_user_input, [], [user_input]) - submitBtn.click(predict, [user_input, chatbot, max_length, top_p, temperature, history], [chatbot, history], show_progress=True) submitBtn.click(reset_user_input, [], [user_input]) From 40d83f32feb6fcbec54ab8c8479a4830378edb3e Mon Sep 17 00:00:00 2001 From: duzx16 Date: Thu, 6 Apr 2023 17:00:51 +0800 Subject: [PATCH 06/12] Update model path --- web_demo3.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/web_demo3.py b/web_demo3.py index 0e39968..df7f983 100644 --- a/web_demo3.py +++ b/web_demo3.py @@ -2,8 +2,8 @@ from transformers import AutoModel, AutoTokenizer import gradio as gr import mdtex2html -tokenizer = AutoTokenizer.from_pretrained("/mnt/vepfs/workspace/zxdu/chatglm_6b", trust_remote_code=True) -model = AutoModel.from_pretrained("/mnt/vepfs/workspace/zxdu/chatglm_6b", trust_remote_code=True).half().cuda() +tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True) +model = AutoModel.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True).half().cuda() model = model.eval() """Override Chatbot.postprocess""" From 28335463394983ddfc7b554f8fd3ee894a7b98b5 Mon Sep 17 00:00:00 2001 From: duzx16 Date: Thu, 6 Apr 2023 17:01:24 +0800 Subject: [PATCH 07/12] Use chatbot web demo --- web_demo.py | 104 +++++++++++++++++++++++++++++++++++++----------- web_demo3.py | 101 ---------------------------------------------- web_demo_old.py | 45 +++++++++++++++++++++ 3 files changed, 125 insertions(+), 125 deletions(-) delete mode 100644 web_demo3.py create mode 100644 web_demo_old.py diff --git a/web_demo.py b/web_demo.py index 88a6dc8..df7f983 100644 --- a/web_demo.py +++ b/web_demo.py @@ -1,45 +1,101 @@ from transformers import AutoModel, AutoTokenizer import gradio as gr +import mdtex2html tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True) model = AutoModel.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True).half().cuda() model = model.eval() -MAX_TURNS = 20 -MAX_BOXES = MAX_TURNS * 2 +"""Override Chatbot.postprocess""" -def predict(input, max_length, top_p, temperature, history=None): - if history is None: - history = [] +def postprocess(self, y): + if y is None: + return [] + for i, (message, response) in enumerate(y): + y[i] = ( + None if message is None else mdtex2html.convert((message)), + None if response is None else mdtex2html.convert(response), + ) + return y + + +gr.Chatbot.postprocess = postprocess + + +def parse_text(text): + """copy from https://github.com/GaiZhenbiao/ChuanhuChatGPT/""" + lines = text.split("\n") + lines = [line for line in lines if line != ""] + count = 0 + for i, line in enumerate(lines): + if "```" in line: + count += 1 + items = line.split('`') + if count % 2 == 1: + lines[i] = f'
'
+            else:
+                lines[i] = f'
' + else: + if i > 0: + if count % 2 == 1: + line = line.replace("`", "\`") + line = line.replace("<", "<") + line = line.replace(">", ">") + line = line.replace(" ", " ") + line = line.replace("*", "*") + line = line.replace("_", "_") + line = line.replace("-", "-") + line = line.replace(".", ".") + line = line.replace("!", "!") + line = line.replace("(", "(") + line = line.replace(")", ")") + line = line.replace("$", "$") + lines[i] = "
"+line + text = "".join(lines) + return text + + +def predict(input, chatbot, max_length, top_p, temperature, history): + chatbot.append((parse_text(input), "")) for response, history in model.stream_chat(tokenizer, input, history, max_length=max_length, top_p=top_p, temperature=temperature): - updates = [] - for query, response in history: - updates.append(gr.update(visible=True, value="用户:" + query)) - updates.append(gr.update(visible=True, value="ChatGLM-6B:" + response)) - if len(updates) < MAX_BOXES: - updates = updates + [gr.Textbox.update(visible=False)] * (MAX_BOXES - len(updates)) - yield [history] + updates + chatbot[-1] = (parse_text(input), parse_text(response)) + + yield chatbot, history + + +def reset_user_input(): + return gr.update(value='') + + +def reset_state(): + return [], [] with gr.Blocks() as demo: - state = gr.State([]) - text_boxes = [] - for i in range(MAX_BOXES): - if i % 2 == 0: - text_boxes.append(gr.Markdown(visible=False, label="提问:")) - else: - text_boxes.append(gr.Markdown(visible=False, label="回复:")) + gr.HTML("""

ChatGLM

""") + chatbot = gr.Chatbot() with gr.Row(): with gr.Column(scale=4): - txt = gr.Textbox(show_label=False, placeholder="Enter text and press enter", lines=11).style( - container=False) + with gr.Column(scale=12): + user_input = gr.Textbox(show_label=False, placeholder="Input...", lines=10).style( + container=False) + with gr.Column(min_width=32, scale=1): + submitBtn = gr.Button("Submit", variant="primary") with gr.Column(scale=1): + emptyBtn = gr.Button("Clear History") max_length = gr.Slider(0, 4096, value=2048, step=1.0, label="Maximum length", interactive=True) top_p = gr.Slider(0, 1, value=0.7, step=0.01, label="Top P", interactive=True) temperature = gr.Slider(0, 1, value=0.95, step=0.01, label="Temperature", interactive=True) - button = gr.Button("Generate") - button.click(predict, [txt, max_length, top_p, temperature, state], [state] + text_boxes) -demo.queue().launch(share=False, inbrowser=True) + + history = gr.State([]) + + submitBtn.click(predict, [user_input, chatbot, max_length, top_p, temperature, history], [chatbot, history], + show_progress=True) + submitBtn.click(reset_user_input, [], [user_input]) + + emptyBtn.click(reset_state, outputs=[chatbot, history], show_progress=True) + +demo.queue().launch(share=True, inbrowser=True) diff --git a/web_demo3.py b/web_demo3.py deleted file mode 100644 index df7f983..0000000 --- a/web_demo3.py +++ /dev/null @@ -1,101 +0,0 @@ -from transformers import AutoModel, AutoTokenizer -import gradio as gr -import mdtex2html - -tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True) -model = AutoModel.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True).half().cuda() -model = model.eval() - -"""Override Chatbot.postprocess""" - - -def postprocess(self, y): - if y is None: - return [] - for i, (message, response) in enumerate(y): - y[i] = ( - None if message is None else mdtex2html.convert((message)), - None if response is None else mdtex2html.convert(response), - ) - return y - - -gr.Chatbot.postprocess = postprocess - - -def parse_text(text): - """copy from https://github.com/GaiZhenbiao/ChuanhuChatGPT/""" - lines = text.split("\n") - lines = [line for line in lines if line != ""] - count = 0 - for i, line in enumerate(lines): - if "```" in line: - count += 1 - items = line.split('`') - if count % 2 == 1: - lines[i] = f'
'
-            else:
-                lines[i] = f'
' - else: - if i > 0: - if count % 2 == 1: - line = line.replace("`", "\`") - line = line.replace("<", "<") - line = line.replace(">", ">") - line = line.replace(" ", " ") - line = line.replace("*", "*") - line = line.replace("_", "_") - line = line.replace("-", "-") - line = line.replace(".", ".") - line = line.replace("!", "!") - line = line.replace("(", "(") - line = line.replace(")", ")") - line = line.replace("$", "$") - lines[i] = "
"+line - text = "".join(lines) - return text - - -def predict(input, chatbot, max_length, top_p, temperature, history): - chatbot.append((parse_text(input), "")) - for response, history in model.stream_chat(tokenizer, input, history, max_length=max_length, top_p=top_p, - temperature=temperature): - chatbot[-1] = (parse_text(input), parse_text(response)) - - yield chatbot, history - - -def reset_user_input(): - return gr.update(value='') - - -def reset_state(): - return [], [] - - -with gr.Blocks() as demo: - gr.HTML("""

ChatGLM

""") - - chatbot = gr.Chatbot() - with gr.Row(): - with gr.Column(scale=4): - with gr.Column(scale=12): - user_input = gr.Textbox(show_label=False, placeholder="Input...", lines=10).style( - container=False) - with gr.Column(min_width=32, scale=1): - submitBtn = gr.Button("Submit", variant="primary") - with gr.Column(scale=1): - emptyBtn = gr.Button("Clear History") - max_length = gr.Slider(0, 4096, value=2048, step=1.0, label="Maximum length", interactive=True) - top_p = gr.Slider(0, 1, value=0.7, step=0.01, label="Top P", interactive=True) - temperature = gr.Slider(0, 1, value=0.95, step=0.01, label="Temperature", interactive=True) - - history = gr.State([]) - - submitBtn.click(predict, [user_input, chatbot, max_length, top_p, temperature, history], [chatbot, history], - show_progress=True) - submitBtn.click(reset_user_input, [], [user_input]) - - emptyBtn.click(reset_state, outputs=[chatbot, history], show_progress=True) - -demo.queue().launch(share=True, inbrowser=True) diff --git a/web_demo_old.py b/web_demo_old.py new file mode 100644 index 0000000..88a6dc8 --- /dev/null +++ b/web_demo_old.py @@ -0,0 +1,45 @@ +from transformers import AutoModel, AutoTokenizer +import gradio as gr + +tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True) +model = AutoModel.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True).half().cuda() +model = model.eval() + +MAX_TURNS = 20 +MAX_BOXES = MAX_TURNS * 2 + + +def predict(input, max_length, top_p, temperature, history=None): + if history is None: + history = [] + for response, history in model.stream_chat(tokenizer, input, history, max_length=max_length, top_p=top_p, + temperature=temperature): + updates = [] + for query, response in history: + updates.append(gr.update(visible=True, value="用户:" + query)) + updates.append(gr.update(visible=True, value="ChatGLM-6B:" + response)) + if len(updates) < MAX_BOXES: + updates = updates + [gr.Textbox.update(visible=False)] * (MAX_BOXES - len(updates)) + yield [history] + updates + + +with gr.Blocks() as demo: + state = gr.State([]) + text_boxes = [] + for i in range(MAX_BOXES): + if i % 2 == 0: + text_boxes.append(gr.Markdown(visible=False, label="提问:")) + else: + text_boxes.append(gr.Markdown(visible=False, label="回复:")) + + with gr.Row(): + with gr.Column(scale=4): + txt = gr.Textbox(show_label=False, placeholder="Enter text and press enter", lines=11).style( + container=False) + with gr.Column(scale=1): + max_length = gr.Slider(0, 4096, value=2048, step=1.0, label="Maximum length", interactive=True) + top_p = gr.Slider(0, 1, value=0.7, step=0.01, label="Top P", interactive=True) + temperature = gr.Slider(0, 1, value=0.95, step=0.01, label="Temperature", interactive=True) + button = gr.Button("Generate") + button.click(predict, [txt, max_length, top_p, temperature, state], [state] + text_boxes) +demo.queue().launch(share=False, inbrowser=True) From 7131d29f2d49ed984dba8aeeb88b5a16a83b07ab Mon Sep 17 00:00:00 2001 From: duzx16 Date: Thu, 6 Apr 2023 17:51:20 +0800 Subject: [PATCH 08/12] Add English readme --- README_en.md | 4 ++++ ptuning/README.md | 2 ++ 2 files changed, 6 insertions(+) diff --git a/README_en.md b/README_en.md index d5c05bb..da2b8dc 100644 --- a/README_en.md +++ b/README_en.md @@ -9,6 +9,8 @@ ChatGLM-6B uses technology similar to ChatGPT, optimized for Chinese QA and dial Try the [online demo](https://huggingface.co/spaces/ysharma/ChatGLM-6b_Gradio_Streaming) on Huggingface Spaces. ## Update +**[2023/03/31]** Added a parameter-efficient tuning implementation based on [P-Tuning-v2](https://github.com/THUDM/P-tuning-v2). The minimum INT4 quantization level only needs 7GB GPU memory is enough for model tuning. See [Parameter-efficient tuning method](ptuning/README.md) for details. + **[2023/03/23]** Add API deployment, thanks to [@LemonQu-GIT](https://github.com/LemonQu-GIT). Add embedding-quantized model [ChatGLM-6B-INT4-QE](https://huggingface.co/THUDM/chatglm-6b-int4-qe). Add support for GPU inference on Mac with Apple Silicon. **[2023/03/19]** Add streaming output function `stream_chat`, already applied in web and CLI demo. Fix Chinese punctuations in output. Add quantized model [ChatGLM-6B-INT4](https://huggingface.co/THUDM/chatglm-6b-int4). @@ -168,6 +170,8 @@ model = AutoModel.from_pretrained("your local path", trust_remote_code=True).hal ``` Then you can use GPU-accelerated model inference on Mac. +## Parameter-efficient Tuning +Parameter-efficient tuning based on [P-tuning v2](https://github.com/THUDM/P-tuning-v2). See [ptuning/README.md](ptuning/README.md) for details on how to use it. ## ChatGLM-6B Examples diff --git a/ptuning/README.md b/ptuning/README.md index ca1fc73..11ee326 100644 --- a/ptuning/README.md +++ b/ptuning/README.md @@ -3,6 +3,8 @@ 下面以 [ADGEN](https://aclanthology.org/D19-1321.pdf) (广告生成) 数据集为例介绍代码的使用方法。 +*Read this in [English](README_en.md).* + ## 软件依赖 运行微调需要4.27.1版本的`transformers`。除 ChatGLM-6B 的依赖之外,还需要按照以下依赖 ``` From 66e641d572d612e905ca5e16b1ebdab029eb6910 Mon Sep 17 00:00:00 2001 From: duzx16 Date: Thu, 6 Apr 2023 17:54:46 +0800 Subject: [PATCH 09/12] Add English readme --- .idea/ChatGLM-6B.iml | 7 +++++++ 1 file changed, 7 insertions(+) create mode 100644 .idea/ChatGLM-6B.iml diff --git a/.idea/ChatGLM-6B.iml b/.idea/ChatGLM-6B.iml new file mode 100644 index 0000000..ec63674 --- /dev/null +++ b/.idea/ChatGLM-6B.iml @@ -0,0 +1,7 @@ + + + + + \ No newline at end of file From e79e4f2859321406272c322ea9909fb5395f285d Mon Sep 17 00:00:00 2001 From: duzx16 Date: Thu, 6 Apr 2023 17:55:27 +0800 Subject: [PATCH 10/12] Revert "Add English readme" This reverts commit 66e641d572d612e905ca5e16b1ebdab029eb6910. --- .idea/ChatGLM-6B.iml | 7 ------- 1 file changed, 7 deletions(-) delete mode 100644 .idea/ChatGLM-6B.iml diff --git a/.idea/ChatGLM-6B.iml b/.idea/ChatGLM-6B.iml deleted file mode 100644 index ec63674..0000000 --- a/.idea/ChatGLM-6B.iml +++ /dev/null @@ -1,7 +0,0 @@ - - - - - \ No newline at end of file From 6792ca6805dcae1fbf83c2eb33a7ffc3a96b243a Mon Sep 17 00:00:00 2001 From: duzx16 Date: Thu, 6 Apr 2023 17:55:31 +0800 Subject: [PATCH 11/12] Add English readme --- ptuning/README_en.md | 115 +++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 115 insertions(+) create mode 100644 ptuning/README_en.md diff --git a/ptuning/README_en.md b/ptuning/README_en.md new file mode 100644 index 0000000..9282da3 --- /dev/null +++ b/ptuning/README_en.md @@ -0,0 +1,115 @@ +# ChatGLM-6B-PT +This repository implements tuning of the ChatGLM-6B model based on [P-Tuning v2](https://github.com/THUDM/P-tuning-v2). P-Tuning v2 reduces the amount of parameters that need to be optimized to 0.1% of the full fine-tuning, and then through model quantization, Gradient Checkpoint and other methods, it only needs a minimum of 7GB of video memory to run. + +The following uses the [ADGEN](https://aclanthology.org/D19-1321.pdf) (advertising generation) dataset as an example to introduce how to use the code. + +## Software dependencies +Running p-tuning requires version 4.27.1 of `transformers`. In addition to the dependencies of ChatGLM-6B, the following dependencies are required +``` +pip install rouge_chinese nltk jieba datasets +``` +## Instructions + +### Download the dataset +The task of the ADGEN dataset is to generate an advertisement word (summary) based on the input (content). + +```json +{ + "content": "类型#上衣*版型#宽松*版型#显瘦*图案#线条*衣样式#衬衫*衣袖型#泡泡袖*衣款式#抽绳", + "summary": "这件衬衫的款式非常的宽松,利落的线条可以很好的隐藏身材上的小缺点,穿在身上有着很好的显瘦效果。领口装饰了一个可爱的抽绳,漂亮的绳结展现出了十足的个性,配合时尚的泡泡袖型,尽显女性甜美可爱的气息。" +} +``` + +From [Google Drive](https://drive.google.com/file/d/13_vf0xRTQsyneRKdD1bZIr93vBGOczrk/view?usp=sharing) or [Tsinghua Cloud](https://cloud.tsinghua.edu.cn/f/b3f119a008264b1cabd1/?dl=1) Download the processed ADGEN dataset, and put the decompressed `AdvertiseGen` directory into this directory. + +### Training +Run the following commands for training: +```shell +bash train.sh +``` +`PRE_SEQ_LEN` and `LR` in `train.sh` are soft prompt length and training learning rate respectively, which can be adjusted to achieve the best results. The P-Tuning-v2 method will freeze all model parameters, and the quantization level of the original model can be adjusted by adjusting `quantization_bit`. If this option is not added, it will be loaded with FP16 precision. + +Under the default configuration of `per_device_train_batch_size=1`, `gradient_accumulation_steps=16`, the model parameters of INT4 are frozen, and a training iteration will perform 16 cumulative forward and backward propagations with a batch size of 1, which is equivalent to the total batch size of 16, and only 6.7G GPU memory is required at this time with `quantization_bit=4`. If you want to improve the training efficiency under the same batch size, you can increase the value of `per_device_train_batch_size` while keeping the product of the two unchanged, but it will also bring more GPU memory consumption, please adjust it according to the actual situation. + +### Inference + +Change `CHECKPOINT` in `evaluate.sh` to the checkpoint name saved during training, and run the following commands for model inference and evaluation: +```shell +bash evaluate.sh +``` + +The evaluation indicators are Chinese Rouge score and BLEU-4. The generated results are saved in +`./output/adgen-chatglm-6b-pt-8-1e-2/generated_predictions.txt`. + +### Example +#### Example 1 +* Input: 类型#上衣\*材质#牛仔布\*颜色#白色\*风格#简约\*图案#刺绣\*衣样式#外套\*衣款式#破洞 +* Label: 简约而不简单的牛仔外套,白色的衣身十分百搭。衣身多处有做旧破洞设计,打破单调乏味,增加一丝造型看点。衣身后背处有趣味刺绣装饰,丰富层次感,彰显别样时尚。 +* Output[微调前]: 这件上衣的材质是牛仔布,颜色是白色,风格是简约,图案是刺绣,衣样式是外套,衣款式是破洞。 +* Output[微调后]: 这是一款简约的牛仔外套,破洞设计,将牛仔布破洞,带来一种随意与个性。破洞的牛仔外套,展现出时尚气息,带来一种休闲感。同时,刺绣图案,让整件外套更加立体。 + +#### Example 2 + +* Input: 类型#裙\*版型#显瘦\*风格#文艺\*风格#简约\*图案#印花\*图案#撞色\*裙下摆#压褶\*裙长#连衣裙\*裙领型#圆领 +* Label: 文艺个性的印花连衣裙,藏青色底蕴,低调又大气,撞色太阳花分布整个裙身,绚丽而美好,带来时尚减龄的气质。基础款的舒适圆领,简约不失大方,勾勒精致脸庞。领后是一粒包布扣固定,穿脱十分方便。前片立体的打褶设计,搭配后片压褶的做工,增添层次和空间感,显瘦又有型。 +* Output[微调前]: 类型#裙*版型#显瘦*风格#文艺*风格#简约*图案#印花*图案#撞色*裙下摆#压褶*裙长#连衣裙*裙领型#圆领 1\. 连衣裙:简约风格,裙长为膝盖以上,裙领型为圆领。2\. 裙下摆:压褶设计,使裙摆呈现出流畅的褶皱效果。3\. 裙领型:裙领型为圆领,使穿上连衣裙后更加有型。4\. 版型:采用显瘦设计,让连衣裙看起来更加苗条。5\. 风格:文艺风格,让连衣裙更加有内涵和品味。6\. 图案:印花设计,在连衣裙上印有独特的图案。7\. 撞色:采用撞色设计,让连衣裙在色彩上更加鲜明、富有层次感。 +* Output[微调后]: 这是一款文艺范的连衣裙,以印花为元素,采用简约的印花,既能够突出文艺气质,又能够展现简约风。在印花的同时又有领子和裙摆的压褶设计,更加凸显文艺气质。简约而不会过于单调,搭配出街,穿着十分舒适。 + +### evaluation result + +| | P-tuning v2 | LoRA | +| ------- | ----------- | ----- | +| BLEU-4 | 7.71 | 6.13 | +| Rouge-1 | 31.35 | 28.36 | +| Rouge-2 | 7.19 | 4.38 | +| Rouge-l | 25.17 | 17.54 | + +#### Experiment Settings + + ``` +max_source_length=64 +max_target_length=64 +per_device_train_batch_size=1 +gradient_accumulation_steps=16 +max_steps=3000 + ``` + +##### P-tuning v2 + +``` +pre_seq_len=128 +learning_rate=2e-2 +quantization_bit=4 +``` + +##### LoRA + +``` +learning_rate=5e-4 +``` + +The implementation uses [simple_thu_chatglm6b](https://github.com/yuanzhoulvpi2017/zero_nlp/tree/main/simple_thu_chatglm6b) + + + +## Model Deployment +Replace `THUDM/chatglm-6b` in the corresponding demo or code with the path of the checkpoint after P-Tuning(in the example, `./output/adgen-chatglm-6b-pt-8-1e-2/ checkpoint-3000`). Note that the current fine-tuning does not support multiple rounds of data, so only the responses from the first round of the conversation are fine-tuned. + +## Use your own dataset +Modify `train_file`, `validation_file` and `test_file` in `train.sh` and `evaluate.sh` to your own JSON format dataset paths, and change `prompt_column` and `response_column` to the keys in the JSON file corresponding to input text and output text. + +## TODO +* [ ] Support for chat data +* [ ] Support for full finetuning + +## quoting + +``` +@inproceedings{liu2022p, + title={P-tuning: Prompt tuning can be comparable to fine-tuning across scales and tasks}, + author={Liu, Xiao and Ji, Kaixuan and Fu, Yicheng and Tam, Weng and Du, Zhengxiao and Yang, Zhilin and Tang, Jie}, + booktitle={Proceedings of the 60th Annual Meeting of the Association for Computational Linguistics (Volume 2: Short Papers)}, + pages={61--68}, + year={2022} +} +``` \ No newline at end of file From 8a809d4ab712eb61c1ae070452b61f7ab8e4164e Mon Sep 17 00:00:00 2001 From: duzx16 Date: Thu, 6 Apr 2023 19:28:07 +0800 Subject: [PATCH 12/12] Drop icetk dependency --- requirements.txt | 1 - 1 file changed, 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 072d12c..4788707 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,6 +1,5 @@ protobuf>=3.19.5,<3.20.1 transformers==4.27.1 -icetk cpm_kernels torch>=1.10 gradio