mirror of https://github.com/THUDM/ChatGLM2-6B
check device
parent
e4d57691d5
commit
03648a5a60
|
@ -7,6 +7,12 @@ import torch
|
||||||
# os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
# os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||||
|
|
||||||
device = torch.device("cpu")
|
device = torch.device("cpu")
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
device = torch.device("cuda")
|
||||||
|
elif torch.backends.mps.is_available():
|
||||||
|
device = torch.device("mps")
|
||||||
|
else:
|
||||||
|
device = torch.device("cpu")
|
||||||
print("device:", device)
|
print("device:", device)
|
||||||
|
|
||||||
checkpoint = "/Users/hhwang/models/opt-125m"
|
checkpoint = "/Users/hhwang/models/opt-125m"
|
||||||
|
@ -37,7 +43,7 @@ trainer = Trainer(
|
||||||
args=TrainingArguments(
|
args=TrainingArguments(
|
||||||
per_device_train_batch_size=4,
|
per_device_train_batch_size=4,
|
||||||
gradient_accumulation_steps=4,
|
gradient_accumulation_steps=4,
|
||||||
warmup_steps=2,
|
warmup_steps=3,
|
||||||
max_steps=10,
|
max_steps=10,
|
||||||
learning_rate=2e-4,
|
learning_rate=2e-4,
|
||||||
# fp16=True, # only works on cuda
|
# fp16=True, # only works on cuda
|
||||||
|
|
Loading…
Reference in New Issue