diff --git a/configs/moe_cfg.py b/configs/moe_cfg.py new file mode 100644 index 0000000..89e1a96 --- /dev/null +++ b/configs/moe_cfg.py @@ -0,0 +1,152 @@ +JOB_NAME = "7b_train" + +SEQ_LEN = 2048 +HIDDEN_SIZE = 4096 +NUM_ATTENTION_HEAD = 32 +MLP_RATIO = 8 / 3 +NUM_LAYER = 16 +VOCAB_SIZE = 103168 + +MODEL_ONLY_FOLDER = "local:llm_ckpts/xxxx" +# Ckpt folder format: +# fs: 'local:/mnt/nfs/XXX' +SAVE_CKPT_FOLDER = "local:llm_ckpts" +LOAD_CKPT_FOLDER = "local:llm_ckpts/49" + +# boto3 Ckpt folder format: +# import os +# BOTO3_IP = os.environ["BOTO3_IP"] # boto3 bucket endpoint +# SAVE_CKPT_FOLDER = f"boto3:s3://model_weights.{BOTO3_IP}/internlm" +# LOAD_CKPT_FOLDER = f"boto3:s3://model_weights.{BOTO3_IP}/internlm/snapshot/1/" +CHECKPOINT_EVERY = 50 +ckpt = dict( + enable_save_ckpt=False, # enable ckpt save. + save_ckpt_folder=SAVE_CKPT_FOLDER, # Path to save training ckpt. + # load_ckpt_folder=LOAD_CKPT_FOLDER, # Ckpt path to resume training(load weights and scheduler/context states). + # load_model_only_folder=MODEL_ONLY_FOLDER, # Path to initialize with given model weights. + load_optimizer=True, # Wheter to load optimizer states when continuing training. + checkpoint_every=CHECKPOINT_EVERY, + async_upload=True, # async ckpt upload. (only work for boto3 ckpt) + async_upload_tmp_folder="/dev/shm/internlm_tmp_ckpt/", # path for temporarily files during asynchronous upload. + snapshot_ckpt_folder="/".join([SAVE_CKPT_FOLDER, "snapshot"]), # directory for snapshot ckpt storage path. + oss_snapshot_freq=int(CHECKPOINT_EVERY / 2), # snapshot ckpt save frequency. +) + +TRAIN_FOLDER = "/mnt/petrelfs/share_data/llm_data/0623_scratch_tokenized_filtered/train/en/enwiki" +VALID_FOLDER = "/mnt/petrelfs/share_data/llm_data/0623_scratch_tokenized_filtered/train/en/enwiki" +data = dict( + seq_len=SEQ_LEN, + # micro_num means the number of micro_batch contained in one gradient update + micro_num=4, + packed_length = 2 * SEQ_LEN, + micro_bsz=2, + # defaults to the value of micro_num + valid_micro_num=4, + # defaults to 0, means disable evaluate + valid_every=50000, + pack_sample_into_one=False, + total_steps=50000, + skip_batches="", + rampup_batch_size="", + # Datasets with less than 50 rows will be discarded + min_length=50, + train_folder=TRAIN_FOLDER, + valid_folder=VALID_FOLDER, +) + +grad_scaler = dict( + fp16=dict( + # the initial loss scale, defaults to 2**16 + initial_scale=2**16, + # the minimum loss scale, defaults to None + min_scale=1, + # the number of steps to increase loss scale when no overflow occurs + growth_interval=1000, + ), + # the multiplication factor for increasing loss scale, defaults to 2 + growth_factor=2, + # the multiplication factor for decreasing loss scale, defaults to 0.5 + backoff_factor=0.5, + # the maximum loss scale, defaults to None + max_scale=2**24, + # the number of overflows before decreasing loss scale, defaults to 2 + hysteresis=2, +) + +hybrid_zero_optimizer = dict( + # Enable low_level_optimzer overlap_communication + zero_overlap_communication=True, + # bucket size for nccl communication params + reduce_bucket_size=512 * 1024 * 1024, + # grad clipping + clip_grad_norm=1.0, +) + +loss = dict( + label_smoothing=0, + moe_loss_coeff=0.1, +) + +adam = dict( + lr=1e-4, + adam_beta1=0.9, + adam_beta2=0.95, + adam_beta2_c=0, + adam_eps=1e-8, + weight_decay=0.01, +) + +lr_scheduler = dict( + total_steps=data["total_steps"], + init_steps=0, # optimizer_warmup_step + warmup_ratio=0.01, + eta_min=1e-5, + last_epoch=-1, +) + +beta2_scheduler = dict( + init_beta2=adam["adam_beta2"], + c=adam["adam_beta2_c"], + cur_iter=-1, +) + +model = dict( + checkpoint=False, + num_attention_heads=NUM_ATTENTION_HEAD, + embed_split_hidden=True, + vocab_size=VOCAB_SIZE, + embed_grad_scale=1, + parallel_output=True, + hidden_size=HIDDEN_SIZE, + num_layers=NUM_LAYER, + mlp_ratio=MLP_RATIO, + apply_post_layer_norm=False, + dtype="torch.bfloat16", + norm_type="rmsnorm", + layer_norm_epsilon=1e-5, + use_flash_attn=True, + num_chunks=1, # if num_chunks > 1, interleaved pipeline scheduler is used. + sequence_parallel=False, + num_experts=4, + moe_use_residual=False, +) +""" +zero1 parallel: + 1. if zero1 <= 0, The size of the zero process group is equal to the size of the dp process group, + so parameters will be divided within the range of dp. + 2. if zero1 == 1, zero is not used, and all dp groups retain the full amount of model parameters. + 3. zero1 > 1 and zero1 <= dp world size, the world size of zero is a subset of dp world size. + For smaller models, it is usually a better choice to split the parameters within nodes with a setting <= 8. +pipeline parallel (dict): + 1. size: int, the size of pipeline parallel. + 2. interleaved_overlap: bool, enable/disable communication overlap when using interleaved pipeline scheduler. +tensor parallel: tensor parallel size, usually the number of GPUs per node. +""" +parallel = dict( + # zero1=4, + pipeline=dict(size=4, interleaved_overlap=False), + # tensor=dict(size=4), +) + +cudnn_deterministic = False +cudnn_benchmark = False diff --git a/internlm/core/scheduler/no_pipeline_scheduler.py b/internlm/core/scheduler/no_pipeline_scheduler.py index daab2bb..c1e8830 100644 --- a/internlm/core/scheduler/no_pipeline_scheduler.py +++ b/internlm/core/scheduler/no_pipeline_scheduler.py @@ -115,8 +115,9 @@ class NonPipelineScheduler(BaseScheduler): loss = self._call_engine_criterion(engine, output, label) self._call_hooks("after_criterion", loss) moe_loss = sum(moe_losses) * moe_loss_coeff - loss += moe_loss + moe_loss /= scale_loss loss /= scale_loss + loss += moe_loss # backward if not forward_only: @@ -127,7 +128,7 @@ class NonPipelineScheduler(BaseScheduler): if not return_loss: loss = None - return output, loss + return output, loss, moe_loss def forward_backward_step( self, @@ -166,6 +167,7 @@ class NonPipelineScheduler(BaseScheduler): data, label = batch_data loss = 0 if return_loss else None + moe_loss = 0 if return_loss else None outputs = [] labels = [] @@ -180,12 +182,14 @@ class NonPipelineScheduler(BaseScheduler): _data, _label = self._load_accum_batch(data, label) - _output, _loss = self._train_one_batch( + _output, _loss, _moe_loss = self._train_one_batch( _data, _label, engine, forward_only, return_loss, self._grad_accum_size, moe_loss_coeff ) if return_loss: loss += _loss + moe_loss += _moe_loss + if return_output_label: outputs.append(_output) labels.append(_label) @@ -193,4 +197,4 @@ class NonPipelineScheduler(BaseScheduler): if not return_output_label: outputs, labels = None, None - return outputs, labels, loss + return outputs, labels, loss, moe_loss diff --git a/internlm/core/scheduler/pipeline_scheduler.py b/internlm/core/scheduler/pipeline_scheduler.py index ebdb374..dd3268f 100644 --- a/internlm/core/scheduler/pipeline_scheduler.py +++ b/internlm/core/scheduler/pipeline_scheduler.py @@ -7,6 +7,7 @@ from contextlib import contextmanager from typing import Callable, List, Optional, Tuple, Union import torch.cuda +import torch.distributed as dist import internlm.core.communication as comm from internlm.core.context import ParallelMode @@ -130,7 +131,7 @@ class PipelineScheduler(BaseScheduler): self.dtype = dtype self._hooks = scheduler_hooks - self.tensor_shape = ( + self._tensor_shape = ( tensor_shape if tensor_shape is None or isinstance(tensor_shape, torch.Size) else torch.Size(tensor_shape) ) @@ -146,6 +147,14 @@ class PipelineScheduler(BaseScheduler): # cache for the batch data self.batch_data = None + @property + def tensor_shape(self) -> torch.Size: + return self._tensor_shape + + @tensor_shape.setter + def tensor_shape(self, tensor_shape: torch.Size): + self._tensor_shape = tensor_shape + def pre_processing(self, engine): types = set() @@ -239,7 +248,16 @@ class PipelineScheduler(BaseScheduler): """ return step_id - def _forward_step(self, engine, input_obj, return_tensors, return_output_label=True, accum_loss=None): + def _forward_step( + self, + engine, + input_obj, + return_tensors, + return_output_label=True, + accum_loss=None, + accum_moe_loss=None, + moe_loss_coeff=1.0, + ): """ Forward step for passed-in model. If it is the first stage, the input tensor is obtained from data_iterator, otherwise the passed-in input_obj is used. @@ -251,6 +269,7 @@ class PipelineScheduler(BaseScheduler): return_tensors (List[:class:`torch.Tensor`]): A list of tensors to return. return_output_label (bool, optional): Whether returns output labels. accum_loss (optional): Where accumulated loss stores. + accum_moe_loss (optional): Where accumulated moe loss stores. Returns: Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]: output or the loss value of the current pipeline stage. @@ -259,7 +278,7 @@ class PipelineScheduler(BaseScheduler): data, label = self._get_data_label_for_current_step(input_obj, micro_batch_data) self._call_hooks("before_forward", data) - output_obj = self._call_engine(engine.model, data) + output_obj, moe_losses = self._call_engine(engine.model, data) self._call_hooks("after_forward", output_obj) if gpc.is_last_rank(ParallelMode.PIPELINE): @@ -275,9 +294,13 @@ class PipelineScheduler(BaseScheduler): accum_loss.add_(loss_reduced.detach()) output_obj = loss_reduced - return output_obj + moe_loss = sum(moe_losses) * moe_loss_coeff + moe_loss /= self.num_microbatches + accum_moe_loss.add_(moe_loss.detach()) - def _backward_step(self, engine, step_id, input_obj, output_obj, output_obj_grad): + return output_obj, moe_loss + + def _backward_step(self, engine, step_id, input_obj, output_obj, output_obj_grad, moe_loss=None): """ Backward step through the passed-in output tensor. If it is the last stage, the output_obj_grad is None, otherwise it is the gradients with respect to stage's output tensor. @@ -311,10 +334,18 @@ class PipelineScheduler(BaseScheduler): self._call_hooks("before_backward", output_obj, output_obj_grad) with switch_optimizer_grad_sync_skip_mode(engine.optimizer, skip_grad_sync): - if output_obj_grad is None: - engine.backward(output_obj) + if moe_loss is None: + if output_obj_grad is None: + engine.backward(output_obj) + else: + engine.backward_by_grad(output_obj, output_obj_grad) else: - engine.backward_by_grad(output_obj, output_obj_grad) + if output_obj_grad is None: + engine.backward(output_obj + moe_loss) + else: + # scale the latent loss + moe_loss = moe_loss * engine.optimizer.loss_scale + engine.backward_by_grad([output_obj, moe_loss], [output_obj_grad, None]) # Collect the grad of the input_obj. input_obj_grad = None @@ -329,7 +360,7 @@ class PipelineScheduler(BaseScheduler): return input_obj_grad - def _forward_only_step(self, engine, return_loss=True, return_output_label=True): + def _forward_only_step(self, engine, return_loss=True, return_output_label=True, moe_loss_coeff=1.0): """ This function performs forward only computation process. The scheduling of microbatches is similar to the warmup phase, where each microbatch first receives the forward input from the previous stage, then performs @@ -356,6 +387,7 @@ class PipelineScheduler(BaseScheduler): if return_loss and gpc.is_pipeline_last_stage(ignore_virtual=True) else None ) + accum_moe_loss = torch.zeros(1, device=get_current_device()) # Used for tensor meta information communication forward_recv_shapes = self.tensor_shape @@ -376,12 +408,14 @@ class PipelineScheduler(BaseScheduler): input_obj = None # Perform forward computation - output_obj = self._forward_step( + output_obj, _ = self._forward_step( engine, input_obj, return_tensors, return_output_label=return_output_label, accum_loss=accum_loss, + accum_moe_loss=accum_moe_loss, + moe_loss_coeff=moe_loss_coeff, ) if not gpc.is_last_rank(ParallelMode.PIPELINE): @@ -392,10 +426,14 @@ class PipelineScheduler(BaseScheduler): comm.send_forward(output_obj, scatter_gather_tensors=self.scatter_gather_tensors) output, label = pack_return_tensors(return_tensors) if len(return_tensors) > 0 else (None, None) + dist.all_reduce(accum_moe_loss, group=gpc.get_group(ParallelMode.PIPELINE)) - return output, label, accum_loss + if accum_loss is not None: + accum_loss += accum_moe_loss - def _forward_backward_step(self, engine, return_loss=True, return_output_label=True): + return output, label, accum_loss, accum_moe_loss + + def _forward_backward_step(self, engine, return_loss=True, return_output_label=True, moe_loss_coeff=1.0): """ This function schedules the forward and backward computation of microbatches in the pipeline in a 1F1B manner. It consists of three stages: warmup, 1F1B, and cooldown. @@ -441,12 +479,14 @@ class PipelineScheduler(BaseScheduler): # Input, output tensors only need to be saved when doing backward passes input_objs = [] output_objs = [] + moe_losses = [] return_tensors = [] accum_loss = ( torch.zeros(1, device=get_current_device()) if return_loss and gpc.is_pipeline_last_stage(ignore_virtual=True) else None ) + accum_moe_loss = torch.zeros(1, device=get_current_device()) # Used for tensor meta information communication forward_recv_shapes = self.tensor_shape @@ -468,12 +508,14 @@ class PipelineScheduler(BaseScheduler): input_obj = None # Perform forward computation - output_obj = self._forward_step( + output_obj, moe_loss = self._forward_step( engine, input_obj, return_tensors, return_output_label=return_output_label, accum_loss=accum_loss, + accum_moe_loss=accum_moe_loss, + moe_loss_coeff=moe_loss_coeff, ) if not gpc.is_last_rank(ParallelMode.PIPELINE): @@ -493,6 +535,7 @@ class PipelineScheduler(BaseScheduler): input_objs.append(input_obj) output_objs.append(output_obj) + moe_losses.append(moe_loss) # Before running 1F1B, need to receive first forward tensor. # If all microbatches are run in warmup / cooldown phase, then no need to @@ -512,12 +555,14 @@ class PipelineScheduler(BaseScheduler): # Run 1F1B in steady state. for i in range(num_1f1b_micropairs): # Perform forward computation - output_obj = self._forward_step( + output_obj, moe_loss = self._forward_step( engine, input_obj, return_tensors, return_output_label=return_output_label, accum_loss=accum_loss, + accum_moe_loss=accum_moe_loss, + moe_loss_coeff=moe_loss_coeff, ) if gpc.is_last_rank(ParallelMode.PIPELINE): @@ -533,13 +578,15 @@ class PipelineScheduler(BaseScheduler): # Add input_obj and output_obj to end of list. input_objs.append(input_obj) output_objs.append(output_obj) + moe_losses.append(moe_loss) # Pop output_obj and output_obj from the start of the list for # the backward pass. input_obj = input_objs.pop(0) output_obj = output_objs.pop(0) + moe_loss = moe_losses.pop(0) - input_obj_grad = self._backward_step(engine, i, input_obj, output_obj, output_obj_grad) + input_obj_grad = self._backward_step(engine, i, input_obj, output_obj, output_obj_grad, moe_loss) if i == (num_1f1b_micropairs - 1): input_obj = None @@ -563,6 +610,7 @@ class PipelineScheduler(BaseScheduler): for i in range(num_warmup_microsteps): input_obj = input_objs.pop(0) output_obj = output_objs.pop(0) + moe_loss = moe_losses.pop(0) if not gpc.is_last_rank(ParallelMode.PIPELINE): output_obj_grad = comm.recv_backward( @@ -574,17 +622,25 @@ class PipelineScheduler(BaseScheduler): output_obj_grad = None input_obj_grad = self._backward_step( - engine, num_1f1b_micropairs + i, input_obj, output_obj, output_obj_grad + engine, num_1f1b_micropairs + i, input_obj, output_obj, output_obj_grad, moe_loss ) if not gpc.is_first_rank(ParallelMode.PIPELINE): comm.send_backward(input_obj_grad, scatter_gather_tensors=self.scatter_gather_tensors) + logger.info(f"{gpc.get_local_rank(ParallelMode.PIPELINE)}, moe_loss: {accum_moe_loss.item()}") + output, label = pack_return_tensors(return_tensors) if len(return_tensors) > 0 else (None, None) + dist.all_reduce(accum_moe_loss, group=gpc.get_group(ParallelMode.PIPELINE)) - return output, label, accum_loss + if accum_loss is not None: + accum_loss += accum_moe_loss - def forward_backward_step(self, engine, data_iter, forward_only=False, return_loss=True, return_output_label=True): + return output, label, accum_loss, accum_moe_loss + + def forward_backward_step( + self, engine, data_iter, forward_only=False, return_loss=True, return_output_label=True, moe_loss_coeff=1.0 + ): """Runs non-interleaved 1F1B schedule, with communication between pipeline stages. Returns a tuple with losses if the last stage, an empty tuple otherwise. @@ -596,7 +652,7 @@ class PipelineScheduler(BaseScheduler): return_loss (bool, optional): Whether returns the loss value. Default is true. return_output_label (bool, optional): If False, the output and label won't be returned. Returns: - Tuple[:class:`torch.Tensor`]: A tuple of (output, label, loss), loss and label could be None. + Tuple[:class:`torch.Tensor`]: A tuple of (output, label, loss, loss), loss and label could be None. """ assert ( @@ -607,9 +663,9 @@ class PipelineScheduler(BaseScheduler): self.load_batch(engine, data_iter) if forward_only: - return self._forward_only_step(engine, return_loss, return_output_label) + return self._forward_only_step(engine, return_loss, return_output_label, moe_loss_coeff) else: - return self._forward_backward_step(engine, return_loss, return_output_label) + return self._forward_backward_step(engine, return_loss, return_output_label, moe_loss_coeff) class InterleavedPipelineScheduler(PipelineScheduler): @@ -676,21 +732,35 @@ class InterleavedPipelineScheduler(PipelineScheduler): self._pp_rank = gpc.get_local_rank(ParallelMode.PIPELINE) self._accum_loss = None + self._accum_moe_loss = None self._return_tensors = None self._input_objs = [[] for _ in range(num_chunks)] self._output_objs = [[] for _ in range(num_chunks)] self._output_obj_grads = [[] for _ in range(num_chunks)] + self._moe_losses = [[] for _ in range(num_chunks)] self._input_obj_shapes = [self.tensor_shape for _ in range(num_chunks)] self._output_obj_shapes = [None for _ in range(num_chunks)] self._send_tensor_shape_flags = [self.tensor_shape is None for _ in range(num_chunks)] + @property + def tensor_shape(self) -> torch.Size: + return self._tensor_shape + + @tensor_shape.setter + def tensor_shape(self, tensor_shape: torch.Size): + self._tensor_shape = tensor_shape + self._input_obj_shapes = [self._tensor_shape for _ in range(self._num_chunks)] + self._send_tensor_shape_flags = [self._tensor_shape is None for _ in range(self._num_chunks)] + def _clear_state(self) -> None: self._accum_loss = None + self._accum_moe_loss = None self._return_tensors = None self._input_objs = [[] for _ in range(self._num_chunks)] self._output_objs = [[] for _ in range(self._num_chunks)] self._output_obj_grads = [[] for _ in range(self._num_chunks)] + self._moe_losses = [[] for _ in range(self._num_chunks)] self._input_obj_shapes = [self.tensor_shape for _ in range(self._num_chunks)] self._output_obj_shapes = [None for _ in range(self._num_chunks)] @@ -712,7 +782,7 @@ class InterleavedPipelineScheduler(PipelineScheduler): self.microbatch_offset[model_chunk_id] += self.microbatch_size return move_to_device(micro_batch_data) - def _forward_step(self, engine, chunk_id): + def _forward_step(self, engine, chunk_id, moe_loss_coeff=1.0): """Forward step for passed-in model. If it is the first stage, the input tensor is obtained from data_iterator, otherwise the passed-in input_obj is used. Returns output tensor. This is a helper function and can be ignored by users. @@ -734,7 +804,7 @@ class InterleavedPipelineScheduler(PipelineScheduler): data, label = self._get_data_label_for_current_step(input_obj, micro_batch_data) self._call_hooks("before_forward", data) - output_obj = self._call_engine(engine.model[chunk_id], data) + output_obj, moe_losses = self._call_engine(engine.model[chunk_id], data) # Convert output_obj to fp32 when last model chunk of last stage if gpc.is_pipeline_last_stage(ignore_virtual=False) and isinstance(engine.model[chunk_id], NaiveAMPModel): output_obj = engine.model[chunk_id].convert_to_fp32(output_obj) @@ -754,7 +824,14 @@ class InterleavedPipelineScheduler(PipelineScheduler): self._accum_loss.add_(loss_reduced.detach()) output_obj = loss_reduced + moe_loss = sum(moe_losses) * moe_loss_coeff + moe_loss /= self.num_microbatches + + if self._accum_moe_loss is not None: + self._accum_moe_loss.add_(moe_loss.detach()) + self._output_objs[chunk_id].append(output_obj) + self._moe_losses[chunk_id].append(moe_loss) return output_obj @@ -780,8 +857,9 @@ class InterleavedPipelineScheduler(PipelineScheduler): input_obj = self._input_objs[chunk_id].pop(0) output_obj = self._output_objs[chunk_id].pop(0) output_obj_grad = self._output_obj_grads[chunk_id].pop(0) + moe_loss = self._moe_losses[chunk_id].pop(0) - input_obj_grad = super()._backward_step(engine, step_id, input_obj, output_obj, output_obj_grad) + input_obj_grad = super()._backward_step(engine, step_id, input_obj, output_obj, output_obj_grad, moe_loss) return input_obj_grad @@ -813,6 +891,7 @@ class InterleavedPipelineScheduler(PipelineScheduler): num_warmup_microsteps: int, receive_extra_backward: bool = False, forward_only: bool = False, + moe_loss_coeff: float = 1.0, ) -> None: """ Run the warm-up loop and prepare data for the 1F1B stage. @@ -850,12 +929,13 @@ class InterleavedPipelineScheduler(PipelineScheduler): for k in range(num_warmup_microsteps): chunk_id = self._get_chunk_by_microbatch(k) - output_obj = self._forward_step(engine, chunk_id) + output_obj = self._forward_step(engine, chunk_id, moe_loss_coeff) if forward_only: # when forward-only, no need to save tensors for a backward pass self._input_objs[chunk_id].pop() self._output_objs[chunk_id].pop() + self._moe_losses[chunk_id].pop() if not gpc.is_pipeline_last_stage(): if isinstance(output_obj, torch.Tensor): @@ -931,6 +1011,7 @@ class InterleavedPipelineScheduler(PipelineScheduler): num_warmup_microsteps: int, num_1f1b_micropairs: int, all_warmup_microsteps: bool = False, + moe_loss_coeff: float = 1.0, ) -> None: """ Run the 1F1B loop with overlap. @@ -960,7 +1041,7 @@ class InterleavedPipelineScheduler(PipelineScheduler): backward_chunk_id = self._get_chunk_by_microbatch(backward_microstep_id, backward=True) # 1. Forward pass. - output_obj = self._forward_step(engine, forward_chunk_id) + output_obj = self._forward_step(engine, forward_chunk_id, moe_loss_coeff) # 2. Check if the backward input is ready. if backward_async_communicator is not None: @@ -1045,6 +1126,7 @@ class InterleavedPipelineScheduler(PipelineScheduler): num_warmup_microsteps: int, num_1f1b_micropairs: int, all_warmup_microsteps: bool = False, + moe_loss_coeff: float = 1.0, ) -> None: """ Run the 1F1B loop without overlap. @@ -1066,7 +1148,7 @@ class InterleavedPipelineScheduler(PipelineScheduler): # Forward pass. forward_microstep_id = k + num_warmup_microsteps forward_chunk_id = self._get_chunk_by_microbatch(forward_microstep_id) - output_obj = self._forward_step(engine, forward_chunk_id) + output_obj = self._forward_step(engine, forward_chunk_id, moe_loss_coeff) # Backward pass. backward_microstep_id = k @@ -1171,7 +1253,7 @@ class InterleavedPipelineScheduler(PipelineScheduler): ) ) - def _forward_only_step(self, engine: Engine): + def _forward_only_step(self, engine: Engine, moe_loss_coeff: float = 1.0): num_microsteps = self.num_microbatches * self._num_chunks num_warmup_microsteps = num_microsteps @@ -1181,9 +1263,10 @@ class InterleavedPipelineScheduler(PipelineScheduler): num_warmup_microsteps, receive_extra_backward=False, forward_only=True, + moe_loss_coeff=moe_loss_coeff, ) - def _forward_backward_step(self, engine: Engine): + def _forward_backward_step(self, engine: Engine, moe_loss_coeff: float = 1.0): # Compute number of warmup and remaining microbatches. all_warmup_microsteps = False num_microsteps = self.num_microbatches * self._num_chunks @@ -1217,6 +1300,7 @@ class InterleavedPipelineScheduler(PipelineScheduler): num_microsteps, num_warmup_steps, receive_extra_backward=receive_extra_backward, + moe_loss_coeff=moe_loss_coeff, ) # 2. 1F1B @@ -1225,12 +1309,15 @@ class InterleavedPipelineScheduler(PipelineScheduler): num_warmup_steps, num_1f1b_micropairs=num_1f1b_micropairs, all_warmup_microsteps=all_warmup_microsteps, + moe_loss_coeff=moe_loss_coeff, ) # 3. Cooldown self._run_cooldown_loop(engine, num_microsteps, num_1f1b_micropairs=num_1f1b_micropairs) - def forward_backward_step(self, engine, data_iter, forward_only=False, return_loss=True, return_output_label=True): + def forward_backward_step( + self, engine, data_iter, forward_only=False, return_loss=True, return_output_label=True, moe_loss_coeff=1.0 + ): """Run interleaved 1F1B schedule (model split into model chunks), with communication between pipeline stages as needed. @@ -1250,24 +1337,36 @@ class InterleavedPipelineScheduler(PipelineScheduler): forward_only or return_loss ), "The argument 'return_loss' has to be True when 'forward_only' is False, but got False." + gpc.set_virtual_pipeline_parallel_rank(0) + self.load_batch(engine, data_iter) if return_loss and gpc.is_pipeline_last_stage(ignore_virtual=True): self._accum_loss = torch.zeros(1, device=get_current_device()) + self._accum_moe_loss = torch.zeros(1, device=get_current_device()) + if return_output_label: self._return_tensors = [] if forward_only: - self._forward_only_step(engine) + self._forward_only_step(engine, moe_loss_coeff) else: - self._forward_backward_step(engine) + self._forward_backward_step(engine, moe_loss_coeff) if return_output_label and len(self._return_tensors) > 0: output, label = pack_return_tensors(self._return_tensors) else: output, label = (None, None) + + logger.info(f"{gpc.get_local_rank(ParallelMode.PIPELINE)}, moe_loss: {self._accum_moe_loss.item()}") + + dist.all_reduce(self._accum_moe_loss, group=gpc.get_group(ParallelMode.PIPELINE)) + accum_moe_loss = self._accum_moe_loss + accum_loss = self._accum_loss + if accum_loss is not None: + accum_loss += self._accum_moe_loss self._clear_state() - return output, label, accum_loss + return output, label, accum_loss, accum_moe_loss diff --git a/internlm/core/trainer.py b/internlm/core/trainer.py index 536c1aa..fd899ba 100644 --- a/internlm/core/trainer.py +++ b/internlm/core/trainer.py @@ -155,5 +155,5 @@ class Trainer: Returns: Tuple[:class:`torch.Tensor`]: A tuple of (output, label, loss). """ - output, label, loss = self._schedule.forward_backward_step(self._engine, data_iter, **kwargs) - return output, label, loss + output, label, loss, moe_loss = self._schedule.forward_backward_step(self._engine, data_iter, **kwargs) + return output, label, loss, moe_loss diff --git a/internlm/initialize/launch.py b/internlm/initialize/launch.py index 986d1f7..014278e 100644 --- a/internlm/initialize/launch.py +++ b/internlm/initialize/launch.py @@ -108,67 +108,96 @@ def args_sanity_check(): logger.info(f"valid_every: {data.valid_every}") # processing the checkpoint config - if "enable_save_ckpt" not in gpc.config.ckpt: - gpc.config.ckpt._add_item("enable_save_ckpt", False) + ckpt = gpc.config.ckpt + if "enable_save_ckpt" not in ckpt: + ckpt._add_item("enable_save_ckpt", False) - if "checkpoint_every" not in gpc.config.ckpt or gpc.config.ckpt.checkpoint_every <= 0: - gpc.config.ckpt._add_item("checkpoint_every", float("inf")) + # Saving checkpoint args. + if ckpt.enable_save_ckpt: + assert "checkpoint_every" in ckpt, "If enable save checkpoint, must give checkpoint_every in config.data!" + assert ckpt.checkpoint_every > 0 + assert "save_ckpt_folder" in ckpt, "If enable save checkpoint, must give save_ckpt_folder in config.data!" - if "load_optimizer" not in gpc.config.ckpt: - gpc.config.ckpt._add_item("load_optimizer", True) + if "async_upload" not in ckpt: + ckpt._add_item("async_upload", False) # async defalut is False. + else: + if ckpt.async_upload: + assert "save_ckpt_folder" in ckpt + if "boto3:" not in ckpt.save_ckpt_folder: + if gpc.is_rank_for_log(): + logger.warning( + "Storing ckpt on file system does not support asynchronous storage, will use sync save!" + ) + ckpt.async_upload = False + else: + if "async_upload_tmp_folder" not in ckpt: + ckpt._add_item("async_upload_tmp_folder", "/dev/shm/internlm_tmp_ckpt/") - if "save_ckpt_folder" not in gpc.config.ckpt: - gpc.config.ckpt._add_item("save_ckpt_folder", None) + if not ckpt.async_upload: + ckpt._add_item("async_upload_tmp_folder", None) - if "load_ckpt_folder" not in gpc.config.ckpt: - gpc.config.ckpt._add_item("load_ckpt_folder", None) + if "snapshot_ckpt_folder" not in ckpt: + ckpt._add_item("snapshot_ckpt_folder", os.path.join(ckpt.save_ckpt_folder, "snapshot")) - if "load_model_only_folder" not in gpc.config.ckpt: - gpc.config.ckpt._add_item("load_model_only_folder", None) + if "oss_snapshot_freq" not in ckpt: + ckpt._add_item("oss_snapshot_freq", float("inf")) # if oss_snapshot_freq not given, we disable. + else: + ckpt._add_item("checkpoint_every", float("inf")) + ckpt._add_item("oss_snapshot_freq", float("inf")) + ckpt._add_item("save_ckpt_folder", None) + ckpt._add_item("async_upload", False) + ckpt._add_item("async_upload_tmp_folder", None) + ckpt._add_item("snapshot_ckpt_folder", None) + ckpt._add_item("snapshot_ckpt_folder", None) - if "async_upload" not in gpc.config.ckpt: - gpc.config.ckpt._add_item("async_upload", False) + # Loading checkpoint args. + if "load_model_only_folder" not in ckpt: + ckpt._add_item("load_model_only_folder", None) - if "async_upload_tmp_folder" not in gpc.config.ckpt: - gpc.config.ckpt._add_item("async_upload_tmp_folder", "/dev/shm/internlm_tmp_ckpt/") + if "load_ckpt_folder" not in ckpt: + ckpt._add_item("load_ckpt_folder", None) - if gpc.config.ckpt.async_upload: - assert "save_ckpt_folder" in gpc.config.ckpt - if "boto3:" not in gpc.config.ckpt.save_ckpt_folder: - if gpc.is_rank_for_log(): - logger.warning("Storing ckpt on file system does not support asynchronous storage, will use sync save!") - gpc.config.ckpt.async_upload = False + if "load_optimizer" not in ckpt: + ckpt._add_item("load_optimizer", True) - if "snapshot_ckpt_folder" not in gpc.config.ckpt: - gpc.config.ckpt._add_item("snapshot_ckpt_folder", os.path.join(gpc.config.ckpt.save_ckpt_folder, "snapshot")) + if "stop_file_path" not in ckpt: + ckpt._add_item("stop_file_path", None) - if "oss_snapshot_freq" not in gpc.config.ckpt and gpc.config.ckpt.checkpoint_every != float("inf"): - gpc.config.ckpt._add_item("oss_snapshot_freq", gpc.config.ckpt.checkpoint_every / 2) - assert gpc.config.ckpt.oss_snapshot_freq > 0 + if "load_given_ckpt" not in ckpt: + # If 'load_given_ckpt' is not given, we set it to False, so internlm can have opportunity + # to auto-load latest checkpoint. + ckpt._add_item("load_given_ckpt", False) - assert not ( - gpc.config.ckpt.load_ckpt_folder is not None and gpc.config.ckpt.load_model_only_folder is not None - ), "'load_ckpt_folder' and 'load_model_only_folder' cannot be set at the same time." + if ckpt.load_given_ckpt: + # Priority: load_given_ckpt(True) > latest_checkpoint > load_model_only_folder + if ckpt.load_ckpt_folder and ckpt.load_model_only_folder: + logger.warning( + "Detect 'load_ckpt_folder' and 'load_model_only_folder' set at the same time, \ +and 'load_given_ckpt' is True, so internlm will load from 'load_ckpt_folder'" + ) + ckpt.load_model_only_folder = None if gpc.is_rank_for_log(): logger.info("+" * 15 + " Ckpt Info " + "+" * 15) # pylint: disable=W1201 - logger.info(f"is enable save ckpt: {gpc.config.ckpt.enable_save_ckpt}") - logger.info(f"save_ckpt_folder: {gpc.config.ckpt.save_ckpt_folder}") - logger.info(f"checkpoint_every: {gpc.config.ckpt.checkpoint_every}") - logger.info(f"async_upload: {gpc.config.ckpt.async_upload}") - if gpc.config.ckpt.async_upload: - logger.info(f"async_upload_tmp_folder: {gpc.config.ckpt.async_upload_tmp_folder}") + logger.info(f"is enable save ckpt: {ckpt.enable_save_ckpt}") + logger.info(f"save_ckpt_folder: {ckpt.save_ckpt_folder}") + logger.info(f"checkpoint_every: {ckpt.checkpoint_every}") + logger.info(f"load_given_ckpt: {ckpt.load_given_ckpt}") # initialization storage manager - init_storage_manager(gpc.config.ckpt) + init_storage_manager(ckpt) # tensorboard writer config if "enable_tb" not in gpc.config: gpc.config._add_item("enable_tb", True) if "tensorboard_folder" not in gpc.config: - gpc.config._add_item("tensorboard_folder", None) + gpc.config._add_item( + "tensorboard_folder", os.environ["tensorboard_folder"] if "tensorboard_folder" in os.environ else None + ) if "resume_tb_folder" not in gpc.config: - gpc.config._add_item("resume_tb_folder", None) + gpc.config._add_item( + "resume_tb_folder", os.environ["resume_tb_folder"] if "resume_tb_folder" in os.environ else None + ) # cudnn torch.backends.cudnn.benchmark = gpc.config.get("cudnn_benchmark", False) @@ -236,11 +265,13 @@ def args_sanity_check(): # process the model config if "use_flash_attn" not in gpc.config.model: gpc.config.model._add_item("use_flash_attn", True) - if "sequence_parallel" not in gpc.config.model: - gpc.config.model._add_item("sequence_parallel", False) + + # process the parallel config + if "sequence_parallel" not in gpc.config.parallel: + gpc.config.parallel._add_item("sequence_parallel", False) else: assert not ( - gpc.config.model.sequence_parallel is True and gpc.config.model.use_flash_attn is False + gpc.config.parallel.sequence_parallel is True and gpc.config.model.use_flash_attn is False ), "sequence parallel does not support use_flash_attn=False" # feishu webhook address for alerting diff --git a/internlm/model/embedding.py b/internlm/model/embedding.py index d35b9c1..8c59aaf 100644 --- a/internlm/model/embedding.py +++ b/internlm/model/embedding.py @@ -7,6 +7,7 @@ import rotary_emb import torch import torch.nn.functional as F from einops import rearrange +from flash_attn.layers.rotary import ApplyRotaryEmb as LegacyApplyRotaryEmb from flash_attn.layers.rotary import ApplyRotaryEmbQKV_ as LegacyApplyRotaryEmbQKV_ from torch import Tensor, nn @@ -56,7 +57,7 @@ class Embedding1D(nn.Module): output = gather_forward_split_backward(output_parallel, ParallelMode.TENSOR, dim=-1) - if gpc.config.model.sequence_parallel: + if gpc.config.parallel.sequence_parallel: output = split_forward_gather_backward(output, ParallelMode.TENSOR, dim=1) return output @@ -111,6 +112,7 @@ class ApplyRotaryEmbQKV_(torch.autograd.Function): apply_rotary_emb_qkv_ = ApplyRotaryEmbQKV_.apply legacy_apply_rotary_embed_qkv = LegacyApplyRotaryEmbQKV_.apply +legacy_apply_rotary_embed = LegacyApplyRotaryEmb.apply class RotaryEmbedding(torch.nn.Module): diff --git a/internlm/model/linear.py b/internlm/model/linear.py index 50b4bf0..32f29f8 100644 --- a/internlm/model/linear.py +++ b/internlm/model/linear.py @@ -62,7 +62,7 @@ class ScaleColumnParallelLinear(nn.Linear): weight, self.bias, process_group=self.process_group, - sequence_parallel=gpc.config.model.sequence_parallel, + sequence_parallel=gpc.config.parallel.sequence_parallel, ) @@ -111,7 +111,7 @@ class RewardModelLinear(ScaleColumnParallelLinear): weight, self.bias, process_group=self.process_group, - sequence_parallel=gpc.config.model.sequence_parallel, + sequence_parallel=gpc.config.parallel.sequence_parallel, ) @@ -173,7 +173,7 @@ class FeedForward(nn.Module): hidden_features, process_group, bias, - sequence_parallel=gpc.config.model.sequence_parallel, + sequence_parallel=gpc.config.parallel.sequence_parallel, device=device, dtype=dtype, ) @@ -182,7 +182,7 @@ class FeedForward(nn.Module): hidden_features, process_group, bias, - sequence_parallel=gpc.config.model.sequence_parallel, + sequence_parallel=gpc.config.parallel.sequence_parallel, device=device, dtype=dtype, ) @@ -191,7 +191,7 @@ class FeedForward(nn.Module): out_features, process_group, bias=bias, - sequence_parallel=gpc.config.model.sequence_parallel, + sequence_parallel=gpc.config.parallel.sequence_parallel, device=device, dtype=dtype, ) diff --git a/internlm/model/modeling_internlm.py b/internlm/model/modeling_internlm.py index d03bb8f..3783f66 100644 --- a/internlm/model/modeling_internlm.py +++ b/internlm/model/modeling_internlm.py @@ -393,7 +393,7 @@ class PackedFlashInternLm1D(nn.Module): max_position_embeddings=-1, process_group=gpc.get_group(ParallelMode.TENSOR), padding_idx=None, - sequence_parallel=gpc.config.model.sequence_parallel, + sequence_parallel=gpc.config.parallel.sequence_parallel, device=device, dtype=dtype, ) diff --git a/internlm/model/moe.py b/internlm/model/moe.py index 75beb14..1504838 100644 --- a/internlm/model/moe.py +++ b/internlm/model/moe.py @@ -74,7 +74,7 @@ class MoE(torch.nn.Module): drop_tokens: bool = True, use_rts: bool = True, using_default_moe: bool = True, - use_residual=True, + use_residual=False, residual_mlp=None, ): diff --git a/internlm/model/multi_head_attention.py b/internlm/model/multi_head_attention.py index 2b213ec..48deef5 100644 --- a/internlm/model/multi_head_attention.py +++ b/internlm/model/multi_head_attention.py @@ -82,7 +82,7 @@ class MHA(nn.Module): 3 * embed_dim, process_group, bias=True, - sequence_parallel=gpc.config.model.sequence_parallel, + sequence_parallel=gpc.config.parallel.sequence_parallel, **factory_kwargs, ) # according to https://spaces.ac.cn/archives/9577 @@ -95,7 +95,11 @@ class MHA(nn.Module): # output projection always have the bias (for now) self.out_proj = RowParallelLinearTorch( - embed_dim, embed_dim, process_group, sequence_parallel=gpc.config.model.sequence_parallel, **factory_kwargs + embed_dim, + embed_dim, + process_group, + sequence_parallel=gpc.config.parallel.sequence_parallel, + **factory_kwargs, ) # need to assign tp attribute so that internlm know it is tensor parallel module if gpc.get_world_size(ParallelMode.TENSOR) > 1: diff --git a/internlm/moe/sharded_moe.py b/internlm/moe/sharded_moe.py index 1ae68e0..3bd529b 100644 --- a/internlm/moe/sharded_moe.py +++ b/internlm/moe/sharded_moe.py @@ -356,6 +356,8 @@ class TopKGate(Module): # Only top-1 and top-2 are supported at the moment. if k not in (1, 2): raise ValueError("Only top-1 and top-2 gatings are supported.") + # TODO: can we use tensor parallel here? + # Deepspeed's mechisms, alway use fp32 self.wg = torch.nn.Linear(model_dim, num_experts, bias=False).float() self.k = k self.capacity_factor = capacity_factor diff --git a/internlm/solver/optimizer/hybrid_zero_optim.py b/internlm/solver/optimizer/hybrid_zero_optim.py index 853071c..b1ff4a6 100644 --- a/internlm/solver/optimizer/hybrid_zero_optim.py +++ b/internlm/solver/optimizer/hybrid_zero_optim.py @@ -3,6 +3,7 @@ import math from functools import partial +from itertools import product import torch import torch.distributed as dist @@ -20,6 +21,7 @@ from internlm.solver.optimizer.store import ( ) from internlm.solver.optimizer.utils import ( DynamicGradScaler, + ParamBcastSyncHandler, flatten, get_grad_accumulate_object, has_inf_or_nan, @@ -88,10 +90,10 @@ class HybridZeroOptimizer(BaseOptimizer): self, optimizer: Optimizer, cpu_offload=False, - overlap_broadcast=False, grad_scal_cfg: Config = None, zero_cfg: Config = None, has_moe: bool = False, + param_bcast_sync_handler: ParamBcastSyncHandler = None, ): # DynamicGradScaler related args if gpc.config.model.dtype is torch.float32: @@ -163,7 +165,9 @@ class HybridZeroOptimizer(BaseOptimizer): + f"zo-{self._zero_local_rank}.pt" ) self.params_per_rank_id_dict = [] - self.overlap_broadcast = overlap_broadcast + self._param_bcast_sync_handler = param_bcast_sync_handler + if self._overlap_communication: + assert self._param_bcast_sync_handler is not None # iterate over the param group in the optimizer # partition these param groups for data parallel training @@ -238,6 +242,8 @@ class HybridZeroOptimizer(BaseOptimizer): # communication-computation overlapping if self._overlap_communication: self._comm_stream = torch.cuda.Stream() + else: + self._comm_stream = torch.cuda.current_stream() # reduction hook is only used if overlapping communication # if it is stage 1 without overlapping, no hook will be attached @@ -284,8 +290,10 @@ class HybridZeroOptimizer(BaseOptimizer): global_id = str(i) for j in range(len(param.size())): global_id = "_".join([global_id, str(param.size()[j])]) - - rank_to_go = numel_per_rank.index(min(numel_per_rank)) + if self._overlap_communication: + rank_to_go = self._param_bcast_sync_handler.get_rank_by_param(param) + else: + rank_to_go = numel_per_rank.index(min(numel_per_rank)) params_per_rank[rank_to_go].append(param) self.params_per_rank_id_dict[-1][rank_to_go].append(global_id) numel_per_rank[rank_to_go] += param.numel() @@ -322,7 +330,9 @@ class HybridZeroOptimizer(BaseOptimizer): self._grad_store.add_accumulate_grad_object(accum_grad_obj) reduction_func = partial( - self._store_and_try_reduce_grads_by_bucket, param=param, reduce_rank=reduce_rank + self._store_and_try_reduce_grads_by_bucket, + param=param, + reduce_rank=reduce_rank, ) # define hook @@ -416,16 +426,16 @@ class HybridZeroOptimizer(BaseOptimizer): def _reduce_and_copy(self, bucket: TensorBucket, reduce_rank, dp_parallel_mode): if self._overlap_communication: - stream = self._comm_stream - stream.synchronize() + self._comm_stream.synchronize() self._param_store.clear_grads_of_previous_reduced_params() - else: - stream = torch.cuda.current_stream() - with torch.cuda.stream(stream): + with torch.cuda.stream(self._comm_stream): flat = bucket.flatten() reduced_flat = reduce_tensor( - tensor=flat, dtype=self.dtype, dst_rank=reduce_rank, parallel_mode=dp_parallel_mode + tensor=flat, + dtype=self.dtype, + dst_rank=reduce_rank, + parallel_mode=dp_parallel_mode, ) # update the reduced tensor @@ -616,7 +626,10 @@ class HybridZeroOptimizer(BaseOptimizer): if found_inf: if gpc.is_rank_for_log(): logger.warning("Overflow occurs, please check it.") - send_alert_message(address=gpc.config.alert_address, message="Overflow occurs, please check it.") + send_alert_message( + address=gpc.config.alert_address, + message="Overflow occurs, please check it.", + ) self._grad_store._averaged_gradients = dict() self.zero_grad() return False, None @@ -678,37 +691,42 @@ class HybridZeroOptimizer(BaseOptimizer): fp32_param = self._fp32_flat_param_groups_of_current_rank[group_id] fp16_param.data.copy_(fp32_param) - # TODO: support broadcast overlap - self.broadcast_params(overlap=False) + with torch.cuda.stream(self._comm_stream): + self.broadcast_params() timer("step").stop() + # update gradients may not be needed here, because the sync_params function is used in initialization, # so synchronization is maintained return True, [global_norm / loss_scale for global_norm in global_norm_groups] - def broadcast_params(self, overlap=False): + def broadcast_params(self): handles = [] - for group_id in range(self.num_param_groups): + for rank, group_id in product(range(self._zero_world_size), range(self.num_param_groups)): if self._is_moe_group(self.optim.param_groups[group_id]): continue - for rank in range(self._zero_world_size): - # The following operations are performed only on the rank to which parameters are assigned. - if rank not in self.param_group_no_params_ranks[group_id]: - fp16_param = self._param_store.get_flat_fp16_param_by_rank_group(rank=rank, group_id=group_id) - # grank = gpc.get_ranks_in_group(group_type)[rank] # need to convert to the global rank - # assert grank == rank, f"{grank} == {rank}" - g_rank = gpc.get_ranks_in_group(self._broadcast_parallel_mode)[rank] - handle = dist.broadcast( - fp16_param, src=g_rank, group=gpc.get_group(ParallelMode.ZERO1), async_op=True - ) - handles.append(handle) + # The following operations are performed only on the rank to which parameters are assigned. + if rank in self.param_group_no_params_ranks[group_id]: + continue + fp16_param = self._param_store.get_flat_fp16_param_by_rank_group(rank=rank, group_id=group_id) + # grank = gpc.get_ranks_in_group(group_type)[rank] # need to convert to the global rank + # assert grank == rank, f"{grank} == {rank}" + g_rank = gpc.get_ranks_in_group(self._broadcast_parallel_mode)[rank] + handle = dist.broadcast( + fp16_param, + src=g_rank, + group=gpc.get_group(ParallelMode.ZERO1), + async_op=True, + ) - if not overlap: - for handle in handles: - handle.wait() - else: - return handles + if self._overlap_communication: + self._param_bcast_sync_handler.add_bcast_handle(rank, handle) + else: + handles.append(handle) + + for handle in handles: + handle.wait() ################## # FP16 Utilities # @@ -726,7 +744,11 @@ class HybridZeroOptimizer(BaseOptimizer): if avg_grad is not None and has_inf_or_nan(avg_grad): self._found_overflow.fill_(1.0) break - dist.all_reduce(self._found_overflow, op=dist.ReduceOp.MAX, group=gpc.get_group(ParallelMode.GLOBAL)) + dist.all_reduce( + self._found_overflow, + op=dist.ReduceOp.MAX, + group=gpc.get_group(ParallelMode.GLOBAL), + ) return self._found_overflow.item() > 0 diff --git a/internlm/solver/optimizer/utils.py b/internlm/solver/optimizer/utils.py index 5a752ef..38e4560 100644 --- a/internlm/solver/optimizer/utils.py +++ b/internlm/solver/optimizer/utils.py @@ -3,15 +3,18 @@ import math from abc import ABC, abstractmethod -from typing import Dict, Optional +from collections import OrderedDict +from functools import partial +from typing import Any, Dict, Optional, Union import torch import torch.distributed as dist -from torch import Tensor +from torch import Tensor, nn from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors from internlm.core.context import ParallelMode from internlm.core.context import global_context as gpc +from internlm.core.naive_amp import NaiveAMPModel from internlm.utils.common import get_tensor_norm, move_norm_to_cuda from internlm.utils.logger import get_logger from internlm.utils.parallel import is_model_parallel_parameter @@ -60,12 +63,19 @@ def get_grad_accumulate_object(tensor): def split_half_float_double(tensor_list): - dtypes = ["torch.cuda.HalfTensor", "torch.cuda.FloatTensor", "torch.cuda.DoubleTensor", "torch.cuda.BFloat16Tensor"] - buckets = [] - for _, dtype in enumerate(dtypes): - bucket = [t for t in tensor_list if t.type() == dtype] - if bucket: - buckets.append(bucket) + dtype_buckets = { + "torch.cuda.HalfTensor": [], + "torch.cuda.FloatTensor": [], + "torch.cuda.DoubleTensor": [], + "torch.cuda.BFloat16Tensor": [], + } + + for t in tensor_list: + dtype = t.type() + if dtype in dtype_buckets: + dtype_buckets[dtype].append(t) + + buckets = [bucket for bucket in dtype_buckets.values() if bucket] return buckets @@ -184,7 +194,10 @@ def calc_l2_norm(grads): if APEX_AVAILABLE: dummy_overflow_buf = torch.cuda.IntTensor([0]) norm, _ = multi_tensor_applier( - amp_C.multi_tensor_l2norm, dummy_overflow_buf, [grads], False # no per-parameter norm + amp_C.multi_tensor_l2norm, + dummy_overflow_buf, + [grads], + False, # no per-parameter norm ) else: norm, _ = multi_tensor_l2norm_torch(grads, False) @@ -228,7 +241,11 @@ def compute_norm(gradients, parameters, last_stage=False, previous_norm=None, no # Take max across all model-parallel GPUs. if gpc.get_world_size(ParallelMode.MODEL) > 1: - dist.all_reduce(total_norm_cuda, op=dist.ReduceOp.MAX, group=gpc.get_group(ParallelMode.MODEL)) + dist.all_reduce( + total_norm_cuda, + op=dist.ReduceOp.MAX, + group=gpc.get_group(ParallelMode.MODEL), + ) total_norm = total_norm_cuda[0].item() else: tensor_parallel_grads = [] @@ -280,7 +297,11 @@ def compute_norm(gradients, parameters, last_stage=False, previous_norm=None, no # Sum across all model-parallel GPUs. if gpc.is_initialized(ParallelMode.MODEL): - dist.all_reduce(total_norm, op=dist.ReduceOp.SUM, group=gpc.get_group(ParallelMode.MODEL)) + dist.all_reduce( + total_norm, + op=dist.ReduceOp.SUM, + group=gpc.get_group(ParallelMode.MODEL), + ) # This is because we use zero1, so we need to use this reduction. # TODO: Check zero group to be a subset of dp group. @@ -459,3 +480,90 @@ class DynamicGradScaler(BaseGradScaler): self._scale = self._scale.fill_(state_dict["_scale"]) self._growth_step = state_dict["_growth_step"] self._hysteresis_step = state_dict["_hysteresis_step"] + + +class ParamBcastSyncHandler: + """ + Model Partition Handler for overlap broadcast with forward + """ + + def __init__(self, model: Union[nn.Module, nn.ModuleList]) -> None: + self._block_to_param = OrderedDict() # + self._param_to_rank = dict() # + self._block_to_rank = dict() # + self._bcast_handles = dict() # + + zero1_size = gpc.get_world_size(ParallelMode.ZERO1) + total_param_num = sum(p.numel() for p in model.parameters()) + avg_param_num = total_param_num * 1.0 // zero1_size + + # just want to share same for loop for ModuleList and Module + if not isinstance(model, nn.ModuleList): + model = [model] + + # record the parameters to transformer/embeding/head/norm block + for _chunk in model: + if isinstance(_chunk, NaiveAMPModel): + _chunk = _chunk.model + + for _, children in _chunk.named_children(): + # should be the transformer block definaton in modeling_xxx.py + if isinstance(children, nn.ModuleList): + # record the block that a parameter belongs to + for _, block in enumerate(children): + # self._block_to_param[f"{name}.{idx}"] = list(block.parameters()) + self._block_to_param[block] = list(block.parameters()) + else: + # record the block that a parameter belongs to + # self._block_to_param[name] = list(children.parameters()) + self._block_to_param[children] = list(children.parameters()) + + alloc_num = 0 + rank_to_go = 0 + + # process the parameters in block_to_param sequencially, + # allocate each parameter to a local rank of ParallelMode.ZERO1, + # NOTE that we do NOT consider following scenarios: + # 1) whether a parameter is trainable; + # 2) paramters maybe in different optimizer group + for block, params in self._block_to_param.items(): + # allocate a model block to a local rank of ParallelMode.ZERO1 + self._block_to_rank[block] = [rank_to_go] + for p in params: + alloc_num = alloc_num + p.numel() + # in this case, allocate the param to next rank if possible + if alloc_num > avg_param_num * 1.01 and rank_to_go < zero1_size - 1: + rank_to_go = rank_to_go + 1 + alloc_num = 0 + self._block_to_rank[block].append(rank_to_go) + # allocate a parameter to a local rank of ParallelMode.ZERO1 + self._param_to_rank[p] = rank_to_go + + # initialize an empty list for _bcast_handles of each rank + for rank in range(gpc.get_world_size(ParallelMode.ZERO1)): + self._bcast_handles[rank] = [] + + # register_forward_pre_hook for transformer/embeding/norm/xxx block + self._register_sync_parameters_hook() + + def _register_sync_parameters_hook(self) -> None: + def _pre_forward_hook(model: nn.Module, inputs: Any): # pylint: disable=W0613 + bcast_handles = [] + # gather all required broadcast hanles into a list + for rank in self._block_to_rank[model]: + bcast_handles.extend(self._bcast_handles[rank]) + # need to clear _bcast_handles since they would be processed later + self._bcast_handles[rank] = [] + # wait all required broadcast handles to be completed + for handle in bcast_handles: + handle.wait() + + # register_forward_pre_hook for transformer/embeding/norm/xxx block + for block, _ in self._block_to_rank.items(): + block.register_forward_pre_hook(partial(_pre_forward_hook)) + + def get_rank_by_param(self, param) -> int: + return self._param_to_rank[param] + + def add_bcast_handle(self, rank, handle) -> None: + self._bcast_handles[rank].append(handle) diff --git a/internlm/train/__init__.py b/internlm/train/__init__.py new file mode 100644 index 0000000..7a0eddb --- /dev/null +++ b/internlm/train/__init__.py @@ -0,0 +1,21 @@ +from .training_internlm import ( + get_train_data_loader, + get_validation_data_loader, + initialize_distributed_env, + initialize_llm_profile, + initialize_model, + initialize_optimizer, + load_new_batch, + record_current_batch_training_metrics, +) + +__all__ = [ + "get_train_data_loader", + "get_validation_data_loader", + "initialize_distributed_env", + "initialize_llm_profile", + "initialize_model", + "initialize_optimizer", + "load_new_batch", + "record_current_batch_training_metrics", +] diff --git a/internlm/train/training_internlm.py b/internlm/train/training_internlm.py new file mode 100644 index 0000000..3fee265 --- /dev/null +++ b/internlm/train/training_internlm.py @@ -0,0 +1,447 @@ +#!/usr/bin/env python +# -*- encoding: utf-8 -*- + +import time +from functools import partial +from typing import Callable, Iterable, Union + +import torch +import torch.distributed as dist +from torch import nn +from torch.utils.data import ConcatDataset, DataLoader + +import internlm +from internlm.core.context import ParallelMode +from internlm.core.context import global_context as gpc +from internlm.core.naive_amp import NaiveAMPModel +from internlm.core.trainer import TrainState +from internlm.data.batch_sampler import StaticBatchSampler, get_dpsampler_dataloader +from internlm.data.collaters import jsonl_ds_collate_fn, packed_collate_fn +from internlm.data.dataset import get_dataset_dict +from internlm.data.dummy_dataset import RandomDataset +from internlm.data.packed_dataset import ( + PackedDataset, + PackedDatasetWithoutCuSeqlen, + get_packed_dataset_without_short_length, +) +from internlm.data.utils import DATASET_TYPE_IDS_MAP, unpack_data +from internlm.model.moe import create_moe_param_groups, has_moe_layers +from internlm.monitor import set_env_var +from internlm.monitor.monitor import monitor_manager as mm +from internlm.solver.beta2_scheduler import Beta2Scheduler +from internlm.solver.lr_scheduler import FineTuneCosineAnnealingWarmupLR +from internlm.solver.optimizer import HybridZeroOptimizer +from internlm.solver.optimizer.utils import ParamBcastSyncHandler +from internlm.utils.common import DummyProfile, get_master_node +from internlm.utils.logger import get_logger +from internlm.utils.megatron_timers import megatron_timer as timer +from internlm.utils.parallel import ( + is_no_pp_or_last_stage, + sync_model_param_with_ep, + sync_model_param_within_tp, +) +from internlm.utils.registry import MODEL_INITIALIZER + +logger = get_logger(__file__) + + +def initialize_distributed_env(config: str, launcher: str = "slurm", master_port: int = 8888, seed: int = 1024): + """ + Initialize distributed environment for distributed training. + + Args: + config (str): Config file path. + launcher (str): Launcher for launching distributed environment, can be slurm or torch. "slurm" by default. + master_port (str): The master port for distributed training. 8888 by default. + seed (int, optional): Specified random seed for every process. 1024 by default. + """ + + torch.cuda.empty_cache() + + if launcher == "torch": + internlm.launch_from_torch(config=config, seed=seed) + elif launcher == "slurm": + internlm.launch_from_slurm( + config=config, + host=get_master_node(), + port=master_port, + seed=seed, + ) + else: + assert launcher in ["slurm", "torch"], "launcher only support slurm or torch" + + +def initialize_model(): + """ + Initialize model. + + Returns: The neural network model to be trained or evaluated. + """ + + model = MODEL_INITIALIZER.get_module(module_name=gpc.config.model_type)(**(gpc.config.model)) + if isinstance(model, nn.ModuleList): + model = nn.ModuleList( + [ + NaiveAMPModel( + model=_m, + output_to_fp32=False, # manually controlled by interleaved pipleline scheduler + dtype=gpc.config.model.get("dtype", torch.half), + sync_buffer=False, + ) + for _m in model + ] + ) + else: + model = NaiveAMPModel( + model=model, + output_to_fp32=is_no_pp_or_last_stage(), + dtype=gpc.config.model.get("dtype", torch.half), + sync_buffer=False, + ) + + # This sync is very important, cause the model weights kept in optimizer are copied + # from the origin parameters in the memory, so we should make sure the dp sync + # does not influence the model weights in optimizer be different with the origin parameters. + sync_model_param_with_ep(model) + + # This function is needed to make sure parameters that are not splitted by tensor parallelism are + # the same across tensor parallelism. + sync_model_param_within_tp(model) + + return model + + +def initialize_optimizer(model: Union[nn.Module, nn.ModuleList]): + """ + Initialize optimizer. + + Args: + model (torch.nn.Module): Your model instance to be trained or evaluated. + + Returns: A tuple of (optimizer, beta2_scheduler, lr_scheduler). + """ + param_bcast_sync_handler = ParamBcastSyncHandler(model) + adam_cfg = gpc.config.adam + if gpc.config.model.num_experts > 1: + params = create_moe_param_groups(model, adam_cfg.weight_decay) + else: + params = [{"params": model.parameters(), "weight_decay": adam_cfg.weight_decay}] + naive_optimizer = torch.optim.AdamW( + params=params, + lr=adam_cfg.lr, + betas=(adam_cfg.adam_beta1, adam_cfg.adam_beta2), + eps=adam_cfg.adam_eps, + ) + + has_moe = has_moe_layers(model) + optimizer = HybridZeroOptimizer( + naive_optimizer, + grad_scal_cfg=gpc.config.grad_scaler, + zero_cfg=gpc.config.hybrid_zero_optimizer, + has_moe=has_moe, + param_bcast_sync_handler=param_bcast_sync_handler, + ) + + beta2_scheduler = Beta2Scheduler(optimizer=naive_optimizer, **gpc.config.beta2_scheduler) + + lr_scheduler = FineTuneCosineAnnealingWarmupLR(optimizer, **gpc.config.lr_scheduler) + + return optimizer, beta2_scheduler, lr_scheduler + + +def get_train_data_loader( + num_worker: int = 0, dataset_generate_func: Callable = None, train_sampler=None, train_collate_fn=None +): + """ + Generate and return the training data loader. + + Returns: A tuple of (train_dl, dataset_types). + """ + + # Get the dataset types + dataset_types = None + dataset_types = list(DATASET_TYPE_IDS_MAP.keys()) + data_cfg = gpc.config.data + + # Get the sample weight dictionary + train_folder = data_cfg.train_folder + + if not train_folder: + train_ds = RandomDataset(num_samples=1000000, max_len=data_cfg.seq_len) + if data_cfg.pack_sample_into_one: + train_ds = PackedDatasetWithoutCuSeqlen( + train_ds, max_length_per_sample=data_cfg.seq_len, packed_length=data_cfg.packed_length + ) + else: + train_ds = PackedDataset( + train_ds, max_length_per_sample=data_cfg.seq_len, packed_length=data_cfg.packed_length + ) + else: + if dataset_generate_func is not None: + train_ds = dataset_generate_func() + else: + train_ds = get_packed_dataset_without_short_length( + folder=data_cfg.train_folder, + packed_length=data_cfg.packed_length, + max_length_per_sample=data_cfg.seq_len, + show_progress=dist.get_rank() == 0, + min_length=data_cfg.min_length, + min_length_dict=data_cfg.get("min_length_dict", {}), + pack_into_one_sample=data_cfg.pack_sample_into_one, + ) + + if dataset_generate_func is None or not train_folder: + # partition already completed + assert isinstance(train_ds, (PackedDataset, PackedDatasetWithoutCuSeqlen, ConcatDataset)) + # Create the training dataset sampler + train_sampler = StaticBatchSampler( + train_ds.datasets if isinstance(train_ds, ConcatDataset) else [train_ds], + batch_size=data_cfg.micro_num, + rampup_batch_size=data_cfg.rampup_batch_size, + micro_bsz=data_cfg.micro_bsz, + seed=1024, + drop_last=True, + data_rank=gpc.get_local_rank(ParallelMode.DATA), + data_world_size=gpc.get_world_size(ParallelMode.DATA), + ) + + if dataset_generate_func is None or not train_folder: + train_collate_fn = partial(packed_collate_fn, packed_length=data_cfg.packed_length) + + # Create the training data loader + train_dl = DataLoader( + dataset=train_ds, + batch_sampler=train_sampler, + num_workers=num_worker, + pin_memory=True, + collate_fn=train_collate_fn, + persistent_workers=num_worker > 0, + ) + + return train_dl, dataset_types + + +def get_validation_data_loader( + num_worker: int = 0, dataset_generate_func: Callable = None, val_collate_fn=None, dataloader_func=None +): + """Generate and return the validation data loader.""" + + data_cfg = gpc.config.data + + if not data_cfg.valid_folder: + val_ds = RandomDataset(num_samples=gpc.get_world_size(ParallelMode.DATA) * 500, max_len=data_cfg.seq_len) + else: + if dataset_generate_func is not None: + assert val_collate_fn and dataloader_func is not None + val_ds = dataset_generate_func() + else: + val_ds = get_dataset_dict(folder=data_cfg.valid_folder, split="") + + if not isinstance(val_ds, dict): + val_ds = {"val": val_ds} + + if val_collate_fn is None or not data_cfg.valid_folder: + val_collate_fn = partial(jsonl_ds_collate_fn, max_length_per_sample=data_cfg.seq_len) + + val_dls = {} + for val_name, ds in val_ds.items(): + if dataloader_func and data_cfg.valid_folder is not None: + val_dls[val_name] = dataloader_func(dataset=ds, collate_fn=val_collate_fn) + if gpc.is_rank_for_log(): + logger.info( + f"load validation dataset {val_name} with valid batch size {str(data_cfg.valid_micro_num)} and " + f"{ds.size} Byte samples." + ) + else: + # making the batch_size of validate larger can speed up the evaluation, but it should not be too large, + # otherwise too much data may be dropped + batch_size = min( + data_cfg.valid_micro_num * data_cfg.micro_bsz, len(ds) // gpc.get_world_size(ParallelMode.DATA) + ) + batch_size = batch_size // data_cfg.micro_bsz * data_cfg.micro_bsz + + if batch_size == 0 and gpc.is_rank_for_log(): + logger.info(f"skip validate {val_name}.") + continue + + val_dls[val_name] = get_dpsampler_dataloader( + ds, + shuffle=False, + num_workers=num_worker, + batch_size=batch_size, + collate_fn=val_collate_fn, + drop_last=True, + ) # drop_last=True, otherwise it may cause problems in the last batch + + if gpc.is_rank_for_log(): + logger.info( + f"load validation dataset {val_name} with valid batch size {str(batch_size)} and " + f"samples {str(len(val_dls[val_name]))}." + ) + + return val_dls + + +def load_new_batch(train_dl: DataLoader, train_iter: Iterable, train_state: TrainState): + """ + Load and return the new batch data based on training data loader. + + Args: + train_dl (torch.utils.data.DataLoader): Dataloader for training. + train_iter (Iterable): Data iterator from which get a batch of data, obtained by calling iter(dataloader). + train_state (TrainState): Current training state. + + Returns: A batch data and the updated train_iter. + """ + + timer("batch-gen").start() + try: + batch = next(train_iter) # structure is ({'input_ids': Tensor, 'cu_seqlens': Tensor}, Tensor) + if hasattr(train_state, "batch_sampler_iter"): + next(train_state.batch_sampler_iter) + except StopIteration: + train_iter = iter(train_dl) + batch = next(train_iter) + train_state.num_consumed_samples_in_epoch = 0 + if hasattr(train_state, "batch_sampler"): + train_state.batch_sampler_iter = iter(train_state.batch_sampler) + next(train_state.batch_sampler_iter) + timer("batch-gen").stop() + + if batch[0].get("type_ids", None) is not None: + # if use_flash_attn is False, we need to unpack type_ids + if not gpc.config.model.use_flash_attn: + batch[0]["type_ids"] = unpack_data(batch[0]["type_ids"], batch[0]["cu_seqlens"]) + + return batch, train_iter + + +def initialize_llm_profile(profiling: bool = False, start_time: str = None): + """Initialize and return the profiler context manager instance.""" + + if profiling and gpc.get_local_rank(ParallelMode.DATA) == 0 and gpc.get_local_rank(ParallelMode.TENSOR) == 0: + llm_profile = torch.profiler.profile + logger.info(f"Do profiling in rank {gpc.get_global_rank()}!") + else: + llm_profile = DummyProfile + + return llm_profile( + activities=[torch.profiler.ProfilerActivity.CPU, torch.profiler.ProfilerActivity.CUDA], + schedule=torch.profiler.schedule(skip_first=5, wait=1, warmup=1, active=1, repeat=1), + on_trace_ready=torch.profiler.tensorboard_trace_handler( + f"{gpc.config.JOB_NAME}/{start_time}/traces/rank{gpc.get_global_rank()}_" + + f"dp{gpc.get_local_rank(ParallelMode.DATA)}_" + + f"tp{gpc.get_local_rank(ParallelMode.TENSOR)}_" + + f"pp{gpc.get_local_rank(ParallelMode.PIPELINE)}", + ), + with_stack=True, + with_modules=True, + ) + + +def record_current_batch_training_metrics( + get_tflops_func, + logger, + writer, + success_update, + batch_count, + batch, + train_state, + optimizer, + beta2_scheduler, + trainer, + start_time, + loss, + moe_loss, + grad_norm, + metric, + update_panel, +): + """ + Print some training metrics of current batch. + """ + + set_env_var(key="LAST_ACTIVE_TIMESTAMP", value=int(time.time())) + + if success_update in (0, True): + train_state.num_consumed_tokens += batch[1].nelement() * gpc.get_world_size(ParallelMode.DATA) + if is_no_pp_or_last_stage(): + acc_perplex = metric.get_metric() + + if success_update and gpc.is_rank_for_log(): + lr = optimizer.param_groups[0]["lr"] + if hasattr(trainer.engine.optimizer, "grad_scaler"): + scaler = trainer.engine.optimizer.grad_scaler._scale.item() + elif hasattr(trainer.engine.optimizer.optim, "grad_scaler"): + scaler = trainer.engine.optimizer.optim.grad_scaler._scale.item() + + num_tokens_in_batch = batch[1].nelement() + num_samples_in_batch = sum([len(b) - 1 for b in batch[0]["cu_seqlens"]]) + max_length_in_batch = max([(b[1:] - b[:-1]).max().item() for b in batch[0]["cu_seqlens"]]) + max_samples_in_batch = max([len(b) - 1 for b in batch[0]["cu_seqlens"]]) + min_samples_in_batch = min([len(b) - 1 for b in batch[0]["cu_seqlens"]]) + + tk_per_gpu = 0 + tk_per_gpu = round( + num_tokens_in_batch + * gpc.get_world_size(ParallelMode.DATA) + / gpc.get_world_size(ParallelMode.GLOBAL) + / (time.time() - start_time), + 2, + ) + + tflops = get_tflops_func((time.time() - start_time)) + + infos = { + "tflops": tflops, + "step": batch_count, + "loss": loss.item(), + "moe_loss": moe_loss.item(), + "tgs (tokens/gpu/second)": tk_per_gpu, + "lr": lr, + "loss_scale": scaler, + "grad_norm": grad_norm, + } + + infos["micro_num"] = len(batch[1]) + infos["num_consumed_tokens"] = train_state.num_consumed_tokens + infos["inf_nan_skip_batches"] = train_state.inf_nan_skip_batches + infos["num_samples_in_batch"] = num_samples_in_batch # the number of batches which have the most samples + infos["largest_length"] = max_length_in_batch # the longest input + infos["largest_batch"] = max_samples_in_batch # the batch with the most samples + infos["smallest_batch"] = min_samples_in_batch + infos["adam_beta2"] = beta2_scheduler.get_beta2() + + fwd_bwd_time = round(timer("fwd-bwd").elapsed(), 2) + infos["fwd_bwd_time"] = fwd_bwd_time + + for key, value in acc_perplex.items(): + infos[key] = value + + line = "" + for key, value in infos.items(): + line += f"{key}={value} " + writer.add_scalar(key=key, value=value, step=train_state.step_count) + + if update_panel: + logger.info( + line, + extra={ + "step": batch_count, + "lr": lr, + "num_consumed_tokens": train_state.num_consumed_tokens, + "grad_norm": grad_norm, + "loss": loss.item(), + "moe_loss": moe_loss.item(), + "flops": tflops, + "tgs": tk_per_gpu, + "acc": acc_perplex["acc"], + "perplexity": acc_perplex["perplexity"], + "fwd_bwd_time": fwd_bwd_time, + }, + ) + else: + logger.info(line) + + # if loss spike occurs, send alert info to feishu + mm.monitor_loss_spike(alert_address=gpc.config.alert_address, step_count=batch_count, cur_step_loss=loss.item()) diff --git a/internlm/utils/common.py b/internlm/utils/common.py index d479284..f3b58c0 100644 --- a/internlm/utils/common.py +++ b/internlm/utils/common.py @@ -218,3 +218,21 @@ def get_megatron_flops( tflops = flops_per_iteration / (elapsed_time_per_iter * global_world_size * (10**12)) return tflops + + +class DummyProfile: + """ + Dummy Profile. + """ + + def __init__(self, *args, **kwargs) -> None: + pass + + def __enter__(self): + return self + + def __exit__(self, a, b, c): + pass + + def step(self): + pass diff --git a/internlm/utils/evaluation.py b/internlm/utils/evaluation.py index d10f0c1..a5c8607 100644 --- a/internlm/utils/evaluation.py +++ b/internlm/utils/evaluation.py @@ -50,6 +50,16 @@ def switch_evaluation_pipeline_scheduler(trainer, num_microbatches, tensor_shape trainer.schedule._hooks = prev_metric_hooks +@contextmanager +def switch_sequence_parallel_mode(): + prev_mode = gpc.config.parallel.sequence_parallel + try: + gpc.config.parallel.sequence_parallel = False + yield + finally: + gpc.config.parallel.sequence_parallel = prev_mode + + def evaluate_on_val_dls( trainer, val_dls, @@ -57,110 +67,102 @@ def evaluate_on_val_dls( logger, step_count, update_panel: bool = False, + streaming: bool = False, ): - torch.cuda.empty_cache() - trainer.eval() - verbose = gpc.is_rank_for_log() - data_cfg = gpc.config.data + with switch_sequence_parallel_mode(): + torch.cuda.empty_cache() + trainer.eval() + verbose = gpc.is_rank_for_log() + data_cfg = gpc.config.data - for val_name, val_dl in val_dls.items(): - if len(val_dl) == 0 and verbose: - logger.info(f"Validation dataset: {val_name} is empty") - continue + for val_name, val_dl in val_dls.items(): + if len(val_dl) == 0 and verbose and not streaming: + logger.info(f"Validation dataset: {val_name} is empty") + continue - val_metric = AccPerplex( - device=torch.cuda.current_device(), - tp_pg=gpc.get_group(ParallelMode.TENSOR), - dp_pg=gpc.get_group(ParallelMode.DATA), - ) - val_sche_metric_hook = SchedulerMetricHook(metric=val_metric) + val_metric = AccPerplex( + device=torch.cuda.current_device(), + tp_pg=gpc.get_group(ParallelMode.TENSOR), + dp_pg=gpc.get_group(ParallelMode.DATA), + ) + val_sche_metric_hook = SchedulerMetricHook(metric=val_metric) - val_loss = 0 - val_idx = -1 - for val_idx, batch in tqdm( - enumerate(val_dl), - desc="Val.", - total=len(val_dl), - position=1, - disable=not verbose, - leave=False, - ): - with torch.inference_mode(): - if gpc.is_using_pp(): - total_val_bsz = len(batch[1]) - assert total_val_bsz % data_cfg.micro_bsz == 0 - num_microbatches = total_val_bsz // data_cfg.micro_bsz - tensor_shape = torch.Size( - [data_cfg.micro_bsz, batch[0]["input_ids"].shape[1], gpc.config.HIDDEN_SIZE] + val_loss = 0 + val_idx = -1 + for val_idx, batch in tqdm( + enumerate(val_dl), + desc="Val.", + total=len(val_dl) if not streaming else None, + position=1, + disable=not verbose, + leave=False, + ): + with torch.inference_mode(): + if gpc.is_using_pp(): + total_val_bsz = len(batch[1]) + assert total_val_bsz % data_cfg.micro_bsz == 0 + num_microbatches = total_val_bsz // data_cfg.micro_bsz + tensor_shape = torch.Size( + [data_cfg.micro_bsz, batch[0]["input_ids"].shape[1], gpc.config.HIDDEN_SIZE] + ) + + with switch_evaluation_pipeline_scheduler( + trainer=trainer, + num_microbatches=num_microbatches, + tensor_shape=tensor_shape, + metric_hook_list=[val_sche_metric_hook], + ): + _, _, loss, _ = trainer.execute_schedule( + batch, forward_only=True, return_loss=True, return_output_label=False + ) + else: + total_val_bsz = len(batch[1]) + assert total_val_bsz % data_cfg.micro_bsz == 0 + grad_accum_size = total_val_bsz // data_cfg.micro_bsz + grad_accum_batch_size = data_cfg.micro_bsz + with switch_evaluation_no_pipeline_scheduler( + trainer=trainer, + grad_accum_size=grad_accum_size, + grad_accum_batch_size=grad_accum_batch_size, + metric_hook_list=[val_sche_metric_hook], + ): + _, _, loss, _ = trainer.execute_schedule( + batch, forward_only=True, return_loss=True, return_output_label=False + ) + if verbose: + val_loss += loss.item() + + assert val_idx != -1 + dist.barrier() + + val_res = val_metric.get_metric() + if verbose and len(val_dl) != 0: + val_loss = val_loss / (val_idx + 1 + 1e-6) + infos = { + "step": step_count, + f"val/{val_name}_loss": val_loss, + f"val/{val_name}_acc": val_res["acc"], + f"val/{val_name}_plex": val_res["perplexity"], + } + + for key, value in infos.items(): + writer.add_scalar(key=key, value=value, step=step_count) + + if update_panel: + logger.info( + f"Validation on {val_name}: " + " ".join([f"{key}={value}" for key, value in infos.items()]), + extra={ + "step": step_count, + "val_loss": val_loss, + "val_acc": val_res["acc"], + "val_perplexity": val_res["perplexity"], + }, + ) + else: + logger.info( + f"Validation on {val_name}: " + " ".join([f"{key}={value}" for key, value in infos.items()]) ) - with switch_evaluation_pipeline_scheduler( - trainer=trainer, - num_microbatches=num_microbatches, - tensor_shape=tensor_shape, - metric_hook_list=[val_sche_metric_hook], - ): - _, _, loss = trainer.execute_schedule( - batch, forward_only=True, return_loss=True, return_output_label=False - ) - else: - total_val_bsz = len(batch[1]) - assert total_val_bsz % data_cfg.micro_bsz == 0 - grad_accum_size = total_val_bsz // data_cfg.micro_bsz - grad_accum_batch_size = data_cfg.micro_bsz - with switch_evaluation_no_pipeline_scheduler( - trainer=trainer, - grad_accum_size=grad_accum_size, - grad_accum_batch_size=grad_accum_batch_size, - metric_hook_list=[val_sche_metric_hook], - ): - _, _, loss = trainer.execute_schedule( - batch, forward_only=True, return_loss=True, return_output_label=False - ) - if verbose: - val_loss += loss.item() - - assert val_idx != -1 + trainer.train() + torch.cuda.empty_cache() dist.barrier() - - val_res = val_metric.get_metric() - if verbose and len(val_dl) != 0: - val_loss = val_loss / (val_idx + 1 + 1e-6) - infos = { - "step": step_count, - f"val/{val_name}_loss": val_loss, - f"val/{val_name}_acc": val_res["acc"], - f"val/{val_name}_plex": val_res["perplexity"], - } - - for key, value in infos.items(): - writer.add_scalar(key=key, value=value, step=step_count) - - if update_panel: - logger.info( - f"Validation on {val_name}: " + " ".join([f"{key}={value}" for key, value in infos.items()]), - extra={ - "step": step_count, - "val_loss": val_loss, - "val_acc": val_res["acc"], - "val_perplexity": val_res["perplexity"], - }, - ) - else: - logger.info( - f"Validation on {val_name}: " + " ".join([f"{key}={value}" for key, value in infos.items()]) - ) - - trainer.train() - torch.cuda.empty_cache() - dist.barrier() - - -@contextmanager -def switch_sequence_parallel_mode(): - prev_mode = gpc.config.model.sequence_parallel - try: - gpc.config.model.sequence_parallel = False - yield - finally: - gpc.config.model.sequence_parallel = prev_mode diff --git a/internlm/utils/megatron_timers.py b/internlm/utils/megatron_timers.py index 6c4ed11..e319a80 100644 --- a/internlm/utils/megatron_timers.py +++ b/internlm/utils/megatron_timers.py @@ -14,18 +14,19 @@ class _Timer: self.elapsed_ = 0.0 self.started_ = False self.start_time = time.time() + self.stream = torch.cuda.current_stream() def start(self): """Start the timer.""" assert not self.started_, "timer has already been started" - torch.cuda.synchronize() + self.stream.synchronize() self.start_time = time.time() self.started_ = True def stop(self): """Stop the timer.""" assert self.started_, "timer is not started" - torch.cuda.synchronize() + self.stream.synchronize() self.elapsed_ += time.time() - self.start_time self.started_ = False diff --git a/internlm/utils/model_checkpoint.py b/internlm/utils/model_checkpoint.py index 3fe29cc..2c7a8f4 100644 --- a/internlm/utils/model_checkpoint.py +++ b/internlm/utils/model_checkpoint.py @@ -2,8 +2,12 @@ # -*- encoding: utf-8 -*- import copy +import fcntl import os +import re +import socket import time +from collections import defaultdict from enum import Enum from typing import Dict @@ -12,6 +16,8 @@ import torch from internlm.core.context import ParallelMode from internlm.core.context import global_context as gpc from internlm.core.trainer import TrainState +from internlm.model.moe import MoE +from internlm.monitor import send_alert_message from internlm.solver.optimizer import HybridZeroOptimizer from internlm.utils.common import get_current_device from internlm.utils.logger import get_logger @@ -25,8 +31,6 @@ from internlm.utils.storage_manager import ( logger = get_logger(__file__) -quit_signal_handler = None - class CheckpointType(Enum): NORMAL_CHECKPOINT = 1 @@ -69,6 +73,8 @@ def save_model_checkpoint(folder, model): """ states = model.state_dict() + # get non-moe parameters + states = get_non_moe_state_dict(states) topo = get_model_topology(model) if folder is not None: @@ -92,6 +98,9 @@ def save_model_checkpoint(folder, model): topo_fp = os.path.join(folder, topo_fn) llm_save(topo_fp, saved_obj=topo) + # move the judgement logic into save_moe_checkpoint(.) + try_save_moe_checkpoint(folder, model) + torch.distributed.barrier() @@ -128,6 +137,18 @@ def load_model_checkpoint(folder, model): fp = os.path.join(folder, should_load_name) states = llm_load(fp, map_location=get_current_device()) + """ + # need convert the gate parameters to float32 (to fit deepspeed style mechanism), it may cause round-off in + # gate.weight. The conversion will also be done when doing forward. so we can just comment it out. this make + # the gate parameters to be float16 before forward. + for key in list(states.keys()): + if 'moe_layer.gate.wg.weight' in key: + states[key] = states[key].float() + print("load: ", states[key].float(),flush=True) + """ + + try_load_moe_checkpoint(folder, model, states) + missing_k, unexpected_keys = model.load_state_dict(states, strict=False) if len(missing_k) != 0: logger.warning(f"Warning: missing keys {missing_k}") @@ -139,6 +160,58 @@ def load_model_checkpoint(folder, model): torch.cuda.empty_cache() +def try_save_moe_checkpoint(folder, model): + # Using layer_#_expert_# to save the model's expert state_dict,a hack. + moe_layer_id = 0 + for n_module, module in model.named_modules(): + if isinstance(module, MoE): # and deepspeed.comm.get_rank() == 0: + num_local_experts = module.num_local_experts + expp_rank = gpc.get_local_rank(ParallelMode.EXPERT) + + # get all moe parameters + moe_state_dict = {} + for n, p in module.state_dict().items(): + if "expert" in n and "moe_layer.gate.wg.weight" not in n: + moe_state_dict[n_module + "." + n] = p + moe_str_prefix = ".moe_layer.experts.experts." + # Reorder the moe name rank, so that each checkpoint only has one expert + experts_state_dict = defaultdict(dict) + for key in list(moe_state_dict.keys()): + m = re.match(f".*{moe_str_prefix}([0-9]+).*", key) + + local_expert_id = None + if not m: + logger.warning(f"No expert found in key {key}.") + else: + local_expert_id = m.group(1) + + global_expert_id = expp_rank * num_local_experts + int(local_expert_id) + expert_key = key.replace(f"{moe_str_prefix}{local_expert_id}", f"{moe_str_prefix}{global_expert_id}") + + # truncating extra tensor (shared) storage + truncated = moe_state_dict.pop(key).clone().detach() + experts_state_dict[str(global_expert_id)][expert_key] = truncated + + # let save the moe parameters + for global_expert_id, expert_state_dict in experts_state_dict.items(): + # save the moe parameters + fn = f"model_moe_layer{moe_layer_id}_expert{global_expert_id}.pt" + fp = os.path.join(folder, fn) + llm_save(fp, saved_obj=expert_state_dict) + moe_layer_id += 1 + + +def get_non_moe_state_dict(full_state_dict): + """ + Get the state dict of the non-moe layers + """ + for key in list(full_state_dict.keys()): + if "expert" in key and "moe_layer.gate.wg.weight" not in key: + full_state_dict.pop(key) + + return full_state_dict + + def save_optimizer_checkpoint(optim, state_path): """Store the state of the optimizer to the local file system or remote OSS. @@ -167,42 +240,25 @@ def save_optimizer_checkpoint(optim, state_path): llm_save(os.path.join(state_path, fp), states) -def save_checkpoint(folder, model, optimizer, scheduler, train_state: TrainState, model_config: Dict = None): - """ - Save checkpoint to the given folder path. - """ - - start = time.time() - torch.distributed.barrier() - folder = os.path.join(folder, str(train_state.step_count)) - logger.info( - f"Saving checkpoint to `{folder}` at batch count:{train_state.step_count} from rank:{gpc.get_global_rank()}..." - ) - - timer("save-model").start() - save_model_checkpoint(folder=folder, model=model) - timer("save-model").stop() - - timer("save-optimizer").start() - save_optimizer_checkpoint(optim=optimizer, state_path=folder) - timer("save-optimizer").stop() - - if gpc.is_rank_for_log(): - scheduler_states = scheduler.state_dict() - llm_save(os.path.join(folder, "schedulder.pt"), saved_obj=scheduler_states) - - sampler_state = train_state.batch_sampler.state_dict() - llm_save(os.path.join(folder, "sampler.pt"), saved_obj=sampler_state) - llm_save(os.path.join(folder, "context.pt"), saved_obj=train_state.state_dict()) - - if model_config is not None: - llm_save(os.path.join(folder, "model_config.pt"), saved_obj=model_config) - - torch.distributed.barrier() - - if gpc.is_rank_for_log(): - timer.log(["save-model", "save-optimizer"], logger=logger) - logger.info(f"Step: {train_state.step_count}, rank 0 save ckpt use {time.time() - start:.3f} s") +def try_load_moe_checkpoint(folder, model, state_dict): + moe_layer_id = 0 + for _, module in model.named_modules(): + if isinstance(module, MoE): # and deepspeed.comm.get_rank() == 0: + num_local_experts = module.num_local_experts + expp_rank = gpc.get_local_rank(ParallelMode.EXPERT) + # loop all local_experts + for local_expert_id in range(num_local_experts): + global_expert_id = expp_rank * num_local_experts + local_expert_id + fn = f"model_moe_layer{moe_layer_id}_expert{global_expert_id}.pt" + fp = os.path.join(folder, fn) + expert_state_dict = llm_load(fp, map_location=get_current_device()) + # Updating global -> local expert ids + moe_str_prefix = ".moe_layer.experts.experts." + for key in list(expert_state_dict.keys()): + local_key = key.replace(f"{moe_str_prefix}{global_expert_id}", f"{moe_str_prefix}{local_expert_id}") + expert_state_dict[local_key] = expert_state_dict.pop(key) + state_dict.update(expert_state_dict) + moe_layer_id += 1 def load_optimizer_checkpoint(folder, optim): @@ -304,19 +360,12 @@ def load_scheduler(ckpt_path: str, lr_scheduler, optimizer, learning_rate, train logger.info(f"reload load_scheduler:{lr_scheduler}") -class CheckpointSaveManager: +class CheckpointManager: """StorageManagerContext""" - def __init__( - self, - ckpt_config, - model, - optimizer, - lr_scheduler, - model_config, - ) -> None: + def __init__(self, ckpt_config, model, model_config, feishu_address=None) -> None: """ - CheckpointSaveManager is used to decide when to store ckpt. If it is an asynchronous + CheckpointManager is used to decide when to store ckpt. If it is an asynchronous upload mode, you must call wait_async_upload_finish at the end of the program to wait for the asynchronous ckpt upload to complete. @@ -332,26 +381,95 @@ class CheckpointSaveManager: self.save_ckpt_folder = ckpt_config.save_ckpt_folder self.snapshot_ckpt_folder = ckpt_config.snapshot_ckpt_folder self.oss_snapshot_freq: int = ckpt_config.oss_snapshot_freq + self.stop_file_path = ckpt_config.stop_file_path + self.load_model_only_folder = ckpt_config.load_model_only_folder + self.feishu_address = feishu_address self.storage_manager = get_storage_manager() self.snapshot_counter = 0 + self.load_optimizer = gpc.config.ckpt.load_optimizer self.model = model - self.optimizer = optimizer - self.lr_scheduler = lr_scheduler self.model_config = model_config + if self.stop_file_path and gpc.get_global_rank() == 0: + dir_path = os.path.dirname(self.stop_file_path) + if dir_path != "" and not os.path.exists(dir_path): + os.makedirs(dir_path) + with open(self.stop_file_path, "w", encoding="utf-8") as f: + f.write("0") + + if ckpt_config.load_given_ckpt is False: + # Priority: load_given_ckpt(True) > latest_checkpoint > load_model_only_folder + latest_ckpt_path = self.query_lastest_ckpt() + if latest_ckpt_path: + self.load_ckpt_folder = latest_ckpt_path + else: + # At this time, we have to load model init weights and train from step 0. + self.load_ckpt_folder = self.load_model_only_folder + else: + self.load_ckpt_folder = ckpt_config.load_ckpt_folder + + if gpc.is_rank_for_log(): + logger.info(f"load_ckpt_folder will set to :'{self.load_ckpt_folder}'") + if self.stop_file_path is None: + logger.warning("no set stop_file_path, quit_signal_handler is disable") + + def quit_signal_handler(self, train_state) -> bool: + """ + Exit signal detection function, if we write the exit step in the 'QUIT_FILE_PATH' file, + all ranks will save ckpt and exit. + Negative integer step means save ckpt. + Positive integer step means save ckpt and quit. + + Args: + train_state (TrainState): + Returns: + bool: whether to quit. + """ + now_break, now_save_ckpt, save_type = False, False, CheckpointType.NORMAL_CHECKPOINT + + if self.stop_file_path is None: + return now_break, now_save_ckpt, save_type + + with open(self.stop_file_path, "a+", encoding="utf-8") as f: + fcntl.flock(f, fcntl.LOCK_EX) + f.seek(0) + msg = f.read() + fcntl.flock(f, fcntl.LOCK_UN) + action_step = int(msg) + + if action_step < 0 and abs(action_step) == train_state.step_count: + now_save_ckpt = True + + if action_step > 0 and action_step == train_state.step_count: + now_break, now_save_ckpt = True, True + + if action_step != 0 and gpc.is_rank_for_log(): + msg = "Stop" if action_step > 0 else "Save" + action_step = abs(action_step) + if train_state.step_count <= action_step: + if self.feishu_address: + send_alert_message( + address=self.feishu_address, + message=f"training will {msg} at step_count {action_step}!\ +now step_count is {train_state.step_count}", + ) + + return now_break, now_save_ckpt, save_type + def try_save_checkpoint(self, train_state): if not self.enable_save_ckpt: - return + return False save_ckpts, save_type = False, CheckpointType.NORMAL_CHECKPOINT if self.oss_snapshot_freq > 1 and train_state.step_count % self.oss_snapshot_freq == 0: save_ckpts, save_type = True, CheckpointType.SNAPSHOT_CHECKPOINT if train_state.step_count % self.checkpoint_every == 0: save_ckpts, save_type = True, CheckpointType.NORMAL_CHECKPOINT + now_break, singal_save_ckpts, singal_save_type = self.quit_signal_handler(train_state) if save_ckpts is False: - if quit_signal_handler is not None: - save_ckpts, save_type = quit_signal_handler(train_state) + save_ckpts = singal_save_ckpts + save_type = singal_save_type if save_ckpts: # Wait for the previous round of asynchronous upload storage to complete. @@ -361,9 +479,9 @@ class CheckpointSaveManager: self.snapshot_counter = (self.snapshot_counter + 1) % 2 save_ckpt_folder = os.path.join(self.snapshot_ckpt_folder, f"{self.snapshot_counter}") else: - save_ckpt_folder = self.save_ckpt_folder + save_ckpt_folder = os.path.join(self.save_ckpt_folder, str(train_state.step_count)) - save_checkpoint( + self.save_checkpoint( folder=save_ckpt_folder, model=self.model, optimizer=self.optimizer, @@ -372,7 +490,221 @@ class CheckpointSaveManager: model_config=self.model_config, ) + return now_break + def wait_async_upload_finish(self): """wait for all checkpoint uploads to be completed""" self.storage_manager.wait() torch.distributed.barrier() + + def query_latest_snapshot_step_boto3(self): + """query_latest_snapshot_step_boto3 + Returns: + Tuple(str, int): path of latest ckpt and ckpt step, if not found, None will return. + """ + ckpt_list = self.storage_manager.get_fns(self.save_ckpt_folder) + if len(ckpt_list) == 0: + return None, None + + max_normal_step = 0 + ckpt_list = list(map(lambda a: int(a.strip("/")) if a.strip("/").isdigit() else 0, ckpt_list)) + ckpt_list.sort(reverse=True) + for ckpt in ckpt_list: + fns_list = self.storage_manager.get_fns(os.path.join(self.save_ckpt_folder, str(ckpt))) + for fn in fns_list: + if fn.endswith(".step"): + max_normal_step = ckpt + break + if max_normal_step != 0: + break + + max_normal_step = ckpt_list[0] + load_normal_ckpt_path = os.path.join(self.save_ckpt_folder, str(max_normal_step)) + + snapshot_path_0 = os.path.join(self.save_ckpt_folder, "snapshot", "0") + snapshot_path_1 = os.path.join(self.save_ckpt_folder, "snapshot", "1") + ckpt_list_1 = self.storage_manager.get_fns(snapshot_path_0) + ckpt_list_2 = self.storage_manager.get_fns(snapshot_path_1) + max_step_0, max_step_1 = 0, 0 + for ckpt in ckpt_list_1: + ckpt = ckpt.strip("/") + if ckpt.endswith(".step"): + max_step_0 = max(max_step_0, int(ckpt.split(".")[0])) + for ckpt in ckpt_list_2: + ckpt = ckpt.strip("/") + if ckpt.endswith(".step"): + max_step_1 = max(max_step_1, int(ckpt.split(".")[0])) + + snap_load_path = snapshot_path_0 if max_step_0 > max_step_1 else snapshot_path_1 + snap_step = max(max_step_0, max_step_1) + load_path = snap_load_path if snap_step > max_normal_step else load_normal_ckpt_path + load_step = max(snap_step, max_normal_step) + return load_path, load_step + + def query_latest_snapshot_step_local(self): + max_step, max_step_path = 0, None + for root, _, files in os.walk(self.save_ckpt_folder, followlinks=True): + for fn in files: + fn = fn.strip("/") + if fn.endswith(".step"): + # We assume that both normal ckpt and snapshot ckpt will store the '.step' file + # as an integrity flag. + step = int(fn.rsplit(".", maxsplit=1)[0]) + if max_step < step: + max_step = step + max_step_path = root + + return max_step_path, max_step + + def query_lastest_ckpt(self): + latest_checkpoint = None + # Training was automatically restarted by the process, forcing the latest snapshot to be read. + if self.save_ckpt_folder: + if self.save_ckpt_folder.startswith("boto3"): + latest_checkpoint, step = self.query_latest_snapshot_step_boto3() + elif self.save_ckpt_folder.startswith("local"): + latest_checkpoint, step = self.query_latest_snapshot_step_local() + else: + latest_checkpoint, step = None, 0 + + if latest_checkpoint is not None: + if gpc.is_rank_for_log(): + logger.info(f"Found latest ckpt : {latest_checkpoint}, step: {step}") + send_alert_message( + address=self.feishu_address, + message=f"Auto restart resume from ckpt-path: '{latest_checkpoint}', step : {step}", + ) + else: + if gpc.is_rank_for_log(): + send_alert_message( + address=self.feishu_address, + message=f"Can't find snapshot checkpoint, use default load-ckpt path: {latest_checkpoint}", + ) + + return latest_checkpoint + + def try_load_model(self, current_time=""): + model_load_path = None + + if self.load_ckpt_folder and self.load_model_only_folder: + raise ValueError( + "Error, try to use both load_ckpt_folder and load_model_only_folder paths, \ +if you only need to load model weights (for example starting an SFT task for the first time), \ +set load_model_only_folder path, if you need to resume training from ckpt, \ +set load_ckpt_folder or use default value \ +(if is the default value, internlm will try to load the latest ckpt from save_ckpt_folder)" + ) + + if self.load_ckpt_folder: + if gpc.is_rank_for_log(): + logger.info( + f"===========Resume training from `{self.load_ckpt_folder}` {current_time} on host:" + f"{socket.gethostname()}===========" + ) + model_load_path = self.load_ckpt_folder + elif self.load_model_only_folder: + if gpc.is_rank_for_log(): + logger.info( + f"===========Load Model from `{self.load_model_only_folder}` {current_time} on host:" + f"{socket.gethostname()}===========" + ) + model_load_path = self.load_model_only_folder + else: + if gpc.is_rank_for_log(): + logger.info( + f"===========New Run {current_time} on host:{socket.gethostname()},rank={gpc.get_global_rank()}," + f"tp={gpc.get_local_rank(ParallelMode.TENSOR)},pp={gpc.get_local_rank(ParallelMode.PIPELINE)}," + f"dp={gpc.get_local_rank(ParallelMode.DATA)}===========" + ) + + # Loading model weights must be done before zero is initialized. + if model_load_path is not None: + load_model_checkpoint(folder=model_load_path, model=self.model) + + def try_resume_training(self, lr_scheduler, optimizer, lr, train_state, train_dl): + """Attempt to restore the training state of the last ckpt. + + Args: + lr_scheduler (_LRScheduler): lr_scheduler object. + optimizer (Optimizer): optimizer object. + lr (float): learning rate. + train_state (dict): traing states. + train_dl (DataLoader): traning dataloader object + """ + if self.load_ckpt_folder is not None: + # load optimzier states. + if self.load_optimizer: + load_optimizer_checkpoint(self.load_ckpt_folder, optimizer) + # load lr scheduler states. + load_scheduler(self.load_ckpt_folder, lr_scheduler, optimizer, lr, train_state) + # load training states. + load_context(self.load_ckpt_folder, train_dl, train_state) + # load dataloader sampler states. + if hasattr(train_state, "batch_sampler") and not isinstance( + train_state.batch_sampler, torch.utils.data.sampler.BatchSampler + ): + load_sampler(self.load_ckpt_folder, train_dl.batch_sampler) + if hasattr(train_state, "data_state_dict"): + train_dl.dataset.load_state_dict( + llm_load(os.path.join(self.load_ckpt_folder, "sampler_0.pt")), ckpt_path=self.load_ckpt_folder + ) + self.optimizer = optimizer + self.lr_scheduler = lr_scheduler + + def save_checkpoint(self, folder, model, optimizer, scheduler, train_state: TrainState, model_config: Dict = None): + """ + Save checkpoint to the given folder path. + """ + + start = time.time() + self.set_save_folder(folder, train_state.step_count) + torch.cuda.synchronize() + torch.distributed.barrier() + if gpc.is_rank_for_log(): + logger.info(f"Saving checkpoint to `{folder}` at batch count:{train_state.step_count}...") + + timer("save-model").start() + save_model_checkpoint(folder=folder, model=model) + timer("save-model").stop() + + timer("save-optimizer").start() + save_optimizer_checkpoint(optim=optimizer, state_path=folder) + timer("save-optimizer").stop() + + if ( + hasattr(train_state, "data_state_dict") + and gpc.get_local_rank(ParallelMode.TENSOR) == 0 + and gpc.get_local_rank(ParallelMode.PIPELINE) == 0 + ): + llm_save( + os.path.join(folder, f"sampler_{gpc.get_local_rank(ParallelMode.DATA)}.pt"), + saved_obj=train_state.data_state_dict, + ) + + if gpc.is_rank_for_log(): + scheduler_states = scheduler.state_dict() + llm_save(os.path.join(folder, "schedulder.pt"), saved_obj=scheduler_states) + if hasattr(train_state, "batch_sampler") and not isinstance( + train_state.batch_sampler, torch.utils.data.sampler.BatchSampler + ): + sampler_state = train_state.batch_sampler.state_dict() + llm_save(os.path.join(folder, "sampler.pt"), saved_obj=sampler_state) + llm_save(os.path.join(folder, "context.pt"), saved_obj=train_state.state_dict()) + + if model_config is not None: + llm_save(os.path.join(folder, "model_config.pt"), saved_obj=model_config) + + torch.distributed.barrier() + + if gpc.is_rank_for_log(): + timer.log(["save-model", "save-optimizer"], logger=logger) + logger.info(f"Step: {train_state.step_count}, rank 0 save ckpt use {time.time() - start:.3f} s") + if self.storage_manager.async_mode is False: + llm_save( + os.path.join(folder, f"{train_state.step_count}.step"), + saved_obj=dict({"step": train_state.step_count}), + ) + + def set_save_folder(self, folder, step): + self.storage_manager.latest_save_folder = folder + self.storage_manager.latest_save_step = step diff --git a/internlm/utils/simple_memory_profiler.py b/internlm/utils/simple_memory_profiler.py index 4ca6679..9caf0a2 100644 --- a/internlm/utils/simple_memory_profiler.py +++ b/internlm/utils/simple_memory_profiler.py @@ -1,15 +1,13 @@ import os import time from collections import OrderedDict -from functools import partial +from functools import partial, reduce from typing import Any, Dict, List, Tuple import pyecharts import torch -from internlm.core.context import ParallelMode -from internlm.core.context import global_context as gpc -from internlm.solver.pipeline_utils import partition_uniform +from internlm.core.naive_amp import NaiveAMPModel mb = 1024 * 1024 @@ -107,6 +105,8 @@ class SimpleMemState: """ Update the total memory usage of the model and sub-models. """ + self._total_mem = self._layer_mem + for stat in self.sub_model_stats.values(): # Update sub-model status first. stat.update_total_memory() @@ -169,6 +169,39 @@ class SimpleMemState: return {"name": self.layer_name, "children": children} +class ActivationMemState: + """ + Activation Memory State + """ + + def __init__(self, num_chunks: int) -> None: + self._num_chunks = num_chunks + + self.inited: List[bool] = [False for _ in range(num_chunks)] + self.states: List[SimpleMemState] = [SimpleMemState(f"activations_{idx}") for idx in range(num_chunks)] + + @property + def total_mem(self) -> int: + return sum(state.total_mem for state in self.states) + + def dump(self, prefix: str = "") -> str: + return reduce(lambda x, y: x + y, [state.dump(prefix) for state in self.states]) + + def to_json(self, base: int = 1024 * 1024) -> List: + return [state.to_json(base) for state in self.states] + + +def _unpack_naive_wrapper(model: torch.nn.Module) -> Tuple[torch.nn.Module, int]: + num_chunks = len(model) if isinstance(model, torch.nn.ModuleList) else 1 + + if num_chunks > 1: + model = torch.nn.ModuleList([_model.model if isinstance(_model, NaiveAMPModel) else _model for _model in model]) + else: + model = model.model if isinstance(model, NaiveAMPModel) else model + + return model, num_chunks + + class SimpleMemoryProfiler: """ A memory profiler for a llm model. @@ -177,7 +210,7 @@ class SimpleMemoryProfiler: model (torch.nn.Module): The model to profile. optimizer (torch.optim.Optimizer): The optimizer used for training the model. log_file (str): The file to write the memory state information to. - activation_config (List[str], optional): The list of activation layers to track. Defaults to None. + total_steps: number of steps to trace. """ def __init__( @@ -186,9 +219,8 @@ class SimpleMemoryProfiler: optimizer: torch.optim.Optimizer, log_folder: str, total_steps: int = 5, - activation_config: List[str] = None, ): - self._model = model + self._model, self._num_model_chunks = _unpack_naive_wrapper(model) self._optimizer = optimizer self._log_folder = log_folder self._remaining_steps = total_steps @@ -197,17 +229,20 @@ class SimpleMemoryProfiler: self._record_start_time = time.time() # For activation memory state. - self._activation_config = activation_config - self._activation_mem_inited: bool = False + self._activation_mem: int = 0 - self._activation_max_count = 0 - self._activation_base_mem: SimpleMemState = SimpleMemState("activations") + self._activation_mem_max: int = 0 + self._activation_base_mems = ActivationMemState(self._num_model_chunks) # Check or create log folder os.makedirs(self._log_folder, exist_ok=True) # Register activation memory tracking hooks - self._register_activation_trace_hooks() + if self._num_model_chunks > 1: + for chunk_id in range(self._num_model_chunks): + self._register_activation_trace_hooks(chunk_id, self._model[chunk_id]) + else: + self._register_activation_trace_hooks(0, self._model) # Calculate static parameter cuda memory self._param_mem_state = SimpleMemState("param_mem") @@ -221,7 +256,7 @@ class SimpleMemoryProfiler: self._calc_tensor_group_memory(self._os_params_mem_state, list(enumerate(self._optimizer.param_groups))) # Generate the first memory record - self.point(create=True) + self.point(with_options="params,grads,os_params", create=True) def point(self, with_options: str = "", create: bool = False) -> None: """ @@ -272,7 +307,7 @@ class SimpleMemoryProfiler: if "os_state" in options: layout_info += "os_state_layout:\n" + self._os_state_mem_state.dump() if "activation_base" in options: - layout_info += "activation_base_layout:\n" + self._activation_base_mem.dump() + layout_info += "activation_base_layout:\n" + self._activation_base_mems.dump() # Write memory state information to log file file_mode = "w" if create else "a" @@ -315,14 +350,14 @@ class SimpleMemoryProfiler: [self._os_params_mem_state.to_json(), self._os_state_mem_state.to_json()], "os_memory_sunburst", ) - self._render_sunburst_chart(self._activation_base_mem.to_json()["children"], "activation_memory_sunburst") + self._render_sunburst_chart(self._activation_base_mems.to_json(), "activation_memory_sunburst") # Generate summary sunburst chart summary_sunburst_data = [ {"name": "params", "value": self._param_mem_state.total_mem // mb}, {"name": "grads", "value": self._grad_mem_state.total_mem // mb}, {"name": "os_params", "value": self._os_params_mem_state.total_mem // mb}, {"name": "os_state", "value": self._os_state_mem_state.total_mem // mb}, - {"name": "activation", "value": self._activation_base_mem.total_mem // mb}, + {"name": "activation", "value": self._activation_mem_max // mb}, ] self._render_sunburst_chart(summary_sunburst_data, "summary_sunburst") @@ -337,12 +372,13 @@ class SimpleMemoryProfiler: {}, { "r0": "10%", - "r": "40%", + "r": "35%", "itemStyle": {"borderWidth": 3}, "label": {"align": "left"}, }, - {"r0": "40%", "r": "65%", "label": {"align": "left"}}, - {"r0": "65%", "r": "80%", "label": {"align": "left"}}, + {"r0": "35%", "r": "55%", "label": {"align": "left"}}, + {"r0": "55%", "r": "70%", "label": {"align": "left"}}, + {"r0": "70%", "r": "80%", "label": {"align": "left"}}, {"r0": "80%", "r": "90%", "label": {"align": "left"}}, { "r0": "90%", @@ -357,7 +393,14 @@ class SimpleMemoryProfiler: f"{self._log_folder}/{name}.html" ) - def _inner_activation_trace_hook(self, layer_name: str, model: Any, inputs: Any, output: torch.Tensor) -> None: + def _inner_activation_trace_hook( + self, + chunk_id: int, + layer_name: str, + model: Any, + inputs: Any, + output: torch.Tensor, + ) -> None: """ Hook function to trace the activation memory usage for a inner layer. @@ -373,13 +416,15 @@ class SimpleMemoryProfiler: del model, inputs assert isinstance(output, torch.Tensor), f"Invalid output type: {type(output)}" - if self._stoped or self._activation_mem_inited: + if self._stoped or self._activation_base_mems.inited[chunk_id]: return # Delay updating the total_mem of activation_base_mem here, it will be handled in the forward ending hook. - self._activation_base_mem.add(layer_name, output.element_size() * output.nelement(), flush=False) + self._activation_base_mems.states[chunk_id].add( + layer_name, output.element_size() * output.nelement(), flush=False + ) - def _activation_trace_hook_forward(self, model: Any, inputs: Any, output: torch.Tensor) -> None: + def _activation_trace_hook_forward(self, chunk_id: int, model: Any, inputs: Any, output: torch.Tensor) -> None: """ Hook function to trace the activation memory usage for a forward pass. @@ -398,23 +443,24 @@ class SimpleMemoryProfiler: return # Check if the activation memory has been initialized - if self._activation_mem_inited is False: + if self._activation_base_mems.inited[chunk_id] is False: + self._activation_base_mems.inited[chunk_id] = True # Update the total memory of the activation base memory state - self._activation_base_mem.update_total_memory() + self._activation_base_mems.states[chunk_id].update_total_memory() # Set with_options to "activation_base" to include activation_base_layout in the memory dump - self._activation_mem_inited = True + with_options = "activation_base" + else: + with_options = "" # Accumulate activation memory usage for each forward pass - self._activation_mem += self._activation_base_mem.total_mem - - # Update activation max count - if self._activation_mem // self._activation_base_mem.total_mem > self._activation_max_count: - self._activation_max_count = self._activation_mem // self._activation_base_mem.total_mem + self._activation_mem += self._activation_base_mems.states[chunk_id].total_mem + if self._activation_mem > self._activation_mem_max: + self._activation_mem_max = self._activation_mem # Trigger a memory record - self.point() + self.point(with_options) - def _activation_tarce_hook_backward(self, model: Any, inputs: Any, grad_outputs: Any) -> None: + def _activation_tarce_hook_backward(self, chunk_id: int, model: Any, inputs: Any, grad_outputs: Any) -> None: """ Hook function to trace the activation memory usage for a backward pass. @@ -432,37 +478,28 @@ class SimpleMemoryProfiler: return # Release activation memory usage for each backward pass - self._activation_mem -= self._activation_base_mem.total_mem + self._activation_mem -= self._activation_base_mems.states[chunk_id].total_mem # Trigger a memory record self.point() - def _register_activation_trace_hooks(self) -> None: + def _register_activation_trace_hooks(self, chunk_id: int, model_chunk: torch.nn.Module) -> None: """ Register activation trace hooks for the model and each submodule in the model. """ # Register inner activation trace hooks for each submodule in the model - for layer_name in self._activation_config: - # Register a hook for every activation - model = self._model - sub_models = layer_name.split(".") - # Get the target sub-model - for sub_model_name in sub_models: - try: - model = model.get_submodule(sub_model_name) - except AttributeError: - model = None - break - + for layer_name, sub_model in model_chunk.named_modules(): # Register the hook - if model is not None: - model.register_forward_hook(partial(self._inner_activation_trace_hook, layer_name)) + if len(sub_model._modules) != 0: + continue # TODO: in some special cases, we may need some additional configuration to correct + + sub_model.register_forward_hook(partial(self._inner_activation_trace_hook, chunk_id, layer_name)) # Register a forward hook for the main model to track activation memory usage - self._model.register_forward_hook(self._activation_trace_hook_forward) + model_chunk.register_forward_hook(partial(self._activation_trace_hook_forward, chunk_id)) # Register a backward hook for the main model to release activation memory usage - self._model.register_full_backward_hook(self._activation_tarce_hook_backward) + model_chunk.register_full_backward_hook(partial(self._activation_tarce_hook_backward, chunk_id)) def _calc_tensor_memory( self, root_stat: SimpleMemState, named_tensors: Dict[str, torch.Tensor], require_grad: bool = False @@ -554,48 +591,6 @@ class SimpleMemoryProfiler: self._calc_tensor_memory(root_stat, named_tensors) -def build_activation_config(num_layers: int, num_chunks: int = 1) -> List[str]: - # TODO: support interleaved pipeline scheduling. - assert num_chunks == 1, "Only support num_chunks == 1" - - if gpc.is_initialized(ParallelMode.PIPELINE): - pipeline_size = gpc.get_world_size(ParallelMode.PIPELINE) - pipeline_rank = gpc.get_local_rank(ParallelMode.PIPELINE) - else: - pipeline_size = 1 - pipeline_rank = 0 - - all_parts = partition_uniform(num_layers, pipeline_size, num_chunks) - parts = all_parts[pipeline_rank] - start, end = parts[0] - num_blocks = end - start - - block_conf_tmpl = [ - "mixer.rotary_emb", - "mixer.Wqkv", - "mixer.inner_attn", - "mixer.inner_cross_attn", - "mixer.out_proj", - # "dropout1", # skip when dropout_selective_checkpoint is True - # "dropout2", # skip when dropout_selective_checkpoint is True - "norm1", - "norm2", - "mlp.w1", - "mlp.w2", - "mlp.w3", - ] - - block_conf = [] - for block_id in range(num_blocks): - block_conf += [f"blocks.{block_id}.{layer}" for layer in block_conf_tmpl] - - # We don't need to care about whether the embedding, norm, and head layers exist in the model after partitioning. - # If they don't exist, they will be automatically ignored when registering activation trace hooks. - activation_conf = ["embedding", "norm", "head"] + block_conf - - return activation_conf - - if __name__ == "__main__": class SimpleModel(torch.nn.Module): @@ -635,32 +630,39 @@ if __name__ == "__main__": return output + def _simple_schedule(_num_chunks, _model_chunks, _input) -> torch.Tensor: + if _num_chunks > 1: + _output = _input + for _model_chunk in _model_chunks: + _output = _model_chunk(_output) + else: + _output = _model_chunks(_input) + + return _output + + # num_chunks config + _num_chunks = 1 + # init model and optimizer - _model: torch.nn.Module = SimpleModel() + if _num_chunks > 1: + _chunks = [SimpleModel(skip_layer2=idx % 2 == 0) for idx in range(_num_chunks)] + _model = torch.nn.ModuleList(_chunks).cuda() + else: + _model: torch.nn.Module = SimpleModel().cuda() _optimizer = torch.optim.Adam(_model.parameters()) - # create activation config for simple model layer by layer. - activation_configs = [ - # model level 0 - "layer1", - "layer2", - "layer3", - # model level 1 - "layer2.layer1", - "layer2.layer3", - ] - - _model.modules() - # init profiler - profiler = SimpleMemoryProfiler(_model, _optimizer, "./test_simple_memory_profiler.log", activation_configs) + profiler = SimpleMemoryProfiler(_model, _optimizer, "./test_simple_memory_profiler", total_steps=1) _optimizer.zero_grad() - x1 = torch.randn((128, 5120)) - x2 = torch.randn((128, 5120)) - out1 = _model(x1) - out2 = _model(x2) + # inputs + x1 = torch.randn((128, 5120)).cuda() + x2 = torch.randn((128, 5120)).cuda() + # forward + out1 = _simple_schedule(_num_chunks, _model, x1) + out2 = _simple_schedule(_num_chunks, _model, x2) + # backward out1.mean().backward() out2.mean().backward() diff --git a/internlm/utils/storage_manager.py b/internlm/utils/storage_manager.py index c9b42ea..c7b71f4 100644 --- a/internlm/utils/storage_manager.py +++ b/internlm/utils/storage_manager.py @@ -15,8 +15,6 @@ from asyncio.tasks import ALL_COMPLETED from datetime import datetime from typing import Any, Awaitable, Callable, Dict, List, Union -import boto3 -import botocore import torch import torch.distributed as dist @@ -24,6 +22,13 @@ from internlm.core.context import global_context as gpc from internlm.utils.common import SingletonMeta from internlm.utils.logger import get_logger +try: + import boto3 + import botocore +except ImportError: + pass + + logger = get_logger(__file__) boto3_url_re = re.compile(r"([^\.]+)\.([\d\.]+)") @@ -234,13 +239,13 @@ class Boto3Client(StorageClient): """ paginator = handler.client.get_paginator("list_objects_v2") pages = paginator.paginate(Bucket=bucket_name, Prefix=fp) - folder_name_list = [] for page in pages: - for obj in page["Contents"]: - fp: str = obj["Key"] - folder_name_list.append(fp.rsplit("/", maxsplit=1)[1]) - return folder_name_list + if "Contents" in page: + for obj in page["Contents"]: + pth: str = obj["Key"] + folder_name_list.append(pth.split(fp, maxsplit=1)[1].strip("/").split("/", maxsplit=1)[0]) + return list(set(folder_name_list)) @staticmethod def async_upload_fileobj(handler, bucket_name: str, fp: str, local_nvme_path: str): @@ -391,6 +396,11 @@ class StorageManager(metaclass=SingletonMeta): self.tmp_local_folder = tmp_local_folder self.async_mode = async_mode self.has_warning = False + self._async_loop = None + self._thread_pool = None + self.latest_save_folder = None + self.latest_save_step = 0 + self.async_task_peeding = False if enable_save and self.async_mode: self._async_loop = asyncio.new_event_loop() @@ -485,6 +495,7 @@ class StorageManager(metaclass=SingletonMeta): torch.save(saved_obj, f, pickle_protocol=pickle.HIGHEST_PROTOCOL) self.async_executor(meta.async_upload_fn, *unpack_meta(meta)) os.chmod(tmp_step_file, stat.S_IRWXU | stat.S_IRWXG | stat.S_IRWXO) + self.async_task_peeding = True else: meta.client.sync_upload_fileobj(*unpack_meta(meta), *args, saved_obj=saved_obj, **kwargs) self.upload_count += 1 @@ -523,23 +534,22 @@ class StorageManager(metaclass=SingletonMeta): pass async def _sync_tasks(self) -> Awaitable[None]: - if not self._async_stack: - return - - await asyncio.wait(self._async_stack, return_when=ALL_COMPLETED) - - for task in self._async_stack: - try: - task.exception() - except InvalidStateError: - continue - except Exception as e: - file_id = len(self._exception_list) - self._exception_list.append((e, file_id)) - - logger.error(f"File: {self._to_be_del_files[file_id]}, " f"upload failed with {e}") - - self._async_stack.clear() + if self._async_stack: + await asyncio.wait(self._async_stack, return_when=ALL_COMPLETED) + count = 0 + while self._async_stack: + t = self._async_stack[0] + try: + e = t.exception() + if e: + self._exception_list.append((e, count)) + logger.error(f"File:{self._to_be_del_files[count]}, upload failed for {e}") + # raise e + count += 1 + self._async_stack.pop(0) + except InvalidStateError: + # Not finished. https://docs.python.org/3/library/asyncio-task.html#asyncio.Task.exception + pass def async_executor(self, fn: Callable, *args, **kwargs) -> None: """ @@ -559,11 +569,14 @@ class StorageManager(metaclass=SingletonMeta): if not self.async_mode: return + if not self.async_task_peeding: + return + if self._async_loop: self._async_loop.run_until_complete(self._sync_tasks()) if self._exception_list: - for file_id, error_msg in self._exception_list: + for error_msg, file_id in self._exception_list: logger.error( f"Node:{socket.gethostname()}, Error: Checkpoint {self._to_be_del_files[file_id]} " f"failed on step {self.upload_count}: {error_msg}" @@ -577,10 +590,16 @@ class StorageManager(metaclass=SingletonMeta): self._del_tmp_folder() self._exception_list.clear() self._to_be_del_files.clear() + self.async_task_peeding = False if gpc.is_rank_for_log(): - logger.info("all async uploads succeeded!") self.upload_count += 1 + if self.async_mode: + self.save( + os.path.join(self.latest_save_folder, f"{self.latest_save_step}.step"), + saved_obj=dict({"step": self.latest_save_step}), + async_upload=False, + ) storage_manager: StorageManager = None diff --git a/internlm/utils/writer.py b/internlm/utils/writer.py index 311c6b3..5ea0680 100644 --- a/internlm/utils/writer.py +++ b/internlm/utils/writer.py @@ -11,10 +11,6 @@ from torch.utils.tensorboard import SummaryWriter from internlm.core.context import global_context as gpc -def copy_ignore_folder(source_path, target_path): - os.system(f"cp -r {source_path}/* {target_path}/") - - def tb_save_run_info(writer, config_lines, global_step=0): writer.add_text(tag="cmd", text_string=" ".join(sys.argv[:]), global_step=global_step) lines = [] @@ -44,7 +40,8 @@ def init_tb_writer( if gpc.get_global_rank() == 0: if resume_tb_folder is not None: logger.info(f"Try mv tensorboard logs: {resume_tb_folder} to {tb_folder}...") - copy_ignore_folder(resume_tb_folder, tb_folder) + os.system(f"cp -r {resume_tb_folder}/* {tb_folder}/") + os.system(f"chmod -R +w {tb_folder}/") else: logger.info(f"Login tensorboard logs to: {tb_folder}") diff --git a/train.py b/train.py index bdfd8db..bad0fb2 100644 --- a/train.py +++ b/train.py @@ -5,99 +5,48 @@ import socket import time import traceback from functools import partial -from typing import Iterable import numpy as np import torch import torch.distributed as dist -from torch import nn -from torch.utils.data import DataLoader import internlm from internlm.core.context import ParallelMode from internlm.core.context import global_context as gpc -from internlm.core.naive_amp import NaiveAMPModel from internlm.core.scheduler import SchedulerMetricHook from internlm.core.trainer import TrainState -from internlm.data.batch_sampler import StaticBatchSampler, get_dpsampler_dataloader -from internlm.data.collaters import jsonl_ds_collate_fn, packed_collate_fn -from internlm.data.dataset import get_dataset_dict -from internlm.data.dummy_dataset import RandomDataset -from internlm.data.packed_dataset import ( - PackedDataset, - PackedDatasetWithoutCuSeqlen, - get_packed_dataset_without_short_length, -) -from internlm.data.utils import DATASET_TYPE_IDS_MAP, unpack_data from internlm.model.loss import FlashGPTLMLoss from internlm.model.metrics import AccPerplex -from internlm.model.moe import create_moe_param_groups, has_moe_layers -from internlm.monitor import initialize_monitor_manager, send_alert_message, set_env_var +from internlm.monitor import initialize_monitor_manager, send_alert_message from internlm.monitor.monitor import monitor_manager as mm -from internlm.solver.beta2_scheduler import Beta2Scheduler -from internlm.solver.lr_scheduler import FineTuneCosineAnnealingWarmupLR -from internlm.solver.optimizer import HybridZeroOptimizer +from internlm.train import ( + get_train_data_loader, + get_validation_data_loader, + initialize_distributed_env, + initialize_llm_profile, + initialize_model, + initialize_optimizer, + load_new_batch, + record_current_batch_training_metrics, +) from internlm.utils.common import ( BatchSkipper, - get_master_node, get_megatron_flops, launch_time, parse_args, ) -from internlm.utils.evaluation import evaluate_on_val_dls, switch_sequence_parallel_mode +from internlm.utils.evaluation import evaluate_on_val_dls from internlm.utils.logger import get_logger, initialize_uniscale_logger from internlm.utils.megatron_timers import megatron_timer as timer -from internlm.utils.model_checkpoint import ( - CheckpointSaveManager, - load_context, - load_model_checkpoint, - load_optimizer_checkpoint, - load_sampler, - load_scheduler, -) -from internlm.utils.parallel import ( - get_parallel_log_file_name, - is_no_pp_or_last_stage, - sync_model_param_with_ep, - sync_model_param_within_tp, -) -from internlm.utils.registry import MODEL_INITIALIZER -from internlm.utils.simple_memory_profiler import ( - SimpleMemoryProfiler, - build_activation_config, -) +from internlm.utils.model_checkpoint import CheckpointManager +from internlm.utils.parallel import get_parallel_log_file_name +from internlm.utils.simple_memory_profiler import SimpleMemoryProfiler from internlm.utils.writer import Writer # global llm logger logger = get_logger(__file__) -def initialize_distributed_env(config: str, launcher: str = "slurm", master_port: int = 8888, seed: int = 1024): - """ - Initialize distributed environment for distributed training. - - Args: - config (str): Config file path. - launcher (str): Launcher for launching distributed environment, can be slurm or torch. "slurm" by default. - master_port (str): The master port for distributed training. 8888 by default. - seed (int, optional): Specified random seed for every process. 1024 by default. - """ - - torch.cuda.empty_cache() - - if launcher == "torch": - internlm.launch_from_torch(config=config, seed=seed) - elif launcher == "slurm": - internlm.launch_from_slurm( - config=config, - host=get_master_node(), - port=master_port, - seed=seed, - ) - else: - assert launcher in ["slurm", "torch"], "launcher only support slurm or torch" - - def initialize_llm_logger(start_time: str): """ Initialize customed uniscale logger. @@ -118,338 +67,14 @@ def initialize_llm_logger(start_time: str): return uniscale_logger -def initialize_model(): - """ - Initialize model. - - Returns: The neural network model to be trained or evaluated. - """ - - model = MODEL_INITIALIZER.get_module(module_name=gpc.config.model_type)(**(gpc.config.model)) - if isinstance(model, nn.ModuleList): - model = nn.ModuleList( - [ - NaiveAMPModel( - model=_m, - output_to_fp32=False, # manually controlled by interleaved pipleline scheduler - dtype=gpc.config.model.get("dtype", torch.half), - sync_buffer=False, - ) - for _m in model - ] - ) - else: - model = NaiveAMPModel( - model=model, - output_to_fp32=is_no_pp_or_last_stage(), - dtype=gpc.config.model.get("dtype", torch.half), - sync_buffer=False, - ) - - # This sync is very important, cause the model weights kept in optimizer are copied - # from the origin parameters in the memory, so we should make sure the dp sync - # does not influence the model weights in optimizer be different with the origin parameters. - sync_model_param_with_ep(model) - - # This function is needed to make sure parameters that are not splitted by tensor parallelism are - # the same across tensor parallelism. - sync_model_param_within_tp(model) - - return model - - -def get_train_data_loader(num_worker: int = 0): - """ - Generate and return the training data loader. - - Returns: A tuple of (train_dl, dataset_types). - """ - - # Get the dataset types - dataset_types = None - dataset_types = list(DATASET_TYPE_IDS_MAP.keys()) - data_cfg = gpc.config.data - - # Get the sample weight dictionary - train_folder = data_cfg.train_folder - - if not train_folder: - train_ds = RandomDataset(num_samples=1000000, max_len=data_cfg.seq_len) - if data_cfg.pack_sample_into_one: - train_ds = PackedDatasetWithoutCuSeqlen( - train_ds, max_length_per_sample=data_cfg.seq_len, packed_length=data_cfg.packed_length - ) - else: - train_ds = PackedDataset( - train_ds, max_length_per_sample=data_cfg.seq_len, packed_length=data_cfg.packed_length - ) - else: - train_ds = get_packed_dataset_without_short_length( - folder=data_cfg.train_folder, - packed_length=data_cfg.packed_length, - max_length_per_sample=data_cfg.seq_len, - show_progress=dist.get_rank() == 0, - min_length=data_cfg.min_length, - min_length_dict=data_cfg.get("min_length_dict", {}), - pack_into_one_sample=data_cfg.pack_sample_into_one, - ) - - # partition already completed - # assert isinstance(train_ds, (PackedDataset, PackedDatasetWithoutCuSeqlen)) - if isinstance(train_ds, (PackedDataset, PackedDatasetWithoutCuSeqlen)): - datasets = [train_ds] - else: - datasets = train_ds.datasets - - # Create the training dataset sampler - train_sampler = StaticBatchSampler( - datasets, - batch_size=data_cfg.micro_num, - rampup_batch_size=data_cfg.rampup_batch_size, - micro_bsz=data_cfg.micro_bsz, - seed=1024, - drop_last=True, - data_rank=gpc.get_local_rank(ParallelMode.DATA), - data_world_size=gpc.get_world_size(ParallelMode.DATA), - ) - - train_collate_fn = partial(packed_collate_fn, packed_length=data_cfg.packed_length) - - # Create the training data loader - train_dl = DataLoader( - dataset=train_ds, - batch_sampler=train_sampler, - num_workers=num_worker, - pin_memory=True, - collate_fn=train_collate_fn, - persistent_workers=True, - ) - - return train_dl, dataset_types - - -def get_validation_data_loader(num_worker: int = 0): - """Generate and return the validation data loader.""" - - data_cfg = gpc.config.data - - if not data_cfg.valid_folder: - val_ds = RandomDataset(num_samples=gpc.get_world_size(ParallelMode.DATA) * 500, max_len=data_cfg.seq_len) - else: - val_ds = get_dataset_dict(folder=data_cfg.valid_folder, split="") - - if not isinstance(val_ds, dict): - val_ds = {"val": val_ds} - - val_collate_fn = partial(jsonl_ds_collate_fn, max_length_per_sample=data_cfg.seq_len) - - val_dls = {} - for val_name, ds in val_ds.items(): - # making the batch_size of validate larger can speed up the evaluation, but it should not be too large, - # otherwise too much data may be dropped - batch_size = min( - data_cfg.valid_micro_num * data_cfg.micro_bsz, len(ds) // gpc.get_world_size(ParallelMode.DATA) - ) - batch_size = batch_size // data_cfg.micro_bsz * data_cfg.micro_bsz - - if batch_size == 0 and gpc.is_rank_for_log(): - logger.info(f"skip validate {val_name}.") # pylint: disable=W1203 - continue - - val_dls[val_name] = get_dpsampler_dataloader( - ds, shuffle=False, num_workers=num_worker, batch_size=batch_size, collate_fn=val_collate_fn, drop_last=True - ) # drop_last=True, otherwise it may cause problems in the last batch - - if gpc.is_rank_for_log(): - logger.info( # pylint: disable=W1203 - f"load validation dataset {val_name} with valid batch size {str(batch_size)} and " - f"samples {str(len(val_dls[val_name]))}." - ) - - return val_dls - - -def load_new_batch(train_dl: DataLoader, train_iter: Iterable, train_state: TrainState): - """ - Load and return the new batch data based on training data loader. - - Args: - train_dl (torch.utils.data.DataLoader): Dataloader for training. - train_iter (Iterable): Data iterator from which get a batch of data, obtained by calling iter(dataloader). - train_state (TrainState): Current training state. - - Returns: A batch data and the updated train_iter. - """ - - timer("batch-gen").start() - try: - batch = next(train_iter) # structure is ({'input_ids': Tensor, 'cu_seqlens': Tensor}, Tensor) - next(train_state.batch_sampler_iter) - except StopIteration: - train_iter = iter(train_dl) - batch = next(train_iter) - train_state.batch_sampler_iter = iter(train_state.batch_sampler) - next(train_state.batch_sampler_iter) - train_state.num_consumed_samples_in_epoch = 0 - timer("batch-gen").stop() - - return batch, train_iter - - -def initialize_optimizer(model: nn.Module): - """ - Initialize optimizer. - - Args: - model (torch.nn.Module): Your model instance to be trained or evaluated. - - Returns: A tuple of (optimizer, beta2_scheduler, lr_scheduler). - """ - - adam_cfg = gpc.config.adam - if gpc.config.model.num_experts > 1: - params = create_moe_param_groups(model, adam_cfg.weight_decay) - else: - params = [{"params": model.parameters(), "weight_decay": adam_cfg.weight_decay}] - naive_optimizer = torch.optim.AdamW( - params=params, - lr=adam_cfg.lr, - betas=(adam_cfg.adam_beta1, adam_cfg.adam_beta2), - eps=adam_cfg.adam_eps, - ) - - has_moe = has_moe_layers(model) - optimizer = HybridZeroOptimizer( - naive_optimizer, - grad_scal_cfg=gpc.config.grad_scaler, - zero_cfg=gpc.config.hybrid_zero_optimizer, - has_moe=has_moe, - ) - - beta2_scheduler = Beta2Scheduler(optimizer=naive_optimizer, **gpc.config.beta2_scheduler) - - lr_scheduler = FineTuneCosineAnnealingWarmupLR(optimizer, **gpc.config.lr_scheduler) - - return optimizer, beta2_scheduler, lr_scheduler - - -def record_current_batch_training_metrics( - get_tflops_func, - logger, - writer, - success_update, - batch_count, - batch, - train_state, - optimizer, - beta2_scheduler, - trainer, - start_time, - loss, - grad_norm, - metric, - update_panel, -): - """ - Print some training metrics of current batch. - """ - - set_env_var(key="LAST_ACTIVE_TIMESTAMP", value=int(time.time())) - - if success_update in (0, True): - train_state.num_consumed_tokens += batch[1].nelement() * gpc.get_world_size(ParallelMode.DATA) - if is_no_pp_or_last_stage(): - acc_perplex = metric.get_metric() - - if success_update and gpc.is_rank_for_log(): - lr = optimizer.param_groups[0]["lr"] - if hasattr(trainer.engine.optimizer, "grad_scaler"): - scaler = trainer.engine.optimizer.grad_scaler._scale.item() - elif hasattr(trainer.engine.optimizer.optim, "grad_scaler"): - scaler = trainer.engine.optimizer.optim.grad_scaler._scale.item() - - num_tokens_in_batch = batch[1].nelement() - num_samples_in_batch = sum([len(b) - 1 for b in batch[0]["cu_seqlens"]]) - max_length_in_batch = max([(b[1:] - b[:-1]).max().item() for b in batch[0]["cu_seqlens"]]) - max_samples_in_batch = max([len(b) - 1 for b in batch[0]["cu_seqlens"]]) - min_samples_in_batch = min([len(b) - 1 for b in batch[0]["cu_seqlens"]]) - - tk_per_gpu = 0 - tk_per_gpu = round( - num_tokens_in_batch - * gpc.get_world_size(ParallelMode.DATA) - / gpc.get_world_size(ParallelMode.GLOBAL) - / (time.time() - start_time), - 2, - ) - - tflops = get_tflops_func((time.time() - start_time)) - - infos = { - "tflops": tflops, - "step": batch_count, - "loss": loss.item(), - "tgs (tokens/gpu/second)": tk_per_gpu, - "lr": lr, - "loss_scale": scaler, - "grad_norm": grad_norm, - } - - infos["micro_num"] = len(batch[1]) - infos["num_consumed_tokens"] = train_state.num_consumed_tokens - infos["inf_nan_skip_batches"] = train_state.inf_nan_skip_batches - infos["num_samples_in_batch"] = num_samples_in_batch # the number of batches which have the most samples - infos["largest_length"] = max_length_in_batch # the longest input - infos["largest_batch"] = max_samples_in_batch # the batch with the most samples - infos["smallest_batch"] = min_samples_in_batch - infos["adam_beta2"] = beta2_scheduler.get_beta2() - - fwd_bwd_time = round(timer("fwd-bwd").elapsed(), 2) - infos["fwd_bwd_time"] = fwd_bwd_time - - for key, value in acc_perplex.items(): - infos[key] = value - - line = "" - for key, value in infos.items(): - line += f"{key}={value} " - writer.add_scalar(key=key, value=value, step=train_state.step_count) - - if update_panel: - logger.info( - line, - extra={ - "step": batch_count, - "lr": lr, - "num_consumed_tokens": train_state.num_consumed_tokens, - "grad_norm": grad_norm, - "loss": loss.item(), - "flops": tflops, - "tgs": tk_per_gpu, - "acc": acc_perplex["acc"], - "perplexity": acc_perplex["perplexity"], - "fwd_bwd_time": fwd_bwd_time, - }, - ) - else: - logger.info(line) - - # if loss spike occurs, send alert info to feishu - mm.monitor_loss_spike(alert_address=gpc.config.alert_address, step_count=batch_count, cur_step_loss=loss.item()) - - def main(args): # init setting skip_batches = gpc.config.data.skip_batches total_steps = gpc.config.data.total_steps valid_every = gpc.config.data.valid_every - load_optimizer = gpc.config.ckpt.load_optimizer label_smoothing = gpc.config.loss.label_smoothing lr = gpc.config.adam.lr - load_model_only_folder = gpc.config.ckpt.get("load_model_only_folder", None) - load_resume_ckpt_folder = gpc.config.ckpt.get("load_ckpt_folder", None) - get_tflops_func = partial( get_megatron_flops, checkpoint=gpc.config.model.checkpoint, @@ -485,32 +110,19 @@ def main(args): enable_tb=gpc.config.enable_tb, ) - model_load_path = None - if load_resume_ckpt_folder is not None: - logger.info( # pylint: disable=W1203 - f"===========Resume training from `{load_resume_ckpt_folder}` {current_time} on host:" - f"{socket.gethostname()}===========" - ) - model_load_path = load_resume_ckpt_folder - elif load_model_only_folder is not None: - logger.info( # pylint: disable=W1203 - f"===========SFT training from `{load_model_only_folder}` {current_time} on host:" - f"{socket.gethostname()}===========" - ) - model_load_path = load_model_only_folder - else: - logger.info( # pylint: disable=W1203 - f"===========New Run {current_time} on host:{socket.gethostname()},rank={gpc.get_global_rank()}," - f"tp={gpc.get_local_rank(ParallelMode.TENSOR)},pp={gpc.get_local_rank(ParallelMode.PIPELINE)}," - f"dp={gpc.get_local_rank(ParallelMode.DATA)}===========" - ) - # initialize and resume train state train_state = TrainState(gpc.config) # initialize model model = initialize_model() + ckpt_manager = CheckpointManager( + ckpt_config=gpc.config.ckpt, + model=model, + model_config=gpc.config.model, + feishu_address=gpc.config.alert_address, + ) + # initialize loss function criterion = FlashGPTLMLoss(parallel_output=True, label_smoothing=label_smoothing) @@ -520,30 +132,12 @@ def main(args): train_state.init_batch_sampler(train_dl) # Loading model weights must be done before zero is initialized. - if model_load_path is not None: - load_model_checkpoint(folder=model_load_path, model=model) + ckpt_manager.try_load_model(current_time) optimizer, beta2_scheduler, lr_scheduler = initialize_optimizer(model=model) # Loading other persistent training states. - if load_resume_ckpt_folder is not None: - # load lr scheduler states. - load_scheduler(load_resume_ckpt_folder, lr_scheduler, optimizer, lr, train_state) - # load training states. - load_context(load_resume_ckpt_folder, train_dl, train_state) - # load dataloader sampler states. - load_sampler(load_resume_ckpt_folder, train_dl.batch_sampler) - # load optimzier states. - if load_optimizer: - load_optimizer_checkpoint(load_resume_ckpt_folder, optimizer) - - ckpt_save_manager = CheckpointSaveManager( - ckpt_config=gpc.config.ckpt, - model=model, - optimizer=optimizer, - lr_scheduler=lr_scheduler, - model_config=gpc.config.model, - ) + ckpt_manager.try_resume_training(lr_scheduler, optimizer, lr, train_state, train_dl) # initialize metric for calculating accuracy and perplexity metric = AccPerplex( @@ -579,12 +173,11 @@ def main(args): # initialize simple memory profiler if args.profiling: memory_profiler = SimpleMemoryProfiler( - model.model, + model, optimizer.optim, log_folder=f"memory_trace/rank{gpc.get_global_rank()}_" + f"dp{gpc.get_local_rank(ParallelMode.DATA)}_" + f"tp{gpc.get_local_rank(ParallelMode.TENSOR)}", - activation_config=build_activation_config(gpc.config.model.num_layers), ) else: memory_profiler = None @@ -597,86 +190,85 @@ def main(args): # transfer the train data loader into train data iterator train_iter = iter(train_dl) - # start iterating the train data and begin training - for batch_count in range(train_state.batch_count, total_steps): - if batch_count % 50 == 0: - torch.cuda.empty_cache() + with initialize_llm_profile(profiling=args.profiling, start_time=current_time) as prof: + # start iterating the train data and begin training + for batch_count in range(train_state.batch_count, total_steps): + if batch_count % 50 == 0: + torch.cuda.empty_cache() - start_time = time.time() - timer("one-batch").start() + start_time = time.time() + timer("one-batch").start() - # load batch data - batch, train_iter = load_new_batch(train_dl=train_dl, train_iter=train_iter, train_state=train_state) + # load batch data + batch, train_iter = load_new_batch(train_dl=train_dl, train_iter=train_iter, train_state=train_state) + + # record the consumed samples in training + train_state.batch_count = batch_count + train_state.num_consumed_samples_in_epoch += len(batch[1]) + if batch_skipper(batch_count): # skip this batch + if gpc.is_rank_for_log(): + logger.info(f"Skip batch count:`{batch_count}`...") + timer("one-batch").stop() + continue + + # zero the grads of parameters + trainer.zero_grad() + # process data + if batch[0].get("type_ids", None) is not None: + metric.set_current_type_ids(type_ids=batch[0].pop("type_ids", None)) + + # do forward and backward + timer("fwd-bwd").start() + + _, _, loss, moe_loss = trainer.execute_schedule( + batch, + forward_only=False, + return_loss=True, + return_output_label=False, + moe_loss_coeff=gpc.config.loss.moe_loss_coeff, + ) + timer("fwd-bwd").stop() + + # update parameters, and returns (success_update, grad_norm) + trainer_result = trainer.step() + assert trainer_result is not None + + success_update, grad_norm_groups = trainer_result + if success_update: # update parameters successfully + train_state.step_count += 1 + else: + train_state.inf_nan_skip_batches += 1 # record the amount of updating parameters unsuccessfully. + if -99.0 in grad_norm_groups and gpc.is_rank_for_log(): # -99.0 encodes a specific failure case + logger.warning(f"Warning: skip parameter update at step {batch_count}.") + send_alert_message( + address=gpc.config.alert_address, + message=f"Warning: skip parameter update at step {batch_count}.", + ) + + # calculate and record the training metrics, eg. loss, accuracy and so on. + record_current_batch_training_metrics( + get_tflops_func=get_tflops_func, + logger=logger, + writer=writer, + success_update=success_update, + batch_count=batch_count, + batch=batch, + train_state=train_state, + optimizer=optimizer, + beta2_scheduler=beta2_scheduler, + trainer=trainer, + start_time=start_time, + loss=loss, + moe_loss=moe_loss, + grad_norm=np.array(grad_norm_groups), + metric=metric, + update_panel=uniscale_logger is not None, + ) - # record the consumed samples in training - train_state.batch_count = batch_count - train_state.num_consumed_samples_in_epoch += len(batch[1]) - if batch_skipper(batch_count): # skip this batch - if gpc.is_rank_for_log(): - logger.info(f"Skip batch count:`{batch_count}`...") # pylint: disable=W1203 timer("one-batch").stop() - continue - # zero the grads of parameters - trainer.zero_grad() - type_ids = batch[0].pop("type_ids", None) - # process data - # if use_flash_attn is False, we need to unpack type_ids - if not gpc.config.model.use_flash_attn: - type_ids = unpack_data(type_ids, batch[0]["cu_seqlens"]) - if type_ids is not None: - metric.set_current_type_ids(type_ids=type_ids) - - # do forward and backward - timer("fwd-bwd").start() - _, _, loss = trainer.execute_schedule( - batch, - forward_only=False, - return_loss=True, - return_output_label=False, - moe_loss_coeff=gpc.config.loss.moe_loss_coeff, - ) - timer("fwd-bwd").stop() - - # update parameters, and returns (success_update, grad_norm) - trainer_result = trainer.step() - assert trainer_result is not None - - success_update, grad_norm_groups = trainer_result - if success_update: # update parameters successfully - train_state.step_count += 1 - else: - train_state.inf_nan_skip_batches += 1 # record the amount of updating parameters unsuccessfully. - if -99.0 in grad_norm_groups and gpc.is_rank_for_log(): # -99.0 encodes a specific failure case - logger.warning(f"Warning: skip parameter update at step {batch_count}.") - send_alert_message( - address=gpc.config.alert_address, message=f"Warning: skip parameter update at step {batch_count}." - ) - - # calculate and record the training metrics, eg. loss, accuracy and so on. - record_current_batch_training_metrics( - get_tflops_func=get_tflops_func, - logger=logger, - writer=writer, - success_update=success_update, - batch_count=batch_count, - batch=batch, - train_state=train_state, - optimizer=optimizer, - beta2_scheduler=beta2_scheduler, - trainer=trainer, - start_time=start_time, - loss=loss, - grad_norm=np.array(grad_norm_groups), - metric=metric, - update_panel=uniscale_logger is not None, - ) - - timer("one-batch").stop() - - # evaluate on validation data loaders - if valid_every > 0 and train_state.step_count % valid_every == 0: - with switch_sequence_parallel_mode(): + # evaluate on validation data loaders + if valid_every > 0 and train_state.step_count % valid_every == 0: evaluate_on_val_dls( trainer=trainer, val_dls=val_dls, @@ -686,14 +278,19 @@ def main(args): update_panel=uniscale_logger is not None, ) - if memory_profiler is not None: - memory_profiler.step() + # checkpoint the training states in specific steps, which is determined by the args "checkpoint_every" + # # save batch sampler that tracks the true consumed samples + now_break = ckpt_manager.try_save_checkpoint(train_state) + if now_break: + break - # checkpoint the training states in specific steps, which is determined by the args "checkpoint_every" - # # save batch sampler that tracks the true consumed samples - ckpt_save_manager.try_save_checkpoint(train_state) + if memory_profiler is not None: + memory_profiler.step() - ckpt_save_manager.wait_async_upload_finish() + if batch_count % 2 == 0: + prof.step() + + ckpt_manager.wait_async_upload_finish() if __name__ == "__main__":