ColossalAI/colossalai/tensor/_ops/loss.py

42 lines
1.5 KiB
Python
Raw Normal View History

from colossalai.tensor.dist_spec import DistPlacementPattern
import torch
from colossalai.tensor.op_wrapper import colo_op_impl
from colossalai.tensor import ColoTensor
from colossalai.nn.loss.loss_1d import VocabParallelCrossEntropyLoss1D
@colo_op_impl(torch.nn.functional.cross_entropy)
def colo_cross_entropy(types, args=(), kwargs=None, pg=None):
arg_num = len(args)
if arg_num > 0:
input_tensor = args[0]
if arg_num > 1:
target = args[1]
if arg_num > 2:
weight = args[2]
if 'input' in kwargs:
input_tensor = kwargs.pop('input')
if 'target' in kwargs:
target = kwargs.pop('target')
if 'weight' in kwargs:
weight = kwargs.pop('weight')
if not isinstance(input_tensor, ColoTensor):
input_tensor = ColoTensor.init_from_torch_tensor(input_tensor)
if isinstance(target, ColoTensor):
target = target.torch_tensor()
if input_tensor.spec.is_gathered(): # Input is gathered
return ColoTensor.init_from_torch_tensor(
torch.nn.functional.cross_entropy(input_tensor.torch_tensor(), target, weight))
elif input_tensor.has_spec() and input_tensor.spec.num_action == 1: # Single Model Parallel Applied
if input_tensor.spec.is_1D_col():
return ColoTensor.init_from_torch_tensor(VocabParallelCrossEntropyLoss1D()(input_tensor.torch_tensor(),
target))
else:
raise NotImplementedError
else:
raise NotImplementedError