mirror of https://github.com/InternLM/InternLM
add logger for moe_loss
parent
8cdd1abb35
commit
2983076d89
|
@ -127,7 +127,7 @@ class NonPipelineScheduler(BaseScheduler):
|
||||||
if not return_loss:
|
if not return_loss:
|
||||||
loss = None
|
loss = None
|
||||||
|
|
||||||
return output, loss
|
return output, loss, moe_loss
|
||||||
|
|
||||||
def forward_backward_step(
|
def forward_backward_step(
|
||||||
self,
|
self,
|
||||||
|
@ -166,6 +166,7 @@ class NonPipelineScheduler(BaseScheduler):
|
||||||
data, label = batch_data
|
data, label = batch_data
|
||||||
|
|
||||||
loss = 0 if return_loss else None
|
loss = 0 if return_loss else None
|
||||||
|
moe_loss = 0 if return_loss else None
|
||||||
outputs = []
|
outputs = []
|
||||||
labels = []
|
labels = []
|
||||||
|
|
||||||
|
@ -180,12 +181,14 @@ class NonPipelineScheduler(BaseScheduler):
|
||||||
|
|
||||||
_data, _label = self._load_accum_batch(data, label)
|
_data, _label = self._load_accum_batch(data, label)
|
||||||
|
|
||||||
_output, _loss = self._train_one_batch(
|
_output, _loss, _moe_loss = self._train_one_batch(
|
||||||
_data, _label, engine, forward_only, return_loss, self._grad_accum_size, moe_loss_coeff
|
_data, _label, engine, forward_only, return_loss, self._grad_accum_size, moe_loss_coeff
|
||||||
)
|
)
|
||||||
|
|
||||||
if return_loss:
|
if return_loss:
|
||||||
loss += _loss
|
loss += _loss
|
||||||
|
moe_loss += _moe_loss
|
||||||
|
|
||||||
if return_output_label:
|
if return_output_label:
|
||||||
outputs.append(_output)
|
outputs.append(_output)
|
||||||
labels.append(_label)
|
labels.append(_label)
|
||||||
|
@ -193,4 +196,4 @@ class NonPipelineScheduler(BaseScheduler):
|
||||||
if not return_output_label:
|
if not return_output_label:
|
||||||
outputs, labels = None, None
|
outputs, labels = None, None
|
||||||
|
|
||||||
return outputs, labels, loss
|
return outputs, labels, loss, moe_loss
|
||||||
|
|
|
@ -7,6 +7,7 @@ from contextlib import contextmanager
|
||||||
from typing import Callable, List, Optional, Tuple, Union
|
from typing import Callable, List, Optional, Tuple, Union
|
||||||
|
|
||||||
import torch.cuda
|
import torch.cuda
|
||||||
|
import torch.distributed as dist
|
||||||
|
|
||||||
import internlm.core.communication as comm
|
import internlm.core.communication as comm
|
||||||
from internlm.core.context import ParallelMode
|
from internlm.core.context import ParallelMode
|
||||||
|
@ -239,7 +240,8 @@ class PipelineScheduler(BaseScheduler):
|
||||||
"""
|
"""
|
||||||
return step_id
|
return step_id
|
||||||
|
|
||||||
def _forward_step(self, engine, input_obj, return_tensors, return_output_label=True, accum_loss=None, moe_loss_coeff:float=1.0):
|
def _forward_step(self, engine, input_obj, return_tensors, return_output_label=True,
|
||||||
|
accum_loss=None, accum_moe_loss=None, moe_loss_coeff:float=1.0):
|
||||||
"""
|
"""
|
||||||
Forward step for passed-in model. If it is the first stage, the input tensor
|
Forward step for passed-in model. If it is the first stage, the input tensor
|
||||||
is obtained from data_iterator, otherwise the passed-in input_obj is used.
|
is obtained from data_iterator, otherwise the passed-in input_obj is used.
|
||||||
|
@ -251,6 +253,7 @@ class PipelineScheduler(BaseScheduler):
|
||||||
return_tensors (List[:class:`torch.Tensor`]): A list of tensors to return.
|
return_tensors (List[:class:`torch.Tensor`]): A list of tensors to return.
|
||||||
return_output_label (bool, optional): Whether returns output labels.
|
return_output_label (bool, optional): Whether returns output labels.
|
||||||
accum_loss (optional): Where accumulated loss stores.
|
accum_loss (optional): Where accumulated loss stores.
|
||||||
|
accum_moe_loss (optional): Where accumulated moe loss stores.
|
||||||
Returns:
|
Returns:
|
||||||
Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]: output or the loss value of the current
|
Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]: output or the loss value of the current
|
||||||
pipeline stage.
|
pipeline stage.
|
||||||
|
@ -277,6 +280,7 @@ class PipelineScheduler(BaseScheduler):
|
||||||
|
|
||||||
moe_loss = sum(moe_losses) * moe_loss_coeff
|
moe_loss = sum(moe_losses) * moe_loss_coeff
|
||||||
moe_loss /= self.num_microbatches
|
moe_loss /= self.num_microbatches
|
||||||
|
accum_moe_loss.add_(moe_loss.detach())
|
||||||
|
|
||||||
return output_obj, moe_loss
|
return output_obj, moe_loss
|
||||||
|
|
||||||
|
@ -366,6 +370,7 @@ class PipelineScheduler(BaseScheduler):
|
||||||
if return_loss and gpc.is_pipeline_last_stage(ignore_virtual=True)
|
if return_loss and gpc.is_pipeline_last_stage(ignore_virtual=True)
|
||||||
else None
|
else None
|
||||||
)
|
)
|
||||||
|
accum_moe_loss = torch.zeros(1, device=get_current_device())
|
||||||
|
|
||||||
# Used for tensor meta information communication
|
# Used for tensor meta information communication
|
||||||
forward_recv_shapes = self.tensor_shape
|
forward_recv_shapes = self.tensor_shape
|
||||||
|
@ -392,6 +397,7 @@ class PipelineScheduler(BaseScheduler):
|
||||||
return_tensors,
|
return_tensors,
|
||||||
return_output_label=return_output_label,
|
return_output_label=return_output_label,
|
||||||
accum_loss=accum_loss,
|
accum_loss=accum_loss,
|
||||||
|
accum_moe_loss=accum_moe_loss,
|
||||||
moe_loss_coeff=moe_loss_coeff,
|
moe_loss_coeff=moe_loss_coeff,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -403,8 +409,12 @@ class PipelineScheduler(BaseScheduler):
|
||||||
comm.send_forward(output_obj, scatter_gather_tensors=self.scatter_gather_tensors)
|
comm.send_forward(output_obj, scatter_gather_tensors=self.scatter_gather_tensors)
|
||||||
|
|
||||||
output, label = pack_return_tensors(return_tensors) if len(return_tensors) > 0 else (None, None)
|
output, label = pack_return_tensors(return_tensors) if len(return_tensors) > 0 else (None, None)
|
||||||
|
dist.all_reduce(accum_moe_loss, group=gpc.get_group(ParallelMode.PIPELINE))
|
||||||
|
|
||||||
return output, label, accum_loss
|
if accum_loss is not None:
|
||||||
|
accum_loss += accum_moe_loss
|
||||||
|
|
||||||
|
return output, label, accum_loss, accum_moe_loss
|
||||||
|
|
||||||
def _forward_backward_step(self, engine, return_loss=True, return_output_label=True, moe_loss_coeff:float=1.0):
|
def _forward_backward_step(self, engine, return_loss=True, return_output_label=True, moe_loss_coeff:float=1.0):
|
||||||
"""
|
"""
|
||||||
|
@ -459,6 +469,7 @@ class PipelineScheduler(BaseScheduler):
|
||||||
if return_loss and gpc.is_pipeline_last_stage(ignore_virtual=True)
|
if return_loss and gpc.is_pipeline_last_stage(ignore_virtual=True)
|
||||||
else None
|
else None
|
||||||
)
|
)
|
||||||
|
accum_moe_loss = torch.zeros(1, device=get_current_device())
|
||||||
|
|
||||||
# Used for tensor meta information communication
|
# Used for tensor meta information communication
|
||||||
forward_recv_shapes = self.tensor_shape
|
forward_recv_shapes = self.tensor_shape
|
||||||
|
@ -486,6 +497,7 @@ class PipelineScheduler(BaseScheduler):
|
||||||
return_tensors,
|
return_tensors,
|
||||||
return_output_label=return_output_label,
|
return_output_label=return_output_label,
|
||||||
accum_loss=accum_loss,
|
accum_loss=accum_loss,
|
||||||
|
accum_moe_loss=accum_moe_loss,
|
||||||
moe_loss_coeff=moe_loss_coeff,
|
moe_loss_coeff=moe_loss_coeff,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -532,6 +544,7 @@ class PipelineScheduler(BaseScheduler):
|
||||||
return_tensors,
|
return_tensors,
|
||||||
return_output_label=return_output_label,
|
return_output_label=return_output_label,
|
||||||
accum_loss=accum_loss,
|
accum_loss=accum_loss,
|
||||||
|
accum_moe_loss=accum_moe_loss,
|
||||||
moe_loss_coeff=moe_loss_coeff,
|
moe_loss_coeff=moe_loss_coeff,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -598,11 +611,18 @@ class PipelineScheduler(BaseScheduler):
|
||||||
if not gpc.is_first_rank(ParallelMode.PIPELINE):
|
if not gpc.is_first_rank(ParallelMode.PIPELINE):
|
||||||
comm.send_backward(input_obj_grad, scatter_gather_tensors=self.scatter_gather_tensors)
|
comm.send_backward(input_obj_grad, scatter_gather_tensors=self.scatter_gather_tensors)
|
||||||
|
|
||||||
|
logger.info(f"{gpc.get_local_rank(ParallelMode.PIPELINE)}, moe_loss: {accum_moe_loss.item()}")
|
||||||
|
|
||||||
output, label = pack_return_tensors(return_tensors) if len(return_tensors) > 0 else (None, None)
|
output, label = pack_return_tensors(return_tensors) if len(return_tensors) > 0 else (None, None)
|
||||||
|
dist.all_reduce(accum_moe_loss, group=gpc.get_group(ParallelMode.PIPELINE))
|
||||||
|
|
||||||
return output, label, accum_loss
|
if accum_loss is not None:
|
||||||
|
accum_loss += accum_moe_loss
|
||||||
|
|
||||||
def forward_backward_step(self, engine, data_iter, forward_only=False, return_loss=True, return_output_label=True, moe_loss_coeff:float=1.0):
|
return output, label, accum_loss, accum_moe_loss
|
||||||
|
|
||||||
|
def forward_backward_step(self, engine, data_iter, forward_only=False, return_loss=True,
|
||||||
|
return_output_label=True, moe_loss_coeff:float=1.0):
|
||||||
"""Runs non-interleaved 1F1B schedule, with communication between pipeline stages.
|
"""Runs non-interleaved 1F1B schedule, with communication between pipeline stages.
|
||||||
Returns a tuple with losses if the last stage, an empty tuple otherwise.
|
Returns a tuple with losses if the last stage, an empty tuple otherwise.
|
||||||
|
|
||||||
|
@ -614,7 +634,7 @@ class PipelineScheduler(BaseScheduler):
|
||||||
return_loss (bool, optional): Whether returns the loss value. Default is true.
|
return_loss (bool, optional): Whether returns the loss value. Default is true.
|
||||||
return_output_label (bool, optional): If False, the output and label won't be returned.
|
return_output_label (bool, optional): If False, the output and label won't be returned.
|
||||||
Returns:
|
Returns:
|
||||||
Tuple[:class:`torch.Tensor`]: A tuple of (output, label, loss), loss and label could be None.
|
Tuple[:class:`torch.Tensor`]: A tuple of (output, label, loss, loss), loss and label could be None.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
assert (
|
assert (
|
||||||
|
@ -694,6 +714,7 @@ class InterleavedPipelineScheduler(PipelineScheduler):
|
||||||
self._pp_rank = gpc.get_local_rank(ParallelMode.PIPELINE)
|
self._pp_rank = gpc.get_local_rank(ParallelMode.PIPELINE)
|
||||||
|
|
||||||
self._accum_loss = None
|
self._accum_loss = None
|
||||||
|
self._accum_moe_loss = None
|
||||||
self._return_tensors = None
|
self._return_tensors = None
|
||||||
self._input_objs = [[] for _ in range(num_chunks)]
|
self._input_objs = [[] for _ in range(num_chunks)]
|
||||||
self._output_objs = [[] for _ in range(num_chunks)]
|
self._output_objs = [[] for _ in range(num_chunks)]
|
||||||
|
@ -706,6 +727,7 @@ class InterleavedPipelineScheduler(PipelineScheduler):
|
||||||
|
|
||||||
def _clear_state(self) -> None:
|
def _clear_state(self) -> None:
|
||||||
self._accum_loss = None
|
self._accum_loss = None
|
||||||
|
self._accum_moe_loss = None
|
||||||
self._return_tensors = None
|
self._return_tensors = None
|
||||||
self._input_objs = [[] for _ in range(self._num_chunks)]
|
self._input_objs = [[] for _ in range(self._num_chunks)]
|
||||||
self._output_objs = [[] for _ in range(self._num_chunks)]
|
self._output_objs = [[] for _ in range(self._num_chunks)]
|
||||||
|
@ -777,6 +799,9 @@ class InterleavedPipelineScheduler(PipelineScheduler):
|
||||||
moe_loss = sum(moe_losses) * moe_loss_coeff
|
moe_loss = sum(moe_losses) * moe_loss_coeff
|
||||||
moe_loss /= self.num_microbatches
|
moe_loss /= self.num_microbatches
|
||||||
|
|
||||||
|
if self._accum_moe_loss is not None:
|
||||||
|
self._accum_moe_loss.add_(moe_loss.detach())
|
||||||
|
|
||||||
self._output_objs[chunk_id].append(output_obj)
|
self._output_objs[chunk_id].append(output_obj)
|
||||||
self._moe_losses[chunk_id].append(moe_loss)
|
self._moe_losses[chunk_id].append(moe_loss)
|
||||||
|
|
||||||
|
@ -1287,6 +1312,8 @@ class InterleavedPipelineScheduler(PipelineScheduler):
|
||||||
|
|
||||||
if return_loss and gpc.is_pipeline_last_stage(ignore_virtual=True):
|
if return_loss and gpc.is_pipeline_last_stage(ignore_virtual=True):
|
||||||
self._accum_loss = torch.zeros(1, device=get_current_device())
|
self._accum_loss = torch.zeros(1, device=get_current_device())
|
||||||
|
if return_loss:
|
||||||
|
self._accum_moe_loss = torch.zeros(1, device=get_current_device())
|
||||||
if return_output_label:
|
if return_output_label:
|
||||||
self._return_tensors = []
|
self._return_tensors = []
|
||||||
|
|
||||||
|
@ -1301,6 +1328,11 @@ class InterleavedPipelineScheduler(PipelineScheduler):
|
||||||
output, label = (None, None)
|
output, label = (None, None)
|
||||||
accum_loss = self._accum_loss
|
accum_loss = self._accum_loss
|
||||||
|
|
||||||
|
logger.info(f"{gpc.get_local_rank(ParallelMode.PIPELINE)}, moe_loss: {accum_moe_loss.item()}")
|
||||||
|
|
||||||
|
dist.all_reduce(self._accum_moe_loss, group=gpc.get_group(ParallelMode.PIPELINE))
|
||||||
|
accum_moe_loss = self._accum_moe_loss
|
||||||
|
|
||||||
self._clear_state()
|
self._clear_state()
|
||||||
|
|
||||||
return output, label, accum_loss
|
return output, label, accum_loss, accum_moe_loss
|
||||||
|
|
|
@ -155,5 +155,5 @@ class Trainer:
|
||||||
Returns:
|
Returns:
|
||||||
Tuple[:class:`torch.Tensor`]: A tuple of (output, label, loss).
|
Tuple[:class:`torch.Tensor`]: A tuple of (output, label, loss).
|
||||||
"""
|
"""
|
||||||
output, label, loss = self._schedule.forward_backward_step(self._engine, data_iter, **kwargs)
|
output, label, loss, moe_loss = self._schedule.forward_backward_step(self._engine, data_iter, **kwargs)
|
||||||
return output, label, loss
|
return output, label, loss, moe_loss
|
||||||
|
|
|
@ -100,7 +100,7 @@ def evaluate_on_val_dls(
|
||||||
tensor_shape=tensor_shape,
|
tensor_shape=tensor_shape,
|
||||||
metric_hook_list=[val_sche_metric_hook],
|
metric_hook_list=[val_sche_metric_hook],
|
||||||
):
|
):
|
||||||
_, _, loss = trainer.execute_schedule(
|
_, _, loss, _ = trainer.execute_schedule(
|
||||||
batch, forward_only=True, return_loss=True, return_output_label=False
|
batch, forward_only=True, return_loss=True, return_output_label=False
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
|
@ -114,7 +114,7 @@ def evaluate_on_val_dls(
|
||||||
grad_accum_batch_size=grad_accum_batch_size,
|
grad_accum_batch_size=grad_accum_batch_size,
|
||||||
metric_hook_list=[val_sche_metric_hook],
|
metric_hook_list=[val_sche_metric_hook],
|
||||||
):
|
):
|
||||||
_, _, loss = trainer.execute_schedule(
|
_, _, loss, _ = trainer.execute_schedule(
|
||||||
batch, forward_only=True, return_loss=True, return_output_label=False
|
batch, forward_only=True, return_loss=True, return_output_label=False
|
||||||
)
|
)
|
||||||
if verbose:
|
if verbose:
|
||||||
|
|
6
train.py
6
train.py
|
@ -341,6 +341,7 @@ def record_current_batch_training_metrics(
|
||||||
trainer,
|
trainer,
|
||||||
start_time,
|
start_time,
|
||||||
loss,
|
loss,
|
||||||
|
moe_loss,
|
||||||
grad_norm,
|
grad_norm,
|
||||||
metric,
|
metric,
|
||||||
update_panel,
|
update_panel,
|
||||||
|
@ -384,6 +385,7 @@ def record_current_batch_training_metrics(
|
||||||
"tflops": tflops,
|
"tflops": tflops,
|
||||||
"step": batch_count,
|
"step": batch_count,
|
||||||
"loss": loss.item(),
|
"loss": loss.item(),
|
||||||
|
"moe_loss": moe_loss.item(),
|
||||||
"tgs (tokens/gpu/second)": tk_per_gpu,
|
"tgs (tokens/gpu/second)": tk_per_gpu,
|
||||||
"lr": lr,
|
"lr": lr,
|
||||||
"loss_scale": scaler,
|
"loss_scale": scaler,
|
||||||
|
@ -419,6 +421,7 @@ def record_current_batch_training_metrics(
|
||||||
"num_consumed_tokens": train_state.num_consumed_tokens,
|
"num_consumed_tokens": train_state.num_consumed_tokens,
|
||||||
"grad_norm": grad_norm,
|
"grad_norm": grad_norm,
|
||||||
"loss": loss.item(),
|
"loss": loss.item(),
|
||||||
|
"moe_loss": moe_loss.item(),
|
||||||
"flops": tflops,
|
"flops": tflops,
|
||||||
"tgs": tk_per_gpu,
|
"tgs": tk_per_gpu,
|
||||||
"acc": acc_perplex["acc"],
|
"acc": acc_perplex["acc"],
|
||||||
|
@ -606,7 +609,7 @@ def main(args):
|
||||||
|
|
||||||
# do forward and backward
|
# do forward and backward
|
||||||
timer("fwd-bwd").start()
|
timer("fwd-bwd").start()
|
||||||
_, _, loss = trainer.execute_schedule(
|
_, _, loss, moe_loss = trainer.execute_schedule(
|
||||||
batch,
|
batch,
|
||||||
forward_only=False,
|
forward_only=False,
|
||||||
return_loss=True,
|
return_loss=True,
|
||||||
|
@ -644,6 +647,7 @@ def main(args):
|
||||||
trainer=trainer,
|
trainer=trainer,
|
||||||
start_time=start_time,
|
start_time=start_time,
|
||||||
loss=loss,
|
loss=loss,
|
||||||
|
moe_loss=moe_loss,
|
||||||
grad_norm=grad_norm,
|
grad_norm=grad_norm,
|
||||||
metric=metric,
|
metric=metric,
|
||||||
update_panel=uniscale_logger is not None,
|
update_panel=uniscale_logger is not None,
|
||||||
|
|
Loading…
Reference in New Issue