diff --git a/internlm/core/scheduler/no_pipeline_scheduler.py b/internlm/core/scheduler/no_pipeline_scheduler.py index daab2bb..dd9b49a 100644 --- a/internlm/core/scheduler/no_pipeline_scheduler.py +++ b/internlm/core/scheduler/no_pipeline_scheduler.py @@ -127,7 +127,7 @@ class NonPipelineScheduler(BaseScheduler): if not return_loss: loss = None - return output, loss + return output, loss, moe_loss def forward_backward_step( self, @@ -166,6 +166,7 @@ class NonPipelineScheduler(BaseScheduler): data, label = batch_data loss = 0 if return_loss else None + moe_loss = 0 if return_loss else None outputs = [] labels = [] @@ -180,12 +181,14 @@ class NonPipelineScheduler(BaseScheduler): _data, _label = self._load_accum_batch(data, label) - _output, _loss = self._train_one_batch( + _output, _loss, _moe_loss = self._train_one_batch( _data, _label, engine, forward_only, return_loss, self._grad_accum_size, moe_loss_coeff ) if return_loss: loss += _loss + moe_loss += _moe_loss + if return_output_label: outputs.append(_output) labels.append(_label) @@ -193,4 +196,4 @@ class NonPipelineScheduler(BaseScheduler): if not return_output_label: outputs, labels = None, None - return outputs, labels, loss + return outputs, labels, loss, moe_loss diff --git a/internlm/core/scheduler/pipeline_scheduler.py b/internlm/core/scheduler/pipeline_scheduler.py index 8396169..19975b6 100644 --- a/internlm/core/scheduler/pipeline_scheduler.py +++ b/internlm/core/scheduler/pipeline_scheduler.py @@ -7,6 +7,7 @@ from contextlib import contextmanager from typing import Callable, List, Optional, Tuple, Union import torch.cuda +import torch.distributed as dist import internlm.core.communication as comm from internlm.core.context import ParallelMode @@ -239,7 +240,8 @@ class PipelineScheduler(BaseScheduler): """ return step_id - def _forward_step(self, engine, input_obj, return_tensors, return_output_label=True, accum_loss=None, moe_loss_coeff:float=1.0): + def _forward_step(self, engine, input_obj, return_tensors, return_output_label=True, + accum_loss=None, accum_moe_loss=None, moe_loss_coeff:float=1.0): """ Forward step for passed-in model. If it is the first stage, the input tensor is obtained from data_iterator, otherwise the passed-in input_obj is used. @@ -251,6 +253,7 @@ class PipelineScheduler(BaseScheduler): return_tensors (List[:class:`torch.Tensor`]): A list of tensors to return. return_output_label (bool, optional): Whether returns output labels. accum_loss (optional): Where accumulated loss stores. + accum_moe_loss (optional): Where accumulated moe loss stores. Returns: Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]: output or the loss value of the current pipeline stage. @@ -277,6 +280,7 @@ class PipelineScheduler(BaseScheduler): moe_loss = sum(moe_losses) * moe_loss_coeff moe_loss /= self.num_microbatches + accum_moe_loss.add_(moe_loss.detach()) return output_obj, moe_loss @@ -366,6 +370,7 @@ class PipelineScheduler(BaseScheduler): if return_loss and gpc.is_pipeline_last_stage(ignore_virtual=True) else None ) + accum_moe_loss = torch.zeros(1, device=get_current_device()) # Used for tensor meta information communication forward_recv_shapes = self.tensor_shape @@ -392,6 +397,7 @@ class PipelineScheduler(BaseScheduler): return_tensors, return_output_label=return_output_label, accum_loss=accum_loss, + accum_moe_loss=accum_moe_loss, moe_loss_coeff=moe_loss_coeff, ) @@ -403,8 +409,12 @@ class PipelineScheduler(BaseScheduler): comm.send_forward(output_obj, scatter_gather_tensors=self.scatter_gather_tensors) output, label = pack_return_tensors(return_tensors) if len(return_tensors) > 0 else (None, None) + dist.all_reduce(accum_moe_loss, group=gpc.get_group(ParallelMode.PIPELINE)) - return output, label, accum_loss + if accum_loss is not None: + accum_loss += accum_moe_loss + + return output, label, accum_loss, accum_moe_loss def _forward_backward_step(self, engine, return_loss=True, return_output_label=True, moe_loss_coeff:float=1.0): """ @@ -459,6 +469,7 @@ class PipelineScheduler(BaseScheduler): if return_loss and gpc.is_pipeline_last_stage(ignore_virtual=True) else None ) + accum_moe_loss = torch.zeros(1, device=get_current_device()) # Used for tensor meta information communication forward_recv_shapes = self.tensor_shape @@ -486,6 +497,7 @@ class PipelineScheduler(BaseScheduler): return_tensors, return_output_label=return_output_label, accum_loss=accum_loss, + accum_moe_loss=accum_moe_loss, moe_loss_coeff=moe_loss_coeff, ) @@ -532,6 +544,7 @@ class PipelineScheduler(BaseScheduler): return_tensors, return_output_label=return_output_label, accum_loss=accum_loss, + accum_moe_loss=accum_moe_loss, moe_loss_coeff=moe_loss_coeff, ) @@ -598,11 +611,18 @@ class PipelineScheduler(BaseScheduler): if not gpc.is_first_rank(ParallelMode.PIPELINE): comm.send_backward(input_obj_grad, scatter_gather_tensors=self.scatter_gather_tensors) + logger.info(f"{gpc.get_local_rank(ParallelMode.PIPELINE)}, moe_loss: {accum_moe_loss.item()}") + output, label = pack_return_tensors(return_tensors) if len(return_tensors) > 0 else (None, None) + dist.all_reduce(accum_moe_loss, group=gpc.get_group(ParallelMode.PIPELINE)) - return output, label, accum_loss + if accum_loss is not None: + accum_loss += accum_moe_loss - def forward_backward_step(self, engine, data_iter, forward_only=False, return_loss=True, return_output_label=True, moe_loss_coeff:float=1.0): + return output, label, accum_loss, accum_moe_loss + + def forward_backward_step(self, engine, data_iter, forward_only=False, return_loss=True, + return_output_label=True, moe_loss_coeff:float=1.0): """Runs non-interleaved 1F1B schedule, with communication between pipeline stages. Returns a tuple with losses if the last stage, an empty tuple otherwise. @@ -614,7 +634,7 @@ class PipelineScheduler(BaseScheduler): return_loss (bool, optional): Whether returns the loss value. Default is true. return_output_label (bool, optional): If False, the output and label won't be returned. Returns: - Tuple[:class:`torch.Tensor`]: A tuple of (output, label, loss), loss and label could be None. + Tuple[:class:`torch.Tensor`]: A tuple of (output, label, loss, loss), loss and label could be None. """ assert ( @@ -694,6 +714,7 @@ class InterleavedPipelineScheduler(PipelineScheduler): self._pp_rank = gpc.get_local_rank(ParallelMode.PIPELINE) self._accum_loss = None + self._accum_moe_loss = None self._return_tensors = None self._input_objs = [[] for _ in range(num_chunks)] self._output_objs = [[] for _ in range(num_chunks)] @@ -706,6 +727,7 @@ class InterleavedPipelineScheduler(PipelineScheduler): def _clear_state(self) -> None: self._accum_loss = None + self._accum_moe_loss = None self._return_tensors = None self._input_objs = [[] for _ in range(self._num_chunks)] self._output_objs = [[] for _ in range(self._num_chunks)] @@ -777,6 +799,9 @@ class InterleavedPipelineScheduler(PipelineScheduler): moe_loss = sum(moe_losses) * moe_loss_coeff moe_loss /= self.num_microbatches + if self._accum_moe_loss is not None: + self._accum_moe_loss.add_(moe_loss.detach()) + self._output_objs[chunk_id].append(output_obj) self._moe_losses[chunk_id].append(moe_loss) @@ -1287,6 +1312,8 @@ class InterleavedPipelineScheduler(PipelineScheduler): if return_loss and gpc.is_pipeline_last_stage(ignore_virtual=True): self._accum_loss = torch.zeros(1, device=get_current_device()) + if return_loss: + self._accum_moe_loss = torch.zeros(1, device=get_current_device()) if return_output_label: self._return_tensors = [] @@ -1301,6 +1328,11 @@ class InterleavedPipelineScheduler(PipelineScheduler): output, label = (None, None) accum_loss = self._accum_loss + logger.info(f"{gpc.get_local_rank(ParallelMode.PIPELINE)}, moe_loss: {accum_moe_loss.item()}") + + dist.all_reduce(self._accum_moe_loss, group=gpc.get_group(ParallelMode.PIPELINE)) + accum_moe_loss = self._accum_moe_loss + self._clear_state() - return output, label, accum_loss + return output, label, accum_loss, accum_moe_loss diff --git a/internlm/core/trainer.py b/internlm/core/trainer.py index 536c1aa..fd899ba 100644 --- a/internlm/core/trainer.py +++ b/internlm/core/trainer.py @@ -155,5 +155,5 @@ class Trainer: Returns: Tuple[:class:`torch.Tensor`]: A tuple of (output, label, loss). """ - output, label, loss = self._schedule.forward_backward_step(self._engine, data_iter, **kwargs) - return output, label, loss + output, label, loss, moe_loss = self._schedule.forward_backward_step(self._engine, data_iter, **kwargs) + return output, label, loss, moe_loss diff --git a/internlm/utils/evaluation.py b/internlm/utils/evaluation.py index d10f0c1..f6c86b6 100644 --- a/internlm/utils/evaluation.py +++ b/internlm/utils/evaluation.py @@ -100,7 +100,7 @@ def evaluate_on_val_dls( tensor_shape=tensor_shape, metric_hook_list=[val_sche_metric_hook], ): - _, _, loss = trainer.execute_schedule( + _, _, loss, _ = trainer.execute_schedule( batch, forward_only=True, return_loss=True, return_output_label=False ) else: @@ -114,7 +114,7 @@ def evaluate_on_val_dls( grad_accum_batch_size=grad_accum_batch_size, metric_hook_list=[val_sche_metric_hook], ): - _, _, loss = trainer.execute_schedule( + _, _, loss, _ = trainer.execute_schedule( batch, forward_only=True, return_loss=True, return_output_label=False ) if verbose: diff --git a/train.py b/train.py index 39fa942..7184c70 100644 --- a/train.py +++ b/train.py @@ -341,6 +341,7 @@ def record_current_batch_training_metrics( trainer, start_time, loss, + moe_loss, grad_norm, metric, update_panel, @@ -384,6 +385,7 @@ def record_current_batch_training_metrics( "tflops": tflops, "step": batch_count, "loss": loss.item(), + "moe_loss": moe_loss.item(), "tgs (tokens/gpu/second)": tk_per_gpu, "lr": lr, "loss_scale": scaler, @@ -419,6 +421,7 @@ def record_current_batch_training_metrics( "num_consumed_tokens": train_state.num_consumed_tokens, "grad_norm": grad_norm, "loss": loss.item(), + "moe_loss": moe_loss.item(), "flops": tflops, "tgs": tk_per_gpu, "acc": acc_perplex["acc"], @@ -606,7 +609,7 @@ def main(args): # do forward and backward timer("fwd-bwd").start() - _, _, loss = trainer.execute_schedule( + _, _, loss, moe_loss = trainer.execute_schedule( batch, forward_only=False, return_loss=True, @@ -644,6 +647,7 @@ def main(args): trainer=trainer, start_time=start_time, loss=loss, + moe_loss=moe_loss, grad_norm=grad_norm, metric=metric, update_panel=uniscale_logger is not None,