mirror of https://github.com/hpcaitech/ColossalAI
Jiarui Fang
3 years ago
committed by
GitHub
6 changed files with 48 additions and 7 deletions
@ -1,4 +1,5 @@
|
||||
from .init import colo_uniform |
||||
from .linear import colo_linear |
||||
from .element_wise import colo_mean |
||||
from .layernorm import colo_layernorm |
||||
from .layernorm import colo_layernorm |
||||
from .loss import colo_cross_entropy |
@ -0,0 +1,29 @@
|
||||
import torch |
||||
from colossalai.tensor.op_wrapper import colo_op_impl |
||||
from colossalai.tensor import ColoTensor |
||||
|
||||
|
||||
@colo_op_impl(torch.nn.functional.cross_entropy) |
||||
def colo_cross_entropy(types, args=(), kwargs=None, pg=None): |
||||
arg_num = len(args) |
||||
|
||||
if arg_num > 0: |
||||
input_tensor = args[0] |
||||
if arg_num > 1: |
||||
target = args[1] |
||||
if arg_num > 2: |
||||
weight = args[3] |
||||
|
||||
if 'input' in kwargs: |
||||
input_tensor = kwargs['input'] |
||||
if 'target' in kwargs: |
||||
target = kwargs['target'] |
||||
if 'weight' in kwargs: |
||||
weight = kwargs['weight'] |
||||
|
||||
if isinstance(input_tensor, ColoTensor): |
||||
input_tensor = input_tensor.torch_tensor() |
||||
if isinstance(target, ColoTensor): |
||||
target = target.torch_tensor() |
||||
|
||||
return ColoTensor.init_from_torch_tensor(torch.nn.functional.cross_entropy(input_tensor, target, weight)) |
Loading…
Reference in new issue