import torch import torch.nn.functional as F from typing import Optional from colossalai.tensor.op_wrapper import colo_op_impl from colossalai.tensor import ColoTensor, ColoTensorSpec from colossalai.nn.loss.loss_1d import VocabParallelCrossEntropyLoss1D from ._utils import GeneralTensor, convert_to_colo_tensor @colo_op_impl(F.cross_entropy) def colo_cross_entropy(input_tensor: GeneralTensor, target: GeneralTensor, weight: Optional[GeneralTensor] = None, size_average: Optional[bool] = None, ignore_index: int = -100, reduce: Optional[bool] = None, reduction: str = "mean", label_smoothing: float = 0.0): assert isinstance(weight, ColoTensor) or isinstance(target, ColoTensor) or isinstance(input_tensor, ColoTensor) pg = input_tensor.get_process_group() if isinstance(input_tensor, ColoTensor) else isinstance(target, ColoTensor) weight = convert_to_colo_tensor(weight, pg) target = convert_to_colo_tensor(target, pg) input_tensor = convert_to_colo_tensor(input_tensor, pg) if input_tensor.is_replicate(): # Input is gathered assert target.is_replicate() and (weight is None or weight.is_replicate()), \ "Target tensor and weight tensor both should be complete" output = F.cross_entropy(input_tensor, target, weight=weight, size_average=size_average, ignore_index=ignore_index, reduce=reduce, reduction=reduction, label_smoothing=label_smoothing) return ColoTensor.from_torch_tensor(output, ColoTensorSpec(pg)) elif input_tensor.has_compute_spec(): # Single Model Parallel Applied if input_tensor.is_shard_1dcol(): assert weight is None, "Current TP cross entropy loss function doesn't support passing weight tensor in" assert target.is_replicate(), "Target tensor should be complete in TP cross entropy loss function" output = VocabParallelCrossEntropyLoss1D()(input_tensor, target, process_group=input_tensor.process_group.tp_process_group()) return ColoTensor.from_torch_tensor(output, ColoTensorSpec(pg)) else: raise NotImplementedError else: raise NotImplementedError