mirror of https://github.com/THUDM/ChatGLM-6B
Add clear button
Clear cuda cache so user don't need to restart program when running out of vram.pull/154/head
parent
52aa3261d7
commit
b54c6ad624
16
web_demo.py
16
web_demo.py
|
@ -1,5 +1,6 @@
|
|||
from transformers import AutoModel, AutoTokenizer
|
||||
import torch
|
||||
import gradio as gr
|
||||
from transformers import AutoModel, AutoTokenizer
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True)
|
||||
model = AutoModel.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True).half().cuda()
|
||||
|
@ -23,6 +24,11 @@ def predict(input, max_length, top_p, temperature, history=None):
|
|||
yield [history] + updates
|
||||
|
||||
|
||||
def clear():
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
|
||||
with gr.Blocks() as demo:
|
||||
state = gr.State([])
|
||||
text_boxes = []
|
||||
|
@ -40,6 +46,8 @@ with gr.Blocks() as demo:
|
|||
max_length = gr.Slider(0, 4096, value=2048, step=1.0, label="Maximum length", interactive=True)
|
||||
top_p = gr.Slider(0, 1, value=0.7, step=0.01, label="Top P", interactive=True)
|
||||
temperature = gr.Slider(0, 1, value=0.95, step=0.01, label="Temperature", interactive=True)
|
||||
button = gr.Button("Generate")
|
||||
button.click(predict, [txt, max_length, top_p, temperature, state], [state] + text_boxes)
|
||||
demo.queue().launch(share=True, inbrowser=True)
|
||||
generate_button = gr.Button("Generate")
|
||||
clear_button = gr.Button("Clear")
|
||||
generate_button.click(predict, [txt, max_length, top_p, temperature, state], [state] + text_boxes)
|
||||
clear_button.click(clear)
|
||||
demo.queue().launch(share=True, inbrowser=True)
|
||||
|
|
Loading…
Reference in New Issue