diff --git a/configs/7B_sft.py b/configs/7B_sft.py index eeb69ec..8cb1e04 100644 --- a/configs/7B_sft.py +++ b/configs/7B_sft.py @@ -1,7 +1,8 @@ JOB_NAME = "7b_train" +DO_ALERT = False -SEQ_LEN = 2048 -HIDDEN_SIZE = 4096 +SEQ_LEN = 256 +HIDDEN_SIZE = 512 NUM_ATTENTION_HEAD = 32 MLP_RATIO = 8 / 3 NUM_LAYER = 32 @@ -22,14 +23,22 @@ CHECKPOINT_EVERY = 20 ckpt = dict( enable_save_ckpt=False, # enable ckpt save. save_ckpt_folder=SAVE_CKPT_FOLDER, # Path to save training ckpt. - load_ckpt_folder=LOAD_CKPT_FOLDER, # Ckpt path to resume training(load weights and scheduler/context states). + # load_ckpt_folder=LOAD_CKPT_FOLDER, # Ckpt path to resume training(load weights and scheduler/context states). load_given_ckpt = False, # load_model_only_folder=MODEL_ONLY_FOLDER, # Path to initialize with given model weights. load_optimizer=True, # Wheter to load optimizer states when continuing training. + + # load_ckpt_folder= dict(path=MODEL_ONLY_FOLDER, content=["model"], ckpt_type="normal"), + load_ckpt_folder="local:llm_ckpts/", + # 'load_ckpt_info' setting guide: + # 1. the 'path' indicate ckpt path, + # 2. the 'content‘ means what states will be loaded, support: "model", "sampler", "optimizer", "scheduler", "all" + # 3. the ’ckpt_type‘ means the type of checkpoint to be loaded, now only 'normal' type is supported. + load_ckpt_info=dict(path=MODEL_ONLY_FOLDER, content=("model",), ckpt_type="internlm"), + checkpoint_every=CHECKPOINT_EVERY, async_upload=True, # async ckpt upload. (only work for boto3 ckpt) async_upload_tmp_folder="/dev/shm/internlm_tmp_ckpt/", # path for temporarily files during asynchronous upload. - snapshot_ckpt_folder="/".join([SAVE_CKPT_FOLDER, "snapshot"]), # directory for snapshot ckpt storage path. oss_snapshot_freq=int(CHECKPOINT_EVERY / 2), # snapshot ckpt save frequency. ) @@ -46,7 +55,7 @@ data = dict( # defaults to 0, means disable evaluate valid_every=50, pack_sample_into_one=False, - total_steps=50000, + total_steps=30, skip_batches="", rampup_batch_size="", # Datasets with less than 50 rows will be discarded @@ -145,8 +154,17 @@ parallel = dict( pipeline=dict(size=1, interleaved_overlap=True), tensor=1, sequence_parallel=False, - use_fsdp = True, + use_fsdp=True, ) cudnn_deterministic = False cudnn_benchmark = False + +monitor = dict( + # feishu alert configs + alert=dict( + enable_feishu_alert=DO_ALERT, + feishu_alert_address=None, # feishu webhook to send alert message + light_monitor_address=None, # light_monitor address to send heartbeat + ), +) diff --git a/doc/en/usage.md b/doc/en/usage.md index f8809d0..d115fb1 100644 --- a/doc/en/usage.md +++ b/doc/en/usage.md @@ -112,19 +112,19 @@ If you want to load a model checkpoint when starting the training, you can confi ```python SAVE_CKPT_FOLDER = "local:/path/to/save/ckpt" -MODEL_ONLY_FOLDER = "local:/path/to/load/init/model/ckpt" LOAD_CKPT_FOLDER = "local:/path/to/load/resume/ckpt" ckpt = dict( save_ckpt_folder=SAVE_CKPT_FOLDER, # Path to save the model and optimizer checkpoints checkpoint_every=float("inf"), # Save a checkpoint every specified number of steps, default value is inf - load_model_only_folder=MODEL_ONLY_FOLDER, # Path to load the initial model weights, only load model weights without loading optimizer weights, training will start from the first step - load_ckpt_folder=LOAD_CKPT_FOLDER, # Path to load the weights of the model and optimizer for resuming training, training will resume from the specified step - load_optimizer=True, # Whether to load optimizer weights when resuming training, default value is True + # When resuming training from a breakpoint,: + # (1) 'path' is the path of the loaded checkpoint. + # (2) 'content' indicates which state will be loaded, support: "model", "sampler", "optimizer", "scheduler", "all" + # (3) 'ckpt_type' indicates which type ckpt will be loaded, currently supported: "internlm" + load_ckpt_info=dict(path=MODEL_ONLY_FOLDER, content=("model",), ckpt_type="internlm"), ) ``` Note: -- `load_model_only_folder` and `load_ckpt_folder` cannot be set at the same time. - If the path starts with `local:`, it means the file is stored in the local file system. If it starts with `boto3:`, it means the file is stored in the remote OSS. The configuration for the model is as follows: diff --git a/doc/usage.md b/doc/usage.md index c00b03e..1b98c10 100644 --- a/doc/usage.md +++ b/doc/usage.md @@ -101,18 +101,17 @@ data = dict( 如果在启动训练时要加载模型 `checkpoint`,可进行如下相关配置: ```python SAVE_CKPT_FOLDER = "local:/path/to/save/ckpt" -MODEL_ONLY_FOLDER = "local:/path/to/load/init/model/ckpt" LOAD_CKPT_FOLDER = "local:/path/to/load/resume/ckpt" ckpt = dict( save_ckpt_folder=SAVE_CKPT_FOLDER, # 存储模型和优化器 checkpoint 的路径 checkpoint_every=float("inf"), # 每多少个 step 存储一次 checkpoint,默认值为 inf - load_model_only_folder=MODEL_ONLY_FOLDER, # 加载模型初始权重的路径,只加载模型权重,不加载优化器权重,训练将从第一个 step 开始 - load_ckpt_folder=LOAD_CKPT_FOLDER, # 断点续训时,加载模型和优化器等权重的路径,将从指定的 step 恢复训练 - load_optimizer=True, # 断点续训时,是否需要加载优化器权重,默认值为 True + # 断点续训时,加载模型和优化器等权重的路径,将从指定的 step 恢复训练 + # content 表示哪些状态会被加载,支持: "model", "sampler", "optimizer", "scheduler", "all" + # ckpt_type 表示加载的模型类型,目前支持: "internlm" + load_ckpt_info=dict(path=MODEL_ONLY_FOLDER, content=("model",), ckpt_type="internlm"), ) ``` 注意: -- `load_model_only_folder`与`load_ckpt_folder`不能同时设置 - 路径若以 `local:` 为前缀,则存储在本地文件系统;若以 `boto3:` 为前缀,则存储在远程 oss 上 模型相关关键参数配置如下所示: diff --git a/internlm/core/context/parallel_context.py b/internlm/core/context/parallel_context.py index 0726fa1..18544a7 100644 --- a/internlm/core/context/parallel_context.py +++ b/internlm/core/context/parallel_context.py @@ -18,6 +18,7 @@ import torch.distributed as dist from internlm.utils.common import SingletonMeta from internlm.utils.logger import get_logger +from internlm.utils.timeout import LLM_NCCL_TIMEOUT from . import process_group_initializer as pgroup_initializer from .process_group_initializer import ParallelMode @@ -36,7 +37,7 @@ class Config(dict): config (dict): The dict object to be wrapped. """ - def __init__(self, config: dict = None): + def __init__(self, config: dict = None): # pylint: disable=W0231 if config is not None: for k, v in config.items(): self._add_item(k, v) @@ -100,7 +101,7 @@ class Config(dict): module_name = filepath.stem source_file = SourceFileLoader(fullname=str(module_name), path=str(filepath)) - module = source_file.load_module() # pylint: disable=W4902,E1120 + module = source_file.load_module() # pylint: disable=W4902,E1120,W1505 # load into config config = Config() @@ -374,12 +375,22 @@ class ParallelContext(metaclass=SingletonMeta): """ # initialize the default process group init_method = f"tcp://[{host}]:{port}" - dist.init_process_group(rank=rank, world_size=world_size, backend=backend, init_method=init_method) + dist.init_process_group( + rank=rank, + world_size=world_size, + backend=backend, + init_method=init_method, + timeout=LLM_NCCL_TIMEOUT, + ) # None will give the default global process group for pytorch dist operations ranks = list(range(world_size)) if use_cpu: - cpu_group = dist.new_group(ranks, backend="gloo") if dist.get_backend() != "gloo" else None + cpu_group = ( + dist.new_group(ranks, backend="gloo", timeout=LLM_NCCL_TIMEOUT) + if dist.get_backend() != "gloo" + else None + ) else: cpu_group = None self._register_dist(rank, world_size, dist.GroupMember.WORLD, cpu_group, ranks, ParallelMode.GLOBAL) @@ -528,6 +539,7 @@ class ParallelContext(metaclass=SingletonMeta): if dpseed_with_tpoffset: dp_seed = seed + pipeline_offset * 1024 add_seed(ParallelMode.DATA, dp_seed) + add_seed(ParallelMode.DUMMY, dp_seed) # model parallel seeds are different across ranks if self.is_initialized(ParallelMode.TENSOR): @@ -535,7 +547,11 @@ class ParallelContext(metaclass=SingletonMeta): tp_seed = seed + tp_rank + pipeline_offset * 1024 add_seed(ParallelMode.TENSOR, tp_seed) - set_mode(ParallelMode.DATA) + # we do not set the random state mode to ParallelMode.DATA until model is built (instead, we use a dummy mode + # during model construction), this is because the random state will be different in different tensor parallel + # device of the same data parallel group. The underlying reason is that the device of tp_rank = 0 will perform + # additional random operations during the RowParallelLinear module building process. + set_mode(ParallelMode.DUMMY) seeds = get_seeds() seed_str = ", ".join([f"{k}: {v}" for k, v in seeds.items()]) diff --git a/internlm/core/context/process_group_initializer.py b/internlm/core/context/process_group_initializer.py index 79ce106..194e651 100644 --- a/internlm/core/context/process_group_initializer.py +++ b/internlm/core/context/process_group_initializer.py @@ -9,6 +9,8 @@ from enum import Enum import torch.distributed as dist +from internlm.utils.timeout import LLM_NCCL_TIMEOUT + # parallel modes class ParallelMode(Enum): @@ -40,6 +42,9 @@ class ParallelMode(Enum): # then manual communication only happens between inter-fsdp-modules, while intra-modules reduction is done by fsdp ZERO3_DP = "zero3_dp" + # dummy mode, only used during mode construction + DUMMY = "dummy" + class ProcessGroupInitializer(ABC): """An object, knowing the parallelism configuration, that initializes parallel groups. @@ -111,9 +116,13 @@ class Initializer_Data(ProcessGroupInitializer): for i in range(self.rank_num_per_dp_group): ranks = [i + j * self.rank_num_per_dp_group for j in range(self.data_parallel_size)] - group = dist.new_group(ranks) + group = dist.new_group(ranks, timeout=LLM_NCCL_TIMEOUT) if use_cpu: - group_cpu = dist.new_group(ranks, backend="gloo") if dist.get_backend() != "gloo" else group + group_cpu = ( + dist.new_group(ranks, backend="gloo", timeout=LLM_NCCL_TIMEOUT) + if dist.get_backend() != "gloo" + else group + ) else: group_cpu = None @@ -163,9 +172,13 @@ class Initializer_Model(ProcessGroupInitializer): for i in range(self.num_group): ranks = [i * self.rank_num_per_group + j for j in range(self.rank_num_per_group)] - group = dist.new_group(ranks) + group = dist.new_group(ranks, timeout=LLM_NCCL_TIMEOUT) if use_cpu: - group_cpu = dist.new_group(ranks, backend="gloo") if dist.get_backend() != "gloo" else group + group_cpu = ( + dist.new_group(ranks, backend="gloo", timeout=LLM_NCCL_TIMEOUT) + if dist.get_backend() != "gloo" + else group + ) else: group_cpu = None @@ -223,9 +236,13 @@ class Initializer_Pipeline(ProcessGroupInitializer): ) ) pipe_group_size = len(ranks) - pipe_group = dist.new_group(ranks) + pipe_group = dist.new_group(ranks, timeout=LLM_NCCL_TIMEOUT) if use_cpu: - group_cpu = dist.new_group(ranks, backend="gloo") if dist.get_backend() != "gloo" else pipe_group + group_cpu = ( + dist.new_group(ranks, backend="gloo", timeout=LLM_NCCL_TIMEOUT) + if dist.get_backend() != "gloo" + else pipe_group + ) else: group_cpu = None @@ -273,9 +290,13 @@ class Initializer_Tensor(ProcessGroupInitializer): for i in range(self.num_tensor_parallel_group): ranks = [i * self.tensor_parallel_size + j for j in range(self.tensor_parallel_size)] - group = dist.new_group(ranks) + group = dist.new_group(ranks, timeout=LLM_NCCL_TIMEOUT) if use_cpu: - group_cpu = dist.new_group(ranks, backend="gloo") if dist.get_backend() != "gloo" else group + group_cpu = ( + dist.new_group(ranks, backend="gloo", timeout=LLM_NCCL_TIMEOUT) + if dist.get_backend() != "gloo" + else group + ) else: group_cpu = None @@ -329,9 +350,13 @@ class Initializer_Zero1(ProcessGroupInitializer): i + (j * self.zero1_parallel_size + k) * self.rank_num_per_dp_group for k in range(self.zero1_parallel_size) ] - group = dist.new_group(ranks) + group = dist.new_group(ranks, timeout=LLM_NCCL_TIMEOUT) if use_cpu: - group_cpu = dist.new_group(ranks, backend="gloo") if dist.get_backend() != "gloo" else group + group_cpu = ( + dist.new_group(ranks, backend="gloo", timeout=LLM_NCCL_TIMEOUT) + if dist.get_backend() != "gloo" + else group + ) else: group_cpu = None @@ -378,9 +403,13 @@ class Initializer_Nettest(ProcessGroupInitializer): rank = i * self.nettest_parallel_size + j if rank < self.world_size: ranks.append(rank) - group = dist.new_group(ranks) + group = dist.new_group(ranks, timeout=LLM_NCCL_TIMEOUT) if use_cpu: - group_cpu = dist.new_group(ranks, backend="gloo") if dist.get_backend() != "gloo" else group + group_cpu = ( + dist.new_group(ranks, backend="gloo", timeout=LLM_NCCL_TIMEOUT) + if dist.get_backend() != "gloo" + else group + ) else: group_cpu = None diff --git a/internlm/core/scheduler/no_pipeline_scheduler.py b/internlm/core/scheduler/no_pipeline_scheduler.py index 0076349..1d8b61e 100644 --- a/internlm/core/scheduler/no_pipeline_scheduler.py +++ b/internlm/core/scheduler/no_pipeline_scheduler.py @@ -9,6 +9,7 @@ import torch from internlm.core.engine import Engine from internlm.utils.common import conditional_context +from internlm.utils.timeout import llm_timeout from .base_scheduler import BaseScheduler, SchedulerHook @@ -126,6 +127,7 @@ class NonPipelineScheduler(BaseScheduler): return output, loss + @llm_timeout(func_name="nopp_forward_backward_step") def forward_backward_step( self, engine: Engine, diff --git a/internlm/core/scheduler/pipeline_scheduler.py b/internlm/core/scheduler/pipeline_scheduler.py index fd0b23e..e9b6c64 100644 --- a/internlm/core/scheduler/pipeline_scheduler.py +++ b/internlm/core/scheduler/pipeline_scheduler.py @@ -15,6 +15,7 @@ from internlm.core.engine import Engine from internlm.core.naive_amp import NaiveAMPModel from internlm.utils.common import get_current_device, move_to_device from internlm.utils.logger import get_logger +from internlm.utils.timeout import llm_timeout from .base_scheduler import BaseScheduler, SchedulerHook @@ -592,6 +593,7 @@ class PipelineScheduler(BaseScheduler): return output, label, accum_loss + @llm_timeout(func_name="nointerleaved_forward_backward_step") def forward_backward_step(self, engine, data_iter, forward_only=False, return_loss=True, return_output_label=True): """Runs non-interleaved 1F1B schedule, with communication between pipeline stages. Returns a tuple with losses if the last stage, an empty tuple otherwise. @@ -1247,6 +1249,7 @@ class InterleavedPipelineScheduler(PipelineScheduler): # 3. Cooldown self._run_cooldown_loop(engine, num_microsteps, num_1f1b_micropairs=num_1f1b_micropairs) + @llm_timeout(func_name="interleaved_forward_backward_step") def forward_backward_step(self, engine, data_iter, forward_only=False, return_loss=True, return_output_label=True): """Run interleaved 1F1B schedule (model split into model chunks), with communication between pipeline stages as needed. diff --git a/internlm/core/trainer.py b/internlm/core/trainer.py index 6fd40ce..18a8f6f 100644 --- a/internlm/core/trainer.py +++ b/internlm/core/trainer.py @@ -23,7 +23,15 @@ class TrainState: train_dl (DataLoader): The DataLoader object used for training. """ - def __init__(self, config) -> None: + def __init__(self, config, batch_sampler) -> None: + """ + Args: + config (Config): internlm config + batch_sampler (torch.utils.data.Sampler): Because the dataloader loading is + asynchronous and prefetched, the batch_sampler state maintained inside the + dataloader are faster then the actual training progress, so we copy the + batch_sampler as the anchor point of ckpt reload. + """ # The number of batches produced by the data iterator self.batch_count: int = 0 # Used to store the number of samples consumed in the current epoch @@ -43,9 +51,20 @@ class TrainState: self.tensorboard_folder = config.tensorboard_folder - def init_batch_sampler(self, train_dl): - # Copy of the batch sampler from the DataLoader - self.batch_sampler = train_dl.batch_sampler.copy() + # learning rate + self.lr = config.adam.lr + + # smapler state + if batch_sampler: + self.init_batch_sampler(batch_sampler) + + def init_batch_sampler(self, batch_sampler): + """ + Args: + batch_sampler (torch.utils.data.Sampler): sampler. + """ + # make a copy of batch_sampler. + self.batch_sampler = batch_sampler.copy() # Iterator for the batch sampler self.batch_sampler_iter = iter(self.batch_sampler) @@ -61,26 +80,22 @@ class TrainState: return json.dumps(info, indent=4, sort_keys=True) - def load_state_dict(self, other_stuffs, train_dl): + def load_state_dict(self, other_stuffs): """ Resumes training from a checkpoint. Args: other_stuffs (dict): Other information needed to resume training. - train_dl (DataLoader): The DataLoader object used for training. """ - - self.batch_count = other_stuffs["batch_count"] + 1 # here you need to shift a batch backward self.num_consumed_samples_in_epoch = other_stuffs["num_consumed_samples_in_epoch"] self.num_consumed_tokens = other_stuffs["num_consumed_tokens"] self.inf_nan_skip_batches = other_stuffs["inf_nan_skip_batches"] - # compatible with previous checkpoints without this parameter - self.step_count = other_stuffs.get("step_count", other_stuffs["batch_count"]) + 1 - # track the actual updates of sampler when using weighted sampling - if hasattr(self, "batch_sampler"): - self.batch_sampler = train_dl.batch_sampler.copy() - self.batch_sampler_iter = iter(self.batch_sampler) + # Because the ckpt save occurs after updating 'step_count', + # there is no need to increment 'step_count' here (Does our step count start from 0 ?), + # However, 'batch_count' is updating before ckpt storage, so it need to inc 1 when resume. + self.batch_count = other_stuffs["batch_count"] + 1 # here you need to shift a batch backward + self.step_count = other_stuffs.get("step_count", self.batch_count) # resume tensorboard from older tensorboard_folder self.resume_tb_folder = other_stuffs.get("tensorboard_folder", None) diff --git a/internlm/initialize/launch.py b/internlm/initialize/launch.py index 66c1712..388051a 100644 --- a/internlm/initialize/launch.py +++ b/internlm/initialize/launch.py @@ -10,9 +10,10 @@ import torch from internlm.core.context import Config from internlm.core.context import global_context as gpc +from internlm.monitor import initialize_light_monitor from internlm.utils.common import get_master_node from internlm.utils.logger import get_logger -from internlm.utils.storage_manager import init_storage_manager +from internlm.utils.timeout import llm_timeout logger = get_logger(__file__) @@ -122,7 +123,7 @@ def args_sanity_check(): # processing the checkpoint config ckpt = gpc.config.ckpt if "enable_save_ckpt" not in ckpt: - ckpt._add_item("enable_save_ckpt", False) + ckpt._add_item("enable_save_ckpt", True) # Saving checkpoint args. if ckpt.enable_save_ckpt: @@ -148,9 +149,6 @@ def args_sanity_check(): if not ckpt.async_upload: ckpt._add_item("async_upload_tmp_folder", None) - if "snapshot_ckpt_folder" not in ckpt: - ckpt._add_item("snapshot_ckpt_folder", os.path.join(ckpt.save_ckpt_folder, "snapshot")) - if "oss_snapshot_freq" not in ckpt: ckpt._add_item("oss_snapshot_freq", float("inf")) # if oss_snapshot_freq not given, we disable. else: @@ -160,44 +158,23 @@ def args_sanity_check(): ckpt._add_item("async_upload", False) ckpt._add_item("async_upload_tmp_folder", None) ckpt._add_item("snapshot_ckpt_folder", None) - ckpt._add_item("snapshot_ckpt_folder", None) - - # Loading checkpoint args. - if "load_model_only_folder" not in ckpt: - ckpt._add_item("load_model_only_folder", None) if "load_ckpt_folder" not in ckpt: ckpt._add_item("load_ckpt_folder", None) - if "load_optimizer" not in ckpt: - ckpt._add_item("load_optimizer", True) - if "stop_file_path" not in ckpt: ckpt._add_item("stop_file_path", None) - if "load_given_ckpt" not in ckpt: - # If 'load_given_ckpt' is not given, we set it to False, so internlm can have opportunity + if "auto_resume" not in ckpt: + # If 'auto_resume' is not given, we set it to True, so internlm can have opportunity # to auto-load latest checkpoint. - ckpt._add_item("load_given_ckpt", False) - - if ckpt.load_given_ckpt: - # Priority: load_given_ckpt(True) > latest_checkpoint > load_model_only_folder - if ckpt.load_ckpt_folder and ckpt.load_model_only_folder: - logger.warning( - "Detect 'load_ckpt_folder' and 'load_model_only_folder' set at the same time, \ -and 'load_given_ckpt' is True, so internlm will load from 'load_ckpt_folder'" - ) - ckpt.load_model_only_folder = None + ckpt._add_item("auto_resume", True) if gpc.is_rank_for_log(): logger.info("+" * 15 + " Ckpt Info " + "+" * 15) # pylint: disable=W1201 logger.info(f"is enable save ckpt: {ckpt.enable_save_ckpt}") logger.info(f"save_ckpt_folder: {ckpt.save_ckpt_folder}") logger.info(f"checkpoint_every: {ckpt.checkpoint_every}") - logger.info(f"load_given_ckpt: {ckpt.load_given_ckpt}") - - # initialization storage manager - init_storage_manager(ckpt) # tensorboard writer config if "enable_tb" not in gpc.config: @@ -288,9 +265,22 @@ and 'load_given_ckpt' is True, so internlm will load from 'load_ckpt_folder'" gpc.config.parallel.sequence_parallel is True and gpc.config.model.use_flash_attn is False ), "sequence parallel does not support use_flash_attn=False" - # feishu webhook address for alerting - if "alert_address" not in gpc.config: - gpc.config._add_item("alert_address", None) + # monitoring default config + monitor_default_config = { + "alert_address": None, # compatible with old alert config + "monitor": { # new monitoring config + "alert": {"enable_feishu_alert": False, "feishu_alert_address": None, "light_monitor_address": None} + }, + } + + for key, value in monitor_default_config.items(): + if key not in gpc.config: + gpc.config._add_item(key, value) + + alert = gpc.config.monitor.alert + + if alert.enable_feishu_alert and not alert.feishu_alert_address and gpc.is_rank_for_log(): + logger.warning("alert is enable but alert_address is not set") optim_ckpt = gpc.config.hybrid_zero_optimizer if "zero_overlap_communication" in optim_ckpt: @@ -437,6 +427,7 @@ def launch_from_torch( ) +@llm_timeout(func_name="initialize_distributed_env") def initialize_distributed_env( config: str, launcher: str = "slurm", @@ -470,3 +461,20 @@ def initialize_distributed_env( if args_check: args_sanity_check() + + # init light monitor client + alert_config = gpc.config.monitor.alert + if alert_config.enable_feishu_alert and gpc.is_rank_for_log(): + light_monitor_address = alert_config.light_monitor_address + if light_monitor_address: + initialize_light_monitor(light_monitor_address) + else: + logger.warning("monitor address is none, monitor could not be used!") + + +def get_config_value(config, key, defalut): + try: + value = config[key] + except KeyError: + value = defalut + return value diff --git a/internlm/initialize/legacy/__init__.py b/internlm/initialize/legacy/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/internlm/initialize/legacy/launch.py b/internlm/initialize/legacy/launch.py new file mode 100644 index 0000000..8313654 --- /dev/null +++ b/internlm/initialize/legacy/launch.py @@ -0,0 +1,40 @@ +#!/usr/bin/env python +# -*- encoding: utf-8 -*- + +from internlm.initialize.launch import get_config_value +from internlm.utils.logger import get_logger + +logger = get_logger(__file__) + + +def auto_resume_sanity_check(ckpt_config): + load_given_ckpt = get_config_value(ckpt_config, "load_given_ckpt", None) + if load_given_ckpt is None: + return True # default value is True + else: + return not load_given_ckpt + + +def ckpt_info_sanity_check(ckpt_config): + load_ckpt_folder = get_config_value(ckpt_config, "load_ckpt_folder", None) + + load_model_only_folder = get_config_value(ckpt_config, "load_model_only_folder", None) + + if load_model_only_folder is not None: + assert ( + load_ckpt_folder is None + ), "Detect 'load_ckpt_folder' and 'load_model_only_folder' set at the same time, \ +# and 'load_given_ckpt' is True, so internlm will load from 'load_ckpt_folder'" + return dict(path=load_model_only_folder, content=("model",), ckpt_type="internlm") + else: + load_optimizer = get_config_value(ckpt_config, "load_optimizer", True) + + if isinstance(load_ckpt_folder, str): + if load_optimizer: + return dict(path=load_ckpt_folder, content=("model", "sampler", "optimizer"), ckpt_type="internlm") + else: + return dict(path=load_ckpt_folder, content=("model", "sampler"), ckpt_type="internlm") + elif load_ckpt_folder is None: + return None + else: + assert f"Unsupport data type:'{type(load_ckpt_folder)}' for config.ckpt arg: 'load_ckpt_folder'" diff --git a/internlm/model/linear.py b/internlm/model/linear.py index 32f29f8..5a3a4eb 100644 --- a/internlm/model/linear.py +++ b/internlm/model/linear.py @@ -9,7 +9,7 @@ from flash_attn.ops.fused_dense import ColumnParallelLinear, RowParallelLinear from flash_attn.utils.distributed import all_reduce, reduce_scatter from torch import nn -from internlm.core.context import IS_TENSOR_PARALLEL, ParallelMode +from internlm.core.context import ParallelMode from internlm.core.context import global_context as gpc from internlm.model.utils import fused_dense_func_torch @@ -195,12 +195,6 @@ class FeedForward(nn.Module): device=device, dtype=dtype, ) - # need to assign tp attribute so that colossalai know it is tensor parallel module - - if gpc.get_world_size(ParallelMode.TENSOR) > 1: - for name in ["w1", "w2", "w3"]: - for param in getattr(self, name).parameters(): - setattr(param, IS_TENSOR_PARALLEL, True) def forward(self, x): out = self.w3(F.silu(self.w1(x)) * self.w2(x)) diff --git a/internlm/model/modeling_internlm.py b/internlm/model/modeling_internlm.py index ceb4ac3..64ff4de 100644 --- a/internlm/model/modeling_internlm.py +++ b/internlm/model/modeling_internlm.py @@ -127,6 +127,9 @@ class PackedFlashBaseLayer1D(nn.Module): device=device, dtype=dtype, ) + for _, param in self.mlp.named_parameters(): + if gpc.get_world_size(ParallelMode.TENSOR) > 1: + setattr(param, IS_TENSOR_PARALLEL, True) self.dropout2 = nn.Dropout(drop_rate) self.use_swiglu = use_swiglu self.use_scaled_init = use_scaled_init diff --git a/internlm/monitor/__init__.py b/internlm/monitor/__init__.py index b100cde..2501d66 100644 --- a/internlm/monitor/__init__.py +++ b/internlm/monitor/__init__.py @@ -1,4 +1,11 @@ +from .alert import initialize_light_monitor, send_heartbeat from .monitor import initialize_monitor_manager, send_alert_message from .utils import set_env_var -__all__ = ["send_alert_message", "initialize_monitor_manager", "set_env_var"] +__all__ = [ + "send_alert_message", + "initialize_monitor_manager", + "set_env_var", + "initialize_light_monitor", + "send_heartbeat", +] diff --git a/internlm/monitor/alert.py b/internlm/monitor/alert.py index 78b6040..1772e7f 100644 --- a/internlm/monitor/alert.py +++ b/internlm/monitor/alert.py @@ -1,8 +1,59 @@ import json +import math +import os +import re import time +from typing import Dict import requests +from internlm.utils.logger import get_logger + +logger = get_logger(__file__) + + +def initialize_light_monitor(monitor_address: str = None): + try: + from uniscale_monitoring import init_monitor + + init_monitor(monitor_address) + except Exception as e: + logger.warning(f"init monitor meet error: {e}") + + +def send_heartbeat(msg_type: str, msg: Dict): + def nan2none(v): + if isinstance(v, float) and math.isnan(v): + return None + return v + + try: + from uniscale_monitoring import send_meta + + data = {} + for k, v in msg.items(): + if isinstance(v, Dict): + for k1, v1 in v.items(): + new_k = f"{k}_{k1}".split(" ")[0] + new_k = re.sub(r"[^a-zA-Z0-9_]", "_", new_k) + data[new_k] = nan2none(v1) + else: + new_k = k.split(" ")[0] + new_k = re.sub(r"[^a-zA-Z0-9_]", "_", new_k) + data[new_k] = nan2none(v) + + if os.getenv("CLUSTER_NAME"): + data.update({"cluster": os.getenv("CLUSTER_NAME")}) + if msg_type == "train_metrics": + data.update({"msg_type": "train_metrics"}) + elif msg_type == "init_time": + data.update({"msg_type": "init_time"}) + elif msg_type == "stage_time": + data.update({"msg_type": "stage_time"}) + send_meta(data, timeout=0.1) + except Exception as e: + logger.warning(f"send heartbeat meet error: {e}") + def send_feishu_msg_with_webhook(webhook: str, title: str, message: str): """ diff --git a/internlm/monitor/monitor.py b/internlm/monitor/monitor.py index a8ef5a0..6a3b9dc 100644 --- a/internlm/monitor/monitor.py +++ b/internlm/monitor/monitor.py @@ -226,9 +226,7 @@ def initialize_monitor_manager(job_name: str = None, alert_address: str = None): send_alert_message(address=alert_address, message=f"Training in {socket.gethostname()} is starting.") yield finally: - send_alert_message( - address=gpc.config.alert_address, message=f"Training in {socket.gethostname()} completed." - ) + send_alert_message(address=alert_address, message=f"Training in {socket.gethostname()} completed.") monitor_manager.stop_monitor() else: yield diff --git a/internlm/solver/optimizer/__init__.py b/internlm/solver/optimizer/__init__.py index b7178ad..7968e75 100644 --- a/internlm/solver/optimizer/__init__.py +++ b/internlm/solver/optimizer/__init__.py @@ -1,6 +1,6 @@ #!/usr/bin/env python # -*- encoding: utf-8 -*- -from .hybrid_zero_optim import HybridZeroOptimizer, FSDPadaptOptimizer +from .hybrid_zero_optim import HybridZeroOptimizer, FSDPadaptOptimizer, reload_zero_fp32_buff -__all__ = ["HybridZeroOptimizer", "FSDPadaptOptimizer"] +__all__ = ["HybridZeroOptimizer", "FSDPadaptOptimizer", "reload_zero_fp32_buff"] \ No newline at end of file diff --git a/internlm/solver/optimizer/hybrid_zero_optim.py b/internlm/solver/optimizer/hybrid_zero_optim.py index e903342..330b696 100644 --- a/internlm/solver/optimizer/hybrid_zero_optim.py +++ b/internlm/solver/optimizer/hybrid_zero_optim.py @@ -32,6 +32,7 @@ from internlm.solver.optimizer.utils import ( from internlm.utils.common import get_current_device from internlm.utils.logger import get_logger from internlm.utils.megatron_timers import megatron_timer as timer +from internlm.utils.timeout import llm_timeout from .utils import compute_norm @@ -329,6 +330,7 @@ class HybridZeroOptimizer(BaseOptimizer): self._param_store = ParameterStore(ParallelMode.ZERO1) self._grad_store = GradientStore(ParallelMode.DATA) self._bucket_store = BucketStore(ParallelMode.DATA) + self._bucket_in_progress = [] # fp16 and fp32 params for mixed precision training self._fp16_param_groups = dict() @@ -338,6 +340,8 @@ class HybridZeroOptimizer(BaseOptimizer): # self._overlap_communication = overlap_communication self._reduce_bucket_size = reduce_bucket_size + self._comm_bcast_stream = torch.cuda.Stream() + # gradient scaler self.grad_scaler = DynamicGradScaler( initial_scale=initial_scale, @@ -436,13 +440,6 @@ class HybridZeroOptimizer(BaseOptimizer): # flag used to skip unnecessary gradient reduce operation when gradient accumulation is enabled. self.skip_grad_reduce = False - # initialize communication stream for - # communication-computation overlapping - if self._overlap_sync_grad: - self._comm_stream = torch.cuda.Stream() - else: - self._comm_stream = torch.cuda.current_stream() - # reduction hook is only used if overlapping communication # if it is stage 1 without overlapping, no hook will be attached if self._overlap_sync_grad: @@ -588,34 +585,41 @@ class HybridZeroOptimizer(BaseOptimizer): def _reduce_grads_by_rank(self, reduce_rank, grads, bucket_size): grad_buckets_by_dtype = split_half_float_double(grads) - + next_bucket_list = [] + # add parameters into bucket for reduction for tensor_list in grad_buckets_by_dtype: param_bucket = TensorBucket(size=bucket_size) for tensor in tensor_list: param_bucket.add_to_bucket(tensor, allow_oversize=True) - if param_bucket.is_full_or_oversized(): - self._reduce_and_copy(bucket=param_bucket, reduce_rank=reduce_rank) - param_bucket.empty() if not param_bucket.is_empty(): self._reduce_and_copy(bucket=param_bucket, reduce_rank=reduce_rank) + next_bucket_list.append(param_bucket) + + # wait for the completion of previouce bucket list reduction, and do unflatten_and_copy() + # here we can also overlap the communication with some memcpy operation caused by bucket.flatten() + for bucket in self._bucket_in_progress: + bucket.commu_handle.wait() + bucket.unflatten_and_copy() + bucket.empty() + self._bucket_in_progress = [] + self._param_store.clear_grads_of_previous_reduced_params() + + # after the completion of bucket list reduction, add new buckets into _bucket_in_progress + self._bucket_in_progress = next_bucket_list.copy() def _reduce_and_copy(self, bucket: TensorBucket, reduce_rank): - if self._overlap_sync_grad: - self._comm_stream.synchronize() - self._param_store.clear_grads_of_previous_reduced_params() + # flatten the tensors and do allreduce + bucket.flatten() + bucket.commu_handle = reduce_tensor( + tensor=bucket.get_flat_tensor(), + dtype=None, + dst_rank=reduce_rank, + parallel_mode=ParallelMode.DATA, + ) - with torch.cuda.stream(self._comm_stream): - flat = bucket.flatten() - reduced_flat = reduce_tensor( - tensor=flat, - dtype=self.dtype, - dst_rank=reduce_rank, - parallel_mode=ParallelMode.DATA, - ) - - # update the reduced tensor - if reduce_rank is None or reduce_rank == self._zero_local_rank: - bucket.unflatten_and_copy(reduced_flat) + # update the reduced tensor + if reduce_rank is None or reduce_rank == self._zero_local_rank: + bucket.set_unflatten_and_copy_flag(flag=True) def _has_inf_or_nan(self, tensor): try: @@ -711,6 +715,7 @@ class HybridZeroOptimizer(BaseOptimizer): return norm + @llm_timeout(func_name="optim_step") def step(self, closure=None): """Performs a single optimization step. @@ -739,10 +744,13 @@ class HybridZeroOptimizer(BaseOptimizer): groups_norms.append(self._compute_norm_with_stage(group_id=group_id)) # clear reduced grads - if self._overlap_sync_grad: - # grads in the last bucket is reduced - self._comm_stream.synchronize() - self._param_store.clear_grads_of_previous_reduced_params() + # grads in the last bucket is reduced + for bucket in self._bucket_in_progress: + bucket.commu_handle.wait() + bucket.unflatten_and_copy() + bucket.empty() + self._bucket_in_progress = [] + self._param_store.clear_grads_of_previous_reduced_params() # compute norm for gradients in the last bucket total_norms = {} @@ -783,7 +791,7 @@ class HybridZeroOptimizer(BaseOptimizer): if gpc.is_rank_for_log(): logger.warning("Overflow occurs, please check it.") send_alert_message( - address=gpc.config.alert_address, + address=gpc.config.monitor.alert.feishu_alert_address, message="Overflow occurs, please check it.", ) self._grad_store._averaged_gradients = dict() @@ -829,7 +837,9 @@ class HybridZeroOptimizer(BaseOptimizer): if gpc.config.model.dtype is not torch.float32: if len(single_grad_partition_groups) != 0 and self._clip_grad_norm > 0: self._unscale_and_clip_grads( - single_grad_partition_groups, list(global_norm_groups.values()), loss_scale + single_grad_partition_groups, + list(global_norm_groups.values()), + loss_scale, ) # update the parameters @@ -850,7 +860,9 @@ class HybridZeroOptimizer(BaseOptimizer): fp32_param = self._fp32_flat_param_groups_of_current_rank[group_id] fp16_param.data.copy_(fp32_param) - self.broadcast_params() + torch.cuda.synchronize() + with torch.cuda.stream(self._comm_bcast_stream): + self.broadcast_params() timer("step").stop() @@ -976,3 +988,17 @@ class HybridZeroOptimizer(BaseOptimizer): if "zero_devide_optim_plan" in states: self.params_per_rank_id_dict = states["zero_devide_optim_plan"] + + +def reload_zero_fp32_buff(optimizer): + # If we use AMP optimizer, we need to update its fp32 buffer as newly loaded weights value. + # Or we must ensure that loading model weights must be done before zero is initialized. + if isinstance(optimizer, HybridZeroOptimizer): + for group_id, param_group in enumerate(optimizer.optim.param_groups): + if optimizer.param_group_has_params[group_id]: + # flatten fp16 params have already been updated by 'load_model_checkpoint' + fp16_flat_current_rank = optimizer._param_store.get_flat_fp16_param_by_rank_group( + optimizer._zero_local_rank, group_id + ) + # param_group["params"] is fp32 flatten optimizer states of this zero rank. + param_group["params"][0].data.copy_(fp16_flat_current_rank.float()) diff --git a/internlm/solver/optimizer/store.py b/internlm/solver/optimizer/store.py index 05a44d2..adab6c9 100644 --- a/internlm/solver/optimizer/store.py +++ b/internlm/solver/optimizer/store.py @@ -249,11 +249,17 @@ class ParameterStore(BaseStore): if not last_bucket: if group_id not in self._former_bucket_reduced_param: return [], [] - return self._former_bucket_reduced_param[group_id], self._former_bucket_reduced_grad[group_id] + return ( + self._former_bucket_reduced_param[group_id], + self._former_bucket_reduced_grad[group_id], + ) else: if group_id not in self._last_bucket_reduced_param: return [], [] - return self._last_bucket_reduced_param[group_id], self._last_bucket_reduced_grad[group_id] + return ( + self._last_bucket_reduced_param[group_id], + self._last_bucket_reduced_grad[group_id], + ) def reset_reduced_data_for_compute_norm(self): self._former_bucket_reduced_param = {} @@ -277,6 +283,9 @@ class TensorBucket: self._max_size = size self._current_size = 0 self._bucket = [] + self._flat_tensor = None + self._unflatten_and_copy_flag = False + self.commu_handle = None @property def max_size(self): @@ -292,6 +301,15 @@ class TensorBucket: def is_empty(self): return len(self._bucket) == 0 + def set_unflatten_and_copy_flag(self, flag): + self._unflatten_and_copy_flag = flag + + def get_unflatten_and_copy_flag(self): + return self._unflatten_and_copy_flag + + def get_flat_tensor(self): + return self._flat_tensor + def add_to_bucket(self, tensor, allow_oversize=False): tensor_size = tensor.numel() @@ -312,11 +330,14 @@ class TensorBucket: def empty(self): self._bucket = [] self._size = 0 + self._flat_tensor = None + self.commu_handle = None def flatten(self): - return _flatten_dense_tensors(self._bucket) + self._flat_tensor = _flatten_dense_tensors(self._bucket) - def unflatten_and_copy(self, flat_tensor): - unflattened_tensor_list = _unflatten_dense_tensors(flat_tensor, self._bucket) - for old, new in zip(self._bucket, unflattened_tensor_list): - old.copy_(new) + def unflatten_and_copy(self): + if self._unflatten_and_copy_flag: + unflattened_tensor_list = _unflatten_dense_tensors(self._flat_tensor, self._bucket) + for old, new in zip(self._bucket, unflattened_tensor_list): + old.copy_(new) diff --git a/internlm/solver/optimizer/utils.py b/internlm/solver/optimizer/utils.py index 38e4560..63c8c25 100644 --- a/internlm/solver/optimizer/utils.py +++ b/internlm/solver/optimizer/utils.py @@ -95,37 +95,34 @@ def reduce_tensor(tensor, dtype=None, dst_rank=None, parallel_mode=ParallelMode. :type parallel_mode: ParallelMode, optional """ # use the original dtype - if dtype is None: - dtype = tensor.dtype + # if dtype is None: + assert dtype is None + dtype = tensor.dtype # cast the data to specified dtype for reduce/all-reduce - if tensor.dtype != dtype: - tensor_to_reduce = tensor.to(dtype) - else: - tensor_to_reduce = tensor + # if tensor.dtype != dtype: + # tensor_to_reduce = tensor.to(dtype) + # else: + # tensor_to_reduce = tensor - world_size = gpc.get_world_size(parallel_mode) + # world_size = gpc.get_world_size(parallel_mode) + # tensor.div_(world_size) group = gpc.get_group(parallel_mode) - tensor_to_reduce.div_(world_size) # if rank is None, all reduce will be used # else, reduce is used use_all_reduce = dst_rank is None if use_all_reduce: - dist.all_reduce(tensor_to_reduce, group=group) + handle = dist.all_reduce(tensor=tensor, group=group, op=torch.distributed.ReduceOp.AVG, async_op=True) else: ranks_in_group = gpc.get_ranks_in_group(parallel_mode) global_rank = ranks_in_group[dst_rank] - dist.reduce(tensor=tensor_to_reduce, dst=global_rank, group=group) + handle = dist.reduce( + tensor=tensor, dst=global_rank, group=group, op=torch.distributed.ReduceOp.AVG, async_op=True + ) - # recover the original dtype - if tensor.dtype != dtype and tensor is not tensor_to_reduce: - local_rank = gpc.get_local_rank(parallel_mode) - if use_all_reduce or dst_rank == local_rank: - tensor.copy_(tensor_to_reduce) - - return tensor + return handle def has_inf_or_nan(tensor): diff --git a/internlm/train/training_internlm.py b/internlm/train/training_internlm.py index a0dd913..e5b5097 100644 --- a/internlm/train/training_internlm.py +++ b/internlm/train/training_internlm.py @@ -12,6 +12,7 @@ from torch.utils.data import ConcatDataset, DataLoader from internlm.core.context import ParallelMode from internlm.core.context import global_context as gpc +from internlm.core.context.random import set_mode from internlm.core.naive_amp import NaiveAMPModel from internlm.core.trainer import TrainState from internlm.data.batch_sampler import StaticBatchSampler, get_dpsampler_dataloader @@ -24,7 +25,7 @@ from internlm.data.packed_dataset import ( get_packed_dataset_without_short_length, ) from internlm.data.utils import DATASET_TYPE_IDS_MAP, unpack_data -from internlm.monitor import set_env_var +from internlm.monitor import send_heartbeat, set_env_var from internlm.monitor.monitor import monitor_manager as mm from internlm.solver.beta2_scheduler import Beta2Scheduler from internlm.solver.lr_scheduler import FineTuneCosineAnnealingWarmupLR @@ -39,6 +40,7 @@ from internlm.utils.parallel import ( sync_model_param_within_tp, ) from internlm.utils.registry import MODEL_INITIALIZER +from internlm.utils.timeout import llm_timeout from torch.distributed.fsdp import FullyShardedDataParallel as FSDP from torch.distributed.fsdp.fully_sharded_data_parallel import ( @@ -54,6 +56,7 @@ from internlm.model.modeling_internlm import PackedFlashBaseLayer1D, PackedFlash logger = get_logger(__file__) +@llm_timeout(func_name="initialize_model") def initialize_model(): """ Initialize model with Automatic Mixed Precision. @@ -93,6 +96,10 @@ def initialize_model(): # the same across tensor parallelism. sync_model_param_within_tp(model) + # Change random state mode to ParallelMode.DATA after model is built, guaranteeing the random + # state in the same dp group are all the same. + set_mode(ParallelMode.DATA) + return model @@ -114,6 +121,7 @@ def warp_FSDP_model(model: Union[nn.Module, nn.ModuleList]): return model +@llm_timeout(func_name="initialize_optimizer") def initialize_optimizer(model: Union[nn.Module, nn.ModuleList]): """ Initialize optimizer. @@ -158,6 +166,7 @@ def initialize_optimizer(model: Union[nn.Module, nn.ModuleList]): return optimizer, beta2_scheduler, lr_scheduler +@llm_timeout(func_name="get_train_data_loader") def get_train_data_loader( num_worker: int = 0, dataset_generate_func: Callable = None, train_sampler=None, train_collate_fn=None ): @@ -237,6 +246,7 @@ def get_train_data_loader( return train_dl, dataset_types +@llm_timeout(func_name="get_validation_data_loader") def get_validation_data_loader( num_worker: int = 0, dataset_generate_func: Callable = None, val_collate_fn=None, dataloader_func=None ): @@ -298,6 +308,7 @@ def get_validation_data_loader( return val_dls +@llm_timeout(func_name="load_new_batch") def load_new_batch(train_dl: DataLoader, train_iter: Iterable, train_state: TrainState): """ Load and return the new batch data based on training data loader. @@ -355,6 +366,7 @@ def initialize_llm_profile(profiling: bool = False, start_time: str = None): ) +@llm_timeout(func_name="record_current_batch_training_metrics") def record_current_batch_training_metrics( get_tflops_func, logger, @@ -440,6 +452,9 @@ def record_current_batch_training_metrics( else: writer.add_scalar(key=key, value=value, step=train_state.step_count) + if gpc.config.monitor.alert.get("light_monitor_address", None) and batch_count % 50 == 0: + send_heartbeat("train_metrics", infos) + if update_panel: # metrics shown with dashboard panels panel_metrics = { @@ -465,4 +480,8 @@ def record_current_batch_training_metrics( logger.info(line) # if loss spike occurs, send alert info to feishu - mm.monitor_loss_spike(alert_address=gpc.config.alert_address, step_count=batch_count, cur_step_loss=loss.item()) + mm.monitor_loss_spike( + alert_address=gpc.config.monitor.alert.feishu_alert_address, + step_count=batch_count, + cur_step_loss=loss.item(), + ) diff --git a/internlm/utils/logger.py b/internlm/utils/logger.py index 679913a..6111553 100644 --- a/internlm/utils/logger.py +++ b/internlm/utils/logger.py @@ -84,7 +84,7 @@ def initialize_uniscale_logger( job_name and launch_time and file_name ), "If file_path is None, job_name, launch_time and file_name must be setted." log_file_name = file_name - log_folder = os.path.join(job_name, launch_time, "logs") + log_folder = os.path.join("RUN", job_name, launch_time, "logs") log_dir = os.path.join(log_folder, log_file_name) file_path = log_dir diff --git a/internlm/utils/model_checkpoint.py b/internlm/utils/model_checkpoint.py index b36afec..a7a0c16 100644 --- a/internlm/utils/model_checkpoint.py +++ b/internlm/utils/model_checkpoint.py @@ -3,37 +3,136 @@ import copy import fcntl +import inspect import os import socket import time from enum import Enum -from typing import Dict +from typing import Callable, Dict, Union import torch from internlm.core.context import ParallelMode from internlm.core.context import global_context as gpc from internlm.core.trainer import TrainState +from internlm.initialize.launch import get_config_value +from internlm.initialize.legacy.launch import ( + auto_resume_sanity_check, + ckpt_info_sanity_check, +) from internlm.monitor import send_alert_message -from internlm.solver.optimizer import HybridZeroOptimizer +from internlm.solver.optimizer import HybridZeroOptimizer, reload_zero_fp32_buff from internlm.utils.common import get_current_device from internlm.utils.logger import get_logger from internlm.utils.megatron_timers import megatron_timer as timer from internlm.utils.storage_manager import ( get_fns, get_storage_manager, + init_storage_manager, llm_load, llm_save, + try_get_storage_backend, ) +from internlm.utils.timeout import llm_timeout logger = get_logger(__file__) -class CheckpointType(Enum): +class CheckpointSaveType(Enum): NORMAL_CHECKPOINT = 1 SNAPSHOT_CHECKPOINT = 2 +class CheckpointLoadType(Enum): + INTERNLM = "internlm" + + +# The load method implemented by internlm by default does not use string representation types, +# but uses enumeration types defined in advance. +LOAD_TYPE_DICT = { + "internlm": CheckpointLoadType.INTERNLM, +} + + +class CheckpointLoadContent: + MODEL = "model" + SAMPLER = "sampler" + OPIMIZER = "optimizer" + SCHEDULAER = "scheduler" + + +class CheckpointLoadMethod: + """The registration class of the checkpoint loading method, + users can define their own custom ckpt loading methods.""" + + LOAD_FUNC_SIG = None + LOAD_TYPE_FUNC = {} + + @staticmethod + def convet_load_type(load_type: str) -> Union[CheckpointLoadType, str]: + if load_type.lower() in LOAD_TYPE_DICT: + # The ckpt load method implemented by internlm by default. + return LOAD_TYPE_DICT[load_type.lower()] + else: + # If it is a user-defined field, we do not do any conversion and represent it as a string. + return load_type + + @staticmethod + def register_ckpt_load_type(load_type: Union[str, CheckpointLoadType], load_func: Callable): + if load_type in CheckpointLoadMethod.LOAD_TYPE_FUNC: + logger.warning(f"{load_type} has aleady been registed!") + return + + CheckpointLoadMethod.LOAD_TYPE_FUNC.update({load_type: load_func}) + + if load_type == CheckpointLoadType.INTERNLM: + CheckpointLoadMethod.LOAD_FUNC_SIG = inspect.signature(load_func) + else: + if inspect.signature(load_func) != CheckpointLoadMethod.LOAD_FUNC_SIG: + logger.warning( + f"registe load model ckpt signature is not same with: {CheckpointLoadMethod.LOAD_FUNC_SIG}" + ) + + @staticmethod + def get_ckpt_load_type_func(load_type: Union[str, CheckpointLoadType]): + return CheckpointLoadMethod.LOAD_TYPE_FUNC[load_type] + + +class CheckpointLoadMask: + """ + According to the content field in the incoming ckpt_info, decide which components to load. + """ + + LOAD_CONTENT_DICT = { + "model": CheckpointLoadContent.MODEL, + "sampler": CheckpointLoadContent.SAMPLER, + "optimizer": CheckpointLoadContent.OPIMIZER, + "scheduler": CheckpointLoadContent.SCHEDULAER, + } + + def __init__(self, content: tuple) -> None: + self.load_set = set(map(lambda x: x.lower(), content)) + if "all" in self.load_set: + self.load_set = set(CheckpointLoadMask.LOAD_CONTENT_DICT.values()) + else: + self.load_set = set(map(lambda x: CheckpointLoadMask.LOAD_CONTENT_DICT[x.lower()], content)) + + def need_load(self, content: CheckpointLoadContent): + return content in self.load_set + + def not_only_load(self, content: CheckpointLoadContent): + return content in self.load_set and len(self.load_set) > 1 + + def only_load(self, content: CheckpointLoadContent): + return set((content,)) == self.load_set + + def __str__(self) -> str: + return f"{self.load_set}." + + def __repr__(self) -> str: + return f"{self.load_set}." + + def get_model_topology(model): """ Returns: @@ -75,6 +174,66 @@ def get_state_dict(model): return states +def try_load_internlm_ckpt(ckpt_mm, load_info, train_state: TrainState): + load_content_str = "" + load_ckpt_folder = load_info["path"] + load_content: CheckpointLoadMask = load_info["content"] + + if gpc.is_rank_for_log(): + logger.info(f"Try load_ckpt_folder: {load_ckpt_folder}") + + if load_content.need_load(CheckpointLoadContent.MODEL): + load_model_checkpoint(folder=load_ckpt_folder, model=ckpt_mm.model) + load_content_str += f"{CheckpointLoadContent.MODEL}, " + + if load_content.not_only_load(CheckpointLoadContent.MODEL): + # load training states. + load_context(load_ckpt_folder, train_state) + + # load optimzier states. + if load_content.need_load(CheckpointLoadContent.OPIMIZER): + load_optimizer_checkpoint(load_ckpt_folder, ckpt_mm.optimizer) + load_content_str += f"{CheckpointLoadContent.OPIMIZER}, " + else: + if gpc.is_rank_for_log(): + logger.warning("CheckpointManager has no 'optimizer', skip reload optim checkpoint!") + + # load lr scheduler states. + if load_content.need_load(CheckpointLoadContent.SCHEDULAER): + if ckpt_mm.lr_scheduler: + load_scheduler(load_ckpt_folder, ckpt_mm.lr_scheduler, ckpt_mm.optimizer, train_state) + load_content_str += f"{CheckpointLoadContent.SCHEDULAER}, " + else: + if gpc.is_rank_for_log(): + logger.warning("CheckpointManager has no 'lr_scheduler', skip reload lr_scheduler checkpoint!") + + # load dataloader sampler states. + if load_content.need_load(CheckpointLoadContent.SAMPLER): + if hasattr(train_state, "batch_sampler") and not isinstance( + train_state.batch_sampler, torch.utils.data.sampler.BatchSampler + ): + load_sampler(load_ckpt_folder, ckpt_mm.train_dl.batch_sampler) + # track the actual updates of sampler when using weighted sampling + train_state.init_batch_sampler(ckpt_mm.train_dl.batch_sampler) + load_content_str += f"{CheckpointLoadContent.SAMPLER}, " + else: + if gpc.is_rank_for_log(): + logger.warning("CheckpointManager skip reload 'batch_sampler'") + + # reload data state dict. + if hasattr(train_state, "data_state_dict"): + ckpt_mm.train_dl.dataset.load_state_dict( + llm_load(os.path.join(load_ckpt_folder, "sampler_0.pt")), ckpt_path=load_ckpt_folder + ) + load_content_str += f"{CheckpointLoadContent.SAMPLER}, " + else: + if gpc.is_rank_for_log(): + logger.warning( + "CheckpointManager has no 'data_state_dict', skip reload data_state_dict checkpoint!" + ) + return load_content_str + + def save_model_checkpoint(folder, model): """ Save the model according to the relationship between tp and dp. The principle is that the data of each tp @@ -257,15 +416,16 @@ def load_sampler(ckpt_path: str, sampler): torch.cuda.empty_cache() -def load_context(ckpt_path: str, train_dl, train_state: TrainState): +def load_context(ckpt_path: str, train_state: TrainState): context_stuffs = llm_load(os.path.join(ckpt_path, "context.pt")) - train_state.load_state_dict(context_stuffs, train_dl) + train_state.load_state_dict(context_stuffs) if gpc.is_rank_for_log(): logger.info(f"reload train_state:{train_state}") torch.cuda.empty_cache() -def load_scheduler(ckpt_path: str, lr_scheduler, optimizer, learning_rate, train_state: TrainState): +def load_scheduler(ckpt_path: str, lr_scheduler, optimizer, train_state: TrainState): + learning_rate = train_state.lr scheduler_states = llm_load(os.path.join(ckpt_path, "schedulder.pt")) if learning_rate != scheduler_states["base_lrs"][0] and gpc.is_rank_for_log(): logger.warning( @@ -294,7 +454,17 @@ def load_scheduler(ckpt_path: str, lr_scheduler, optimizer, learning_rate, train class CheckpointManager: """StorageManagerContext""" - def __init__(self, ckpt_config, model, model_config=None, model_config_file=None, feishu_address=None) -> None: + def __init__( + self, + ckpt_config, + model, + train_dl=None, + optimizer=None, + lr_scheduler=None, + model_config=None, + model_config_file=None, + feishu_address=None, + ) -> None: """ CheckpointManager is used to decide when to store ckpt. If it is an asynchronous upload mode, you must call wait_async_upload_finish at the end of the program to wait @@ -307,22 +477,44 @@ class CheckpointManager: lr_scheduler (object): lr_scheduler obj. model_config (dict): model config. """ - self.enable_save_ckpt = ckpt_config.enable_save_ckpt - self.checkpoint_every = ckpt_config.checkpoint_every - self.save_ckpt_folder = ckpt_config.save_ckpt_folder - self.snapshot_ckpt_folder = ckpt_config.snapshot_ckpt_folder - self.oss_snapshot_freq: int = ckpt_config.oss_snapshot_freq - self.stop_file_path = ckpt_config.stop_file_path - self.load_model_only_folder = ckpt_config.load_model_only_folder + self.enable_save_ckpt = get_config_value(ckpt_config, "enable_save_ckpt", False) + self.checkpoint_every = get_config_value(ckpt_config, "checkpoint_every", 100) + self.save_ckpt_folder = get_config_value(ckpt_config, "save_ckpt_folder", None) + self.oss_snapshot_freq: int = get_config_value(ckpt_config, "oss_snapshot_freq", 50) + self.stop_file_path = get_config_value(ckpt_config, "stop_file_path", None) + if self.save_ckpt_folder: + self.snapshot_ckpt_folder = get_config_value( + ckpt_config, "snapshot_ckpt_folder", os.path.join(self.save_ckpt_folder, "snapshot") + ) + self.async_upload_tmp_folder = get_config_value( + ckpt_config, "async_upload_tmp_folder", "/dev/shm/internlm_tmp_ckpt/" + ) + else: + self.snapshot_ckpt_folder = None + self.async_upload_tmp_folder = None + + self.async_upload = get_config_value(ckpt_config, "async_upload", False) + + # initialization storage manager + init_storage_manager(self.enable_save_ckpt, self.async_upload_tmp_folder, self.async_upload) + self.feishu_address = feishu_address self.storage_manager = get_storage_manager() self.snapshot_counter = 0 - self.load_optimizer = gpc.config.ckpt.load_optimizer self.model = model + self.optimizer = optimizer + self.lr_scheduler = lr_scheduler + self.train_dl = train_dl self.model_config = model_config self.model_config_file = model_config_file + # Register defalut internlm ckpt load type. + self.defalut_load_type_func = {CheckpointLoadType.INTERNLM: try_load_internlm_ckpt} + for ckpt_load_type in CheckpointLoadType: + CheckpointLoadMethod.register_ckpt_load_type(ckpt_load_type, self.defalut_load_type_func[ckpt_load_type]) + + # Init alter file. if self.stop_file_path and gpc.get_global_rank() == 0: dir_path = os.path.dirname(self.stop_file_path) if dir_path != "" and not os.path.exists(dir_path): @@ -330,21 +522,35 @@ class CheckpointManager: with open(self.stop_file_path, "w", encoding="utf-8") as f: f.write("0") - if ckpt_config.load_given_ckpt is False: - # Priority: load_given_ckpt(True) > latest_checkpoint > load_model_only_folder - latest_ckpt_path = self.query_lastest_ckpt() - if latest_ckpt_path: - self.load_ckpt_folder = latest_ckpt_path - else: - # At this time, we have to load model init weights and train from step 0. - self.load_ckpt_folder = self.load_model_only_folder - else: - self.load_ckpt_folder = ckpt_config.load_ckpt_folder + self.load_ckpt_info = get_config_value(ckpt_config, "load_ckpt_info", None) + if self.load_ckpt_info is None: # (legacy): Try Compatible with old interfaces + self.load_ckpt_info = ckpt_info_sanity_check(ckpt_config) - if gpc.is_rank_for_log(): - logger.info(f"load_ckpt_folder will set to :'{self.load_ckpt_folder}'") - if self.stop_file_path is None: - logger.warning("no set stop_file_path, quit_signal_handler is disable") + # Auto-reload latest checkpoint, it will overwrite the setting of 'load_ckpt_info'. + self.auto_resume = get_config_value(ckpt_config, "auto_resume", None) + if self.auto_resume is None: # (legacy): Try Compatible with old interfaces + self.auto_resume = auto_resume_sanity_check(ckpt_config) + if self.auto_resume: + self.load_ckpt_info = self.query_lastest_ckpt() + + if self.stop_file_path is None and gpc.is_rank_for_log(): + logger.warning("no set stop_file_path, quit_signal_handler is disable") + + # convert to internal representation + if self.load_ckpt_info: + assert ( + "path" in self.load_ckpt_info + and "content" in self.load_ckpt_info + and "ckpt_type" in self.load_ckpt_info + ), "please set content in ckpt setting, eg: ckpt = dict(path='', content=['model'], ckpt_type='internlm')" + + # replace load_ckpt + self.load_ckpt_info["content"] = CheckpointLoadMask(self.load_ckpt_info["content"]) + self.load_ckpt_info["ckpt_type"] = CheckpointLoadMethod.convet_load_type(self.load_ckpt_info["ckpt_type"]) + + # test storage setting is ok. + if self.enable_save_ckpt: + self.try_ping_storage() def quit_signal_handler(self, train_state) -> bool: """ @@ -358,7 +564,7 @@ class CheckpointManager: Returns: bool: whether to quit. """ - now_break, now_save_ckpt, save_type = False, False, CheckpointType.NORMAL_CHECKPOINT + now_break, now_save_ckpt, save_type = False, False, CheckpointSaveType.NORMAL_CHECKPOINT if self.stop_file_path is None: return now_break, now_save_ckpt, save_type @@ -389,24 +595,29 @@ now step_count is {train_state.step_count}", return now_break, now_save_ckpt, save_type - def try_save_checkpoint(self, train_state): - if not self.enable_save_ckpt: - return False - - save_ckpts, save_type = False, CheckpointType.NORMAL_CHECKPOINT + def is_now_to_save_ckpt(self, train_state) -> (bool, CheckpointSaveType, bool): + save_ckpts, save_type, now_break = False, CheckpointSaveType.NORMAL_CHECKPOINT, False if self.oss_snapshot_freq > 1 and train_state.step_count % self.oss_snapshot_freq == 0: - save_ckpts, save_type = True, CheckpointType.SNAPSHOT_CHECKPOINT + save_ckpts, save_type = True, CheckpointSaveType.SNAPSHOT_CHECKPOINT if train_state.step_count % self.checkpoint_every == 0: - save_ckpts, save_type = True, CheckpointType.NORMAL_CHECKPOINT + save_ckpts, save_type = True, CheckpointSaveType.NORMAL_CHECKPOINT now_break, singal_save_ckpts, singal_save_type = self.quit_signal_handler(train_state) if save_ckpts is False: save_ckpts = singal_save_ckpts save_type = singal_save_type + return save_ckpts, save_type, now_break + + def try_save_checkpoint(self, train_state): + if not self.enable_save_ckpt: + return False + + save_ckpts, save_type, now_break = self.is_now_to_save_ckpt(train_state) + if save_ckpts: # Wait for the previous round of asynchronous upload storage to complete. self.storage_manager.wait() - if save_type == CheckpointType.SNAPSHOT_CHECKPOINT: + if save_type == CheckpointSaveType.SNAPSHOT_CHECKPOINT: # Snapshot number, with only two snapshots written alternately. self.snapshot_counter = (self.snapshot_counter + 1) % 2 save_ckpt_folder = os.path.join(self.snapshot_ckpt_folder, f"{self.snapshot_counter}") @@ -436,7 +647,7 @@ now step_count is {train_state.step_count}", Tuple(str, int): path of latest ckpt and ckpt step, if not found, None will return. """ ckpt_list = self.storage_manager.get_fns(self.save_ckpt_folder) - if len(ckpt_list) == 0: + if ckpt_list is None or len(ckpt_list) == 0: return None, None max_normal_step = 0 @@ -459,14 +670,16 @@ now step_count is {train_state.step_count}", ckpt_list_1 = self.storage_manager.get_fns(snapshot_path_0) ckpt_list_2 = self.storage_manager.get_fns(snapshot_path_1) max_step_0, max_step_1 = 0, 0 - for ckpt in ckpt_list_1: - ckpt = ckpt.strip("/") - if ckpt.endswith(".step"): - max_step_0 = max(max_step_0, int(ckpt.split(".")[0])) - for ckpt in ckpt_list_2: - ckpt = ckpt.strip("/") - if ckpt.endswith(".step"): - max_step_1 = max(max_step_1, int(ckpt.split(".")[0])) + if ckpt_list_1: + for ckpt in ckpt_list_1: + ckpt = ckpt.strip("/") + if ckpt.endswith(".step"): + max_step_0 = max(max_step_0, int(ckpt.split(".")[0])) + if ckpt_list_2: + for ckpt in ckpt_list_2: + ckpt = ckpt.strip("/") + if ckpt.endswith(".step"): + max_step_1 = max(max_step_1, int(ckpt.split(".")[0])) snap_load_path = snapshot_path_0 if max_step_0 > max_step_1 else snapshot_path_1 snap_step = max(max_step_0, max_step_1) @@ -476,11 +689,12 @@ now step_count is {train_state.step_count}", def query_latest_snapshot_step_local(self): max_step, max_step_path = 0, None - for root, _, files in os.walk(self.save_ckpt_folder, followlinks=True): + save_ckpt_folder = self.save_ckpt_folder.split(":")[1] + for root, _, files in os.walk(save_ckpt_folder, followlinks=True): for fn in files: fn = fn.strip("/") if fn.endswith(".step"): - # We assume that both normal ckpt and snapshot ckpt will store the '.step' file + # We assume that both internlm ckpt and snapshot ckpt will store the '.step' file # as an integrity flag. step = int(fn.rsplit(".", maxsplit=1)[0]) if max_step < step: @@ -490,100 +704,55 @@ now step_count is {train_state.step_count}", return max_step_path, max_step def query_lastest_ckpt(self): - latest_checkpoint = None + latest_ckpt, step = None, -1 # Training was automatically restarted by the process, forcing the latest snapshot to be read. if self.save_ckpt_folder: - if self.save_ckpt_folder.startswith("boto3"): - latest_checkpoint, step = self.query_latest_snapshot_step_boto3() - elif self.save_ckpt_folder.startswith("local"): - latest_checkpoint, step = self.query_latest_snapshot_step_local() - else: - latest_checkpoint, step = None, 0 + backend, _ = try_get_storage_backend(self.save_ckpt_folder) + if backend == "boto3": + latest_ckpt, step = self.query_latest_snapshot_step_boto3() + if latest_ckpt and not latest_ckpt.startswith("boto3:"): + latest_ckpt = ":".join(["boto3", latest_ckpt]) + elif backend == "local": + latest_ckpt, step = self.query_latest_snapshot_step_local() + if latest_ckpt and not latest_ckpt.startswith("local:"): + latest_ckpt = ":".join(["local", latest_ckpt]) - if latest_checkpoint is not None: - if gpc.is_rank_for_log(): - logger.info(f"Found latest ckpt : {latest_checkpoint}, step: {step}") - send_alert_message( - address=self.feishu_address, - message=f"Auto restart resume from ckpt-path: '{latest_checkpoint}', step : {step}", - ) - else: - if gpc.is_rank_for_log(): - send_alert_message( - address=self.feishu_address, - message=f"Can't find snapshot checkpoint, use default load-ckpt path: {latest_checkpoint}", - ) + if gpc.is_rank_for_log(): + logger.info(f"Found latest ckpt {latest_ckpt if latest_ckpt else 'None'}, step: {step}...") - return latest_checkpoint + return dict(path=latest_ckpt, content=("all",), ckpt_type="internlm") - def try_load_model(self, current_time=""): - model_load_path = None + def try_resume_training(self, train_state: TrainState, current_time=""): - if self.load_ckpt_folder and self.load_model_only_folder: - raise ValueError( - "Error, try to use both load_ckpt_folder and load_model_only_folder paths, \ -if you only need to load model weights (for example starting an SFT task for the first time), \ -set load_model_only_folder path, if you need to resume training from ckpt, \ -set load_ckpt_folder or use default value \ -(if is the default value, internlm will try to load the latest ckpt from save_ckpt_folder)" - ) - - if self.load_ckpt_folder: - if gpc.is_rank_for_log(): - logger.info( - f"===========Resume training from `{self.load_ckpt_folder}` {current_time} on host:" - f"{socket.gethostname()}===========" - ) - model_load_path = self.load_ckpt_folder - elif self.load_model_only_folder: - if gpc.is_rank_for_log(): - logger.info( - f"===========Load Model from `{self.load_model_only_folder}` {current_time} on host:" - f"{socket.gethostname()}===========" - ) - model_load_path = self.load_model_only_folder - else: + if self.load_ckpt_info is None or self.load_ckpt_info["path"] is None: if gpc.is_rank_for_log(): logger.info( f"===========New Run {current_time} on host:{socket.gethostname()},rank={gpc.get_global_rank()}," f"tp={gpc.get_local_rank(ParallelMode.TENSOR)},pp={gpc.get_local_rank(ParallelMode.PIPELINE)}," f"dp={gpc.get_local_rank(ParallelMode.DATA)}===========" ) + else: + load_path = self.load_ckpt_info["path"] + load_content = self.load_ckpt_info["content"] + load_type = self.load_ckpt_info["ckpt_type"] - # Loading model weights must be done before zero is initialized. - if model_load_path is not None: - load_model_checkpoint(folder=model_load_path, model=self.model) + load_func = CheckpointLoadMethod.get_ckpt_load_type_func(load_type) + load_content_str = load_func(self, self.load_ckpt_info, train_state) - def try_resume_training(self, lr_scheduler, optimizer, lr, train_state, train_dl): - """Attempt to restore the training state of the last ckpt. + # If we only load model weight, we need rewrite zero optim's fp32 buffer. + if load_content.only_load(CheckpointLoadContent.MODEL) and isinstance(self.optimizer, HybridZeroOptimizer): + reload_zero_fp32_buff(self.optimizer) - Args: - lr_scheduler (_LRScheduler): lr_scheduler object. - optimizer (Optimizer): optimizer object. - lr (float): learning rate. - train_state (dict): traing states. - train_dl (DataLoader): traning dataloader object - """ - if self.load_ckpt_folder is not None: - # load optimzier states. - if self.load_optimizer: - load_optimizer_checkpoint(self.load_ckpt_folder, optimizer) - # load lr scheduler states. - load_scheduler(self.load_ckpt_folder, lr_scheduler, optimizer, lr, train_state) - # load training states. - load_context(self.load_ckpt_folder, train_dl, train_state) - # load dataloader sampler states. - if hasattr(train_state, "batch_sampler") and not isinstance( - train_state.batch_sampler, torch.utils.data.sampler.BatchSampler - ): - load_sampler(self.load_ckpt_folder, train_dl.batch_sampler) - if hasattr(train_state, "data_state_dict"): - train_dl.dataset.load_state_dict( - llm_load(os.path.join(self.load_ckpt_folder, "sampler_0.pt")), ckpt_path=self.load_ckpt_folder + if gpc.is_rank_for_log(): + logger.info(f"load_ckpt_info : {self.load_ckpt_info}") + logger.info( + f"===========Resume training from `{load_path}` {current_time} on host:" + f"{socket.gethostname()}===========" ) - self.optimizer = optimizer - self.lr_scheduler = lr_scheduler + if load_content_str: + logger.info(f"===========Load contents are: {load_content_str}") + @llm_timeout(func_name="save_checkpoint") def save_checkpoint( self, folder, @@ -624,8 +793,10 @@ set load_ckpt_folder or use default value \ ) if gpc.is_rank_for_log(): - scheduler_states = scheduler.state_dict() - llm_save(os.path.join(folder, "schedulder.pt"), saved_obj=scheduler_states) + if scheduler: + scheduler_states = scheduler.state_dict() + llm_save(os.path.join(folder, "schedulder.pt"), saved_obj=scheduler_states) + if hasattr(train_state, "batch_sampler") and not isinstance( train_state.batch_sampler, torch.utils.data.sampler.BatchSampler ): @@ -655,3 +826,12 @@ set load_ckpt_folder or use default value \ def set_save_folder(self, folder, step): self.storage_manager.latest_save_folder = folder self.storage_manager.latest_save_step = step + + def try_ping_storage(self): + if gpc.get_global_rank() % 8 == 0: + buff = torch.ones((1, 64, 64), dtype=torch.bfloat16) + test_fn = os.path.join(self.save_ckpt_folder, f"pings/{socket.gethostname()}.ping") + self.storage_manager.save(test_fn, buff) + self.storage_manager.wait() + self.storage_manager.load(test_fn) + del buff diff --git a/internlm/utils/storage_manager.py b/internlm/utils/storage_manager.py index c7b71f4..36bd105 100644 --- a/internlm/utils/storage_manager.py +++ b/internlm/utils/storage_manager.py @@ -46,12 +46,12 @@ def get_fns(fp: str): return storage_manager.get_fns(fp) -def llm_load(fp: str, *args, **kwargs): - return storage_manager.load(fp, *args, **kwargs) +def llm_load(fp: str, **kwargs): + return storage_manager.load(fp, **kwargs) -def llm_save(save_path: str, saved_obj: Any, *args, **kwargs): - storage_manager.save(save_path, *args, saved_obj=saved_obj, **kwargs) +def llm_save(save_path: str, saved_obj: Any, **kwargs): + storage_manager.save(save_path, to_save_obj=saved_obj, **kwargs) class StorageClient: @@ -63,19 +63,23 @@ class StorageClient: self.handler = handler @staticmethod - def load(client, load_path: str, *args, **kwargs): + def load(*args, **kwargs): raise NotImplementedError @staticmethod - def sync_upload_fileobj(*args, saved_obj=None, **kwargs): + def sync_upload_fileobj(*args, **kwargs): raise NotImplementedError @staticmethod - def assert_fp_exists(client): + def async_upload_fileobj(*args, **kwargs): raise NotImplementedError @staticmethod - def get_fns(client): + def assert_fp_exists(*args, **kwargs): + raise NotImplementedError + + @staticmethod + def get_fns(*args, **kwargs): raise NotImplementedError @@ -92,40 +96,65 @@ class Boto3MetaInfo: async_upload_fn: callable, local_nvme_path=None, ) -> None: - self.is_async = is_async + # all need info. self.client = handler self.bucket_name = bucket_name - self.endpoint = endpoint self.file_path = file_path - self.async_upload_fn = async_upload_fn + # only save need info. self.local_nvme_path = local_nvme_path + self.is_async = is_async + self.endpoint = endpoint + self.async_upload_fn = async_upload_fn def __str__(self) -> str: return f"is_async: {self.is_async}, bucket_name:{self.bucket_name}, endpoint:{self.endpoint}, \ local_nvme_path: {self.local_nvme_path}" + @staticmethod + def unpack_boto3_save_meta(meta): + if meta.is_async: + return meta.client, meta.bucket_name, meta.file_path, meta.local_nvme_path + else: + return meta.client, meta.bucket_name, meta.file_path + + @staticmethod + def unpack_boto3_nosave_meta(meta): + return meta.client, meta.bucket_name, meta.file_path + class LocalMetaInfo: """Local meta info for save/load etc.""" - def __init__(self, handler: StorageClient, dest_path: str) -> None: - self.is_async = False - self.client = handler - self.dest_path = dest_path + def __init__(self, file_path: str) -> None: + self.file_path = file_path self.async_upload_fn = None + self.is_async = False + + @staticmethod + def unpack_local_save_meta(meta): + return (meta.file_path,) + + @staticmethod + def unpack_local_nosave_meta(meta): + return (meta.file_path,) -def unpack_meta(meta): - args = [] - is_async = meta.is_async - for k, v in meta.__dict__.items(): - if k in ("endpoint", "async_upload_fn", "is_async"): - continue - if not is_async and k in ("local_nvme_path",): - continue - args.append(v) +def unpack_save_meta(meta: Union[Boto3MetaInfo, LocalMetaInfo]): + if isinstance(meta, Boto3MetaInfo): + return Boto3MetaInfo.unpack_boto3_save_meta(meta) + elif isinstance(meta, LocalMetaInfo): + return LocalMetaInfo.unpack_local_save_meta(meta) + else: + raise ValueError(f"unkonwn meta info: {type(meta)}") - return args + +def unpack_nosave_meta(meta: Union[Boto3MetaInfo, LocalMetaInfo]): + if isinstance(meta, Boto3MetaInfo): + return Boto3MetaInfo.unpack_boto3_nosave_meta(meta) + elif isinstance(meta, LocalMetaInfo): + return LocalMetaInfo.unpack_local_nosave_meta(meta) + else: + raise ValueError(f"unkonwn meta info: {type(meta)}") def compute_file_md5_by_chunk(file_name: str): @@ -136,6 +165,22 @@ def compute_file_md5_by_chunk(file_name: str): return hash_md5.hexdigest() +def try_get_storage_backend(path: str): + sre = path.split(":", maxsplit=1) + if len(sre) == 1: + if path.startswith("s3:"): + backend = "boto3" + if gpc.is_rank_for_log(): + logger.warning(f"path: '{path}' not start with backend prefix, guess it is the backend of boto3.") + else: + backend = "local" + if gpc.is_rank_for_log(): + logger.warning(f"path: '{path}' not start with backend prefix, guess it is the backend of local.") + return backend, sre + else: + return sre[0], sre[1] # (backend_prefix, splited_path) + + class Boto3Client(StorageClient): """ Boto3Client @@ -189,13 +234,11 @@ class Boto3Client(StorageClient): ) @staticmethod - def sync_upload_fileobj( - handler, bucket_name: str, fp: str, local_nvme_path: str, *args, saved_obj=None, **kwargs - ): # pylint: disable=W0613 + def sync_upload_fileobj(handler, bucket_name: str, fp: str, saved_obj=None, **kwargs): assert saved_obj is not None, "saved_obj is None!" try: with io.BytesIO() as f: - torch.save(saved_obj, f, *args, **kwargs) + torch.save(saved_obj, f, **kwargs) f.seek(0) handler.client.upload_fileobj(f, bucket_name, fp, Config=handler.config) except handler.botocore.exceptions.EndpointConnectionError as exc: @@ -204,14 +247,7 @@ class Boto3Client(StorageClient): ) from exc @staticmethod - def load( - handler, - bucket_name: str, - fp: str, - local_nvme_path: str, # pylint: disable=W0613 - *args, - **kwargs, - ) -> Dict: + def load(handler, bucket_name: str, fp: str, **kwargs) -> Dict: """ Args: fp (str): Path to save, eg. s3://opennlplab/model_weights/xxx/ddd.pt @@ -220,7 +256,7 @@ class Boto3Client(StorageClient): with io.BytesIO() as f: handler.client.download_fileobj(bucket_name, fp, f, Config=handler.config) f.seek(0) - states = torch.load(f, *args, **kwargs) + states = torch.load(f, **kwargs) except handler.botocore.exceptions.EndpointConnectionError as exc: raise RuntimeError( f"Boto3 Network Error: Please Check your Internet Connection in {socket.gethostname()}" @@ -228,24 +264,37 @@ class Boto3Client(StorageClient): return states @staticmethod - def assert_fp_exists(handler, bucket_name: str, fp: str, local_nvme_path: str): # pylint: disable=W0613 + def assert_fp_exists(handler, bucket_name: str, fp: str): # pylint: disable=W0613 assert len(list(handler.client.list_objects(Bucket=bucket_name, Prefix=fp)["Contents"])) > 0, fp @staticmethod - def get_fns(handler, bucket_name: str, fp: str, local_nvme_path: str, *args, **kwargs): # pylint: disable=W0613 + def is_fp_exists(handler, bucket_name: str, fp: str): # pylint: disable=W0613 + re = handler.client.list_objects(Bucket=bucket_name, Prefix=fp) + if "Contents" in re: + return len(list(re["Contents"])) > 0 + else: + return False + + @staticmethod + def get_fns(handler, bucket_name: str, fp: str): """ Ref: https://stackoverflow.com/questions/54314563/ how-to-get-more-than-1000-objects-from-s3-by-using-list-objects-v2 """ - paginator = handler.client.get_paginator("list_objects_v2") - pages = paginator.paginate(Bucket=bucket_name, Prefix=fp) - folder_name_list = [] - for page in pages: - if "Contents" in page: - for obj in page["Contents"]: - pth: str = obj["Key"] - folder_name_list.append(pth.split(fp, maxsplit=1)[1].strip("/").split("/", maxsplit=1)[0]) - return list(set(folder_name_list)) + if Boto3Client.is_fp_exists(handler, bucket_name, fp): + paginator = handler.client.get_paginator("list_objects_v2") + pages = paginator.paginate(Bucket=bucket_name, Prefix=fp) + folder_name_list = [] + for page in pages: + if "Contents" in page: + for obj in page["Contents"]: + pth: str = obj["Key"] + folder_name_list.append(pth.split(fp, maxsplit=1)[1].strip("/").split("/", maxsplit=1)[0]) + return list(set(folder_name_list)) + else: + if gpc.is_rank_for_log(): + logger.warning(f"'{fp}' not found!") + return None @staticmethod def async_upload_fileobj(handler, bucket_name: str, fp: str, local_nvme_path: str): @@ -273,37 +322,35 @@ class LocalClient(StorageClient): super().__init__(None) @staticmethod - def sync_upload_fileobj(handler, fp: str, *args, saved_obj=None, **kwargs): - assert isinstance(handler, LocalClient) + def sync_upload_fileobj(fp: str, saved_obj=None, **kwargs): assert saved_obj is not None fp_dirname = os.path.dirname(fp) if not os.path.exists(fp_dirname): os.makedirs(fp_dirname, exist_ok=True) - torch.save(saved_obj, fp, *args, **kwargs) + torch.save(saved_obj, fp, **kwargs) @staticmethod - def load(handler, fp: str, *args, **kwargs): # pylint: disable=W0613 - assert isinstance(handler, LocalClient) - assert os.path.exists(fp), f"{fp} is not found!" - with open(fp, "rb") as f: - states = torch.load(f, *args, **kwargs) + def load(load_path: str, **kwargs): + assert os.path.exists(load_path), f"{load_path} is not found!" + with open(load_path, "rb") as f: + states = torch.load(f, **kwargs) return states @staticmethod - def assert_fp_exists(handler, folder): - assert isinstance(handler, LocalClient) + def assert_fp_exists(folder): assert os.path.exists(folder), folder @staticmethod - def get_fns(handler, folder): - assert isinstance(handler, LocalClient) - assert os.path.exists(folder), f"folder '{folder}' not exists!" - fns = os.listdir(folder) - return fns + def get_fns(folder): + if not os.path.exists(folder): + if gpc.is_rank_for_log(): + logger.warning(f"'{folder}' not found!") + return None + else: + return os.listdir(folder) @staticmethod - def delete_obj(handler, fp: str): - assert isinstance(handler, LocalClient) + def delete_obj(fp: str): if not os.path.isdir(fp): os.remove(fp) @@ -327,7 +374,10 @@ def get_boto3_meta(fp: str, tmp_local_folder: str, is_async: bool) -> Boto3MetaI assert match is not None, f"url '{fp}' is not a valid boto3 url" bucket_name, endpoint = match.group(1), match.group(2) endpoint = "http://" + endpoint + ":80" - tmp_step_file = get_tmp_file_name(tmp_local_folder, fp) + if is_async: + tmp_step_file = get_tmp_file_name(tmp_local_folder, fp) + else: + tmp_step_file = None return Boto3MetaInfo( is_async=is_async, handler=None, @@ -341,7 +391,7 @@ def get_boto3_meta(fp: str, tmp_local_folder: str, is_async: bool) -> Boto3MetaI def get_local_meta(fp: str) -> LocalMetaInfo: assert not fp.startswith("s3://"), f"Path '{fp}' is not a local path" - return LocalMetaInfo(None, fp) + return LocalMetaInfo(fp) def get_mount_point_free_size(path: str): @@ -427,7 +477,7 @@ class StorageManager(metaclass=SingletonMeta): logger.error(f'tmp_local_folder only have "{free_size}" GB free space, less then 100 GB!') raise RuntimeError(f"Insufficient temporary storage space on {socket.gethostname()}") - def _get_client(self, path=str) -> Union[Boto3MetaInfo, LocalMetaInfo]: + def _get_client(self, path: str, async_mode: bool = False) -> Union[Boto3MetaInfo, LocalMetaInfo]: """ example: local:/path/to/checkpoint @@ -436,17 +486,14 @@ class StorageManager(metaclass=SingletonMeta): Args: path (str): _description_ """ - try: - backend, path = path.split(":", maxsplit=1) - except Exception as exc: - raise AttributeError(f"Given path '{path}' is not startwith backend prefix:'local/boto3'") from exc + backend, path = try_get_storage_backend(path) init_args = (None,) if backend == "local": meta_info = get_local_meta(path) backend_key = backend elif backend == "boto3": - meta_info = get_boto3_meta(path, self.tmp_local_folder, self.async_mode) + meta_info = get_boto3_meta(path, self.tmp_local_folder, async_mode) backend_key = backend + ":" + meta_info.endpoint init_args = (meta_info.endpoint,) if ( @@ -474,17 +521,22 @@ class StorageManager(metaclass=SingletonMeta): def assert_fp_exists(self, folder) -> None: meta = self._get_client(path=folder) - meta.client.assert_fp_exists(*unpack_meta(meta)) + meta.client.assert_fp_exists(*unpack_nosave_meta(meta)) def get_fns(self, folder) -> List[str]: meta = self._get_client(path=folder) - return meta.client.get_fns(*unpack_meta(meta)) + return meta.client.get_fns(*unpack_nosave_meta(meta)) - def save(self, save_path: str, saved_obj: Any, *args, async_upload=None, **kwargs): - meta = self._get_client(path=save_path) + def save(self, save_path: str, to_save_obj: Any, async_upload=None, **kwargs): if async_upload is None: async_upload = self.async_mode + + if not save_path.startswith("boto3:"): + async_upload = False + + meta = self._get_client(save_path, async_upload) + if async_upload: assert ( self.tmp_local_folder @@ -492,22 +544,22 @@ class StorageManager(metaclass=SingletonMeta): tmp_step_file = meta.local_nvme_path self._to_be_del_files.append(tmp_step_file) with open(tmp_step_file, "wb") as f: - torch.save(saved_obj, f, pickle_protocol=pickle.HIGHEST_PROTOCOL) - self.async_executor(meta.async_upload_fn, *unpack_meta(meta)) + torch.save(to_save_obj, f, pickle_protocol=pickle.HIGHEST_PROTOCOL) + self.async_executor(meta.async_upload_fn, *unpack_save_meta(meta)) os.chmod(tmp_step_file, stat.S_IRWXU | stat.S_IRWXG | stat.S_IRWXO) self.async_task_peeding = True else: - meta.client.sync_upload_fileobj(*unpack_meta(meta), *args, saved_obj=saved_obj, **kwargs) + meta.client.sync_upload_fileobj(*unpack_save_meta(meta), saved_obj=to_save_obj, **kwargs) self.upload_count += 1 - def load(self, load_path: str, *args, **kwargs) -> Any: + def load(self, load_path: str, **kwargs) -> Any: self.wait() meta = self._get_client(path=load_path) - return meta.client.load(*unpack_meta(meta), *args, **kwargs) + return meta.client.load(*unpack_nosave_meta(meta), **kwargs) def delete_obj(self, fp: str): meta = self._get_client(path=fp) - meta.client.delete_obj(*unpack_meta(meta)) + meta.client.delete_obj(*unpack_nosave_meta(meta)) def _del_tmp_folder(self): for fp in self._to_be_del_files: @@ -594,23 +646,24 @@ class StorageManager(metaclass=SingletonMeta): if gpc.is_rank_for_log(): self.upload_count += 1 - if self.async_mode: + if self.async_mode and self.latest_save_folder: self.save( os.path.join(self.latest_save_folder, f"{self.latest_save_step}.step"), - saved_obj=dict({"step": self.latest_save_step}), + to_save_obj=dict({"step": self.latest_save_step}), async_upload=False, ) + self.latest_save_folder = None storage_manager: StorageManager = None -def init_storage_manager(ckpt_config): +def init_storage_manager(enable_save_ckpt, async_upload_tmp_folder, async_upload): global storage_manager storage_manager = StorageManager( - ckpt_config.enable_save_ckpt, - tmp_local_folder=ckpt_config.async_upload_tmp_folder, - async_mode=ckpt_config.async_upload, + enable_save_ckpt, + tmp_local_folder=async_upload_tmp_folder, + async_mode=async_upload, ) diff --git a/internlm/utils/timeout.py b/internlm/utils/timeout.py index 07a0911..7a96841 100644 --- a/internlm/utils/timeout.py +++ b/internlm/utils/timeout.py @@ -1,4 +1,13 @@ +import datetime +import os import signal +import socket +import traceback +from functools import wraps + +from internlm.utils.logger import get_logger + +logger = get_logger(__file__) class Timeout: @@ -24,3 +33,81 @@ class Timeout: def __exit__(self, error_type, value, traceback): signal.alarm(0) + + +ENABLE_TIMEOUT = os.getenv("INTERNLM_ENABLE_TIMEOUT", None) + + +timeout_threshold_dict = { + "initialize_distributed_env": 120, + "nopp_forward_backward_step": 360, + "initialize_model": 10, + "initialize_optimizer": 20, + "optim_step": 30, + "get_train_data_loader": 600, + "get_validation_data_loader": 60, + "load_new_batch": 10, + "record_current_batch_training_metrics": 10, + "save_checkpoint": 1200, + "interleaved_forward_backward_step": 600, + "nointerleaved_forward_backward_step": 600, +} + +if ENABLE_TIMEOUT is not None: + os.environ["NCCL_ASYNC_ERROR_HANDLING"] = "1" + LLM_NCCL_TIMEOUT = datetime.timedelta(seconds=int(os.getenv("NCCL_TIMEOUT", str(60)))) +else: + timeout_threshold_dict = dict.fromkeys(timeout_threshold_dict.keys(), 0) + LLM_NCCL_TIMEOUT = datetime.timedelta(seconds=1800) + + +def try_get_gpc_rank(): + try: + from internlm.core.context import global_context as gpc + + rank = gpc.get_global_rank() + except: # noqa # pylint: disable=bare-except + rank = "unknown" + + return f"host-{socket.gethostname()}-rank-{rank}" + + +def llm_timeout(seconds=0, func_name=None): + """timeout decorator, Note that this decorator cannot be reentrant, + otherwise the signal will be reset. + + Args: + seconds (int, optional): timeout threshold. Defaults to 300. + func_name (str, optional): the func who is been waited to timeout. + """ + + def decorator(func): + nonlocal func_name + if func_name is None: + func_name = func.__name__ + + @wraps(func) + def wrapper(*args, **kwargs): + def _handle_timeout(signum, frame): + raise TimeoutError + + nonlocal seconds + seconds = timeout_threshold_dict.get(func_name, seconds) + + if seconds > 0: + signal.signal(signal.SIGALRM, _handle_timeout) + signal.alarm(seconds) + + try: + result = func(*args, **kwargs) + except TimeoutError as e: + logger.error(f"TimeoutError at {try_get_gpc_rank()}: {func_name}\\n {traceback.format_exc()}") + raise e + finally: + signal.alarm(0) + + return result + + return wrapper + + return decorator diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/test_utils/common_fixture.py b/tests/test_utils/common_fixture.py new file mode 100644 index 0000000..d6a19b6 --- /dev/null +++ b/tests/test_utils/common_fixture.py @@ -0,0 +1,181 @@ +import os +import shutil +from subprocess import PIPE, STDOUT, Popen + +import pytest +import torch + +from internlm.core.context import global_context as gpc +from internlm.core.context.parallel_context import Config +from internlm.solver.optimizer.hybrid_zero_optim import HybridZeroOptimizer +from internlm.utils.common import SingletonMeta + +OSS_NAME = os.environ["OSS_BUCKET_NAME"] +OSS_IP = os.environ["OSS_IP"] +USER = os.environ["USER"] +JOB_NAME = "CI_TEST" +LOCAL_SAVE_PATH = "local:local_ckpt" + +BOTO_SAVE_PATH = f"boto3:s3://{OSS_NAME}.{OSS_IP}/{USER}/{JOB_NAME}" +BOTO_SAVE_PATH_NO_PRFIX = f"s3://{OSS_NAME}.{OSS_IP}/{USER}/{JOB_NAME}/" + +ASYNC_TMP_FOLDER = "./async_tmp_folder" + + +# 1B +init_config = Config( + dict( + parallel=dict(zero1=1, pipeline=dict(size=1, interleaved_overlap=False), sequence_parallel=False, tensor=1), + model_type="INTERNLM", + adam=dict( + lr=1e-4, + ), + data=dict(seq_len=2048, micro_num=1, micro_bsz=1, pack_sample_into_one=False, min_length=0, total_steps=9999), + model=dict( + checkpoint=False, + num_attention_heads=2, + embed_split_hidden=True, + vocab_size=103168, + embed_grad_scale=1, + parallel_output=True, + hidden_size=1024, + num_layers=2, + mlp_ratio=1, + apply_post_layer_norm=False, + dtype=torch.bfloat16, + norm_type="rmsnorm", + layer_norm_epsilon=1e-5, + use_flash_attn=True, + num_chunks=1, + ), + resume_tb_folder="", + tensorboard_folder="", + ) +) + + +def init_naive_model(): + # let MODEL_INITIALIZER to work + import internlm.model.modeling_internlm # noqa # pylint: disable=unused-import + from internlm.core.naive_amp import NaiveAMPModel + from internlm.utils.registry import MODEL_INITIALIZER + + model = MODEL_INITIALIZER.get_module(module_name=gpc.config.model_type)(**(init_config.model)) + model = NaiveAMPModel( + model=model, + output_to_fp32=False, + dtype=torch.bfloat16, + sync_buffer=False, + ) + return model + + +def init_naive_optim(model): + naive_optimizer = torch.optim.AdamW( + params=[{"params": model.parameters(), "weight_decay": 0.01}], + lr=1e-4, + betas=(0.9, 0.95), + eps=1e-8, + ) + return naive_optimizer + + +def init_hybrid_optim(model): + naive_optimizer = torch.optim.AdamW( + params=[{"params": model.parameters(), "weight_decay": 0.01}], + lr=1e-4, + betas=(0.9, 0.95), + eps=1e-8, + ) + optimizer = HybridZeroOptimizer( + naive_optimizer, + grad_scal_cfg=Config( + dict( + fp16=dict( + initial_scale=2**16, + min_scale=1, + growth_interval=1000, + ), + growth_factor=2, + backoff_factor=0.5, + max_scale=2**24, + hysteresis=2, + ) + ), + zero_cfg=Config( + dict( + overlap_sync_grad=False, + overlap_sync_param=False, + reduce_bucket_size=512 * 1024 * 1024, + clip_grad_norm=1.0, + ) + ), + param_bcast_sync_handler=None, + ) + return optimizer + + +@pytest.fixture(autouse=True, scope="function") +def reset_singletons(): + SingletonMeta._instances = {} + + +def reset_seed(): + from internlm.core.context.random import _SEED_MANAGER + + _SEED_MANAGER.reset() + + +@pytest.fixture(scope="module") +def init_dist_and_model(rank=0, world_size=1): + from internlm.initialize import initialize_distributed_env + + os.environ["RANK"] = str(rank) + os.environ["LOCAL_RANK"] = str(rank) + os.environ["WORLD_SIZE"] = str(world_size) + os.environ["MASTER_ADDR"] = "127.0.0.1" + os.environ["MASTER_PORT"] = "12377" + initialize_distributed_env(config=init_config, launcher="torch", master_port=12377, args_check=False) + + # setup + print("set up", flush=True) + model = init_naive_model() + # opim = init_naive_optim(model) + opim = init_hybrid_optim(model) + + yield model, opim + + # teardown + del model, opim + print("teardown", flush=True) + gpc.destroy() + reset_seed() + + +def enter_flag(text): + print(f"{text} begin!", flush=True) + yield + print(f"{text} end!", flush=True) + + +def del_tmp_file(): + try: + shutil.rmtree(ASYNC_TMP_FOLDER, ignore_errors=True) + except FileNotFoundError: + pass + + try: + shutil.rmtree(LOCAL_SAVE_PATH.split(":")[1], ignore_errors=True) + except FileNotFoundError: + pass + + try: + cmd = r"/mnt/petrelfs/share/sensesync --dryrun --deleteSrc cp " + BOTO_SAVE_PATH_NO_PRFIX + " / " + with Popen(cmd, stdout=PIPE, stderr=STDOUT, shell=True) as output: + results, presults = "", "" + for line in iter(output.stdout.readline, b""): + results += str(line.rstrip()) + presults += line.rstrip().decode() + "\n" + print(presults, flush=True) + except FileNotFoundError: + pass diff --git a/tests/test_utils/test_model_checkpoint.py b/tests/test_utils/test_model_checkpoint.py new file mode 100644 index 0000000..bd93436 --- /dev/null +++ b/tests/test_utils/test_model_checkpoint.py @@ -0,0 +1,247 @@ +import os + +import pytest +import torch + +from internlm.core.context import global_context as gpc +from internlm.core.context.parallel_context import Config +from internlm.core.trainer import TrainState +from internlm.solver.optimizer.hybrid_zero_optim import HybridZeroOptimizer +from internlm.utils.common import SingletonMeta +from internlm.utils.model_checkpoint import CheckpointManager +from internlm.utils.storage_manager import wait_async_upload_finish +from tests.test_utils.common_fixture import ( # noqa # pylint: disable=unused-import + ASYNC_TMP_FOLDER, + BOTO_SAVE_PATH, + LOCAL_SAVE_PATH, + del_tmp_file, + init_dist_and_model, + reset_singletons, +) + +TOTAL_STEP = 6 + +CKPT_EVERY = 4 +SNPASHOT_EVERY = 2 + + +ckpt_config_list = [ + # Old interface format + dict( + enable_save_ckpt=True, + save_ckpt_folder=BOTO_SAVE_PATH, + load_optimizer=True, + checkpoint_every=CKPT_EVERY, + async_upload=True, + async_upload_tmp_folder=ASYNC_TMP_FOLDER, + snapshot_ckpt_folder="/".join([BOTO_SAVE_PATH, "snapshot"]), + oss_snapshot_freq=SNPASHOT_EVERY, + stop_file_path=None, + load_model_only_folder=None, + load_given_ckpt=False, + load_ckpt_folder=None, + is_old_api=True, + ), + # Old interface format + dict( + enable_save_ckpt=True, + save_ckpt_folder=LOCAL_SAVE_PATH, + load_optimizer=True, + checkpoint_every=CKPT_EVERY, + async_upload=False, + async_upload_tmp_folder=ASYNC_TMP_FOLDER, + snapshot_ckpt_folder="/".join([LOCAL_SAVE_PATH, "snapshot"]), + oss_snapshot_freq=SNPASHOT_EVERY, + stop_file_path=None, + load_model_only_folder=None, + load_given_ckpt=False, + load_ckpt_folder=None, + is_old_api=True, + ), + # New interface format + dict( + enable_save_ckpt=True, + save_ckpt_folder=BOTO_SAVE_PATH, + checkpoint_every=CKPT_EVERY, + async_upload=True, + async_upload_tmp_folder=ASYNC_TMP_FOLDER, + oss_snapshot_freq=SNPASHOT_EVERY, + stop_file_path=None, + is_old_api=False, + auto_resume=True, + ), + dict( + enable_save_ckpt=True, + save_ckpt_folder=LOCAL_SAVE_PATH, + checkpoint_every=CKPT_EVERY, + async_upload=False, + async_upload_tmp_folder=ASYNC_TMP_FOLDER, + oss_snapshot_freq=SNPASHOT_EVERY, + stop_file_path=None, + load_ckpt_folder=None, + is_old_api=False, + auto_resume=True, + ), +] + + +def overwrite_optim_state(optim, set_value): + if isinstance(optim, HybridZeroOptimizer): + for group_id, p in optim._fp32_flat_param_groups_of_current_rank.items(): + if optim._zero_local_rank not in optim.param_group_no_params_ranks[group_id]: + # p.copy_(torch.full_like(p, set_value, dtype=p.dtype)) + p.data.fill_(set_value) + for group_id in range(len(optim._fp16_param_groups)): + if optim._zero_local_rank not in optim.param_group_no_params_ranks[group_id]: + fp16_p = optim._param_store.get_flat_fp16_param_by_rank_group( + rank=optim._zero_local_rank, group_id=group_id + ) + fp16_p.fill_(set_value) + else: + for group in optim.param_groups: + for p in group["params"]: + # p.copy_(torch.full_like(p, set_value, dtype=p.dtype)) + p.data.fill_(set_value) + + +def compare_optim_state(optim1, optim2): + re = True + if isinstance(optim1, HybridZeroOptimizer): + fp32_buff1 = optim1._fp32_flat_param_groups_of_current_rank + fp32_buff2 = optim2._fp32_flat_param_groups_of_current_rank + for group_id_1, group_id_2 in zip(fp32_buff1, fp32_buff2): + re &= group_id_1 == group_id_2 + if optim1.zero_local_rank not in optim1.param_group_no_params_ranks[group_id_1]: + re &= torch.equal(fp32_buff1[group_id_1], fp32_buff1[group_id_2]) + else: + for group1, group2 in zip(optim1.param_groups, optim2.param_groups): + for p1, p2 in zip(group1["params"], group2["params"]): + re &= torch.equal(p1, p2) + return re + + +def compare_optim_value(optim, value): + re = True + if isinstance(optim, HybridZeroOptimizer): + for group_id, p in optim._fp32_flat_param_groups_of_current_rank.items(): + if optim._zero_local_rank not in optim.param_group_no_params_ranks[group_id]: + re &= torch.equal(p, torch.full_like(p, value, dtype=p.dtype)) + for group_id in range(len(optim._fp16_param_groups)): + if optim._zero_local_rank not in optim.param_group_no_params_ranks[group_id]: + fp16_p = optim._param_store.get_flat_fp16_param_by_rank_group( + rank=optim._zero_local_rank, group_id=group_id + ) + re &= torch.equal(fp16_p, torch.full_like(fp16_p, value, dtype=fp16_p.dtype)) + else: + for group in optim.param_groups: + for p in group["params"]: + re &= torch.equal(p, torch.full_like(p, value, dtype=p.dtype)) + return re + + +def overwrite_model_value(model, value): + for p in model.parameters(): + # p.copy_(torch.full_like(p, value, dtype=p.dtype)) + p.data.fill_(value) + + +def compare_model_value(model, value): + re = True + for p in model.parameters(): + re &= torch.equal(p, torch.full_like(p, value, dtype=p.dtype)) + return re + + +@pytest.fixture(scope="function") +def del_tmp(): + del_tmp_file() + yield + del_tmp_file() + + +@pytest.mark.usefixtures("del_tmp") +@pytest.mark.usefixtures("reset_singletons") +@pytest.mark.parametrize("ckpt_config", ckpt_config_list) +def test_ckpt_mm(ckpt_config, init_dist_and_model): # noqa # pylint: disable=unused-import + from internlm.utils.model_checkpoint import CheckpointLoadMask, CheckpointLoadType + + ckpt_config = Config(ckpt_config) + assert ckpt_config.checkpoint_every < TOTAL_STEP + assert ckpt_config.oss_snapshot_freq < TOTAL_STEP + + model, opim = init_dist_and_model + train_state = TrainState(gpc.config, None) + if isinstance(opim, HybridZeroOptimizer): + print("Is HybridZeroOptimizer!", flush=True) + else: + print("Is naive Adam!", flush=True) + + ckpt_mm = CheckpointManager(ckpt_config, model=model, optimizer=opim) + latest_ckpt_step = None + for i in range(TOTAL_STEP + 1): + overwrite_model_value(model, i) + overwrite_optim_state(opim, i) + + train_state.batch_count = i + train_state.step_count += 1 + + save_ckpts, _, _ = ckpt_mm.is_now_to_save_ckpt(train_state) + if save_ckpts: + latest_ckpt_step = i + + ckpt_mm.try_save_checkpoint(train_state) + + wait_async_upload_finish() + latest_ckpt_info = ckpt_mm.query_lastest_ckpt() + assert latest_ckpt_info is not None + latest_ckpt = latest_ckpt_info["path"] + if ckpt_mm.save_ckpt_folder.startswith("local"): + assert latest_ckpt == "local:local_ckpt/snapshot/0", latest_ckpt + else: + assert latest_ckpt == f"{BOTO_SAVE_PATH}/snapshot/0", latest_ckpt + + del ckpt_mm + SingletonMeta._instances = {} + ckpt_mm = CheckpointManager(ckpt_config, model=model, optimizer=opim) + ckpt_mm.try_resume_training(train_state) + assert latest_ckpt_step == 5 + assert train_state.step_count == 6 + assert train_state.batch_count == 6 + assert compare_optim_value(ckpt_mm.optimizer, latest_ckpt_step), ckpt_mm.optimizer.param_groups[0]["params"][0] + assert compare_model_value(ckpt_mm.model, latest_ckpt_step), list(ckpt_mm.model.parameters())[0][0] + + if ckpt_mm.save_ckpt_folder.startswith("local:"): + ckpt_mm.load_ckpt_info = dict( + path=os.path.join(LOCAL_SAVE_PATH, "4"), + content=CheckpointLoadMask(("all",)), + ckpt_type=CheckpointLoadType.INTERNLM, + ) + else: + ckpt_mm.load_ckpt_info = dict( + path=os.path.join(BOTO_SAVE_PATH, "4"), + content=CheckpointLoadMask(("all",)), + ckpt_type=CheckpointLoadType.INTERNLM, + ) + + ckpt_mm.try_resume_training(train_state) + + assert train_state.step_count == 4 + assert train_state.batch_count == 4 + assert compare_optim_value(ckpt_mm.optimizer, 3), ckpt_mm.optimizer.param_groups[0]["params"][0] + assert compare_model_value(ckpt_mm.model, 3), list(ckpt_mm.model.parameters())[0][0] + + +@pytest.mark.usefixtures("del_tmp") +@pytest.mark.usefixtures("reset_singletons") +@pytest.mark.parametrize("ckpt_config", ckpt_config_list) +def test_ckpt_mm_ping(ckpt_config, init_dist_and_model): # noqa # pylint: disable=unused-import + ckpt_config = Config(ckpt_config) + + model, opim = init_dist_and_model + SingletonMeta._instances = {} + ckpt_mm = CheckpointManager(ckpt_config, model=model, optimizer=opim) + ckpt_mm.try_ping_storage() + + +if __name__ == "__main__": + pytest.main() diff --git a/tests/test_utils/test_storage_manager.py b/tests/test_utils/test_storage_manager.py new file mode 100644 index 0000000..32f905b --- /dev/null +++ b/tests/test_utils/test_storage_manager.py @@ -0,0 +1,89 @@ +import os + +import pytest +import torch + +from internlm.core.context.parallel_context import Config +from internlm.initialize.launch import get_config_value +from tests.test_utils.common_fixture import ( # noqa # pylint: disable=unused-import + ASYNC_TMP_FOLDER, + BOTO_SAVE_PATH, + LOCAL_SAVE_PATH, + del_tmp_file, + init_dist_and_model, + reset_singletons, +) + +ASYNC_TMP_FOLDER = "./async_tmp_folder" +ckpt_config_list = [ + # async boto + dict( + enable_save_ckpt=True, + async_upload_tmp_folder=ASYNC_TMP_FOLDER, + async_upload=True, + save_folder=BOTO_SAVE_PATH, + test_id=0, + ), + # sync local + dict( + enable_save_ckpt=True, + async_upload_tmp_folder=None, + async_upload=False, + save_folder=LOCAL_SAVE_PATH, + test_id=1, + ), + # sync boto + dict( + enable_save_ckpt=True, + async_upload_tmp_folder=None, + async_upload=False, + save_folder=BOTO_SAVE_PATH, + test_id=2, + ), + # async local + dict( + enable_save_ckpt=True, + async_upload_tmp_folder=ASYNC_TMP_FOLDER, + async_upload=True, + save_folder=LOCAL_SAVE_PATH, + test_id=3, + ), +] + + +@pytest.fixture(scope="function") +def del_tmp(): + del_tmp_file() + yield + del_tmp_file() + + +@pytest.mark.usefixtures("del_tmp") +@pytest.mark.usefixtures("reset_singletons") +@pytest.mark.parametrize("ckpt_config", ckpt_config_list) +def test_storage_mm_save_load(ckpt_config, init_dist_and_model): # noqa # pylint: disable=unused-argument + from internlm.utils.storage_manager import ( + check_folder, + get_fns, + init_storage_manager, + llm_load, + llm_save, + wait_async_upload_finish, + ) + + ckpt_config = Config(ckpt_config) + enable_save_ckpt = get_config_value(ckpt_config, "enable_save_ckpt", False) + async_upload_tmp_folder = get_config_value(ckpt_config, "async_upload_tmp_folder", False) + async_upload = get_config_value(ckpt_config, "async_upload", False) + + init_storage_manager(enable_save_ckpt, async_upload_tmp_folder, async_upload) + + tobj = torch.rand(64, 64) + save_fn = os.path.join(ckpt_config.save_folder, "test.pt") + llm_save(save_fn, tobj) + if ckpt_config.test_id == 0: + wait_async_upload_finish() + check_folder(save_fn) + assert get_fns(ckpt_config.save_folder)[0] == "test.pt" + load_obj = llm_load(save_fn, map_location="cpu") + assert 0 == ((load_obj != tobj).sum()) diff --git a/tests/test_utils/test_timeout.py b/tests/test_utils/test_timeout.py new file mode 100644 index 0000000..a3f15f9 --- /dev/null +++ b/tests/test_utils/test_timeout.py @@ -0,0 +1,119 @@ +import fcntl +import os +import time +from multiprocessing import Process + +import pytest +import torch +import torch.distributed as dist + +os.environ["INTERNLM_ENABLE_TIMEOUT"] = "1" # noqa # pylint: disable=wrong-import-position +os.environ["NCCL_TIMEOUT"] = "5" +from internlm.utils.timeout import llm_timeout +from tests.test_utils.common_fixture import ( # noqa # pylint: disable=unused-import + init_config, +) + +WORLD_SIZE = 2 + + +@llm_timeout(2, "fake_timeout_func") +def fake_timeout_func(): + time.sleep(10) + + +@llm_timeout(10, "nccl_timeout_func") +def nccl_timeout_func(rank): + # see: https://github.com/pytorch/pytorch/issues/104506#issuecomment-1679762880 + # 'NCCL_ASYNC_ERROR_HANDLING' cannot take effect on the first collective communication. + buff = torch.ones([64, 64]).cuda(rank) + dist.all_reduce(buff) # lazy communicator init + torch.cuda.synchronize() + if rank == 0: + dist.all_reduce(buff) + torch.cuda.synchronize() # main thread will hang at here. + else: + time.sleep(9999) + + +@llm_timeout(10, "try_file_lock") +def try_file_lock(rank, stop_file_path): + if rank == 1: + time.sleep(5) + + with open(stop_file_path, "r", encoding="utf-8") as f: + fcntl.flock(f, fcntl.LOCK_EX) # rank 1 hang. + if rank == 0: + time.sleep(99999) # rank 0 hang. + f.seek(0) + f.read() + fcntl.flock(f, fcntl.LOCK_UN) + + +def local_timeout(rank, _): + + try: + fake_timeout_func() + except TimeoutError as e: + print(f"local_timeout, rank:{rank}, e:{e}", flush=True) + else: + assert False, "It should timeout!" + + +def gpc_timeout(rank, world_size): + + from internlm.initialize import initialize_distributed_env + + os.environ["RANK"] = str(rank) + os.environ["LOCAL_RANK"] = str(rank) + os.environ["WORLD_SIZE"] = str(world_size) + os.environ["MASTER_ADDR"] = "127.0.0.1" + os.environ["MASTER_PORT"] = "12377" + initialize_distributed_env(config=init_config, launcher="torch", master_port=12377, args_check=False) + + try: + nccl_timeout_func(rank) + except TimeoutError as e: + print(f"gpc_timeout, rank:{rank}, e:{e}", flush=True) + time.sleep(5) # wait rank 0 to be killed + else: + time.sleep(5) # give some time to let Watchdog kill rank 0. + assert False, "It should timeout!" + + +def file_lock_timeout(rank, _, stop_file_path): + if rank == 0: + with open(stop_file_path, "w"): + pass + try: + try_file_lock(rank, stop_file_path) + except TimeoutError as e: + print(e, flush=True) + else: + assert False, "It should timeout!" + finally: + if rank == 0: + os.remove(stop_file_path) + + +timeout_func_list = [(gpc_timeout, 2, None), (local_timeout, 1, None), (file_lock_timeout, 2, "test_lock.log")] + + +@pytest.mark.parametrize("timeout_func_and_args", timeout_func_list) +def test_timeout(timeout_func_and_args): + timeout_func, world_size, other_args = timeout_func_and_args + procs = [] + for i in range(world_size): + if other_args is None: + args = (i, world_size) + else: + args = (i, world_size, other_args) + proc = Process(target=timeout_func, args=args) + proc.start() + procs.append(proc) + + for proc in procs: + proc.join(15) + if proc.is_alive(): + proc.terminate() + proc.join() diff --git a/train.py b/train.py index 6b0ffcf..35e9612 100644 --- a/train.py +++ b/train.py @@ -36,7 +36,6 @@ from internlm.utils.common import ( parse_args, ) from internlm.utils.evaluation import evaluate_on_val_dls -from internlm.utils.gputest import bench_gpu, bench_net from internlm.utils.logger import get_logger, initialize_uniscale_logger from internlm.utils.megatron_timers import megatron_timer as timer from internlm.utils.model_checkpoint import CheckpointManager @@ -74,7 +73,6 @@ def main(args): total_steps = gpc.config.data.total_steps valid_every = gpc.config.data.valid_every label_smoothing = gpc.config.loss.label_smoothing - lr = gpc.config.adam.lr get_tflops_func = partial( get_megatron_flops, @@ -97,21 +95,11 @@ def main(args): # initialize customed llm logger uniscale_logger = initialize_llm_logger(start_time=current_time) - # initialize and resume train state - train_state = TrainState(gpc.config) - # initialize model model = initialize_model() with open(args.config, "r") as f: config_lines = f.readlines() - ckpt_manager = CheckpointManager( - ckpt_config=gpc.config.ckpt, - model=model, - model_config=gpc.config.model, - model_config_file="".join(config_lines), - feishu_address=gpc.config.alert_address, - ) # initialize loss function criterion = FlashGPTLMLoss(parallel_output=True, label_smoothing=label_smoothing) @@ -119,18 +107,28 @@ def main(args): # initialize the train and validation data loader train_dl, dataset_types = get_train_data_loader(num_worker=4) val_dls = get_validation_data_loader() - train_state.init_batch_sampler(train_dl) - # Loading model weights must be done before zero is initialized. - ckpt_manager.try_load_model(current_time) + # initialize and resume train state + train_state = TrainState(gpc.config, train_dl.batch_sampler) # if fsdp enabled, warp the model model = warp_FSDP_model(model) optimizer, beta2_scheduler, lr_scheduler = initialize_optimizer(model=model) + ckpt_manager = CheckpointManager( + ckpt_config=gpc.config.ckpt, + model=model, + optimizer=optimizer, + lr_scheduler=lr_scheduler, + train_dl=train_dl, + model_config=gpc.config.model, + model_config_file="".join(config_lines), + feishu_address=gpc.config.monitor.alert.feishu_alert_address, + ) + # Loading other persistent training states. - ckpt_manager.try_resume_training(lr_scheduler, optimizer, lr, train_state, train_dl) + ckpt_manager.try_resume_training(train_state, current_time) # initialize customed llm writer writer = Writer( @@ -201,8 +199,6 @@ def main(args): for batch_count in range(train_state.batch_count, total_steps): if batch_count % 50 == 0: torch.cuda.empty_cache() - bench_gpu() - bench_net() start_time = time.time() timer("one-batch").start() @@ -245,7 +241,7 @@ def main(args): if -1 in grad_norm_groups.values() and gpc.is_rank_for_log(): # -1 encodes a specific failure case logger.warning(f"Warning: skip parameter update at step {batch_count}.") send_alert_message( - address=gpc.config.alert_address, + address=gpc.config.monitor.alert.feishu_alert_address, message=f"Warning: skip parameter update at step {batch_count}.", ) @@ -305,11 +301,15 @@ if __name__ == "__main__": assert hasattr(gpc, "config") and gpc.config is not None # initialize monitor manager context - with initialize_monitor_manager(job_name=gpc.config.JOB_NAME, alert_address=gpc.config.alert_address): + with initialize_monitor_manager( + job_name=gpc.config.JOB_NAME, alert_address=gpc.config.monitor.alert.feishu_alert_address + ): try: main(args) except Exception: logger.error( f"Raise exception from {hostname} with rank id: {gpc.get_global_rank()}\n{traceback.format_exc()}", ) - mm.monitor_exception(alert_address=gpc.config.alert_address, excp_info=traceback.format_exc()) + mm.monitor_exception( + alert_address=gpc.config.monitor.alert.feishu_alert_address, excp_info=traceback.format_exc() + )