check_init

pull/502/head
lijiaxing 2023-11-16 10:15:49 +08:00
parent e05751e0c6
commit dd5fbd2edf
1 changed files with 9 additions and 5 deletions

View File

@ -225,7 +225,7 @@ def main(args):
ckpt_name = ( ckpt_name = (
f"model_tp{gpc.get_local_rank(ParallelMode.TENSOR)}_pp{gpc.get_local_rank(ParallelMode.PIPELINE)}.pt" f"model_tp{gpc.get_local_rank(ParallelMode.TENSOR)}_pp{gpc.get_local_rank(ParallelMode.PIPELINE)}.pt"
) )
ckpt_path = os.path.join("...", ckpt_name) ckpt_path = os.path.join(os.environ["share_path"], "quailty_assurance/7B_init_8_tp=4_pp=2_ckpt", ckpt_name)
check_model_weights(model, ckpt_path, total_equal=True) check_model_weights(model, ckpt_path, total_equal=True)
with initialize_llm_profile(profiling=args.profiling, start_time=current_time) as prof: with initialize_llm_profile(profiling=args.profiling, start_time=current_time) as prof:
@ -242,9 +242,10 @@ def main(args):
if batch_index == 0: if batch_index == 0:
data_local_rank = gpc.get_local_rank(ParallelMode.DATA) data_local_rank = gpc.get_local_rank(ParallelMode.DATA)
batch_step = (batch_count // 1000 + 1) * 1000 batch_step = (batch_count // 1000 + 1) * 1000
data_path = ( data_path = os.path.join(
"..." os.environ["share_path"],
f"dp-11{data_local_rank}/batch-{batch_step}.pt" "quailty_assurance/debug_Qiansanqiang_7B_v16",
f"dp-11{data_local_rank}/batch-{batch_step}.pt",
) )
data_1000 = torch.load(data_path, map_location=torch.device("cpu")) data_1000 = torch.load(data_path, map_location=torch.device("cpu"))
batch = data_1000[batch_index] batch = data_1000[batch_index]
@ -336,7 +337,10 @@ def main(args):
# check model weights # check model weights
if batch_count > 0 and batch_count % 100 == 0: if batch_count > 0 and batch_count % 100 == 0:
ckpt_path = os.path.join( ckpt_path = os.path.join(
"...", str(batch_count), "model_tp0_pp0.pt" os.environ["share_path"],
"quailty_assurance/7B_model_weights_ckpt",
str(batch_count),
"model_tp0_pp0.pt",
) )
check_model_weights(model, ckpt_path) check_model_weights(model, ckpt_path)