mirror of https://github.com/InternLM/InternLM
fix(ckpt): fix snapshot none load error and remove file lock (#298)
parent
1ee31ff9b1
commit
85e39aae67
|
@ -2,7 +2,6 @@
|
||||||
# -*- encoding: utf-8 -*-
|
# -*- encoding: utf-8 -*-
|
||||||
|
|
||||||
import copy
|
import copy
|
||||||
import fcntl
|
|
||||||
import inspect
|
import inspect
|
||||||
import os
|
import os
|
||||||
import socket
|
import socket
|
||||||
|
@ -545,12 +544,17 @@ class CheckpointManager:
|
||||||
if self.stop_file_path is None:
|
if self.stop_file_path is None:
|
||||||
return now_break, now_save_ckpt, save_type
|
return now_break, now_save_ckpt, save_type
|
||||||
|
|
||||||
with open(self.stop_file_path, "a+", encoding="utf-8") as f:
|
with torch.no_grad():
|
||||||
fcntl.flock(f, fcntl.LOCK_EX)
|
action_step_t = torch.zeros((1,), dtype=torch.int64).cuda()
|
||||||
f.seek(0)
|
if gpc.get_global_rank() == 0:
|
||||||
msg = f.read()
|
with open(self.stop_file_path, "r+", encoding="utf-8") as f:
|
||||||
fcntl.flock(f, fcntl.LOCK_UN)
|
f.seek(0)
|
||||||
action_step = int(msg)
|
msg = f.read()
|
||||||
|
action_step_t.fill_(int(msg))
|
||||||
|
|
||||||
|
torch.distributed.broadcast(action_step_t, src=0)
|
||||||
|
action_step = action_step_t.item()
|
||||||
|
del action_step_t
|
||||||
|
|
||||||
if action_step < 0 and abs(action_step) == train_state.step_count:
|
if action_step < 0 and abs(action_step) == train_state.step_count:
|
||||||
now_save_ckpt = True
|
now_save_ckpt = True
|
||||||
|
@ -627,41 +631,50 @@ now step_count is {train_state.step_count}",
|
||||||
return None, None
|
return None, None
|
||||||
|
|
||||||
max_normal_step = 0
|
max_normal_step = 0
|
||||||
ckpt_list = list(map(lambda a: int(a.strip("/")) if a.strip("/").isdigit() else 0, ckpt_list))
|
# Return ckpt_list look like: ['pings', 'snapshot', '4']
|
||||||
ckpt_list.sort(reverse=True)
|
# Here we only try to find the ckpt folder named after step, ignoring snapshot and other folders.
|
||||||
for ckpt in ckpt_list:
|
ckpt_list = [int(fn.strip("/")) for fn in ckpt_list if fn.strip("/").isdigit()]
|
||||||
fns_list = self.storage_manager.get_fns(os.path.join(self.save_ckpt_folder, str(ckpt)))
|
if len(ckpt_list) == 0:
|
||||||
for fn in fns_list:
|
logger.warning("Not found avaliable normal checkpoint!")
|
||||||
if fn.endswith(".step"):
|
else:
|
||||||
max_normal_step = ckpt
|
logger.info(f"Found avaliable normal checkpoint: {ckpt_list}!")
|
||||||
|
ckpt_list.sort(reverse=True)
|
||||||
|
for ckpt in ckpt_list:
|
||||||
|
fns_list = self.storage_manager.get_fns(os.path.join(self.save_ckpt_folder, str(ckpt)))
|
||||||
|
for fn in fns_list:
|
||||||
|
if fn.endswith(".step"):
|
||||||
|
max_normal_step = ckpt
|
||||||
|
break
|
||||||
|
if max_normal_step != 0:
|
||||||
break
|
break
|
||||||
if max_normal_step != 0:
|
|
||||||
break
|
|
||||||
|
|
||||||
max_normal_step = ckpt_list[0]
|
max_normal_step = ckpt_list[0]
|
||||||
load_normal_ckpt_path = os.path.join(self.save_ckpt_folder, str(max_normal_step))
|
load_normal_ckpt_path = os.path.join(self.save_ckpt_folder, str(max_normal_step))
|
||||||
|
|
||||||
snapshot_path_0 = os.path.join(self.save_ckpt_folder, "snapshot", "0")
|
snapshot_path_0 = os.path.join(self.save_ckpt_folder, "snapshot", "0")
|
||||||
snapshot_path_1 = os.path.join(self.save_ckpt_folder, "snapshot", "1")
|
snapshot_path_1 = os.path.join(self.save_ckpt_folder, "snapshot", "1")
|
||||||
ckpt_list_1 = self.storage_manager.get_fns(snapshot_path_0)
|
ckpt_list_0 = self.storage_manager.get_fns(snapshot_path_0)
|
||||||
ckpt_list_2 = self.storage_manager.get_fns(snapshot_path_1)
|
ckpt_list_1 = self.storage_manager.get_fns(snapshot_path_1)
|
||||||
max_step_0, max_step_1 = 0, 0
|
|
||||||
if ckpt_list_1:
|
|
||||||
for ckpt in ckpt_list_1:
|
|
||||||
ckpt = ckpt.strip("/")
|
|
||||||
if ckpt.endswith(".step"):
|
|
||||||
max_step_0 = max(max_step_0, int(ckpt.split(".")[0]))
|
|
||||||
if ckpt_list_2:
|
|
||||||
for ckpt in ckpt_list_2:
|
|
||||||
ckpt = ckpt.strip("/")
|
|
||||||
if ckpt.endswith(".step"):
|
|
||||||
max_step_1 = max(max_step_1, int(ckpt.split(".")[0]))
|
|
||||||
|
|
||||||
snap_load_path = snapshot_path_0 if max_step_0 > max_step_1 else snapshot_path_1
|
def found_latest_snapshot(_ckpt_list):
|
||||||
snap_step = max(max_step_0, max_step_1)
|
_max_step_snapshot = 0
|
||||||
load_path = snap_load_path if snap_step > max_normal_step else load_normal_ckpt_path
|
if _ckpt_list:
|
||||||
load_step = max(snap_step, max_normal_step)
|
for ckpt in _ckpt_list:
|
||||||
return load_path, load_step
|
ckpt = ckpt.strip("/")
|
||||||
|
if ckpt.endswith(".step"):
|
||||||
|
_max_step_snapshot = max(_max_step_snapshot, int(ckpt.split(".")[0]))
|
||||||
|
return _max_step_snapshot
|
||||||
|
|
||||||
|
max_step_0 = found_latest_snapshot(ckpt_list_0)
|
||||||
|
max_step_1 = found_latest_snapshot(ckpt_list_1)
|
||||||
|
|
||||||
|
if sum([max_step_0, max_step_1, max_normal_step]) == 0:
|
||||||
|
return None, None
|
||||||
|
else:
|
||||||
|
snap_load_path = snapshot_path_0 if max_step_0 > max_step_1 else snapshot_path_1
|
||||||
|
snap_step = max(max_step_0, max_step_1)
|
||||||
|
load_path = snap_load_path if snap_step > max_normal_step else load_normal_ckpt_path
|
||||||
|
return load_path, max(snap_step, max_normal_step)
|
||||||
|
|
||||||
def query_latest_snapshot_step_local(self):
|
def query_latest_snapshot_step_local(self):
|
||||||
max_step, max_step_path = 0, None
|
max_step, max_step_path = 0, None
|
||||||
|
|
|
@ -50,6 +50,8 @@ init_config = Config(
|
||||||
),
|
),
|
||||||
resume_tb_folder="",
|
resume_tb_folder="",
|
||||||
tensorboard_folder="",
|
tensorboard_folder="",
|
||||||
|
alert_address=None,
|
||||||
|
monitor=dict(alert=dict(enable_feishu_alert=False, feishu_alert_address=None, light_monitor_address=None)),
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -177,5 +179,5 @@ def del_tmp_file():
|
||||||
results += str(line.rstrip())
|
results += str(line.rstrip())
|
||||||
presults += line.rstrip().decode() + "\n"
|
presults += line.rstrip().decode() + "\n"
|
||||||
print(presults, flush=True)
|
print(presults, flush=True)
|
||||||
except FileNotFoundError:
|
except: # noqa # pylint: disable=bare-except
|
||||||
pass
|
pass
|
||||||
|
|
|
@ -1,9 +1,10 @@
|
||||||
import os
|
import os
|
||||||
|
from functools import partial
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
|
import torch.distributed as dist
|
||||||
|
|
||||||
from internlm.core.context import global_context as gpc
|
|
||||||
from internlm.core.context.parallel_context import Config
|
from internlm.core.context.parallel_context import Config
|
||||||
from internlm.core.trainer import TrainState
|
from internlm.core.trainer import TrainState
|
||||||
from internlm.solver.optimizer.hybrid_zero_optim import HybridZeroOptimizer
|
from internlm.solver.optimizer.hybrid_zero_optim import HybridZeroOptimizer
|
||||||
|
@ -15,27 +16,24 @@ from tests.test_utils.common_fixture import ( # noqa # pylint: disable=unused-i
|
||||||
BOTO_SAVE_PATH,
|
BOTO_SAVE_PATH,
|
||||||
LOCAL_SAVE_PATH,
|
LOCAL_SAVE_PATH,
|
||||||
del_tmp_file,
|
del_tmp_file,
|
||||||
|
init_config,
|
||||||
init_dist_and_model,
|
init_dist_and_model,
|
||||||
reset_singletons,
|
reset_singletons,
|
||||||
)
|
)
|
||||||
|
|
||||||
TOTAL_STEP = 6
|
# (TOTAL_STEP, CKPT_EVERY, SNPASHOT_EVERY)
|
||||||
|
step_info_list = [(8, 4, 2), (3, 4, 2), (1, 6, 3)]
|
||||||
CKPT_EVERY = 4
|
|
||||||
SNPASHOT_EVERY = 2
|
|
||||||
|
|
||||||
|
|
||||||
ckpt_config_list = [
|
ckpt_config_list = [
|
||||||
# Old interface format
|
# Old interface format
|
||||||
dict(
|
dict(
|
||||||
enable_save_ckpt=True,
|
enable_save_ckpt=True,
|
||||||
save_ckpt_folder=BOTO_SAVE_PATH,
|
save_ckpt_folder=BOTO_SAVE_PATH,
|
||||||
load_optimizer=True,
|
load_optimizer=True,
|
||||||
checkpoint_every=CKPT_EVERY,
|
checkpoint_every=0,
|
||||||
async_upload=True,
|
async_upload=True,
|
||||||
async_upload_tmp_folder=ASYNC_TMP_FOLDER,
|
async_upload_tmp_folder=ASYNC_TMP_FOLDER,
|
||||||
snapshot_ckpt_folder="/".join([BOTO_SAVE_PATH, "snapshot"]),
|
snapshot_ckpt_folder="/".join([BOTO_SAVE_PATH, "snapshot"]),
|
||||||
oss_snapshot_freq=SNPASHOT_EVERY,
|
oss_snapshot_freq=0,
|
||||||
stop_file_path=None,
|
stop_file_path=None,
|
||||||
load_model_only_folder=None,
|
load_model_only_folder=None,
|
||||||
load_given_ckpt=False,
|
load_given_ckpt=False,
|
||||||
|
@ -47,11 +45,11 @@ ckpt_config_list = [
|
||||||
enable_save_ckpt=True,
|
enable_save_ckpt=True,
|
||||||
save_ckpt_folder=LOCAL_SAVE_PATH,
|
save_ckpt_folder=LOCAL_SAVE_PATH,
|
||||||
load_optimizer=True,
|
load_optimizer=True,
|
||||||
checkpoint_every=CKPT_EVERY,
|
checkpoint_every=0,
|
||||||
async_upload=False,
|
async_upload=False,
|
||||||
async_upload_tmp_folder=ASYNC_TMP_FOLDER,
|
async_upload_tmp_folder=ASYNC_TMP_FOLDER,
|
||||||
snapshot_ckpt_folder="/".join([LOCAL_SAVE_PATH, "snapshot"]),
|
snapshot_ckpt_folder="/".join([LOCAL_SAVE_PATH, "snapshot"]),
|
||||||
oss_snapshot_freq=SNPASHOT_EVERY,
|
oss_snapshot_freq=0,
|
||||||
stop_file_path=None,
|
stop_file_path=None,
|
||||||
load_model_only_folder=None,
|
load_model_only_folder=None,
|
||||||
load_given_ckpt=False,
|
load_given_ckpt=False,
|
||||||
|
@ -62,10 +60,10 @@ ckpt_config_list = [
|
||||||
dict(
|
dict(
|
||||||
enable_save_ckpt=True,
|
enable_save_ckpt=True,
|
||||||
save_ckpt_folder=BOTO_SAVE_PATH,
|
save_ckpt_folder=BOTO_SAVE_PATH,
|
||||||
checkpoint_every=CKPT_EVERY,
|
checkpoint_every=0,
|
||||||
async_upload=True,
|
async_upload=True,
|
||||||
async_upload_tmp_folder=ASYNC_TMP_FOLDER,
|
async_upload_tmp_folder=ASYNC_TMP_FOLDER,
|
||||||
oss_snapshot_freq=SNPASHOT_EVERY,
|
oss_snapshot_freq=0,
|
||||||
stop_file_path=None,
|
stop_file_path=None,
|
||||||
is_old_api=False,
|
is_old_api=False,
|
||||||
auto_resume=True,
|
auto_resume=True,
|
||||||
|
@ -73,10 +71,10 @@ ckpt_config_list = [
|
||||||
dict(
|
dict(
|
||||||
enable_save_ckpt=True,
|
enable_save_ckpt=True,
|
||||||
save_ckpt_folder=LOCAL_SAVE_PATH,
|
save_ckpt_folder=LOCAL_SAVE_PATH,
|
||||||
checkpoint_every=CKPT_EVERY,
|
checkpoint_every=0,
|
||||||
async_upload=False,
|
async_upload=False,
|
||||||
async_upload_tmp_folder=ASYNC_TMP_FOLDER,
|
async_upload_tmp_folder=ASYNC_TMP_FOLDER,
|
||||||
oss_snapshot_freq=SNPASHOT_EVERY,
|
oss_snapshot_freq=0,
|
||||||
stop_file_path=None,
|
stop_file_path=None,
|
||||||
load_ckpt_folder=None,
|
load_ckpt_folder=None,
|
||||||
is_old_api=False,
|
is_old_api=False,
|
||||||
|
@ -159,15 +157,63 @@ def del_tmp():
|
||||||
del_tmp_file()
|
del_tmp_file()
|
||||||
|
|
||||||
|
|
||||||
|
def return_prefix_path(save_ckpt_folder):
|
||||||
|
if save_ckpt_folder.startswith("local:"):
|
||||||
|
return LOCAL_SAVE_PATH
|
||||||
|
else:
|
||||||
|
return BOTO_SAVE_PATH
|
||||||
|
|
||||||
|
|
||||||
|
def return_latest_save_path(save_ckpt_folder, total_step, snapshot_freq, ckpt_freq):
|
||||||
|
|
||||||
|
snapshot_latest_step, normal_latest_step = 0, 0
|
||||||
|
snapshot_latest_count, normal_latest_count = 0, 0
|
||||||
|
|
||||||
|
for i in range(total_step):
|
||||||
|
if (i + 1) % ckpt_freq == 0:
|
||||||
|
normal_latest_step = i + 1
|
||||||
|
normal_latest_count += 1
|
||||||
|
else:
|
||||||
|
if (i + 1) % snapshot_freq == 0:
|
||||||
|
snapshot_latest_step = i + 1
|
||||||
|
snapshot_latest_count += 1
|
||||||
|
|
||||||
|
if snapshot_latest_step == 0:
|
||||||
|
return None, None
|
||||||
|
|
||||||
|
if normal_latest_step >= snapshot_latest_step:
|
||||||
|
return normal_latest_step, os.path.join(return_prefix_path(save_ckpt_folder), f"{normal_latest_step}")
|
||||||
|
elif normal_latest_step < snapshot_latest_step:
|
||||||
|
if snapshot_latest_count % 2 == 0:
|
||||||
|
re_path = f"{return_prefix_path(save_ckpt_folder)}/snapshot/0"
|
||||||
|
else:
|
||||||
|
re_path = f"{return_prefix_path(save_ckpt_folder)}/snapshot/1"
|
||||||
|
return snapshot_latest_step, re_path
|
||||||
|
else:
|
||||||
|
assert False
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.usefixtures("del_tmp")
|
@pytest.mark.usefixtures("del_tmp")
|
||||||
@pytest.mark.usefixtures("reset_singletons")
|
@pytest.mark.usefixtures("reset_singletons")
|
||||||
|
@pytest.mark.parametrize("step_info", step_info_list)
|
||||||
@pytest.mark.parametrize("ckpt_config", ckpt_config_list)
|
@pytest.mark.parametrize("ckpt_config", ckpt_config_list)
|
||||||
def test_ckpt_mm(ckpt_config, init_dist_and_model): # noqa # pylint: disable=unused-import
|
def test_ckpt_mm(step_info, ckpt_config, init_dist_and_model): # noqa # pylint: disable=unused-import
|
||||||
|
from internlm.core.context import global_context as gpc
|
||||||
from internlm.utils.model_checkpoint import CheckpointLoadMask, CheckpointLoadType
|
from internlm.utils.model_checkpoint import CheckpointLoadMask, CheckpointLoadType
|
||||||
|
|
||||||
ckpt_config = Config(ckpt_config)
|
ckpt_config = Config(ckpt_config)
|
||||||
assert ckpt_config.checkpoint_every < TOTAL_STEP
|
total_step, checkpoint_every, oss_snapshot_freq = step_info
|
||||||
assert ckpt_config.oss_snapshot_freq < TOTAL_STEP
|
print(total_step, checkpoint_every, oss_snapshot_freq, flush=True)
|
||||||
|
ckpt_config.checkpoint_every = checkpoint_every
|
||||||
|
ckpt_config.oss_snapshot_freq = oss_snapshot_freq
|
||||||
|
|
||||||
|
bond_return_latest_save_path = partial(
|
||||||
|
return_latest_save_path,
|
||||||
|
ckpt_config.save_ckpt_folder,
|
||||||
|
total_step,
|
||||||
|
ckpt_config.oss_snapshot_freq,
|
||||||
|
ckpt_config.checkpoint_every,
|
||||||
|
)
|
||||||
|
|
||||||
model, opim = init_dist_and_model
|
model, opim = init_dist_and_model
|
||||||
train_state = TrainState(gpc.config, None)
|
train_state = TrainState(gpc.config, None)
|
||||||
|
@ -178,7 +224,7 @@ def test_ckpt_mm(ckpt_config, init_dist_and_model): # noqa # pylint: disable=un
|
||||||
|
|
||||||
ckpt_mm = CheckpointManager(ckpt_config, model=model, optimizer=opim)
|
ckpt_mm = CheckpointManager(ckpt_config, model=model, optimizer=opim)
|
||||||
latest_ckpt_step = None
|
latest_ckpt_step = None
|
||||||
for i in range(TOTAL_STEP + 1):
|
for i in range(total_step):
|
||||||
overwrite_model_value(model, i)
|
overwrite_model_value(model, i)
|
||||||
overwrite_optim_state(opim, i)
|
overwrite_optim_state(opim, i)
|
||||||
|
|
||||||
|
@ -193,54 +239,119 @@ def test_ckpt_mm(ckpt_config, init_dist_and_model): # noqa # pylint: disable=un
|
||||||
|
|
||||||
wait_async_upload_finish()
|
wait_async_upload_finish()
|
||||||
latest_ckpt_info = ckpt_mm.query_lastest_ckpt()
|
latest_ckpt_info = ckpt_mm.query_lastest_ckpt()
|
||||||
assert latest_ckpt_info is not None
|
step, path = bond_return_latest_save_path()
|
||||||
latest_ckpt = latest_ckpt_info["path"]
|
assert latest_ckpt_info["path"] == path
|
||||||
if ckpt_mm.save_ckpt_folder.startswith("local"):
|
if latest_ckpt_step is None:
|
||||||
assert latest_ckpt == "local:local_ckpt/snapshot/0", latest_ckpt
|
assert latest_ckpt_step == step
|
||||||
else:
|
else:
|
||||||
assert latest_ckpt == f"{BOTO_SAVE_PATH}/snapshot/0", latest_ckpt
|
assert latest_ckpt_step == step - 1
|
||||||
|
|
||||||
|
# resume from before save skpt
|
||||||
del ckpt_mm
|
del ckpt_mm
|
||||||
SingletonMeta._instances = {}
|
SingletonMeta._instances = {}
|
||||||
ckpt_mm = CheckpointManager(ckpt_config, model=model, optimizer=opim)
|
ckpt_mm = CheckpointManager(ckpt_config, model=model, optimizer=opim)
|
||||||
ckpt_mm.try_resume_training(train_state)
|
ckpt_mm.try_resume_training(train_state)
|
||||||
assert latest_ckpt_step == 5
|
|
||||||
assert train_state.step_count == 6
|
|
||||||
assert train_state.batch_count == 6
|
|
||||||
assert compare_optim_value(ckpt_mm.optimizer, latest_ckpt_step), ckpt_mm.optimizer.param_groups[0]["params"][0]
|
|
||||||
assert compare_model_value(ckpt_mm.model, latest_ckpt_step), list(ckpt_mm.model.parameters())[0][0]
|
|
||||||
|
|
||||||
if ckpt_mm.save_ckpt_folder.startswith("local:"):
|
if ckpt_config.checkpoint_every < total_step:
|
||||||
ckpt_mm.load_ckpt_info = dict(
|
# we use step_count to decide when save ckpt, os here latest_ckpt_step = step_count - 1
|
||||||
path=os.path.join(LOCAL_SAVE_PATH, "4"),
|
assert train_state.step_count == latest_ckpt_step + 1
|
||||||
content=CheckpointLoadMask(("all",)),
|
assert train_state.batch_count == latest_ckpt_step + 1
|
||||||
ckpt_type=CheckpointLoadType.INTERNLM,
|
assert compare_optim_value(ckpt_mm.optimizer, latest_ckpt_step), ckpt_mm.optimizer.param_groups[0]["params"][0]
|
||||||
)
|
assert compare_model_value(ckpt_mm.model, latest_ckpt_step), list(ckpt_mm.model.parameters())[0][0]
|
||||||
|
|
||||||
|
if ckpt_mm.save_ckpt_folder.startswith("local:"):
|
||||||
|
ckpt_mm.load_ckpt_info = dict(
|
||||||
|
path=os.path.join(LOCAL_SAVE_PATH, f"{ckpt_config.checkpoint_every}"),
|
||||||
|
content=CheckpointLoadMask(("all",)),
|
||||||
|
ckpt_type=CheckpointLoadType.INTERNLM,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
ckpt_mm.load_ckpt_info = dict(
|
||||||
|
path=os.path.join(BOTO_SAVE_PATH, f"{ckpt_config.checkpoint_every}"),
|
||||||
|
content=CheckpointLoadMask(("all",)),
|
||||||
|
ckpt_type=CheckpointLoadType.INTERNLM,
|
||||||
|
)
|
||||||
|
|
||||||
|
ckpt_mm.try_resume_training(train_state)
|
||||||
|
|
||||||
|
assert train_state.step_count == ckpt_config.checkpoint_every
|
||||||
|
assert train_state.batch_count == ckpt_config.checkpoint_every
|
||||||
|
# compare value is same with i.
|
||||||
|
assert compare_optim_value(ckpt_mm.optimizer, ckpt_config.checkpoint_every - 1), ckpt_mm.optimizer.param_groups[
|
||||||
|
0
|
||||||
|
]["params"][0]
|
||||||
|
assert compare_model_value(ckpt_mm.model, ckpt_config.checkpoint_every - 1), list(ckpt_mm.model.parameters())[
|
||||||
|
0
|
||||||
|
][0]
|
||||||
else:
|
else:
|
||||||
ckpt_mm.load_ckpt_info = dict(
|
pass
|
||||||
path=os.path.join(BOTO_SAVE_PATH, "4"),
|
|
||||||
content=CheckpointLoadMask(("all",)),
|
|
||||||
ckpt_type=CheckpointLoadType.INTERNLM,
|
STOP_FILE_PATH = "./alter.log"
|
||||||
|
|
||||||
|
|
||||||
|
def query_quit_file(rank, world_size=2):
|
||||||
|
from internlm.core.context import global_context as gpc
|
||||||
|
from internlm.initialize import initialize_distributed_env
|
||||||
|
from internlm.utils.model_checkpoint import CheckpointSaveType
|
||||||
|
|
||||||
|
ckpt_config = Config(
|
||||||
|
dict(
|
||||||
|
enable_save_ckpt=True,
|
||||||
|
save_ckpt_folder=BOTO_SAVE_PATH,
|
||||||
|
load_optimizer=True,
|
||||||
|
checkpoint_every=0,
|
||||||
|
async_upload=True,
|
||||||
|
async_upload_tmp_folder=ASYNC_TMP_FOLDER,
|
||||||
|
snapshot_ckpt_folder="/".join([BOTO_SAVE_PATH, "snapshot"]),
|
||||||
|
oss_snapshot_freq=0,
|
||||||
|
stop_file_path=STOP_FILE_PATH,
|
||||||
|
load_model_only_folder=None,
|
||||||
|
load_given_ckpt=False,
|
||||||
|
load_ckpt_folder=None,
|
||||||
|
is_old_api=True,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
os.environ["RANK"] = str(rank)
|
||||||
|
os.environ["LOCAL_RANK"] = str(rank)
|
||||||
|
os.environ["WORLD_SIZE"] = str(world_size)
|
||||||
|
os.environ["MASTER_ADDR"] = "127.0.0.1"
|
||||||
|
os.environ["MASTER_PORT"] = "12376"
|
||||||
|
|
||||||
|
initialize_distributed_env(config=init_config, launcher="torch", master_port=12376, args_check=False)
|
||||||
|
train_state = TrainState(init_config, None)
|
||||||
|
ckpt_mm = CheckpointManager(ckpt_config, model=None, optimizer=None)
|
||||||
|
if rank == 0:
|
||||||
|
with open(STOP_FILE_PATH, "w+") as f:
|
||||||
|
f.write("5")
|
||||||
|
dist.barrier()
|
||||||
|
for i in range(10):
|
||||||
|
train_state.step_count = i
|
||||||
|
now_break, now_save_ckpt, save_type = ckpt_mm.quit_signal_handler(train_state)
|
||||||
|
print(
|
||||||
|
f"step:{i}, rank:{rank}, now_break:{now_break}, now_save_ckpt:{now_save_ckpt}, save_type:{save_type}",
|
||||||
|
flush=True,
|
||||||
)
|
)
|
||||||
|
if train_state.step_count == 5:
|
||||||
ckpt_mm.try_resume_training(train_state)
|
assert now_break is True
|
||||||
|
assert now_save_ckpt is True
|
||||||
assert train_state.step_count == 4
|
assert save_type is CheckpointSaveType.NORMAL_CHECKPOINT
|
||||||
assert train_state.batch_count == 4
|
dist.barrier()
|
||||||
assert compare_optim_value(ckpt_mm.optimizer, 3), ckpt_mm.optimizer.param_groups[0]["params"][0]
|
gpc.destroy()
|
||||||
assert compare_model_value(ckpt_mm.model, 3), list(ckpt_mm.model.parameters())[0][0]
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.usefixtures("del_tmp")
|
def test_quit_siganl_handler(): # noqa # pylint: disable=unused-import
|
||||||
@pytest.mark.usefixtures("reset_singletons")
|
import multiprocessing
|
||||||
@pytest.mark.parametrize("ckpt_config", ckpt_config_list)
|
from multiprocessing.pool import Pool
|
||||||
def test_ckpt_mm_ping(ckpt_config, init_dist_and_model): # noqa # pylint: disable=unused-import
|
|
||||||
ckpt_config = Config(ckpt_config)
|
|
||||||
|
|
||||||
model, opim = init_dist_and_model
|
world_size = 2
|
||||||
SingletonMeta._instances = {}
|
with Pool(processes=world_size, context=multiprocessing.get_context("spawn")) as pool:
|
||||||
ckpt_mm = CheckpointManager(ckpt_config, model=model, optimizer=opim)
|
items = [(0,), (1,)]
|
||||||
ckpt_mm.try_ping_storage()
|
for result in pool.starmap(query_quit_file, items):
|
||||||
|
print(f"Got result: {result}", flush=True)
|
||||||
|
|
||||||
|
os.remove(STOP_FILE_PATH)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
Loading…
Reference in New Issue