feat(ckpt): fix checkpoint bugs and add feature enhancements. (#259)

* fix(ckpt): ckpt bug fix and api refactor
1. fix latest ckpt query bug
2. add ckpt unit test
3. fix storage manager boto3/local client get_fns bug
4. fix only model load case zero fp32 buffer overwrite model weights bug.
5. add ckpt_type and add zero reload ci-test

* fix(ckpt): fix ckpt and trainer bug

* fix and refactor

* fix base on comment

* feat: add legacy api
pull/275/head^2
Guoteng 2023-09-05 17:40:48 +08:00 committed by GitHub
parent 860de0aa46
commit f6e007f95b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
15 changed files with 931 additions and 226 deletions

View File

@ -22,13 +22,16 @@ CHECKPOINT_EVERY = 50
ckpt = dict(
enable_save_ckpt=False, # enable ckpt save.
save_ckpt_folder=SAVE_CKPT_FOLDER, # Path to save training ckpt.
# load_ckpt_folder=LOAD_CKPT_FOLDER, # Ckpt path to resume training(load weights and scheduler/context states).
# load_model_only_folder=MODEL_ONLY_FOLDER, # Path to initialize with given model weights.
load_optimizer=True, # Wheter to load optimizer states when continuing training.
# load_ckpt_folder= dict(path=MODEL_ONLY_FOLDER, content=["model"], ckpt_type="normal"),
load_ckpt_folder="local:llm_ckpts/",
# 'load_ckpt_info' setting guide:
# 1. the 'path' indicate ckpt path,
# 2. the 'content means what states will be loaded, support: "model", "sampler", "optimizer", "scheduler", "all"
# 3. the ckpt_type means the type of checkpoint to be loaded, now only 'normal' type is supported.
load_ckpt_info=dict(path=MODEL_ONLY_FOLDER, content=("model",), ckpt_type="internlm"),
checkpoint_every=CHECKPOINT_EVERY,
async_upload=True, # async ckpt upload. (only work for boto3 ckpt)
async_upload_tmp_folder="/dev/shm/internlm_tmp_ckpt/", # path for temporarily files during asynchronous upload.
snapshot_ckpt_folder="/".join([SAVE_CKPT_FOLDER, "snapshot"]), # directory for snapshot ckpt storage path.
oss_snapshot_freq=int(CHECKPOINT_EVERY / 2), # snapshot ckpt save frequency.
)

View File

@ -115,19 +115,19 @@ If you want to load a model checkpoint when starting the training, you can confi
```python
SAVE_CKPT_FOLDER = "local:/path/to/save/ckpt"
MODEL_ONLY_FOLDER = "local:/path/to/load/init/model/ckpt"
LOAD_CKPT_FOLDER = "local:/path/to/load/resume/ckpt"
ckpt = dict(
save_ckpt_folder=SAVE_CKPT_FOLDER, # Path to save the model and optimizer checkpoints
checkpoint_every=float("inf"), # Save a checkpoint every specified number of steps, default value is inf
load_model_only_folder=MODEL_ONLY_FOLDER, # Path to load the initial model weights, only load model weights without loading optimizer weights, training will start from the first step
load_ckpt_folder=LOAD_CKPT_FOLDER, # Path to load the weights of the model and optimizer for resuming training, training will resume from the specified step
load_optimizer=True, # Whether to load optimizer weights when resuming training, default value is True
# When resuming training from a breakpoint,:
# (1) 'path' is the path of the loaded checkpoint.
# (2) 'content' indicates which state will be loaded, support: "model", "sampler", "optimizer", "scheduler", "all"
# (3) 'ckpt_type' indicates which type ckpt will be loaded, currently supported: "internlm"
load_ckpt_info=dict(path=MODEL_ONLY_FOLDER, content=("model",), ckpt_type="internlm"),
)
```
Note:
- `load_model_only_folder` and `load_ckpt_folder` cannot be set at the same time.
- If the path starts with `local:`, it means the file is stored in the local file system. If it starts with `boto3:`, it means the file is stored in the remote OSS.
The configuration for the model is as follows:

View File

@ -103,18 +103,17 @@ data = dict(
如果在启动训练时要加载模型 `checkpoint`,可进行如下相关配置:
```python
SAVE_CKPT_FOLDER = "local:/path/to/save/ckpt"
MODEL_ONLY_FOLDER = "local:/path/to/load/init/model/ckpt"
LOAD_CKPT_FOLDER = "local:/path/to/load/resume/ckpt"
ckpt = dict(
save_ckpt_folder=SAVE_CKPT_FOLDER, # 存储模型和优化器 checkpoint 的路径
checkpoint_every=float("inf"), # 每多少个 step 存储一次 checkpoint默认值为 inf
load_model_only_folder=MODEL_ONLY_FOLDER, # 加载模型初始权重的路径,只加载模型权重,不加载优化器权重,训练将从第一个 step 开始
load_ckpt_folder=LOAD_CKPT_FOLDER, # 断点续训时,加载模型和优化器等权重的路径,将从指定的 step 恢复训练
load_optimizer=True, # 断点续训时,是否需要加载优化器权重,默认值为 True
# 断点续训时,加载模型和优化器等权重的路径,将从指定的 step 恢复训练
# content 表示哪些状态会被加载,支持: "model", "sampler", "optimizer", "scheduler", "all"
# ckpt_type 表示加载的模型类型,目前支持: "internlm"
load_ckpt_info=dict(path=MODEL_ONLY_FOLDER, content=("model",), ckpt_type="internlm"),
)
```
注意:
- `load_model_only_folder`与`load_ckpt_folder`不能同时设置
- 路径若以 `local:` 为前缀,则存储在本地文件系统;若以 `boto3:` 为前缀,则存储在远程 oss 上
模型相关关键参数配置如下所示:

View File

@ -23,7 +23,15 @@ class TrainState:
train_dl (DataLoader): The DataLoader object used for training.
"""
def __init__(self, config) -> None:
def __init__(self, config, batch_sampler) -> None:
"""
Args:
config (Config): internlm config
batch_sampler (torch.utils.data.Sampler): Because the dataloader loading is
asynchronous and prefetched, the batch_sampler state maintained inside the
dataloader are faster then the actual training progress, so we copy the
batch_sampler as the anchor point of ckpt reload.
"""
# The number of batches produced by the data iterator
self.batch_count: int = 0
# Used to store the number of samples consumed in the current epoch
@ -43,9 +51,20 @@ class TrainState:
self.tensorboard_folder = config.tensorboard_folder
def init_batch_sampler(self, train_dl):
# Copy of the batch sampler from the DataLoader
self.batch_sampler = train_dl.batch_sampler.copy()
# learning rate
self.lr = config.adam.lr
# smapler state
if batch_sampler:
self.init_batch_sampler(batch_sampler)
def init_batch_sampler(self, batch_sampler):
"""
Args:
batch_sampler (torch.utils.data.Sampler): sampler.
"""
# make a copy of batch_sampler.
self.batch_sampler = batch_sampler.copy()
# Iterator for the batch sampler
self.batch_sampler_iter = iter(self.batch_sampler)
@ -61,26 +80,22 @@ class TrainState:
return json.dumps(info, indent=4, sort_keys=True)
def load_state_dict(self, other_stuffs, train_dl):
def load_state_dict(self, other_stuffs):
"""
Resumes training from a checkpoint.
Args:
other_stuffs (dict): Other information needed to resume training.
train_dl (DataLoader): The DataLoader object used for training.
"""
self.batch_count = other_stuffs["batch_count"] + 1 # here you need to shift a batch backward
self.num_consumed_samples_in_epoch = other_stuffs["num_consumed_samples_in_epoch"]
self.num_consumed_tokens = other_stuffs["num_consumed_tokens"]
self.inf_nan_skip_batches = other_stuffs["inf_nan_skip_batches"]
# compatible with previous checkpoints without this parameter
self.step_count = other_stuffs.get("step_count", other_stuffs["batch_count"]) + 1
# track the actual updates of sampler when using weighted sampling
if hasattr(self, "batch_sampler"):
self.batch_sampler = train_dl.batch_sampler.copy()
self.batch_sampler_iter = iter(self.batch_sampler)
# Because the ckpt save occurs after updating 'step_count',
# there is no need to increment 'step_count' here (Does our step count start from 0 ?),
# However, 'batch_count' is updating before ckpt storage, so it need to inc 1 when resume.
self.batch_count = other_stuffs["batch_count"] + 1 # here you need to shift a batch backward
self.step_count = other_stuffs.get("step_count", self.batch_count)
# resume tensorboard from older tensorboard_folder
self.resume_tb_folder = other_stuffs.get("tensorboard_folder", None)

View File

@ -12,7 +12,6 @@ from internlm.core.context import Config
from internlm.core.context import global_context as gpc
from internlm.utils.common import get_master_node
from internlm.utils.logger import get_logger
from internlm.utils.storage_manager import init_storage_manager
logger = get_logger(__file__)
@ -111,7 +110,7 @@ def args_sanity_check():
# processing the checkpoint config
ckpt = gpc.config.ckpt
if "enable_save_ckpt" not in ckpt:
ckpt._add_item("enable_save_ckpt", False)
ckpt._add_item("enable_save_ckpt", True)
# Saving checkpoint args.
if ckpt.enable_save_ckpt:
@ -137,9 +136,6 @@ def args_sanity_check():
if not ckpt.async_upload:
ckpt._add_item("async_upload_tmp_folder", None)
if "snapshot_ckpt_folder" not in ckpt:
ckpt._add_item("snapshot_ckpt_folder", os.path.join(ckpt.save_ckpt_folder, "snapshot"))
if "oss_snapshot_freq" not in ckpt:
ckpt._add_item("oss_snapshot_freq", float("inf")) # if oss_snapshot_freq not given, we disable.
else:
@ -149,44 +145,23 @@ def args_sanity_check():
ckpt._add_item("async_upload", False)
ckpt._add_item("async_upload_tmp_folder", None)
ckpt._add_item("snapshot_ckpt_folder", None)
ckpt._add_item("snapshot_ckpt_folder", None)
# Loading checkpoint args.
if "load_model_only_folder" not in ckpt:
ckpt._add_item("load_model_only_folder", None)
if "load_ckpt_folder" not in ckpt:
ckpt._add_item("load_ckpt_folder", None)
if "load_optimizer" not in ckpt:
ckpt._add_item("load_optimizer", True)
if "stop_file_path" not in ckpt:
ckpt._add_item("stop_file_path", None)
if "load_given_ckpt" not in ckpt:
# If 'load_given_ckpt' is not given, we set it to False, so internlm can have opportunity
if "auto_resume" not in ckpt:
# If 'auto_resume' is not given, we set it to True, so internlm can have opportunity
# to auto-load latest checkpoint.
ckpt._add_item("load_given_ckpt", False)
if ckpt.load_given_ckpt:
# Priority: load_given_ckpt(True) > latest_checkpoint > load_model_only_folder
if ckpt.load_ckpt_folder and ckpt.load_model_only_folder:
logger.warning(
"Detect 'load_ckpt_folder' and 'load_model_only_folder' set at the same time, \
and 'load_given_ckpt' is True, so internlm will load from 'load_ckpt_folder'"
)
ckpt.load_model_only_folder = None
ckpt._add_item("auto_resume", True)
if gpc.is_rank_for_log():
logger.info("+" * 15 + " Ckpt Info " + "+" * 15) # pylint: disable=W1201
logger.info(f"is enable save ckpt: {ckpt.enable_save_ckpt}")
logger.info(f"save_ckpt_folder: {ckpt.save_ckpt_folder}")
logger.info(f"checkpoint_every: {ckpt.checkpoint_every}")
logger.info(f"load_given_ckpt: {ckpt.load_given_ckpt}")
# initialization storage manager
init_storage_manager(ckpt)
# tensorboard writer config
if "enable_tb" not in gpc.config:
@ -459,3 +434,11 @@ def initialize_distributed_env(
if args_check:
args_sanity_check()
def get_config_value(config, key, defalut):
try:
value = config[key]
except KeyError:
value = defalut
return value

View File

View File

@ -0,0 +1,40 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
from internlm.initialize.launch import get_config_value
from internlm.utils.logger import get_logger
logger = get_logger(__file__)
def auto_resume_sanity_check(ckpt_config):
load_given_ckpt = get_config_value(ckpt_config, "load_given_ckpt", None)
if load_given_ckpt is None:
return True # default value is True
else:
return not load_given_ckpt
def ckpt_info_sanity_check(ckpt_config):
load_ckpt_folder = get_config_value(ckpt_config, "load_ckpt_folder", None)
load_model_only_folder = get_config_value(ckpt_config, "load_model_only_folder", None)
if load_model_only_folder is not None:
assert (
load_ckpt_folder is None
), "Detect 'load_ckpt_folder' and 'load_model_only_folder' set at the same time, \
# and 'load_given_ckpt' is True, so internlm will load from 'load_ckpt_folder'"
return dict(path=load_model_only_folder, content=("model",), ckpt_type="internlm")
else:
load_optimizer = get_config_value(ckpt_config, "load_optimizer", True)
if isinstance(load_ckpt_folder, str):
if load_optimizer:
return dict(path=load_ckpt_folder, content=("model", "sampler", "optimizer"), ckpt_type="internlm")
else:
return dict(path=load_ckpt_folder, content=("model", "sampler"), ckpt_type="internlm")
elif load_ckpt_folder is None:
return None
else:
assert f"Unsupport data type:'{type(load_ckpt_folder)}' for config.ckpt arg: 'load_ckpt_folder'"

View File

@ -1,6 +1,6 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
from .hybrid_zero_optim import HybridZeroOptimizer
from .hybrid_zero_optim import HybridZeroOptimizer, reload_zero_fp32_buff
__all__ = ["HybridZeroOptimizer"]
__all__ = ["HybridZeroOptimizer", "reload_zero_fp32_buff"]

View File

@ -775,3 +775,17 @@ class HybridZeroOptimizer(BaseOptimizer):
if "zero_devide_optim_plan" in states:
self.params_per_rank_id_dict = states["zero_devide_optim_plan"]
def reload_zero_fp32_buff(optimizer):
# If we use AMP optimizer, we need to update its fp32 buffer as newly loaded weights value.
# Or we must ensure that loading model weights must be done before zero is initialized.
if isinstance(optimizer, HybridZeroOptimizer):
for group_id, param_group in enumerate(optimizer.optim.param_groups):
if optimizer.param_group_has_params[group_id]:
# flatten fp16 params have already been updated by 'load_model_checkpoint'
fp16_flat_current_rank = optimizer._param_store.get_flat_fp16_param_by_rank_group(
optimizer._zero_local_rank, group_id
)
# param_group["params"] is fp32 flatten optimizer states of this zero rank.
param_group["params"][0].copy_(fp16_flat_current_rank.float())

View File

@ -3,37 +3,135 @@
import copy
import fcntl
import inspect
import os
import socket
import time
from enum import Enum
from typing import Dict
from typing import Callable, Dict, Union
import torch
from internlm.core.context import ParallelMode
from internlm.core.context import global_context as gpc
from internlm.core.trainer import TrainState
from internlm.initialize.launch import get_config_value
from internlm.initialize.legacy.launch import (
auto_resume_sanity_check,
ckpt_info_sanity_check,
)
from internlm.monitor import send_alert_message
from internlm.solver.optimizer import HybridZeroOptimizer
from internlm.solver.optimizer import HybridZeroOptimizer, reload_zero_fp32_buff
from internlm.utils.common import get_current_device
from internlm.utils.logger import get_logger
from internlm.utils.megatron_timers import megatron_timer as timer
from internlm.utils.storage_manager import (
get_fns,
get_storage_manager,
init_storage_manager,
llm_load,
llm_save,
try_get_storage_backend,
)
logger = get_logger(__file__)
class CheckpointType(Enum):
class CheckpointSaveType(Enum):
NORMAL_CHECKPOINT = 1
SNAPSHOT_CHECKPOINT = 2
class CheckpointLoadType(Enum):
INTERNLM = "internlm"
# The load method implemented by internlm by default does not use string representation types,
# but uses enumeration types defined in advance.
LOAD_TYPE_DICT = {
"internlm": CheckpointLoadType.INTERNLM,
}
class CheckpointLoadContent:
MODEL = "model"
SAMPLER = "sampler"
OPIMIZER = "optimizer"
SCHEDULAER = "scheduler"
class CheckpointLoadMethod:
"""The registration class of the checkpoint loading method,
users can define their own custom ckpt loading methods."""
LOAD_FUNC_SIG = None
LOAD_TYPE_FUNC = {}
@staticmethod
def convet_load_type(load_type: str) -> Union[CheckpointLoadType, str]:
if load_type.lower() in LOAD_TYPE_DICT:
# The ckpt load method implemented by internlm by default.
return LOAD_TYPE_DICT[load_type.lower()]
else:
# If it is a user-defined field, we do not do any conversion and represent it as a string.
return load_type
@staticmethod
def register_ckpt_load_type(load_type: Union[str, CheckpointLoadType], load_func: Callable):
if load_type in CheckpointLoadMethod.LOAD_TYPE_FUNC:
logger.warning(f"{load_type} has aleady been registed!")
return
CheckpointLoadMethod.LOAD_TYPE_FUNC.update({load_type: load_func})
if load_type == CheckpointLoadType.INTERNLM:
CheckpointLoadMethod.LOAD_FUNC_SIG = inspect.signature(load_func)
else:
if inspect.signature(load_func) != CheckpointLoadMethod.LOAD_FUNC_SIG:
logger.warning(
f"registe load model ckpt signature is not same with: {CheckpointLoadMethod.LOAD_FUNC_SIG}"
)
@staticmethod
def get_ckpt_load_type_func(load_type: Union[str, CheckpointLoadType]):
return CheckpointLoadMethod.LOAD_TYPE_FUNC[load_type]
class CheckpointLoadMask:
"""
According to the content field in the incoming ckpt_info, decide which components to load.
"""
LOAD_CONTENT_DICT = {
"model": CheckpointLoadContent.MODEL,
"sampler": CheckpointLoadContent.SAMPLER,
"optimizer": CheckpointLoadContent.OPIMIZER,
"scheduler": CheckpointLoadContent.SCHEDULAER,
}
def __init__(self, content: tuple) -> None:
self.load_set = set(map(lambda x: x.lower(), content))
if "all" in self.load_set:
self.load_set = set(CheckpointLoadMask.LOAD_CONTENT_DICT.values())
else:
self.load_set = set(map(lambda x: CheckpointLoadMask.LOAD_CONTENT_DICT[x.lower()], content))
def need_load(self, content: CheckpointLoadContent):
return content in self.load_set
def not_only_load(self, content: CheckpointLoadContent):
return content in self.load_set and len(self.load_set) > 1
def only_load(self, content: CheckpointLoadContent):
return set(content) == self.load_set
def __str__(self) -> str:
return f"{self.load_set}."
def __repr__(self) -> str:
return f"{self.load_set}."
def get_model_topology(model):
"""
Returns:
@ -55,6 +153,66 @@ def get_model_topology(model):
return topos
def try_load_internlm_ckpt(ckpt_mm, load_info, train_state: TrainState):
load_content_str = ""
load_ckpt_folder = load_info["path"]
load_content: CheckpointLoadMask = load_info["content"]
if gpc.is_rank_for_log():
logger.info(f"Try load_ckpt_folder: {load_ckpt_folder}")
if load_content.need_load(CheckpointLoadContent.MODEL):
load_model_checkpoint(folder=load_ckpt_folder, model=ckpt_mm.model)
load_content_str += f"{CheckpointLoadContent.MODEL}, "
if load_content.not_only_load(CheckpointLoadContent.MODEL):
# load training states.
load_context(load_ckpt_folder, train_state)
# load optimzier states.
if load_content.need_load(CheckpointLoadContent.OPIMIZER):
load_optimizer_checkpoint(load_ckpt_folder, ckpt_mm.optimizer)
load_content_str += f"{CheckpointLoadContent.OPIMIZER}, "
else:
if gpc.is_rank_for_log():
logger.warning("CheckpointManager has no 'optimizer', skip reload optim checkpoint!")
# load lr scheduler states.
if load_content.need_load(CheckpointLoadContent.SCHEDULAER):
if ckpt_mm.lr_scheduler:
load_scheduler(load_ckpt_folder, ckpt_mm.lr_scheduler, ckpt_mm.optimizer, train_state)
load_content_str += f"{CheckpointLoadContent.SCHEDULAER}, "
else:
if gpc.is_rank_for_log():
logger.warning("CheckpointManager has no 'lr_scheduler', skip reload lr_scheduler checkpoint!")
# load dataloader sampler states.
if load_content.need_load(CheckpointLoadContent.SAMPLER):
if hasattr(train_state, "batch_sampler") and not isinstance(
train_state.batch_sampler, torch.utils.data.sampler.BatchSampler
):
load_sampler(load_ckpt_folder, ckpt_mm.train_dl.batch_sampler)
# track the actual updates of sampler when using weighted sampling
train_state.init_batch_sampler(ckpt_mm.train_dl.batch_sampler)
load_content_str += f"{CheckpointLoadContent.SAMPLER}, "
else:
if gpc.is_rank_for_log():
logger.warning("CheckpointManager skip reload 'batch_sampler'")
# reload data state dict.
if hasattr(train_state, "data_state_dict"):
ckpt_mm.train_dl.dataset.load_state_dict(
llm_load(os.path.join(load_ckpt_folder, "sampler_0.pt")), ckpt_path=load_ckpt_folder
)
load_content_str += f"{CheckpointLoadContent.SAMPLER}, "
else:
if gpc.is_rank_for_log():
logger.warning(
"CheckpointManager has no 'data_state_dict', skip reload data_state_dict checkpoint!"
)
return load_content_str
def save_model_checkpoint(folder, model):
"""
Save the model according to the relationship between tp and dp. The principle is that the data of each tp
@ -233,15 +391,16 @@ def load_sampler(ckpt_path: str, sampler):
torch.cuda.empty_cache()
def load_context(ckpt_path: str, train_dl, train_state: TrainState):
def load_context(ckpt_path: str, train_state: TrainState):
context_stuffs = llm_load(os.path.join(ckpt_path, "context.pt"))
train_state.load_state_dict(context_stuffs, train_dl)
train_state.load_state_dict(context_stuffs)
if gpc.is_rank_for_log():
logger.info(f"reload train_state:{train_state}")
torch.cuda.empty_cache()
def load_scheduler(ckpt_path: str, lr_scheduler, optimizer, learning_rate, train_state: TrainState):
def load_scheduler(ckpt_path: str, lr_scheduler, optimizer, train_state: TrainState):
learning_rate = train_state.lr
scheduler_states = llm_load(os.path.join(ckpt_path, "schedulder.pt"))
if learning_rate != scheduler_states["base_lrs"][0] and gpc.is_rank_for_log():
logger.warning(
@ -270,7 +429,17 @@ def load_scheduler(ckpt_path: str, lr_scheduler, optimizer, learning_rate, train
class CheckpointManager:
"""StorageManagerContext"""
def __init__(self, ckpt_config, model, model_config=None, model_config_file=None, feishu_address=None) -> None:
def __init__(
self,
ckpt_config,
model,
train_dl=None,
optimizer=None,
lr_scheduler=None,
model_config=None,
model_config_file=None,
feishu_address=None,
) -> None:
"""
CheckpointManager is used to decide when to store ckpt. If it is an asynchronous
upload mode, you must call wait_async_upload_finish at the end of the program to wait
@ -283,22 +452,44 @@ class CheckpointManager:
lr_scheduler (object): lr_scheduler obj.
model_config (dict): model config.
"""
self.enable_save_ckpt = ckpt_config.enable_save_ckpt
self.checkpoint_every = ckpt_config.checkpoint_every
self.save_ckpt_folder = ckpt_config.save_ckpt_folder
self.snapshot_ckpt_folder = ckpt_config.snapshot_ckpt_folder
self.oss_snapshot_freq: int = ckpt_config.oss_snapshot_freq
self.stop_file_path = ckpt_config.stop_file_path
self.load_model_only_folder = ckpt_config.load_model_only_folder
self.enable_save_ckpt = get_config_value(ckpt_config, "enable_save_ckpt", False)
self.checkpoint_every = get_config_value(ckpt_config, "checkpoint_every", 100)
self.save_ckpt_folder = get_config_value(ckpt_config, "save_ckpt_folder", None)
self.oss_snapshot_freq: int = get_config_value(ckpt_config, "oss_snapshot_freq", 50)
self.stop_file_path = get_config_value(ckpt_config, "stop_file_path", None)
if self.save_ckpt_folder:
self.snapshot_ckpt_folder = get_config_value(
ckpt_config, "snapshot_ckpt_folder", os.path.join(self.save_ckpt_folder, "snapshot")
)
self.async_upload_tmp_folder = get_config_value(
ckpt_config, "async_upload_tmp_folder", "/dev/shm/internlm_tmp_ckpt/"
)
else:
self.snapshot_ckpt_folder = None
self.async_upload_tmp_folder = None
self.async_upload = get_config_value(ckpt_config, "async_upload", False)
# initialization storage manager
init_storage_manager(self.enable_save_ckpt, self.async_upload_tmp_folder, self.async_upload)
self.feishu_address = feishu_address
self.storage_manager = get_storage_manager()
self.snapshot_counter = 0
self.load_optimizer = gpc.config.ckpt.load_optimizer
self.model = model
self.optimizer = optimizer
self.lr_scheduler = lr_scheduler
self.train_dl = train_dl
self.model_config = model_config
self.model_config_file = model_config_file
# Register defalut internlm ckpt load type.
self.defalut_load_type_func = {CheckpointLoadType.INTERNLM: try_load_internlm_ckpt}
for ckpt_load_type in CheckpointLoadType:
CheckpointLoadMethod.register_ckpt_load_type(ckpt_load_type, self.defalut_load_type_func[ckpt_load_type])
# Init alter file.
if self.stop_file_path and gpc.get_global_rank() == 0:
dir_path = os.path.dirname(self.stop_file_path)
if dir_path != "" and not os.path.exists(dir_path):
@ -306,21 +497,35 @@ class CheckpointManager:
with open(self.stop_file_path, "w", encoding="utf-8") as f:
f.write("0")
if ckpt_config.load_given_ckpt is False:
# Priority: load_given_ckpt(True) > latest_checkpoint > load_model_only_folder
latest_ckpt_path = self.query_lastest_ckpt()
if latest_ckpt_path:
self.load_ckpt_folder = latest_ckpt_path
else:
# At this time, we have to load model init weights and train from step 0.
self.load_ckpt_folder = self.load_model_only_folder
else:
self.load_ckpt_folder = ckpt_config.load_ckpt_folder
self.load_ckpt_info = get_config_value(ckpt_config, "load_ckpt_info", None)
if self.load_ckpt_info is None: # (legacy): Try Compatible with old interfaces
self.load_ckpt_info = ckpt_info_sanity_check(ckpt_config)
if gpc.is_rank_for_log():
logger.info(f"load_ckpt_folder will set to :'{self.load_ckpt_folder}'")
if self.stop_file_path is None:
logger.warning("no set stop_file_path, quit_signal_handler is disable")
# Auto-reload latest checkpoint, it will overwrite the setting of 'load_ckpt_info'.
self.auto_resume = get_config_value(ckpt_config, "auto_resume", None)
if self.auto_resume is None: # (legacy): Try Compatible with old interfaces
self.auto_resume = auto_resume_sanity_check(ckpt_config)
if self.auto_resume:
self.load_ckpt_info = self.query_lastest_ckpt()
if self.stop_file_path is None and gpc.is_rank_for_log():
logger.warning("no set stop_file_path, quit_signal_handler is disable")
# convert to internal representation
if self.load_ckpt_info:
assert (
"path" in self.load_ckpt_info
and "content" in self.load_ckpt_info
and "ckpt_type" in self.load_ckpt_info
), "please set content in ckpt setting, eg: ckpt = dict(path='', content=['model'], ckpt_type='internlm')"
# replace load_ckpt
self.load_ckpt_info["content"] = CheckpointLoadMask(self.load_ckpt_info["content"])
self.load_ckpt_info["ckpt_type"] = CheckpointLoadMethod.convet_load_type(self.load_ckpt_info["ckpt_type"])
# test storage setting is ok.
if self.enable_save_ckpt:
self.try_ping_storage()
def quit_signal_handler(self, train_state) -> bool:
"""
@ -334,7 +539,7 @@ class CheckpointManager:
Returns:
bool: whether to quit.
"""
now_break, now_save_ckpt, save_type = False, False, CheckpointType.NORMAL_CHECKPOINT
now_break, now_save_ckpt, save_type = False, False, CheckpointSaveType.NORMAL_CHECKPOINT
if self.stop_file_path is None:
return now_break, now_save_ckpt, save_type
@ -365,24 +570,29 @@ now step_count is {train_state.step_count}",
return now_break, now_save_ckpt, save_type
def try_save_checkpoint(self, train_state):
if not self.enable_save_ckpt:
return False
save_ckpts, save_type = False, CheckpointType.NORMAL_CHECKPOINT
def is_now_to_save_ckpt(self, train_state) -> (bool, CheckpointSaveType, bool):
save_ckpts, save_type, now_break = False, CheckpointSaveType.NORMAL_CHECKPOINT, False
if self.oss_snapshot_freq > 1 and train_state.step_count % self.oss_snapshot_freq == 0:
save_ckpts, save_type = True, CheckpointType.SNAPSHOT_CHECKPOINT
save_ckpts, save_type = True, CheckpointSaveType.SNAPSHOT_CHECKPOINT
if train_state.step_count % self.checkpoint_every == 0:
save_ckpts, save_type = True, CheckpointType.NORMAL_CHECKPOINT
save_ckpts, save_type = True, CheckpointSaveType.NORMAL_CHECKPOINT
now_break, singal_save_ckpts, singal_save_type = self.quit_signal_handler(train_state)
if save_ckpts is False:
save_ckpts = singal_save_ckpts
save_type = singal_save_type
return save_ckpts, save_type, now_break
def try_save_checkpoint(self, train_state):
if not self.enable_save_ckpt:
return False
save_ckpts, save_type, now_break = self.is_now_to_save_ckpt(train_state)
if save_ckpts:
# Wait for the previous round of asynchronous upload storage to complete.
self.storage_manager.wait()
if save_type == CheckpointType.SNAPSHOT_CHECKPOINT:
if save_type == CheckpointSaveType.SNAPSHOT_CHECKPOINT:
# Snapshot number, with only two snapshots written alternately.
self.snapshot_counter = (self.snapshot_counter + 1) % 2
save_ckpt_folder = os.path.join(self.snapshot_ckpt_folder, f"{self.snapshot_counter}")
@ -412,7 +622,7 @@ now step_count is {train_state.step_count}",
Tuple(str, int): path of latest ckpt and ckpt step, if not found, None will return.
"""
ckpt_list = self.storage_manager.get_fns(self.save_ckpt_folder)
if len(ckpt_list) == 0:
if ckpt_list is None or len(ckpt_list) == 0:
return None, None
max_normal_step = 0
@ -435,14 +645,16 @@ now step_count is {train_state.step_count}",
ckpt_list_1 = self.storage_manager.get_fns(snapshot_path_0)
ckpt_list_2 = self.storage_manager.get_fns(snapshot_path_1)
max_step_0, max_step_1 = 0, 0
for ckpt in ckpt_list_1:
ckpt = ckpt.strip("/")
if ckpt.endswith(".step"):
max_step_0 = max(max_step_0, int(ckpt.split(".")[0]))
for ckpt in ckpt_list_2:
ckpt = ckpt.strip("/")
if ckpt.endswith(".step"):
max_step_1 = max(max_step_1, int(ckpt.split(".")[0]))
if ckpt_list_1:
for ckpt in ckpt_list_1:
ckpt = ckpt.strip("/")
if ckpt.endswith(".step"):
max_step_0 = max(max_step_0, int(ckpt.split(".")[0]))
if ckpt_list_2:
for ckpt in ckpt_list_2:
ckpt = ckpt.strip("/")
if ckpt.endswith(".step"):
max_step_1 = max(max_step_1, int(ckpt.split(".")[0]))
snap_load_path = snapshot_path_0 if max_step_0 > max_step_1 else snapshot_path_1
snap_step = max(max_step_0, max_step_1)
@ -452,11 +664,12 @@ now step_count is {train_state.step_count}",
def query_latest_snapshot_step_local(self):
max_step, max_step_path = 0, None
for root, _, files in os.walk(self.save_ckpt_folder, followlinks=True):
save_ckpt_folder = self.save_ckpt_folder.split(":")[1]
for root, _, files in os.walk(save_ckpt_folder, followlinks=True):
for fn in files:
fn = fn.strip("/")
if fn.endswith(".step"):
# We assume that both normal ckpt and snapshot ckpt will store the '.step' file
# We assume that both internlm ckpt and snapshot ckpt will store the '.step' file
# as an integrity flag.
step = int(fn.rsplit(".", maxsplit=1)[0])
if max_step < step:
@ -466,99 +679,53 @@ now step_count is {train_state.step_count}",
return max_step_path, max_step
def query_lastest_ckpt(self):
latest_checkpoint = None
latest_ckpt, step = None, -1
# Training was automatically restarted by the process, forcing the latest snapshot to be read.
if self.save_ckpt_folder:
if self.save_ckpt_folder.startswith("boto3"):
latest_checkpoint, step = self.query_latest_snapshot_step_boto3()
elif self.save_ckpt_folder.startswith("local"):
latest_checkpoint, step = self.query_latest_snapshot_step_local()
else:
latest_checkpoint, step = None, 0
backend, _ = try_get_storage_backend(self.save_ckpt_folder)
if backend == "boto3":
latest_ckpt, step = self.query_latest_snapshot_step_boto3()
if latest_ckpt and not latest_ckpt.startswith("boto3:"):
latest_ckpt = ":".join(["boto3", latest_ckpt])
elif backend == "local":
latest_ckpt, step = self.query_latest_snapshot_step_local()
if latest_ckpt and not latest_ckpt.startswith("local:"):
latest_ckpt = ":".join(["local", latest_ckpt])
if latest_checkpoint is not None:
if gpc.is_rank_for_log():
logger.info(f"Found latest ckpt : {latest_checkpoint}, step: {step}")
send_alert_message(
address=self.feishu_address,
message=f"Auto restart resume from ckpt-path: '{latest_checkpoint}', step : {step}",
)
else:
if gpc.is_rank_for_log():
send_alert_message(
address=self.feishu_address,
message=f"Can't find snapshot checkpoint, use default load-ckpt path: {latest_checkpoint}",
)
if gpc.is_rank_for_log():
logger.info(f"Found latest ckpt {latest_ckpt if latest_ckpt else 'None'}, step: {step}...")
return latest_checkpoint
return dict(path=latest_ckpt, content=("all",), ckpt_type="internlm")
def try_load_model(self, current_time=""):
model_load_path = None
def try_resume_training(self, train_state: TrainState, current_time=""):
if self.load_ckpt_folder and self.load_model_only_folder:
raise ValueError(
"Error, try to use both load_ckpt_folder and load_model_only_folder paths, \
if you only need to load model weights (for example starting an SFT task for the first time), \
set load_model_only_folder path, if you need to resume training from ckpt, \
set load_ckpt_folder or use default value \
(if is the default value, internlm will try to load the latest ckpt from save_ckpt_folder)"
)
if self.load_ckpt_folder:
if gpc.is_rank_for_log():
logger.info(
f"===========Resume training from `{self.load_ckpt_folder}` {current_time} on host:"
f"{socket.gethostname()}==========="
)
model_load_path = self.load_ckpt_folder
elif self.load_model_only_folder:
if gpc.is_rank_for_log():
logger.info(
f"===========Load Model from `{self.load_model_only_folder}` {current_time} on host:"
f"{socket.gethostname()}==========="
)
model_load_path = self.load_model_only_folder
else:
if self.load_ckpt_info is None or self.load_ckpt_info["path"] is None:
if gpc.is_rank_for_log():
logger.info(
f"===========New Run {current_time} on host:{socket.gethostname()},rank={gpc.get_global_rank()},"
f"tp={gpc.get_local_rank(ParallelMode.TENSOR)},pp={gpc.get_local_rank(ParallelMode.PIPELINE)},"
f"dp={gpc.get_local_rank(ParallelMode.DATA)}==========="
)
else:
load_path = self.load_ckpt_info["path"]
load_content = self.load_ckpt_info["content"]
load_type = self.load_ckpt_info["ckpt_type"]
# Loading model weights must be done before zero is initialized.
if model_load_path is not None:
load_model_checkpoint(folder=model_load_path, model=self.model)
load_func = CheckpointLoadMethod.get_ckpt_load_type_func(load_type)
load_content_str = load_func(self, self.load_ckpt_info, train_state)
def try_resume_training(self, lr_scheduler, optimizer, lr, train_state, train_dl):
"""Attempt to restore the training state of the last ckpt.
# If we only load model weight, we need rewrite zero optim's fp32 buffer.
if load_content.only_load(CheckpointLoadContent.MODEL) and isinstance(self.optimizer, HybridZeroOptimizer):
reload_zero_fp32_buff(self.optimizer)
Args:
lr_scheduler (_LRScheduler): lr_scheduler object.
optimizer (Optimizer): optimizer object.
lr (float): learning rate.
train_state (dict): traing states.
train_dl (DataLoader): traning dataloader object
"""
if self.load_ckpt_folder is not None:
# load optimzier states.
if self.load_optimizer:
load_optimizer_checkpoint(self.load_ckpt_folder, optimizer)
# load lr scheduler states.
load_scheduler(self.load_ckpt_folder, lr_scheduler, optimizer, lr, train_state)
# load training states.
load_context(self.load_ckpt_folder, train_dl, train_state)
# load dataloader sampler states.
if hasattr(train_state, "batch_sampler") and not isinstance(
train_state.batch_sampler, torch.utils.data.sampler.BatchSampler
):
load_sampler(self.load_ckpt_folder, train_dl.batch_sampler)
if hasattr(train_state, "data_state_dict"):
train_dl.dataset.load_state_dict(
llm_load(os.path.join(self.load_ckpt_folder, "sampler_0.pt")), ckpt_path=self.load_ckpt_folder
if gpc.is_rank_for_log():
logger.info(f"load_ckpt_info : {self.load_ckpt_info}")
logger.info(
f"===========Resume training from `{load_path}` {current_time} on host:"
f"{socket.gethostname()}==========="
)
self.optimizer = optimizer
self.lr_scheduler = lr_scheduler
if load_content_str:
logger.info(f"===========Load contents are: {load_content_str}")
def save_checkpoint(
self,
@ -600,8 +767,10 @@ set load_ckpt_folder or use default value \
)
if gpc.is_rank_for_log():
scheduler_states = scheduler.state_dict()
llm_save(os.path.join(folder, "schedulder.pt"), saved_obj=scheduler_states)
if scheduler:
scheduler_states = scheduler.state_dict()
llm_save(os.path.join(folder, "schedulder.pt"), saved_obj=scheduler_states)
if hasattr(train_state, "batch_sampler") and not isinstance(
train_state.batch_sampler, torch.utils.data.sampler.BatchSampler
):
@ -631,3 +800,12 @@ set load_ckpt_folder or use default value \
def set_save_folder(self, folder, step):
self.storage_manager.latest_save_folder = folder
self.storage_manager.latest_save_step = step
def try_ping_storage(self):
if gpc.get_global_rank() % 8 == 0:
buff = torch.ones((1, 64, 64), dtype=torch.bfloat16)
test_fn = os.path.join(self.save_ckpt_folder, f"pings/{socket.gethostname()}.ping")
self.storage_manager.save(test_fn, buff)
self.storage_manager.wait()
self.storage_manager.load(test_fn)
del buff

View File

@ -136,6 +136,22 @@ def compute_file_md5_by_chunk(file_name: str):
return hash_md5.hexdigest()
def try_get_storage_backend(path: str):
sre = path.split(":", maxsplit=1)
if len(sre) == 1:
if path.startswith("s3:"):
backend = "boto3"
if gpc.is_rank_for_log():
logger.warning(f"path: '{path}' not start with backend prefix, guess it is the backend of boto3.")
else:
backend = "local"
if gpc.is_rank_for_log():
logger.warning(f"path: '{path}' not start with backend prefix, guess it is the backend of local.")
return backend, sre
else:
return sre[0], sre[1] # (backend_prefix, splited_path)
class Boto3Client(StorageClient):
"""
Boto3Client
@ -231,21 +247,34 @@ class Boto3Client(StorageClient):
def assert_fp_exists(handler, bucket_name: str, fp: str, local_nvme_path: str): # pylint: disable=W0613
assert len(list(handler.client.list_objects(Bucket=bucket_name, Prefix=fp)["Contents"])) > 0, fp
@staticmethod
def is_fp_exists(handler, bucket_name: str, fp: str, local_nvme_path: str): # pylint: disable=W0613
re = handler.client.list_objects(Bucket=bucket_name, Prefix=fp)
if "Contents" in re:
return len(list(re["Contents"])) > 0
else:
return False
@staticmethod
def get_fns(handler, bucket_name: str, fp: str, local_nvme_path: str, *args, **kwargs): # pylint: disable=W0613
"""
Ref: https://stackoverflow.com/questions/54314563/
how-to-get-more-than-1000-objects-from-s3-by-using-list-objects-v2
"""
paginator = handler.client.get_paginator("list_objects_v2")
pages = paginator.paginate(Bucket=bucket_name, Prefix=fp)
folder_name_list = []
for page in pages:
if "Contents" in page:
for obj in page["Contents"]:
pth: str = obj["Key"]
folder_name_list.append(pth.split(fp, maxsplit=1)[1].strip("/").split("/", maxsplit=1)[0])
return list(set(folder_name_list))
if Boto3Client.is_fp_exists(handler, bucket_name, fp, None):
paginator = handler.client.get_paginator("list_objects_v2")
pages = paginator.paginate(Bucket=bucket_name, Prefix=fp)
folder_name_list = []
for page in pages:
if "Contents" in page:
for obj in page["Contents"]:
pth: str = obj["Key"]
folder_name_list.append(pth.split(fp, maxsplit=1)[1].strip("/").split("/", maxsplit=1)[0])
return list(set(folder_name_list))
else:
if gpc.is_rank_for_log():
logger.warning(f"'{fp}' not found!")
return None
@staticmethod
def async_upload_fileobj(handler, bucket_name: str, fp: str, local_nvme_path: str):
@ -297,9 +326,12 @@ class LocalClient(StorageClient):
@staticmethod
def get_fns(handler, folder):
assert isinstance(handler, LocalClient)
assert os.path.exists(folder), f"folder '{folder}' not exists!"
fns = os.listdir(folder)
return fns
if not os.path.exists(folder):
if gpc.is_rank_for_log():
logger.warning(f"'{folder}' not found!")
return None
else:
return os.listdir(folder)
@staticmethod
def delete_obj(handler, fp: str):
@ -436,10 +468,7 @@ class StorageManager(metaclass=SingletonMeta):
Args:
path (str): _description_
"""
try:
backend, path = path.split(":", maxsplit=1)
except Exception as exc:
raise AttributeError(f"Given path '{path}' is not startwith backend prefix:'local/boto3'") from exc
backend, path = try_get_storage_backend(path)
init_args = (None,)
if backend == "local":
@ -594,23 +623,24 @@ class StorageManager(metaclass=SingletonMeta):
if gpc.is_rank_for_log():
self.upload_count += 1
if self.async_mode:
if self.async_mode and self.latest_save_folder:
self.save(
os.path.join(self.latest_save_folder, f"{self.latest_save_step}.step"),
saved_obj=dict({"step": self.latest_save_step}),
async_upload=False,
)
self.latest_save_folder = None
storage_manager: StorageManager = None
def init_storage_manager(ckpt_config):
def init_storage_manager(enable_save_ckpt, async_upload_tmp_folder, async_upload):
global storage_manager
storage_manager = StorageManager(
ckpt_config.enable_save_ckpt,
tmp_local_folder=ckpt_config.async_upload_tmp_folder,
async_mode=ckpt_config.async_upload,
enable_save_ckpt,
tmp_local_folder=async_upload_tmp_folder,
async_mode=async_upload,
)

View File

@ -0,0 +1,143 @@
import os
import pytest
import torch
from internlm.core.context import global_context as gpc
from internlm.core.context.parallel_context import Config
from internlm.solver.optimizer.hybrid_zero_optim import HybridZeroOptimizer
from internlm.utils.common import SingletonMeta
# 1B
init_config = Config(
dict(
parallel=dict(zero1=1, pipeline=dict(size=1, interleaved_overlap=False), sequence_parallel=False, tensor=1),
model_type="INTERNLM",
adam=dict(
lr=1e-4,
),
data=dict(seq_len=2048, micro_num=1, micro_bsz=1, pack_sample_into_one=False, min_length=0, total_steps=9999),
model=dict(
checkpoint=False,
num_attention_heads=2,
embed_split_hidden=True,
vocab_size=103168,
embed_grad_scale=1,
parallel_output=True,
hidden_size=1024,
num_layers=2,
mlp_ratio=1,
apply_post_layer_norm=False,
dtype=torch.bfloat16,
norm_type="rmsnorm",
layer_norm_epsilon=1e-5,
use_flash_attn=True,
num_chunks=1,
),
resume_tb_folder="",
tensorboard_folder="",
)
)
def init_naive_model():
# let MODEL_INITIALIZER to work
import internlm.model.modeling_internlm # noqa # pylint: disable=unused-import
from internlm.core.naive_amp import NaiveAMPModel
from internlm.utils.registry import MODEL_INITIALIZER
model = MODEL_INITIALIZER.get_module(module_name=gpc.config.model_type)(**(init_config.model))
model = NaiveAMPModel(
model=model,
output_to_fp32=False,
dtype=torch.bfloat16,
sync_buffer=False,
)
return model
def init_naive_optim(model):
naive_optimizer = torch.optim.AdamW(
params=[{"params": model.parameters(), "weight_decay": 0.01}],
lr=1e-4,
betas=(0.9, 0.95),
eps=1e-8,
)
return naive_optimizer
def init_hybrid_optim(model):
naive_optimizer = torch.optim.AdamW(
params=[{"params": model.parameters(), "weight_decay": 0.01}],
lr=1e-4,
betas=(0.9, 0.95),
eps=1e-8,
)
optimizer = HybridZeroOptimizer(
naive_optimizer,
grad_scal_cfg=Config(
dict(
fp16=dict(
initial_scale=2**16,
min_scale=1,
growth_interval=1000,
),
growth_factor=2,
backoff_factor=0.5,
max_scale=2**24,
hysteresis=2,
)
),
zero_cfg=Config(
dict(
overlap_sync_grad=False,
overlap_sync_param=False,
reduce_bucket_size=512 * 1024 * 1024,
clip_grad_norm=1.0,
)
),
param_bcast_sync_handler=None,
)
return optimizer
@pytest.fixture(autouse=True, scope="function")
def reset_singletons():
SingletonMeta._instances = {}
def reset_seed():
from internlm.core.context.random import _SEED_MANAGER
_SEED_MANAGER.reset()
@pytest.fixture(scope="module")
def init_dist_and_model():
from internlm.initialize import initialize_distributed_env
os.environ["RANK"] = "0"
os.environ["LOCAL_RANK"] = "0"
os.environ["WORLD_SIZE"] = "1"
os.environ["MASTER_ADDR"] = "127.0.0.1"
os.environ["MASTER_PORT"] = "12377"
initialize_distributed_env(config=init_config, launcher="torch", master_port=12377, args_check=False)
# setup
print("set up", flush=True)
model = init_naive_model()
# opim = init_naive_optim(model)
opim = init_hybrid_optim(model)
yield model, opim
# teardown
del model, opim
print("teardown", flush=True)
gpc.destroy()
reset_seed()
def enter_flag(text):
print(f"{text} begin!", flush=True)
yield
print(f"{text} end!", flush=True)

View File

@ -0,0 +1,278 @@
import os
import shutil
from subprocess import PIPE, STDOUT, Popen
import pytest
import torch
from internlm.core.context import global_context as gpc
from internlm.core.context.parallel_context import Config
from internlm.core.trainer import TrainState
from internlm.solver.optimizer.hybrid_zero_optim import HybridZeroOptimizer
from internlm.utils.common import SingletonMeta
from internlm.utils.model_checkpoint import CheckpointManager
from internlm.utils.storage_manager import wait_async_upload_finish
from tests.test_utils.common_fixture import ( # noqa # pylint: disable=unused-import
init_dist_and_model,
reset_singletons,
)
TOTAL_STEP = 6
CKPT_EVERY = 4
SNPASHOT_EVERY = 2
OSS_NAME = os.environ["OSS_BUCKET_NAME"]
OSS_IP = os.environ["OSS_IP"]
USER = os.environ["USER"]
JOB_NAME = "CI_TEST"
LOCAL_SAVE_PATH = "local:local_ckpt"
BOTO_SAVE_PATH = f"boto3:s3://{OSS_NAME}.{OSS_IP}/{USER}/{JOB_NAME}"
BOTO_SAVE_PATH_NO_PRFIX = f"s3://{OSS_NAME}.{OSS_IP}/{USER}/{JOB_NAME}/"
ASYNC_TMP_FOLDER = "./async_tmp_folder"
def del_tmp_file():
try:
shutil.rmtree(ASYNC_TMP_FOLDER, ignore_errors=True)
except FileNotFoundError:
pass
try:
shutil.rmtree(LOCAL_SAVE_PATH.split(":")[1], ignore_errors=True)
except FileNotFoundError:
pass
try:
cmd = r"/mnt/petrelfs/share/sensesync --dryrun --deleteSrc cp " + BOTO_SAVE_PATH_NO_PRFIX + " / "
with Popen(cmd, stdout=PIPE, stderr=STDOUT, shell=True) as output:
results, presults = "", ""
for line in iter(output.stdout.readline, b""):
results += str(line.rstrip())
presults += line.rstrip().decode() + "\n"
print(presults, flush=True)
except FileNotFoundError:
pass
ckpt_config_list = [
# Old interface format
dict(
enable_save_ckpt=True,
save_ckpt_folder=BOTO_SAVE_PATH,
load_optimizer=True,
checkpoint_every=CKPT_EVERY,
async_upload=True,
async_upload_tmp_folder=ASYNC_TMP_FOLDER,
snapshot_ckpt_folder="/".join([BOTO_SAVE_PATH, "snapshot"]),
oss_snapshot_freq=SNPASHOT_EVERY,
stop_file_path=None,
load_model_only_folder=None,
load_given_ckpt=False,
load_ckpt_folder=None,
is_old_api=True,
),
# Old interface format
dict(
enable_save_ckpt=True,
save_ckpt_folder=LOCAL_SAVE_PATH,
load_optimizer=True,
checkpoint_every=CKPT_EVERY,
async_upload=False,
async_upload_tmp_folder=ASYNC_TMP_FOLDER,
snapshot_ckpt_folder="/".join([LOCAL_SAVE_PATH, "snapshot"]),
oss_snapshot_freq=SNPASHOT_EVERY,
stop_file_path=None,
load_model_only_folder=None,
load_given_ckpt=False,
load_ckpt_folder=None,
is_old_api=True,
),
# New interface format
dict(
enable_save_ckpt=True,
save_ckpt_folder=BOTO_SAVE_PATH,
checkpoint_every=CKPT_EVERY,
async_upload=True,
async_upload_tmp_folder=ASYNC_TMP_FOLDER,
oss_snapshot_freq=SNPASHOT_EVERY,
stop_file_path=None,
is_old_api=False,
auto_resume=True,
),
dict(
enable_save_ckpt=True,
save_ckpt_folder=LOCAL_SAVE_PATH,
checkpoint_every=CKPT_EVERY,
async_upload=False,
async_upload_tmp_folder=ASYNC_TMP_FOLDER,
oss_snapshot_freq=SNPASHOT_EVERY,
stop_file_path=None,
load_ckpt_folder=None,
is_old_api=False,
auto_resume=True,
),
]
def overwrite_optim_state(optim, set_value):
if isinstance(optim, HybridZeroOptimizer):
for group_id, p in optim._fp32_flat_param_groups_of_current_rank.items():
if optim._zero_local_rank not in optim.param_group_no_params_ranks[group_id]:
# p.copy_(torch.full_like(p, set_value, dtype=p.dtype))
p.data.fill_(set_value)
for group_id in range(len(optim._fp16_param_groups)):
if optim._zero_local_rank not in optim.param_group_no_params_ranks[group_id]:
fp16_p = optim._param_store.get_flat_fp16_param_by_rank_group(
rank=optim._zero_local_rank, group_id=group_id
)
fp16_p.fill_(set_value)
else:
for group in optim.param_groups:
for p in group["params"]:
# p.copy_(torch.full_like(p, set_value, dtype=p.dtype))
p.data.fill_(set_value)
def compare_optim_state(optim1, optim2):
re = True
if isinstance(optim1, HybridZeroOptimizer):
fp32_buff1 = optim1._fp32_flat_param_groups_of_current_rank
fp32_buff2 = optim2._fp32_flat_param_groups_of_current_rank
for group_id_1, group_id_2 in zip(fp32_buff1, fp32_buff2):
re &= group_id_1 == group_id_2
if optim1.zero_local_rank not in optim1.param_group_no_params_ranks[group_id_1]:
re &= torch.equal(fp32_buff1[group_id_1], fp32_buff1[group_id_2])
else:
for group1, group2 in zip(optim1.param_groups, optim2.param_groups):
for p1, p2 in zip(group1["params"], group2["params"]):
re &= torch.equal(p1, p2)
return re
def compare_optim_value(optim, value):
re = True
if isinstance(optim, HybridZeroOptimizer):
for group_id, p in optim._fp32_flat_param_groups_of_current_rank.items():
if optim._zero_local_rank not in optim.param_group_no_params_ranks[group_id]:
re &= torch.equal(p, torch.full_like(p, value, dtype=p.dtype))
for group_id in range(len(optim._fp16_param_groups)):
if optim._zero_local_rank not in optim.param_group_no_params_ranks[group_id]:
fp16_p = optim._param_store.get_flat_fp16_param_by_rank_group(
rank=optim._zero_local_rank, group_id=group_id
)
re &= torch.equal(fp16_p, torch.full_like(fp16_p, value, dtype=fp16_p.dtype))
else:
for group in optim.param_groups:
for p in group["params"]:
re &= torch.equal(p, torch.full_like(p, value, dtype=p.dtype))
return re
def overwrite_model_value(model, value):
for p in model.parameters():
# p.copy_(torch.full_like(p, value, dtype=p.dtype))
p.data.fill_(value)
def compare_model_value(model, value):
re = True
for p in model.parameters():
re &= torch.equal(p, torch.full_like(p, value, dtype=p.dtype))
return re
@pytest.fixture(scope="function")
def del_tmp():
del_tmp_file()
yield
del_tmp_file()
@pytest.mark.usefixtures("del_tmp")
@pytest.mark.usefixtures("reset_singletons")
@pytest.mark.parametrize("ckpt_config", ckpt_config_list)
def test_ckpt_mm(ckpt_config, init_dist_and_model): # noqa # pylint: disable=unused-import
from internlm.utils.model_checkpoint import CheckpointLoadMask, CheckpointLoadType
ckpt_config = Config(ckpt_config)
assert ckpt_config.checkpoint_every < TOTAL_STEP
assert ckpt_config.oss_snapshot_freq < TOTAL_STEP
model, opim = init_dist_and_model
train_state = TrainState(gpc.config, None)
if isinstance(opim, HybridZeroOptimizer):
print("Is HybridZeroOptimizer!", flush=True)
else:
print("Is naive Adam!", flush=True)
ckpt_mm = CheckpointManager(ckpt_config, model=model, optimizer=opim)
latest_ckpt_step = None
for i in range(TOTAL_STEP + 1):
overwrite_model_value(model, i)
overwrite_optim_state(opim, i)
train_state.batch_count = i
train_state.step_count += 1
save_ckpts, _, _ = ckpt_mm.is_now_to_save_ckpt(train_state)
if save_ckpts:
latest_ckpt_step = i
ckpt_mm.try_save_checkpoint(train_state)
wait_async_upload_finish()
latest_ckpt_info = ckpt_mm.query_lastest_ckpt()
assert latest_ckpt_info is not None
latest_ckpt = latest_ckpt_info["path"]
if ckpt_mm.save_ckpt_folder.startswith("local"):
assert latest_ckpt == "local:local_ckpt/snapshot/0", latest_ckpt
else:
assert latest_ckpt == f"{BOTO_SAVE_PATH}/snapshot/0", latest_ckpt
del ckpt_mm
SingletonMeta._instances = {}
ckpt_mm = CheckpointManager(ckpt_config, model=model, optimizer=opim)
ckpt_mm.try_resume_training(train_state)
assert latest_ckpt_step == 5
assert train_state.step_count == 6
assert train_state.batch_count == 6
assert compare_optim_value(ckpt_mm.optimizer, latest_ckpt_step), ckpt_mm.optimizer.param_groups[0]["params"][0]
assert compare_model_value(ckpt_mm.model, latest_ckpt_step), list(ckpt_mm.model.parameters())[0][0]
if ckpt_mm.save_ckpt_folder.startswith("local:"):
ckpt_mm.load_ckpt_info = dict(
path=os.path.join(LOCAL_SAVE_PATH, "4"),
content=CheckpointLoadMask(("all",)),
ckpt_type=CheckpointLoadType.INTERNLM,
)
else:
ckpt_mm.load_ckpt_info = dict(
path=os.path.join(BOTO_SAVE_PATH, "4"),
content=CheckpointLoadMask(("all",)),
ckpt_type=CheckpointLoadType.INTERNLM,
)
ckpt_mm.try_resume_training(train_state)
assert train_state.step_count == 4
assert train_state.batch_count == 4
assert compare_optim_value(ckpt_mm.optimizer, 3), ckpt_mm.optimizer.param_groups[0]["params"][0]
assert compare_model_value(ckpt_mm.model, 3), list(ckpt_mm.model.parameters())[0][0]
@pytest.mark.usefixtures("del_tmp")
@pytest.mark.usefixtures("reset_singletons")
@pytest.mark.parametrize("ckpt_config", ckpt_config_list)
def test_ckpt_mm_ping(ckpt_config, init_dist_and_model): # noqa # pylint: disable=unused-import
ckpt_config = Config(ckpt_config)
model, opim = init_dist_and_model
SingletonMeta._instances = {}
ckpt_mm = CheckpointManager(ckpt_config, model=model, optimizer=opim)
ckpt_mm.try_ping_storage()
if __name__ == "__main__":
pytest.main()

View File

@ -0,0 +1,26 @@
import pytest
from internlm.core.context.parallel_context import Config
from internlm.initialize.launch import get_config_value
from tests.test_utils.common_fixture import ( # noqa # pylint: disable=unused-import
BOTO_SAVE_PATH,
TOTAL_STEP,
ckpt_config_list,
del_tmp_file,
init_dist_and_model,
reset_singletons,
)
@pytest.mark.usefixtures("reset_singletons")
@pytest.mark.parametrize("ckpt_config", ckpt_config_list)
def test_storage_mm(ckpt_config, init_dist_and_model): # noqa # pylint: disable=unused-argument
from internlm.utils.storage_manager import get_storage_manager, init_storage_manager
ckpt_config = Config(ckpt_config)
enable_save_ckpt = get_config_value(ckpt_config, "enable_save_ckpt", False)
async_upload_tmp_folder = get_config_value(ckpt_config, "async_upload_tmp_folder", False)
async_upload = get_config_value(ckpt_config, "async_upload", False)
init_storage_manager(enable_save_ckpt, async_upload_tmp_folder, async_upload)
get_storage_manager()

View File

@ -35,7 +35,6 @@ from internlm.utils.common import (
parse_args,
)
from internlm.utils.evaluation import evaluate_on_val_dls
from internlm.utils.gputest import bench_gpu, bench_net
from internlm.utils.logger import get_logger, initialize_uniscale_logger
from internlm.utils.megatron_timers import megatron_timer as timer
from internlm.utils.model_checkpoint import CheckpointManager
@ -73,7 +72,6 @@ def main(args):
total_steps = gpc.config.data.total_steps
valid_every = gpc.config.data.valid_every
label_smoothing = gpc.config.loss.label_smoothing
lr = gpc.config.adam.lr
get_tflops_func = partial(
get_megatron_flops,
@ -96,21 +94,11 @@ def main(args):
# initialize customed llm logger
uniscale_logger = initialize_llm_logger(start_time=current_time)
# initialize and resume train state
train_state = TrainState(gpc.config)
# initialize model
model = initialize_model()
with open(args.config, "r") as f:
config_lines = f.readlines()
ckpt_manager = CheckpointManager(
ckpt_config=gpc.config.ckpt,
model=model,
model_config=gpc.config.model,
model_config_file="".join(config_lines),
feishu_address=gpc.config.alert_address,
)
# initialize loss function
criterion = FlashGPTLMLoss(parallel_output=True, label_smoothing=label_smoothing)
@ -118,15 +106,25 @@ def main(args):
# initialize the train and validation data loader
train_dl, dataset_types = get_train_data_loader(num_worker=4)
val_dls = get_validation_data_loader()
train_state.init_batch_sampler(train_dl)
# Loading model weights must be done before zero is initialized.
ckpt_manager.try_load_model(current_time)
# initialize and resume train state
train_state = TrainState(gpc.config, train_dl.batch_sampler)
optimizer, beta2_scheduler, lr_scheduler = initialize_optimizer(model=model)
ckpt_manager = CheckpointManager(
ckpt_config=gpc.config.ckpt,
model=model,
optimizer=optimizer,
lr_scheduler=lr_scheduler,
train_dl=train_dl,
model_config=gpc.config.model,
model_config_file="".join(config_lines),
feishu_address=gpc.config.alert_address,
)
# Loading other persistent training states.
ckpt_manager.try_resume_training(lr_scheduler, optimizer, lr, train_state, train_dl)
ckpt_manager.try_resume_training(train_state, current_time)
# initialize customed llm writer
writer = Writer(
@ -197,8 +195,6 @@ def main(args):
for batch_count in range(train_state.batch_count, total_steps):
if batch_count % 50 == 0:
torch.cuda.empty_cache()
bench_gpu()
bench_net()
start_time = time.time()
timer("one-batch").start()