mirror of https://github.com/THUDM/ChatGLM-6B
Add clear button
Clear cuda cache so user don't need to restart program when running out of vram.pull/154/head
parent
52aa3261d7
commit
b54c6ad624
16
web_demo.py
16
web_demo.py
|
@ -1,5 +1,6 @@
|
||||||
from transformers import AutoModel, AutoTokenizer
|
import torch
|
||||||
import gradio as gr
|
import gradio as gr
|
||||||
|
from transformers import AutoModel, AutoTokenizer
|
||||||
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True)
|
tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True)
|
||||||
model = AutoModel.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True).half().cuda()
|
model = AutoModel.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True).half().cuda()
|
||||||
|
@ -23,6 +24,11 @@ def predict(input, max_length, top_p, temperature, history=None):
|
||||||
yield [history] + updates
|
yield [history] + updates
|
||||||
|
|
||||||
|
|
||||||
|
def clear():
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
|
||||||
with gr.Blocks() as demo:
|
with gr.Blocks() as demo:
|
||||||
state = gr.State([])
|
state = gr.State([])
|
||||||
text_boxes = []
|
text_boxes = []
|
||||||
|
@ -40,6 +46,8 @@ with gr.Blocks() as demo:
|
||||||
max_length = gr.Slider(0, 4096, value=2048, step=1.0, label="Maximum length", interactive=True)
|
max_length = gr.Slider(0, 4096, value=2048, step=1.0, label="Maximum length", interactive=True)
|
||||||
top_p = gr.Slider(0, 1, value=0.7, step=0.01, label="Top P", interactive=True)
|
top_p = gr.Slider(0, 1, value=0.7, step=0.01, label="Top P", interactive=True)
|
||||||
temperature = gr.Slider(0, 1, value=0.95, step=0.01, label="Temperature", interactive=True)
|
temperature = gr.Slider(0, 1, value=0.95, step=0.01, label="Temperature", interactive=True)
|
||||||
button = gr.Button("Generate")
|
generate_button = gr.Button("Generate")
|
||||||
button.click(predict, [txt, max_length, top_p, temperature, state], [state] + text_boxes)
|
clear_button = gr.Button("Clear")
|
||||||
demo.queue().launch(share=True, inbrowser=True)
|
generate_button.click(predict, [txt, max_length, top_p, temperature, state], [state] + text_boxes)
|
||||||
|
clear_button.click(clear)
|
||||||
|
demo.queue().launch(share=True, inbrowser=True)
|
||||||
|
|
Loading…
Reference in New Issue