mirror of https://github.com/InternLM/InternLM
feat(ckpt): fix checkpoint bugs and add feature enhancements. (#259)
* fix(ckpt): ckpt bug fix and api refactor 1. fix latest ckpt query bug 2. add ckpt unit test 3. fix storage manager boto3/local client get_fns bug 4. fix only model load case zero fp32 buffer overwrite model weights bug. 5. add ckpt_type and add zero reload ci-test * fix(ckpt): fix ckpt and trainer bug * fix and refactor * fix base on comment * feat: add legacy apipull/275/head^2
parent
860de0aa46
commit
f6e007f95b
|
@ -22,13 +22,16 @@ CHECKPOINT_EVERY = 50
|
|||
ckpt = dict(
|
||||
enable_save_ckpt=False, # enable ckpt save.
|
||||
save_ckpt_folder=SAVE_CKPT_FOLDER, # Path to save training ckpt.
|
||||
# load_ckpt_folder=LOAD_CKPT_FOLDER, # Ckpt path to resume training(load weights and scheduler/context states).
|
||||
# load_model_only_folder=MODEL_ONLY_FOLDER, # Path to initialize with given model weights.
|
||||
load_optimizer=True, # Wheter to load optimizer states when continuing training.
|
||||
# load_ckpt_folder= dict(path=MODEL_ONLY_FOLDER, content=["model"], ckpt_type="normal"),
|
||||
load_ckpt_folder="local:llm_ckpts/",
|
||||
# 'load_ckpt_info' setting guide:
|
||||
# 1. the 'path' indicate ckpt path,
|
||||
# 2. the 'content‘ means what states will be loaded, support: "model", "sampler", "optimizer", "scheduler", "all"
|
||||
# 3. the ’ckpt_type‘ means the type of checkpoint to be loaded, now only 'normal' type is supported.
|
||||
load_ckpt_info=dict(path=MODEL_ONLY_FOLDER, content=("model",), ckpt_type="internlm"),
|
||||
checkpoint_every=CHECKPOINT_EVERY,
|
||||
async_upload=True, # async ckpt upload. (only work for boto3 ckpt)
|
||||
async_upload_tmp_folder="/dev/shm/internlm_tmp_ckpt/", # path for temporarily files during asynchronous upload.
|
||||
snapshot_ckpt_folder="/".join([SAVE_CKPT_FOLDER, "snapshot"]), # directory for snapshot ckpt storage path.
|
||||
oss_snapshot_freq=int(CHECKPOINT_EVERY / 2), # snapshot ckpt save frequency.
|
||||
)
|
||||
|
||||
|
|
|
@ -115,19 +115,19 @@ If you want to load a model checkpoint when starting the training, you can confi
|
|||
|
||||
```python
|
||||
SAVE_CKPT_FOLDER = "local:/path/to/save/ckpt"
|
||||
MODEL_ONLY_FOLDER = "local:/path/to/load/init/model/ckpt"
|
||||
LOAD_CKPT_FOLDER = "local:/path/to/load/resume/ckpt"
|
||||
ckpt = dict(
|
||||
save_ckpt_folder=SAVE_CKPT_FOLDER, # Path to save the model and optimizer checkpoints
|
||||
checkpoint_every=float("inf"), # Save a checkpoint every specified number of steps, default value is inf
|
||||
load_model_only_folder=MODEL_ONLY_FOLDER, # Path to load the initial model weights, only load model weights without loading optimizer weights, training will start from the first step
|
||||
load_ckpt_folder=LOAD_CKPT_FOLDER, # Path to load the weights of the model and optimizer for resuming training, training will resume from the specified step
|
||||
load_optimizer=True, # Whether to load optimizer weights when resuming training, default value is True
|
||||
# When resuming training from a breakpoint,:
|
||||
# (1) 'path' is the path of the loaded checkpoint.
|
||||
# (2) 'content' indicates which state will be loaded, support: "model", "sampler", "optimizer", "scheduler", "all"
|
||||
# (3) 'ckpt_type' indicates which type ckpt will be loaded, currently supported: "internlm"
|
||||
load_ckpt_info=dict(path=MODEL_ONLY_FOLDER, content=("model",), ckpt_type="internlm"),
|
||||
)
|
||||
```
|
||||
|
||||
Note:
|
||||
- `load_model_only_folder` and `load_ckpt_folder` cannot be set at the same time.
|
||||
- If the path starts with `local:`, it means the file is stored in the local file system. If it starts with `boto3:`, it means the file is stored in the remote OSS.
|
||||
|
||||
The configuration for the model is as follows:
|
||||
|
|
|
@ -103,18 +103,17 @@ data = dict(
|
|||
如果在启动训练时要加载模型 `checkpoint`,可进行如下相关配置:
|
||||
```python
|
||||
SAVE_CKPT_FOLDER = "local:/path/to/save/ckpt"
|
||||
MODEL_ONLY_FOLDER = "local:/path/to/load/init/model/ckpt"
|
||||
LOAD_CKPT_FOLDER = "local:/path/to/load/resume/ckpt"
|
||||
ckpt = dict(
|
||||
save_ckpt_folder=SAVE_CKPT_FOLDER, # 存储模型和优化器 checkpoint 的路径
|
||||
checkpoint_every=float("inf"), # 每多少个 step 存储一次 checkpoint,默认值为 inf
|
||||
load_model_only_folder=MODEL_ONLY_FOLDER, # 加载模型初始权重的路径,只加载模型权重,不加载优化器权重,训练将从第一个 step 开始
|
||||
load_ckpt_folder=LOAD_CKPT_FOLDER, # 断点续训时,加载模型和优化器等权重的路径,将从指定的 step 恢复训练
|
||||
load_optimizer=True, # 断点续训时,是否需要加载优化器权重,默认值为 True
|
||||
# 断点续训时,加载模型和优化器等权重的路径,将从指定的 step 恢复训练
|
||||
# content 表示哪些状态会被加载,支持: "model", "sampler", "optimizer", "scheduler", "all"
|
||||
# ckpt_type 表示加载的模型类型,目前支持: "internlm"
|
||||
load_ckpt_info=dict(path=MODEL_ONLY_FOLDER, content=("model",), ckpt_type="internlm"),
|
||||
)
|
||||
```
|
||||
注意:
|
||||
- `load_model_only_folder`与`load_ckpt_folder`不能同时设置
|
||||
- 路径若以 `local:` 为前缀,则存储在本地文件系统;若以 `boto3:` 为前缀,则存储在远程 oss 上
|
||||
|
||||
模型相关关键参数配置如下所示:
|
||||
|
|
|
@ -23,7 +23,15 @@ class TrainState:
|
|||
train_dl (DataLoader): The DataLoader object used for training.
|
||||
"""
|
||||
|
||||
def __init__(self, config) -> None:
|
||||
def __init__(self, config, batch_sampler) -> None:
|
||||
"""
|
||||
Args:
|
||||
config (Config): internlm config
|
||||
batch_sampler (torch.utils.data.Sampler): Because the dataloader loading is
|
||||
asynchronous and prefetched, the batch_sampler state maintained inside the
|
||||
dataloader are faster then the actual training progress, so we copy the
|
||||
batch_sampler as the anchor point of ckpt reload.
|
||||
"""
|
||||
# The number of batches produced by the data iterator
|
||||
self.batch_count: int = 0
|
||||
# Used to store the number of samples consumed in the current epoch
|
||||
|
@ -43,9 +51,20 @@ class TrainState:
|
|||
|
||||
self.tensorboard_folder = config.tensorboard_folder
|
||||
|
||||
def init_batch_sampler(self, train_dl):
|
||||
# Copy of the batch sampler from the DataLoader
|
||||
self.batch_sampler = train_dl.batch_sampler.copy()
|
||||
# learning rate
|
||||
self.lr = config.adam.lr
|
||||
|
||||
# smapler state
|
||||
if batch_sampler:
|
||||
self.init_batch_sampler(batch_sampler)
|
||||
|
||||
def init_batch_sampler(self, batch_sampler):
|
||||
"""
|
||||
Args:
|
||||
batch_sampler (torch.utils.data.Sampler): sampler.
|
||||
"""
|
||||
# make a copy of batch_sampler.
|
||||
self.batch_sampler = batch_sampler.copy()
|
||||
# Iterator for the batch sampler
|
||||
self.batch_sampler_iter = iter(self.batch_sampler)
|
||||
|
||||
|
@ -61,26 +80,22 @@ class TrainState:
|
|||
|
||||
return json.dumps(info, indent=4, sort_keys=True)
|
||||
|
||||
def load_state_dict(self, other_stuffs, train_dl):
|
||||
def load_state_dict(self, other_stuffs):
|
||||
"""
|
||||
Resumes training from a checkpoint.
|
||||
|
||||
Args:
|
||||
other_stuffs (dict): Other information needed to resume training.
|
||||
train_dl (DataLoader): The DataLoader object used for training.
|
||||
"""
|
||||
|
||||
self.batch_count = other_stuffs["batch_count"] + 1 # here you need to shift a batch backward
|
||||
self.num_consumed_samples_in_epoch = other_stuffs["num_consumed_samples_in_epoch"]
|
||||
self.num_consumed_tokens = other_stuffs["num_consumed_tokens"]
|
||||
self.inf_nan_skip_batches = other_stuffs["inf_nan_skip_batches"]
|
||||
# compatible with previous checkpoints without this parameter
|
||||
self.step_count = other_stuffs.get("step_count", other_stuffs["batch_count"]) + 1
|
||||
|
||||
# track the actual updates of sampler when using weighted sampling
|
||||
if hasattr(self, "batch_sampler"):
|
||||
self.batch_sampler = train_dl.batch_sampler.copy()
|
||||
self.batch_sampler_iter = iter(self.batch_sampler)
|
||||
# Because the ckpt save occurs after updating 'step_count',
|
||||
# there is no need to increment 'step_count' here (Does our step count start from 0 ?),
|
||||
# However, 'batch_count' is updating before ckpt storage, so it need to inc 1 when resume.
|
||||
self.batch_count = other_stuffs["batch_count"] + 1 # here you need to shift a batch backward
|
||||
self.step_count = other_stuffs.get("step_count", self.batch_count)
|
||||
|
||||
# resume tensorboard from older tensorboard_folder
|
||||
self.resume_tb_folder = other_stuffs.get("tensorboard_folder", None)
|
||||
|
|
|
@ -12,7 +12,6 @@ from internlm.core.context import Config
|
|||
from internlm.core.context import global_context as gpc
|
||||
from internlm.utils.common import get_master_node
|
||||
from internlm.utils.logger import get_logger
|
||||
from internlm.utils.storage_manager import init_storage_manager
|
||||
|
||||
logger = get_logger(__file__)
|
||||
|
||||
|
@ -111,7 +110,7 @@ def args_sanity_check():
|
|||
# processing the checkpoint config
|
||||
ckpt = gpc.config.ckpt
|
||||
if "enable_save_ckpt" not in ckpt:
|
||||
ckpt._add_item("enable_save_ckpt", False)
|
||||
ckpt._add_item("enable_save_ckpt", True)
|
||||
|
||||
# Saving checkpoint args.
|
||||
if ckpt.enable_save_ckpt:
|
||||
|
@ -137,9 +136,6 @@ def args_sanity_check():
|
|||
if not ckpt.async_upload:
|
||||
ckpt._add_item("async_upload_tmp_folder", None)
|
||||
|
||||
if "snapshot_ckpt_folder" not in ckpt:
|
||||
ckpt._add_item("snapshot_ckpt_folder", os.path.join(ckpt.save_ckpt_folder, "snapshot"))
|
||||
|
||||
if "oss_snapshot_freq" not in ckpt:
|
||||
ckpt._add_item("oss_snapshot_freq", float("inf")) # if oss_snapshot_freq not given, we disable.
|
||||
else:
|
||||
|
@ -149,44 +145,23 @@ def args_sanity_check():
|
|||
ckpt._add_item("async_upload", False)
|
||||
ckpt._add_item("async_upload_tmp_folder", None)
|
||||
ckpt._add_item("snapshot_ckpt_folder", None)
|
||||
ckpt._add_item("snapshot_ckpt_folder", None)
|
||||
|
||||
# Loading checkpoint args.
|
||||
if "load_model_only_folder" not in ckpt:
|
||||
ckpt._add_item("load_model_only_folder", None)
|
||||
|
||||
if "load_ckpt_folder" not in ckpt:
|
||||
ckpt._add_item("load_ckpt_folder", None)
|
||||
|
||||
if "load_optimizer" not in ckpt:
|
||||
ckpt._add_item("load_optimizer", True)
|
||||
|
||||
if "stop_file_path" not in ckpt:
|
||||
ckpt._add_item("stop_file_path", None)
|
||||
|
||||
if "load_given_ckpt" not in ckpt:
|
||||
# If 'load_given_ckpt' is not given, we set it to False, so internlm can have opportunity
|
||||
if "auto_resume" not in ckpt:
|
||||
# If 'auto_resume' is not given, we set it to True, so internlm can have opportunity
|
||||
# to auto-load latest checkpoint.
|
||||
ckpt._add_item("load_given_ckpt", False)
|
||||
|
||||
if ckpt.load_given_ckpt:
|
||||
# Priority: load_given_ckpt(True) > latest_checkpoint > load_model_only_folder
|
||||
if ckpt.load_ckpt_folder and ckpt.load_model_only_folder:
|
||||
logger.warning(
|
||||
"Detect 'load_ckpt_folder' and 'load_model_only_folder' set at the same time, \
|
||||
and 'load_given_ckpt' is True, so internlm will load from 'load_ckpt_folder'"
|
||||
)
|
||||
ckpt.load_model_only_folder = None
|
||||
ckpt._add_item("auto_resume", True)
|
||||
|
||||
if gpc.is_rank_for_log():
|
||||
logger.info("+" * 15 + " Ckpt Info " + "+" * 15) # pylint: disable=W1201
|
||||
logger.info(f"is enable save ckpt: {ckpt.enable_save_ckpt}")
|
||||
logger.info(f"save_ckpt_folder: {ckpt.save_ckpt_folder}")
|
||||
logger.info(f"checkpoint_every: {ckpt.checkpoint_every}")
|
||||
logger.info(f"load_given_ckpt: {ckpt.load_given_ckpt}")
|
||||
|
||||
# initialization storage manager
|
||||
init_storage_manager(ckpt)
|
||||
|
||||
# tensorboard writer config
|
||||
if "enable_tb" not in gpc.config:
|
||||
|
@ -459,3 +434,11 @@ def initialize_distributed_env(
|
|||
|
||||
if args_check:
|
||||
args_sanity_check()
|
||||
|
||||
|
||||
def get_config_value(config, key, defalut):
|
||||
try:
|
||||
value = config[key]
|
||||
except KeyError:
|
||||
value = defalut
|
||||
return value
|
||||
|
|
|
@ -0,0 +1,40 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- encoding: utf-8 -*-
|
||||
|
||||
from internlm.initialize.launch import get_config_value
|
||||
from internlm.utils.logger import get_logger
|
||||
|
||||
logger = get_logger(__file__)
|
||||
|
||||
|
||||
def auto_resume_sanity_check(ckpt_config):
|
||||
load_given_ckpt = get_config_value(ckpt_config, "load_given_ckpt", None)
|
||||
if load_given_ckpt is None:
|
||||
return True # default value is True
|
||||
else:
|
||||
return not load_given_ckpt
|
||||
|
||||
|
||||
def ckpt_info_sanity_check(ckpt_config):
|
||||
load_ckpt_folder = get_config_value(ckpt_config, "load_ckpt_folder", None)
|
||||
|
||||
load_model_only_folder = get_config_value(ckpt_config, "load_model_only_folder", None)
|
||||
|
||||
if load_model_only_folder is not None:
|
||||
assert (
|
||||
load_ckpt_folder is None
|
||||
), "Detect 'load_ckpt_folder' and 'load_model_only_folder' set at the same time, \
|
||||
# and 'load_given_ckpt' is True, so internlm will load from 'load_ckpt_folder'"
|
||||
return dict(path=load_model_only_folder, content=("model",), ckpt_type="internlm")
|
||||
else:
|
||||
load_optimizer = get_config_value(ckpt_config, "load_optimizer", True)
|
||||
|
||||
if isinstance(load_ckpt_folder, str):
|
||||
if load_optimizer:
|
||||
return dict(path=load_ckpt_folder, content=("model", "sampler", "optimizer"), ckpt_type="internlm")
|
||||
else:
|
||||
return dict(path=load_ckpt_folder, content=("model", "sampler"), ckpt_type="internlm")
|
||||
elif load_ckpt_folder is None:
|
||||
return None
|
||||
else:
|
||||
assert f"Unsupport data type:'{type(load_ckpt_folder)}' for config.ckpt arg: 'load_ckpt_folder'"
|
|
@ -1,6 +1,6 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- encoding: utf-8 -*-
|
||||
|
||||
from .hybrid_zero_optim import HybridZeroOptimizer
|
||||
from .hybrid_zero_optim import HybridZeroOptimizer, reload_zero_fp32_buff
|
||||
|
||||
__all__ = ["HybridZeroOptimizer"]
|
||||
__all__ = ["HybridZeroOptimizer", "reload_zero_fp32_buff"]
|
||||
|
|
|
@ -775,3 +775,17 @@ class HybridZeroOptimizer(BaseOptimizer):
|
|||
|
||||
if "zero_devide_optim_plan" in states:
|
||||
self.params_per_rank_id_dict = states["zero_devide_optim_plan"]
|
||||
|
||||
|
||||
def reload_zero_fp32_buff(optimizer):
|
||||
# If we use AMP optimizer, we need to update its fp32 buffer as newly loaded weights value.
|
||||
# Or we must ensure that loading model weights must be done before zero is initialized.
|
||||
if isinstance(optimizer, HybridZeroOptimizer):
|
||||
for group_id, param_group in enumerate(optimizer.optim.param_groups):
|
||||
if optimizer.param_group_has_params[group_id]:
|
||||
# flatten fp16 params have already been updated by 'load_model_checkpoint'
|
||||
fp16_flat_current_rank = optimizer._param_store.get_flat_fp16_param_by_rank_group(
|
||||
optimizer._zero_local_rank, group_id
|
||||
)
|
||||
# param_group["params"] is fp32 flatten optimizer states of this zero rank.
|
||||
param_group["params"][0].copy_(fp16_flat_current_rank.float())
|
||||
|
|
|
@ -3,37 +3,135 @@
|
|||
|
||||
import copy
|
||||
import fcntl
|
||||
import inspect
|
||||
import os
|
||||
import socket
|
||||
import time
|
||||
from enum import Enum
|
||||
from typing import Dict
|
||||
from typing import Callable, Dict, Union
|
||||
|
||||
import torch
|
||||
|
||||
from internlm.core.context import ParallelMode
|
||||
from internlm.core.context import global_context as gpc
|
||||
from internlm.core.trainer import TrainState
|
||||
from internlm.initialize.launch import get_config_value
|
||||
from internlm.initialize.legacy.launch import (
|
||||
auto_resume_sanity_check,
|
||||
ckpt_info_sanity_check,
|
||||
)
|
||||
from internlm.monitor import send_alert_message
|
||||
from internlm.solver.optimizer import HybridZeroOptimizer
|
||||
from internlm.solver.optimizer import HybridZeroOptimizer, reload_zero_fp32_buff
|
||||
from internlm.utils.common import get_current_device
|
||||
from internlm.utils.logger import get_logger
|
||||
from internlm.utils.megatron_timers import megatron_timer as timer
|
||||
from internlm.utils.storage_manager import (
|
||||
get_fns,
|
||||
get_storage_manager,
|
||||
init_storage_manager,
|
||||
llm_load,
|
||||
llm_save,
|
||||
try_get_storage_backend,
|
||||
)
|
||||
|
||||
logger = get_logger(__file__)
|
||||
|
||||
|
||||
class CheckpointType(Enum):
|
||||
class CheckpointSaveType(Enum):
|
||||
NORMAL_CHECKPOINT = 1
|
||||
SNAPSHOT_CHECKPOINT = 2
|
||||
|
||||
|
||||
class CheckpointLoadType(Enum):
|
||||
INTERNLM = "internlm"
|
||||
|
||||
|
||||
# The load method implemented by internlm by default does not use string representation types,
|
||||
# but uses enumeration types defined in advance.
|
||||
LOAD_TYPE_DICT = {
|
||||
"internlm": CheckpointLoadType.INTERNLM,
|
||||
}
|
||||
|
||||
|
||||
class CheckpointLoadContent:
|
||||
MODEL = "model"
|
||||
SAMPLER = "sampler"
|
||||
OPIMIZER = "optimizer"
|
||||
SCHEDULAER = "scheduler"
|
||||
|
||||
|
||||
class CheckpointLoadMethod:
|
||||
"""The registration class of the checkpoint loading method,
|
||||
users can define their own custom ckpt loading methods."""
|
||||
|
||||
LOAD_FUNC_SIG = None
|
||||
LOAD_TYPE_FUNC = {}
|
||||
|
||||
@staticmethod
|
||||
def convet_load_type(load_type: str) -> Union[CheckpointLoadType, str]:
|
||||
if load_type.lower() in LOAD_TYPE_DICT:
|
||||
# The ckpt load method implemented by internlm by default.
|
||||
return LOAD_TYPE_DICT[load_type.lower()]
|
||||
else:
|
||||
# If it is a user-defined field, we do not do any conversion and represent it as a string.
|
||||
return load_type
|
||||
|
||||
@staticmethod
|
||||
def register_ckpt_load_type(load_type: Union[str, CheckpointLoadType], load_func: Callable):
|
||||
if load_type in CheckpointLoadMethod.LOAD_TYPE_FUNC:
|
||||
logger.warning(f"{load_type} has aleady been registed!")
|
||||
return
|
||||
|
||||
CheckpointLoadMethod.LOAD_TYPE_FUNC.update({load_type: load_func})
|
||||
|
||||
if load_type == CheckpointLoadType.INTERNLM:
|
||||
CheckpointLoadMethod.LOAD_FUNC_SIG = inspect.signature(load_func)
|
||||
else:
|
||||
if inspect.signature(load_func) != CheckpointLoadMethod.LOAD_FUNC_SIG:
|
||||
logger.warning(
|
||||
f"registe load model ckpt signature is not same with: {CheckpointLoadMethod.LOAD_FUNC_SIG}"
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def get_ckpt_load_type_func(load_type: Union[str, CheckpointLoadType]):
|
||||
return CheckpointLoadMethod.LOAD_TYPE_FUNC[load_type]
|
||||
|
||||
|
||||
class CheckpointLoadMask:
|
||||
"""
|
||||
According to the content field in the incoming ckpt_info, decide which components to load.
|
||||
"""
|
||||
|
||||
LOAD_CONTENT_DICT = {
|
||||
"model": CheckpointLoadContent.MODEL,
|
||||
"sampler": CheckpointLoadContent.SAMPLER,
|
||||
"optimizer": CheckpointLoadContent.OPIMIZER,
|
||||
"scheduler": CheckpointLoadContent.SCHEDULAER,
|
||||
}
|
||||
|
||||
def __init__(self, content: tuple) -> None:
|
||||
self.load_set = set(map(lambda x: x.lower(), content))
|
||||
if "all" in self.load_set:
|
||||
self.load_set = set(CheckpointLoadMask.LOAD_CONTENT_DICT.values())
|
||||
else:
|
||||
self.load_set = set(map(lambda x: CheckpointLoadMask.LOAD_CONTENT_DICT[x.lower()], content))
|
||||
|
||||
def need_load(self, content: CheckpointLoadContent):
|
||||
return content in self.load_set
|
||||
|
||||
def not_only_load(self, content: CheckpointLoadContent):
|
||||
return content in self.load_set and len(self.load_set) > 1
|
||||
|
||||
def only_load(self, content: CheckpointLoadContent):
|
||||
return set(content) == self.load_set
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"{self.load_set}."
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"{self.load_set}."
|
||||
|
||||
|
||||
def get_model_topology(model):
|
||||
"""
|
||||
Returns:
|
||||
|
@ -55,6 +153,66 @@ def get_model_topology(model):
|
|||
return topos
|
||||
|
||||
|
||||
def try_load_internlm_ckpt(ckpt_mm, load_info, train_state: TrainState):
|
||||
load_content_str = ""
|
||||
load_ckpt_folder = load_info["path"]
|
||||
load_content: CheckpointLoadMask = load_info["content"]
|
||||
|
||||
if gpc.is_rank_for_log():
|
||||
logger.info(f"Try load_ckpt_folder: {load_ckpt_folder}")
|
||||
|
||||
if load_content.need_load(CheckpointLoadContent.MODEL):
|
||||
load_model_checkpoint(folder=load_ckpt_folder, model=ckpt_mm.model)
|
||||
load_content_str += f"{CheckpointLoadContent.MODEL}, "
|
||||
|
||||
if load_content.not_only_load(CheckpointLoadContent.MODEL):
|
||||
# load training states.
|
||||
load_context(load_ckpt_folder, train_state)
|
||||
|
||||
# load optimzier states.
|
||||
if load_content.need_load(CheckpointLoadContent.OPIMIZER):
|
||||
load_optimizer_checkpoint(load_ckpt_folder, ckpt_mm.optimizer)
|
||||
load_content_str += f"{CheckpointLoadContent.OPIMIZER}, "
|
||||
else:
|
||||
if gpc.is_rank_for_log():
|
||||
logger.warning("CheckpointManager has no 'optimizer', skip reload optim checkpoint!")
|
||||
|
||||
# load lr scheduler states.
|
||||
if load_content.need_load(CheckpointLoadContent.SCHEDULAER):
|
||||
if ckpt_mm.lr_scheduler:
|
||||
load_scheduler(load_ckpt_folder, ckpt_mm.lr_scheduler, ckpt_mm.optimizer, train_state)
|
||||
load_content_str += f"{CheckpointLoadContent.SCHEDULAER}, "
|
||||
else:
|
||||
if gpc.is_rank_for_log():
|
||||
logger.warning("CheckpointManager has no 'lr_scheduler', skip reload lr_scheduler checkpoint!")
|
||||
|
||||
# load dataloader sampler states.
|
||||
if load_content.need_load(CheckpointLoadContent.SAMPLER):
|
||||
if hasattr(train_state, "batch_sampler") and not isinstance(
|
||||
train_state.batch_sampler, torch.utils.data.sampler.BatchSampler
|
||||
):
|
||||
load_sampler(load_ckpt_folder, ckpt_mm.train_dl.batch_sampler)
|
||||
# track the actual updates of sampler when using weighted sampling
|
||||
train_state.init_batch_sampler(ckpt_mm.train_dl.batch_sampler)
|
||||
load_content_str += f"{CheckpointLoadContent.SAMPLER}, "
|
||||
else:
|
||||
if gpc.is_rank_for_log():
|
||||
logger.warning("CheckpointManager skip reload 'batch_sampler'")
|
||||
|
||||
# reload data state dict.
|
||||
if hasattr(train_state, "data_state_dict"):
|
||||
ckpt_mm.train_dl.dataset.load_state_dict(
|
||||
llm_load(os.path.join(load_ckpt_folder, "sampler_0.pt")), ckpt_path=load_ckpt_folder
|
||||
)
|
||||
load_content_str += f"{CheckpointLoadContent.SAMPLER}, "
|
||||
else:
|
||||
if gpc.is_rank_for_log():
|
||||
logger.warning(
|
||||
"CheckpointManager has no 'data_state_dict', skip reload data_state_dict checkpoint!"
|
||||
)
|
||||
return load_content_str
|
||||
|
||||
|
||||
def save_model_checkpoint(folder, model):
|
||||
"""
|
||||
Save the model according to the relationship between tp and dp. The principle is that the data of each tp
|
||||
|
@ -233,15 +391,16 @@ def load_sampler(ckpt_path: str, sampler):
|
|||
torch.cuda.empty_cache()
|
||||
|
||||
|
||||
def load_context(ckpt_path: str, train_dl, train_state: TrainState):
|
||||
def load_context(ckpt_path: str, train_state: TrainState):
|
||||
context_stuffs = llm_load(os.path.join(ckpt_path, "context.pt"))
|
||||
train_state.load_state_dict(context_stuffs, train_dl)
|
||||
train_state.load_state_dict(context_stuffs)
|
||||
if gpc.is_rank_for_log():
|
||||
logger.info(f"reload train_state:{train_state}")
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
|
||||
def load_scheduler(ckpt_path: str, lr_scheduler, optimizer, learning_rate, train_state: TrainState):
|
||||
def load_scheduler(ckpt_path: str, lr_scheduler, optimizer, train_state: TrainState):
|
||||
learning_rate = train_state.lr
|
||||
scheduler_states = llm_load(os.path.join(ckpt_path, "schedulder.pt"))
|
||||
if learning_rate != scheduler_states["base_lrs"][0] and gpc.is_rank_for_log():
|
||||
logger.warning(
|
||||
|
@ -270,7 +429,17 @@ def load_scheduler(ckpt_path: str, lr_scheduler, optimizer, learning_rate, train
|
|||
class CheckpointManager:
|
||||
"""StorageManagerContext"""
|
||||
|
||||
def __init__(self, ckpt_config, model, model_config=None, model_config_file=None, feishu_address=None) -> None:
|
||||
def __init__(
|
||||
self,
|
||||
ckpt_config,
|
||||
model,
|
||||
train_dl=None,
|
||||
optimizer=None,
|
||||
lr_scheduler=None,
|
||||
model_config=None,
|
||||
model_config_file=None,
|
||||
feishu_address=None,
|
||||
) -> None:
|
||||
"""
|
||||
CheckpointManager is used to decide when to store ckpt. If it is an asynchronous
|
||||
upload mode, you must call wait_async_upload_finish at the end of the program to wait
|
||||
|
@ -283,22 +452,44 @@ class CheckpointManager:
|
|||
lr_scheduler (object): lr_scheduler obj.
|
||||
model_config (dict): model config.
|
||||
"""
|
||||
self.enable_save_ckpt = ckpt_config.enable_save_ckpt
|
||||
self.checkpoint_every = ckpt_config.checkpoint_every
|
||||
self.save_ckpt_folder = ckpt_config.save_ckpt_folder
|
||||
self.snapshot_ckpt_folder = ckpt_config.snapshot_ckpt_folder
|
||||
self.oss_snapshot_freq: int = ckpt_config.oss_snapshot_freq
|
||||
self.stop_file_path = ckpt_config.stop_file_path
|
||||
self.load_model_only_folder = ckpt_config.load_model_only_folder
|
||||
self.enable_save_ckpt = get_config_value(ckpt_config, "enable_save_ckpt", False)
|
||||
self.checkpoint_every = get_config_value(ckpt_config, "checkpoint_every", 100)
|
||||
self.save_ckpt_folder = get_config_value(ckpt_config, "save_ckpt_folder", None)
|
||||
self.oss_snapshot_freq: int = get_config_value(ckpt_config, "oss_snapshot_freq", 50)
|
||||
self.stop_file_path = get_config_value(ckpt_config, "stop_file_path", None)
|
||||
if self.save_ckpt_folder:
|
||||
self.snapshot_ckpt_folder = get_config_value(
|
||||
ckpt_config, "snapshot_ckpt_folder", os.path.join(self.save_ckpt_folder, "snapshot")
|
||||
)
|
||||
self.async_upload_tmp_folder = get_config_value(
|
||||
ckpt_config, "async_upload_tmp_folder", "/dev/shm/internlm_tmp_ckpt/"
|
||||
)
|
||||
else:
|
||||
self.snapshot_ckpt_folder = None
|
||||
self.async_upload_tmp_folder = None
|
||||
|
||||
self.async_upload = get_config_value(ckpt_config, "async_upload", False)
|
||||
|
||||
# initialization storage manager
|
||||
init_storage_manager(self.enable_save_ckpt, self.async_upload_tmp_folder, self.async_upload)
|
||||
|
||||
self.feishu_address = feishu_address
|
||||
self.storage_manager = get_storage_manager()
|
||||
self.snapshot_counter = 0
|
||||
self.load_optimizer = gpc.config.ckpt.load_optimizer
|
||||
|
||||
self.model = model
|
||||
self.optimizer = optimizer
|
||||
self.lr_scheduler = lr_scheduler
|
||||
self.train_dl = train_dl
|
||||
self.model_config = model_config
|
||||
self.model_config_file = model_config_file
|
||||
|
||||
# Register defalut internlm ckpt load type.
|
||||
self.defalut_load_type_func = {CheckpointLoadType.INTERNLM: try_load_internlm_ckpt}
|
||||
for ckpt_load_type in CheckpointLoadType:
|
||||
CheckpointLoadMethod.register_ckpt_load_type(ckpt_load_type, self.defalut_load_type_func[ckpt_load_type])
|
||||
|
||||
# Init alter file.
|
||||
if self.stop_file_path and gpc.get_global_rank() == 0:
|
||||
dir_path = os.path.dirname(self.stop_file_path)
|
||||
if dir_path != "" and not os.path.exists(dir_path):
|
||||
|
@ -306,21 +497,35 @@ class CheckpointManager:
|
|||
with open(self.stop_file_path, "w", encoding="utf-8") as f:
|
||||
f.write("0")
|
||||
|
||||
if ckpt_config.load_given_ckpt is False:
|
||||
# Priority: load_given_ckpt(True) > latest_checkpoint > load_model_only_folder
|
||||
latest_ckpt_path = self.query_lastest_ckpt()
|
||||
if latest_ckpt_path:
|
||||
self.load_ckpt_folder = latest_ckpt_path
|
||||
else:
|
||||
# At this time, we have to load model init weights and train from step 0.
|
||||
self.load_ckpt_folder = self.load_model_only_folder
|
||||
else:
|
||||
self.load_ckpt_folder = ckpt_config.load_ckpt_folder
|
||||
self.load_ckpt_info = get_config_value(ckpt_config, "load_ckpt_info", None)
|
||||
if self.load_ckpt_info is None: # (legacy): Try Compatible with old interfaces
|
||||
self.load_ckpt_info = ckpt_info_sanity_check(ckpt_config)
|
||||
|
||||
if gpc.is_rank_for_log():
|
||||
logger.info(f"load_ckpt_folder will set to :'{self.load_ckpt_folder}'")
|
||||
if self.stop_file_path is None:
|
||||
logger.warning("no set stop_file_path, quit_signal_handler is disable")
|
||||
# Auto-reload latest checkpoint, it will overwrite the setting of 'load_ckpt_info'.
|
||||
self.auto_resume = get_config_value(ckpt_config, "auto_resume", None)
|
||||
if self.auto_resume is None: # (legacy): Try Compatible with old interfaces
|
||||
self.auto_resume = auto_resume_sanity_check(ckpt_config)
|
||||
if self.auto_resume:
|
||||
self.load_ckpt_info = self.query_lastest_ckpt()
|
||||
|
||||
if self.stop_file_path is None and gpc.is_rank_for_log():
|
||||
logger.warning("no set stop_file_path, quit_signal_handler is disable")
|
||||
|
||||
# convert to internal representation
|
||||
if self.load_ckpt_info:
|
||||
assert (
|
||||
"path" in self.load_ckpt_info
|
||||
and "content" in self.load_ckpt_info
|
||||
and "ckpt_type" in self.load_ckpt_info
|
||||
), "please set content in ckpt setting, eg: ckpt = dict(path='', content=['model'], ckpt_type='internlm')"
|
||||
|
||||
# replace load_ckpt
|
||||
self.load_ckpt_info["content"] = CheckpointLoadMask(self.load_ckpt_info["content"])
|
||||
self.load_ckpt_info["ckpt_type"] = CheckpointLoadMethod.convet_load_type(self.load_ckpt_info["ckpt_type"])
|
||||
|
||||
# test storage setting is ok.
|
||||
if self.enable_save_ckpt:
|
||||
self.try_ping_storage()
|
||||
|
||||
def quit_signal_handler(self, train_state) -> bool:
|
||||
"""
|
||||
|
@ -334,7 +539,7 @@ class CheckpointManager:
|
|||
Returns:
|
||||
bool: whether to quit.
|
||||
"""
|
||||
now_break, now_save_ckpt, save_type = False, False, CheckpointType.NORMAL_CHECKPOINT
|
||||
now_break, now_save_ckpt, save_type = False, False, CheckpointSaveType.NORMAL_CHECKPOINT
|
||||
|
||||
if self.stop_file_path is None:
|
||||
return now_break, now_save_ckpt, save_type
|
||||
|
@ -365,24 +570,29 @@ now step_count is {train_state.step_count}",
|
|||
|
||||
return now_break, now_save_ckpt, save_type
|
||||
|
||||
def try_save_checkpoint(self, train_state):
|
||||
if not self.enable_save_ckpt:
|
||||
return False
|
||||
|
||||
save_ckpts, save_type = False, CheckpointType.NORMAL_CHECKPOINT
|
||||
def is_now_to_save_ckpt(self, train_state) -> (bool, CheckpointSaveType, bool):
|
||||
save_ckpts, save_type, now_break = False, CheckpointSaveType.NORMAL_CHECKPOINT, False
|
||||
if self.oss_snapshot_freq > 1 and train_state.step_count % self.oss_snapshot_freq == 0:
|
||||
save_ckpts, save_type = True, CheckpointType.SNAPSHOT_CHECKPOINT
|
||||
save_ckpts, save_type = True, CheckpointSaveType.SNAPSHOT_CHECKPOINT
|
||||
if train_state.step_count % self.checkpoint_every == 0:
|
||||
save_ckpts, save_type = True, CheckpointType.NORMAL_CHECKPOINT
|
||||
save_ckpts, save_type = True, CheckpointSaveType.NORMAL_CHECKPOINT
|
||||
now_break, singal_save_ckpts, singal_save_type = self.quit_signal_handler(train_state)
|
||||
if save_ckpts is False:
|
||||
save_ckpts = singal_save_ckpts
|
||||
save_type = singal_save_type
|
||||
|
||||
return save_ckpts, save_type, now_break
|
||||
|
||||
def try_save_checkpoint(self, train_state):
|
||||
if not self.enable_save_ckpt:
|
||||
return False
|
||||
|
||||
save_ckpts, save_type, now_break = self.is_now_to_save_ckpt(train_state)
|
||||
|
||||
if save_ckpts:
|
||||
# Wait for the previous round of asynchronous upload storage to complete.
|
||||
self.storage_manager.wait()
|
||||
if save_type == CheckpointType.SNAPSHOT_CHECKPOINT:
|
||||
if save_type == CheckpointSaveType.SNAPSHOT_CHECKPOINT:
|
||||
# Snapshot number, with only two snapshots written alternately.
|
||||
self.snapshot_counter = (self.snapshot_counter + 1) % 2
|
||||
save_ckpt_folder = os.path.join(self.snapshot_ckpt_folder, f"{self.snapshot_counter}")
|
||||
|
@ -412,7 +622,7 @@ now step_count is {train_state.step_count}",
|
|||
Tuple(str, int): path of latest ckpt and ckpt step, if not found, None will return.
|
||||
"""
|
||||
ckpt_list = self.storage_manager.get_fns(self.save_ckpt_folder)
|
||||
if len(ckpt_list) == 0:
|
||||
if ckpt_list is None or len(ckpt_list) == 0:
|
||||
return None, None
|
||||
|
||||
max_normal_step = 0
|
||||
|
@ -435,14 +645,16 @@ now step_count is {train_state.step_count}",
|
|||
ckpt_list_1 = self.storage_manager.get_fns(snapshot_path_0)
|
||||
ckpt_list_2 = self.storage_manager.get_fns(snapshot_path_1)
|
||||
max_step_0, max_step_1 = 0, 0
|
||||
for ckpt in ckpt_list_1:
|
||||
ckpt = ckpt.strip("/")
|
||||
if ckpt.endswith(".step"):
|
||||
max_step_0 = max(max_step_0, int(ckpt.split(".")[0]))
|
||||
for ckpt in ckpt_list_2:
|
||||
ckpt = ckpt.strip("/")
|
||||
if ckpt.endswith(".step"):
|
||||
max_step_1 = max(max_step_1, int(ckpt.split(".")[0]))
|
||||
if ckpt_list_1:
|
||||
for ckpt in ckpt_list_1:
|
||||
ckpt = ckpt.strip("/")
|
||||
if ckpt.endswith(".step"):
|
||||
max_step_0 = max(max_step_0, int(ckpt.split(".")[0]))
|
||||
if ckpt_list_2:
|
||||
for ckpt in ckpt_list_2:
|
||||
ckpt = ckpt.strip("/")
|
||||
if ckpt.endswith(".step"):
|
||||
max_step_1 = max(max_step_1, int(ckpt.split(".")[0]))
|
||||
|
||||
snap_load_path = snapshot_path_0 if max_step_0 > max_step_1 else snapshot_path_1
|
||||
snap_step = max(max_step_0, max_step_1)
|
||||
|
@ -452,11 +664,12 @@ now step_count is {train_state.step_count}",
|
|||
|
||||
def query_latest_snapshot_step_local(self):
|
||||
max_step, max_step_path = 0, None
|
||||
for root, _, files in os.walk(self.save_ckpt_folder, followlinks=True):
|
||||
save_ckpt_folder = self.save_ckpt_folder.split(":")[1]
|
||||
for root, _, files in os.walk(save_ckpt_folder, followlinks=True):
|
||||
for fn in files:
|
||||
fn = fn.strip("/")
|
||||
if fn.endswith(".step"):
|
||||
# We assume that both normal ckpt and snapshot ckpt will store the '.step' file
|
||||
# We assume that both internlm ckpt and snapshot ckpt will store the '.step' file
|
||||
# as an integrity flag.
|
||||
step = int(fn.rsplit(".", maxsplit=1)[0])
|
||||
if max_step < step:
|
||||
|
@ -466,99 +679,53 @@ now step_count is {train_state.step_count}",
|
|||
return max_step_path, max_step
|
||||
|
||||
def query_lastest_ckpt(self):
|
||||
latest_checkpoint = None
|
||||
latest_ckpt, step = None, -1
|
||||
# Training was automatically restarted by the process, forcing the latest snapshot to be read.
|
||||
if self.save_ckpt_folder:
|
||||
if self.save_ckpt_folder.startswith("boto3"):
|
||||
latest_checkpoint, step = self.query_latest_snapshot_step_boto3()
|
||||
elif self.save_ckpt_folder.startswith("local"):
|
||||
latest_checkpoint, step = self.query_latest_snapshot_step_local()
|
||||
else:
|
||||
latest_checkpoint, step = None, 0
|
||||
backend, _ = try_get_storage_backend(self.save_ckpt_folder)
|
||||
if backend == "boto3":
|
||||
latest_ckpt, step = self.query_latest_snapshot_step_boto3()
|
||||
if latest_ckpt and not latest_ckpt.startswith("boto3:"):
|
||||
latest_ckpt = ":".join(["boto3", latest_ckpt])
|
||||
elif backend == "local":
|
||||
latest_ckpt, step = self.query_latest_snapshot_step_local()
|
||||
if latest_ckpt and not latest_ckpt.startswith("local:"):
|
||||
latest_ckpt = ":".join(["local", latest_ckpt])
|
||||
|
||||
if latest_checkpoint is not None:
|
||||
if gpc.is_rank_for_log():
|
||||
logger.info(f"Found latest ckpt : {latest_checkpoint}, step: {step}")
|
||||
send_alert_message(
|
||||
address=self.feishu_address,
|
||||
message=f"Auto restart resume from ckpt-path: '{latest_checkpoint}', step : {step}",
|
||||
)
|
||||
else:
|
||||
if gpc.is_rank_for_log():
|
||||
send_alert_message(
|
||||
address=self.feishu_address,
|
||||
message=f"Can't find snapshot checkpoint, use default load-ckpt path: {latest_checkpoint}",
|
||||
)
|
||||
if gpc.is_rank_for_log():
|
||||
logger.info(f"Found latest ckpt {latest_ckpt if latest_ckpt else 'None'}, step: {step}...")
|
||||
|
||||
return latest_checkpoint
|
||||
return dict(path=latest_ckpt, content=("all",), ckpt_type="internlm")
|
||||
|
||||
def try_load_model(self, current_time=""):
|
||||
model_load_path = None
|
||||
def try_resume_training(self, train_state: TrainState, current_time=""):
|
||||
|
||||
if self.load_ckpt_folder and self.load_model_only_folder:
|
||||
raise ValueError(
|
||||
"Error, try to use both load_ckpt_folder and load_model_only_folder paths, \
|
||||
if you only need to load model weights (for example starting an SFT task for the first time), \
|
||||
set load_model_only_folder path, if you need to resume training from ckpt, \
|
||||
set load_ckpt_folder or use default value \
|
||||
(if is the default value, internlm will try to load the latest ckpt from save_ckpt_folder)"
|
||||
)
|
||||
|
||||
if self.load_ckpt_folder:
|
||||
if gpc.is_rank_for_log():
|
||||
logger.info(
|
||||
f"===========Resume training from `{self.load_ckpt_folder}` {current_time} on host:"
|
||||
f"{socket.gethostname()}==========="
|
||||
)
|
||||
model_load_path = self.load_ckpt_folder
|
||||
elif self.load_model_only_folder:
|
||||
if gpc.is_rank_for_log():
|
||||
logger.info(
|
||||
f"===========Load Model from `{self.load_model_only_folder}` {current_time} on host:"
|
||||
f"{socket.gethostname()}==========="
|
||||
)
|
||||
model_load_path = self.load_model_only_folder
|
||||
else:
|
||||
if self.load_ckpt_info is None or self.load_ckpt_info["path"] is None:
|
||||
if gpc.is_rank_for_log():
|
||||
logger.info(
|
||||
f"===========New Run {current_time} on host:{socket.gethostname()},rank={gpc.get_global_rank()},"
|
||||
f"tp={gpc.get_local_rank(ParallelMode.TENSOR)},pp={gpc.get_local_rank(ParallelMode.PIPELINE)},"
|
||||
f"dp={gpc.get_local_rank(ParallelMode.DATA)}==========="
|
||||
)
|
||||
else:
|
||||
load_path = self.load_ckpt_info["path"]
|
||||
load_content = self.load_ckpt_info["content"]
|
||||
load_type = self.load_ckpt_info["ckpt_type"]
|
||||
|
||||
# Loading model weights must be done before zero is initialized.
|
||||
if model_load_path is not None:
|
||||
load_model_checkpoint(folder=model_load_path, model=self.model)
|
||||
load_func = CheckpointLoadMethod.get_ckpt_load_type_func(load_type)
|
||||
load_content_str = load_func(self, self.load_ckpt_info, train_state)
|
||||
|
||||
def try_resume_training(self, lr_scheduler, optimizer, lr, train_state, train_dl):
|
||||
"""Attempt to restore the training state of the last ckpt.
|
||||
# If we only load model weight, we need rewrite zero optim's fp32 buffer.
|
||||
if load_content.only_load(CheckpointLoadContent.MODEL) and isinstance(self.optimizer, HybridZeroOptimizer):
|
||||
reload_zero_fp32_buff(self.optimizer)
|
||||
|
||||
Args:
|
||||
lr_scheduler (_LRScheduler): lr_scheduler object.
|
||||
optimizer (Optimizer): optimizer object.
|
||||
lr (float): learning rate.
|
||||
train_state (dict): traing states.
|
||||
train_dl (DataLoader): traning dataloader object
|
||||
"""
|
||||
if self.load_ckpt_folder is not None:
|
||||
# load optimzier states.
|
||||
if self.load_optimizer:
|
||||
load_optimizer_checkpoint(self.load_ckpt_folder, optimizer)
|
||||
# load lr scheduler states.
|
||||
load_scheduler(self.load_ckpt_folder, lr_scheduler, optimizer, lr, train_state)
|
||||
# load training states.
|
||||
load_context(self.load_ckpt_folder, train_dl, train_state)
|
||||
# load dataloader sampler states.
|
||||
if hasattr(train_state, "batch_sampler") and not isinstance(
|
||||
train_state.batch_sampler, torch.utils.data.sampler.BatchSampler
|
||||
):
|
||||
load_sampler(self.load_ckpt_folder, train_dl.batch_sampler)
|
||||
if hasattr(train_state, "data_state_dict"):
|
||||
train_dl.dataset.load_state_dict(
|
||||
llm_load(os.path.join(self.load_ckpt_folder, "sampler_0.pt")), ckpt_path=self.load_ckpt_folder
|
||||
if gpc.is_rank_for_log():
|
||||
logger.info(f"load_ckpt_info : {self.load_ckpt_info}")
|
||||
logger.info(
|
||||
f"===========Resume training from `{load_path}` {current_time} on host:"
|
||||
f"{socket.gethostname()}==========="
|
||||
)
|
||||
self.optimizer = optimizer
|
||||
self.lr_scheduler = lr_scheduler
|
||||
if load_content_str:
|
||||
logger.info(f"===========Load contents are: {load_content_str}")
|
||||
|
||||
def save_checkpoint(
|
||||
self,
|
||||
|
@ -600,8 +767,10 @@ set load_ckpt_folder or use default value \
|
|||
)
|
||||
|
||||
if gpc.is_rank_for_log():
|
||||
scheduler_states = scheduler.state_dict()
|
||||
llm_save(os.path.join(folder, "schedulder.pt"), saved_obj=scheduler_states)
|
||||
if scheduler:
|
||||
scheduler_states = scheduler.state_dict()
|
||||
llm_save(os.path.join(folder, "schedulder.pt"), saved_obj=scheduler_states)
|
||||
|
||||
if hasattr(train_state, "batch_sampler") and not isinstance(
|
||||
train_state.batch_sampler, torch.utils.data.sampler.BatchSampler
|
||||
):
|
||||
|
@ -631,3 +800,12 @@ set load_ckpt_folder or use default value \
|
|||
def set_save_folder(self, folder, step):
|
||||
self.storage_manager.latest_save_folder = folder
|
||||
self.storage_manager.latest_save_step = step
|
||||
|
||||
def try_ping_storage(self):
|
||||
if gpc.get_global_rank() % 8 == 0:
|
||||
buff = torch.ones((1, 64, 64), dtype=torch.bfloat16)
|
||||
test_fn = os.path.join(self.save_ckpt_folder, f"pings/{socket.gethostname()}.ping")
|
||||
self.storage_manager.save(test_fn, buff)
|
||||
self.storage_manager.wait()
|
||||
self.storage_manager.load(test_fn)
|
||||
del buff
|
||||
|
|
|
@ -136,6 +136,22 @@ def compute_file_md5_by_chunk(file_name: str):
|
|||
return hash_md5.hexdigest()
|
||||
|
||||
|
||||
def try_get_storage_backend(path: str):
|
||||
sre = path.split(":", maxsplit=1)
|
||||
if len(sre) == 1:
|
||||
if path.startswith("s3:"):
|
||||
backend = "boto3"
|
||||
if gpc.is_rank_for_log():
|
||||
logger.warning(f"path: '{path}' not start with backend prefix, guess it is the backend of boto3.")
|
||||
else:
|
||||
backend = "local"
|
||||
if gpc.is_rank_for_log():
|
||||
logger.warning(f"path: '{path}' not start with backend prefix, guess it is the backend of local.")
|
||||
return backend, sre
|
||||
else:
|
||||
return sre[0], sre[1] # (backend_prefix, splited_path)
|
||||
|
||||
|
||||
class Boto3Client(StorageClient):
|
||||
"""
|
||||
Boto3Client
|
||||
|
@ -231,21 +247,34 @@ class Boto3Client(StorageClient):
|
|||
def assert_fp_exists(handler, bucket_name: str, fp: str, local_nvme_path: str): # pylint: disable=W0613
|
||||
assert len(list(handler.client.list_objects(Bucket=bucket_name, Prefix=fp)["Contents"])) > 0, fp
|
||||
|
||||
@staticmethod
|
||||
def is_fp_exists(handler, bucket_name: str, fp: str, local_nvme_path: str): # pylint: disable=W0613
|
||||
re = handler.client.list_objects(Bucket=bucket_name, Prefix=fp)
|
||||
if "Contents" in re:
|
||||
return len(list(re["Contents"])) > 0
|
||||
else:
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def get_fns(handler, bucket_name: str, fp: str, local_nvme_path: str, *args, **kwargs): # pylint: disable=W0613
|
||||
"""
|
||||
Ref: https://stackoverflow.com/questions/54314563/
|
||||
how-to-get-more-than-1000-objects-from-s3-by-using-list-objects-v2
|
||||
"""
|
||||
paginator = handler.client.get_paginator("list_objects_v2")
|
||||
pages = paginator.paginate(Bucket=bucket_name, Prefix=fp)
|
||||
folder_name_list = []
|
||||
for page in pages:
|
||||
if "Contents" in page:
|
||||
for obj in page["Contents"]:
|
||||
pth: str = obj["Key"]
|
||||
folder_name_list.append(pth.split(fp, maxsplit=1)[1].strip("/").split("/", maxsplit=1)[0])
|
||||
return list(set(folder_name_list))
|
||||
if Boto3Client.is_fp_exists(handler, bucket_name, fp, None):
|
||||
paginator = handler.client.get_paginator("list_objects_v2")
|
||||
pages = paginator.paginate(Bucket=bucket_name, Prefix=fp)
|
||||
folder_name_list = []
|
||||
for page in pages:
|
||||
if "Contents" in page:
|
||||
for obj in page["Contents"]:
|
||||
pth: str = obj["Key"]
|
||||
folder_name_list.append(pth.split(fp, maxsplit=1)[1].strip("/").split("/", maxsplit=1)[0])
|
||||
return list(set(folder_name_list))
|
||||
else:
|
||||
if gpc.is_rank_for_log():
|
||||
logger.warning(f"'{fp}' not found!")
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def async_upload_fileobj(handler, bucket_name: str, fp: str, local_nvme_path: str):
|
||||
|
@ -297,9 +326,12 @@ class LocalClient(StorageClient):
|
|||
@staticmethod
|
||||
def get_fns(handler, folder):
|
||||
assert isinstance(handler, LocalClient)
|
||||
assert os.path.exists(folder), f"folder '{folder}' not exists!"
|
||||
fns = os.listdir(folder)
|
||||
return fns
|
||||
if not os.path.exists(folder):
|
||||
if gpc.is_rank_for_log():
|
||||
logger.warning(f"'{folder}' not found!")
|
||||
return None
|
||||
else:
|
||||
return os.listdir(folder)
|
||||
|
||||
@staticmethod
|
||||
def delete_obj(handler, fp: str):
|
||||
|
@ -436,10 +468,7 @@ class StorageManager(metaclass=SingletonMeta):
|
|||
Args:
|
||||
path (str): _description_
|
||||
"""
|
||||
try:
|
||||
backend, path = path.split(":", maxsplit=1)
|
||||
except Exception as exc:
|
||||
raise AttributeError(f"Given path '{path}' is not startwith backend prefix:'local/boto3'") from exc
|
||||
backend, path = try_get_storage_backend(path)
|
||||
|
||||
init_args = (None,)
|
||||
if backend == "local":
|
||||
|
@ -594,23 +623,24 @@ class StorageManager(metaclass=SingletonMeta):
|
|||
|
||||
if gpc.is_rank_for_log():
|
||||
self.upload_count += 1
|
||||
if self.async_mode:
|
||||
if self.async_mode and self.latest_save_folder:
|
||||
self.save(
|
||||
os.path.join(self.latest_save_folder, f"{self.latest_save_step}.step"),
|
||||
saved_obj=dict({"step": self.latest_save_step}),
|
||||
async_upload=False,
|
||||
)
|
||||
self.latest_save_folder = None
|
||||
|
||||
|
||||
storage_manager: StorageManager = None
|
||||
|
||||
|
||||
def init_storage_manager(ckpt_config):
|
||||
def init_storage_manager(enable_save_ckpt, async_upload_tmp_folder, async_upload):
|
||||
global storage_manager
|
||||
storage_manager = StorageManager(
|
||||
ckpt_config.enable_save_ckpt,
|
||||
tmp_local_folder=ckpt_config.async_upload_tmp_folder,
|
||||
async_mode=ckpt_config.async_upload,
|
||||
enable_save_ckpt,
|
||||
tmp_local_folder=async_upload_tmp_folder,
|
||||
async_mode=async_upload,
|
||||
)
|
||||
|
||||
|
||||
|
|
|
@ -0,0 +1,143 @@
|
|||
import os
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from internlm.core.context import global_context as gpc
|
||||
from internlm.core.context.parallel_context import Config
|
||||
from internlm.solver.optimizer.hybrid_zero_optim import HybridZeroOptimizer
|
||||
from internlm.utils.common import SingletonMeta
|
||||
|
||||
# 1B
|
||||
init_config = Config(
|
||||
dict(
|
||||
parallel=dict(zero1=1, pipeline=dict(size=1, interleaved_overlap=False), sequence_parallel=False, tensor=1),
|
||||
model_type="INTERNLM",
|
||||
adam=dict(
|
||||
lr=1e-4,
|
||||
),
|
||||
data=dict(seq_len=2048, micro_num=1, micro_bsz=1, pack_sample_into_one=False, min_length=0, total_steps=9999),
|
||||
model=dict(
|
||||
checkpoint=False,
|
||||
num_attention_heads=2,
|
||||
embed_split_hidden=True,
|
||||
vocab_size=103168,
|
||||
embed_grad_scale=1,
|
||||
parallel_output=True,
|
||||
hidden_size=1024,
|
||||
num_layers=2,
|
||||
mlp_ratio=1,
|
||||
apply_post_layer_norm=False,
|
||||
dtype=torch.bfloat16,
|
||||
norm_type="rmsnorm",
|
||||
layer_norm_epsilon=1e-5,
|
||||
use_flash_attn=True,
|
||||
num_chunks=1,
|
||||
),
|
||||
resume_tb_folder="",
|
||||
tensorboard_folder="",
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def init_naive_model():
|
||||
# let MODEL_INITIALIZER to work
|
||||
import internlm.model.modeling_internlm # noqa # pylint: disable=unused-import
|
||||
from internlm.core.naive_amp import NaiveAMPModel
|
||||
from internlm.utils.registry import MODEL_INITIALIZER
|
||||
|
||||
model = MODEL_INITIALIZER.get_module(module_name=gpc.config.model_type)(**(init_config.model))
|
||||
model = NaiveAMPModel(
|
||||
model=model,
|
||||
output_to_fp32=False,
|
||||
dtype=torch.bfloat16,
|
||||
sync_buffer=False,
|
||||
)
|
||||
return model
|
||||
|
||||
|
||||
def init_naive_optim(model):
|
||||
naive_optimizer = torch.optim.AdamW(
|
||||
params=[{"params": model.parameters(), "weight_decay": 0.01}],
|
||||
lr=1e-4,
|
||||
betas=(0.9, 0.95),
|
||||
eps=1e-8,
|
||||
)
|
||||
return naive_optimizer
|
||||
|
||||
|
||||
def init_hybrid_optim(model):
|
||||
naive_optimizer = torch.optim.AdamW(
|
||||
params=[{"params": model.parameters(), "weight_decay": 0.01}],
|
||||
lr=1e-4,
|
||||
betas=(0.9, 0.95),
|
||||
eps=1e-8,
|
||||
)
|
||||
optimizer = HybridZeroOptimizer(
|
||||
naive_optimizer,
|
||||
grad_scal_cfg=Config(
|
||||
dict(
|
||||
fp16=dict(
|
||||
initial_scale=2**16,
|
||||
min_scale=1,
|
||||
growth_interval=1000,
|
||||
),
|
||||
growth_factor=2,
|
||||
backoff_factor=0.5,
|
||||
max_scale=2**24,
|
||||
hysteresis=2,
|
||||
)
|
||||
),
|
||||
zero_cfg=Config(
|
||||
dict(
|
||||
overlap_sync_grad=False,
|
||||
overlap_sync_param=False,
|
||||
reduce_bucket_size=512 * 1024 * 1024,
|
||||
clip_grad_norm=1.0,
|
||||
)
|
||||
),
|
||||
param_bcast_sync_handler=None,
|
||||
)
|
||||
return optimizer
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True, scope="function")
|
||||
def reset_singletons():
|
||||
SingletonMeta._instances = {}
|
||||
|
||||
|
||||
def reset_seed():
|
||||
from internlm.core.context.random import _SEED_MANAGER
|
||||
_SEED_MANAGER.reset()
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def init_dist_and_model():
|
||||
from internlm.initialize import initialize_distributed_env
|
||||
|
||||
os.environ["RANK"] = "0"
|
||||
os.environ["LOCAL_RANK"] = "0"
|
||||
os.environ["WORLD_SIZE"] = "1"
|
||||
os.environ["MASTER_ADDR"] = "127.0.0.1"
|
||||
os.environ["MASTER_PORT"] = "12377"
|
||||
initialize_distributed_env(config=init_config, launcher="torch", master_port=12377, args_check=False)
|
||||
|
||||
# setup
|
||||
print("set up", flush=True)
|
||||
model = init_naive_model()
|
||||
# opim = init_naive_optim(model)
|
||||
opim = init_hybrid_optim(model)
|
||||
|
||||
yield model, opim
|
||||
|
||||
# teardown
|
||||
del model, opim
|
||||
print("teardown", flush=True)
|
||||
gpc.destroy()
|
||||
reset_seed()
|
||||
|
||||
|
||||
|
||||
def enter_flag(text):
|
||||
print(f"{text} begin!", flush=True)
|
||||
yield
|
||||
print(f"{text} end!", flush=True)
|
|
@ -0,0 +1,278 @@
|
|||
import os
|
||||
import shutil
|
||||
from subprocess import PIPE, STDOUT, Popen
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from internlm.core.context import global_context as gpc
|
||||
from internlm.core.context.parallel_context import Config
|
||||
from internlm.core.trainer import TrainState
|
||||
from internlm.solver.optimizer.hybrid_zero_optim import HybridZeroOptimizer
|
||||
from internlm.utils.common import SingletonMeta
|
||||
from internlm.utils.model_checkpoint import CheckpointManager
|
||||
from internlm.utils.storage_manager import wait_async_upload_finish
|
||||
from tests.test_utils.common_fixture import ( # noqa # pylint: disable=unused-import
|
||||
init_dist_and_model,
|
||||
reset_singletons,
|
||||
)
|
||||
|
||||
TOTAL_STEP = 6
|
||||
|
||||
CKPT_EVERY = 4
|
||||
SNPASHOT_EVERY = 2
|
||||
OSS_NAME = os.environ["OSS_BUCKET_NAME"]
|
||||
OSS_IP = os.environ["OSS_IP"]
|
||||
USER = os.environ["USER"]
|
||||
JOB_NAME = "CI_TEST"
|
||||
LOCAL_SAVE_PATH = "local:local_ckpt"
|
||||
|
||||
BOTO_SAVE_PATH = f"boto3:s3://{OSS_NAME}.{OSS_IP}/{USER}/{JOB_NAME}"
|
||||
BOTO_SAVE_PATH_NO_PRFIX = f"s3://{OSS_NAME}.{OSS_IP}/{USER}/{JOB_NAME}/"
|
||||
|
||||
ASYNC_TMP_FOLDER = "./async_tmp_folder"
|
||||
|
||||
|
||||
def del_tmp_file():
|
||||
try:
|
||||
shutil.rmtree(ASYNC_TMP_FOLDER, ignore_errors=True)
|
||||
except FileNotFoundError:
|
||||
pass
|
||||
|
||||
try:
|
||||
shutil.rmtree(LOCAL_SAVE_PATH.split(":")[1], ignore_errors=True)
|
||||
except FileNotFoundError:
|
||||
pass
|
||||
|
||||
try:
|
||||
cmd = r"/mnt/petrelfs/share/sensesync --dryrun --deleteSrc cp " + BOTO_SAVE_PATH_NO_PRFIX + " / "
|
||||
with Popen(cmd, stdout=PIPE, stderr=STDOUT, shell=True) as output:
|
||||
results, presults = "", ""
|
||||
for line in iter(output.stdout.readline, b""):
|
||||
results += str(line.rstrip())
|
||||
presults += line.rstrip().decode() + "\n"
|
||||
print(presults, flush=True)
|
||||
except FileNotFoundError:
|
||||
pass
|
||||
|
||||
|
||||
ckpt_config_list = [
|
||||
# Old interface format
|
||||
dict(
|
||||
enable_save_ckpt=True,
|
||||
save_ckpt_folder=BOTO_SAVE_PATH,
|
||||
load_optimizer=True,
|
||||
checkpoint_every=CKPT_EVERY,
|
||||
async_upload=True,
|
||||
async_upload_tmp_folder=ASYNC_TMP_FOLDER,
|
||||
snapshot_ckpt_folder="/".join([BOTO_SAVE_PATH, "snapshot"]),
|
||||
oss_snapshot_freq=SNPASHOT_EVERY,
|
||||
stop_file_path=None,
|
||||
load_model_only_folder=None,
|
||||
load_given_ckpt=False,
|
||||
load_ckpt_folder=None,
|
||||
is_old_api=True,
|
||||
),
|
||||
# Old interface format
|
||||
dict(
|
||||
enable_save_ckpt=True,
|
||||
save_ckpt_folder=LOCAL_SAVE_PATH,
|
||||
load_optimizer=True,
|
||||
checkpoint_every=CKPT_EVERY,
|
||||
async_upload=False,
|
||||
async_upload_tmp_folder=ASYNC_TMP_FOLDER,
|
||||
snapshot_ckpt_folder="/".join([LOCAL_SAVE_PATH, "snapshot"]),
|
||||
oss_snapshot_freq=SNPASHOT_EVERY,
|
||||
stop_file_path=None,
|
||||
load_model_only_folder=None,
|
||||
load_given_ckpt=False,
|
||||
load_ckpt_folder=None,
|
||||
is_old_api=True,
|
||||
),
|
||||
# New interface format
|
||||
dict(
|
||||
enable_save_ckpt=True,
|
||||
save_ckpt_folder=BOTO_SAVE_PATH,
|
||||
checkpoint_every=CKPT_EVERY,
|
||||
async_upload=True,
|
||||
async_upload_tmp_folder=ASYNC_TMP_FOLDER,
|
||||
oss_snapshot_freq=SNPASHOT_EVERY,
|
||||
stop_file_path=None,
|
||||
is_old_api=False,
|
||||
auto_resume=True,
|
||||
),
|
||||
dict(
|
||||
enable_save_ckpt=True,
|
||||
save_ckpt_folder=LOCAL_SAVE_PATH,
|
||||
checkpoint_every=CKPT_EVERY,
|
||||
async_upload=False,
|
||||
async_upload_tmp_folder=ASYNC_TMP_FOLDER,
|
||||
oss_snapshot_freq=SNPASHOT_EVERY,
|
||||
stop_file_path=None,
|
||||
load_ckpt_folder=None,
|
||||
is_old_api=False,
|
||||
auto_resume=True,
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
def overwrite_optim_state(optim, set_value):
|
||||
if isinstance(optim, HybridZeroOptimizer):
|
||||
for group_id, p in optim._fp32_flat_param_groups_of_current_rank.items():
|
||||
if optim._zero_local_rank not in optim.param_group_no_params_ranks[group_id]:
|
||||
# p.copy_(torch.full_like(p, set_value, dtype=p.dtype))
|
||||
p.data.fill_(set_value)
|
||||
for group_id in range(len(optim._fp16_param_groups)):
|
||||
if optim._zero_local_rank not in optim.param_group_no_params_ranks[group_id]:
|
||||
fp16_p = optim._param_store.get_flat_fp16_param_by_rank_group(
|
||||
rank=optim._zero_local_rank, group_id=group_id
|
||||
)
|
||||
fp16_p.fill_(set_value)
|
||||
else:
|
||||
for group in optim.param_groups:
|
||||
for p in group["params"]:
|
||||
# p.copy_(torch.full_like(p, set_value, dtype=p.dtype))
|
||||
p.data.fill_(set_value)
|
||||
|
||||
|
||||
def compare_optim_state(optim1, optim2):
|
||||
re = True
|
||||
if isinstance(optim1, HybridZeroOptimizer):
|
||||
fp32_buff1 = optim1._fp32_flat_param_groups_of_current_rank
|
||||
fp32_buff2 = optim2._fp32_flat_param_groups_of_current_rank
|
||||
for group_id_1, group_id_2 in zip(fp32_buff1, fp32_buff2):
|
||||
re &= group_id_1 == group_id_2
|
||||
if optim1.zero_local_rank not in optim1.param_group_no_params_ranks[group_id_1]:
|
||||
re &= torch.equal(fp32_buff1[group_id_1], fp32_buff1[group_id_2])
|
||||
else:
|
||||
for group1, group2 in zip(optim1.param_groups, optim2.param_groups):
|
||||
for p1, p2 in zip(group1["params"], group2["params"]):
|
||||
re &= torch.equal(p1, p2)
|
||||
return re
|
||||
|
||||
|
||||
def compare_optim_value(optim, value):
|
||||
re = True
|
||||
if isinstance(optim, HybridZeroOptimizer):
|
||||
for group_id, p in optim._fp32_flat_param_groups_of_current_rank.items():
|
||||
if optim._zero_local_rank not in optim.param_group_no_params_ranks[group_id]:
|
||||
re &= torch.equal(p, torch.full_like(p, value, dtype=p.dtype))
|
||||
for group_id in range(len(optim._fp16_param_groups)):
|
||||
if optim._zero_local_rank not in optim.param_group_no_params_ranks[group_id]:
|
||||
fp16_p = optim._param_store.get_flat_fp16_param_by_rank_group(
|
||||
rank=optim._zero_local_rank, group_id=group_id
|
||||
)
|
||||
re &= torch.equal(fp16_p, torch.full_like(fp16_p, value, dtype=fp16_p.dtype))
|
||||
else:
|
||||
for group in optim.param_groups:
|
||||
for p in group["params"]:
|
||||
re &= torch.equal(p, torch.full_like(p, value, dtype=p.dtype))
|
||||
return re
|
||||
|
||||
|
||||
def overwrite_model_value(model, value):
|
||||
for p in model.parameters():
|
||||
# p.copy_(torch.full_like(p, value, dtype=p.dtype))
|
||||
p.data.fill_(value)
|
||||
|
||||
|
||||
def compare_model_value(model, value):
|
||||
re = True
|
||||
for p in model.parameters():
|
||||
re &= torch.equal(p, torch.full_like(p, value, dtype=p.dtype))
|
||||
return re
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def del_tmp():
|
||||
del_tmp_file()
|
||||
yield
|
||||
del_tmp_file()
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("del_tmp")
|
||||
@pytest.mark.usefixtures("reset_singletons")
|
||||
@pytest.mark.parametrize("ckpt_config", ckpt_config_list)
|
||||
def test_ckpt_mm(ckpt_config, init_dist_and_model): # noqa # pylint: disable=unused-import
|
||||
from internlm.utils.model_checkpoint import CheckpointLoadMask, CheckpointLoadType
|
||||
|
||||
ckpt_config = Config(ckpt_config)
|
||||
assert ckpt_config.checkpoint_every < TOTAL_STEP
|
||||
assert ckpt_config.oss_snapshot_freq < TOTAL_STEP
|
||||
|
||||
model, opim = init_dist_and_model
|
||||
train_state = TrainState(gpc.config, None)
|
||||
if isinstance(opim, HybridZeroOptimizer):
|
||||
print("Is HybridZeroOptimizer!", flush=True)
|
||||
else:
|
||||
print("Is naive Adam!", flush=True)
|
||||
|
||||
ckpt_mm = CheckpointManager(ckpt_config, model=model, optimizer=opim)
|
||||
latest_ckpt_step = None
|
||||
for i in range(TOTAL_STEP + 1):
|
||||
overwrite_model_value(model, i)
|
||||
overwrite_optim_state(opim, i)
|
||||
|
||||
train_state.batch_count = i
|
||||
train_state.step_count += 1
|
||||
|
||||
save_ckpts, _, _ = ckpt_mm.is_now_to_save_ckpt(train_state)
|
||||
if save_ckpts:
|
||||
latest_ckpt_step = i
|
||||
|
||||
ckpt_mm.try_save_checkpoint(train_state)
|
||||
|
||||
wait_async_upload_finish()
|
||||
latest_ckpt_info = ckpt_mm.query_lastest_ckpt()
|
||||
assert latest_ckpt_info is not None
|
||||
latest_ckpt = latest_ckpt_info["path"]
|
||||
if ckpt_mm.save_ckpt_folder.startswith("local"):
|
||||
assert latest_ckpt == "local:local_ckpt/snapshot/0", latest_ckpt
|
||||
else:
|
||||
assert latest_ckpt == f"{BOTO_SAVE_PATH}/snapshot/0", latest_ckpt
|
||||
|
||||
del ckpt_mm
|
||||
SingletonMeta._instances = {}
|
||||
ckpt_mm = CheckpointManager(ckpt_config, model=model, optimizer=opim)
|
||||
ckpt_mm.try_resume_training(train_state)
|
||||
assert latest_ckpt_step == 5
|
||||
assert train_state.step_count == 6
|
||||
assert train_state.batch_count == 6
|
||||
assert compare_optim_value(ckpt_mm.optimizer, latest_ckpt_step), ckpt_mm.optimizer.param_groups[0]["params"][0]
|
||||
assert compare_model_value(ckpt_mm.model, latest_ckpt_step), list(ckpt_mm.model.parameters())[0][0]
|
||||
|
||||
if ckpt_mm.save_ckpt_folder.startswith("local:"):
|
||||
ckpt_mm.load_ckpt_info = dict(
|
||||
path=os.path.join(LOCAL_SAVE_PATH, "4"),
|
||||
content=CheckpointLoadMask(("all",)),
|
||||
ckpt_type=CheckpointLoadType.INTERNLM,
|
||||
)
|
||||
else:
|
||||
ckpt_mm.load_ckpt_info = dict(
|
||||
path=os.path.join(BOTO_SAVE_PATH, "4"),
|
||||
content=CheckpointLoadMask(("all",)),
|
||||
ckpt_type=CheckpointLoadType.INTERNLM,
|
||||
)
|
||||
|
||||
ckpt_mm.try_resume_training(train_state)
|
||||
|
||||
assert train_state.step_count == 4
|
||||
assert train_state.batch_count == 4
|
||||
assert compare_optim_value(ckpt_mm.optimizer, 3), ckpt_mm.optimizer.param_groups[0]["params"][0]
|
||||
assert compare_model_value(ckpt_mm.model, 3), list(ckpt_mm.model.parameters())[0][0]
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("del_tmp")
|
||||
@pytest.mark.usefixtures("reset_singletons")
|
||||
@pytest.mark.parametrize("ckpt_config", ckpt_config_list)
|
||||
def test_ckpt_mm_ping(ckpt_config, init_dist_and_model): # noqa # pylint: disable=unused-import
|
||||
ckpt_config = Config(ckpt_config)
|
||||
|
||||
model, opim = init_dist_and_model
|
||||
SingletonMeta._instances = {}
|
||||
ckpt_mm = CheckpointManager(ckpt_config, model=model, optimizer=opim)
|
||||
ckpt_mm.try_ping_storage()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main()
|
|
@ -0,0 +1,26 @@
|
|||
import pytest
|
||||
|
||||
from internlm.core.context.parallel_context import Config
|
||||
from internlm.initialize.launch import get_config_value
|
||||
from tests.test_utils.common_fixture import ( # noqa # pylint: disable=unused-import
|
||||
BOTO_SAVE_PATH,
|
||||
TOTAL_STEP,
|
||||
ckpt_config_list,
|
||||
del_tmp_file,
|
||||
init_dist_and_model,
|
||||
reset_singletons,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("reset_singletons")
|
||||
@pytest.mark.parametrize("ckpt_config", ckpt_config_list)
|
||||
def test_storage_mm(ckpt_config, init_dist_and_model): # noqa # pylint: disable=unused-argument
|
||||
from internlm.utils.storage_manager import get_storage_manager, init_storage_manager
|
||||
|
||||
ckpt_config = Config(ckpt_config)
|
||||
enable_save_ckpt = get_config_value(ckpt_config, "enable_save_ckpt", False)
|
||||
async_upload_tmp_folder = get_config_value(ckpt_config, "async_upload_tmp_folder", False)
|
||||
async_upload = get_config_value(ckpt_config, "async_upload", False)
|
||||
|
||||
init_storage_manager(enable_save_ckpt, async_upload_tmp_folder, async_upload)
|
||||
get_storage_manager()
|
32
train.py
32
train.py
|
@ -35,7 +35,6 @@ from internlm.utils.common import (
|
|||
parse_args,
|
||||
)
|
||||
from internlm.utils.evaluation import evaluate_on_val_dls
|
||||
from internlm.utils.gputest import bench_gpu, bench_net
|
||||
from internlm.utils.logger import get_logger, initialize_uniscale_logger
|
||||
from internlm.utils.megatron_timers import megatron_timer as timer
|
||||
from internlm.utils.model_checkpoint import CheckpointManager
|
||||
|
@ -73,7 +72,6 @@ def main(args):
|
|||
total_steps = gpc.config.data.total_steps
|
||||
valid_every = gpc.config.data.valid_every
|
||||
label_smoothing = gpc.config.loss.label_smoothing
|
||||
lr = gpc.config.adam.lr
|
||||
|
||||
get_tflops_func = partial(
|
||||
get_megatron_flops,
|
||||
|
@ -96,21 +94,11 @@ def main(args):
|
|||
# initialize customed llm logger
|
||||
uniscale_logger = initialize_llm_logger(start_time=current_time)
|
||||
|
||||
# initialize and resume train state
|
||||
train_state = TrainState(gpc.config)
|
||||
|
||||
# initialize model
|
||||
model = initialize_model()
|
||||
|
||||
with open(args.config, "r") as f:
|
||||
config_lines = f.readlines()
|
||||
ckpt_manager = CheckpointManager(
|
||||
ckpt_config=gpc.config.ckpt,
|
||||
model=model,
|
||||
model_config=gpc.config.model,
|
||||
model_config_file="".join(config_lines),
|
||||
feishu_address=gpc.config.alert_address,
|
||||
)
|
||||
|
||||
# initialize loss function
|
||||
criterion = FlashGPTLMLoss(parallel_output=True, label_smoothing=label_smoothing)
|
||||
|
@ -118,15 +106,25 @@ def main(args):
|
|||
# initialize the train and validation data loader
|
||||
train_dl, dataset_types = get_train_data_loader(num_worker=4)
|
||||
val_dls = get_validation_data_loader()
|
||||
train_state.init_batch_sampler(train_dl)
|
||||
|
||||
# Loading model weights must be done before zero is initialized.
|
||||
ckpt_manager.try_load_model(current_time)
|
||||
# initialize and resume train state
|
||||
train_state = TrainState(gpc.config, train_dl.batch_sampler)
|
||||
|
||||
optimizer, beta2_scheduler, lr_scheduler = initialize_optimizer(model=model)
|
||||
|
||||
ckpt_manager = CheckpointManager(
|
||||
ckpt_config=gpc.config.ckpt,
|
||||
model=model,
|
||||
optimizer=optimizer,
|
||||
lr_scheduler=lr_scheduler,
|
||||
train_dl=train_dl,
|
||||
model_config=gpc.config.model,
|
||||
model_config_file="".join(config_lines),
|
||||
feishu_address=gpc.config.alert_address,
|
||||
)
|
||||
|
||||
# Loading other persistent training states.
|
||||
ckpt_manager.try_resume_training(lr_scheduler, optimizer, lr, train_state, train_dl)
|
||||
ckpt_manager.try_resume_training(train_state, current_time)
|
||||
|
||||
# initialize customed llm writer
|
||||
writer = Writer(
|
||||
|
@ -197,8 +195,6 @@ def main(args):
|
|||
for batch_count in range(train_state.batch_count, total_steps):
|
||||
if batch_count % 50 == 0:
|
||||
torch.cuda.empty_cache()
|
||||
bench_gpu()
|
||||
bench_net()
|
||||
|
||||
start_time = time.time()
|
||||
timer("one-batch").start()
|
||||
|
|
Loading…
Reference in New Issue