From 03648a5a6081e4049328f554eb64992705ec8926 Mon Sep 17 00:00:00 2001 From: wanghh2000 Date: Fri, 29 Dec 2023 10:35:08 +0800 Subject: [PATCH] check device --- ptuning/finetune-opt.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/ptuning/finetune-opt.py b/ptuning/finetune-opt.py index c5ea64a..648af8e 100644 --- a/ptuning/finetune-opt.py +++ b/ptuning/finetune-opt.py @@ -7,6 +7,12 @@ import torch # os.environ["TOKENIZERS_PARALLELISM"] = "false" device = torch.device("cpu") +if torch.cuda.is_available(): + device = torch.device("cuda") +elif torch.backends.mps.is_available(): + device = torch.device("mps") +else: + device = torch.device("cpu") print("device:", device) checkpoint = "/Users/hhwang/models/opt-125m" @@ -37,7 +43,7 @@ trainer = Trainer( args=TrainingArguments( per_device_train_batch_size=4, gradient_accumulation_steps=4, - warmup_steps=2, + warmup_steps=3, max_steps=10, learning_rate=2e-4, # fp16=True, # only works on cuda