check_weights

pull/476/head
lijiaxing 2023-11-07 15:34:34 +08:00
parent 25604ed040
commit 8c8883367a
1 changed files with 5 additions and 5 deletions

View File

@ -7,7 +7,7 @@ script_dir = os.path.dirname(os.path.abspath(__file__))
project_root = os.path.abspath(os.path.join(script_dir, "../../"))
sys.path.append(project_root)
# pylint: disable=C0413
# pylint: disable=C0413,W0612,W0611
import socket
import time
import traceback
@ -218,14 +218,14 @@ def main(args):
# load batch data
# batch, train_iter = load_new_batch(train_dl=train_dl, train_iter=train_iter, train_state=train_state)
# pylint: disable=C0301
batch_index = batch_count % 1000
if batch_index == 0:
data_local_rank = gpc.get_local_rank(ParallelMode.DATA)
batch_step = (batch_count // 1000 + 1) * 1000
data_path = f'/mnt/petrelfs/share/quailty_assurance/debug_Qiansanqiang_7B_v16/dp-11{data_local_rank}/batch-{batch_step}.pt'
data_1000 = torch.load(data_path, map_location=torch.device('cpu'))
data_path = f"/mnt/petrelfs/share/quailty_assurance/debug_Qiansanqiang_7B_v16/dp-11{data_local_rank}/batch-{batch_step}.pt"
data_1000 = torch.load(data_path, map_location=torch.device("cpu"))
batch = data_1000[batch_index]
# record the consumed samples in training
train_state.batch_count = batch_count
@ -316,7 +316,7 @@ def main(args):
"/mnt/petrelfs/share/quailty_assurance/7B_model_weights_ckpt", str(batch_count), "model_tp0_pp0.pt"
)
check_model_weights(model, ckpt_path)
# checkpoint the training states in specific steps, which is determined by the args "checkpoint_every"
# # save batch sampler that tracks the true consumed samples
now_break = ckpt_manager.try_save_checkpoint(train_state)