check_weights

pull/476/head
lijiaxing 2023-11-07 15:34:34 +08:00
parent 25604ed040
commit 8c8883367a
1 changed files with 5 additions and 5 deletions

View File

@ -7,7 +7,7 @@ script_dir = os.path.dirname(os.path.abspath(__file__))
project_root = os.path.abspath(os.path.join(script_dir, "../../")) project_root = os.path.abspath(os.path.join(script_dir, "../../"))
sys.path.append(project_root) sys.path.append(project_root)
# pylint: disable=C0413 # pylint: disable=C0413,W0612,W0611
import socket import socket
import time import time
import traceback import traceback
@ -218,14 +218,14 @@ def main(args):
# load batch data # load batch data
# batch, train_iter = load_new_batch(train_dl=train_dl, train_iter=train_iter, train_state=train_state) # batch, train_iter = load_new_batch(train_dl=train_dl, train_iter=train_iter, train_state=train_state)
# pylint: disable=C0301
batch_index = batch_count % 1000 batch_index = batch_count % 1000
if batch_index == 0: if batch_index == 0:
data_local_rank = gpc.get_local_rank(ParallelMode.DATA) data_local_rank = gpc.get_local_rank(ParallelMode.DATA)
batch_step = (batch_count // 1000 + 1) * 1000 batch_step = (batch_count // 1000 + 1) * 1000
data_path = f'/mnt/petrelfs/share/quailty_assurance/debug_Qiansanqiang_7B_v16/dp-11{data_local_rank}/batch-{batch_step}.pt' data_path = f"/mnt/petrelfs/share/quailty_assurance/debug_Qiansanqiang_7B_v16/dp-11{data_local_rank}/batch-{batch_step}.pt"
data_1000 = torch.load(data_path, map_location=torch.device('cpu')) data_1000 = torch.load(data_path, map_location=torch.device("cpu"))
batch = data_1000[batch_index] batch = data_1000[batch_index]
# record the consumed samples in training # record the consumed samples in training
train_state.batch_count = batch_count train_state.batch_count = batch_count
@ -316,7 +316,7 @@ def main(args):
"/mnt/petrelfs/share/quailty_assurance/7B_model_weights_ckpt", str(batch_count), "model_tp0_pp0.pt" "/mnt/petrelfs/share/quailty_assurance/7B_model_weights_ckpt", str(batch_count), "model_tp0_pp0.pt"
) )
check_model_weights(model, ckpt_path) check_model_weights(model, ckpt_path)
# checkpoint the training states in specific steps, which is determined by the args "checkpoint_every" # checkpoint the training states in specific steps, which is determined by the args "checkpoint_every"
# # save batch sampler that tracks the true consumed samples # # save batch sampler that tracks the true consumed samples
now_break = ckpt_manager.try_save_checkpoint(train_state) now_break = ckpt_manager.try_save_checkpoint(train_state)