From b3be333aa2898749757eb4782b16c55b7cfb6400 Mon Sep 17 00:00:00 2001 From: Guoteng <32697156+SolenoidWGT@users.noreply.github.com> Date: Thu, 30 Nov 2023 19:16:35 +0800 Subject: [PATCH] fix(ci): fix test model ckpt ci test (#518) --- tests/test_utils/common_fixture.py | 39 ++++++++++++++++------- tests/test_utils/test_model_checkpoint.py | 9 ++++-- 2 files changed, 34 insertions(+), 14 deletions(-) diff --git a/tests/test_utils/common_fixture.py b/tests/test_utils/common_fixture.py index 56e7b21..d0f1455 100644 --- a/tests/test_utils/common_fixture.py +++ b/tests/test_utils/common_fixture.py @@ -8,22 +8,33 @@ import torch from internlm.core.context import global_context as gpc from internlm.core.context.parallel_context import Config from internlm.solver.optimizer.hybrid_zero_optim import HybridZeroOptimizer +from internlm.train.utils import create_param_groups from internlm.utils.common import SingletonMeta -OSS_NAME = os.environ.get("OSS_BUCKET_NAME") -OSS_IP = os.environ.get("OSS_IP") -USER = os.environ.get("USER") +OSS_NAME = os.environ.get("OSS_BUCKET_NAME", None) +OSS_IP = os.environ.get("OSS_IP", None) +USER = os.environ.get("USER", None) JOB_NAME = "CI_TEST" LOCAL_SAVE_PATH = "local:local_ckpt" -BOTO_SAVE_PATH = f"boto3:s3://{OSS_NAME}.{OSS_IP}/{USER}/{JOB_NAME}" -BOTO_SAVE_PATH_NO_PRFIX = f"s3://{OSS_NAME}.{OSS_IP}/{USER}/{JOB_NAME}/" +if OSS_NAME is None or OSS_IP is None: + BOTO_SAVE_PATH = None + BOTO_SAVE_PATH_NO_PRFIX = None -VOLC_SAVE_PATH = f"volc:vc://{OSS_NAME}.{OSS_IP}/{USER}/{JOB_NAME}" -VOLC_SAVE_PATH_NO_PRFIX = f"vc://{OSS_NAME}.{OSS_IP}/{USER}/{JOB_NAME}/" + VOLC_SAVE_PATH = None + VOLC_SAVE_PATH_NO_PRFIX = None -ALI_SAVE_PATH = f"oss2:ali://{OSS_NAME}.{OSS_IP}/{USER}/{JOB_NAME}" -ALI_SAVE_PATH_NO_PRFIX = f"ali://{OSS_NAME}.{OSS_IP}/{USER}/{JOB_NAME}/" + ALI_SAVE_PATH = None + ALI_SAVE_PATH_NO_PRFIX = None +else: + BOTO_SAVE_PATH = f"boto3:s3://{OSS_NAME}.{OSS_IP}/{USER}/{JOB_NAME}" + BOTO_SAVE_PATH_NO_PRFIX = f"s3://{OSS_NAME}.{OSS_IP}/{USER}/{JOB_NAME}/" + + VOLC_SAVE_PATH = f"volc:vc://{OSS_NAME}.{OSS_IP}/{USER}/{JOB_NAME}" + VOLC_SAVE_PATH_NO_PRFIX = f"vc://{OSS_NAME}.{OSS_IP}/{USER}/{JOB_NAME}/" + + ALI_SAVE_PATH = f"oss2:ali://{OSS_NAME}.{OSS_IP}/{USER}/{JOB_NAME}" + ALI_SAVE_PATH_NO_PRFIX = f"ali://{OSS_NAME}.{OSS_IP}/{USER}/{JOB_NAME}/" ASYNC_TMP_FOLDER = "./async_tmp_folder" @@ -31,7 +42,12 @@ ASYNC_TMP_FOLDER = "./async_tmp_folder" # 1B init_config = Config( dict( - parallel=dict(zero1=1, pipeline=dict(size=1, interleaved_overlap=False), sequence_parallel=False, tensor=1), + parallel=dict( + zero1=dict(size=1, fsdp=False), + pipeline=dict(size=1, interleaved_overlap=False), + sequence_parallel=False, + tensor=1, + ), model_type="INTERNLM", adam=dict( lr=1e-4, @@ -90,8 +106,9 @@ def init_naive_optim(model): def init_hybrid_optim(model): + params = create_param_groups(model, 0.01) naive_optimizer = torch.optim.AdamW( - params=[{"params": model.parameters(), "weight_decay": 0.01}], + params=params, lr=1e-4, betas=(0.9, 0.95), eps=1e-8, diff --git a/tests/test_utils/test_model_checkpoint.py b/tests/test_utils/test_model_checkpoint.py index 0804455..c50eec1 100644 --- a/tests/test_utils/test_model_checkpoint.py +++ b/tests/test_utils/test_model_checkpoint.py @@ -32,7 +32,7 @@ ckpt_config_list = [ checkpoint_every=0, async_upload=True, async_upload_tmp_folder=ASYNC_TMP_FOLDER, - snapshot_ckpt_folder="/".join([BOTO_SAVE_PATH, "snapshot"]), + snapshot_ckpt_folder="/".join([BOTO_SAVE_PATH, "snapshot"]) if BOTO_SAVE_PATH is not None else None, oss_snapshot_freq=0, stop_file_path=None, load_model_only_folder=None, @@ -207,6 +207,9 @@ def test_ckpt_mm(step_info, ckpt_config, init_dist_and_model): # noqa # pylint: ckpt_config.checkpoint_every = checkpoint_every ckpt_config.oss_snapshot_freq = oss_snapshot_freq + if ckpt_config.save_ckpt_folder is None: + return + bond_return_latest_save_path = partial( return_latest_save_path, ckpt_config.save_ckpt_folder, @@ -298,12 +301,12 @@ def query_quit_file(rank, world_size=2): ckpt_config = Config( dict( enable_save_ckpt=True, - save_ckpt_folder=BOTO_SAVE_PATH, + save_ckpt_folder=LOCAL_SAVE_PATH, load_optimizer=True, checkpoint_every=0, async_upload=True, async_upload_tmp_folder=ASYNC_TMP_FOLDER, - snapshot_ckpt_folder="/".join([BOTO_SAVE_PATH, "snapshot"]), + snapshot_ckpt_folder="/".join([LOCAL_SAVE_PATH, "snapshot"]), oss_snapshot_freq=0, stop_file_path=STOP_FILE_PATH, load_model_only_folder=None,