diff --git a/tests/test_training/train_CI.py b/tests/test_training/train_CI.py index b05772b..f3e7a9f 100644 --- a/tests/test_training/train_CI.py +++ b/tests/test_training/train_CI.py @@ -225,7 +225,7 @@ def main(args): ckpt_name = ( f"model_tp{gpc.get_local_rank(ParallelMode.TENSOR)}_pp{gpc.get_local_rank(ParallelMode.PIPELINE)}.pt" ) - ckpt_path = os.path.join("/mnt/petrelfs/share/quailty_assurance/7B_init_8_tp=4_pp=2_ckpt", ckpt_name) + ckpt_path = os.path.join("...", ckpt_name) check_model_weights(model, ckpt_path, total_equal=True) with initialize_llm_profile(profiling=args.profiling, start_time=current_time) as prof: @@ -243,7 +243,7 @@ def main(args): data_local_rank = gpc.get_local_rank(ParallelMode.DATA) batch_step = (batch_count // 1000 + 1) * 1000 data_path = ( - "/mnt/petrelfs/share/quailty_assurance/debug_Qiansanqiang_7B_v16/" + "..." f"dp-11{data_local_rank}/batch-{batch_step}.pt" ) data_1000 = torch.load(data_path, map_location=torch.device("cpu")) @@ -336,7 +336,7 @@ def main(args): # check model weights if batch_count > 0 and batch_count % 100 == 0: ckpt_path = os.path.join( - "/mnt/petrelfs/share/quailty_assurance/7B_model_weights_ckpt", str(batch_count), "model_tp0_pp0.pt" + "...", str(batch_count), "model_tp0_pp0.pt" ) check_model_weights(model, ckpt_path)