From dd5fbd2edf5b6f67326db2b178cb7daeeac15d66 Mon Sep 17 00:00:00 2001 From: lijiaxing Date: Thu, 16 Nov 2023 10:15:49 +0800 Subject: [PATCH] check_init --- tests/test_training/train_CI.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/tests/test_training/train_CI.py b/tests/test_training/train_CI.py index f3e7a9f..348c780 100644 --- a/tests/test_training/train_CI.py +++ b/tests/test_training/train_CI.py @@ -225,7 +225,7 @@ def main(args): ckpt_name = ( f"model_tp{gpc.get_local_rank(ParallelMode.TENSOR)}_pp{gpc.get_local_rank(ParallelMode.PIPELINE)}.pt" ) - ckpt_path = os.path.join("...", ckpt_name) + ckpt_path = os.path.join(os.environ["share_path"], "quailty_assurance/7B_init_8_tp=4_pp=2_ckpt", ckpt_name) check_model_weights(model, ckpt_path, total_equal=True) with initialize_llm_profile(profiling=args.profiling, start_time=current_time) as prof: @@ -242,9 +242,10 @@ def main(args): if batch_index == 0: data_local_rank = gpc.get_local_rank(ParallelMode.DATA) batch_step = (batch_count // 1000 + 1) * 1000 - data_path = ( - "..." - f"dp-11{data_local_rank}/batch-{batch_step}.pt" + data_path = os.path.join( + os.environ["share_path"], + "quailty_assurance/debug_Qiansanqiang_7B_v16", + f"dp-11{data_local_rank}/batch-{batch_step}.pt", ) data_1000 = torch.load(data_path, map_location=torch.device("cpu")) batch = data_1000[batch_index] @@ -336,7 +337,10 @@ def main(args): # check model weights if batch_count > 0 and batch_count % 100 == 0: ckpt_path = os.path.join( - "...", str(batch_count), "model_tp0_pp0.pt" + os.environ["share_path"], + "quailty_assurance/7B_model_weights_ckpt", + str(batch_count), + "model_tp0_pp0.pt", ) check_model_weights(model, ckpt_path)