mirror of https://github.com/InternLM/InternLM
fix(ci): fix test model ckpt ci test (#518)
parent
b79d5ea7ae
commit
b3be333aa2
|
@ -8,14 +8,25 @@ import torch
|
||||||
from internlm.core.context import global_context as gpc
|
from internlm.core.context import global_context as gpc
|
||||||
from internlm.core.context.parallel_context import Config
|
from internlm.core.context.parallel_context import Config
|
||||||
from internlm.solver.optimizer.hybrid_zero_optim import HybridZeroOptimizer
|
from internlm.solver.optimizer.hybrid_zero_optim import HybridZeroOptimizer
|
||||||
|
from internlm.train.utils import create_param_groups
|
||||||
from internlm.utils.common import SingletonMeta
|
from internlm.utils.common import SingletonMeta
|
||||||
|
|
||||||
OSS_NAME = os.environ.get("OSS_BUCKET_NAME")
|
OSS_NAME = os.environ.get("OSS_BUCKET_NAME", None)
|
||||||
OSS_IP = os.environ.get("OSS_IP")
|
OSS_IP = os.environ.get("OSS_IP", None)
|
||||||
USER = os.environ.get("USER")
|
USER = os.environ.get("USER", None)
|
||||||
JOB_NAME = "CI_TEST"
|
JOB_NAME = "CI_TEST"
|
||||||
LOCAL_SAVE_PATH = "local:local_ckpt"
|
LOCAL_SAVE_PATH = "local:local_ckpt"
|
||||||
|
|
||||||
|
if OSS_NAME is None or OSS_IP is None:
|
||||||
|
BOTO_SAVE_PATH = None
|
||||||
|
BOTO_SAVE_PATH_NO_PRFIX = None
|
||||||
|
|
||||||
|
VOLC_SAVE_PATH = None
|
||||||
|
VOLC_SAVE_PATH_NO_PRFIX = None
|
||||||
|
|
||||||
|
ALI_SAVE_PATH = None
|
||||||
|
ALI_SAVE_PATH_NO_PRFIX = None
|
||||||
|
else:
|
||||||
BOTO_SAVE_PATH = f"boto3:s3://{OSS_NAME}.{OSS_IP}/{USER}/{JOB_NAME}"
|
BOTO_SAVE_PATH = f"boto3:s3://{OSS_NAME}.{OSS_IP}/{USER}/{JOB_NAME}"
|
||||||
BOTO_SAVE_PATH_NO_PRFIX = f"s3://{OSS_NAME}.{OSS_IP}/{USER}/{JOB_NAME}/"
|
BOTO_SAVE_PATH_NO_PRFIX = f"s3://{OSS_NAME}.{OSS_IP}/{USER}/{JOB_NAME}/"
|
||||||
|
|
||||||
|
@ -31,7 +42,12 @@ ASYNC_TMP_FOLDER = "./async_tmp_folder"
|
||||||
# 1B
|
# 1B
|
||||||
init_config = Config(
|
init_config = Config(
|
||||||
dict(
|
dict(
|
||||||
parallel=dict(zero1=1, pipeline=dict(size=1, interleaved_overlap=False), sequence_parallel=False, tensor=1),
|
parallel=dict(
|
||||||
|
zero1=dict(size=1, fsdp=False),
|
||||||
|
pipeline=dict(size=1, interleaved_overlap=False),
|
||||||
|
sequence_parallel=False,
|
||||||
|
tensor=1,
|
||||||
|
),
|
||||||
model_type="INTERNLM",
|
model_type="INTERNLM",
|
||||||
adam=dict(
|
adam=dict(
|
||||||
lr=1e-4,
|
lr=1e-4,
|
||||||
|
@ -90,8 +106,9 @@ def init_naive_optim(model):
|
||||||
|
|
||||||
|
|
||||||
def init_hybrid_optim(model):
|
def init_hybrid_optim(model):
|
||||||
|
params = create_param_groups(model, 0.01)
|
||||||
naive_optimizer = torch.optim.AdamW(
|
naive_optimizer = torch.optim.AdamW(
|
||||||
params=[{"params": model.parameters(), "weight_decay": 0.01}],
|
params=params,
|
||||||
lr=1e-4,
|
lr=1e-4,
|
||||||
betas=(0.9, 0.95),
|
betas=(0.9, 0.95),
|
||||||
eps=1e-8,
|
eps=1e-8,
|
||||||
|
|
|
@ -32,7 +32,7 @@ ckpt_config_list = [
|
||||||
checkpoint_every=0,
|
checkpoint_every=0,
|
||||||
async_upload=True,
|
async_upload=True,
|
||||||
async_upload_tmp_folder=ASYNC_TMP_FOLDER,
|
async_upload_tmp_folder=ASYNC_TMP_FOLDER,
|
||||||
snapshot_ckpt_folder="/".join([BOTO_SAVE_PATH, "snapshot"]),
|
snapshot_ckpt_folder="/".join([BOTO_SAVE_PATH, "snapshot"]) if BOTO_SAVE_PATH is not None else None,
|
||||||
oss_snapshot_freq=0,
|
oss_snapshot_freq=0,
|
||||||
stop_file_path=None,
|
stop_file_path=None,
|
||||||
load_model_only_folder=None,
|
load_model_only_folder=None,
|
||||||
|
@ -207,6 +207,9 @@ def test_ckpt_mm(step_info, ckpt_config, init_dist_and_model): # noqa # pylint:
|
||||||
ckpt_config.checkpoint_every = checkpoint_every
|
ckpt_config.checkpoint_every = checkpoint_every
|
||||||
ckpt_config.oss_snapshot_freq = oss_snapshot_freq
|
ckpt_config.oss_snapshot_freq = oss_snapshot_freq
|
||||||
|
|
||||||
|
if ckpt_config.save_ckpt_folder is None:
|
||||||
|
return
|
||||||
|
|
||||||
bond_return_latest_save_path = partial(
|
bond_return_latest_save_path = partial(
|
||||||
return_latest_save_path,
|
return_latest_save_path,
|
||||||
ckpt_config.save_ckpt_folder,
|
ckpt_config.save_ckpt_folder,
|
||||||
|
@ -298,12 +301,12 @@ def query_quit_file(rank, world_size=2):
|
||||||
ckpt_config = Config(
|
ckpt_config = Config(
|
||||||
dict(
|
dict(
|
||||||
enable_save_ckpt=True,
|
enable_save_ckpt=True,
|
||||||
save_ckpt_folder=BOTO_SAVE_PATH,
|
save_ckpt_folder=LOCAL_SAVE_PATH,
|
||||||
load_optimizer=True,
|
load_optimizer=True,
|
||||||
checkpoint_every=0,
|
checkpoint_every=0,
|
||||||
async_upload=True,
|
async_upload=True,
|
||||||
async_upload_tmp_folder=ASYNC_TMP_FOLDER,
|
async_upload_tmp_folder=ASYNC_TMP_FOLDER,
|
||||||
snapshot_ckpt_folder="/".join([BOTO_SAVE_PATH, "snapshot"]),
|
snapshot_ckpt_folder="/".join([LOCAL_SAVE_PATH, "snapshot"]),
|
||||||
oss_snapshot_freq=0,
|
oss_snapshot_freq=0,
|
||||||
stop_file_path=STOP_FILE_PATH,
|
stop_file_path=STOP_FILE_PATH,
|
||||||
load_model_only_folder=None,
|
load_model_only_folder=None,
|
||||||
|
|
Loading…
Reference in New Issue