From b54c6ad62462edcc653a6977a901a42854ae3870 Mon Sep 17 00:00:00 2001 From: Anderson Date: Sun, 19 Mar 2023 16:10:01 +0800 Subject: [PATCH] Add clear button Clear cuda cache so user don't need to restart program when running out of vram. --- web_demo.py | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) diff --git a/web_demo.py b/web_demo.py index 39709a7..d81a1f8 100644 --- a/web_demo.py +++ b/web_demo.py @@ -1,5 +1,6 @@ -from transformers import AutoModel, AutoTokenizer +import torch import gradio as gr +from transformers import AutoModel, AutoTokenizer tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True) model = AutoModel.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True).half().cuda() @@ -23,6 +24,11 @@ def predict(input, max_length, top_p, temperature, history=None): yield [history] + updates +def clear(): + if torch.cuda.is_available(): + torch.cuda.empty_cache() + + with gr.Blocks() as demo: state = gr.State([]) text_boxes = [] @@ -40,6 +46,8 @@ with gr.Blocks() as demo: max_length = gr.Slider(0, 4096, value=2048, step=1.0, label="Maximum length", interactive=True) top_p = gr.Slider(0, 1, value=0.7, step=0.01, label="Top P", interactive=True) temperature = gr.Slider(0, 1, value=0.95, step=0.01, label="Temperature", interactive=True) - button = gr.Button("Generate") - button.click(predict, [txt, max_length, top_p, temperature, state], [state] + text_boxes) -demo.queue().launch(share=True, inbrowser=True) \ No newline at end of file + generate_button = gr.Button("Generate") + clear_button = gr.Button("Clear") + generate_button.click(predict, [txt, max_length, top_p, temperature, state], [state] + text_boxes) + clear_button.click(clear) +demo.queue().launch(share=True, inbrowser=True)