check device

pull/672/head
wanghh2000 2023-12-29 10:35:08 +08:00
parent e4d57691d5
commit 03648a5a60
1 changed files with 7 additions and 1 deletions

View File

@ -7,6 +7,12 @@ import torch
# os.environ["TOKENIZERS_PARALLELISM"] = "false" # os.environ["TOKENIZERS_PARALLELISM"] = "false"
device = torch.device("cpu") device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda")
elif torch.backends.mps.is_available():
device = torch.device("mps")
else:
device = torch.device("cpu")
print("device:", device) print("device:", device)
checkpoint = "/Users/hhwang/models/opt-125m" checkpoint = "/Users/hhwang/models/opt-125m"
@ -37,7 +43,7 @@ trainer = Trainer(
args=TrainingArguments( args=TrainingArguments(
per_device_train_batch_size=4, per_device_train_batch_size=4,
gradient_accumulation_steps=4, gradient_accumulation_steps=4,
warmup_steps=2, warmup_steps=3,
max_steps=10, max_steps=10,
learning_rate=2e-4, learning_rate=2e-4,
# fp16=True, # only works on cuda # fp16=True, # only works on cuda