From 8c8883367a51d462b1e38e45f9e651d64de861a1 Mon Sep 17 00:00:00 2001 From: lijiaxing Date: Tue, 7 Nov 2023 15:34:34 +0800 Subject: [PATCH] check_weights --- tests/test_training/train_CI.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/tests/test_training/train_CI.py b/tests/test_training/train_CI.py index 897475d..f81b4c6 100644 --- a/tests/test_training/train_CI.py +++ b/tests/test_training/train_CI.py @@ -7,7 +7,7 @@ script_dir = os.path.dirname(os.path.abspath(__file__)) project_root = os.path.abspath(os.path.join(script_dir, "../../")) sys.path.append(project_root) -# pylint: disable=C0413 +# pylint: disable=C0413,W0612,W0611 import socket import time import traceback @@ -218,14 +218,14 @@ def main(args): # load batch data # batch, train_iter = load_new_batch(train_dl=train_dl, train_iter=train_iter, train_state=train_state) + # pylint: disable=C0301 batch_index = batch_count % 1000 if batch_index == 0: data_local_rank = gpc.get_local_rank(ParallelMode.DATA) batch_step = (batch_count // 1000 + 1) * 1000 - data_path = f'/mnt/petrelfs/share/quailty_assurance/debug_Qiansanqiang_7B_v16/dp-11{data_local_rank}/batch-{batch_step}.pt' - data_1000 = torch.load(data_path, map_location=torch.device('cpu')) + data_path = f"/mnt/petrelfs/share/quailty_assurance/debug_Qiansanqiang_7B_v16/dp-11{data_local_rank}/batch-{batch_step}.pt" + data_1000 = torch.load(data_path, map_location=torch.device("cpu")) batch = data_1000[batch_index] - # record the consumed samples in training train_state.batch_count = batch_count @@ -316,7 +316,7 @@ def main(args): "/mnt/petrelfs/share/quailty_assurance/7B_model_weights_ckpt", str(batch_count), "model_tp0_pp0.pt" ) check_model_weights(model, ckpt_path) - + # checkpoint the training states in specific steps, which is determined by the args "checkpoint_every" # # save batch sampler that tracks the true consumed samples now_break = ckpt_manager.try_save_checkpoint(train_state)