mirror of https://github.com/InternLM/InternLM
check_weights
parent
25604ed040
commit
8c8883367a
|
@ -7,7 +7,7 @@ script_dir = os.path.dirname(os.path.abspath(__file__))
|
|||
project_root = os.path.abspath(os.path.join(script_dir, "../../"))
|
||||
sys.path.append(project_root)
|
||||
|
||||
# pylint: disable=C0413
|
||||
# pylint: disable=C0413,W0612,W0611
|
||||
import socket
|
||||
import time
|
||||
import traceback
|
||||
|
@ -218,14 +218,14 @@ def main(args):
|
|||
|
||||
# load batch data
|
||||
# batch, train_iter = load_new_batch(train_dl=train_dl, train_iter=train_iter, train_state=train_state)
|
||||
# pylint: disable=C0301
|
||||
batch_index = batch_count % 1000
|
||||
if batch_index == 0:
|
||||
data_local_rank = gpc.get_local_rank(ParallelMode.DATA)
|
||||
batch_step = (batch_count // 1000 + 1) * 1000
|
||||
data_path = f'/mnt/petrelfs/share/quailty_assurance/debug_Qiansanqiang_7B_v16/dp-11{data_local_rank}/batch-{batch_step}.pt'
|
||||
data_1000 = torch.load(data_path, map_location=torch.device('cpu'))
|
||||
data_path = f"/mnt/petrelfs/share/quailty_assurance/debug_Qiansanqiang_7B_v16/dp-11{data_local_rank}/batch-{batch_step}.pt"
|
||||
data_1000 = torch.load(data_path, map_location=torch.device("cpu"))
|
||||
batch = data_1000[batch_index]
|
||||
|
||||
|
||||
# record the consumed samples in training
|
||||
train_state.batch_count = batch_count
|
||||
|
@ -316,7 +316,7 @@ def main(args):
|
|||
"/mnt/petrelfs/share/quailty_assurance/7B_model_weights_ckpt", str(batch_count), "model_tp0_pp0.pt"
|
||||
)
|
||||
check_model_weights(model, ckpt_path)
|
||||
|
||||
|
||||
# checkpoint the training states in specific steps, which is determined by the args "checkpoint_every"
|
||||
# # save batch sampler that tracks the true consumed samples
|
||||
now_break = ckpt_manager.try_save_checkpoint(train_state)
|
||||
|
|
Loading…
Reference in New Issue