fix(ci): fix test model ckpt ci test (#518)

pull/520/head^2
Guoteng 2023-11-30 19:16:35 +08:00 committed by GitHub
parent b79d5ea7ae
commit b3be333aa2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 34 additions and 14 deletions

View File

@ -8,14 +8,25 @@ import torch
from internlm.core.context import global_context as gpc from internlm.core.context import global_context as gpc
from internlm.core.context.parallel_context import Config from internlm.core.context.parallel_context import Config
from internlm.solver.optimizer.hybrid_zero_optim import HybridZeroOptimizer from internlm.solver.optimizer.hybrid_zero_optim import HybridZeroOptimizer
from internlm.train.utils import create_param_groups
from internlm.utils.common import SingletonMeta from internlm.utils.common import SingletonMeta
OSS_NAME = os.environ.get("OSS_BUCKET_NAME") OSS_NAME = os.environ.get("OSS_BUCKET_NAME", None)
OSS_IP = os.environ.get("OSS_IP") OSS_IP = os.environ.get("OSS_IP", None)
USER = os.environ.get("USER") USER = os.environ.get("USER", None)
JOB_NAME = "CI_TEST" JOB_NAME = "CI_TEST"
LOCAL_SAVE_PATH = "local:local_ckpt" LOCAL_SAVE_PATH = "local:local_ckpt"
if OSS_NAME is None or OSS_IP is None:
BOTO_SAVE_PATH = None
BOTO_SAVE_PATH_NO_PRFIX = None
VOLC_SAVE_PATH = None
VOLC_SAVE_PATH_NO_PRFIX = None
ALI_SAVE_PATH = None
ALI_SAVE_PATH_NO_PRFIX = None
else:
BOTO_SAVE_PATH = f"boto3:s3://{OSS_NAME}.{OSS_IP}/{USER}/{JOB_NAME}" BOTO_SAVE_PATH = f"boto3:s3://{OSS_NAME}.{OSS_IP}/{USER}/{JOB_NAME}"
BOTO_SAVE_PATH_NO_PRFIX = f"s3://{OSS_NAME}.{OSS_IP}/{USER}/{JOB_NAME}/" BOTO_SAVE_PATH_NO_PRFIX = f"s3://{OSS_NAME}.{OSS_IP}/{USER}/{JOB_NAME}/"
@ -31,7 +42,12 @@ ASYNC_TMP_FOLDER = "./async_tmp_folder"
# 1B # 1B
init_config = Config( init_config = Config(
dict( dict(
parallel=dict(zero1=1, pipeline=dict(size=1, interleaved_overlap=False), sequence_parallel=False, tensor=1), parallel=dict(
zero1=dict(size=1, fsdp=False),
pipeline=dict(size=1, interleaved_overlap=False),
sequence_parallel=False,
tensor=1,
),
model_type="INTERNLM", model_type="INTERNLM",
adam=dict( adam=dict(
lr=1e-4, lr=1e-4,
@ -90,8 +106,9 @@ def init_naive_optim(model):
def init_hybrid_optim(model): def init_hybrid_optim(model):
params = create_param_groups(model, 0.01)
naive_optimizer = torch.optim.AdamW( naive_optimizer = torch.optim.AdamW(
params=[{"params": model.parameters(), "weight_decay": 0.01}], params=params,
lr=1e-4, lr=1e-4,
betas=(0.9, 0.95), betas=(0.9, 0.95),
eps=1e-8, eps=1e-8,

View File

@ -32,7 +32,7 @@ ckpt_config_list = [
checkpoint_every=0, checkpoint_every=0,
async_upload=True, async_upload=True,
async_upload_tmp_folder=ASYNC_TMP_FOLDER, async_upload_tmp_folder=ASYNC_TMP_FOLDER,
snapshot_ckpt_folder="/".join([BOTO_SAVE_PATH, "snapshot"]), snapshot_ckpt_folder="/".join([BOTO_SAVE_PATH, "snapshot"]) if BOTO_SAVE_PATH is not None else None,
oss_snapshot_freq=0, oss_snapshot_freq=0,
stop_file_path=None, stop_file_path=None,
load_model_only_folder=None, load_model_only_folder=None,
@ -207,6 +207,9 @@ def test_ckpt_mm(step_info, ckpt_config, init_dist_and_model): # noqa # pylint:
ckpt_config.checkpoint_every = checkpoint_every ckpt_config.checkpoint_every = checkpoint_every
ckpt_config.oss_snapshot_freq = oss_snapshot_freq ckpt_config.oss_snapshot_freq = oss_snapshot_freq
if ckpt_config.save_ckpt_folder is None:
return
bond_return_latest_save_path = partial( bond_return_latest_save_path = partial(
return_latest_save_path, return_latest_save_path,
ckpt_config.save_ckpt_folder, ckpt_config.save_ckpt_folder,
@ -298,12 +301,12 @@ def query_quit_file(rank, world_size=2):
ckpt_config = Config( ckpt_config = Config(
dict( dict(
enable_save_ckpt=True, enable_save_ckpt=True,
save_ckpt_folder=BOTO_SAVE_PATH, save_ckpt_folder=LOCAL_SAVE_PATH,
load_optimizer=True, load_optimizer=True,
checkpoint_every=0, checkpoint_every=0,
async_upload=True, async_upload=True,
async_upload_tmp_folder=ASYNC_TMP_FOLDER, async_upload_tmp_folder=ASYNC_TMP_FOLDER,
snapshot_ckpt_folder="/".join([BOTO_SAVE_PATH, "snapshot"]), snapshot_ckpt_folder="/".join([LOCAL_SAVE_PATH, "snapshot"]),
oss_snapshot_freq=0, oss_snapshot_freq=0,
stop_file_path=STOP_FILE_PATH, stop_file_path=STOP_FILE_PATH,
load_model_only_folder=None, load_model_only_folder=None,