Add clear button

Clear cuda cache so user don't need to restart program when running out of vram.
pull/154/head
Anderson 2023-03-19 16:10:01 +08:00 committed by GitHub
parent 52aa3261d7
commit b54c6ad624
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 12 additions and 4 deletions

View File

@ -1,5 +1,6 @@
from transformers import AutoModel, AutoTokenizer import torch
import gradio as gr import gradio as gr
from transformers import AutoModel, AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True) tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True)
model = AutoModel.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True).half().cuda() model = AutoModel.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True).half().cuda()
@ -23,6 +24,11 @@ def predict(input, max_length, top_p, temperature, history=None):
yield [history] + updates yield [history] + updates
def clear():
if torch.cuda.is_available():
torch.cuda.empty_cache()
with gr.Blocks() as demo: with gr.Blocks() as demo:
state = gr.State([]) state = gr.State([])
text_boxes = [] text_boxes = []
@ -40,6 +46,8 @@ with gr.Blocks() as demo:
max_length = gr.Slider(0, 4096, value=2048, step=1.0, label="Maximum length", interactive=True) max_length = gr.Slider(0, 4096, value=2048, step=1.0, label="Maximum length", interactive=True)
top_p = gr.Slider(0, 1, value=0.7, step=0.01, label="Top P", interactive=True) top_p = gr.Slider(0, 1, value=0.7, step=0.01, label="Top P", interactive=True)
temperature = gr.Slider(0, 1, value=0.95, step=0.01, label="Temperature", interactive=True) temperature = gr.Slider(0, 1, value=0.95, step=0.01, label="Temperature", interactive=True)
button = gr.Button("Generate") generate_button = gr.Button("Generate")
button.click(predict, [txt, max_length, top_p, temperature, state], [state] + text_boxes) clear_button = gr.Button("Clear")
demo.queue().launch(share=True, inbrowser=True) generate_button.click(predict, [txt, max_length, top_p, temperature, state], [state] + text_boxes)
clear_button.click(clear)
demo.queue().launch(share=True, inbrowser=True)