import torch import torch.nn as nn from colossalai.core import global_context as gpc from colossalai.context import ParallelMode from colossalai.logging import get_dist_logger import torch.nn.functional as F import torch.distributed as dist from .cross_entropy import vocab_cross_entropy class BertLoss(nn.Module): def forward(self, lm_loss, sop_logits, loss_mask, sentence_order): lm_loss_ = lm_loss.float() loss_mask = loss_mask.float() loss_mask_sum = loss_mask.sum() lm_loss = torch.sum( lm_loss_.view(-1) * loss_mask.reshape(-1)) lm_loss /= loss_mask_sum torch.distributed.all_reduce( lm_loss, group=gpc.get_group(ParallelMode.SEQUENCE) ) if sop_logits is not None: sop_loss = F.cross_entropy(sop_logits.view(-1, 2).float(), sentence_order.view(-1), ignore_index=-1) sop_loss = sop_loss.float() loss = lm_loss + sop_loss * gpc.get_world_size(ParallelMode.SEQUENCE) else: sop_loss = None loss = lm_loss return loss