add logger for moe_loss

pull/182/head
zhanglei 2023-08-17 16:52:11 +08:00
parent 8cdd1abb35
commit 2983076d89
5 changed files with 53 additions and 14 deletions

View File

@ -127,7 +127,7 @@ class NonPipelineScheduler(BaseScheduler):
if not return_loss:
loss = None
return output, loss
return output, loss, moe_loss
def forward_backward_step(
self,
@ -166,6 +166,7 @@ class NonPipelineScheduler(BaseScheduler):
data, label = batch_data
loss = 0 if return_loss else None
moe_loss = 0 if return_loss else None
outputs = []
labels = []
@ -180,12 +181,14 @@ class NonPipelineScheduler(BaseScheduler):
_data, _label = self._load_accum_batch(data, label)
_output, _loss = self._train_one_batch(
_output, _loss, _moe_loss = self._train_one_batch(
_data, _label, engine, forward_only, return_loss, self._grad_accum_size, moe_loss_coeff
)
if return_loss:
loss += _loss
moe_loss += _moe_loss
if return_output_label:
outputs.append(_output)
labels.append(_label)
@ -193,4 +196,4 @@ class NonPipelineScheduler(BaseScheduler):
if not return_output_label:
outputs, labels = None, None
return outputs, labels, loss
return outputs, labels, loss, moe_loss

View File

@ -7,6 +7,7 @@ from contextlib import contextmanager
from typing import Callable, List, Optional, Tuple, Union
import torch.cuda
import torch.distributed as dist
import internlm.core.communication as comm
from internlm.core.context import ParallelMode
@ -239,7 +240,8 @@ class PipelineScheduler(BaseScheduler):
"""
return step_id
def _forward_step(self, engine, input_obj, return_tensors, return_output_label=True, accum_loss=None, moe_loss_coeff:float=1.0):
def _forward_step(self, engine, input_obj, return_tensors, return_output_label=True,
accum_loss=None, accum_moe_loss=None, moe_loss_coeff:float=1.0):
"""
Forward step for passed-in model. If it is the first stage, the input tensor
is obtained from data_iterator, otherwise the passed-in input_obj is used.
@ -251,6 +253,7 @@ class PipelineScheduler(BaseScheduler):
return_tensors (List[:class:`torch.Tensor`]): A list of tensors to return.
return_output_label (bool, optional): Whether returns output labels.
accum_loss (optional): Where accumulated loss stores.
accum_moe_loss (optional): Where accumulated moe loss stores.
Returns:
Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]: output or the loss value of the current
pipeline stage.
@ -277,6 +280,7 @@ class PipelineScheduler(BaseScheduler):
moe_loss = sum(moe_losses) * moe_loss_coeff
moe_loss /= self.num_microbatches
accum_moe_loss.add_(moe_loss.detach())
return output_obj, moe_loss
@ -366,6 +370,7 @@ class PipelineScheduler(BaseScheduler):
if return_loss and gpc.is_pipeline_last_stage(ignore_virtual=True)
else None
)
accum_moe_loss = torch.zeros(1, device=get_current_device())
# Used for tensor meta information communication
forward_recv_shapes = self.tensor_shape
@ -392,6 +397,7 @@ class PipelineScheduler(BaseScheduler):
return_tensors,
return_output_label=return_output_label,
accum_loss=accum_loss,
accum_moe_loss=accum_moe_loss,
moe_loss_coeff=moe_loss_coeff,
)
@ -403,8 +409,12 @@ class PipelineScheduler(BaseScheduler):
comm.send_forward(output_obj, scatter_gather_tensors=self.scatter_gather_tensors)
output, label = pack_return_tensors(return_tensors) if len(return_tensors) > 0 else (None, None)
dist.all_reduce(accum_moe_loss, group=gpc.get_group(ParallelMode.PIPELINE))
return output, label, accum_loss
if accum_loss is not None:
accum_loss += accum_moe_loss
return output, label, accum_loss, accum_moe_loss
def _forward_backward_step(self, engine, return_loss=True, return_output_label=True, moe_loss_coeff:float=1.0):
"""
@ -459,6 +469,7 @@ class PipelineScheduler(BaseScheduler):
if return_loss and gpc.is_pipeline_last_stage(ignore_virtual=True)
else None
)
accum_moe_loss = torch.zeros(1, device=get_current_device())
# Used for tensor meta information communication
forward_recv_shapes = self.tensor_shape
@ -486,6 +497,7 @@ class PipelineScheduler(BaseScheduler):
return_tensors,
return_output_label=return_output_label,
accum_loss=accum_loss,
accum_moe_loss=accum_moe_loss,
moe_loss_coeff=moe_loss_coeff,
)
@ -532,6 +544,7 @@ class PipelineScheduler(BaseScheduler):
return_tensors,
return_output_label=return_output_label,
accum_loss=accum_loss,
accum_moe_loss=accum_moe_loss,
moe_loss_coeff=moe_loss_coeff,
)
@ -598,11 +611,18 @@ class PipelineScheduler(BaseScheduler):
if not gpc.is_first_rank(ParallelMode.PIPELINE):
comm.send_backward(input_obj_grad, scatter_gather_tensors=self.scatter_gather_tensors)
logger.info(f"{gpc.get_local_rank(ParallelMode.PIPELINE)}, moe_loss: {accum_moe_loss.item()}")
output, label = pack_return_tensors(return_tensors) if len(return_tensors) > 0 else (None, None)
dist.all_reduce(accum_moe_loss, group=gpc.get_group(ParallelMode.PIPELINE))
return output, label, accum_loss
if accum_loss is not None:
accum_loss += accum_moe_loss
def forward_backward_step(self, engine, data_iter, forward_only=False, return_loss=True, return_output_label=True, moe_loss_coeff:float=1.0):
return output, label, accum_loss, accum_moe_loss
def forward_backward_step(self, engine, data_iter, forward_only=False, return_loss=True,
return_output_label=True, moe_loss_coeff:float=1.0):
"""Runs non-interleaved 1F1B schedule, with communication between pipeline stages.
Returns a tuple with losses if the last stage, an empty tuple otherwise.
@ -614,7 +634,7 @@ class PipelineScheduler(BaseScheduler):
return_loss (bool, optional): Whether returns the loss value. Default is true.
return_output_label (bool, optional): If False, the output and label won't be returned.
Returns:
Tuple[:class:`torch.Tensor`]: A tuple of (output, label, loss), loss and label could be None.
Tuple[:class:`torch.Tensor`]: A tuple of (output, label, loss, loss), loss and label could be None.
"""
assert (
@ -694,6 +714,7 @@ class InterleavedPipelineScheduler(PipelineScheduler):
self._pp_rank = gpc.get_local_rank(ParallelMode.PIPELINE)
self._accum_loss = None
self._accum_moe_loss = None
self._return_tensors = None
self._input_objs = [[] for _ in range(num_chunks)]
self._output_objs = [[] for _ in range(num_chunks)]
@ -706,6 +727,7 @@ class InterleavedPipelineScheduler(PipelineScheduler):
def _clear_state(self) -> None:
self._accum_loss = None
self._accum_moe_loss = None
self._return_tensors = None
self._input_objs = [[] for _ in range(self._num_chunks)]
self._output_objs = [[] for _ in range(self._num_chunks)]
@ -777,6 +799,9 @@ class InterleavedPipelineScheduler(PipelineScheduler):
moe_loss = sum(moe_losses) * moe_loss_coeff
moe_loss /= self.num_microbatches
if self._accum_moe_loss is not None:
self._accum_moe_loss.add_(moe_loss.detach())
self._output_objs[chunk_id].append(output_obj)
self._moe_losses[chunk_id].append(moe_loss)
@ -1287,6 +1312,8 @@ class InterleavedPipelineScheduler(PipelineScheduler):
if return_loss and gpc.is_pipeline_last_stage(ignore_virtual=True):
self._accum_loss = torch.zeros(1, device=get_current_device())
if return_loss:
self._accum_moe_loss = torch.zeros(1, device=get_current_device())
if return_output_label:
self._return_tensors = []
@ -1301,6 +1328,11 @@ class InterleavedPipelineScheduler(PipelineScheduler):
output, label = (None, None)
accum_loss = self._accum_loss
logger.info(f"{gpc.get_local_rank(ParallelMode.PIPELINE)}, moe_loss: {accum_moe_loss.item()}")
dist.all_reduce(self._accum_moe_loss, group=gpc.get_group(ParallelMode.PIPELINE))
accum_moe_loss = self._accum_moe_loss
self._clear_state()
return output, label, accum_loss
return output, label, accum_loss, accum_moe_loss

View File

@ -155,5 +155,5 @@ class Trainer:
Returns:
Tuple[:class:`torch.Tensor`]: A tuple of (output, label, loss).
"""
output, label, loss = self._schedule.forward_backward_step(self._engine, data_iter, **kwargs)
return output, label, loss
output, label, loss, moe_loss = self._schedule.forward_backward_step(self._engine, data_iter, **kwargs)
return output, label, loss, moe_loss

View File

@ -100,7 +100,7 @@ def evaluate_on_val_dls(
tensor_shape=tensor_shape,
metric_hook_list=[val_sche_metric_hook],
):
_, _, loss = trainer.execute_schedule(
_, _, loss, _ = trainer.execute_schedule(
batch, forward_only=True, return_loss=True, return_output_label=False
)
else:
@ -114,7 +114,7 @@ def evaluate_on_val_dls(
grad_accum_batch_size=grad_accum_batch_size,
metric_hook_list=[val_sche_metric_hook],
):
_, _, loss = trainer.execute_schedule(
_, _, loss, _ = trainer.execute_schedule(
batch, forward_only=True, return_loss=True, return_output_label=False
)
if verbose:

View File

@ -341,6 +341,7 @@ def record_current_batch_training_metrics(
trainer,
start_time,
loss,
moe_loss,
grad_norm,
metric,
update_panel,
@ -384,6 +385,7 @@ def record_current_batch_training_metrics(
"tflops": tflops,
"step": batch_count,
"loss": loss.item(),
"moe_loss": moe_loss.item(),
"tgs (tokens/gpu/second)": tk_per_gpu,
"lr": lr,
"loss_scale": scaler,
@ -419,6 +421,7 @@ def record_current_batch_training_metrics(
"num_consumed_tokens": train_state.num_consumed_tokens,
"grad_norm": grad_norm,
"loss": loss.item(),
"moe_loss": moe_loss.item(),
"flops": tflops,
"tgs": tk_per_gpu,
"acc": acc_perplex["acc"],
@ -606,7 +609,7 @@ def main(args):
# do forward and backward
timer("fwd-bwd").start()
_, _, loss = trainer.execute_schedule(
_, _, loss, moe_loss = trainer.execute_schedule(
batch,
forward_only=False,
return_loss=True,
@ -644,6 +647,7 @@ def main(args):
trainer=trainer,
start_time=start_time,
loss=loss,
moe_loss=moe_loss,
grad_norm=grad_norm,
metric=metric,
update_panel=uniscale_logger is not None,