diff --git a/configs/moe_cfg.py b/configs/moe_cfg.py new file mode 100644 index 0000000..89e1a96 --- /dev/null +++ b/configs/moe_cfg.py @@ -0,0 +1,152 @@ +JOB_NAME = "7b_train" + +SEQ_LEN = 2048 +HIDDEN_SIZE = 4096 +NUM_ATTENTION_HEAD = 32 +MLP_RATIO = 8 / 3 +NUM_LAYER = 16 +VOCAB_SIZE = 103168 + +MODEL_ONLY_FOLDER = "local:llm_ckpts/xxxx" +# Ckpt folder format: +# fs: 'local:/mnt/nfs/XXX' +SAVE_CKPT_FOLDER = "local:llm_ckpts" +LOAD_CKPT_FOLDER = "local:llm_ckpts/49" + +# boto3 Ckpt folder format: +# import os +# BOTO3_IP = os.environ["BOTO3_IP"] # boto3 bucket endpoint +# SAVE_CKPT_FOLDER = f"boto3:s3://model_weights.{BOTO3_IP}/internlm" +# LOAD_CKPT_FOLDER = f"boto3:s3://model_weights.{BOTO3_IP}/internlm/snapshot/1/" +CHECKPOINT_EVERY = 50 +ckpt = dict( + enable_save_ckpt=False, # enable ckpt save. + save_ckpt_folder=SAVE_CKPT_FOLDER, # Path to save training ckpt. + # load_ckpt_folder=LOAD_CKPT_FOLDER, # Ckpt path to resume training(load weights and scheduler/context states). + # load_model_only_folder=MODEL_ONLY_FOLDER, # Path to initialize with given model weights. + load_optimizer=True, # Wheter to load optimizer states when continuing training. + checkpoint_every=CHECKPOINT_EVERY, + async_upload=True, # async ckpt upload. (only work for boto3 ckpt) + async_upload_tmp_folder="/dev/shm/internlm_tmp_ckpt/", # path for temporarily files during asynchronous upload. + snapshot_ckpt_folder="/".join([SAVE_CKPT_FOLDER, "snapshot"]), # directory for snapshot ckpt storage path. + oss_snapshot_freq=int(CHECKPOINT_EVERY / 2), # snapshot ckpt save frequency. +) + +TRAIN_FOLDER = "/mnt/petrelfs/share_data/llm_data/0623_scratch_tokenized_filtered/train/en/enwiki" +VALID_FOLDER = "/mnt/petrelfs/share_data/llm_data/0623_scratch_tokenized_filtered/train/en/enwiki" +data = dict( + seq_len=SEQ_LEN, + # micro_num means the number of micro_batch contained in one gradient update + micro_num=4, + packed_length = 2 * SEQ_LEN, + micro_bsz=2, + # defaults to the value of micro_num + valid_micro_num=4, + # defaults to 0, means disable evaluate + valid_every=50000, + pack_sample_into_one=False, + total_steps=50000, + skip_batches="", + rampup_batch_size="", + # Datasets with less than 50 rows will be discarded + min_length=50, + train_folder=TRAIN_FOLDER, + valid_folder=VALID_FOLDER, +) + +grad_scaler = dict( + fp16=dict( + # the initial loss scale, defaults to 2**16 + initial_scale=2**16, + # the minimum loss scale, defaults to None + min_scale=1, + # the number of steps to increase loss scale when no overflow occurs + growth_interval=1000, + ), + # the multiplication factor for increasing loss scale, defaults to 2 + growth_factor=2, + # the multiplication factor for decreasing loss scale, defaults to 0.5 + backoff_factor=0.5, + # the maximum loss scale, defaults to None + max_scale=2**24, + # the number of overflows before decreasing loss scale, defaults to 2 + hysteresis=2, +) + +hybrid_zero_optimizer = dict( + # Enable low_level_optimzer overlap_communication + zero_overlap_communication=True, + # bucket size for nccl communication params + reduce_bucket_size=512 * 1024 * 1024, + # grad clipping + clip_grad_norm=1.0, +) + +loss = dict( + label_smoothing=0, + moe_loss_coeff=0.1, +) + +adam = dict( + lr=1e-4, + adam_beta1=0.9, + adam_beta2=0.95, + adam_beta2_c=0, + adam_eps=1e-8, + weight_decay=0.01, +) + +lr_scheduler = dict( + total_steps=data["total_steps"], + init_steps=0, # optimizer_warmup_step + warmup_ratio=0.01, + eta_min=1e-5, + last_epoch=-1, +) + +beta2_scheduler = dict( + init_beta2=adam["adam_beta2"], + c=adam["adam_beta2_c"], + cur_iter=-1, +) + +model = dict( + checkpoint=False, + num_attention_heads=NUM_ATTENTION_HEAD, + embed_split_hidden=True, + vocab_size=VOCAB_SIZE, + embed_grad_scale=1, + parallel_output=True, + hidden_size=HIDDEN_SIZE, + num_layers=NUM_LAYER, + mlp_ratio=MLP_RATIO, + apply_post_layer_norm=False, + dtype="torch.bfloat16", + norm_type="rmsnorm", + layer_norm_epsilon=1e-5, + use_flash_attn=True, + num_chunks=1, # if num_chunks > 1, interleaved pipeline scheduler is used. + sequence_parallel=False, + num_experts=4, + moe_use_residual=False, +) +""" +zero1 parallel: + 1. if zero1 <= 0, The size of the zero process group is equal to the size of the dp process group, + so parameters will be divided within the range of dp. + 2. if zero1 == 1, zero is not used, and all dp groups retain the full amount of model parameters. + 3. zero1 > 1 and zero1 <= dp world size, the world size of zero is a subset of dp world size. + For smaller models, it is usually a better choice to split the parameters within nodes with a setting <= 8. +pipeline parallel (dict): + 1. size: int, the size of pipeline parallel. + 2. interleaved_overlap: bool, enable/disable communication overlap when using interleaved pipeline scheduler. +tensor parallel: tensor parallel size, usually the number of GPUs per node. +""" +parallel = dict( + # zero1=4, + pipeline=dict(size=4, interleaved_overlap=False), + # tensor=dict(size=4), +) + +cudnn_deterministic = False +cudnn_benchmark = False diff --git a/internlm/core/scheduler/no_pipeline_scheduler.py b/internlm/core/scheduler/no_pipeline_scheduler.py index daab2bb..dd9b49a 100644 --- a/internlm/core/scheduler/no_pipeline_scheduler.py +++ b/internlm/core/scheduler/no_pipeline_scheduler.py @@ -127,7 +127,7 @@ class NonPipelineScheduler(BaseScheduler): if not return_loss: loss = None - return output, loss + return output, loss, moe_loss def forward_backward_step( self, @@ -166,6 +166,7 @@ class NonPipelineScheduler(BaseScheduler): data, label = batch_data loss = 0 if return_loss else None + moe_loss = 0 if return_loss else None outputs = [] labels = [] @@ -180,12 +181,14 @@ class NonPipelineScheduler(BaseScheduler): _data, _label = self._load_accum_batch(data, label) - _output, _loss = self._train_one_batch( + _output, _loss, _moe_loss = self._train_one_batch( _data, _label, engine, forward_only, return_loss, self._grad_accum_size, moe_loss_coeff ) if return_loss: loss += _loss + moe_loss += _moe_loss + if return_output_label: outputs.append(_output) labels.append(_label) @@ -193,4 +196,4 @@ class NonPipelineScheduler(BaseScheduler): if not return_output_label: outputs, labels = None, None - return outputs, labels, loss + return outputs, labels, loss, moe_loss diff --git a/internlm/core/scheduler/pipeline_scheduler.py b/internlm/core/scheduler/pipeline_scheduler.py index ebdb374..ba919d7 100644 --- a/internlm/core/scheduler/pipeline_scheduler.py +++ b/internlm/core/scheduler/pipeline_scheduler.py @@ -7,6 +7,7 @@ from contextlib import contextmanager from typing import Callable, List, Optional, Tuple, Union import torch.cuda +import torch.distributed as dist import internlm.core.communication as comm from internlm.core.context import ParallelMode @@ -239,7 +240,16 @@ class PipelineScheduler(BaseScheduler): """ return step_id - def _forward_step(self, engine, input_obj, return_tensors, return_output_label=True, accum_loss=None): + def _forward_step( + self, + engine, + input_obj, + return_tensors, + return_output_label=True, + accum_loss=None, + accum_moe_loss=None, + moe_loss_coeff=1.0, + ): """ Forward step for passed-in model. If it is the first stage, the input tensor is obtained from data_iterator, otherwise the passed-in input_obj is used. @@ -251,6 +261,7 @@ class PipelineScheduler(BaseScheduler): return_tensors (List[:class:`torch.Tensor`]): A list of tensors to return. return_output_label (bool, optional): Whether returns output labels. accum_loss (optional): Where accumulated loss stores. + accum_moe_loss (optional): Where accumulated moe loss stores. Returns: Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]: output or the loss value of the current pipeline stage. @@ -259,7 +270,7 @@ class PipelineScheduler(BaseScheduler): data, label = self._get_data_label_for_current_step(input_obj, micro_batch_data) self._call_hooks("before_forward", data) - output_obj = self._call_engine(engine.model, data) + output_obj, moe_losses = self._call_engine(engine.model, data) self._call_hooks("after_forward", output_obj) if gpc.is_last_rank(ParallelMode.PIPELINE): @@ -275,9 +286,13 @@ class PipelineScheduler(BaseScheduler): accum_loss.add_(loss_reduced.detach()) output_obj = loss_reduced - return output_obj + moe_loss = sum(moe_losses) * moe_loss_coeff + moe_loss /= self.num_microbatches + accum_moe_loss.add_(moe_loss.detach()) - def _backward_step(self, engine, step_id, input_obj, output_obj, output_obj_grad): + return output_obj, moe_loss + + def _backward_step(self, engine, step_id, input_obj, output_obj, output_obj_grad, moe_loss=None): """ Backward step through the passed-in output tensor. If it is the last stage, the output_obj_grad is None, otherwise it is the gradients with respect to stage's output tensor. @@ -311,10 +326,18 @@ class PipelineScheduler(BaseScheduler): self._call_hooks("before_backward", output_obj, output_obj_grad) with switch_optimizer_grad_sync_skip_mode(engine.optimizer, skip_grad_sync): - if output_obj_grad is None: - engine.backward(output_obj) + if moe_loss is None: + if output_obj_grad is None: + engine.backward(output_obj) + else: + engine.backward_by_grad(output_obj, output_obj_grad) else: - engine.backward_by_grad(output_obj, output_obj_grad) + if output_obj_grad is None: + engine.backward(output_obj + moe_loss) + else: + # scale the latent loss + moe_loss = moe_loss * engine.optimizer.loss_scale + engine.backward_by_grad([output_obj, moe_loss], [output_obj_grad, None]) # Collect the grad of the input_obj. input_obj_grad = None @@ -329,7 +352,7 @@ class PipelineScheduler(BaseScheduler): return input_obj_grad - def _forward_only_step(self, engine, return_loss=True, return_output_label=True): + def _forward_only_step(self, engine, return_loss=True, return_output_label=True, moe_loss_coeff=1.0): """ This function performs forward only computation process. The scheduling of microbatches is similar to the warmup phase, where each microbatch first receives the forward input from the previous stage, then performs @@ -356,6 +379,7 @@ class PipelineScheduler(BaseScheduler): if return_loss and gpc.is_pipeline_last_stage(ignore_virtual=True) else None ) + accum_moe_loss = torch.zeros(1, device=get_current_device()) # Used for tensor meta information communication forward_recv_shapes = self.tensor_shape @@ -376,12 +400,14 @@ class PipelineScheduler(BaseScheduler): input_obj = None # Perform forward computation - output_obj = self._forward_step( + output_obj, _ = self._forward_step( engine, input_obj, return_tensors, return_output_label=return_output_label, accum_loss=accum_loss, + accum_moe_loss=accum_moe_loss, + moe_loss_coeff=moe_loss_coeff, ) if not gpc.is_last_rank(ParallelMode.PIPELINE): @@ -392,10 +418,14 @@ class PipelineScheduler(BaseScheduler): comm.send_forward(output_obj, scatter_gather_tensors=self.scatter_gather_tensors) output, label = pack_return_tensors(return_tensors) if len(return_tensors) > 0 else (None, None) + dist.all_reduce(accum_moe_loss, group=gpc.get_group(ParallelMode.PIPELINE)) - return output, label, accum_loss + if accum_loss is not None: + accum_loss += accum_moe_loss - def _forward_backward_step(self, engine, return_loss=True, return_output_label=True): + return output, label, accum_loss, accum_moe_loss + + def _forward_backward_step(self, engine, return_loss=True, return_output_label=True, moe_loss_coeff=1.0): """ This function schedules the forward and backward computation of microbatches in the pipeline in a 1F1B manner. It consists of three stages: warmup, 1F1B, and cooldown. @@ -441,12 +471,14 @@ class PipelineScheduler(BaseScheduler): # Input, output tensors only need to be saved when doing backward passes input_objs = [] output_objs = [] + moe_losses = [] return_tensors = [] accum_loss = ( torch.zeros(1, device=get_current_device()) if return_loss and gpc.is_pipeline_last_stage(ignore_virtual=True) else None ) + accum_moe_loss = torch.zeros(1, device=get_current_device()) # Used for tensor meta information communication forward_recv_shapes = self.tensor_shape @@ -468,12 +500,14 @@ class PipelineScheduler(BaseScheduler): input_obj = None # Perform forward computation - output_obj = self._forward_step( + output_obj, moe_loss = self._forward_step( engine, input_obj, return_tensors, return_output_label=return_output_label, accum_loss=accum_loss, + accum_moe_loss=accum_moe_loss, + moe_loss_coeff=moe_loss_coeff, ) if not gpc.is_last_rank(ParallelMode.PIPELINE): @@ -493,6 +527,7 @@ class PipelineScheduler(BaseScheduler): input_objs.append(input_obj) output_objs.append(output_obj) + moe_losses.append(moe_loss) # Before running 1F1B, need to receive first forward tensor. # If all microbatches are run in warmup / cooldown phase, then no need to @@ -512,12 +547,14 @@ class PipelineScheduler(BaseScheduler): # Run 1F1B in steady state. for i in range(num_1f1b_micropairs): # Perform forward computation - output_obj = self._forward_step( + output_obj, moe_loss = self._forward_step( engine, input_obj, return_tensors, return_output_label=return_output_label, accum_loss=accum_loss, + accum_moe_loss=accum_moe_loss, + moe_loss_coeff=moe_loss_coeff, ) if gpc.is_last_rank(ParallelMode.PIPELINE): @@ -533,13 +570,15 @@ class PipelineScheduler(BaseScheduler): # Add input_obj and output_obj to end of list. input_objs.append(input_obj) output_objs.append(output_obj) + moe_losses.append(moe_loss) # Pop output_obj and output_obj from the start of the list for # the backward pass. input_obj = input_objs.pop(0) output_obj = output_objs.pop(0) + moe_loss = moe_losses.pop(0) - input_obj_grad = self._backward_step(engine, i, input_obj, output_obj, output_obj_grad) + input_obj_grad = self._backward_step(engine, i, input_obj, output_obj, output_obj_grad, moe_loss) if i == (num_1f1b_micropairs - 1): input_obj = None @@ -563,6 +602,7 @@ class PipelineScheduler(BaseScheduler): for i in range(num_warmup_microsteps): input_obj = input_objs.pop(0) output_obj = output_objs.pop(0) + moe_loss = moe_losses.pop(0) if not gpc.is_last_rank(ParallelMode.PIPELINE): output_obj_grad = comm.recv_backward( @@ -574,17 +614,25 @@ class PipelineScheduler(BaseScheduler): output_obj_grad = None input_obj_grad = self._backward_step( - engine, num_1f1b_micropairs + i, input_obj, output_obj, output_obj_grad + engine, num_1f1b_micropairs + i, input_obj, output_obj, output_obj_grad, moe_loss ) if not gpc.is_first_rank(ParallelMode.PIPELINE): comm.send_backward(input_obj_grad, scatter_gather_tensors=self.scatter_gather_tensors) + logger.info(f"{gpc.get_local_rank(ParallelMode.PIPELINE)}, moe_loss: {accum_moe_loss.item()}") + output, label = pack_return_tensors(return_tensors) if len(return_tensors) > 0 else (None, None) + dist.all_reduce(accum_moe_loss, group=gpc.get_group(ParallelMode.PIPELINE)) - return output, label, accum_loss + if accum_loss is not None: + accum_loss += accum_moe_loss - def forward_backward_step(self, engine, data_iter, forward_only=False, return_loss=True, return_output_label=True): + return output, label, accum_loss, accum_moe_loss + + def forward_backward_step( + self, engine, data_iter, forward_only=False, return_loss=True, return_output_label=True, moe_loss_coeff=1.0 + ): """Runs non-interleaved 1F1B schedule, with communication between pipeline stages. Returns a tuple with losses if the last stage, an empty tuple otherwise. @@ -596,7 +644,7 @@ class PipelineScheduler(BaseScheduler): return_loss (bool, optional): Whether returns the loss value. Default is true. return_output_label (bool, optional): If False, the output and label won't be returned. Returns: - Tuple[:class:`torch.Tensor`]: A tuple of (output, label, loss), loss and label could be None. + Tuple[:class:`torch.Tensor`]: A tuple of (output, label, loss, loss), loss and label could be None. """ assert ( @@ -607,9 +655,9 @@ class PipelineScheduler(BaseScheduler): self.load_batch(engine, data_iter) if forward_only: - return self._forward_only_step(engine, return_loss, return_output_label) + return self._forward_only_step(engine, return_loss, return_output_label, moe_loss_coeff) else: - return self._forward_backward_step(engine, return_loss, return_output_label) + return self._forward_backward_step(engine, return_loss, return_output_label, moe_loss_coeff) class InterleavedPipelineScheduler(PipelineScheduler): @@ -676,10 +724,12 @@ class InterleavedPipelineScheduler(PipelineScheduler): self._pp_rank = gpc.get_local_rank(ParallelMode.PIPELINE) self._accum_loss = None + self._accum_moe_loss = None self._return_tensors = None self._input_objs = [[] for _ in range(num_chunks)] self._output_objs = [[] for _ in range(num_chunks)] self._output_obj_grads = [[] for _ in range(num_chunks)] + self._moe_losses = [[] for _ in range(num_chunks)] self._input_obj_shapes = [self.tensor_shape for _ in range(num_chunks)] self._output_obj_shapes = [None for _ in range(num_chunks)] @@ -687,10 +737,12 @@ class InterleavedPipelineScheduler(PipelineScheduler): def _clear_state(self) -> None: self._accum_loss = None + self._accum_moe_loss = None self._return_tensors = None self._input_objs = [[] for _ in range(self._num_chunks)] self._output_objs = [[] for _ in range(self._num_chunks)] self._output_obj_grads = [[] for _ in range(self._num_chunks)] + self._moe_losses = [[] for _ in range(self._num_chunks)] self._input_obj_shapes = [self.tensor_shape for _ in range(self._num_chunks)] self._output_obj_shapes = [None for _ in range(self._num_chunks)] @@ -712,7 +764,7 @@ class InterleavedPipelineScheduler(PipelineScheduler): self.microbatch_offset[model_chunk_id] += self.microbatch_size return move_to_device(micro_batch_data) - def _forward_step(self, engine, chunk_id): + def _forward_step(self, engine, chunk_id, moe_loss_coeff=1.0): """Forward step for passed-in model. If it is the first stage, the input tensor is obtained from data_iterator, otherwise the passed-in input_obj is used. Returns output tensor. This is a helper function and can be ignored by users. @@ -734,7 +786,7 @@ class InterleavedPipelineScheduler(PipelineScheduler): data, label = self._get_data_label_for_current_step(input_obj, micro_batch_data) self._call_hooks("before_forward", data) - output_obj = self._call_engine(engine.model[chunk_id], data) + output_obj, moe_losses = self._call_engine(engine.model[chunk_id], data) # Convert output_obj to fp32 when last model chunk of last stage if gpc.is_pipeline_last_stage(ignore_virtual=False) and isinstance(engine.model[chunk_id], NaiveAMPModel): output_obj = engine.model[chunk_id].convert_to_fp32(output_obj) @@ -754,7 +806,14 @@ class InterleavedPipelineScheduler(PipelineScheduler): self._accum_loss.add_(loss_reduced.detach()) output_obj = loss_reduced + moe_loss = sum(moe_losses) * moe_loss_coeff + moe_loss /= self.num_microbatches + + if self._accum_moe_loss is not None: + self._accum_moe_loss.add_(moe_loss.detach()) + self._output_objs[chunk_id].append(output_obj) + self._moe_losses[chunk_id].append(moe_loss) return output_obj @@ -780,8 +839,9 @@ class InterleavedPipelineScheduler(PipelineScheduler): input_obj = self._input_objs[chunk_id].pop(0) output_obj = self._output_objs[chunk_id].pop(0) output_obj_grad = self._output_obj_grads[chunk_id].pop(0) + moe_loss = self._moe_losses[chunk_id].pop(0) - input_obj_grad = super()._backward_step(engine, step_id, input_obj, output_obj, output_obj_grad) + input_obj_grad = super()._backward_step(engine, step_id, input_obj, output_obj, output_obj_grad, moe_loss) return input_obj_grad @@ -813,6 +873,7 @@ class InterleavedPipelineScheduler(PipelineScheduler): num_warmup_microsteps: int, receive_extra_backward: bool = False, forward_only: bool = False, + moe_loss_coeff: float = 1.0, ) -> None: """ Run the warm-up loop and prepare data for the 1F1B stage. @@ -850,12 +911,13 @@ class InterleavedPipelineScheduler(PipelineScheduler): for k in range(num_warmup_microsteps): chunk_id = self._get_chunk_by_microbatch(k) - output_obj = self._forward_step(engine, chunk_id) + output_obj = self._forward_step(engine, chunk_id, moe_loss_coeff) if forward_only: # when forward-only, no need to save tensors for a backward pass self._input_objs[chunk_id].pop() self._output_objs[chunk_id].pop() + self._moe_losses[chunk_id].pop() if not gpc.is_pipeline_last_stage(): if isinstance(output_obj, torch.Tensor): @@ -931,6 +993,7 @@ class InterleavedPipelineScheduler(PipelineScheduler): num_warmup_microsteps: int, num_1f1b_micropairs: int, all_warmup_microsteps: bool = False, + moe_loss_coeff: float = 1.0, ) -> None: """ Run the 1F1B loop with overlap. @@ -960,7 +1023,7 @@ class InterleavedPipelineScheduler(PipelineScheduler): backward_chunk_id = self._get_chunk_by_microbatch(backward_microstep_id, backward=True) # 1. Forward pass. - output_obj = self._forward_step(engine, forward_chunk_id) + output_obj = self._forward_step(engine, forward_chunk_id, moe_loss_coeff) # 2. Check if the backward input is ready. if backward_async_communicator is not None: @@ -1045,6 +1108,7 @@ class InterleavedPipelineScheduler(PipelineScheduler): num_warmup_microsteps: int, num_1f1b_micropairs: int, all_warmup_microsteps: bool = False, + moe_loss_coeff: float = 1.0, ) -> None: """ Run the 1F1B loop without overlap. @@ -1066,7 +1130,7 @@ class InterleavedPipelineScheduler(PipelineScheduler): # Forward pass. forward_microstep_id = k + num_warmup_microsteps forward_chunk_id = self._get_chunk_by_microbatch(forward_microstep_id) - output_obj = self._forward_step(engine, forward_chunk_id) + output_obj = self._forward_step(engine, forward_chunk_id, moe_loss_coeff) # Backward pass. backward_microstep_id = k @@ -1171,7 +1235,7 @@ class InterleavedPipelineScheduler(PipelineScheduler): ) ) - def _forward_only_step(self, engine: Engine): + def _forward_only_step(self, engine: Engine, moe_loss_coeff: float = 1.0): num_microsteps = self.num_microbatches * self._num_chunks num_warmup_microsteps = num_microsteps @@ -1181,9 +1245,10 @@ class InterleavedPipelineScheduler(PipelineScheduler): num_warmup_microsteps, receive_extra_backward=False, forward_only=True, + moe_loss_coeff=moe_loss_coeff, ) - def _forward_backward_step(self, engine: Engine): + def _forward_backward_step(self, engine: Engine, moe_loss_coeff: float = 1.0): # Compute number of warmup and remaining microbatches. all_warmup_microsteps = False num_microsteps = self.num_microbatches * self._num_chunks @@ -1217,6 +1282,7 @@ class InterleavedPipelineScheduler(PipelineScheduler): num_microsteps, num_warmup_steps, receive_extra_backward=receive_extra_backward, + moe_loss_coeff=moe_loss_coeff, ) # 2. 1F1B @@ -1225,12 +1291,15 @@ class InterleavedPipelineScheduler(PipelineScheduler): num_warmup_steps, num_1f1b_micropairs=num_1f1b_micropairs, all_warmup_microsteps=all_warmup_microsteps, + moe_loss_coeff=moe_loss_coeff, ) # 3. Cooldown self._run_cooldown_loop(engine, num_microsteps, num_1f1b_micropairs=num_1f1b_micropairs) - def forward_backward_step(self, engine, data_iter, forward_only=False, return_loss=True, return_output_label=True): + def forward_backward_step( + self, engine, data_iter, forward_only=False, return_loss=True, return_output_label=True, moe_loss_coeff=1.0 + ): """Run interleaved 1F1B schedule (model split into model chunks), with communication between pipeline stages as needed. @@ -1254,20 +1323,30 @@ class InterleavedPipelineScheduler(PipelineScheduler): if return_loss and gpc.is_pipeline_last_stage(ignore_virtual=True): self._accum_loss = torch.zeros(1, device=get_current_device()) + self._accum_moe_loss = torch.zeros(1, device=get_current_device()) + if return_output_label: self._return_tensors = [] if forward_only: - self._forward_only_step(engine) + self._forward_only_step(engine, moe_loss_coeff) else: - self._forward_backward_step(engine) + self._forward_backward_step(engine, moe_loss_coeff) if return_output_label and len(self._return_tensors) > 0: output, label = pack_return_tensors(self._return_tensors) else: output, label = (None, None) + + logger.info(f"{gpc.get_local_rank(ParallelMode.PIPELINE)}, moe_loss: {self._accum_moe_loss.item()}") + + dist.all_reduce(self._accum_moe_loss, group=gpc.get_group(ParallelMode.PIPELINE)) + accum_moe_loss = self._accum_moe_loss + accum_loss = self._accum_loss + if accum_loss is not None: + accum_loss += self._accum_moe_loss self._clear_state() - return output, label, accum_loss + return output, label, accum_loss, accum_moe_loss diff --git a/internlm/core/trainer.py b/internlm/core/trainer.py index 536c1aa..fd899ba 100644 --- a/internlm/core/trainer.py +++ b/internlm/core/trainer.py @@ -155,5 +155,5 @@ class Trainer: Returns: Tuple[:class:`torch.Tensor`]: A tuple of (output, label, loss). """ - output, label, loss = self._schedule.forward_backward_step(self._engine, data_iter, **kwargs) - return output, label, loss + output, label, loss, moe_loss = self._schedule.forward_backward_step(self._engine, data_iter, **kwargs) + return output, label, loss, moe_loss diff --git a/internlm/utils/evaluation.py b/internlm/utils/evaluation.py index d10f0c1..f6c86b6 100644 --- a/internlm/utils/evaluation.py +++ b/internlm/utils/evaluation.py @@ -100,7 +100,7 @@ def evaluate_on_val_dls( tensor_shape=tensor_shape, metric_hook_list=[val_sche_metric_hook], ): - _, _, loss = trainer.execute_schedule( + _, _, loss, _ = trainer.execute_schedule( batch, forward_only=True, return_loss=True, return_output_label=False ) else: @@ -114,7 +114,7 @@ def evaluate_on_val_dls( grad_accum_batch_size=grad_accum_batch_size, metric_hook_list=[val_sche_metric_hook], ): - _, _, loss = trainer.execute_schedule( + _, _, loss, _ = trainer.execute_schedule( batch, forward_only=True, return_loss=True, return_output_label=False ) if verbose: diff --git a/train.py b/train.py index 02cce82..4e75343 100644 --- a/train.py +++ b/train.py @@ -346,6 +346,7 @@ def record_current_batch_training_metrics( trainer, start_time, loss, + moe_loss, grad_norm, metric, update_panel, @@ -389,6 +390,7 @@ def record_current_batch_training_metrics( "tflops": tflops, "step": batch_count, "loss": loss.item(), + "moe_loss": moe_loss.item(), "tgs (tokens/gpu/second)": tk_per_gpu, "lr": lr, "loss_scale": scaler, @@ -424,6 +426,7 @@ def record_current_batch_training_metrics( "num_consumed_tokens": train_state.num_consumed_tokens, "grad_norm": grad_norm, "loss": loss.item(), + "moe_loss": moe_loss.item(), "flops": tflops, "tgs": tk_per_gpu, "acc": acc_perplex["acc"], @@ -629,7 +632,7 @@ def main(args): # do forward and backward timer("fwd-bwd").start() - _, _, loss = trainer.execute_schedule( + _, _, loss, moe_loss = trainer.execute_schedule( batch, forward_only=False, return_loss=True, @@ -667,6 +670,7 @@ def main(args): trainer=trainer, start_time=start_time, loss=loss, + moe_loss=moe_loss, grad_norm=np.array(grad_norm_groups), metric=metric, update_panel=uniscale_logger is not None,