mirror of https://github.com/InternLM/InternLM
check_weights
parent
25604ed040
commit
8c8883367a
|
@ -7,7 +7,7 @@ script_dir = os.path.dirname(os.path.abspath(__file__))
|
||||||
project_root = os.path.abspath(os.path.join(script_dir, "../../"))
|
project_root = os.path.abspath(os.path.join(script_dir, "../../"))
|
||||||
sys.path.append(project_root)
|
sys.path.append(project_root)
|
||||||
|
|
||||||
# pylint: disable=C0413
|
# pylint: disable=C0413,W0612,W0611
|
||||||
import socket
|
import socket
|
||||||
import time
|
import time
|
||||||
import traceback
|
import traceback
|
||||||
|
@ -218,15 +218,15 @@ def main(args):
|
||||||
|
|
||||||
# load batch data
|
# load batch data
|
||||||
# batch, train_iter = load_new_batch(train_dl=train_dl, train_iter=train_iter, train_state=train_state)
|
# batch, train_iter = load_new_batch(train_dl=train_dl, train_iter=train_iter, train_state=train_state)
|
||||||
|
# pylint: disable=C0301
|
||||||
batch_index = batch_count % 1000
|
batch_index = batch_count % 1000
|
||||||
if batch_index == 0:
|
if batch_index == 0:
|
||||||
data_local_rank = gpc.get_local_rank(ParallelMode.DATA)
|
data_local_rank = gpc.get_local_rank(ParallelMode.DATA)
|
||||||
batch_step = (batch_count // 1000 + 1) * 1000
|
batch_step = (batch_count // 1000 + 1) * 1000
|
||||||
data_path = f'/mnt/petrelfs/share/quailty_assurance/debug_Qiansanqiang_7B_v16/dp-11{data_local_rank}/batch-{batch_step}.pt'
|
data_path = f"/mnt/petrelfs/share/quailty_assurance/debug_Qiansanqiang_7B_v16/dp-11{data_local_rank}/batch-{batch_step}.pt"
|
||||||
data_1000 = torch.load(data_path, map_location=torch.device('cpu'))
|
data_1000 = torch.load(data_path, map_location=torch.device("cpu"))
|
||||||
batch = data_1000[batch_index]
|
batch = data_1000[batch_index]
|
||||||
|
|
||||||
|
|
||||||
# record the consumed samples in training
|
# record the consumed samples in training
|
||||||
train_state.batch_count = batch_count
|
train_state.batch_count = batch_count
|
||||||
train_state.num_consumed_samples_in_epoch += len(batch[1])
|
train_state.num_consumed_samples_in_epoch += len(batch[1])
|
||||||
|
|
Loading…
Reference in New Issue