check_weights

pull/476/head
lijiaxing 2023-11-07 15:34:34 +08:00
parent 25604ed040
commit 8c8883367a
1 changed files with 5 additions and 5 deletions

View File

@ -7,7 +7,7 @@ script_dir = os.path.dirname(os.path.abspath(__file__))
project_root = os.path.abspath(os.path.join(script_dir, "../../"))
sys.path.append(project_root)
# pylint: disable=C0413
# pylint: disable=C0413,W0612,W0611
import socket
import time
import traceback
@ -218,15 +218,15 @@ def main(args):
# load batch data
# batch, train_iter = load_new_batch(train_dl=train_dl, train_iter=train_iter, train_state=train_state)
# pylint: disable=C0301
batch_index = batch_count % 1000
if batch_index == 0:
data_local_rank = gpc.get_local_rank(ParallelMode.DATA)
batch_step = (batch_count // 1000 + 1) * 1000
data_path = f'/mnt/petrelfs/share/quailty_assurance/debug_Qiansanqiang_7B_v16/dp-11{data_local_rank}/batch-{batch_step}.pt'
data_1000 = torch.load(data_path, map_location=torch.device('cpu'))
data_path = f"/mnt/petrelfs/share/quailty_assurance/debug_Qiansanqiang_7B_v16/dp-11{data_local_rank}/batch-{batch_step}.pt"
data_1000 = torch.load(data_path, map_location=torch.device("cpu"))
batch = data_1000[batch_index]
# record the consumed samples in training
train_state.batch_count = batch_count
train_state.num_consumed_samples_in_epoch += len(batch[1])