mirror of https://github.com/THUDM/ChatGLM2-6B
check device
parent
e4d57691d5
commit
03648a5a60
|
@ -7,6 +7,12 @@ import torch
|
|||
# os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||
|
||||
device = torch.device("cpu")
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device("cuda")
|
||||
elif torch.backends.mps.is_available():
|
||||
device = torch.device("mps")
|
||||
else:
|
||||
device = torch.device("cpu")
|
||||
print("device:", device)
|
||||
|
||||
checkpoint = "/Users/hhwang/models/opt-125m"
|
||||
|
@ -37,7 +43,7 @@ trainer = Trainer(
|
|||
args=TrainingArguments(
|
||||
per_device_train_batch_size=4,
|
||||
gradient_accumulation_steps=4,
|
||||
warmup_steps=2,
|
||||
warmup_steps=3,
|
||||
max_steps=10,
|
||||
learning_rate=2e-4,
|
||||
# fp16=True, # only works on cuda
|
||||
|
|
Loading…
Reference in New Issue