mirror of https://github.com/InternLM/InternLM
				
				
				
			fix(fsdp): fix conflicts
						commit
						aedd88e5a7
					
				| 
						 | 
				
			
			@ -1,7 +1,8 @@
 | 
			
		|||
JOB_NAME = "7b_train"
 | 
			
		||||
DO_ALERT = False
 | 
			
		||||
 | 
			
		||||
SEQ_LEN = 2048
 | 
			
		||||
HIDDEN_SIZE = 4096
 | 
			
		||||
SEQ_LEN = 256
 | 
			
		||||
HIDDEN_SIZE = 512
 | 
			
		||||
NUM_ATTENTION_HEAD = 32
 | 
			
		||||
MLP_RATIO = 8 / 3
 | 
			
		||||
NUM_LAYER = 32
 | 
			
		||||
| 
						 | 
				
			
			@ -22,14 +23,22 @@ CHECKPOINT_EVERY = 20
 | 
			
		|||
ckpt = dict(
 | 
			
		||||
    enable_save_ckpt=False,  # enable ckpt save.
 | 
			
		||||
    save_ckpt_folder=SAVE_CKPT_FOLDER,  # Path to save training ckpt.
 | 
			
		||||
    load_ckpt_folder=LOAD_CKPT_FOLDER, # Ckpt path to resume training(load weights and scheduler/context states).
 | 
			
		||||
    # load_ckpt_folder=LOAD_CKPT_FOLDER, # Ckpt path to resume training(load weights and scheduler/context states).
 | 
			
		||||
    load_given_ckpt = False,
 | 
			
		||||
    # load_model_only_folder=MODEL_ONLY_FOLDER, # Path to initialize with given model weights.
 | 
			
		||||
    load_optimizer=True,  # Wheter to load optimizer states when continuing training.
 | 
			
		||||
 | 
			
		||||
    # load_ckpt_folder= dict(path=MODEL_ONLY_FOLDER, content=["model"], ckpt_type="normal"),
 | 
			
		||||
    load_ckpt_folder="local:llm_ckpts/",
 | 
			
		||||
    # 'load_ckpt_info' setting guide:
 | 
			
		||||
    # 1. the 'path' indicate ckpt path,
 | 
			
		||||
    # 2. the 'content‘ means what states will be loaded, support: "model", "sampler", "optimizer", "scheduler", "all"
 | 
			
		||||
    # 3. the ’ckpt_type‘ means the type of checkpoint to be loaded, now only 'normal' type is supported.
 | 
			
		||||
    load_ckpt_info=dict(path=MODEL_ONLY_FOLDER, content=("model",), ckpt_type="internlm"),
 | 
			
		||||
 | 
			
		||||
    checkpoint_every=CHECKPOINT_EVERY,
 | 
			
		||||
    async_upload=True,  # async ckpt upload. (only work for boto3 ckpt)
 | 
			
		||||
    async_upload_tmp_folder="/dev/shm/internlm_tmp_ckpt/",  # path for temporarily files during asynchronous upload.
 | 
			
		||||
    snapshot_ckpt_folder="/".join([SAVE_CKPT_FOLDER, "snapshot"]),  # directory for snapshot ckpt storage path.
 | 
			
		||||
    oss_snapshot_freq=int(CHECKPOINT_EVERY / 2),  # snapshot ckpt save frequency.
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -46,7 +55,7 @@ data = dict(
 | 
			
		|||
    # defaults to 0, means disable evaluate
 | 
			
		||||
    valid_every=50,
 | 
			
		||||
    pack_sample_into_one=False,
 | 
			
		||||
    total_steps=50000,
 | 
			
		||||
    total_steps=30,
 | 
			
		||||
    skip_batches="",
 | 
			
		||||
    rampup_batch_size="",
 | 
			
		||||
    # Datasets with less than 50 rows will be discarded
 | 
			
		||||
| 
						 | 
				
			
			@ -145,8 +154,17 @@ parallel = dict(
 | 
			
		|||
    pipeline=dict(size=1, interleaved_overlap=True),
 | 
			
		||||
    tensor=1,
 | 
			
		||||
    sequence_parallel=False,
 | 
			
		||||
    use_fsdp = True,
 | 
			
		||||
    use_fsdp=True,
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
cudnn_deterministic = False
 | 
			
		||||
cudnn_benchmark = False
 | 
			
		||||
 | 
			
		||||
monitor = dict(
 | 
			
		||||
    # feishu alert configs
 | 
			
		||||
    alert=dict(
 | 
			
		||||
        enable_feishu_alert=DO_ALERT,
 | 
			
		||||
        feishu_alert_address=None,  # feishu webhook to send alert message
 | 
			
		||||
        light_monitor_address=None,  # light_monitor address to send heartbeat
 | 
			
		||||
    ),
 | 
			
		||||
)
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -112,19 +112,19 @@ If you want to load a model checkpoint when starting the training, you can confi
 | 
			
		|||
 | 
			
		||||
```python
 | 
			
		||||
SAVE_CKPT_FOLDER = "local:/path/to/save/ckpt"
 | 
			
		||||
MODEL_ONLY_FOLDER = "local:/path/to/load/init/model/ckpt"
 | 
			
		||||
LOAD_CKPT_FOLDER = "local:/path/to/load/resume/ckpt"
 | 
			
		||||
ckpt = dict(
 | 
			
		||||
    save_ckpt_folder=SAVE_CKPT_FOLDER,  # Path to save the model and optimizer checkpoints
 | 
			
		||||
    checkpoint_every=float("inf"),  # Save a checkpoint every specified number of steps, default value is inf
 | 
			
		||||
    load_model_only_folder=MODEL_ONLY_FOLDER,  # Path to load the initial model weights, only load model weights without loading optimizer weights, training will start from the first step
 | 
			
		||||
    load_ckpt_folder=LOAD_CKPT_FOLDER,  # Path to load the weights of the model and optimizer for resuming training, training will resume from the specified step
 | 
			
		||||
    load_optimizer=True,  # Whether to load optimizer weights when resuming training, default value is True
 | 
			
		||||
    # When resuming training from a breakpoint,:
 | 
			
		||||
    # (1) 'path' is the path of the loaded checkpoint.
 | 
			
		||||
    # (2) 'content' indicates which state will be loaded, support: "model", "sampler", "optimizer", "scheduler", "all"
 | 
			
		||||
    # (3) 'ckpt_type' indicates which type ckpt will be loaded, currently supported: "internlm"
 | 
			
		||||
    load_ckpt_info=dict(path=MODEL_ONLY_FOLDER, content=("model",), ckpt_type="internlm"),
 | 
			
		||||
)
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
Note:
 | 
			
		||||
- `load_model_only_folder` and `load_ckpt_folder` cannot be set at the same time.
 | 
			
		||||
- If the path starts with `local:`, it means the file is stored in the local file system. If it starts with `boto3:`, it means the file is stored in the remote OSS.
 | 
			
		||||
 | 
			
		||||
The configuration for the model is as follows:
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -101,18 +101,17 @@ data = dict(
 | 
			
		|||
如果在启动训练时要加载模型 `checkpoint`,可进行如下相关配置:
 | 
			
		||||
```python
 | 
			
		||||
SAVE_CKPT_FOLDER = "local:/path/to/save/ckpt"
 | 
			
		||||
MODEL_ONLY_FOLDER = "local:/path/to/load/init/model/ckpt"
 | 
			
		||||
LOAD_CKPT_FOLDER = "local:/path/to/load/resume/ckpt"
 | 
			
		||||
ckpt = dict(
 | 
			
		||||
    save_ckpt_folder=SAVE_CKPT_FOLDER,  # 存储模型和优化器 checkpoint 的路径
 | 
			
		||||
    checkpoint_every=float("inf"),  # 每多少个 step 存储一次 checkpoint,默认值为 inf
 | 
			
		||||
    load_model_only_folder=MODEL_ONLY_FOLDER,  # 加载模型初始权重的路径,只加载模型权重,不加载优化器权重,训练将从第一个 step 开始
 | 
			
		||||
    load_ckpt_folder=LOAD_CKPT_FOLDER,  # 断点续训时,加载模型和优化器等权重的路径,将从指定的 step 恢复训练
 | 
			
		||||
    load_optimizer=True,  # 断点续训时,是否需要加载优化器权重,默认值为 True
 | 
			
		||||
    # 断点续训时,加载模型和优化器等权重的路径,将从指定的 step 恢复训练
 | 
			
		||||
    # content 表示哪些状态会被加载,支持: "model", "sampler", "optimizer", "scheduler", "all"
 | 
			
		||||
    # ckpt_type 表示加载的模型类型,目前支持: "internlm"
 | 
			
		||||
    load_ckpt_info=dict(path=MODEL_ONLY_FOLDER, content=("model",), ckpt_type="internlm"),
 | 
			
		||||
)
 | 
			
		||||
```
 | 
			
		||||
注意:
 | 
			
		||||
- `load_model_only_folder`与`load_ckpt_folder`不能同时设置
 | 
			
		||||
- 路径若以 `local:` 为前缀,则存储在本地文件系统;若以 `boto3:` 为前缀,则存储在远程 oss 上
 | 
			
		||||
 | 
			
		||||
模型相关关键参数配置如下所示:
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -18,6 +18,7 @@ import torch.distributed as dist
 | 
			
		|||
 | 
			
		||||
from internlm.utils.common import SingletonMeta
 | 
			
		||||
from internlm.utils.logger import get_logger
 | 
			
		||||
from internlm.utils.timeout import LLM_NCCL_TIMEOUT
 | 
			
		||||
 | 
			
		||||
from . import process_group_initializer as pgroup_initializer
 | 
			
		||||
from .process_group_initializer import ParallelMode
 | 
			
		||||
| 
						 | 
				
			
			@ -36,7 +37,7 @@ class Config(dict):
 | 
			
		|||
        config (dict): The dict object to be wrapped.
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    def __init__(self, config: dict = None):
 | 
			
		||||
    def __init__(self, config: dict = None):  # pylint: disable=W0231
 | 
			
		||||
        if config is not None:
 | 
			
		||||
            for k, v in config.items():
 | 
			
		||||
                self._add_item(k, v)
 | 
			
		||||
| 
						 | 
				
			
			@ -100,7 +101,7 @@ class Config(dict):
 | 
			
		|||
 | 
			
		||||
        module_name = filepath.stem
 | 
			
		||||
        source_file = SourceFileLoader(fullname=str(module_name), path=str(filepath))
 | 
			
		||||
        module = source_file.load_module()  # pylint: disable=W4902,E1120
 | 
			
		||||
        module = source_file.load_module()  # pylint: disable=W4902,E1120,W1505
 | 
			
		||||
 | 
			
		||||
        # load into config
 | 
			
		||||
        config = Config()
 | 
			
		||||
| 
						 | 
				
			
			@ -374,12 +375,22 @@ class ParallelContext(metaclass=SingletonMeta):
 | 
			
		|||
        """
 | 
			
		||||
        # initialize the default process group
 | 
			
		||||
        init_method = f"tcp://[{host}]:{port}"
 | 
			
		||||
        dist.init_process_group(rank=rank, world_size=world_size, backend=backend, init_method=init_method)
 | 
			
		||||
        dist.init_process_group(
 | 
			
		||||
            rank=rank,
 | 
			
		||||
            world_size=world_size,
 | 
			
		||||
            backend=backend,
 | 
			
		||||
            init_method=init_method,
 | 
			
		||||
            timeout=LLM_NCCL_TIMEOUT,
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        # None will give the default global process group for pytorch dist operations
 | 
			
		||||
        ranks = list(range(world_size))
 | 
			
		||||
        if use_cpu:
 | 
			
		||||
            cpu_group = dist.new_group(ranks, backend="gloo") if dist.get_backend() != "gloo" else None
 | 
			
		||||
            cpu_group = (
 | 
			
		||||
                dist.new_group(ranks, backend="gloo", timeout=LLM_NCCL_TIMEOUT)
 | 
			
		||||
                if dist.get_backend() != "gloo"
 | 
			
		||||
                else None
 | 
			
		||||
            )
 | 
			
		||||
        else:
 | 
			
		||||
            cpu_group = None
 | 
			
		||||
        self._register_dist(rank, world_size, dist.GroupMember.WORLD, cpu_group, ranks, ParallelMode.GLOBAL)
 | 
			
		||||
| 
						 | 
				
			
			@ -528,6 +539,7 @@ class ParallelContext(metaclass=SingletonMeta):
 | 
			
		|||
        if dpseed_with_tpoffset:
 | 
			
		||||
            dp_seed = seed + pipeline_offset * 1024
 | 
			
		||||
        add_seed(ParallelMode.DATA, dp_seed)
 | 
			
		||||
        add_seed(ParallelMode.DUMMY, dp_seed)
 | 
			
		||||
 | 
			
		||||
        # model parallel seeds are different across ranks
 | 
			
		||||
        if self.is_initialized(ParallelMode.TENSOR):
 | 
			
		||||
| 
						 | 
				
			
			@ -535,7 +547,11 @@ class ParallelContext(metaclass=SingletonMeta):
 | 
			
		|||
            tp_seed = seed + tp_rank + pipeline_offset * 1024
 | 
			
		||||
            add_seed(ParallelMode.TENSOR, tp_seed)
 | 
			
		||||
 | 
			
		||||
        set_mode(ParallelMode.DATA)
 | 
			
		||||
        # we do not set the random state mode to ParallelMode.DATA until model is built (instead, we use a dummy mode
 | 
			
		||||
        # during model construction), this is because the random state will be different in different tensor parallel
 | 
			
		||||
        # device of the same data parallel group. The underlying reason is that the device of tp_rank = 0 will perform
 | 
			
		||||
        # additional random operations during the RowParallelLinear module building process.
 | 
			
		||||
        set_mode(ParallelMode.DUMMY)
 | 
			
		||||
 | 
			
		||||
        seeds = get_seeds()
 | 
			
		||||
        seed_str = ", ".join([f"{k}: {v}" for k, v in seeds.items()])
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -9,6 +9,8 @@ from enum import Enum
 | 
			
		|||
 | 
			
		||||
import torch.distributed as dist
 | 
			
		||||
 | 
			
		||||
from internlm.utils.timeout import LLM_NCCL_TIMEOUT
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# parallel modes
 | 
			
		||||
class ParallelMode(Enum):
 | 
			
		||||
| 
						 | 
				
			
			@ -40,6 +42,9 @@ class ParallelMode(Enum):
 | 
			
		|||
    # then manual communication only happens between inter-fsdp-modules, while intra-modules reduction is done by fsdp
 | 
			
		||||
    ZERO3_DP = "zero3_dp"
 | 
			
		||||
 | 
			
		||||
    # dummy mode, only used during mode construction
 | 
			
		||||
    DUMMY = "dummy"
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class ProcessGroupInitializer(ABC):
 | 
			
		||||
    """An object, knowing the parallelism configuration, that initializes parallel groups.
 | 
			
		||||
| 
						 | 
				
			
			@ -111,9 +116,13 @@ class Initializer_Data(ProcessGroupInitializer):
 | 
			
		|||
 | 
			
		||||
        for i in range(self.rank_num_per_dp_group):
 | 
			
		||||
            ranks = [i + j * self.rank_num_per_dp_group for j in range(self.data_parallel_size)]
 | 
			
		||||
            group = dist.new_group(ranks)
 | 
			
		||||
            group = dist.new_group(ranks, timeout=LLM_NCCL_TIMEOUT)
 | 
			
		||||
            if use_cpu:
 | 
			
		||||
                group_cpu = dist.new_group(ranks, backend="gloo") if dist.get_backend() != "gloo" else group
 | 
			
		||||
                group_cpu = (
 | 
			
		||||
                    dist.new_group(ranks, backend="gloo", timeout=LLM_NCCL_TIMEOUT)
 | 
			
		||||
                    if dist.get_backend() != "gloo"
 | 
			
		||||
                    else group
 | 
			
		||||
                )
 | 
			
		||||
            else:
 | 
			
		||||
                group_cpu = None
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -163,9 +172,13 @@ class Initializer_Model(ProcessGroupInitializer):
 | 
			
		|||
 | 
			
		||||
        for i in range(self.num_group):
 | 
			
		||||
            ranks = [i * self.rank_num_per_group + j for j in range(self.rank_num_per_group)]
 | 
			
		||||
            group = dist.new_group(ranks)
 | 
			
		||||
            group = dist.new_group(ranks, timeout=LLM_NCCL_TIMEOUT)
 | 
			
		||||
            if use_cpu:
 | 
			
		||||
                group_cpu = dist.new_group(ranks, backend="gloo") if dist.get_backend() != "gloo" else group
 | 
			
		||||
                group_cpu = (
 | 
			
		||||
                    dist.new_group(ranks, backend="gloo", timeout=LLM_NCCL_TIMEOUT)
 | 
			
		||||
                    if dist.get_backend() != "gloo"
 | 
			
		||||
                    else group
 | 
			
		||||
                )
 | 
			
		||||
            else:
 | 
			
		||||
                group_cpu = None
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -223,9 +236,13 @@ class Initializer_Pipeline(ProcessGroupInitializer):
 | 
			
		|||
                    )
 | 
			
		||||
                )
 | 
			
		||||
                pipe_group_size = len(ranks)
 | 
			
		||||
                pipe_group = dist.new_group(ranks)
 | 
			
		||||
                pipe_group = dist.new_group(ranks, timeout=LLM_NCCL_TIMEOUT)
 | 
			
		||||
                if use_cpu:
 | 
			
		||||
                    group_cpu = dist.new_group(ranks, backend="gloo") if dist.get_backend() != "gloo" else pipe_group
 | 
			
		||||
                    group_cpu = (
 | 
			
		||||
                        dist.new_group(ranks, backend="gloo", timeout=LLM_NCCL_TIMEOUT)
 | 
			
		||||
                        if dist.get_backend() != "gloo"
 | 
			
		||||
                        else pipe_group
 | 
			
		||||
                    )
 | 
			
		||||
                else:
 | 
			
		||||
                    group_cpu = None
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -273,9 +290,13 @@ class Initializer_Tensor(ProcessGroupInitializer):
 | 
			
		|||
 | 
			
		||||
        for i in range(self.num_tensor_parallel_group):
 | 
			
		||||
            ranks = [i * self.tensor_parallel_size + j for j in range(self.tensor_parallel_size)]
 | 
			
		||||
            group = dist.new_group(ranks)
 | 
			
		||||
            group = dist.new_group(ranks, timeout=LLM_NCCL_TIMEOUT)
 | 
			
		||||
            if use_cpu:
 | 
			
		||||
                group_cpu = dist.new_group(ranks, backend="gloo") if dist.get_backend() != "gloo" else group
 | 
			
		||||
                group_cpu = (
 | 
			
		||||
                    dist.new_group(ranks, backend="gloo", timeout=LLM_NCCL_TIMEOUT)
 | 
			
		||||
                    if dist.get_backend() != "gloo"
 | 
			
		||||
                    else group
 | 
			
		||||
                )
 | 
			
		||||
            else:
 | 
			
		||||
                group_cpu = None
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -329,9 +350,13 @@ class Initializer_Zero1(ProcessGroupInitializer):
 | 
			
		|||
                    i + (j * self.zero1_parallel_size + k) * self.rank_num_per_dp_group
 | 
			
		||||
                    for k in range(self.zero1_parallel_size)
 | 
			
		||||
                ]
 | 
			
		||||
                group = dist.new_group(ranks)
 | 
			
		||||
                group = dist.new_group(ranks, timeout=LLM_NCCL_TIMEOUT)
 | 
			
		||||
                if use_cpu:
 | 
			
		||||
                    group_cpu = dist.new_group(ranks, backend="gloo") if dist.get_backend() != "gloo" else group
 | 
			
		||||
                    group_cpu = (
 | 
			
		||||
                        dist.new_group(ranks, backend="gloo", timeout=LLM_NCCL_TIMEOUT)
 | 
			
		||||
                        if dist.get_backend() != "gloo"
 | 
			
		||||
                        else group
 | 
			
		||||
                    )
 | 
			
		||||
                else:
 | 
			
		||||
                    group_cpu = None
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -378,9 +403,13 @@ class Initializer_Nettest(ProcessGroupInitializer):
 | 
			
		|||
                rank = i * self.nettest_parallel_size + j
 | 
			
		||||
                if rank < self.world_size:
 | 
			
		||||
                    ranks.append(rank)
 | 
			
		||||
            group = dist.new_group(ranks)
 | 
			
		||||
            group = dist.new_group(ranks, timeout=LLM_NCCL_TIMEOUT)
 | 
			
		||||
            if use_cpu:
 | 
			
		||||
                group_cpu = dist.new_group(ranks, backend="gloo") if dist.get_backend() != "gloo" else group
 | 
			
		||||
                group_cpu = (
 | 
			
		||||
                    dist.new_group(ranks, backend="gloo", timeout=LLM_NCCL_TIMEOUT)
 | 
			
		||||
                    if dist.get_backend() != "gloo"
 | 
			
		||||
                    else group
 | 
			
		||||
                )
 | 
			
		||||
            else:
 | 
			
		||||
                group_cpu = None
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -9,6 +9,7 @@ import torch
 | 
			
		|||
 | 
			
		||||
from internlm.core.engine import Engine
 | 
			
		||||
from internlm.utils.common import conditional_context
 | 
			
		||||
from internlm.utils.timeout import llm_timeout
 | 
			
		||||
 | 
			
		||||
from .base_scheduler import BaseScheduler, SchedulerHook
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -126,6 +127,7 @@ class NonPipelineScheduler(BaseScheduler):
 | 
			
		|||
 | 
			
		||||
        return output, loss
 | 
			
		||||
 | 
			
		||||
    @llm_timeout(func_name="nopp_forward_backward_step")
 | 
			
		||||
    def forward_backward_step(
 | 
			
		||||
        self,
 | 
			
		||||
        engine: Engine,
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -15,6 +15,7 @@ from internlm.core.engine import Engine
 | 
			
		|||
from internlm.core.naive_amp import NaiveAMPModel
 | 
			
		||||
from internlm.utils.common import get_current_device, move_to_device
 | 
			
		||||
from internlm.utils.logger import get_logger
 | 
			
		||||
from internlm.utils.timeout import llm_timeout
 | 
			
		||||
 | 
			
		||||
from .base_scheduler import BaseScheduler, SchedulerHook
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -592,6 +593,7 @@ class PipelineScheduler(BaseScheduler):
 | 
			
		|||
 | 
			
		||||
        return output, label, accum_loss
 | 
			
		||||
 | 
			
		||||
    @llm_timeout(func_name="nointerleaved_forward_backward_step")
 | 
			
		||||
    def forward_backward_step(self, engine, data_iter, forward_only=False, return_loss=True, return_output_label=True):
 | 
			
		||||
        """Runs non-interleaved 1F1B schedule, with communication between pipeline stages.
 | 
			
		||||
        Returns a tuple with losses if the last stage, an empty tuple otherwise.
 | 
			
		||||
| 
						 | 
				
			
			@ -1247,6 +1249,7 @@ class InterleavedPipelineScheduler(PipelineScheduler):
 | 
			
		|||
        # 3. Cooldown
 | 
			
		||||
        self._run_cooldown_loop(engine, num_microsteps, num_1f1b_micropairs=num_1f1b_micropairs)
 | 
			
		||||
 | 
			
		||||
    @llm_timeout(func_name="interleaved_forward_backward_step")
 | 
			
		||||
    def forward_backward_step(self, engine, data_iter, forward_only=False, return_loss=True, return_output_label=True):
 | 
			
		||||
        """Run interleaved 1F1B schedule (model split into model chunks), with
 | 
			
		||||
        communication between pipeline stages as needed.
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -23,7 +23,15 @@ class TrainState:
 | 
			
		|||
        train_dl (DataLoader): The DataLoader object used for training.
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    def __init__(self, config) -> None:
 | 
			
		||||
    def __init__(self, config, batch_sampler) -> None:
 | 
			
		||||
        """
 | 
			
		||||
        Args:
 | 
			
		||||
            config (Config): internlm config
 | 
			
		||||
            batch_sampler (torch.utils.data.Sampler): Because the dataloader loading is
 | 
			
		||||
            asynchronous and prefetched, the batch_sampler state maintained inside the
 | 
			
		||||
            dataloader are faster then the actual training progress, so we copy the
 | 
			
		||||
            batch_sampler as the anchor point of ckpt reload.
 | 
			
		||||
        """
 | 
			
		||||
        # The number of batches produced by the data iterator
 | 
			
		||||
        self.batch_count: int = 0
 | 
			
		||||
        # Used to store the number of samples consumed in the current epoch
 | 
			
		||||
| 
						 | 
				
			
			@ -43,9 +51,20 @@ class TrainState:
 | 
			
		|||
 | 
			
		||||
        self.tensorboard_folder = config.tensorboard_folder
 | 
			
		||||
 | 
			
		||||
    def init_batch_sampler(self, train_dl):
 | 
			
		||||
        # Copy of the batch sampler from the DataLoader
 | 
			
		||||
        self.batch_sampler = train_dl.batch_sampler.copy()
 | 
			
		||||
        # learning rate
 | 
			
		||||
        self.lr = config.adam.lr
 | 
			
		||||
 | 
			
		||||
        # smapler state
 | 
			
		||||
        if batch_sampler:
 | 
			
		||||
            self.init_batch_sampler(batch_sampler)
 | 
			
		||||
 | 
			
		||||
    def init_batch_sampler(self, batch_sampler):
 | 
			
		||||
        """
 | 
			
		||||
        Args:
 | 
			
		||||
            batch_sampler (torch.utils.data.Sampler): sampler.
 | 
			
		||||
        """
 | 
			
		||||
        # make a copy of batch_sampler.
 | 
			
		||||
        self.batch_sampler = batch_sampler.copy()
 | 
			
		||||
        # Iterator for the batch sampler
 | 
			
		||||
        self.batch_sampler_iter = iter(self.batch_sampler)
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -61,26 +80,22 @@ class TrainState:
 | 
			
		|||
 | 
			
		||||
        return json.dumps(info, indent=4, sort_keys=True)
 | 
			
		||||
 | 
			
		||||
    def load_state_dict(self, other_stuffs, train_dl):
 | 
			
		||||
    def load_state_dict(self, other_stuffs):
 | 
			
		||||
        """
 | 
			
		||||
        Resumes training from a checkpoint.
 | 
			
		||||
 | 
			
		||||
        Args:
 | 
			
		||||
            other_stuffs (dict): Other information needed to resume training.
 | 
			
		||||
            train_dl (DataLoader): The DataLoader object used for training.
 | 
			
		||||
        """
 | 
			
		||||
 | 
			
		||||
        self.batch_count = other_stuffs["batch_count"] + 1  # here you need to shift a batch backward
 | 
			
		||||
        self.num_consumed_samples_in_epoch = other_stuffs["num_consumed_samples_in_epoch"]
 | 
			
		||||
        self.num_consumed_tokens = other_stuffs["num_consumed_tokens"]
 | 
			
		||||
        self.inf_nan_skip_batches = other_stuffs["inf_nan_skip_batches"]
 | 
			
		||||
        # compatible with previous checkpoints without this parameter
 | 
			
		||||
        self.step_count = other_stuffs.get("step_count", other_stuffs["batch_count"]) + 1
 | 
			
		||||
 | 
			
		||||
        # track the actual updates of sampler when using weighted sampling
 | 
			
		||||
        if hasattr(self, "batch_sampler"):
 | 
			
		||||
            self.batch_sampler = train_dl.batch_sampler.copy()
 | 
			
		||||
            self.batch_sampler_iter = iter(self.batch_sampler)
 | 
			
		||||
        # Because the ckpt save occurs after updating 'step_count',
 | 
			
		||||
        # there is no need to increment 'step_count' here (Does our step count start from 0 ?),
 | 
			
		||||
        # However, 'batch_count' is updating before ckpt storage, so it need to inc 1 when resume.
 | 
			
		||||
        self.batch_count = other_stuffs["batch_count"] + 1  # here you need to shift a batch backward
 | 
			
		||||
        self.step_count = other_stuffs.get("step_count", self.batch_count)
 | 
			
		||||
 | 
			
		||||
        # resume tensorboard from older tensorboard_folder
 | 
			
		||||
        self.resume_tb_folder = other_stuffs.get("tensorboard_folder", None)
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -10,9 +10,10 @@ import torch
 | 
			
		|||
 | 
			
		||||
from internlm.core.context import Config
 | 
			
		||||
from internlm.core.context import global_context as gpc
 | 
			
		||||
from internlm.monitor import initialize_light_monitor
 | 
			
		||||
from internlm.utils.common import get_master_node
 | 
			
		||||
from internlm.utils.logger import get_logger
 | 
			
		||||
from internlm.utils.storage_manager import init_storage_manager
 | 
			
		||||
from internlm.utils.timeout import llm_timeout
 | 
			
		||||
 | 
			
		||||
logger = get_logger(__file__)
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -122,7 +123,7 @@ def args_sanity_check():
 | 
			
		|||
    # processing the checkpoint config
 | 
			
		||||
    ckpt = gpc.config.ckpt
 | 
			
		||||
    if "enable_save_ckpt" not in ckpt:
 | 
			
		||||
        ckpt._add_item("enable_save_ckpt", False)
 | 
			
		||||
        ckpt._add_item("enable_save_ckpt", True)
 | 
			
		||||
 | 
			
		||||
    # Saving checkpoint args.
 | 
			
		||||
    if ckpt.enable_save_ckpt:
 | 
			
		||||
| 
						 | 
				
			
			@ -148,9 +149,6 @@ def args_sanity_check():
 | 
			
		|||
        if not ckpt.async_upload:
 | 
			
		||||
            ckpt._add_item("async_upload_tmp_folder", None)
 | 
			
		||||
 | 
			
		||||
        if "snapshot_ckpt_folder" not in ckpt:
 | 
			
		||||
            ckpt._add_item("snapshot_ckpt_folder", os.path.join(ckpt.save_ckpt_folder, "snapshot"))
 | 
			
		||||
 | 
			
		||||
        if "oss_snapshot_freq" not in ckpt:
 | 
			
		||||
            ckpt._add_item("oss_snapshot_freq", float("inf"))  # if oss_snapshot_freq not given, we disable.
 | 
			
		||||
    else:
 | 
			
		||||
| 
						 | 
				
			
			@ -160,44 +158,23 @@ def args_sanity_check():
 | 
			
		|||
        ckpt._add_item("async_upload", False)
 | 
			
		||||
        ckpt._add_item("async_upload_tmp_folder", None)
 | 
			
		||||
        ckpt._add_item("snapshot_ckpt_folder", None)
 | 
			
		||||
        ckpt._add_item("snapshot_ckpt_folder", None)
 | 
			
		||||
 | 
			
		||||
    # Loading checkpoint args.
 | 
			
		||||
    if "load_model_only_folder" not in ckpt:
 | 
			
		||||
        ckpt._add_item("load_model_only_folder", None)
 | 
			
		||||
 | 
			
		||||
    if "load_ckpt_folder" not in ckpt:
 | 
			
		||||
        ckpt._add_item("load_ckpt_folder", None)
 | 
			
		||||
 | 
			
		||||
    if "load_optimizer" not in ckpt:
 | 
			
		||||
        ckpt._add_item("load_optimizer", True)
 | 
			
		||||
 | 
			
		||||
    if "stop_file_path" not in ckpt:
 | 
			
		||||
        ckpt._add_item("stop_file_path", None)
 | 
			
		||||
 | 
			
		||||
    if "load_given_ckpt" not in ckpt:
 | 
			
		||||
        # If 'load_given_ckpt' is not given, we set it to False, so internlm can have opportunity
 | 
			
		||||
    if "auto_resume" not in ckpt:
 | 
			
		||||
        # If 'auto_resume' is not given, we set it to True, so internlm can have opportunity
 | 
			
		||||
        # to auto-load latest checkpoint.
 | 
			
		||||
        ckpt._add_item("load_given_ckpt", False)
 | 
			
		||||
 | 
			
		||||
    if ckpt.load_given_ckpt:
 | 
			
		||||
        # Priority: load_given_ckpt(True) > latest_checkpoint > load_model_only_folder
 | 
			
		||||
        if ckpt.load_ckpt_folder and ckpt.load_model_only_folder:
 | 
			
		||||
            logger.warning(
 | 
			
		||||
                "Detect 'load_ckpt_folder' and 'load_model_only_folder' set at the same time, \
 | 
			
		||||
and 'load_given_ckpt' is True, so internlm will load from 'load_ckpt_folder'"
 | 
			
		||||
            )
 | 
			
		||||
            ckpt.load_model_only_folder = None
 | 
			
		||||
        ckpt._add_item("auto_resume", True)
 | 
			
		||||
 | 
			
		||||
    if gpc.is_rank_for_log():
 | 
			
		||||
        logger.info("+" * 15 + " Ckpt Info " + "+" * 15)  # pylint: disable=W1201
 | 
			
		||||
        logger.info(f"is enable save ckpt: {ckpt.enable_save_ckpt}")
 | 
			
		||||
        logger.info(f"save_ckpt_folder: {ckpt.save_ckpt_folder}")
 | 
			
		||||
        logger.info(f"checkpoint_every: {ckpt.checkpoint_every}")
 | 
			
		||||
        logger.info(f"load_given_ckpt: {ckpt.load_given_ckpt}")
 | 
			
		||||
 | 
			
		||||
    # initialization storage manager
 | 
			
		||||
    init_storage_manager(ckpt)
 | 
			
		||||
 | 
			
		||||
    # tensorboard writer config
 | 
			
		||||
    if "enable_tb" not in gpc.config:
 | 
			
		||||
| 
						 | 
				
			
			@ -288,9 +265,22 @@ and 'load_given_ckpt' is True, so internlm will load from 'load_ckpt_folder'"
 | 
			
		|||
            gpc.config.parallel.sequence_parallel is True and gpc.config.model.use_flash_attn is False
 | 
			
		||||
        ), "sequence parallel does not support use_flash_attn=False"
 | 
			
		||||
 | 
			
		||||
    # feishu webhook address for alerting
 | 
			
		||||
    if "alert_address" not in gpc.config:
 | 
			
		||||
        gpc.config._add_item("alert_address", None)
 | 
			
		||||
    # monitoring default config
 | 
			
		||||
    monitor_default_config = {
 | 
			
		||||
        "alert_address": None,  # compatible with old alert config
 | 
			
		||||
        "monitor": {  # new monitoring config
 | 
			
		||||
            "alert": {"enable_feishu_alert": False, "feishu_alert_address": None, "light_monitor_address": None}
 | 
			
		||||
        },
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    for key, value in monitor_default_config.items():
 | 
			
		||||
        if key not in gpc.config:
 | 
			
		||||
            gpc.config._add_item(key, value)
 | 
			
		||||
 | 
			
		||||
    alert = gpc.config.monitor.alert
 | 
			
		||||
 | 
			
		||||
    if alert.enable_feishu_alert and not alert.feishu_alert_address and gpc.is_rank_for_log():
 | 
			
		||||
        logger.warning("alert is enable but alert_address is not set")
 | 
			
		||||
 | 
			
		||||
    optim_ckpt = gpc.config.hybrid_zero_optimizer
 | 
			
		||||
    if "zero_overlap_communication" in optim_ckpt:
 | 
			
		||||
| 
						 | 
				
			
			@ -437,6 +427,7 @@ def launch_from_torch(
 | 
			
		|||
    )
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@llm_timeout(func_name="initialize_distributed_env")
 | 
			
		||||
def initialize_distributed_env(
 | 
			
		||||
    config: str,
 | 
			
		||||
    launcher: str = "slurm",
 | 
			
		||||
| 
						 | 
				
			
			@ -470,3 +461,20 @@ def initialize_distributed_env(
 | 
			
		|||
 | 
			
		||||
    if args_check:
 | 
			
		||||
        args_sanity_check()
 | 
			
		||||
 | 
			
		||||
    # init light monitor client
 | 
			
		||||
    alert_config = gpc.config.monitor.alert
 | 
			
		||||
    if alert_config.enable_feishu_alert and gpc.is_rank_for_log():
 | 
			
		||||
        light_monitor_address = alert_config.light_monitor_address
 | 
			
		||||
        if light_monitor_address:
 | 
			
		||||
            initialize_light_monitor(light_monitor_address)
 | 
			
		||||
        else:
 | 
			
		||||
            logger.warning("monitor address is none, monitor could not be used!")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def get_config_value(config, key, defalut):
 | 
			
		||||
    try:
 | 
			
		||||
        value = config[key]
 | 
			
		||||
    except KeyError:
 | 
			
		||||
        value = defalut
 | 
			
		||||
    return value
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -0,0 +1,40 @@
 | 
			
		|||
#!/usr/bin/env python
 | 
			
		||||
# -*- encoding: utf-8 -*-
 | 
			
		||||
 | 
			
		||||
from internlm.initialize.launch import get_config_value
 | 
			
		||||
from internlm.utils.logger import get_logger
 | 
			
		||||
 | 
			
		||||
logger = get_logger(__file__)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def auto_resume_sanity_check(ckpt_config):
 | 
			
		||||
    load_given_ckpt = get_config_value(ckpt_config, "load_given_ckpt", None)
 | 
			
		||||
    if load_given_ckpt is None:
 | 
			
		||||
        return True  # default value is True
 | 
			
		||||
    else:
 | 
			
		||||
        return not load_given_ckpt
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def ckpt_info_sanity_check(ckpt_config):
 | 
			
		||||
    load_ckpt_folder = get_config_value(ckpt_config, "load_ckpt_folder", None)
 | 
			
		||||
 | 
			
		||||
    load_model_only_folder = get_config_value(ckpt_config, "load_model_only_folder", None)
 | 
			
		||||
 | 
			
		||||
    if load_model_only_folder is not None:
 | 
			
		||||
        assert (
 | 
			
		||||
            load_ckpt_folder is None
 | 
			
		||||
        ), "Detect 'load_ckpt_folder' and 'load_model_only_folder' set at the same time, \
 | 
			
		||||
# and 'load_given_ckpt' is True, so internlm will load from 'load_ckpt_folder'"
 | 
			
		||||
        return dict(path=load_model_only_folder, content=("model",), ckpt_type="internlm")
 | 
			
		||||
    else:
 | 
			
		||||
        load_optimizer = get_config_value(ckpt_config, "load_optimizer", True)
 | 
			
		||||
 | 
			
		||||
        if isinstance(load_ckpt_folder, str):
 | 
			
		||||
            if load_optimizer:
 | 
			
		||||
                return dict(path=load_ckpt_folder, content=("model", "sampler", "optimizer"), ckpt_type="internlm")
 | 
			
		||||
            else:
 | 
			
		||||
                return dict(path=load_ckpt_folder, content=("model", "sampler"), ckpt_type="internlm")
 | 
			
		||||
        elif load_ckpt_folder is None:
 | 
			
		||||
            return None
 | 
			
		||||
        else:
 | 
			
		||||
            assert f"Unsupport data type:'{type(load_ckpt_folder)}' for config.ckpt arg: 'load_ckpt_folder'"
 | 
			
		||||
| 
						 | 
				
			
			@ -9,7 +9,7 @@ from flash_attn.ops.fused_dense import ColumnParallelLinear, RowParallelLinear
 | 
			
		|||
from flash_attn.utils.distributed import all_reduce, reduce_scatter
 | 
			
		||||
from torch import nn
 | 
			
		||||
 | 
			
		||||
from internlm.core.context import IS_TENSOR_PARALLEL, ParallelMode
 | 
			
		||||
from internlm.core.context import ParallelMode
 | 
			
		||||
from internlm.core.context import global_context as gpc
 | 
			
		||||
from internlm.model.utils import fused_dense_func_torch
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -195,12 +195,6 @@ class FeedForward(nn.Module):
 | 
			
		|||
            device=device,
 | 
			
		||||
            dtype=dtype,
 | 
			
		||||
        )
 | 
			
		||||
        # need to assign tp attribute so that colossalai know it is tensor parallel module
 | 
			
		||||
 | 
			
		||||
        if gpc.get_world_size(ParallelMode.TENSOR) > 1:
 | 
			
		||||
            for name in ["w1", "w2", "w3"]:
 | 
			
		||||
                for param in getattr(self, name).parameters():
 | 
			
		||||
                    setattr(param, IS_TENSOR_PARALLEL, True)
 | 
			
		||||
 | 
			
		||||
    def forward(self, x):
 | 
			
		||||
        out = self.w3(F.silu(self.w1(x)) * self.w2(x))
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -127,6 +127,9 @@ class PackedFlashBaseLayer1D(nn.Module):
 | 
			
		|||
                device=device,
 | 
			
		||||
                dtype=dtype,
 | 
			
		||||
            )
 | 
			
		||||
        for _, param in self.mlp.named_parameters():
 | 
			
		||||
            if gpc.get_world_size(ParallelMode.TENSOR) > 1:
 | 
			
		||||
                setattr(param, IS_TENSOR_PARALLEL, True)
 | 
			
		||||
        self.dropout2 = nn.Dropout(drop_rate)
 | 
			
		||||
        self.use_swiglu = use_swiglu
 | 
			
		||||
        self.use_scaled_init = use_scaled_init
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -1,4 +1,11 @@
 | 
			
		|||
from .alert import initialize_light_monitor, send_heartbeat
 | 
			
		||||
from .monitor import initialize_monitor_manager, send_alert_message
 | 
			
		||||
from .utils import set_env_var
 | 
			
		||||
 | 
			
		||||
__all__ = ["send_alert_message", "initialize_monitor_manager", "set_env_var"]
 | 
			
		||||
__all__ = [
 | 
			
		||||
    "send_alert_message",
 | 
			
		||||
    "initialize_monitor_manager",
 | 
			
		||||
    "set_env_var",
 | 
			
		||||
    "initialize_light_monitor",
 | 
			
		||||
    "send_heartbeat",
 | 
			
		||||
]
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -1,8 +1,59 @@
 | 
			
		|||
import json
 | 
			
		||||
import math
 | 
			
		||||
import os
 | 
			
		||||
import re
 | 
			
		||||
import time
 | 
			
		||||
from typing import Dict
 | 
			
		||||
 | 
			
		||||
import requests
 | 
			
		||||
 | 
			
		||||
from internlm.utils.logger import get_logger
 | 
			
		||||
 | 
			
		||||
logger = get_logger(__file__)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def initialize_light_monitor(monitor_address: str = None):
 | 
			
		||||
    try:
 | 
			
		||||
        from uniscale_monitoring import init_monitor
 | 
			
		||||
 | 
			
		||||
        init_monitor(monitor_address)
 | 
			
		||||
    except Exception as e:
 | 
			
		||||
        logger.warning(f"init monitor meet error: {e}")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def send_heartbeat(msg_type: str, msg: Dict):
 | 
			
		||||
    def nan2none(v):
 | 
			
		||||
        if isinstance(v, float) and math.isnan(v):
 | 
			
		||||
            return None
 | 
			
		||||
        return v
 | 
			
		||||
 | 
			
		||||
    try:
 | 
			
		||||
        from uniscale_monitoring import send_meta
 | 
			
		||||
 | 
			
		||||
        data = {}
 | 
			
		||||
        for k, v in msg.items():
 | 
			
		||||
            if isinstance(v, Dict):
 | 
			
		||||
                for k1, v1 in v.items():
 | 
			
		||||
                    new_k = f"{k}_{k1}".split(" ")[0]
 | 
			
		||||
                    new_k = re.sub(r"[^a-zA-Z0-9_]", "_", new_k)
 | 
			
		||||
                    data[new_k] = nan2none(v1)
 | 
			
		||||
            else:
 | 
			
		||||
                new_k = k.split(" ")[0]
 | 
			
		||||
                new_k = re.sub(r"[^a-zA-Z0-9_]", "_", new_k)
 | 
			
		||||
                data[new_k] = nan2none(v)
 | 
			
		||||
 | 
			
		||||
        if os.getenv("CLUSTER_NAME"):
 | 
			
		||||
            data.update({"cluster": os.getenv("CLUSTER_NAME")})
 | 
			
		||||
        if msg_type == "train_metrics":
 | 
			
		||||
            data.update({"msg_type": "train_metrics"})
 | 
			
		||||
        elif msg_type == "init_time":
 | 
			
		||||
            data.update({"msg_type": "init_time"})
 | 
			
		||||
        elif msg_type == "stage_time":
 | 
			
		||||
            data.update({"msg_type": "stage_time"})
 | 
			
		||||
        send_meta(data, timeout=0.1)
 | 
			
		||||
    except Exception as e:
 | 
			
		||||
        logger.warning(f"send heartbeat meet error: {e}")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def send_feishu_msg_with_webhook(webhook: str, title: str, message: str):
 | 
			
		||||
    """
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -226,9 +226,7 @@ def initialize_monitor_manager(job_name: str = None, alert_address: str = None):
 | 
			
		|||
            send_alert_message(address=alert_address, message=f"Training in {socket.gethostname()} is starting.")
 | 
			
		||||
            yield
 | 
			
		||||
        finally:
 | 
			
		||||
            send_alert_message(
 | 
			
		||||
                address=gpc.config.alert_address, message=f"Training in {socket.gethostname()} completed."
 | 
			
		||||
            )
 | 
			
		||||
            send_alert_message(address=alert_address, message=f"Training in {socket.gethostname()} completed.")
 | 
			
		||||
            monitor_manager.stop_monitor()
 | 
			
		||||
    else:
 | 
			
		||||
        yield
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -1,6 +1,6 @@
 | 
			
		|||
#!/usr/bin/env python
 | 
			
		||||
# -*- encoding: utf-8 -*-
 | 
			
		||||
 | 
			
		||||
from .hybrid_zero_optim import HybridZeroOptimizer, FSDPadaptOptimizer
 | 
			
		||||
from .hybrid_zero_optim import HybridZeroOptimizer, FSDPadaptOptimizer, reload_zero_fp32_buff
 | 
			
		||||
 | 
			
		||||
__all__ = ["HybridZeroOptimizer", "FSDPadaptOptimizer"]
 | 
			
		||||
__all__ = ["HybridZeroOptimizer", "FSDPadaptOptimizer", "reload_zero_fp32_buff"]
 | 
			
		||||
| 
						 | 
				
			
			@ -32,6 +32,7 @@ from internlm.solver.optimizer.utils import (
 | 
			
		|||
from internlm.utils.common import get_current_device
 | 
			
		||||
from internlm.utils.logger import get_logger
 | 
			
		||||
from internlm.utils.megatron_timers import megatron_timer as timer
 | 
			
		||||
from internlm.utils.timeout import llm_timeout
 | 
			
		||||
 | 
			
		||||
from .utils import compute_norm
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -329,6 +330,7 @@ class HybridZeroOptimizer(BaseOptimizer):
 | 
			
		|||
        self._param_store = ParameterStore(ParallelMode.ZERO1)
 | 
			
		||||
        self._grad_store = GradientStore(ParallelMode.DATA)
 | 
			
		||||
        self._bucket_store = BucketStore(ParallelMode.DATA)
 | 
			
		||||
        self._bucket_in_progress = []
 | 
			
		||||
 | 
			
		||||
        # fp16 and fp32 params for mixed precision training
 | 
			
		||||
        self._fp16_param_groups = dict()
 | 
			
		||||
| 
						 | 
				
			
			@ -338,6 +340,8 @@ class HybridZeroOptimizer(BaseOptimizer):
 | 
			
		|||
        # self._overlap_communication = overlap_communication
 | 
			
		||||
        self._reduce_bucket_size = reduce_bucket_size
 | 
			
		||||
 | 
			
		||||
        self._comm_bcast_stream = torch.cuda.Stream()
 | 
			
		||||
 | 
			
		||||
        # gradient scaler
 | 
			
		||||
        self.grad_scaler = DynamicGradScaler(
 | 
			
		||||
            initial_scale=initial_scale,
 | 
			
		||||
| 
						 | 
				
			
			@ -436,13 +440,6 @@ class HybridZeroOptimizer(BaseOptimizer):
 | 
			
		|||
        # flag used to skip unnecessary gradient reduce operation when gradient accumulation is enabled.
 | 
			
		||||
        self.skip_grad_reduce = False
 | 
			
		||||
 | 
			
		||||
        # initialize communication stream for
 | 
			
		||||
        # communication-computation overlapping
 | 
			
		||||
        if self._overlap_sync_grad:
 | 
			
		||||
            self._comm_stream = torch.cuda.Stream()
 | 
			
		||||
        else:
 | 
			
		||||
            self._comm_stream = torch.cuda.current_stream()
 | 
			
		||||
 | 
			
		||||
        # reduction hook is only used if overlapping communication
 | 
			
		||||
        # if it is stage 1 without overlapping, no hook will be attached
 | 
			
		||||
        if self._overlap_sync_grad:
 | 
			
		||||
| 
						 | 
				
			
			@ -588,34 +585,41 @@ class HybridZeroOptimizer(BaseOptimizer):
 | 
			
		|||
 | 
			
		||||
    def _reduce_grads_by_rank(self, reduce_rank, grads, bucket_size):
 | 
			
		||||
        grad_buckets_by_dtype = split_half_float_double(grads)
 | 
			
		||||
 | 
			
		||||
        next_bucket_list = []
 | 
			
		||||
        # add parameters into bucket for reduction
 | 
			
		||||
        for tensor_list in grad_buckets_by_dtype:
 | 
			
		||||
            param_bucket = TensorBucket(size=bucket_size)
 | 
			
		||||
            for tensor in tensor_list:
 | 
			
		||||
                param_bucket.add_to_bucket(tensor, allow_oversize=True)
 | 
			
		||||
                if param_bucket.is_full_or_oversized():
 | 
			
		||||
                    self._reduce_and_copy(bucket=param_bucket, reduce_rank=reduce_rank)
 | 
			
		||||
                    param_bucket.empty()
 | 
			
		||||
            if not param_bucket.is_empty():
 | 
			
		||||
                self._reduce_and_copy(bucket=param_bucket, reduce_rank=reduce_rank)
 | 
			
		||||
            next_bucket_list.append(param_bucket)
 | 
			
		||||
 | 
			
		||||
        # wait for the completion of previouce bucket list reduction, and do unflatten_and_copy()
 | 
			
		||||
        # here we can also overlap the communication with some memcpy operation caused by bucket.flatten()
 | 
			
		||||
        for bucket in self._bucket_in_progress:
 | 
			
		||||
            bucket.commu_handle.wait()
 | 
			
		||||
            bucket.unflatten_and_copy()
 | 
			
		||||
            bucket.empty()
 | 
			
		||||
        self._bucket_in_progress = []
 | 
			
		||||
        self._param_store.clear_grads_of_previous_reduced_params()
 | 
			
		||||
 | 
			
		||||
        # after the completion of bucket list reduction, add new buckets into _bucket_in_progress
 | 
			
		||||
        self._bucket_in_progress = next_bucket_list.copy()
 | 
			
		||||
 | 
			
		||||
    def _reduce_and_copy(self, bucket: TensorBucket, reduce_rank):
 | 
			
		||||
        if self._overlap_sync_grad:
 | 
			
		||||
            self._comm_stream.synchronize()
 | 
			
		||||
            self._param_store.clear_grads_of_previous_reduced_params()
 | 
			
		||||
        # flatten the tensors and do allreduce
 | 
			
		||||
        bucket.flatten()
 | 
			
		||||
        bucket.commu_handle = reduce_tensor(
 | 
			
		||||
            tensor=bucket.get_flat_tensor(),
 | 
			
		||||
            dtype=None,
 | 
			
		||||
            dst_rank=reduce_rank,
 | 
			
		||||
            parallel_mode=ParallelMode.DATA,
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        with torch.cuda.stream(self._comm_stream):
 | 
			
		||||
            flat = bucket.flatten()
 | 
			
		||||
            reduced_flat = reduce_tensor(
 | 
			
		||||
                tensor=flat,
 | 
			
		||||
                dtype=self.dtype,
 | 
			
		||||
                dst_rank=reduce_rank,
 | 
			
		||||
                parallel_mode=ParallelMode.DATA,
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
            # update the reduced tensor
 | 
			
		||||
            if reduce_rank is None or reduce_rank == self._zero_local_rank:
 | 
			
		||||
                bucket.unflatten_and_copy(reduced_flat)
 | 
			
		||||
        # update the reduced tensor
 | 
			
		||||
        if reduce_rank is None or reduce_rank == self._zero_local_rank:
 | 
			
		||||
            bucket.set_unflatten_and_copy_flag(flag=True)
 | 
			
		||||
 | 
			
		||||
    def _has_inf_or_nan(self, tensor):
 | 
			
		||||
        try:
 | 
			
		||||
| 
						 | 
				
			
			@ -711,6 +715,7 @@ class HybridZeroOptimizer(BaseOptimizer):
 | 
			
		|||
 | 
			
		||||
        return norm
 | 
			
		||||
 | 
			
		||||
    @llm_timeout(func_name="optim_step")
 | 
			
		||||
    def step(self, closure=None):
 | 
			
		||||
        """Performs a single optimization step.
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -739,10 +744,13 @@ class HybridZeroOptimizer(BaseOptimizer):
 | 
			
		|||
            groups_norms.append(self._compute_norm_with_stage(group_id=group_id))
 | 
			
		||||
 | 
			
		||||
        # clear reduced grads
 | 
			
		||||
        if self._overlap_sync_grad:
 | 
			
		||||
            # grads in the last bucket is reduced
 | 
			
		||||
            self._comm_stream.synchronize()
 | 
			
		||||
            self._param_store.clear_grads_of_previous_reduced_params()
 | 
			
		||||
        # grads in the last bucket is reduced
 | 
			
		||||
        for bucket in self._bucket_in_progress:
 | 
			
		||||
            bucket.commu_handle.wait()
 | 
			
		||||
            bucket.unflatten_and_copy()
 | 
			
		||||
            bucket.empty()
 | 
			
		||||
        self._bucket_in_progress = []
 | 
			
		||||
        self._param_store.clear_grads_of_previous_reduced_params()
 | 
			
		||||
 | 
			
		||||
        # compute norm for gradients in the last bucket
 | 
			
		||||
        total_norms = {}
 | 
			
		||||
| 
						 | 
				
			
			@ -783,7 +791,7 @@ class HybridZeroOptimizer(BaseOptimizer):
 | 
			
		|||
            if gpc.is_rank_for_log():
 | 
			
		||||
                logger.warning("Overflow occurs, please check it.")
 | 
			
		||||
                send_alert_message(
 | 
			
		||||
                    address=gpc.config.alert_address,
 | 
			
		||||
                    address=gpc.config.monitor.alert.feishu_alert_address,
 | 
			
		||||
                    message="Overflow occurs, please check it.",
 | 
			
		||||
                )
 | 
			
		||||
            self._grad_store._averaged_gradients = dict()
 | 
			
		||||
| 
						 | 
				
			
			@ -829,7 +837,9 @@ class HybridZeroOptimizer(BaseOptimizer):
 | 
			
		|||
        if gpc.config.model.dtype is not torch.float32:
 | 
			
		||||
            if len(single_grad_partition_groups) != 0 and self._clip_grad_norm > 0:
 | 
			
		||||
                self._unscale_and_clip_grads(
 | 
			
		||||
                    single_grad_partition_groups, list(global_norm_groups.values()), loss_scale
 | 
			
		||||
                    single_grad_partition_groups,
 | 
			
		||||
                    list(global_norm_groups.values()),
 | 
			
		||||
                    loss_scale,
 | 
			
		||||
                )
 | 
			
		||||
 | 
			
		||||
        # update the parameters
 | 
			
		||||
| 
						 | 
				
			
			@ -850,7 +860,9 @@ class HybridZeroOptimizer(BaseOptimizer):
 | 
			
		|||
                    fp32_param = self._fp32_flat_param_groups_of_current_rank[group_id]
 | 
			
		||||
                    fp16_param.data.copy_(fp32_param)
 | 
			
		||||
 | 
			
		||||
        self.broadcast_params()
 | 
			
		||||
        torch.cuda.synchronize()
 | 
			
		||||
        with torch.cuda.stream(self._comm_bcast_stream):
 | 
			
		||||
            self.broadcast_params()
 | 
			
		||||
 | 
			
		||||
        timer("step").stop()
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -976,3 +988,17 @@ class HybridZeroOptimizer(BaseOptimizer):
 | 
			
		|||
 | 
			
		||||
        if "zero_devide_optim_plan" in states:
 | 
			
		||||
            self.params_per_rank_id_dict = states["zero_devide_optim_plan"]
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def reload_zero_fp32_buff(optimizer):
 | 
			
		||||
    # If we use AMP optimizer, we need to update its fp32 buffer as newly loaded weights value.
 | 
			
		||||
    # Or we must ensure that loading model weights must be done before zero is initialized.
 | 
			
		||||
    if isinstance(optimizer, HybridZeroOptimizer):
 | 
			
		||||
        for group_id, param_group in enumerate(optimizer.optim.param_groups):
 | 
			
		||||
            if optimizer.param_group_has_params[group_id]:
 | 
			
		||||
                # flatten fp16 params have already been updated by 'load_model_checkpoint'
 | 
			
		||||
                fp16_flat_current_rank = optimizer._param_store.get_flat_fp16_param_by_rank_group(
 | 
			
		||||
                    optimizer._zero_local_rank, group_id
 | 
			
		||||
                )
 | 
			
		||||
                # param_group["params"] is fp32 flatten optimizer states of this zero rank.
 | 
			
		||||
                param_group["params"][0].data.copy_(fp16_flat_current_rank.float())
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -249,11 +249,17 @@ class ParameterStore(BaseStore):
 | 
			
		|||
        if not last_bucket:
 | 
			
		||||
            if group_id not in self._former_bucket_reduced_param:
 | 
			
		||||
                return [], []
 | 
			
		||||
            return self._former_bucket_reduced_param[group_id], self._former_bucket_reduced_grad[group_id]
 | 
			
		||||
            return (
 | 
			
		||||
                self._former_bucket_reduced_param[group_id],
 | 
			
		||||
                self._former_bucket_reduced_grad[group_id],
 | 
			
		||||
            )
 | 
			
		||||
        else:
 | 
			
		||||
            if group_id not in self._last_bucket_reduced_param:
 | 
			
		||||
                return [], []
 | 
			
		||||
            return self._last_bucket_reduced_param[group_id], self._last_bucket_reduced_grad[group_id]
 | 
			
		||||
            return (
 | 
			
		||||
                self._last_bucket_reduced_param[group_id],
 | 
			
		||||
                self._last_bucket_reduced_grad[group_id],
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
    def reset_reduced_data_for_compute_norm(self):
 | 
			
		||||
        self._former_bucket_reduced_param = {}
 | 
			
		||||
| 
						 | 
				
			
			@ -277,6 +283,9 @@ class TensorBucket:
 | 
			
		|||
        self._max_size = size
 | 
			
		||||
        self._current_size = 0
 | 
			
		||||
        self._bucket = []
 | 
			
		||||
        self._flat_tensor = None
 | 
			
		||||
        self._unflatten_and_copy_flag = False
 | 
			
		||||
        self.commu_handle = None
 | 
			
		||||
 | 
			
		||||
    @property
 | 
			
		||||
    def max_size(self):
 | 
			
		||||
| 
						 | 
				
			
			@ -292,6 +301,15 @@ class TensorBucket:
 | 
			
		|||
    def is_empty(self):
 | 
			
		||||
        return len(self._bucket) == 0
 | 
			
		||||
 | 
			
		||||
    def set_unflatten_and_copy_flag(self, flag):
 | 
			
		||||
        self._unflatten_and_copy_flag = flag
 | 
			
		||||
 | 
			
		||||
    def get_unflatten_and_copy_flag(self):
 | 
			
		||||
        return self._unflatten_and_copy_flag
 | 
			
		||||
 | 
			
		||||
    def get_flat_tensor(self):
 | 
			
		||||
        return self._flat_tensor
 | 
			
		||||
 | 
			
		||||
    def add_to_bucket(self, tensor, allow_oversize=False):
 | 
			
		||||
        tensor_size = tensor.numel()
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -312,11 +330,14 @@ class TensorBucket:
 | 
			
		|||
    def empty(self):
 | 
			
		||||
        self._bucket = []
 | 
			
		||||
        self._size = 0
 | 
			
		||||
        self._flat_tensor = None
 | 
			
		||||
        self.commu_handle = None
 | 
			
		||||
 | 
			
		||||
    def flatten(self):
 | 
			
		||||
        return _flatten_dense_tensors(self._bucket)
 | 
			
		||||
        self._flat_tensor = _flatten_dense_tensors(self._bucket)
 | 
			
		||||
 | 
			
		||||
    def unflatten_and_copy(self, flat_tensor):
 | 
			
		||||
        unflattened_tensor_list = _unflatten_dense_tensors(flat_tensor, self._bucket)
 | 
			
		||||
        for old, new in zip(self._bucket, unflattened_tensor_list):
 | 
			
		||||
            old.copy_(new)
 | 
			
		||||
    def unflatten_and_copy(self):
 | 
			
		||||
        if self._unflatten_and_copy_flag:
 | 
			
		||||
            unflattened_tensor_list = _unflatten_dense_tensors(self._flat_tensor, self._bucket)
 | 
			
		||||
            for old, new in zip(self._bucket, unflattened_tensor_list):
 | 
			
		||||
                old.copy_(new)
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -95,37 +95,34 @@ def reduce_tensor(tensor, dtype=None, dst_rank=None, parallel_mode=ParallelMode.
 | 
			
		|||
    :type parallel_mode: ParallelMode, optional
 | 
			
		||||
    """
 | 
			
		||||
    # use the original dtype
 | 
			
		||||
    if dtype is None:
 | 
			
		||||
        dtype = tensor.dtype
 | 
			
		||||
    # if dtype is None:
 | 
			
		||||
    assert dtype is None
 | 
			
		||||
    dtype = tensor.dtype
 | 
			
		||||
 | 
			
		||||
    # cast the data to specified dtype for reduce/all-reduce
 | 
			
		||||
    if tensor.dtype != dtype:
 | 
			
		||||
        tensor_to_reduce = tensor.to(dtype)
 | 
			
		||||
    else:
 | 
			
		||||
        tensor_to_reduce = tensor
 | 
			
		||||
    # if tensor.dtype != dtype:
 | 
			
		||||
    #     tensor_to_reduce = tensor.to(dtype)
 | 
			
		||||
    # else:
 | 
			
		||||
    #     tensor_to_reduce = tensor
 | 
			
		||||
 | 
			
		||||
    world_size = gpc.get_world_size(parallel_mode)
 | 
			
		||||
    # world_size = gpc.get_world_size(parallel_mode)
 | 
			
		||||
    # tensor.div_(world_size)
 | 
			
		||||
    group = gpc.get_group(parallel_mode)
 | 
			
		||||
    tensor_to_reduce.div_(world_size)
 | 
			
		||||
 | 
			
		||||
    # if rank is None, all reduce will be used
 | 
			
		||||
    # else, reduce is used
 | 
			
		||||
    use_all_reduce = dst_rank is None
 | 
			
		||||
 | 
			
		||||
    if use_all_reduce:
 | 
			
		||||
        dist.all_reduce(tensor_to_reduce, group=group)
 | 
			
		||||
        handle = dist.all_reduce(tensor=tensor, group=group, op=torch.distributed.ReduceOp.AVG, async_op=True)
 | 
			
		||||
    else:
 | 
			
		||||
        ranks_in_group = gpc.get_ranks_in_group(parallel_mode)
 | 
			
		||||
        global_rank = ranks_in_group[dst_rank]
 | 
			
		||||
        dist.reduce(tensor=tensor_to_reduce, dst=global_rank, group=group)
 | 
			
		||||
        handle = dist.reduce(
 | 
			
		||||
            tensor=tensor, dst=global_rank, group=group, op=torch.distributed.ReduceOp.AVG, async_op=True
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    # recover the original dtype
 | 
			
		||||
    if tensor.dtype != dtype and tensor is not tensor_to_reduce:
 | 
			
		||||
        local_rank = gpc.get_local_rank(parallel_mode)
 | 
			
		||||
        if use_all_reduce or dst_rank == local_rank:
 | 
			
		||||
            tensor.copy_(tensor_to_reduce)
 | 
			
		||||
 | 
			
		||||
    return tensor
 | 
			
		||||
    return handle
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def has_inf_or_nan(tensor):
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -12,6 +12,7 @@ from torch.utils.data import ConcatDataset, DataLoader
 | 
			
		|||
 | 
			
		||||
from internlm.core.context import ParallelMode
 | 
			
		||||
from internlm.core.context import global_context as gpc
 | 
			
		||||
from internlm.core.context.random import set_mode
 | 
			
		||||
from internlm.core.naive_amp import NaiveAMPModel
 | 
			
		||||
from internlm.core.trainer import TrainState
 | 
			
		||||
from internlm.data.batch_sampler import StaticBatchSampler, get_dpsampler_dataloader
 | 
			
		||||
| 
						 | 
				
			
			@ -24,7 +25,7 @@ from internlm.data.packed_dataset import (
 | 
			
		|||
    get_packed_dataset_without_short_length,
 | 
			
		||||
)
 | 
			
		||||
from internlm.data.utils import DATASET_TYPE_IDS_MAP, unpack_data
 | 
			
		||||
from internlm.monitor import set_env_var
 | 
			
		||||
from internlm.monitor import send_heartbeat, set_env_var
 | 
			
		||||
from internlm.monitor.monitor import monitor_manager as mm
 | 
			
		||||
from internlm.solver.beta2_scheduler import Beta2Scheduler
 | 
			
		||||
from internlm.solver.lr_scheduler import FineTuneCosineAnnealingWarmupLR
 | 
			
		||||
| 
						 | 
				
			
			@ -39,6 +40,7 @@ from internlm.utils.parallel import (
 | 
			
		|||
    sync_model_param_within_tp,
 | 
			
		||||
)
 | 
			
		||||
from internlm.utils.registry import MODEL_INITIALIZER
 | 
			
		||||
from internlm.utils.timeout import llm_timeout
 | 
			
		||||
 | 
			
		||||
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
 | 
			
		||||
from torch.distributed.fsdp.fully_sharded_data_parallel import (
 | 
			
		||||
| 
						 | 
				
			
			@ -54,6 +56,7 @@ from internlm.model.modeling_internlm import PackedFlashBaseLayer1D, PackedFlash
 | 
			
		|||
logger = get_logger(__file__)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@llm_timeout(func_name="initialize_model")
 | 
			
		||||
def initialize_model():
 | 
			
		||||
    """
 | 
			
		||||
    Initialize model with Automatic Mixed Precision.
 | 
			
		||||
| 
						 | 
				
			
			@ -93,6 +96,10 @@ def initialize_model():
 | 
			
		|||
    # the same across tensor parallelism.
 | 
			
		||||
    sync_model_param_within_tp(model)
 | 
			
		||||
 | 
			
		||||
    # Change random state mode to ParallelMode.DATA after model is built, guaranteeing the random
 | 
			
		||||
    # state in the same dp group are all the same.
 | 
			
		||||
    set_mode(ParallelMode.DATA)
 | 
			
		||||
 | 
			
		||||
    return model
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -114,6 +121,7 @@ def warp_FSDP_model(model: Union[nn.Module, nn.ModuleList]):
 | 
			
		|||
    return model
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@llm_timeout(func_name="initialize_optimizer")
 | 
			
		||||
def initialize_optimizer(model: Union[nn.Module, nn.ModuleList]):
 | 
			
		||||
    """
 | 
			
		||||
    Initialize optimizer.
 | 
			
		||||
| 
						 | 
				
			
			@ -158,6 +166,7 @@ def initialize_optimizer(model: Union[nn.Module, nn.ModuleList]):
 | 
			
		|||
    return optimizer, beta2_scheduler, lr_scheduler
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@llm_timeout(func_name="get_train_data_loader")
 | 
			
		||||
def get_train_data_loader(
 | 
			
		||||
    num_worker: int = 0, dataset_generate_func: Callable = None, train_sampler=None, train_collate_fn=None
 | 
			
		||||
):
 | 
			
		||||
| 
						 | 
				
			
			@ -237,6 +246,7 @@ def get_train_data_loader(
 | 
			
		|||
    return train_dl, dataset_types
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@llm_timeout(func_name="get_validation_data_loader")
 | 
			
		||||
def get_validation_data_loader(
 | 
			
		||||
    num_worker: int = 0, dataset_generate_func: Callable = None, val_collate_fn=None, dataloader_func=None
 | 
			
		||||
):
 | 
			
		||||
| 
						 | 
				
			
			@ -298,6 +308,7 @@ def get_validation_data_loader(
 | 
			
		|||
    return val_dls
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@llm_timeout(func_name="load_new_batch")
 | 
			
		||||
def load_new_batch(train_dl: DataLoader, train_iter: Iterable, train_state: TrainState):
 | 
			
		||||
    """
 | 
			
		||||
    Load and return the new batch data based on training data loader.
 | 
			
		||||
| 
						 | 
				
			
			@ -355,6 +366,7 @@ def initialize_llm_profile(profiling: bool = False, start_time: str = None):
 | 
			
		|||
    )
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@llm_timeout(func_name="record_current_batch_training_metrics")
 | 
			
		||||
def record_current_batch_training_metrics(
 | 
			
		||||
    get_tflops_func,
 | 
			
		||||
    logger,
 | 
			
		||||
| 
						 | 
				
			
			@ -440,6 +452,9 @@ def record_current_batch_training_metrics(
 | 
			
		|||
            else:
 | 
			
		||||
                writer.add_scalar(key=key, value=value, step=train_state.step_count)
 | 
			
		||||
 | 
			
		||||
        if gpc.config.monitor.alert.get("light_monitor_address", None) and batch_count % 50 == 0:
 | 
			
		||||
            send_heartbeat("train_metrics", infos)
 | 
			
		||||
 | 
			
		||||
        if update_panel:
 | 
			
		||||
            # metrics shown with dashboard panels
 | 
			
		||||
            panel_metrics = {
 | 
			
		||||
| 
						 | 
				
			
			@ -465,4 +480,8 @@ def record_current_batch_training_metrics(
 | 
			
		|||
            logger.info(line)
 | 
			
		||||
 | 
			
		||||
        # if loss spike occurs, send alert info to feishu
 | 
			
		||||
        mm.monitor_loss_spike(alert_address=gpc.config.alert_address, step_count=batch_count, cur_step_loss=loss.item())
 | 
			
		||||
        mm.monitor_loss_spike(
 | 
			
		||||
            alert_address=gpc.config.monitor.alert.feishu_alert_address,
 | 
			
		||||
            step_count=batch_count,
 | 
			
		||||
            cur_step_loss=loss.item(),
 | 
			
		||||
        )
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -84,7 +84,7 @@ def initialize_uniscale_logger(
 | 
			
		|||
            job_name and launch_time and file_name
 | 
			
		||||
        ), "If file_path is None, job_name, launch_time and file_name must be setted."
 | 
			
		||||
        log_file_name = file_name
 | 
			
		||||
        log_folder = os.path.join(job_name, launch_time, "logs")
 | 
			
		||||
        log_folder = os.path.join("RUN", job_name, launch_time, "logs")
 | 
			
		||||
        log_dir = os.path.join(log_folder, log_file_name)
 | 
			
		||||
        file_path = log_dir
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -3,37 +3,136 @@
 | 
			
		|||
 | 
			
		||||
import copy
 | 
			
		||||
import fcntl
 | 
			
		||||
import inspect
 | 
			
		||||
import os
 | 
			
		||||
import socket
 | 
			
		||||
import time
 | 
			
		||||
from enum import Enum
 | 
			
		||||
from typing import Dict
 | 
			
		||||
from typing import Callable, Dict, Union
 | 
			
		||||
 | 
			
		||||
import torch
 | 
			
		||||
 | 
			
		||||
from internlm.core.context import ParallelMode
 | 
			
		||||
from internlm.core.context import global_context as gpc
 | 
			
		||||
from internlm.core.trainer import TrainState
 | 
			
		||||
from internlm.initialize.launch import get_config_value
 | 
			
		||||
from internlm.initialize.legacy.launch import (
 | 
			
		||||
    auto_resume_sanity_check,
 | 
			
		||||
    ckpt_info_sanity_check,
 | 
			
		||||
)
 | 
			
		||||
from internlm.monitor import send_alert_message
 | 
			
		||||
from internlm.solver.optimizer import HybridZeroOptimizer
 | 
			
		||||
from internlm.solver.optimizer import HybridZeroOptimizer, reload_zero_fp32_buff
 | 
			
		||||
from internlm.utils.common import get_current_device
 | 
			
		||||
from internlm.utils.logger import get_logger
 | 
			
		||||
from internlm.utils.megatron_timers import megatron_timer as timer
 | 
			
		||||
from internlm.utils.storage_manager import (
 | 
			
		||||
    get_fns,
 | 
			
		||||
    get_storage_manager,
 | 
			
		||||
    init_storage_manager,
 | 
			
		||||
    llm_load,
 | 
			
		||||
    llm_save,
 | 
			
		||||
    try_get_storage_backend,
 | 
			
		||||
)
 | 
			
		||||
from internlm.utils.timeout import llm_timeout
 | 
			
		||||
 | 
			
		||||
logger = get_logger(__file__)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class CheckpointType(Enum):
 | 
			
		||||
class CheckpointSaveType(Enum):
 | 
			
		||||
    NORMAL_CHECKPOINT = 1
 | 
			
		||||
    SNAPSHOT_CHECKPOINT = 2
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class CheckpointLoadType(Enum):
 | 
			
		||||
    INTERNLM = "internlm"
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# The load method implemented by internlm by default does not use string representation types,
 | 
			
		||||
# but uses enumeration types defined in advance.
 | 
			
		||||
LOAD_TYPE_DICT = {
 | 
			
		||||
    "internlm": CheckpointLoadType.INTERNLM,
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class CheckpointLoadContent:
 | 
			
		||||
    MODEL = "model"
 | 
			
		||||
    SAMPLER = "sampler"
 | 
			
		||||
    OPIMIZER = "optimizer"
 | 
			
		||||
    SCHEDULAER = "scheduler"
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class CheckpointLoadMethod:
 | 
			
		||||
    """The registration class of the checkpoint loading method,
 | 
			
		||||
    users can define their own custom ckpt loading methods."""
 | 
			
		||||
 | 
			
		||||
    LOAD_FUNC_SIG = None
 | 
			
		||||
    LOAD_TYPE_FUNC = {}
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def convet_load_type(load_type: str) -> Union[CheckpointLoadType, str]:
 | 
			
		||||
        if load_type.lower() in LOAD_TYPE_DICT:
 | 
			
		||||
            # The ckpt load method implemented by internlm by default.
 | 
			
		||||
            return LOAD_TYPE_DICT[load_type.lower()]
 | 
			
		||||
        else:
 | 
			
		||||
            # If it is a user-defined field, we do not do any conversion and represent it as a string.
 | 
			
		||||
            return load_type
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def register_ckpt_load_type(load_type: Union[str, CheckpointLoadType], load_func: Callable):
 | 
			
		||||
        if load_type in CheckpointLoadMethod.LOAD_TYPE_FUNC:
 | 
			
		||||
            logger.warning(f"{load_type} has aleady been registed!")
 | 
			
		||||
            return
 | 
			
		||||
 | 
			
		||||
        CheckpointLoadMethod.LOAD_TYPE_FUNC.update({load_type: load_func})
 | 
			
		||||
 | 
			
		||||
        if load_type == CheckpointLoadType.INTERNLM:
 | 
			
		||||
            CheckpointLoadMethod.LOAD_FUNC_SIG = inspect.signature(load_func)
 | 
			
		||||
        else:
 | 
			
		||||
            if inspect.signature(load_func) != CheckpointLoadMethod.LOAD_FUNC_SIG:
 | 
			
		||||
                logger.warning(
 | 
			
		||||
                    f"registe load model ckpt signature is not same with: {CheckpointLoadMethod.LOAD_FUNC_SIG}"
 | 
			
		||||
                )
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def get_ckpt_load_type_func(load_type: Union[str, CheckpointLoadType]):
 | 
			
		||||
        return CheckpointLoadMethod.LOAD_TYPE_FUNC[load_type]
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class CheckpointLoadMask:
 | 
			
		||||
    """
 | 
			
		||||
    According to the content field in the incoming ckpt_info, decide which components to load.
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    LOAD_CONTENT_DICT = {
 | 
			
		||||
        "model": CheckpointLoadContent.MODEL,
 | 
			
		||||
        "sampler": CheckpointLoadContent.SAMPLER,
 | 
			
		||||
        "optimizer": CheckpointLoadContent.OPIMIZER,
 | 
			
		||||
        "scheduler": CheckpointLoadContent.SCHEDULAER,
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    def __init__(self, content: tuple) -> None:
 | 
			
		||||
        self.load_set = set(map(lambda x: x.lower(), content))
 | 
			
		||||
        if "all" in self.load_set:
 | 
			
		||||
            self.load_set = set(CheckpointLoadMask.LOAD_CONTENT_DICT.values())
 | 
			
		||||
        else:
 | 
			
		||||
            self.load_set = set(map(lambda x: CheckpointLoadMask.LOAD_CONTENT_DICT[x.lower()], content))
 | 
			
		||||
 | 
			
		||||
    def need_load(self, content: CheckpointLoadContent):
 | 
			
		||||
        return content in self.load_set
 | 
			
		||||
 | 
			
		||||
    def not_only_load(self, content: CheckpointLoadContent):
 | 
			
		||||
        return content in self.load_set and len(self.load_set) > 1
 | 
			
		||||
 | 
			
		||||
    def only_load(self, content: CheckpointLoadContent):
 | 
			
		||||
        return set((content,)) == self.load_set
 | 
			
		||||
 | 
			
		||||
    def __str__(self) -> str:
 | 
			
		||||
        return f"{self.load_set}."
 | 
			
		||||
 | 
			
		||||
    def __repr__(self) -> str:
 | 
			
		||||
        return f"{self.load_set}."
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def get_model_topology(model):
 | 
			
		||||
    """
 | 
			
		||||
    Returns:
 | 
			
		||||
| 
						 | 
				
			
			@ -75,6 +174,66 @@ def get_state_dict(model):
 | 
			
		|||
    return states
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def try_load_internlm_ckpt(ckpt_mm, load_info, train_state: TrainState):
 | 
			
		||||
    load_content_str = ""
 | 
			
		||||
    load_ckpt_folder = load_info["path"]
 | 
			
		||||
    load_content: CheckpointLoadMask = load_info["content"]
 | 
			
		||||
 | 
			
		||||
    if gpc.is_rank_for_log():
 | 
			
		||||
        logger.info(f"Try load_ckpt_folder: {load_ckpt_folder}")
 | 
			
		||||
 | 
			
		||||
    if load_content.need_load(CheckpointLoadContent.MODEL):
 | 
			
		||||
        load_model_checkpoint(folder=load_ckpt_folder, model=ckpt_mm.model)
 | 
			
		||||
        load_content_str += f"{CheckpointLoadContent.MODEL}, "
 | 
			
		||||
 | 
			
		||||
    if load_content.not_only_load(CheckpointLoadContent.MODEL):
 | 
			
		||||
        # load training states.
 | 
			
		||||
        load_context(load_ckpt_folder, train_state)
 | 
			
		||||
 | 
			
		||||
        # load optimzier states.
 | 
			
		||||
        if load_content.need_load(CheckpointLoadContent.OPIMIZER):
 | 
			
		||||
            load_optimizer_checkpoint(load_ckpt_folder, ckpt_mm.optimizer)
 | 
			
		||||
            load_content_str += f"{CheckpointLoadContent.OPIMIZER}, "
 | 
			
		||||
        else:
 | 
			
		||||
            if gpc.is_rank_for_log():
 | 
			
		||||
                logger.warning("CheckpointManager has no 'optimizer', skip reload optim checkpoint!")
 | 
			
		||||
 | 
			
		||||
        # load lr scheduler states.
 | 
			
		||||
        if load_content.need_load(CheckpointLoadContent.SCHEDULAER):
 | 
			
		||||
            if ckpt_mm.lr_scheduler:
 | 
			
		||||
                load_scheduler(load_ckpt_folder, ckpt_mm.lr_scheduler, ckpt_mm.optimizer, train_state)
 | 
			
		||||
                load_content_str += f"{CheckpointLoadContent.SCHEDULAER}, "
 | 
			
		||||
            else:
 | 
			
		||||
                if gpc.is_rank_for_log():
 | 
			
		||||
                    logger.warning("CheckpointManager has no 'lr_scheduler', skip reload lr_scheduler checkpoint!")
 | 
			
		||||
 | 
			
		||||
        # load dataloader sampler states.
 | 
			
		||||
        if load_content.need_load(CheckpointLoadContent.SAMPLER):
 | 
			
		||||
            if hasattr(train_state, "batch_sampler") and not isinstance(
 | 
			
		||||
                train_state.batch_sampler, torch.utils.data.sampler.BatchSampler
 | 
			
		||||
            ):
 | 
			
		||||
                load_sampler(load_ckpt_folder, ckpt_mm.train_dl.batch_sampler)
 | 
			
		||||
                # track the actual updates of sampler when using weighted sampling
 | 
			
		||||
                train_state.init_batch_sampler(ckpt_mm.train_dl.batch_sampler)
 | 
			
		||||
                load_content_str += f"{CheckpointLoadContent.SAMPLER}, "
 | 
			
		||||
            else:
 | 
			
		||||
                if gpc.is_rank_for_log():
 | 
			
		||||
                    logger.warning("CheckpointManager skip reload 'batch_sampler'")
 | 
			
		||||
 | 
			
		||||
            # reload data state dict.
 | 
			
		||||
            if hasattr(train_state, "data_state_dict"):
 | 
			
		||||
                ckpt_mm.train_dl.dataset.load_state_dict(
 | 
			
		||||
                    llm_load(os.path.join(load_ckpt_folder, "sampler_0.pt")), ckpt_path=load_ckpt_folder
 | 
			
		||||
                )
 | 
			
		||||
                load_content_str += f"{CheckpointLoadContent.SAMPLER}, "
 | 
			
		||||
            else:
 | 
			
		||||
                if gpc.is_rank_for_log():
 | 
			
		||||
                    logger.warning(
 | 
			
		||||
                        "CheckpointManager has no 'data_state_dict', skip reload data_state_dict checkpoint!"
 | 
			
		||||
                    )
 | 
			
		||||
    return load_content_str
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def save_model_checkpoint(folder, model):
 | 
			
		||||
    """
 | 
			
		||||
    Save the model according to the relationship between tp and dp. The principle is that the data of each tp
 | 
			
		||||
| 
						 | 
				
			
			@ -257,15 +416,16 @@ def load_sampler(ckpt_path: str, sampler):
 | 
			
		|||
    torch.cuda.empty_cache()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def load_context(ckpt_path: str, train_dl, train_state: TrainState):
 | 
			
		||||
def load_context(ckpt_path: str, train_state: TrainState):
 | 
			
		||||
    context_stuffs = llm_load(os.path.join(ckpt_path, "context.pt"))
 | 
			
		||||
    train_state.load_state_dict(context_stuffs, train_dl)
 | 
			
		||||
    train_state.load_state_dict(context_stuffs)
 | 
			
		||||
    if gpc.is_rank_for_log():
 | 
			
		||||
        logger.info(f"reload train_state:{train_state}")
 | 
			
		||||
    torch.cuda.empty_cache()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def load_scheduler(ckpt_path: str, lr_scheduler, optimizer, learning_rate, train_state: TrainState):
 | 
			
		||||
def load_scheduler(ckpt_path: str, lr_scheduler, optimizer, train_state: TrainState):
 | 
			
		||||
    learning_rate = train_state.lr
 | 
			
		||||
    scheduler_states = llm_load(os.path.join(ckpt_path, "schedulder.pt"))
 | 
			
		||||
    if learning_rate != scheduler_states["base_lrs"][0] and gpc.is_rank_for_log():
 | 
			
		||||
        logger.warning(
 | 
			
		||||
| 
						 | 
				
			
			@ -294,7 +454,17 @@ def load_scheduler(ckpt_path: str, lr_scheduler, optimizer, learning_rate, train
 | 
			
		|||
class CheckpointManager:
 | 
			
		||||
    """StorageManagerContext"""
 | 
			
		||||
 | 
			
		||||
    def __init__(self, ckpt_config, model, model_config=None, model_config_file=None, feishu_address=None) -> None:
 | 
			
		||||
    def __init__(
 | 
			
		||||
        self,
 | 
			
		||||
        ckpt_config,
 | 
			
		||||
        model,
 | 
			
		||||
        train_dl=None,
 | 
			
		||||
        optimizer=None,
 | 
			
		||||
        lr_scheduler=None,
 | 
			
		||||
        model_config=None,
 | 
			
		||||
        model_config_file=None,
 | 
			
		||||
        feishu_address=None,
 | 
			
		||||
    ) -> None:
 | 
			
		||||
        """
 | 
			
		||||
        CheckpointManager is used to decide when to store ckpt. If it is an asynchronous
 | 
			
		||||
        upload mode, you must call wait_async_upload_finish at the end of the program to wait
 | 
			
		||||
| 
						 | 
				
			
			@ -307,22 +477,44 @@ class CheckpointManager:
 | 
			
		|||
            lr_scheduler (object): lr_scheduler obj.
 | 
			
		||||
            model_config (dict): model config.
 | 
			
		||||
        """
 | 
			
		||||
        self.enable_save_ckpt = ckpt_config.enable_save_ckpt
 | 
			
		||||
        self.checkpoint_every = ckpt_config.checkpoint_every
 | 
			
		||||
        self.save_ckpt_folder = ckpt_config.save_ckpt_folder
 | 
			
		||||
        self.snapshot_ckpt_folder = ckpt_config.snapshot_ckpt_folder
 | 
			
		||||
        self.oss_snapshot_freq: int = ckpt_config.oss_snapshot_freq
 | 
			
		||||
        self.stop_file_path = ckpt_config.stop_file_path
 | 
			
		||||
        self.load_model_only_folder = ckpt_config.load_model_only_folder
 | 
			
		||||
        self.enable_save_ckpt = get_config_value(ckpt_config, "enable_save_ckpt", False)
 | 
			
		||||
        self.checkpoint_every = get_config_value(ckpt_config, "checkpoint_every", 100)
 | 
			
		||||
        self.save_ckpt_folder = get_config_value(ckpt_config, "save_ckpt_folder", None)
 | 
			
		||||
        self.oss_snapshot_freq: int = get_config_value(ckpt_config, "oss_snapshot_freq", 50)
 | 
			
		||||
        self.stop_file_path = get_config_value(ckpt_config, "stop_file_path", None)
 | 
			
		||||
        if self.save_ckpt_folder:
 | 
			
		||||
            self.snapshot_ckpt_folder = get_config_value(
 | 
			
		||||
                ckpt_config, "snapshot_ckpt_folder", os.path.join(self.save_ckpt_folder, "snapshot")
 | 
			
		||||
            )
 | 
			
		||||
            self.async_upload_tmp_folder = get_config_value(
 | 
			
		||||
                ckpt_config, "async_upload_tmp_folder", "/dev/shm/internlm_tmp_ckpt/"
 | 
			
		||||
            )
 | 
			
		||||
        else:
 | 
			
		||||
            self.snapshot_ckpt_folder = None
 | 
			
		||||
            self.async_upload_tmp_folder = None
 | 
			
		||||
 | 
			
		||||
        self.async_upload = get_config_value(ckpt_config, "async_upload", False)
 | 
			
		||||
 | 
			
		||||
        # initialization storage manager
 | 
			
		||||
        init_storage_manager(self.enable_save_ckpt, self.async_upload_tmp_folder, self.async_upload)
 | 
			
		||||
 | 
			
		||||
        self.feishu_address = feishu_address
 | 
			
		||||
        self.storage_manager = get_storage_manager()
 | 
			
		||||
        self.snapshot_counter = 0
 | 
			
		||||
        self.load_optimizer = gpc.config.ckpt.load_optimizer
 | 
			
		||||
 | 
			
		||||
        self.model = model
 | 
			
		||||
        self.optimizer = optimizer
 | 
			
		||||
        self.lr_scheduler = lr_scheduler
 | 
			
		||||
        self.train_dl = train_dl
 | 
			
		||||
        self.model_config = model_config
 | 
			
		||||
        self.model_config_file = model_config_file
 | 
			
		||||
 | 
			
		||||
        # Register defalut internlm ckpt load type.
 | 
			
		||||
        self.defalut_load_type_func = {CheckpointLoadType.INTERNLM: try_load_internlm_ckpt}
 | 
			
		||||
        for ckpt_load_type in CheckpointLoadType:
 | 
			
		||||
            CheckpointLoadMethod.register_ckpt_load_type(ckpt_load_type, self.defalut_load_type_func[ckpt_load_type])
 | 
			
		||||
 | 
			
		||||
        # Init alter file.
 | 
			
		||||
        if self.stop_file_path and gpc.get_global_rank() == 0:
 | 
			
		||||
            dir_path = os.path.dirname(self.stop_file_path)
 | 
			
		||||
            if dir_path != "" and not os.path.exists(dir_path):
 | 
			
		||||
| 
						 | 
				
			
			@ -330,21 +522,35 @@ class CheckpointManager:
 | 
			
		|||
            with open(self.stop_file_path, "w", encoding="utf-8") as f:
 | 
			
		||||
                f.write("0")
 | 
			
		||||
 | 
			
		||||
        if ckpt_config.load_given_ckpt is False:
 | 
			
		||||
            # Priority: load_given_ckpt(True) > latest_checkpoint > load_model_only_folder
 | 
			
		||||
            latest_ckpt_path = self.query_lastest_ckpt()
 | 
			
		||||
            if latest_ckpt_path:
 | 
			
		||||
                self.load_ckpt_folder = latest_ckpt_path
 | 
			
		||||
            else:
 | 
			
		||||
                # At this time, we have to load model init weights and train from step 0.
 | 
			
		||||
                self.load_ckpt_folder = self.load_model_only_folder
 | 
			
		||||
        else:
 | 
			
		||||
            self.load_ckpt_folder = ckpt_config.load_ckpt_folder
 | 
			
		||||
        self.load_ckpt_info = get_config_value(ckpt_config, "load_ckpt_info", None)
 | 
			
		||||
        if self.load_ckpt_info is None:  # (legacy): Try Compatible with old interfaces
 | 
			
		||||
            self.load_ckpt_info = ckpt_info_sanity_check(ckpt_config)
 | 
			
		||||
 | 
			
		||||
        if gpc.is_rank_for_log():
 | 
			
		||||
            logger.info(f"load_ckpt_folder will set to :'{self.load_ckpt_folder}'")
 | 
			
		||||
            if self.stop_file_path is None:
 | 
			
		||||
                logger.warning("no set stop_file_path, quit_signal_handler is disable")
 | 
			
		||||
        # Auto-reload latest checkpoint, it will overwrite the setting of 'load_ckpt_info'.
 | 
			
		||||
        self.auto_resume = get_config_value(ckpt_config, "auto_resume", None)
 | 
			
		||||
        if self.auto_resume is None:  # (legacy): Try Compatible with old interfaces
 | 
			
		||||
            self.auto_resume = auto_resume_sanity_check(ckpt_config)
 | 
			
		||||
        if self.auto_resume:
 | 
			
		||||
            self.load_ckpt_info = self.query_lastest_ckpt()
 | 
			
		||||
 | 
			
		||||
        if self.stop_file_path is None and gpc.is_rank_for_log():
 | 
			
		||||
            logger.warning("no set stop_file_path, quit_signal_handler is disable")
 | 
			
		||||
 | 
			
		||||
        # convert to internal representation
 | 
			
		||||
        if self.load_ckpt_info:
 | 
			
		||||
            assert (
 | 
			
		||||
                "path" in self.load_ckpt_info
 | 
			
		||||
                and "content" in self.load_ckpt_info
 | 
			
		||||
                and "ckpt_type" in self.load_ckpt_info
 | 
			
		||||
            ), "please set content in ckpt setting, eg: ckpt = dict(path='', content=['model'], ckpt_type='internlm')"
 | 
			
		||||
 | 
			
		||||
            # replace load_ckpt
 | 
			
		||||
            self.load_ckpt_info["content"] = CheckpointLoadMask(self.load_ckpt_info["content"])
 | 
			
		||||
            self.load_ckpt_info["ckpt_type"] = CheckpointLoadMethod.convet_load_type(self.load_ckpt_info["ckpt_type"])
 | 
			
		||||
 | 
			
		||||
        # test storage setting is ok.
 | 
			
		||||
        if self.enable_save_ckpt:
 | 
			
		||||
            self.try_ping_storage()
 | 
			
		||||
 | 
			
		||||
    def quit_signal_handler(self, train_state) -> bool:
 | 
			
		||||
        """
 | 
			
		||||
| 
						 | 
				
			
			@ -358,7 +564,7 @@ class CheckpointManager:
 | 
			
		|||
        Returns:
 | 
			
		||||
            bool: whether to quit.
 | 
			
		||||
        """
 | 
			
		||||
        now_break, now_save_ckpt, save_type = False, False, CheckpointType.NORMAL_CHECKPOINT
 | 
			
		||||
        now_break, now_save_ckpt, save_type = False, False, CheckpointSaveType.NORMAL_CHECKPOINT
 | 
			
		||||
 | 
			
		||||
        if self.stop_file_path is None:
 | 
			
		||||
            return now_break, now_save_ckpt, save_type
 | 
			
		||||
| 
						 | 
				
			
			@ -389,24 +595,29 @@ now step_count is {train_state.step_count}",
 | 
			
		|||
 | 
			
		||||
        return now_break, now_save_ckpt, save_type
 | 
			
		||||
 | 
			
		||||
    def try_save_checkpoint(self, train_state):
 | 
			
		||||
        if not self.enable_save_ckpt:
 | 
			
		||||
            return False
 | 
			
		||||
 | 
			
		||||
        save_ckpts, save_type = False, CheckpointType.NORMAL_CHECKPOINT
 | 
			
		||||
    def is_now_to_save_ckpt(self, train_state) -> (bool, CheckpointSaveType, bool):
 | 
			
		||||
        save_ckpts, save_type, now_break = False, CheckpointSaveType.NORMAL_CHECKPOINT, False
 | 
			
		||||
        if self.oss_snapshot_freq > 1 and train_state.step_count % self.oss_snapshot_freq == 0:
 | 
			
		||||
            save_ckpts, save_type = True, CheckpointType.SNAPSHOT_CHECKPOINT
 | 
			
		||||
            save_ckpts, save_type = True, CheckpointSaveType.SNAPSHOT_CHECKPOINT
 | 
			
		||||
        if train_state.step_count % self.checkpoint_every == 0:
 | 
			
		||||
            save_ckpts, save_type = True, CheckpointType.NORMAL_CHECKPOINT
 | 
			
		||||
            save_ckpts, save_type = True, CheckpointSaveType.NORMAL_CHECKPOINT
 | 
			
		||||
        now_break, singal_save_ckpts, singal_save_type = self.quit_signal_handler(train_state)
 | 
			
		||||
        if save_ckpts is False:
 | 
			
		||||
            save_ckpts = singal_save_ckpts
 | 
			
		||||
            save_type = singal_save_type
 | 
			
		||||
 | 
			
		||||
        return save_ckpts, save_type, now_break
 | 
			
		||||
 | 
			
		||||
    def try_save_checkpoint(self, train_state):
 | 
			
		||||
        if not self.enable_save_ckpt:
 | 
			
		||||
            return False
 | 
			
		||||
 | 
			
		||||
        save_ckpts, save_type, now_break = self.is_now_to_save_ckpt(train_state)
 | 
			
		||||
 | 
			
		||||
        if save_ckpts:
 | 
			
		||||
            # Wait for the previous round of asynchronous upload storage to complete.
 | 
			
		||||
            self.storage_manager.wait()
 | 
			
		||||
            if save_type == CheckpointType.SNAPSHOT_CHECKPOINT:
 | 
			
		||||
            if save_type == CheckpointSaveType.SNAPSHOT_CHECKPOINT:
 | 
			
		||||
                # Snapshot number, with only two snapshots written alternately.
 | 
			
		||||
                self.snapshot_counter = (self.snapshot_counter + 1) % 2
 | 
			
		||||
                save_ckpt_folder = os.path.join(self.snapshot_ckpt_folder, f"{self.snapshot_counter}")
 | 
			
		||||
| 
						 | 
				
			
			@ -436,7 +647,7 @@ now step_count is {train_state.step_count}",
 | 
			
		|||
            Tuple(str, int): path of latest ckpt and ckpt step, if not found, None will return.
 | 
			
		||||
        """
 | 
			
		||||
        ckpt_list = self.storage_manager.get_fns(self.save_ckpt_folder)
 | 
			
		||||
        if len(ckpt_list) == 0:
 | 
			
		||||
        if ckpt_list is None or len(ckpt_list) == 0:
 | 
			
		||||
            return None, None
 | 
			
		||||
 | 
			
		||||
        max_normal_step = 0
 | 
			
		||||
| 
						 | 
				
			
			@ -459,14 +670,16 @@ now step_count is {train_state.step_count}",
 | 
			
		|||
        ckpt_list_1 = self.storage_manager.get_fns(snapshot_path_0)
 | 
			
		||||
        ckpt_list_2 = self.storage_manager.get_fns(snapshot_path_1)
 | 
			
		||||
        max_step_0, max_step_1 = 0, 0
 | 
			
		||||
        for ckpt in ckpt_list_1:
 | 
			
		||||
            ckpt = ckpt.strip("/")
 | 
			
		||||
            if ckpt.endswith(".step"):
 | 
			
		||||
                max_step_0 = max(max_step_0, int(ckpt.split(".")[0]))
 | 
			
		||||
        for ckpt in ckpt_list_2:
 | 
			
		||||
            ckpt = ckpt.strip("/")
 | 
			
		||||
            if ckpt.endswith(".step"):
 | 
			
		||||
                max_step_1 = max(max_step_1, int(ckpt.split(".")[0]))
 | 
			
		||||
        if ckpt_list_1:
 | 
			
		||||
            for ckpt in ckpt_list_1:
 | 
			
		||||
                ckpt = ckpt.strip("/")
 | 
			
		||||
                if ckpt.endswith(".step"):
 | 
			
		||||
                    max_step_0 = max(max_step_0, int(ckpt.split(".")[0]))
 | 
			
		||||
        if ckpt_list_2:
 | 
			
		||||
            for ckpt in ckpt_list_2:
 | 
			
		||||
                ckpt = ckpt.strip("/")
 | 
			
		||||
                if ckpt.endswith(".step"):
 | 
			
		||||
                    max_step_1 = max(max_step_1, int(ckpt.split(".")[0]))
 | 
			
		||||
 | 
			
		||||
        snap_load_path = snapshot_path_0 if max_step_0 > max_step_1 else snapshot_path_1
 | 
			
		||||
        snap_step = max(max_step_0, max_step_1)
 | 
			
		||||
| 
						 | 
				
			
			@ -476,11 +689,12 @@ now step_count is {train_state.step_count}",
 | 
			
		|||
 | 
			
		||||
    def query_latest_snapshot_step_local(self):
 | 
			
		||||
        max_step, max_step_path = 0, None
 | 
			
		||||
        for root, _, files in os.walk(self.save_ckpt_folder, followlinks=True):
 | 
			
		||||
        save_ckpt_folder = self.save_ckpt_folder.split(":")[1]
 | 
			
		||||
        for root, _, files in os.walk(save_ckpt_folder, followlinks=True):
 | 
			
		||||
            for fn in files:
 | 
			
		||||
                fn = fn.strip("/")
 | 
			
		||||
                if fn.endswith(".step"):
 | 
			
		||||
                    # We assume that both normal ckpt and snapshot ckpt will store the '.step' file
 | 
			
		||||
                    # We assume that both internlm ckpt and snapshot ckpt will store the '.step' file
 | 
			
		||||
                    # as an integrity flag.
 | 
			
		||||
                    step = int(fn.rsplit(".", maxsplit=1)[0])
 | 
			
		||||
                    if max_step < step:
 | 
			
		||||
| 
						 | 
				
			
			@ -490,100 +704,55 @@ now step_count is {train_state.step_count}",
 | 
			
		|||
        return max_step_path, max_step
 | 
			
		||||
 | 
			
		||||
    def query_lastest_ckpt(self):
 | 
			
		||||
        latest_checkpoint = None
 | 
			
		||||
        latest_ckpt, step = None, -1
 | 
			
		||||
        # Training was automatically restarted by the process, forcing the latest snapshot to be read.
 | 
			
		||||
        if self.save_ckpt_folder:
 | 
			
		||||
            if self.save_ckpt_folder.startswith("boto3"):
 | 
			
		||||
                latest_checkpoint, step = self.query_latest_snapshot_step_boto3()
 | 
			
		||||
            elif self.save_ckpt_folder.startswith("local"):
 | 
			
		||||
                latest_checkpoint, step = self.query_latest_snapshot_step_local()
 | 
			
		||||
            else:
 | 
			
		||||
                latest_checkpoint, step = None, 0
 | 
			
		||||
            backend, _ = try_get_storage_backend(self.save_ckpt_folder)
 | 
			
		||||
            if backend == "boto3":
 | 
			
		||||
                latest_ckpt, step = self.query_latest_snapshot_step_boto3()
 | 
			
		||||
                if latest_ckpt and not latest_ckpt.startswith("boto3:"):
 | 
			
		||||
                    latest_ckpt = ":".join(["boto3", latest_ckpt])
 | 
			
		||||
            elif backend == "local":
 | 
			
		||||
                latest_ckpt, step = self.query_latest_snapshot_step_local()
 | 
			
		||||
                if latest_ckpt and not latest_ckpt.startswith("local:"):
 | 
			
		||||
                    latest_ckpt = ":".join(["local", latest_ckpt])
 | 
			
		||||
 | 
			
		||||
            if latest_checkpoint is not None:
 | 
			
		||||
                if gpc.is_rank_for_log():
 | 
			
		||||
                    logger.info(f"Found latest ckpt : {latest_checkpoint}, step: {step}")
 | 
			
		||||
                    send_alert_message(
 | 
			
		||||
                        address=self.feishu_address,
 | 
			
		||||
                        message=f"Auto restart resume from ckpt-path: '{latest_checkpoint}', step : {step}",
 | 
			
		||||
                    )
 | 
			
		||||
            else:
 | 
			
		||||
                if gpc.is_rank_for_log():
 | 
			
		||||
                    send_alert_message(
 | 
			
		||||
                        address=self.feishu_address,
 | 
			
		||||
                        message=f"Can't find snapshot checkpoint, use default load-ckpt path: {latest_checkpoint}",
 | 
			
		||||
                    )
 | 
			
		||||
        if gpc.is_rank_for_log():
 | 
			
		||||
            logger.info(f"Found latest ckpt {latest_ckpt if latest_ckpt else 'None'}, step: {step}...")
 | 
			
		||||
 | 
			
		||||
        return latest_checkpoint
 | 
			
		||||
        return dict(path=latest_ckpt, content=("all",), ckpt_type="internlm")
 | 
			
		||||
 | 
			
		||||
    def try_load_model(self, current_time=""):
 | 
			
		||||
        model_load_path = None
 | 
			
		||||
    def try_resume_training(self, train_state: TrainState, current_time=""):
 | 
			
		||||
 | 
			
		||||
        if self.load_ckpt_folder and self.load_model_only_folder:
 | 
			
		||||
            raise ValueError(
 | 
			
		||||
                "Error, try to use both load_ckpt_folder and load_model_only_folder paths, \
 | 
			
		||||
if you only need to load model weights (for example starting an SFT task for the first time), \
 | 
			
		||||
set load_model_only_folder path, if you need to resume training from ckpt, \
 | 
			
		||||
set load_ckpt_folder or use default value \
 | 
			
		||||
(if is the default value, internlm will try to load the latest ckpt from save_ckpt_folder)"
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
        if self.load_ckpt_folder:
 | 
			
		||||
            if gpc.is_rank_for_log():
 | 
			
		||||
                logger.info(
 | 
			
		||||
                    f"===========Resume training from `{self.load_ckpt_folder}` {current_time} on host:"
 | 
			
		||||
                    f"{socket.gethostname()}==========="
 | 
			
		||||
                )
 | 
			
		||||
            model_load_path = self.load_ckpt_folder
 | 
			
		||||
        elif self.load_model_only_folder:
 | 
			
		||||
            if gpc.is_rank_for_log():
 | 
			
		||||
                logger.info(
 | 
			
		||||
                    f"===========Load Model from `{self.load_model_only_folder}` {current_time} on host:"
 | 
			
		||||
                    f"{socket.gethostname()}==========="
 | 
			
		||||
                )
 | 
			
		||||
            model_load_path = self.load_model_only_folder
 | 
			
		||||
        else:
 | 
			
		||||
        if self.load_ckpt_info is None or self.load_ckpt_info["path"] is None:
 | 
			
		||||
            if gpc.is_rank_for_log():
 | 
			
		||||
                logger.info(
 | 
			
		||||
                    f"===========New Run {current_time} on host:{socket.gethostname()},rank={gpc.get_global_rank()},"
 | 
			
		||||
                    f"tp={gpc.get_local_rank(ParallelMode.TENSOR)},pp={gpc.get_local_rank(ParallelMode.PIPELINE)},"
 | 
			
		||||
                    f"dp={gpc.get_local_rank(ParallelMode.DATA)}==========="
 | 
			
		||||
                )
 | 
			
		||||
        else:
 | 
			
		||||
            load_path = self.load_ckpt_info["path"]
 | 
			
		||||
            load_content = self.load_ckpt_info["content"]
 | 
			
		||||
            load_type = self.load_ckpt_info["ckpt_type"]
 | 
			
		||||
 | 
			
		||||
        # Loading model weights must be done before zero is initialized.
 | 
			
		||||
        if model_load_path is not None:
 | 
			
		||||
            load_model_checkpoint(folder=model_load_path, model=self.model)
 | 
			
		||||
            load_func = CheckpointLoadMethod.get_ckpt_load_type_func(load_type)
 | 
			
		||||
            load_content_str = load_func(self, self.load_ckpt_info, train_state)
 | 
			
		||||
 | 
			
		||||
    def try_resume_training(self, lr_scheduler, optimizer, lr, train_state, train_dl):
 | 
			
		||||
        """Attempt to restore the training state of the last ckpt.
 | 
			
		||||
            # If we only load model weight, we need rewrite zero optim's fp32 buffer.
 | 
			
		||||
            if load_content.only_load(CheckpointLoadContent.MODEL) and isinstance(self.optimizer, HybridZeroOptimizer):
 | 
			
		||||
                reload_zero_fp32_buff(self.optimizer)
 | 
			
		||||
 | 
			
		||||
        Args:
 | 
			
		||||
            lr_scheduler (_LRScheduler): lr_scheduler object.
 | 
			
		||||
            optimizer (Optimizer): optimizer object.
 | 
			
		||||
            lr (float): learning rate.
 | 
			
		||||
            train_state (dict): traing states.
 | 
			
		||||
            train_dl (DataLoader): traning dataloader object
 | 
			
		||||
        """
 | 
			
		||||
        if self.load_ckpt_folder is not None:
 | 
			
		||||
            # load optimzier states.
 | 
			
		||||
            if self.load_optimizer:
 | 
			
		||||
                load_optimizer_checkpoint(self.load_ckpt_folder, optimizer)
 | 
			
		||||
            # load lr scheduler states.
 | 
			
		||||
            load_scheduler(self.load_ckpt_folder, lr_scheduler, optimizer, lr, train_state)
 | 
			
		||||
            # load training states.
 | 
			
		||||
            load_context(self.load_ckpt_folder, train_dl, train_state)
 | 
			
		||||
            # load dataloader sampler states.
 | 
			
		||||
            if hasattr(train_state, "batch_sampler") and not isinstance(
 | 
			
		||||
                train_state.batch_sampler, torch.utils.data.sampler.BatchSampler
 | 
			
		||||
            ):
 | 
			
		||||
                load_sampler(self.load_ckpt_folder, train_dl.batch_sampler)
 | 
			
		||||
            if hasattr(train_state, "data_state_dict"):
 | 
			
		||||
                train_dl.dataset.load_state_dict(
 | 
			
		||||
                    llm_load(os.path.join(self.load_ckpt_folder, "sampler_0.pt")), ckpt_path=self.load_ckpt_folder
 | 
			
		||||
            if gpc.is_rank_for_log():
 | 
			
		||||
                logger.info(f"load_ckpt_info : {self.load_ckpt_info}")
 | 
			
		||||
                logger.info(
 | 
			
		||||
                    f"===========Resume training from `{load_path}` {current_time} on host:"
 | 
			
		||||
                    f"{socket.gethostname()}==========="
 | 
			
		||||
                )
 | 
			
		||||
        self.optimizer = optimizer
 | 
			
		||||
        self.lr_scheduler = lr_scheduler
 | 
			
		||||
                if load_content_str:
 | 
			
		||||
                    logger.info(f"===========Load contents are: {load_content_str}")
 | 
			
		||||
 | 
			
		||||
    @llm_timeout(func_name="save_checkpoint")
 | 
			
		||||
    def save_checkpoint(
 | 
			
		||||
        self,
 | 
			
		||||
        folder,
 | 
			
		||||
| 
						 | 
				
			
			@ -624,8 +793,10 @@ set load_ckpt_folder or use default value \
 | 
			
		|||
            )
 | 
			
		||||
 | 
			
		||||
        if gpc.is_rank_for_log():
 | 
			
		||||
            scheduler_states = scheduler.state_dict()
 | 
			
		||||
            llm_save(os.path.join(folder, "schedulder.pt"), saved_obj=scheduler_states)
 | 
			
		||||
            if scheduler:
 | 
			
		||||
                scheduler_states = scheduler.state_dict()
 | 
			
		||||
                llm_save(os.path.join(folder, "schedulder.pt"), saved_obj=scheduler_states)
 | 
			
		||||
 | 
			
		||||
            if hasattr(train_state, "batch_sampler") and not isinstance(
 | 
			
		||||
                train_state.batch_sampler, torch.utils.data.sampler.BatchSampler
 | 
			
		||||
            ):
 | 
			
		||||
| 
						 | 
				
			
			@ -655,3 +826,12 @@ set load_ckpt_folder or use default value \
 | 
			
		|||
    def set_save_folder(self, folder, step):
 | 
			
		||||
        self.storage_manager.latest_save_folder = folder
 | 
			
		||||
        self.storage_manager.latest_save_step = step
 | 
			
		||||
 | 
			
		||||
    def try_ping_storage(self):
 | 
			
		||||
        if gpc.get_global_rank() % 8 == 0:
 | 
			
		||||
            buff = torch.ones((1, 64, 64), dtype=torch.bfloat16)
 | 
			
		||||
            test_fn = os.path.join(self.save_ckpt_folder, f"pings/{socket.gethostname()}.ping")
 | 
			
		||||
            self.storage_manager.save(test_fn, buff)
 | 
			
		||||
            self.storage_manager.wait()
 | 
			
		||||
            self.storage_manager.load(test_fn)
 | 
			
		||||
            del buff
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -46,12 +46,12 @@ def get_fns(fp: str):
 | 
			
		|||
    return storage_manager.get_fns(fp)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def llm_load(fp: str, *args, **kwargs):
 | 
			
		||||
    return storage_manager.load(fp, *args, **kwargs)
 | 
			
		||||
def llm_load(fp: str, **kwargs):
 | 
			
		||||
    return storage_manager.load(fp, **kwargs)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def llm_save(save_path: str, saved_obj: Any, *args, **kwargs):
 | 
			
		||||
    storage_manager.save(save_path, *args, saved_obj=saved_obj, **kwargs)
 | 
			
		||||
def llm_save(save_path: str, saved_obj: Any, **kwargs):
 | 
			
		||||
    storage_manager.save(save_path, to_save_obj=saved_obj, **kwargs)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class StorageClient:
 | 
			
		||||
| 
						 | 
				
			
			@ -63,19 +63,23 @@ class StorageClient:
 | 
			
		|||
        self.handler = handler
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def load(client, load_path: str, *args, **kwargs):
 | 
			
		||||
    def load(*args, **kwargs):
 | 
			
		||||
        raise NotImplementedError
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def sync_upload_fileobj(*args, saved_obj=None, **kwargs):
 | 
			
		||||
    def sync_upload_fileobj(*args, **kwargs):
 | 
			
		||||
        raise NotImplementedError
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def assert_fp_exists(client):
 | 
			
		||||
    def async_upload_fileobj(*args, **kwargs):
 | 
			
		||||
        raise NotImplementedError
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def get_fns(client):
 | 
			
		||||
    def assert_fp_exists(*args, **kwargs):
 | 
			
		||||
        raise NotImplementedError
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def get_fns(*args, **kwargs):
 | 
			
		||||
        raise NotImplementedError
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -92,40 +96,65 @@ class Boto3MetaInfo:
 | 
			
		|||
        async_upload_fn: callable,
 | 
			
		||||
        local_nvme_path=None,
 | 
			
		||||
    ) -> None:
 | 
			
		||||
        self.is_async = is_async
 | 
			
		||||
        # all need info.
 | 
			
		||||
        self.client = handler
 | 
			
		||||
        self.bucket_name = bucket_name
 | 
			
		||||
        self.endpoint = endpoint
 | 
			
		||||
        self.file_path = file_path
 | 
			
		||||
        self.async_upload_fn = async_upload_fn
 | 
			
		||||
        # only save need info.
 | 
			
		||||
        self.local_nvme_path = local_nvme_path
 | 
			
		||||
        self.is_async = is_async
 | 
			
		||||
        self.endpoint = endpoint
 | 
			
		||||
        self.async_upload_fn = async_upload_fn
 | 
			
		||||
 | 
			
		||||
    def __str__(self) -> str:
 | 
			
		||||
        return f"is_async: {self.is_async}, bucket_name:{self.bucket_name}, endpoint:{self.endpoint}, \
 | 
			
		||||
local_nvme_path: {self.local_nvme_path}"
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def unpack_boto3_save_meta(meta):
 | 
			
		||||
        if meta.is_async:
 | 
			
		||||
            return meta.client, meta.bucket_name, meta.file_path, meta.local_nvme_path
 | 
			
		||||
        else:
 | 
			
		||||
            return meta.client, meta.bucket_name, meta.file_path
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def unpack_boto3_nosave_meta(meta):
 | 
			
		||||
        return meta.client, meta.bucket_name, meta.file_path
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class LocalMetaInfo:
 | 
			
		||||
    """Local meta info for save/load etc."""
 | 
			
		||||
 | 
			
		||||
    def __init__(self, handler: StorageClient, dest_path: str) -> None:
 | 
			
		||||
        self.is_async = False
 | 
			
		||||
        self.client = handler
 | 
			
		||||
        self.dest_path = dest_path
 | 
			
		||||
    def __init__(self, file_path: str) -> None:
 | 
			
		||||
        self.file_path = file_path
 | 
			
		||||
        self.async_upload_fn = None
 | 
			
		||||
        self.is_async = False
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def unpack_local_save_meta(meta):
 | 
			
		||||
        return (meta.file_path,)
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def unpack_local_nosave_meta(meta):
 | 
			
		||||
        return (meta.file_path,)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def unpack_meta(meta):
 | 
			
		||||
    args = []
 | 
			
		||||
    is_async = meta.is_async
 | 
			
		||||
    for k, v in meta.__dict__.items():
 | 
			
		||||
        if k in ("endpoint", "async_upload_fn", "is_async"):
 | 
			
		||||
            continue
 | 
			
		||||
        if not is_async and k in ("local_nvme_path",):
 | 
			
		||||
            continue
 | 
			
		||||
        args.append(v)
 | 
			
		||||
def unpack_save_meta(meta: Union[Boto3MetaInfo, LocalMetaInfo]):
 | 
			
		||||
    if isinstance(meta, Boto3MetaInfo):
 | 
			
		||||
        return Boto3MetaInfo.unpack_boto3_save_meta(meta)
 | 
			
		||||
    elif isinstance(meta, LocalMetaInfo):
 | 
			
		||||
        return LocalMetaInfo.unpack_local_save_meta(meta)
 | 
			
		||||
    else:
 | 
			
		||||
        raise ValueError(f"unkonwn meta info: {type(meta)}")
 | 
			
		||||
 | 
			
		||||
    return args
 | 
			
		||||
 | 
			
		||||
def unpack_nosave_meta(meta: Union[Boto3MetaInfo, LocalMetaInfo]):
 | 
			
		||||
    if isinstance(meta, Boto3MetaInfo):
 | 
			
		||||
        return Boto3MetaInfo.unpack_boto3_nosave_meta(meta)
 | 
			
		||||
    elif isinstance(meta, LocalMetaInfo):
 | 
			
		||||
        return LocalMetaInfo.unpack_local_nosave_meta(meta)
 | 
			
		||||
    else:
 | 
			
		||||
        raise ValueError(f"unkonwn meta info: {type(meta)}")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def compute_file_md5_by_chunk(file_name: str):
 | 
			
		||||
| 
						 | 
				
			
			@ -136,6 +165,22 @@ def compute_file_md5_by_chunk(file_name: str):
 | 
			
		|||
    return hash_md5.hexdigest()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def try_get_storage_backend(path: str):
 | 
			
		||||
    sre = path.split(":", maxsplit=1)
 | 
			
		||||
    if len(sre) == 1:
 | 
			
		||||
        if path.startswith("s3:"):
 | 
			
		||||
            backend = "boto3"
 | 
			
		||||
            if gpc.is_rank_for_log():
 | 
			
		||||
                logger.warning(f"path: '{path}' not start with backend prefix, guess it is the backend of boto3.")
 | 
			
		||||
        else:
 | 
			
		||||
            backend = "local"
 | 
			
		||||
            if gpc.is_rank_for_log():
 | 
			
		||||
                logger.warning(f"path: '{path}' not start with backend prefix, guess it is the backend of local.")
 | 
			
		||||
        return backend, sre
 | 
			
		||||
    else:
 | 
			
		||||
        return sre[0], sre[1]  # (backend_prefix, splited_path)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class Boto3Client(StorageClient):
 | 
			
		||||
    """
 | 
			
		||||
    Boto3Client
 | 
			
		||||
| 
						 | 
				
			
			@ -189,13 +234,11 @@ class Boto3Client(StorageClient):
 | 
			
		|||
        )
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def sync_upload_fileobj(
 | 
			
		||||
        handler, bucket_name: str, fp: str, local_nvme_path: str, *args, saved_obj=None, **kwargs
 | 
			
		||||
    ):  # pylint: disable=W0613
 | 
			
		||||
    def sync_upload_fileobj(handler, bucket_name: str, fp: str, saved_obj=None, **kwargs):
 | 
			
		||||
        assert saved_obj is not None, "saved_obj is None!"
 | 
			
		||||
        try:
 | 
			
		||||
            with io.BytesIO() as f:
 | 
			
		||||
                torch.save(saved_obj, f, *args, **kwargs)
 | 
			
		||||
                torch.save(saved_obj, f, **kwargs)
 | 
			
		||||
                f.seek(0)
 | 
			
		||||
                handler.client.upload_fileobj(f, bucket_name, fp, Config=handler.config)
 | 
			
		||||
        except handler.botocore.exceptions.EndpointConnectionError as exc:
 | 
			
		||||
| 
						 | 
				
			
			@ -204,14 +247,7 @@ class Boto3Client(StorageClient):
 | 
			
		|||
            ) from exc
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def load(
 | 
			
		||||
        handler,
 | 
			
		||||
        bucket_name: str,
 | 
			
		||||
        fp: str,
 | 
			
		||||
        local_nvme_path: str,  # pylint: disable=W0613
 | 
			
		||||
        *args,
 | 
			
		||||
        **kwargs,
 | 
			
		||||
    ) -> Dict:
 | 
			
		||||
    def load(handler, bucket_name: str, fp: str, **kwargs) -> Dict:
 | 
			
		||||
        """
 | 
			
		||||
        Args:
 | 
			
		||||
            fp (str): Path to save, eg. s3://opennlplab/model_weights/xxx/ddd.pt
 | 
			
		||||
| 
						 | 
				
			
			@ -220,7 +256,7 @@ class Boto3Client(StorageClient):
 | 
			
		|||
            with io.BytesIO() as f:
 | 
			
		||||
                handler.client.download_fileobj(bucket_name, fp, f, Config=handler.config)
 | 
			
		||||
                f.seek(0)
 | 
			
		||||
                states = torch.load(f, *args, **kwargs)
 | 
			
		||||
                states = torch.load(f, **kwargs)
 | 
			
		||||
        except handler.botocore.exceptions.EndpointConnectionError as exc:
 | 
			
		||||
            raise RuntimeError(
 | 
			
		||||
                f"Boto3 Network Error: Please Check your Internet Connection in {socket.gethostname()}"
 | 
			
		||||
| 
						 | 
				
			
			@ -228,24 +264,37 @@ class Boto3Client(StorageClient):
 | 
			
		|||
        return states
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def assert_fp_exists(handler, bucket_name: str, fp: str, local_nvme_path: str):  # pylint: disable=W0613
 | 
			
		||||
    def assert_fp_exists(handler, bucket_name: str, fp: str):  # pylint: disable=W0613
 | 
			
		||||
        assert len(list(handler.client.list_objects(Bucket=bucket_name, Prefix=fp)["Contents"])) > 0, fp
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def get_fns(handler, bucket_name: str, fp: str, local_nvme_path: str, *args, **kwargs):  # pylint: disable=W0613
 | 
			
		||||
    def is_fp_exists(handler, bucket_name: str, fp: str):  # pylint: disable=W0613
 | 
			
		||||
        re = handler.client.list_objects(Bucket=bucket_name, Prefix=fp)
 | 
			
		||||
        if "Contents" in re:
 | 
			
		||||
            return len(list(re["Contents"])) > 0
 | 
			
		||||
        else:
 | 
			
		||||
            return False
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def get_fns(handler, bucket_name: str, fp: str):
 | 
			
		||||
        """
 | 
			
		||||
        Ref: https://stackoverflow.com/questions/54314563/
 | 
			
		||||
        how-to-get-more-than-1000-objects-from-s3-by-using-list-objects-v2
 | 
			
		||||
        """
 | 
			
		||||
        paginator = handler.client.get_paginator("list_objects_v2")
 | 
			
		||||
        pages = paginator.paginate(Bucket=bucket_name, Prefix=fp)
 | 
			
		||||
        folder_name_list = []
 | 
			
		||||
        for page in pages:
 | 
			
		||||
            if "Contents" in page:
 | 
			
		||||
                for obj in page["Contents"]:
 | 
			
		||||
                    pth: str = obj["Key"]
 | 
			
		||||
                    folder_name_list.append(pth.split(fp, maxsplit=1)[1].strip("/").split("/", maxsplit=1)[0])
 | 
			
		||||
        return list(set(folder_name_list))
 | 
			
		||||
        if Boto3Client.is_fp_exists(handler, bucket_name, fp):
 | 
			
		||||
            paginator = handler.client.get_paginator("list_objects_v2")
 | 
			
		||||
            pages = paginator.paginate(Bucket=bucket_name, Prefix=fp)
 | 
			
		||||
            folder_name_list = []
 | 
			
		||||
            for page in pages:
 | 
			
		||||
                if "Contents" in page:
 | 
			
		||||
                    for obj in page["Contents"]:
 | 
			
		||||
                        pth: str = obj["Key"]
 | 
			
		||||
                        folder_name_list.append(pth.split(fp, maxsplit=1)[1].strip("/").split("/", maxsplit=1)[0])
 | 
			
		||||
            return list(set(folder_name_list))
 | 
			
		||||
        else:
 | 
			
		||||
            if gpc.is_rank_for_log():
 | 
			
		||||
                logger.warning(f"'{fp}' not found!")
 | 
			
		||||
            return None
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def async_upload_fileobj(handler, bucket_name: str, fp: str, local_nvme_path: str):
 | 
			
		||||
| 
						 | 
				
			
			@ -273,37 +322,35 @@ class LocalClient(StorageClient):
 | 
			
		|||
        super().__init__(None)
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def sync_upload_fileobj(handler, fp: str, *args, saved_obj=None, **kwargs):
 | 
			
		||||
        assert isinstance(handler, LocalClient)
 | 
			
		||||
    def sync_upload_fileobj(fp: str, saved_obj=None, **kwargs):
 | 
			
		||||
        assert saved_obj is not None
 | 
			
		||||
        fp_dirname = os.path.dirname(fp)
 | 
			
		||||
        if not os.path.exists(fp_dirname):
 | 
			
		||||
            os.makedirs(fp_dirname, exist_ok=True)
 | 
			
		||||
        torch.save(saved_obj, fp, *args, **kwargs)
 | 
			
		||||
        torch.save(saved_obj, fp, **kwargs)
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def load(handler, fp: str, *args, **kwargs):  # pylint: disable=W0613
 | 
			
		||||
        assert isinstance(handler, LocalClient)
 | 
			
		||||
        assert os.path.exists(fp), f"{fp} is not found!"
 | 
			
		||||
        with open(fp, "rb") as f:
 | 
			
		||||
            states = torch.load(f, *args, **kwargs)
 | 
			
		||||
    def load(load_path: str, **kwargs):
 | 
			
		||||
        assert os.path.exists(load_path), f"{load_path} is not found!"
 | 
			
		||||
        with open(load_path, "rb") as f:
 | 
			
		||||
            states = torch.load(f, **kwargs)
 | 
			
		||||
        return states
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def assert_fp_exists(handler, folder):
 | 
			
		||||
        assert isinstance(handler, LocalClient)
 | 
			
		||||
    def assert_fp_exists(folder):
 | 
			
		||||
        assert os.path.exists(folder), folder
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def get_fns(handler, folder):
 | 
			
		||||
        assert isinstance(handler, LocalClient)
 | 
			
		||||
        assert os.path.exists(folder), f"folder '{folder}' not exists!"
 | 
			
		||||
        fns = os.listdir(folder)
 | 
			
		||||
        return fns
 | 
			
		||||
    def get_fns(folder):
 | 
			
		||||
        if not os.path.exists(folder):
 | 
			
		||||
            if gpc.is_rank_for_log():
 | 
			
		||||
                logger.warning(f"'{folder}' not found!")
 | 
			
		||||
            return None
 | 
			
		||||
        else:
 | 
			
		||||
            return os.listdir(folder)
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def delete_obj(handler, fp: str):
 | 
			
		||||
        assert isinstance(handler, LocalClient)
 | 
			
		||||
    def delete_obj(fp: str):
 | 
			
		||||
        if not os.path.isdir(fp):
 | 
			
		||||
            os.remove(fp)
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -327,7 +374,10 @@ def get_boto3_meta(fp: str, tmp_local_folder: str, is_async: bool) -> Boto3MetaI
 | 
			
		|||
    assert match is not None, f"url '{fp}' is not a valid boto3 url"
 | 
			
		||||
    bucket_name, endpoint = match.group(1), match.group(2)
 | 
			
		||||
    endpoint = "http://" + endpoint + ":80"
 | 
			
		||||
    tmp_step_file = get_tmp_file_name(tmp_local_folder, fp)
 | 
			
		||||
    if is_async:
 | 
			
		||||
        tmp_step_file = get_tmp_file_name(tmp_local_folder, fp)
 | 
			
		||||
    else:
 | 
			
		||||
        tmp_step_file = None
 | 
			
		||||
    return Boto3MetaInfo(
 | 
			
		||||
        is_async=is_async,
 | 
			
		||||
        handler=None,
 | 
			
		||||
| 
						 | 
				
			
			@ -341,7 +391,7 @@ def get_boto3_meta(fp: str, tmp_local_folder: str, is_async: bool) -> Boto3MetaI
 | 
			
		|||
 | 
			
		||||
def get_local_meta(fp: str) -> LocalMetaInfo:
 | 
			
		||||
    assert not fp.startswith("s3://"), f"Path '{fp}' is not a local path"
 | 
			
		||||
    return LocalMetaInfo(None, fp)
 | 
			
		||||
    return LocalMetaInfo(fp)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def get_mount_point_free_size(path: str):
 | 
			
		||||
| 
						 | 
				
			
			@ -427,7 +477,7 @@ class StorageManager(metaclass=SingletonMeta):
 | 
			
		|||
                logger.error(f'tmp_local_folder only have "{free_size}" GB free space, less then 100 GB!')
 | 
			
		||||
                raise RuntimeError(f"Insufficient temporary storage space on {socket.gethostname()}")
 | 
			
		||||
 | 
			
		||||
    def _get_client(self, path=str) -> Union[Boto3MetaInfo, LocalMetaInfo]:
 | 
			
		||||
    def _get_client(self, path: str, async_mode: bool = False) -> Union[Boto3MetaInfo, LocalMetaInfo]:
 | 
			
		||||
        """
 | 
			
		||||
        example:
 | 
			
		||||
        local:/path/to/checkpoint
 | 
			
		||||
| 
						 | 
				
			
			@ -436,17 +486,14 @@ class StorageManager(metaclass=SingletonMeta):
 | 
			
		|||
        Args:
 | 
			
		||||
            path (str): _description_
 | 
			
		||||
        """
 | 
			
		||||
        try:
 | 
			
		||||
            backend, path = path.split(":", maxsplit=1)
 | 
			
		||||
        except Exception as exc:
 | 
			
		||||
            raise AttributeError(f"Given path '{path}' is not startwith backend prefix:'local/boto3'") from exc
 | 
			
		||||
        backend, path = try_get_storage_backend(path)
 | 
			
		||||
 | 
			
		||||
        init_args = (None,)
 | 
			
		||||
        if backend == "local":
 | 
			
		||||
            meta_info = get_local_meta(path)
 | 
			
		||||
            backend_key = backend
 | 
			
		||||
        elif backend == "boto3":
 | 
			
		||||
            meta_info = get_boto3_meta(path, self.tmp_local_folder, self.async_mode)
 | 
			
		||||
            meta_info = get_boto3_meta(path, self.tmp_local_folder, async_mode)
 | 
			
		||||
            backend_key = backend + ":" + meta_info.endpoint
 | 
			
		||||
            init_args = (meta_info.endpoint,)
 | 
			
		||||
            if (
 | 
			
		||||
| 
						 | 
				
			
			@ -474,17 +521,22 @@ class StorageManager(metaclass=SingletonMeta):
 | 
			
		|||
 | 
			
		||||
    def assert_fp_exists(self, folder) -> None:
 | 
			
		||||
        meta = self._get_client(path=folder)
 | 
			
		||||
        meta.client.assert_fp_exists(*unpack_meta(meta))
 | 
			
		||||
        meta.client.assert_fp_exists(*unpack_nosave_meta(meta))
 | 
			
		||||
 | 
			
		||||
    def get_fns(self, folder) -> List[str]:
 | 
			
		||||
        meta = self._get_client(path=folder)
 | 
			
		||||
        return meta.client.get_fns(*unpack_meta(meta))
 | 
			
		||||
        return meta.client.get_fns(*unpack_nosave_meta(meta))
 | 
			
		||||
 | 
			
		||||
    def save(self, save_path: str, saved_obj: Any, *args, async_upload=None, **kwargs):
 | 
			
		||||
        meta = self._get_client(path=save_path)
 | 
			
		||||
    def save(self, save_path: str, to_save_obj: Any, async_upload=None, **kwargs):
 | 
			
		||||
 | 
			
		||||
        if async_upload is None:
 | 
			
		||||
            async_upload = self.async_mode
 | 
			
		||||
 | 
			
		||||
        if not save_path.startswith("boto3:"):
 | 
			
		||||
            async_upload = False
 | 
			
		||||
 | 
			
		||||
        meta = self._get_client(save_path, async_upload)
 | 
			
		||||
 | 
			
		||||
        if async_upload:
 | 
			
		||||
            assert (
 | 
			
		||||
                self.tmp_local_folder
 | 
			
		||||
| 
						 | 
				
			
			@ -492,22 +544,22 @@ class StorageManager(metaclass=SingletonMeta):
 | 
			
		|||
            tmp_step_file = meta.local_nvme_path
 | 
			
		||||
            self._to_be_del_files.append(tmp_step_file)
 | 
			
		||||
            with open(tmp_step_file, "wb") as f:
 | 
			
		||||
                torch.save(saved_obj, f, pickle_protocol=pickle.HIGHEST_PROTOCOL)
 | 
			
		||||
            self.async_executor(meta.async_upload_fn, *unpack_meta(meta))
 | 
			
		||||
                torch.save(to_save_obj, f, pickle_protocol=pickle.HIGHEST_PROTOCOL)
 | 
			
		||||
            self.async_executor(meta.async_upload_fn, *unpack_save_meta(meta))
 | 
			
		||||
            os.chmod(tmp_step_file, stat.S_IRWXU | stat.S_IRWXG | stat.S_IRWXO)
 | 
			
		||||
            self.async_task_peeding = True
 | 
			
		||||
        else:
 | 
			
		||||
            meta.client.sync_upload_fileobj(*unpack_meta(meta), *args, saved_obj=saved_obj, **kwargs)
 | 
			
		||||
            meta.client.sync_upload_fileobj(*unpack_save_meta(meta), saved_obj=to_save_obj, **kwargs)
 | 
			
		||||
            self.upload_count += 1
 | 
			
		||||
 | 
			
		||||
    def load(self, load_path: str, *args, **kwargs) -> Any:
 | 
			
		||||
    def load(self, load_path: str, **kwargs) -> Any:
 | 
			
		||||
        self.wait()
 | 
			
		||||
        meta = self._get_client(path=load_path)
 | 
			
		||||
        return meta.client.load(*unpack_meta(meta), *args, **kwargs)
 | 
			
		||||
        return meta.client.load(*unpack_nosave_meta(meta), **kwargs)
 | 
			
		||||
 | 
			
		||||
    def delete_obj(self, fp: str):
 | 
			
		||||
        meta = self._get_client(path=fp)
 | 
			
		||||
        meta.client.delete_obj(*unpack_meta(meta))
 | 
			
		||||
        meta.client.delete_obj(*unpack_nosave_meta(meta))
 | 
			
		||||
 | 
			
		||||
    def _del_tmp_folder(self):
 | 
			
		||||
        for fp in self._to_be_del_files:
 | 
			
		||||
| 
						 | 
				
			
			@ -594,23 +646,24 @@ class StorageManager(metaclass=SingletonMeta):
 | 
			
		|||
 | 
			
		||||
        if gpc.is_rank_for_log():
 | 
			
		||||
            self.upload_count += 1
 | 
			
		||||
            if self.async_mode:
 | 
			
		||||
            if self.async_mode and self.latest_save_folder:
 | 
			
		||||
                self.save(
 | 
			
		||||
                    os.path.join(self.latest_save_folder, f"{self.latest_save_step}.step"),
 | 
			
		||||
                    saved_obj=dict({"step": self.latest_save_step}),
 | 
			
		||||
                    to_save_obj=dict({"step": self.latest_save_step}),
 | 
			
		||||
                    async_upload=False,
 | 
			
		||||
                )
 | 
			
		||||
                self.latest_save_folder = None
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
storage_manager: StorageManager = None
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def init_storage_manager(ckpt_config):
 | 
			
		||||
def init_storage_manager(enable_save_ckpt, async_upload_tmp_folder, async_upload):
 | 
			
		||||
    global storage_manager
 | 
			
		||||
    storage_manager = StorageManager(
 | 
			
		||||
        ckpt_config.enable_save_ckpt,
 | 
			
		||||
        tmp_local_folder=ckpt_config.async_upload_tmp_folder,
 | 
			
		||||
        async_mode=ckpt_config.async_upload,
 | 
			
		||||
        enable_save_ckpt,
 | 
			
		||||
        tmp_local_folder=async_upload_tmp_folder,
 | 
			
		||||
        async_mode=async_upload,
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -1,4 +1,13 @@
 | 
			
		|||
import datetime
 | 
			
		||||
import os
 | 
			
		||||
import signal
 | 
			
		||||
import socket
 | 
			
		||||
import traceback
 | 
			
		||||
from functools import wraps
 | 
			
		||||
 | 
			
		||||
from internlm.utils.logger import get_logger
 | 
			
		||||
 | 
			
		||||
logger = get_logger(__file__)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class Timeout:
 | 
			
		||||
| 
						 | 
				
			
			@ -24,3 +33,81 @@ class Timeout:
 | 
			
		|||
 | 
			
		||||
    def __exit__(self, error_type, value, traceback):
 | 
			
		||||
        signal.alarm(0)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
ENABLE_TIMEOUT = os.getenv("INTERNLM_ENABLE_TIMEOUT", None)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
timeout_threshold_dict = {
 | 
			
		||||
    "initialize_distributed_env": 120,
 | 
			
		||||
    "nopp_forward_backward_step": 360,
 | 
			
		||||
    "initialize_model": 10,
 | 
			
		||||
    "initialize_optimizer": 20,
 | 
			
		||||
    "optim_step": 30,
 | 
			
		||||
    "get_train_data_loader": 600,
 | 
			
		||||
    "get_validation_data_loader": 60,
 | 
			
		||||
    "load_new_batch": 10,
 | 
			
		||||
    "record_current_batch_training_metrics": 10,
 | 
			
		||||
    "save_checkpoint": 1200,
 | 
			
		||||
    "interleaved_forward_backward_step": 600,
 | 
			
		||||
    "nointerleaved_forward_backward_step": 600,
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
if ENABLE_TIMEOUT is not None:
 | 
			
		||||
    os.environ["NCCL_ASYNC_ERROR_HANDLING"] = "1"
 | 
			
		||||
    LLM_NCCL_TIMEOUT = datetime.timedelta(seconds=int(os.getenv("NCCL_TIMEOUT", str(60))))
 | 
			
		||||
else:
 | 
			
		||||
    timeout_threshold_dict = dict.fromkeys(timeout_threshold_dict.keys(), 0)
 | 
			
		||||
    LLM_NCCL_TIMEOUT = datetime.timedelta(seconds=1800)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def try_get_gpc_rank():
 | 
			
		||||
    try:
 | 
			
		||||
        from internlm.core.context import global_context as gpc
 | 
			
		||||
 | 
			
		||||
        rank = gpc.get_global_rank()
 | 
			
		||||
    except:  # noqa  # pylint: disable=bare-except
 | 
			
		||||
        rank = "unknown"
 | 
			
		||||
 | 
			
		||||
    return f"host-{socket.gethostname()}-rank-{rank}"
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def llm_timeout(seconds=0, func_name=None):
 | 
			
		||||
    """timeout decorator, Note that this decorator cannot be reentrant,
 | 
			
		||||
    otherwise the signal will be reset.
 | 
			
		||||
 | 
			
		||||
    Args:
 | 
			
		||||
        seconds (int, optional): timeout threshold. Defaults to 300.
 | 
			
		||||
        func_name (str, optional): the func who is been waited to timeout.
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    def decorator(func):
 | 
			
		||||
        nonlocal func_name
 | 
			
		||||
        if func_name is None:
 | 
			
		||||
            func_name = func.__name__
 | 
			
		||||
 | 
			
		||||
        @wraps(func)
 | 
			
		||||
        def wrapper(*args, **kwargs):
 | 
			
		||||
            def _handle_timeout(signum, frame):
 | 
			
		||||
                raise TimeoutError
 | 
			
		||||
 | 
			
		||||
            nonlocal seconds
 | 
			
		||||
            seconds = timeout_threshold_dict.get(func_name, seconds)
 | 
			
		||||
 | 
			
		||||
            if seconds > 0:
 | 
			
		||||
                signal.signal(signal.SIGALRM, _handle_timeout)
 | 
			
		||||
                signal.alarm(seconds)
 | 
			
		||||
 | 
			
		||||
            try:
 | 
			
		||||
                result = func(*args, **kwargs)
 | 
			
		||||
            except TimeoutError as e:
 | 
			
		||||
                logger.error(f"TimeoutError at {try_get_gpc_rank()}: {func_name}\\n {traceback.format_exc()}")
 | 
			
		||||
                raise e
 | 
			
		||||
            finally:
 | 
			
		||||
                signal.alarm(0)
 | 
			
		||||
 | 
			
		||||
            return result
 | 
			
		||||
 | 
			
		||||
        return wrapper
 | 
			
		||||
 | 
			
		||||
    return decorator
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -0,0 +1,181 @@
 | 
			
		|||
import os
 | 
			
		||||
import shutil
 | 
			
		||||
from subprocess import PIPE, STDOUT, Popen
 | 
			
		||||
 | 
			
		||||
import pytest
 | 
			
		||||
import torch
 | 
			
		||||
 | 
			
		||||
from internlm.core.context import global_context as gpc
 | 
			
		||||
from internlm.core.context.parallel_context import Config
 | 
			
		||||
from internlm.solver.optimizer.hybrid_zero_optim import HybridZeroOptimizer
 | 
			
		||||
from internlm.utils.common import SingletonMeta
 | 
			
		||||
 | 
			
		||||
OSS_NAME = os.environ["OSS_BUCKET_NAME"]
 | 
			
		||||
OSS_IP = os.environ["OSS_IP"]
 | 
			
		||||
USER = os.environ["USER"]
 | 
			
		||||
JOB_NAME = "CI_TEST"
 | 
			
		||||
LOCAL_SAVE_PATH = "local:local_ckpt"
 | 
			
		||||
 | 
			
		||||
BOTO_SAVE_PATH = f"boto3:s3://{OSS_NAME}.{OSS_IP}/{USER}/{JOB_NAME}"
 | 
			
		||||
BOTO_SAVE_PATH_NO_PRFIX = f"s3://{OSS_NAME}.{OSS_IP}/{USER}/{JOB_NAME}/"
 | 
			
		||||
 | 
			
		||||
ASYNC_TMP_FOLDER = "./async_tmp_folder"
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# 1B
 | 
			
		||||
init_config = Config(
 | 
			
		||||
    dict(
 | 
			
		||||
        parallel=dict(zero1=1, pipeline=dict(size=1, interleaved_overlap=False), sequence_parallel=False, tensor=1),
 | 
			
		||||
        model_type="INTERNLM",
 | 
			
		||||
        adam=dict(
 | 
			
		||||
            lr=1e-4,
 | 
			
		||||
        ),
 | 
			
		||||
        data=dict(seq_len=2048, micro_num=1, micro_bsz=1, pack_sample_into_one=False, min_length=0, total_steps=9999),
 | 
			
		||||
        model=dict(
 | 
			
		||||
            checkpoint=False,
 | 
			
		||||
            num_attention_heads=2,
 | 
			
		||||
            embed_split_hidden=True,
 | 
			
		||||
            vocab_size=103168,
 | 
			
		||||
            embed_grad_scale=1,
 | 
			
		||||
            parallel_output=True,
 | 
			
		||||
            hidden_size=1024,
 | 
			
		||||
            num_layers=2,
 | 
			
		||||
            mlp_ratio=1,
 | 
			
		||||
            apply_post_layer_norm=False,
 | 
			
		||||
            dtype=torch.bfloat16,
 | 
			
		||||
            norm_type="rmsnorm",
 | 
			
		||||
            layer_norm_epsilon=1e-5,
 | 
			
		||||
            use_flash_attn=True,
 | 
			
		||||
            num_chunks=1,
 | 
			
		||||
        ),
 | 
			
		||||
        resume_tb_folder="",
 | 
			
		||||
        tensorboard_folder="",
 | 
			
		||||
    )
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def init_naive_model():
 | 
			
		||||
    # let MODEL_INITIALIZER to work
 | 
			
		||||
    import internlm.model.modeling_internlm  # noqa # pylint: disable=unused-import
 | 
			
		||||
    from internlm.core.naive_amp import NaiveAMPModel
 | 
			
		||||
    from internlm.utils.registry import MODEL_INITIALIZER
 | 
			
		||||
 | 
			
		||||
    model = MODEL_INITIALIZER.get_module(module_name=gpc.config.model_type)(**(init_config.model))
 | 
			
		||||
    model = NaiveAMPModel(
 | 
			
		||||
        model=model,
 | 
			
		||||
        output_to_fp32=False,
 | 
			
		||||
        dtype=torch.bfloat16,
 | 
			
		||||
        sync_buffer=False,
 | 
			
		||||
    )
 | 
			
		||||
    return model
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def init_naive_optim(model):
 | 
			
		||||
    naive_optimizer = torch.optim.AdamW(
 | 
			
		||||
        params=[{"params": model.parameters(), "weight_decay": 0.01}],
 | 
			
		||||
        lr=1e-4,
 | 
			
		||||
        betas=(0.9, 0.95),
 | 
			
		||||
        eps=1e-8,
 | 
			
		||||
    )
 | 
			
		||||
    return naive_optimizer
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def init_hybrid_optim(model):
 | 
			
		||||
    naive_optimizer = torch.optim.AdamW(
 | 
			
		||||
        params=[{"params": model.parameters(), "weight_decay": 0.01}],
 | 
			
		||||
        lr=1e-4,
 | 
			
		||||
        betas=(0.9, 0.95),
 | 
			
		||||
        eps=1e-8,
 | 
			
		||||
    )
 | 
			
		||||
    optimizer = HybridZeroOptimizer(
 | 
			
		||||
        naive_optimizer,
 | 
			
		||||
        grad_scal_cfg=Config(
 | 
			
		||||
            dict(
 | 
			
		||||
                fp16=dict(
 | 
			
		||||
                    initial_scale=2**16,
 | 
			
		||||
                    min_scale=1,
 | 
			
		||||
                    growth_interval=1000,
 | 
			
		||||
                ),
 | 
			
		||||
                growth_factor=2,
 | 
			
		||||
                backoff_factor=0.5,
 | 
			
		||||
                max_scale=2**24,
 | 
			
		||||
                hysteresis=2,
 | 
			
		||||
            )
 | 
			
		||||
        ),
 | 
			
		||||
        zero_cfg=Config(
 | 
			
		||||
            dict(
 | 
			
		||||
                overlap_sync_grad=False,
 | 
			
		||||
                overlap_sync_param=False,
 | 
			
		||||
                reduce_bucket_size=512 * 1024 * 1024,
 | 
			
		||||
                clip_grad_norm=1.0,
 | 
			
		||||
            )
 | 
			
		||||
        ),
 | 
			
		||||
        param_bcast_sync_handler=None,
 | 
			
		||||
    )
 | 
			
		||||
    return optimizer
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@pytest.fixture(autouse=True, scope="function")
 | 
			
		||||
def reset_singletons():
 | 
			
		||||
    SingletonMeta._instances = {}
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def reset_seed():
 | 
			
		||||
    from internlm.core.context.random import _SEED_MANAGER
 | 
			
		||||
 | 
			
		||||
    _SEED_MANAGER.reset()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@pytest.fixture(scope="module")
 | 
			
		||||
def init_dist_and_model(rank=0, world_size=1):
 | 
			
		||||
    from internlm.initialize import initialize_distributed_env
 | 
			
		||||
 | 
			
		||||
    os.environ["RANK"] = str(rank)
 | 
			
		||||
    os.environ["LOCAL_RANK"] = str(rank)
 | 
			
		||||
    os.environ["WORLD_SIZE"] = str(world_size)
 | 
			
		||||
    os.environ["MASTER_ADDR"] = "127.0.0.1"
 | 
			
		||||
    os.environ["MASTER_PORT"] = "12377"
 | 
			
		||||
    initialize_distributed_env(config=init_config, launcher="torch", master_port=12377, args_check=False)
 | 
			
		||||
 | 
			
		||||
    # setup
 | 
			
		||||
    print("set up", flush=True)
 | 
			
		||||
    model = init_naive_model()
 | 
			
		||||
    # opim = init_naive_optim(model)
 | 
			
		||||
    opim = init_hybrid_optim(model)
 | 
			
		||||
 | 
			
		||||
    yield model, opim
 | 
			
		||||
 | 
			
		||||
    # teardown
 | 
			
		||||
    del model, opim
 | 
			
		||||
    print("teardown", flush=True)
 | 
			
		||||
    gpc.destroy()
 | 
			
		||||
    reset_seed()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def enter_flag(text):
 | 
			
		||||
    print(f"{text} begin!", flush=True)
 | 
			
		||||
    yield
 | 
			
		||||
    print(f"{text} end!", flush=True)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def del_tmp_file():
 | 
			
		||||
    try:
 | 
			
		||||
        shutil.rmtree(ASYNC_TMP_FOLDER, ignore_errors=True)
 | 
			
		||||
    except FileNotFoundError:
 | 
			
		||||
        pass
 | 
			
		||||
 | 
			
		||||
    try:
 | 
			
		||||
        shutil.rmtree(LOCAL_SAVE_PATH.split(":")[1], ignore_errors=True)
 | 
			
		||||
    except FileNotFoundError:
 | 
			
		||||
        pass
 | 
			
		||||
 | 
			
		||||
    try:
 | 
			
		||||
        cmd = r"/mnt/petrelfs/share/sensesync --dryrun --deleteSrc cp " + BOTO_SAVE_PATH_NO_PRFIX + " / "
 | 
			
		||||
        with Popen(cmd, stdout=PIPE, stderr=STDOUT, shell=True) as output:
 | 
			
		||||
            results, presults = "", ""
 | 
			
		||||
            for line in iter(output.stdout.readline, b""):
 | 
			
		||||
                results += str(line.rstrip())
 | 
			
		||||
                presults += line.rstrip().decode() + "\n"
 | 
			
		||||
        print(presults, flush=True)
 | 
			
		||||
    except FileNotFoundError:
 | 
			
		||||
        pass
 | 
			
		||||
| 
						 | 
				
			
			@ -0,0 +1,247 @@
 | 
			
		|||
import os
 | 
			
		||||
 | 
			
		||||
import pytest
 | 
			
		||||
import torch
 | 
			
		||||
 | 
			
		||||
from internlm.core.context import global_context as gpc
 | 
			
		||||
from internlm.core.context.parallel_context import Config
 | 
			
		||||
from internlm.core.trainer import TrainState
 | 
			
		||||
from internlm.solver.optimizer.hybrid_zero_optim import HybridZeroOptimizer
 | 
			
		||||
from internlm.utils.common import SingletonMeta
 | 
			
		||||
from internlm.utils.model_checkpoint import CheckpointManager
 | 
			
		||||
from internlm.utils.storage_manager import wait_async_upload_finish
 | 
			
		||||
from tests.test_utils.common_fixture import (  # noqa # pylint: disable=unused-import
 | 
			
		||||
    ASYNC_TMP_FOLDER,
 | 
			
		||||
    BOTO_SAVE_PATH,
 | 
			
		||||
    LOCAL_SAVE_PATH,
 | 
			
		||||
    del_tmp_file,
 | 
			
		||||
    init_dist_and_model,
 | 
			
		||||
    reset_singletons,
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
TOTAL_STEP = 6
 | 
			
		||||
 | 
			
		||||
CKPT_EVERY = 4
 | 
			
		||||
SNPASHOT_EVERY = 2
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
ckpt_config_list = [
 | 
			
		||||
    # Old interface format
 | 
			
		||||
    dict(
 | 
			
		||||
        enable_save_ckpt=True,
 | 
			
		||||
        save_ckpt_folder=BOTO_SAVE_PATH,
 | 
			
		||||
        load_optimizer=True,
 | 
			
		||||
        checkpoint_every=CKPT_EVERY,
 | 
			
		||||
        async_upload=True,
 | 
			
		||||
        async_upload_tmp_folder=ASYNC_TMP_FOLDER,
 | 
			
		||||
        snapshot_ckpt_folder="/".join([BOTO_SAVE_PATH, "snapshot"]),
 | 
			
		||||
        oss_snapshot_freq=SNPASHOT_EVERY,
 | 
			
		||||
        stop_file_path=None,
 | 
			
		||||
        load_model_only_folder=None,
 | 
			
		||||
        load_given_ckpt=False,
 | 
			
		||||
        load_ckpt_folder=None,
 | 
			
		||||
        is_old_api=True,
 | 
			
		||||
    ),
 | 
			
		||||
    # Old interface format
 | 
			
		||||
    dict(
 | 
			
		||||
        enable_save_ckpt=True,
 | 
			
		||||
        save_ckpt_folder=LOCAL_SAVE_PATH,
 | 
			
		||||
        load_optimizer=True,
 | 
			
		||||
        checkpoint_every=CKPT_EVERY,
 | 
			
		||||
        async_upload=False,
 | 
			
		||||
        async_upload_tmp_folder=ASYNC_TMP_FOLDER,
 | 
			
		||||
        snapshot_ckpt_folder="/".join([LOCAL_SAVE_PATH, "snapshot"]),
 | 
			
		||||
        oss_snapshot_freq=SNPASHOT_EVERY,
 | 
			
		||||
        stop_file_path=None,
 | 
			
		||||
        load_model_only_folder=None,
 | 
			
		||||
        load_given_ckpt=False,
 | 
			
		||||
        load_ckpt_folder=None,
 | 
			
		||||
        is_old_api=True,
 | 
			
		||||
    ),
 | 
			
		||||
    # New interface format
 | 
			
		||||
    dict(
 | 
			
		||||
        enable_save_ckpt=True,
 | 
			
		||||
        save_ckpt_folder=BOTO_SAVE_PATH,
 | 
			
		||||
        checkpoint_every=CKPT_EVERY,
 | 
			
		||||
        async_upload=True,
 | 
			
		||||
        async_upload_tmp_folder=ASYNC_TMP_FOLDER,
 | 
			
		||||
        oss_snapshot_freq=SNPASHOT_EVERY,
 | 
			
		||||
        stop_file_path=None,
 | 
			
		||||
        is_old_api=False,
 | 
			
		||||
        auto_resume=True,
 | 
			
		||||
    ),
 | 
			
		||||
    dict(
 | 
			
		||||
        enable_save_ckpt=True,
 | 
			
		||||
        save_ckpt_folder=LOCAL_SAVE_PATH,
 | 
			
		||||
        checkpoint_every=CKPT_EVERY,
 | 
			
		||||
        async_upload=False,
 | 
			
		||||
        async_upload_tmp_folder=ASYNC_TMP_FOLDER,
 | 
			
		||||
        oss_snapshot_freq=SNPASHOT_EVERY,
 | 
			
		||||
        stop_file_path=None,
 | 
			
		||||
        load_ckpt_folder=None,
 | 
			
		||||
        is_old_api=False,
 | 
			
		||||
        auto_resume=True,
 | 
			
		||||
    ),
 | 
			
		||||
]
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def overwrite_optim_state(optim, set_value):
 | 
			
		||||
    if isinstance(optim, HybridZeroOptimizer):
 | 
			
		||||
        for group_id, p in optim._fp32_flat_param_groups_of_current_rank.items():
 | 
			
		||||
            if optim._zero_local_rank not in optim.param_group_no_params_ranks[group_id]:
 | 
			
		||||
                # p.copy_(torch.full_like(p, set_value, dtype=p.dtype))
 | 
			
		||||
                p.data.fill_(set_value)
 | 
			
		||||
        for group_id in range(len(optim._fp16_param_groups)):
 | 
			
		||||
            if optim._zero_local_rank not in optim.param_group_no_params_ranks[group_id]:
 | 
			
		||||
                fp16_p = optim._param_store.get_flat_fp16_param_by_rank_group(
 | 
			
		||||
                    rank=optim._zero_local_rank, group_id=group_id
 | 
			
		||||
                )
 | 
			
		||||
                fp16_p.fill_(set_value)
 | 
			
		||||
    else:
 | 
			
		||||
        for group in optim.param_groups:
 | 
			
		||||
            for p in group["params"]:
 | 
			
		||||
                # p.copy_(torch.full_like(p, set_value, dtype=p.dtype))
 | 
			
		||||
                p.data.fill_(set_value)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def compare_optim_state(optim1, optim2):
 | 
			
		||||
    re = True
 | 
			
		||||
    if isinstance(optim1, HybridZeroOptimizer):
 | 
			
		||||
        fp32_buff1 = optim1._fp32_flat_param_groups_of_current_rank
 | 
			
		||||
        fp32_buff2 = optim2._fp32_flat_param_groups_of_current_rank
 | 
			
		||||
        for group_id_1, group_id_2 in zip(fp32_buff1, fp32_buff2):
 | 
			
		||||
            re &= group_id_1 == group_id_2
 | 
			
		||||
            if optim1.zero_local_rank not in optim1.param_group_no_params_ranks[group_id_1]:
 | 
			
		||||
                re &= torch.equal(fp32_buff1[group_id_1], fp32_buff1[group_id_2])
 | 
			
		||||
    else:
 | 
			
		||||
        for group1, group2 in zip(optim1.param_groups, optim2.param_groups):
 | 
			
		||||
            for p1, p2 in zip(group1["params"], group2["params"]):
 | 
			
		||||
                re &= torch.equal(p1, p2)
 | 
			
		||||
    return re
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def compare_optim_value(optim, value):
 | 
			
		||||
    re = True
 | 
			
		||||
    if isinstance(optim, HybridZeroOptimizer):
 | 
			
		||||
        for group_id, p in optim._fp32_flat_param_groups_of_current_rank.items():
 | 
			
		||||
            if optim._zero_local_rank not in optim.param_group_no_params_ranks[group_id]:
 | 
			
		||||
                re &= torch.equal(p, torch.full_like(p, value, dtype=p.dtype))
 | 
			
		||||
        for group_id in range(len(optim._fp16_param_groups)):
 | 
			
		||||
            if optim._zero_local_rank not in optim.param_group_no_params_ranks[group_id]:
 | 
			
		||||
                fp16_p = optim._param_store.get_flat_fp16_param_by_rank_group(
 | 
			
		||||
                    rank=optim._zero_local_rank, group_id=group_id
 | 
			
		||||
                )
 | 
			
		||||
                re &= torch.equal(fp16_p, torch.full_like(fp16_p, value, dtype=fp16_p.dtype))
 | 
			
		||||
    else:
 | 
			
		||||
        for group in optim.param_groups:
 | 
			
		||||
            for p in group["params"]:
 | 
			
		||||
                re &= torch.equal(p, torch.full_like(p, value, dtype=p.dtype))
 | 
			
		||||
    return re
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def overwrite_model_value(model, value):
 | 
			
		||||
    for p in model.parameters():
 | 
			
		||||
        # p.copy_(torch.full_like(p, value, dtype=p.dtype))
 | 
			
		||||
        p.data.fill_(value)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def compare_model_value(model, value):
 | 
			
		||||
    re = True
 | 
			
		||||
    for p in model.parameters():
 | 
			
		||||
        re &= torch.equal(p, torch.full_like(p, value, dtype=p.dtype))
 | 
			
		||||
    return re
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@pytest.fixture(scope="function")
 | 
			
		||||
def del_tmp():
 | 
			
		||||
    del_tmp_file()
 | 
			
		||||
    yield
 | 
			
		||||
    del_tmp_file()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@pytest.mark.usefixtures("del_tmp")
 | 
			
		||||
@pytest.mark.usefixtures("reset_singletons")
 | 
			
		||||
@pytest.mark.parametrize("ckpt_config", ckpt_config_list)
 | 
			
		||||
def test_ckpt_mm(ckpt_config, init_dist_and_model):  # noqa # pylint: disable=unused-import
 | 
			
		||||
    from internlm.utils.model_checkpoint import CheckpointLoadMask, CheckpointLoadType
 | 
			
		||||
 | 
			
		||||
    ckpt_config = Config(ckpt_config)
 | 
			
		||||
    assert ckpt_config.checkpoint_every < TOTAL_STEP
 | 
			
		||||
    assert ckpt_config.oss_snapshot_freq < TOTAL_STEP
 | 
			
		||||
 | 
			
		||||
    model, opim = init_dist_and_model
 | 
			
		||||
    train_state = TrainState(gpc.config, None)
 | 
			
		||||
    if isinstance(opim, HybridZeroOptimizer):
 | 
			
		||||
        print("Is HybridZeroOptimizer!", flush=True)
 | 
			
		||||
    else:
 | 
			
		||||
        print("Is naive Adam!", flush=True)
 | 
			
		||||
 | 
			
		||||
    ckpt_mm = CheckpointManager(ckpt_config, model=model, optimizer=opim)
 | 
			
		||||
    latest_ckpt_step = None
 | 
			
		||||
    for i in range(TOTAL_STEP + 1):
 | 
			
		||||
        overwrite_model_value(model, i)
 | 
			
		||||
        overwrite_optim_state(opim, i)
 | 
			
		||||
 | 
			
		||||
        train_state.batch_count = i
 | 
			
		||||
        train_state.step_count += 1
 | 
			
		||||
 | 
			
		||||
        save_ckpts, _, _ = ckpt_mm.is_now_to_save_ckpt(train_state)
 | 
			
		||||
        if save_ckpts:
 | 
			
		||||
            latest_ckpt_step = i
 | 
			
		||||
 | 
			
		||||
        ckpt_mm.try_save_checkpoint(train_state)
 | 
			
		||||
 | 
			
		||||
    wait_async_upload_finish()
 | 
			
		||||
    latest_ckpt_info = ckpt_mm.query_lastest_ckpt()
 | 
			
		||||
    assert latest_ckpt_info is not None
 | 
			
		||||
    latest_ckpt = latest_ckpt_info["path"]
 | 
			
		||||
    if ckpt_mm.save_ckpt_folder.startswith("local"):
 | 
			
		||||
        assert latest_ckpt == "local:local_ckpt/snapshot/0", latest_ckpt
 | 
			
		||||
    else:
 | 
			
		||||
        assert latest_ckpt == f"{BOTO_SAVE_PATH}/snapshot/0", latest_ckpt
 | 
			
		||||
 | 
			
		||||
    del ckpt_mm
 | 
			
		||||
    SingletonMeta._instances = {}
 | 
			
		||||
    ckpt_mm = CheckpointManager(ckpt_config, model=model, optimizer=opim)
 | 
			
		||||
    ckpt_mm.try_resume_training(train_state)
 | 
			
		||||
    assert latest_ckpt_step == 5
 | 
			
		||||
    assert train_state.step_count == 6
 | 
			
		||||
    assert train_state.batch_count == 6
 | 
			
		||||
    assert compare_optim_value(ckpt_mm.optimizer, latest_ckpt_step), ckpt_mm.optimizer.param_groups[0]["params"][0]
 | 
			
		||||
    assert compare_model_value(ckpt_mm.model, latest_ckpt_step), list(ckpt_mm.model.parameters())[0][0]
 | 
			
		||||
 | 
			
		||||
    if ckpt_mm.save_ckpt_folder.startswith("local:"):
 | 
			
		||||
        ckpt_mm.load_ckpt_info = dict(
 | 
			
		||||
            path=os.path.join(LOCAL_SAVE_PATH, "4"),
 | 
			
		||||
            content=CheckpointLoadMask(("all",)),
 | 
			
		||||
            ckpt_type=CheckpointLoadType.INTERNLM,
 | 
			
		||||
        )
 | 
			
		||||
    else:
 | 
			
		||||
        ckpt_mm.load_ckpt_info = dict(
 | 
			
		||||
            path=os.path.join(BOTO_SAVE_PATH, "4"),
 | 
			
		||||
            content=CheckpointLoadMask(("all",)),
 | 
			
		||||
            ckpt_type=CheckpointLoadType.INTERNLM,
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    ckpt_mm.try_resume_training(train_state)
 | 
			
		||||
 | 
			
		||||
    assert train_state.step_count == 4
 | 
			
		||||
    assert train_state.batch_count == 4
 | 
			
		||||
    assert compare_optim_value(ckpt_mm.optimizer, 3), ckpt_mm.optimizer.param_groups[0]["params"][0]
 | 
			
		||||
    assert compare_model_value(ckpt_mm.model, 3), list(ckpt_mm.model.parameters())[0][0]
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@pytest.mark.usefixtures("del_tmp")
 | 
			
		||||
@pytest.mark.usefixtures("reset_singletons")
 | 
			
		||||
@pytest.mark.parametrize("ckpt_config", ckpt_config_list)
 | 
			
		||||
def test_ckpt_mm_ping(ckpt_config, init_dist_and_model):  # noqa # pylint: disable=unused-import
 | 
			
		||||
    ckpt_config = Config(ckpt_config)
 | 
			
		||||
 | 
			
		||||
    model, opim = init_dist_and_model
 | 
			
		||||
    SingletonMeta._instances = {}
 | 
			
		||||
    ckpt_mm = CheckpointManager(ckpt_config, model=model, optimizer=opim)
 | 
			
		||||
    ckpt_mm.try_ping_storage()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if __name__ == "__main__":
 | 
			
		||||
    pytest.main()
 | 
			
		||||
| 
						 | 
				
			
			@ -0,0 +1,89 @@
 | 
			
		|||
import os
 | 
			
		||||
 | 
			
		||||
import pytest
 | 
			
		||||
import torch
 | 
			
		||||
 | 
			
		||||
from internlm.core.context.parallel_context import Config
 | 
			
		||||
from internlm.initialize.launch import get_config_value
 | 
			
		||||
from tests.test_utils.common_fixture import (  # noqa # pylint: disable=unused-import
 | 
			
		||||
    ASYNC_TMP_FOLDER,
 | 
			
		||||
    BOTO_SAVE_PATH,
 | 
			
		||||
    LOCAL_SAVE_PATH,
 | 
			
		||||
    del_tmp_file,
 | 
			
		||||
    init_dist_and_model,
 | 
			
		||||
    reset_singletons,
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
ASYNC_TMP_FOLDER = "./async_tmp_folder"
 | 
			
		||||
ckpt_config_list = [
 | 
			
		||||
    # async boto
 | 
			
		||||
    dict(
 | 
			
		||||
        enable_save_ckpt=True,
 | 
			
		||||
        async_upload_tmp_folder=ASYNC_TMP_FOLDER,
 | 
			
		||||
        async_upload=True,
 | 
			
		||||
        save_folder=BOTO_SAVE_PATH,
 | 
			
		||||
        test_id=0,
 | 
			
		||||
    ),
 | 
			
		||||
    # sync local
 | 
			
		||||
    dict(
 | 
			
		||||
        enable_save_ckpt=True,
 | 
			
		||||
        async_upload_tmp_folder=None,
 | 
			
		||||
        async_upload=False,
 | 
			
		||||
        save_folder=LOCAL_SAVE_PATH,
 | 
			
		||||
        test_id=1,
 | 
			
		||||
    ),
 | 
			
		||||
    # sync boto
 | 
			
		||||
    dict(
 | 
			
		||||
        enable_save_ckpt=True,
 | 
			
		||||
        async_upload_tmp_folder=None,
 | 
			
		||||
        async_upload=False,
 | 
			
		||||
        save_folder=BOTO_SAVE_PATH,
 | 
			
		||||
        test_id=2,
 | 
			
		||||
    ),
 | 
			
		||||
    # async local
 | 
			
		||||
    dict(
 | 
			
		||||
        enable_save_ckpt=True,
 | 
			
		||||
        async_upload_tmp_folder=ASYNC_TMP_FOLDER,
 | 
			
		||||
        async_upload=True,
 | 
			
		||||
        save_folder=LOCAL_SAVE_PATH,
 | 
			
		||||
        test_id=3,
 | 
			
		||||
    ),
 | 
			
		||||
]
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@pytest.fixture(scope="function")
 | 
			
		||||
def del_tmp():
 | 
			
		||||
    del_tmp_file()
 | 
			
		||||
    yield
 | 
			
		||||
    del_tmp_file()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@pytest.mark.usefixtures("del_tmp")
 | 
			
		||||
@pytest.mark.usefixtures("reset_singletons")
 | 
			
		||||
@pytest.mark.parametrize("ckpt_config", ckpt_config_list)
 | 
			
		||||
def test_storage_mm_save_load(ckpt_config, init_dist_and_model):  # noqa # pylint: disable=unused-argument
 | 
			
		||||
    from internlm.utils.storage_manager import (
 | 
			
		||||
        check_folder,
 | 
			
		||||
        get_fns,
 | 
			
		||||
        init_storage_manager,
 | 
			
		||||
        llm_load,
 | 
			
		||||
        llm_save,
 | 
			
		||||
        wait_async_upload_finish,
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    ckpt_config = Config(ckpt_config)
 | 
			
		||||
    enable_save_ckpt = get_config_value(ckpt_config, "enable_save_ckpt", False)
 | 
			
		||||
    async_upload_tmp_folder = get_config_value(ckpt_config, "async_upload_tmp_folder", False)
 | 
			
		||||
    async_upload = get_config_value(ckpt_config, "async_upload", False)
 | 
			
		||||
 | 
			
		||||
    init_storage_manager(enable_save_ckpt, async_upload_tmp_folder, async_upload)
 | 
			
		||||
 | 
			
		||||
    tobj = torch.rand(64, 64)
 | 
			
		||||
    save_fn = os.path.join(ckpt_config.save_folder, "test.pt")
 | 
			
		||||
    llm_save(save_fn, tobj)
 | 
			
		||||
    if ckpt_config.test_id == 0:
 | 
			
		||||
        wait_async_upload_finish()
 | 
			
		||||
    check_folder(save_fn)
 | 
			
		||||
    assert get_fns(ckpt_config.save_folder)[0] == "test.pt"
 | 
			
		||||
    load_obj = llm_load(save_fn, map_location="cpu")
 | 
			
		||||
    assert 0 == ((load_obj != tobj).sum())
 | 
			
		||||
| 
						 | 
				
			
			@ -0,0 +1,119 @@
 | 
			
		|||
import fcntl
 | 
			
		||||
import os
 | 
			
		||||
import time
 | 
			
		||||
from multiprocessing import Process
 | 
			
		||||
 | 
			
		||||
import pytest
 | 
			
		||||
import torch
 | 
			
		||||
import torch.distributed as dist
 | 
			
		||||
 | 
			
		||||
os.environ["INTERNLM_ENABLE_TIMEOUT"] = "1"  # noqa  # pylint: disable=wrong-import-position
 | 
			
		||||
os.environ["NCCL_TIMEOUT"] = "5"
 | 
			
		||||
from internlm.utils.timeout import llm_timeout
 | 
			
		||||
from tests.test_utils.common_fixture import (  # noqa # pylint: disable=unused-import
 | 
			
		||||
    init_config,
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
WORLD_SIZE = 2
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@llm_timeout(2, "fake_timeout_func")
 | 
			
		||||
def fake_timeout_func():
 | 
			
		||||
    time.sleep(10)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@llm_timeout(10, "nccl_timeout_func")
 | 
			
		||||
def nccl_timeout_func(rank):
 | 
			
		||||
    # see: https://github.com/pytorch/pytorch/issues/104506#issuecomment-1679762880
 | 
			
		||||
    # 'NCCL_ASYNC_ERROR_HANDLING' cannot take effect on the first collective communication.
 | 
			
		||||
    buff = torch.ones([64, 64]).cuda(rank)
 | 
			
		||||
    dist.all_reduce(buff)  # lazy communicator init
 | 
			
		||||
    torch.cuda.synchronize()
 | 
			
		||||
    if rank == 0:
 | 
			
		||||
        dist.all_reduce(buff)
 | 
			
		||||
        torch.cuda.synchronize()  # main thread will hang at here.
 | 
			
		||||
    else:
 | 
			
		||||
        time.sleep(9999)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@llm_timeout(10, "try_file_lock")
 | 
			
		||||
def try_file_lock(rank, stop_file_path):
 | 
			
		||||
    if rank == 1:
 | 
			
		||||
        time.sleep(5)
 | 
			
		||||
 | 
			
		||||
    with open(stop_file_path, "r", encoding="utf-8") as f:
 | 
			
		||||
        fcntl.flock(f, fcntl.LOCK_EX)  # rank 1 hang.
 | 
			
		||||
        if rank == 0:
 | 
			
		||||
            time.sleep(99999)  # rank 0 hang.
 | 
			
		||||
        f.seek(0)
 | 
			
		||||
        f.read()
 | 
			
		||||
        fcntl.flock(f, fcntl.LOCK_UN)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def local_timeout(rank, _):
 | 
			
		||||
 | 
			
		||||
    try:
 | 
			
		||||
        fake_timeout_func()
 | 
			
		||||
    except TimeoutError as e:
 | 
			
		||||
        print(f"local_timeout, rank:{rank}, e:{e}", flush=True)
 | 
			
		||||
    else:
 | 
			
		||||
        assert False, "It should timeout!"
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def gpc_timeout(rank, world_size):
 | 
			
		||||
 | 
			
		||||
    from internlm.initialize import initialize_distributed_env
 | 
			
		||||
 | 
			
		||||
    os.environ["RANK"] = str(rank)
 | 
			
		||||
    os.environ["LOCAL_RANK"] = str(rank)
 | 
			
		||||
    os.environ["WORLD_SIZE"] = str(world_size)
 | 
			
		||||
    os.environ["MASTER_ADDR"] = "127.0.0.1"
 | 
			
		||||
    os.environ["MASTER_PORT"] = "12377"
 | 
			
		||||
    initialize_distributed_env(config=init_config, launcher="torch", master_port=12377, args_check=False)
 | 
			
		||||
 | 
			
		||||
    try:
 | 
			
		||||
        nccl_timeout_func(rank)
 | 
			
		||||
    except TimeoutError as e:
 | 
			
		||||
        print(f"gpc_timeout, rank:{rank}, e:{e}", flush=True)
 | 
			
		||||
        time.sleep(5)  # wait rank 0 to be killed
 | 
			
		||||
    else:
 | 
			
		||||
        time.sleep(5)  # give some time to let Watchdog kill rank 0.
 | 
			
		||||
        assert False, "It should timeout!"
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def file_lock_timeout(rank, _, stop_file_path):
 | 
			
		||||
    if rank == 0:
 | 
			
		||||
        with open(stop_file_path, "w"):
 | 
			
		||||
            pass
 | 
			
		||||
    try:
 | 
			
		||||
        try_file_lock(rank, stop_file_path)
 | 
			
		||||
    except TimeoutError as e:
 | 
			
		||||
        print(e, flush=True)
 | 
			
		||||
    else:
 | 
			
		||||
        assert False, "It should timeout!"
 | 
			
		||||
    finally:
 | 
			
		||||
        if rank == 0:
 | 
			
		||||
            os.remove(stop_file_path)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
timeout_func_list = [(gpc_timeout, 2, None), (local_timeout, 1, None), (file_lock_timeout, 2, "test_lock.log")]
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@pytest.mark.parametrize("timeout_func_and_args", timeout_func_list)
 | 
			
		||||
def test_timeout(timeout_func_and_args):
 | 
			
		||||
    timeout_func, world_size, other_args = timeout_func_and_args
 | 
			
		||||
    procs = []
 | 
			
		||||
    for i in range(world_size):
 | 
			
		||||
        if other_args is None:
 | 
			
		||||
            args = (i, world_size)
 | 
			
		||||
        else:
 | 
			
		||||
            args = (i, world_size, other_args)
 | 
			
		||||
        proc = Process(target=timeout_func, args=args)
 | 
			
		||||
        proc.start()
 | 
			
		||||
        procs.append(proc)
 | 
			
		||||
 | 
			
		||||
    for proc in procs:
 | 
			
		||||
        proc.join(15)
 | 
			
		||||
        if proc.is_alive():
 | 
			
		||||
            proc.terminate()
 | 
			
		||||
            proc.join()
 | 
			
		||||
							
								
								
									
										42
									
								
								train.py
								
								
								
								
							
							
						
						
									
										42
									
								
								train.py
								
								
								
								
							| 
						 | 
				
			
			@ -36,7 +36,6 @@ from internlm.utils.common import (
 | 
			
		|||
    parse_args,
 | 
			
		||||
)
 | 
			
		||||
from internlm.utils.evaluation import evaluate_on_val_dls
 | 
			
		||||
from internlm.utils.gputest import bench_gpu, bench_net
 | 
			
		||||
from internlm.utils.logger import get_logger, initialize_uniscale_logger
 | 
			
		||||
from internlm.utils.megatron_timers import megatron_timer as timer
 | 
			
		||||
from internlm.utils.model_checkpoint import CheckpointManager
 | 
			
		||||
| 
						 | 
				
			
			@ -74,7 +73,6 @@ def main(args):
 | 
			
		|||
    total_steps = gpc.config.data.total_steps
 | 
			
		||||
    valid_every = gpc.config.data.valid_every
 | 
			
		||||
    label_smoothing = gpc.config.loss.label_smoothing
 | 
			
		||||
    lr = gpc.config.adam.lr
 | 
			
		||||
 | 
			
		||||
    get_tflops_func = partial(
 | 
			
		||||
        get_megatron_flops,
 | 
			
		||||
| 
						 | 
				
			
			@ -97,21 +95,11 @@ def main(args):
 | 
			
		|||
    # initialize customed llm logger
 | 
			
		||||
    uniscale_logger = initialize_llm_logger(start_time=current_time)
 | 
			
		||||
 | 
			
		||||
    # initialize and resume train state
 | 
			
		||||
    train_state = TrainState(gpc.config)
 | 
			
		||||
 | 
			
		||||
    # initialize model
 | 
			
		||||
    model = initialize_model()
 | 
			
		||||
 | 
			
		||||
    with open(args.config, "r") as f:
 | 
			
		||||
        config_lines = f.readlines()
 | 
			
		||||
    ckpt_manager = CheckpointManager(
 | 
			
		||||
        ckpt_config=gpc.config.ckpt,
 | 
			
		||||
        model=model,
 | 
			
		||||
        model_config=gpc.config.model,
 | 
			
		||||
        model_config_file="".join(config_lines),
 | 
			
		||||
        feishu_address=gpc.config.alert_address,
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    # initialize loss function
 | 
			
		||||
    criterion = FlashGPTLMLoss(parallel_output=True, label_smoothing=label_smoothing)
 | 
			
		||||
| 
						 | 
				
			
			@ -119,18 +107,28 @@ def main(args):
 | 
			
		|||
    # initialize the train and validation data loader
 | 
			
		||||
    train_dl, dataset_types = get_train_data_loader(num_worker=4)
 | 
			
		||||
    val_dls = get_validation_data_loader()
 | 
			
		||||
    train_state.init_batch_sampler(train_dl)
 | 
			
		||||
 | 
			
		||||
    # Loading model weights must be done before zero is initialized.
 | 
			
		||||
    ckpt_manager.try_load_model(current_time)
 | 
			
		||||
    # initialize and resume train state
 | 
			
		||||
    train_state = TrainState(gpc.config, train_dl.batch_sampler)
 | 
			
		||||
 | 
			
		||||
    # if fsdp enabled, warp the model        
 | 
			
		||||
    model = warp_FSDP_model(model)
 | 
			
		||||
 | 
			
		||||
    optimizer, beta2_scheduler, lr_scheduler = initialize_optimizer(model=model)
 | 
			
		||||
 | 
			
		||||
    ckpt_manager = CheckpointManager(
 | 
			
		||||
        ckpt_config=gpc.config.ckpt,
 | 
			
		||||
        model=model,
 | 
			
		||||
        optimizer=optimizer,
 | 
			
		||||
        lr_scheduler=lr_scheduler,
 | 
			
		||||
        train_dl=train_dl,
 | 
			
		||||
        model_config=gpc.config.model,
 | 
			
		||||
        model_config_file="".join(config_lines),
 | 
			
		||||
        feishu_address=gpc.config.monitor.alert.feishu_alert_address,
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    # Loading other persistent training states.
 | 
			
		||||
    ckpt_manager.try_resume_training(lr_scheduler, optimizer, lr, train_state, train_dl)
 | 
			
		||||
    ckpt_manager.try_resume_training(train_state, current_time)
 | 
			
		||||
 | 
			
		||||
    # initialize customed llm writer
 | 
			
		||||
    writer = Writer(
 | 
			
		||||
| 
						 | 
				
			
			@ -201,8 +199,6 @@ def main(args):
 | 
			
		|||
        for batch_count in range(train_state.batch_count, total_steps):
 | 
			
		||||
            if batch_count % 50 == 0:
 | 
			
		||||
                torch.cuda.empty_cache()
 | 
			
		||||
                bench_gpu()
 | 
			
		||||
                bench_net()
 | 
			
		||||
 | 
			
		||||
            start_time = time.time()
 | 
			
		||||
            timer("one-batch").start()
 | 
			
		||||
| 
						 | 
				
			
			@ -245,7 +241,7 @@ def main(args):
 | 
			
		|||
                if -1 in grad_norm_groups.values() and gpc.is_rank_for_log():  # -1 encodes a specific failure case
 | 
			
		||||
                    logger.warning(f"Warning: skip parameter update at step {batch_count}.")
 | 
			
		||||
                    send_alert_message(
 | 
			
		||||
                        address=gpc.config.alert_address,
 | 
			
		||||
                        address=gpc.config.monitor.alert.feishu_alert_address,
 | 
			
		||||
                        message=f"Warning: skip parameter update at step {batch_count}.",
 | 
			
		||||
                    )
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -305,11 +301,15 @@ if __name__ == "__main__":
 | 
			
		|||
    assert hasattr(gpc, "config") and gpc.config is not None
 | 
			
		||||
 | 
			
		||||
    # initialize monitor manager context
 | 
			
		||||
    with initialize_monitor_manager(job_name=gpc.config.JOB_NAME, alert_address=gpc.config.alert_address):
 | 
			
		||||
    with initialize_monitor_manager(
 | 
			
		||||
        job_name=gpc.config.JOB_NAME, alert_address=gpc.config.monitor.alert.feishu_alert_address
 | 
			
		||||
    ):
 | 
			
		||||
        try:
 | 
			
		||||
            main(args)
 | 
			
		||||
        except Exception:
 | 
			
		||||
            logger.error(
 | 
			
		||||
                f"Raise exception from {hostname} with rank id: {gpc.get_global_rank()}\n{traceback.format_exc()}",
 | 
			
		||||
            )
 | 
			
		||||
            mm.monitor_exception(alert_address=gpc.config.alert_address, excp_info=traceback.format_exc())
 | 
			
		||||
            mm.monitor_exception(
 | 
			
		||||
                alert_address=gpc.config.monitor.alert.feishu_alert_address, excp_info=traceback.format_exc()
 | 
			
		||||
            )
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in New Issue