mirror of https://github.com/InternLM/InternLM
fix(ci): fix test model ckpt ci test (#518)
parent
b79d5ea7ae
commit
b3be333aa2
|
@ -8,14 +8,25 @@ import torch
|
|||
from internlm.core.context import global_context as gpc
|
||||
from internlm.core.context.parallel_context import Config
|
||||
from internlm.solver.optimizer.hybrid_zero_optim import HybridZeroOptimizer
|
||||
from internlm.train.utils import create_param_groups
|
||||
from internlm.utils.common import SingletonMeta
|
||||
|
||||
OSS_NAME = os.environ.get("OSS_BUCKET_NAME")
|
||||
OSS_IP = os.environ.get("OSS_IP")
|
||||
USER = os.environ.get("USER")
|
||||
OSS_NAME = os.environ.get("OSS_BUCKET_NAME", None)
|
||||
OSS_IP = os.environ.get("OSS_IP", None)
|
||||
USER = os.environ.get("USER", None)
|
||||
JOB_NAME = "CI_TEST"
|
||||
LOCAL_SAVE_PATH = "local:local_ckpt"
|
||||
|
||||
if OSS_NAME is None or OSS_IP is None:
|
||||
BOTO_SAVE_PATH = None
|
||||
BOTO_SAVE_PATH_NO_PRFIX = None
|
||||
|
||||
VOLC_SAVE_PATH = None
|
||||
VOLC_SAVE_PATH_NO_PRFIX = None
|
||||
|
||||
ALI_SAVE_PATH = None
|
||||
ALI_SAVE_PATH_NO_PRFIX = None
|
||||
else:
|
||||
BOTO_SAVE_PATH = f"boto3:s3://{OSS_NAME}.{OSS_IP}/{USER}/{JOB_NAME}"
|
||||
BOTO_SAVE_PATH_NO_PRFIX = f"s3://{OSS_NAME}.{OSS_IP}/{USER}/{JOB_NAME}/"
|
||||
|
||||
|
@ -31,7 +42,12 @@ ASYNC_TMP_FOLDER = "./async_tmp_folder"
|
|||
# 1B
|
||||
init_config = Config(
|
||||
dict(
|
||||
parallel=dict(zero1=1, pipeline=dict(size=1, interleaved_overlap=False), sequence_parallel=False, tensor=1),
|
||||
parallel=dict(
|
||||
zero1=dict(size=1, fsdp=False),
|
||||
pipeline=dict(size=1, interleaved_overlap=False),
|
||||
sequence_parallel=False,
|
||||
tensor=1,
|
||||
),
|
||||
model_type="INTERNLM",
|
||||
adam=dict(
|
||||
lr=1e-4,
|
||||
|
@ -90,8 +106,9 @@ def init_naive_optim(model):
|
|||
|
||||
|
||||
def init_hybrid_optim(model):
|
||||
params = create_param_groups(model, 0.01)
|
||||
naive_optimizer = torch.optim.AdamW(
|
||||
params=[{"params": model.parameters(), "weight_decay": 0.01}],
|
||||
params=params,
|
||||
lr=1e-4,
|
||||
betas=(0.9, 0.95),
|
||||
eps=1e-8,
|
||||
|
|
|
@ -32,7 +32,7 @@ ckpt_config_list = [
|
|||
checkpoint_every=0,
|
||||
async_upload=True,
|
||||
async_upload_tmp_folder=ASYNC_TMP_FOLDER,
|
||||
snapshot_ckpt_folder="/".join([BOTO_SAVE_PATH, "snapshot"]),
|
||||
snapshot_ckpt_folder="/".join([BOTO_SAVE_PATH, "snapshot"]) if BOTO_SAVE_PATH is not None else None,
|
||||
oss_snapshot_freq=0,
|
||||
stop_file_path=None,
|
||||
load_model_only_folder=None,
|
||||
|
@ -207,6 +207,9 @@ def test_ckpt_mm(step_info, ckpt_config, init_dist_and_model): # noqa # pylint:
|
|||
ckpt_config.checkpoint_every = checkpoint_every
|
||||
ckpt_config.oss_snapshot_freq = oss_snapshot_freq
|
||||
|
||||
if ckpt_config.save_ckpt_folder is None:
|
||||
return
|
||||
|
||||
bond_return_latest_save_path = partial(
|
||||
return_latest_save_path,
|
||||
ckpt_config.save_ckpt_folder,
|
||||
|
@ -298,12 +301,12 @@ def query_quit_file(rank, world_size=2):
|
|||
ckpt_config = Config(
|
||||
dict(
|
||||
enable_save_ckpt=True,
|
||||
save_ckpt_folder=BOTO_SAVE_PATH,
|
||||
save_ckpt_folder=LOCAL_SAVE_PATH,
|
||||
load_optimizer=True,
|
||||
checkpoint_every=0,
|
||||
async_upload=True,
|
||||
async_upload_tmp_folder=ASYNC_TMP_FOLDER,
|
||||
snapshot_ckpt_folder="/".join([BOTO_SAVE_PATH, "snapshot"]),
|
||||
snapshot_ckpt_folder="/".join([LOCAL_SAVE_PATH, "snapshot"]),
|
||||
oss_snapshot_freq=0,
|
||||
stop_file_path=STOP_FILE_PATH,
|
||||
load_model_only_folder=None,
|
||||
|
|
Loading…
Reference in New Issue