volc_path (#454)

pull/464/head
jiaxingli 2023-10-27 18:53:06 +08:00 committed by GitHub
parent 87a3c5c374
commit e6d8ebc3e5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 6 additions and 1 deletions

View File

@ -4,6 +4,7 @@
# adopted from https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/engine
import json
import os
from collections import deque
from typing import Iterable, Optional
@ -120,13 +121,17 @@ class TrainState:
self.resume_tb_folder = other_stuffs.get("tensorboard_folder", None)
def state_dict(self):
if os.environ.get("CLUSTER_NAME") == "volc" and os.environ.get("petrelfs_tb_path") is not None:
tensorboard_folder = os.path.join(os.environ["petrelfs_tb_path"], os.environ["MLP_TASK_ID"])
else:
tensorboard_folder = self.tensorboard_folder
return {
"batch_count": self.batch_count,
"num_consumed_samples_in_epoch": self.num_consumed_samples_in_epoch,
"num_consumed_tokens": self.num_consumed_tokens,
"inf_nan_skip_batches": self.inf_nan_skip_batches,
"step_count": self.step_count,
"tensorboard_folder": self.tensorboard_folder,
"tensorboard_folder": tensorboard_folder,
}