diff --git a/internlm/utils/model_checkpoint.py b/internlm/utils/model_checkpoint.py index b8f7ad6..b6aab02 100644 --- a/internlm/utils/model_checkpoint.py +++ b/internlm/utils/model_checkpoint.py @@ -2,7 +2,6 @@ # -*- encoding: utf-8 -*- import copy -import fcntl import inspect import os import socket @@ -545,12 +544,17 @@ class CheckpointManager: if self.stop_file_path is None: return now_break, now_save_ckpt, save_type - with open(self.stop_file_path, "a+", encoding="utf-8") as f: - fcntl.flock(f, fcntl.LOCK_EX) - f.seek(0) - msg = f.read() - fcntl.flock(f, fcntl.LOCK_UN) - action_step = int(msg) + with torch.no_grad(): + action_step_t = torch.zeros((1,), dtype=torch.int64).cuda() + if gpc.get_global_rank() == 0: + with open(self.stop_file_path, "r+", encoding="utf-8") as f: + f.seek(0) + msg = f.read() + action_step_t.fill_(int(msg)) + + torch.distributed.broadcast(action_step_t, src=0) + action_step = action_step_t.item() + del action_step_t if action_step < 0 and abs(action_step) == train_state.step_count: now_save_ckpt = True @@ -627,41 +631,50 @@ now step_count is {train_state.step_count}", return None, None max_normal_step = 0 - ckpt_list = list(map(lambda a: int(a.strip("/")) if a.strip("/").isdigit() else 0, ckpt_list)) - ckpt_list.sort(reverse=True) - for ckpt in ckpt_list: - fns_list = self.storage_manager.get_fns(os.path.join(self.save_ckpt_folder, str(ckpt))) - for fn in fns_list: - if fn.endswith(".step"): - max_normal_step = ckpt + # Return ckpt_list look like: ['pings', 'snapshot', '4'] + # Here we only try to find the ckpt folder named after step, ignoring snapshot and other folders. + ckpt_list = [int(fn.strip("/")) for fn in ckpt_list if fn.strip("/").isdigit()] + if len(ckpt_list) == 0: + logger.warning("Not found avaliable normal checkpoint!") + else: + logger.info(f"Found avaliable normal checkpoint: {ckpt_list}!") + ckpt_list.sort(reverse=True) + for ckpt in ckpt_list: + fns_list = self.storage_manager.get_fns(os.path.join(self.save_ckpt_folder, str(ckpt))) + for fn in fns_list: + if fn.endswith(".step"): + max_normal_step = ckpt + break + if max_normal_step != 0: break - if max_normal_step != 0: - break - max_normal_step = ckpt_list[0] - load_normal_ckpt_path = os.path.join(self.save_ckpt_folder, str(max_normal_step)) + max_normal_step = ckpt_list[0] + load_normal_ckpt_path = os.path.join(self.save_ckpt_folder, str(max_normal_step)) snapshot_path_0 = os.path.join(self.save_ckpt_folder, "snapshot", "0") snapshot_path_1 = os.path.join(self.save_ckpt_folder, "snapshot", "1") - ckpt_list_1 = self.storage_manager.get_fns(snapshot_path_0) - ckpt_list_2 = self.storage_manager.get_fns(snapshot_path_1) - max_step_0, max_step_1 = 0, 0 - if ckpt_list_1: - for ckpt in ckpt_list_1: - ckpt = ckpt.strip("/") - if ckpt.endswith(".step"): - max_step_0 = max(max_step_0, int(ckpt.split(".")[0])) - if ckpt_list_2: - for ckpt in ckpt_list_2: - ckpt = ckpt.strip("/") - if ckpt.endswith(".step"): - max_step_1 = max(max_step_1, int(ckpt.split(".")[0])) + ckpt_list_0 = self.storage_manager.get_fns(snapshot_path_0) + ckpt_list_1 = self.storage_manager.get_fns(snapshot_path_1) - snap_load_path = snapshot_path_0 if max_step_0 > max_step_1 else snapshot_path_1 - snap_step = max(max_step_0, max_step_1) - load_path = snap_load_path if snap_step > max_normal_step else load_normal_ckpt_path - load_step = max(snap_step, max_normal_step) - return load_path, load_step + def found_latest_snapshot(_ckpt_list): + _max_step_snapshot = 0 + if _ckpt_list: + for ckpt in _ckpt_list: + ckpt = ckpt.strip("/") + if ckpt.endswith(".step"): + _max_step_snapshot = max(_max_step_snapshot, int(ckpt.split(".")[0])) + return _max_step_snapshot + + max_step_0 = found_latest_snapshot(ckpt_list_0) + max_step_1 = found_latest_snapshot(ckpt_list_1) + + if sum([max_step_0, max_step_1, max_normal_step]) == 0: + return None, None + else: + snap_load_path = snapshot_path_0 if max_step_0 > max_step_1 else snapshot_path_1 + snap_step = max(max_step_0, max_step_1) + load_path = snap_load_path if snap_step > max_normal_step else load_normal_ckpt_path + return load_path, max(snap_step, max_normal_step) def query_latest_snapshot_step_local(self): max_step, max_step_path = 0, None diff --git a/tests/test_utils/common_fixture.py b/tests/test_utils/common_fixture.py index d6a19b6..80cb353 100644 --- a/tests/test_utils/common_fixture.py +++ b/tests/test_utils/common_fixture.py @@ -50,6 +50,8 @@ init_config = Config( ), resume_tb_folder="", tensorboard_folder="", + alert_address=None, + monitor=dict(alert=dict(enable_feishu_alert=False, feishu_alert_address=None, light_monitor_address=None)), ) ) @@ -177,5 +179,5 @@ def del_tmp_file(): results += str(line.rstrip()) presults += line.rstrip().decode() + "\n" print(presults, flush=True) - except FileNotFoundError: + except: # noqa # pylint: disable=bare-except pass diff --git a/tests/test_utils/test_model_checkpoint.py b/tests/test_utils/test_model_checkpoint.py index bd93436..956880b 100644 --- a/tests/test_utils/test_model_checkpoint.py +++ b/tests/test_utils/test_model_checkpoint.py @@ -1,9 +1,10 @@ import os +from functools import partial import pytest import torch +import torch.distributed as dist -from internlm.core.context import global_context as gpc from internlm.core.context.parallel_context import Config from internlm.core.trainer import TrainState from internlm.solver.optimizer.hybrid_zero_optim import HybridZeroOptimizer @@ -15,27 +16,24 @@ from tests.test_utils.common_fixture import ( # noqa # pylint: disable=unused-i BOTO_SAVE_PATH, LOCAL_SAVE_PATH, del_tmp_file, + init_config, init_dist_and_model, reset_singletons, ) -TOTAL_STEP = 6 - -CKPT_EVERY = 4 -SNPASHOT_EVERY = 2 - - +# (TOTAL_STEP, CKPT_EVERY, SNPASHOT_EVERY) +step_info_list = [(8, 4, 2), (3, 4, 2), (1, 6, 3)] ckpt_config_list = [ # Old interface format dict( enable_save_ckpt=True, save_ckpt_folder=BOTO_SAVE_PATH, load_optimizer=True, - checkpoint_every=CKPT_EVERY, + checkpoint_every=0, async_upload=True, async_upload_tmp_folder=ASYNC_TMP_FOLDER, snapshot_ckpt_folder="/".join([BOTO_SAVE_PATH, "snapshot"]), - oss_snapshot_freq=SNPASHOT_EVERY, + oss_snapshot_freq=0, stop_file_path=None, load_model_only_folder=None, load_given_ckpt=False, @@ -47,11 +45,11 @@ ckpt_config_list = [ enable_save_ckpt=True, save_ckpt_folder=LOCAL_SAVE_PATH, load_optimizer=True, - checkpoint_every=CKPT_EVERY, + checkpoint_every=0, async_upload=False, async_upload_tmp_folder=ASYNC_TMP_FOLDER, snapshot_ckpt_folder="/".join([LOCAL_SAVE_PATH, "snapshot"]), - oss_snapshot_freq=SNPASHOT_EVERY, + oss_snapshot_freq=0, stop_file_path=None, load_model_only_folder=None, load_given_ckpt=False, @@ -62,10 +60,10 @@ ckpt_config_list = [ dict( enable_save_ckpt=True, save_ckpt_folder=BOTO_SAVE_PATH, - checkpoint_every=CKPT_EVERY, + checkpoint_every=0, async_upload=True, async_upload_tmp_folder=ASYNC_TMP_FOLDER, - oss_snapshot_freq=SNPASHOT_EVERY, + oss_snapshot_freq=0, stop_file_path=None, is_old_api=False, auto_resume=True, @@ -73,10 +71,10 @@ ckpt_config_list = [ dict( enable_save_ckpt=True, save_ckpt_folder=LOCAL_SAVE_PATH, - checkpoint_every=CKPT_EVERY, + checkpoint_every=0, async_upload=False, async_upload_tmp_folder=ASYNC_TMP_FOLDER, - oss_snapshot_freq=SNPASHOT_EVERY, + oss_snapshot_freq=0, stop_file_path=None, load_ckpt_folder=None, is_old_api=False, @@ -159,15 +157,63 @@ def del_tmp(): del_tmp_file() +def return_prefix_path(save_ckpt_folder): + if save_ckpt_folder.startswith("local:"): + return LOCAL_SAVE_PATH + else: + return BOTO_SAVE_PATH + + +def return_latest_save_path(save_ckpt_folder, total_step, snapshot_freq, ckpt_freq): + + snapshot_latest_step, normal_latest_step = 0, 0 + snapshot_latest_count, normal_latest_count = 0, 0 + + for i in range(total_step): + if (i + 1) % ckpt_freq == 0: + normal_latest_step = i + 1 + normal_latest_count += 1 + else: + if (i + 1) % snapshot_freq == 0: + snapshot_latest_step = i + 1 + snapshot_latest_count += 1 + + if snapshot_latest_step == 0: + return None, None + + if normal_latest_step >= snapshot_latest_step: + return normal_latest_step, os.path.join(return_prefix_path(save_ckpt_folder), f"{normal_latest_step}") + elif normal_latest_step < snapshot_latest_step: + if snapshot_latest_count % 2 == 0: + re_path = f"{return_prefix_path(save_ckpt_folder)}/snapshot/0" + else: + re_path = f"{return_prefix_path(save_ckpt_folder)}/snapshot/1" + return snapshot_latest_step, re_path + else: + assert False + + @pytest.mark.usefixtures("del_tmp") @pytest.mark.usefixtures("reset_singletons") +@pytest.mark.parametrize("step_info", step_info_list) @pytest.mark.parametrize("ckpt_config", ckpt_config_list) -def test_ckpt_mm(ckpt_config, init_dist_and_model): # noqa # pylint: disable=unused-import +def test_ckpt_mm(step_info, ckpt_config, init_dist_and_model): # noqa # pylint: disable=unused-import + from internlm.core.context import global_context as gpc from internlm.utils.model_checkpoint import CheckpointLoadMask, CheckpointLoadType ckpt_config = Config(ckpt_config) - assert ckpt_config.checkpoint_every < TOTAL_STEP - assert ckpt_config.oss_snapshot_freq < TOTAL_STEP + total_step, checkpoint_every, oss_snapshot_freq = step_info + print(total_step, checkpoint_every, oss_snapshot_freq, flush=True) + ckpt_config.checkpoint_every = checkpoint_every + ckpt_config.oss_snapshot_freq = oss_snapshot_freq + + bond_return_latest_save_path = partial( + return_latest_save_path, + ckpt_config.save_ckpt_folder, + total_step, + ckpt_config.oss_snapshot_freq, + ckpt_config.checkpoint_every, + ) model, opim = init_dist_and_model train_state = TrainState(gpc.config, None) @@ -178,7 +224,7 @@ def test_ckpt_mm(ckpt_config, init_dist_and_model): # noqa # pylint: disable=un ckpt_mm = CheckpointManager(ckpt_config, model=model, optimizer=opim) latest_ckpt_step = None - for i in range(TOTAL_STEP + 1): + for i in range(total_step): overwrite_model_value(model, i) overwrite_optim_state(opim, i) @@ -193,54 +239,119 @@ def test_ckpt_mm(ckpt_config, init_dist_and_model): # noqa # pylint: disable=un wait_async_upload_finish() latest_ckpt_info = ckpt_mm.query_lastest_ckpt() - assert latest_ckpt_info is not None - latest_ckpt = latest_ckpt_info["path"] - if ckpt_mm.save_ckpt_folder.startswith("local"): - assert latest_ckpt == "local:local_ckpt/snapshot/0", latest_ckpt + step, path = bond_return_latest_save_path() + assert latest_ckpt_info["path"] == path + if latest_ckpt_step is None: + assert latest_ckpt_step == step else: - assert latest_ckpt == f"{BOTO_SAVE_PATH}/snapshot/0", latest_ckpt + assert latest_ckpt_step == step - 1 + # resume from before save skpt del ckpt_mm SingletonMeta._instances = {} ckpt_mm = CheckpointManager(ckpt_config, model=model, optimizer=opim) ckpt_mm.try_resume_training(train_state) - assert latest_ckpt_step == 5 - assert train_state.step_count == 6 - assert train_state.batch_count == 6 - assert compare_optim_value(ckpt_mm.optimizer, latest_ckpt_step), ckpt_mm.optimizer.param_groups[0]["params"][0] - assert compare_model_value(ckpt_mm.model, latest_ckpt_step), list(ckpt_mm.model.parameters())[0][0] - if ckpt_mm.save_ckpt_folder.startswith("local:"): - ckpt_mm.load_ckpt_info = dict( - path=os.path.join(LOCAL_SAVE_PATH, "4"), - content=CheckpointLoadMask(("all",)), - ckpt_type=CheckpointLoadType.INTERNLM, - ) + if ckpt_config.checkpoint_every < total_step: + # we use step_count to decide when save ckpt, os here latest_ckpt_step = step_count - 1 + assert train_state.step_count == latest_ckpt_step + 1 + assert train_state.batch_count == latest_ckpt_step + 1 + assert compare_optim_value(ckpt_mm.optimizer, latest_ckpt_step), ckpt_mm.optimizer.param_groups[0]["params"][0] + assert compare_model_value(ckpt_mm.model, latest_ckpt_step), list(ckpt_mm.model.parameters())[0][0] + + if ckpt_mm.save_ckpt_folder.startswith("local:"): + ckpt_mm.load_ckpt_info = dict( + path=os.path.join(LOCAL_SAVE_PATH, f"{ckpt_config.checkpoint_every}"), + content=CheckpointLoadMask(("all",)), + ckpt_type=CheckpointLoadType.INTERNLM, + ) + else: + ckpt_mm.load_ckpt_info = dict( + path=os.path.join(BOTO_SAVE_PATH, f"{ckpt_config.checkpoint_every}"), + content=CheckpointLoadMask(("all",)), + ckpt_type=CheckpointLoadType.INTERNLM, + ) + + ckpt_mm.try_resume_training(train_state) + + assert train_state.step_count == ckpt_config.checkpoint_every + assert train_state.batch_count == ckpt_config.checkpoint_every + # compare value is same with i. + assert compare_optim_value(ckpt_mm.optimizer, ckpt_config.checkpoint_every - 1), ckpt_mm.optimizer.param_groups[ + 0 + ]["params"][0] + assert compare_model_value(ckpt_mm.model, ckpt_config.checkpoint_every - 1), list(ckpt_mm.model.parameters())[ + 0 + ][0] else: - ckpt_mm.load_ckpt_info = dict( - path=os.path.join(BOTO_SAVE_PATH, "4"), - content=CheckpointLoadMask(("all",)), - ckpt_type=CheckpointLoadType.INTERNLM, + pass + + +STOP_FILE_PATH = "./alter.log" + + +def query_quit_file(rank, world_size=2): + from internlm.core.context import global_context as gpc + from internlm.initialize import initialize_distributed_env + from internlm.utils.model_checkpoint import CheckpointSaveType + + ckpt_config = Config( + dict( + enable_save_ckpt=True, + save_ckpt_folder=BOTO_SAVE_PATH, + load_optimizer=True, + checkpoint_every=0, + async_upload=True, + async_upload_tmp_folder=ASYNC_TMP_FOLDER, + snapshot_ckpt_folder="/".join([BOTO_SAVE_PATH, "snapshot"]), + oss_snapshot_freq=0, + stop_file_path=STOP_FILE_PATH, + load_model_only_folder=None, + load_given_ckpt=False, + load_ckpt_folder=None, + is_old_api=True, + ), + ) + + os.environ["RANK"] = str(rank) + os.environ["LOCAL_RANK"] = str(rank) + os.environ["WORLD_SIZE"] = str(world_size) + os.environ["MASTER_ADDR"] = "127.0.0.1" + os.environ["MASTER_PORT"] = "12376" + + initialize_distributed_env(config=init_config, launcher="torch", master_port=12376, args_check=False) + train_state = TrainState(init_config, None) + ckpt_mm = CheckpointManager(ckpt_config, model=None, optimizer=None) + if rank == 0: + with open(STOP_FILE_PATH, "w+") as f: + f.write("5") + dist.barrier() + for i in range(10): + train_state.step_count = i + now_break, now_save_ckpt, save_type = ckpt_mm.quit_signal_handler(train_state) + print( + f"step:{i}, rank:{rank}, now_break:{now_break}, now_save_ckpt:{now_save_ckpt}, save_type:{save_type}", + flush=True, ) - - ckpt_mm.try_resume_training(train_state) - - assert train_state.step_count == 4 - assert train_state.batch_count == 4 - assert compare_optim_value(ckpt_mm.optimizer, 3), ckpt_mm.optimizer.param_groups[0]["params"][0] - assert compare_model_value(ckpt_mm.model, 3), list(ckpt_mm.model.parameters())[0][0] + if train_state.step_count == 5: + assert now_break is True + assert now_save_ckpt is True + assert save_type is CheckpointSaveType.NORMAL_CHECKPOINT + dist.barrier() + gpc.destroy() -@pytest.mark.usefixtures("del_tmp") -@pytest.mark.usefixtures("reset_singletons") -@pytest.mark.parametrize("ckpt_config", ckpt_config_list) -def test_ckpt_mm_ping(ckpt_config, init_dist_and_model): # noqa # pylint: disable=unused-import - ckpt_config = Config(ckpt_config) +def test_quit_siganl_handler(): # noqa # pylint: disable=unused-import + import multiprocessing + from multiprocessing.pool import Pool - model, opim = init_dist_and_model - SingletonMeta._instances = {} - ckpt_mm = CheckpointManager(ckpt_config, model=model, optimizer=opim) - ckpt_mm.try_ping_storage() + world_size = 2 + with Pool(processes=world_size, context=multiprocessing.get_context("spawn")) as pool: + items = [(0,), (1,)] + for result in pool.starmap(query_quit_file, items): + print(f"Got result: {result}", flush=True) + + os.remove(STOP_FILE_PATH) if __name__ == "__main__":