mirror of https://github.com/InternLM/InternLM
add logger for moe_loss
parent
8cdd1abb35
commit
2983076d89
|
@ -127,7 +127,7 @@ class NonPipelineScheduler(BaseScheduler):
|
|||
if not return_loss:
|
||||
loss = None
|
||||
|
||||
return output, loss
|
||||
return output, loss, moe_loss
|
||||
|
||||
def forward_backward_step(
|
||||
self,
|
||||
|
@ -166,6 +166,7 @@ class NonPipelineScheduler(BaseScheduler):
|
|||
data, label = batch_data
|
||||
|
||||
loss = 0 if return_loss else None
|
||||
moe_loss = 0 if return_loss else None
|
||||
outputs = []
|
||||
labels = []
|
||||
|
||||
|
@ -180,12 +181,14 @@ class NonPipelineScheduler(BaseScheduler):
|
|||
|
||||
_data, _label = self._load_accum_batch(data, label)
|
||||
|
||||
_output, _loss = self._train_one_batch(
|
||||
_output, _loss, _moe_loss = self._train_one_batch(
|
||||
_data, _label, engine, forward_only, return_loss, self._grad_accum_size, moe_loss_coeff
|
||||
)
|
||||
|
||||
if return_loss:
|
||||
loss += _loss
|
||||
moe_loss += _moe_loss
|
||||
|
||||
if return_output_label:
|
||||
outputs.append(_output)
|
||||
labels.append(_label)
|
||||
|
@ -193,4 +196,4 @@ class NonPipelineScheduler(BaseScheduler):
|
|||
if not return_output_label:
|
||||
outputs, labels = None, None
|
||||
|
||||
return outputs, labels, loss
|
||||
return outputs, labels, loss, moe_loss
|
||||
|
|
|
@ -7,6 +7,7 @@ from contextlib import contextmanager
|
|||
from typing import Callable, List, Optional, Tuple, Union
|
||||
|
||||
import torch.cuda
|
||||
import torch.distributed as dist
|
||||
|
||||
import internlm.core.communication as comm
|
||||
from internlm.core.context import ParallelMode
|
||||
|
@ -239,7 +240,8 @@ class PipelineScheduler(BaseScheduler):
|
|||
"""
|
||||
return step_id
|
||||
|
||||
def _forward_step(self, engine, input_obj, return_tensors, return_output_label=True, accum_loss=None, moe_loss_coeff:float=1.0):
|
||||
def _forward_step(self, engine, input_obj, return_tensors, return_output_label=True,
|
||||
accum_loss=None, accum_moe_loss=None, moe_loss_coeff:float=1.0):
|
||||
"""
|
||||
Forward step for passed-in model. If it is the first stage, the input tensor
|
||||
is obtained from data_iterator, otherwise the passed-in input_obj is used.
|
||||
|
@ -251,6 +253,7 @@ class PipelineScheduler(BaseScheduler):
|
|||
return_tensors (List[:class:`torch.Tensor`]): A list of tensors to return.
|
||||
return_output_label (bool, optional): Whether returns output labels.
|
||||
accum_loss (optional): Where accumulated loss stores.
|
||||
accum_moe_loss (optional): Where accumulated moe loss stores.
|
||||
Returns:
|
||||
Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]: output or the loss value of the current
|
||||
pipeline stage.
|
||||
|
@ -277,6 +280,7 @@ class PipelineScheduler(BaseScheduler):
|
|||
|
||||
moe_loss = sum(moe_losses) * moe_loss_coeff
|
||||
moe_loss /= self.num_microbatches
|
||||
accum_moe_loss.add_(moe_loss.detach())
|
||||
|
||||
return output_obj, moe_loss
|
||||
|
||||
|
@ -366,6 +370,7 @@ class PipelineScheduler(BaseScheduler):
|
|||
if return_loss and gpc.is_pipeline_last_stage(ignore_virtual=True)
|
||||
else None
|
||||
)
|
||||
accum_moe_loss = torch.zeros(1, device=get_current_device())
|
||||
|
||||
# Used for tensor meta information communication
|
||||
forward_recv_shapes = self.tensor_shape
|
||||
|
@ -392,6 +397,7 @@ class PipelineScheduler(BaseScheduler):
|
|||
return_tensors,
|
||||
return_output_label=return_output_label,
|
||||
accum_loss=accum_loss,
|
||||
accum_moe_loss=accum_moe_loss,
|
||||
moe_loss_coeff=moe_loss_coeff,
|
||||
)
|
||||
|
||||
|
@ -403,8 +409,12 @@ class PipelineScheduler(BaseScheduler):
|
|||
comm.send_forward(output_obj, scatter_gather_tensors=self.scatter_gather_tensors)
|
||||
|
||||
output, label = pack_return_tensors(return_tensors) if len(return_tensors) > 0 else (None, None)
|
||||
dist.all_reduce(accum_moe_loss, group=gpc.get_group(ParallelMode.PIPELINE))
|
||||
|
||||
return output, label, accum_loss
|
||||
if accum_loss is not None:
|
||||
accum_loss += accum_moe_loss
|
||||
|
||||
return output, label, accum_loss, accum_moe_loss
|
||||
|
||||
def _forward_backward_step(self, engine, return_loss=True, return_output_label=True, moe_loss_coeff:float=1.0):
|
||||
"""
|
||||
|
@ -459,6 +469,7 @@ class PipelineScheduler(BaseScheduler):
|
|||
if return_loss and gpc.is_pipeline_last_stage(ignore_virtual=True)
|
||||
else None
|
||||
)
|
||||
accum_moe_loss = torch.zeros(1, device=get_current_device())
|
||||
|
||||
# Used for tensor meta information communication
|
||||
forward_recv_shapes = self.tensor_shape
|
||||
|
@ -486,6 +497,7 @@ class PipelineScheduler(BaseScheduler):
|
|||
return_tensors,
|
||||
return_output_label=return_output_label,
|
||||
accum_loss=accum_loss,
|
||||
accum_moe_loss=accum_moe_loss,
|
||||
moe_loss_coeff=moe_loss_coeff,
|
||||
)
|
||||
|
||||
|
@ -532,6 +544,7 @@ class PipelineScheduler(BaseScheduler):
|
|||
return_tensors,
|
||||
return_output_label=return_output_label,
|
||||
accum_loss=accum_loss,
|
||||
accum_moe_loss=accum_moe_loss,
|
||||
moe_loss_coeff=moe_loss_coeff,
|
||||
)
|
||||
|
||||
|
@ -598,11 +611,18 @@ class PipelineScheduler(BaseScheduler):
|
|||
if not gpc.is_first_rank(ParallelMode.PIPELINE):
|
||||
comm.send_backward(input_obj_grad, scatter_gather_tensors=self.scatter_gather_tensors)
|
||||
|
||||
logger.info(f"{gpc.get_local_rank(ParallelMode.PIPELINE)}, moe_loss: {accum_moe_loss.item()}")
|
||||
|
||||
output, label = pack_return_tensors(return_tensors) if len(return_tensors) > 0 else (None, None)
|
||||
dist.all_reduce(accum_moe_loss, group=gpc.get_group(ParallelMode.PIPELINE))
|
||||
|
||||
return output, label, accum_loss
|
||||
if accum_loss is not None:
|
||||
accum_loss += accum_moe_loss
|
||||
|
||||
def forward_backward_step(self, engine, data_iter, forward_only=False, return_loss=True, return_output_label=True, moe_loss_coeff:float=1.0):
|
||||
return output, label, accum_loss, accum_moe_loss
|
||||
|
||||
def forward_backward_step(self, engine, data_iter, forward_only=False, return_loss=True,
|
||||
return_output_label=True, moe_loss_coeff:float=1.0):
|
||||
"""Runs non-interleaved 1F1B schedule, with communication between pipeline stages.
|
||||
Returns a tuple with losses if the last stage, an empty tuple otherwise.
|
||||
|
||||
|
@ -614,7 +634,7 @@ class PipelineScheduler(BaseScheduler):
|
|||
return_loss (bool, optional): Whether returns the loss value. Default is true.
|
||||
return_output_label (bool, optional): If False, the output and label won't be returned.
|
||||
Returns:
|
||||
Tuple[:class:`torch.Tensor`]: A tuple of (output, label, loss), loss and label could be None.
|
||||
Tuple[:class:`torch.Tensor`]: A tuple of (output, label, loss, loss), loss and label could be None.
|
||||
"""
|
||||
|
||||
assert (
|
||||
|
@ -694,6 +714,7 @@ class InterleavedPipelineScheduler(PipelineScheduler):
|
|||
self._pp_rank = gpc.get_local_rank(ParallelMode.PIPELINE)
|
||||
|
||||
self._accum_loss = None
|
||||
self._accum_moe_loss = None
|
||||
self._return_tensors = None
|
||||
self._input_objs = [[] for _ in range(num_chunks)]
|
||||
self._output_objs = [[] for _ in range(num_chunks)]
|
||||
|
@ -706,6 +727,7 @@ class InterleavedPipelineScheduler(PipelineScheduler):
|
|||
|
||||
def _clear_state(self) -> None:
|
||||
self._accum_loss = None
|
||||
self._accum_moe_loss = None
|
||||
self._return_tensors = None
|
||||
self._input_objs = [[] for _ in range(self._num_chunks)]
|
||||
self._output_objs = [[] for _ in range(self._num_chunks)]
|
||||
|
@ -777,6 +799,9 @@ class InterleavedPipelineScheduler(PipelineScheduler):
|
|||
moe_loss = sum(moe_losses) * moe_loss_coeff
|
||||
moe_loss /= self.num_microbatches
|
||||
|
||||
if self._accum_moe_loss is not None:
|
||||
self._accum_moe_loss.add_(moe_loss.detach())
|
||||
|
||||
self._output_objs[chunk_id].append(output_obj)
|
||||
self._moe_losses[chunk_id].append(moe_loss)
|
||||
|
||||
|
@ -1287,6 +1312,8 @@ class InterleavedPipelineScheduler(PipelineScheduler):
|
|||
|
||||
if return_loss and gpc.is_pipeline_last_stage(ignore_virtual=True):
|
||||
self._accum_loss = torch.zeros(1, device=get_current_device())
|
||||
if return_loss:
|
||||
self._accum_moe_loss = torch.zeros(1, device=get_current_device())
|
||||
if return_output_label:
|
||||
self._return_tensors = []
|
||||
|
||||
|
@ -1301,6 +1328,11 @@ class InterleavedPipelineScheduler(PipelineScheduler):
|
|||
output, label = (None, None)
|
||||
accum_loss = self._accum_loss
|
||||
|
||||
logger.info(f"{gpc.get_local_rank(ParallelMode.PIPELINE)}, moe_loss: {accum_moe_loss.item()}")
|
||||
|
||||
dist.all_reduce(self._accum_moe_loss, group=gpc.get_group(ParallelMode.PIPELINE))
|
||||
accum_moe_loss = self._accum_moe_loss
|
||||
|
||||
self._clear_state()
|
||||
|
||||
return output, label, accum_loss
|
||||
return output, label, accum_loss, accum_moe_loss
|
||||
|
|
|
@ -155,5 +155,5 @@ class Trainer:
|
|||
Returns:
|
||||
Tuple[:class:`torch.Tensor`]: A tuple of (output, label, loss).
|
||||
"""
|
||||
output, label, loss = self._schedule.forward_backward_step(self._engine, data_iter, **kwargs)
|
||||
return output, label, loss
|
||||
output, label, loss, moe_loss = self._schedule.forward_backward_step(self._engine, data_iter, **kwargs)
|
||||
return output, label, loss, moe_loss
|
||||
|
|
|
@ -100,7 +100,7 @@ def evaluate_on_val_dls(
|
|||
tensor_shape=tensor_shape,
|
||||
metric_hook_list=[val_sche_metric_hook],
|
||||
):
|
||||
_, _, loss = trainer.execute_schedule(
|
||||
_, _, loss, _ = trainer.execute_schedule(
|
||||
batch, forward_only=True, return_loss=True, return_output_label=False
|
||||
)
|
||||
else:
|
||||
|
@ -114,7 +114,7 @@ def evaluate_on_val_dls(
|
|||
grad_accum_batch_size=grad_accum_batch_size,
|
||||
metric_hook_list=[val_sche_metric_hook],
|
||||
):
|
||||
_, _, loss = trainer.execute_schedule(
|
||||
_, _, loss, _ = trainer.execute_schedule(
|
||||
batch, forward_only=True, return_loss=True, return_output_label=False
|
||||
)
|
||||
if verbose:
|
||||
|
|
6
train.py
6
train.py
|
@ -341,6 +341,7 @@ def record_current_batch_training_metrics(
|
|||
trainer,
|
||||
start_time,
|
||||
loss,
|
||||
moe_loss,
|
||||
grad_norm,
|
||||
metric,
|
||||
update_panel,
|
||||
|
@ -384,6 +385,7 @@ def record_current_batch_training_metrics(
|
|||
"tflops": tflops,
|
||||
"step": batch_count,
|
||||
"loss": loss.item(),
|
||||
"moe_loss": moe_loss.item(),
|
||||
"tgs (tokens/gpu/second)": tk_per_gpu,
|
||||
"lr": lr,
|
||||
"loss_scale": scaler,
|
||||
|
@ -419,6 +421,7 @@ def record_current_batch_training_metrics(
|
|||
"num_consumed_tokens": train_state.num_consumed_tokens,
|
||||
"grad_norm": grad_norm,
|
||||
"loss": loss.item(),
|
||||
"moe_loss": moe_loss.item(),
|
||||
"flops": tflops,
|
||||
"tgs": tk_per_gpu,
|
||||
"acc": acc_perplex["acc"],
|
||||
|
@ -606,7 +609,7 @@ def main(args):
|
|||
|
||||
# do forward and backward
|
||||
timer("fwd-bwd").start()
|
||||
_, _, loss = trainer.execute_schedule(
|
||||
_, _, loss, moe_loss = trainer.execute_schedule(
|
||||
batch,
|
||||
forward_only=False,
|
||||
return_loss=True,
|
||||
|
@ -644,6 +647,7 @@ def main(args):
|
|||
trainer=trainer,
|
||||
start_time=start_time,
|
||||
loss=loss,
|
||||
moe_loss=moe_loss,
|
||||
grad_norm=grad_norm,
|
||||
metric=metric,
|
||||
update_panel=uniscale_logger is not None,
|
||||
|
|
Loading…
Reference in New Issue