mirror of https://github.com/hpcaitech/ColossalAI
[tensor] add cross_entrophy_loss (#868)
parent
3107817172
commit
1190b2c4a4
|
@ -1,4 +1,5 @@
|
|||
from .init import colo_uniform
|
||||
from .linear import colo_linear
|
||||
from .element_wise import colo_mean
|
||||
from .layernorm import colo_layernorm
|
||||
from .layernorm import colo_layernorm
|
||||
from .loss import colo_cross_entropy
|
|
@ -1,4 +1,3 @@
|
|||
from numpy import isin, kaiser
|
||||
import torch
|
||||
from colossalai.tensor.op_wrapper import colo_op_impl
|
||||
from colossalai.tensor import ColoTensor
|
||||
|
|
|
@ -31,7 +31,11 @@ def colo_linear(types, args, kwargs, pg):
|
|||
# Add communication logic before and after linear call.
|
||||
if isinstance(weight, ColoTensor):
|
||||
if weight.shard_spec == None:
|
||||
return torch.nn.functional.linear(input_tensor, weight.torch_tensor(), bias)
|
||||
if isinstance(input_tensor, ColoTensor):
|
||||
input_tensor = input_tensor.torch_tensor()
|
||||
if isinstance(weight, ColoTensor):
|
||||
weight = weight.torch_tensor()
|
||||
return torch.nn.functional.linear(input_tensor, weight, bias)
|
||||
elif weight.shard_spec == '1Drow':
|
||||
# Input:S[1] x Weight:S[0] = Output:P
|
||||
# All-Reduce(Output) + bias = res
|
||||
|
|
|
@ -0,0 +1,29 @@
|
|||
import torch
|
||||
from colossalai.tensor.op_wrapper import colo_op_impl
|
||||
from colossalai.tensor import ColoTensor
|
||||
|
||||
|
||||
@colo_op_impl(torch.nn.functional.cross_entropy)
|
||||
def colo_cross_entropy(types, args=(), kwargs=None, pg=None):
|
||||
arg_num = len(args)
|
||||
|
||||
if arg_num > 0:
|
||||
input_tensor = args[0]
|
||||
if arg_num > 1:
|
||||
target = args[1]
|
||||
if arg_num > 2:
|
||||
weight = args[3]
|
||||
|
||||
if 'input' in kwargs:
|
||||
input_tensor = kwargs['input']
|
||||
if 'target' in kwargs:
|
||||
target = kwargs['target']
|
||||
if 'weight' in kwargs:
|
||||
weight = kwargs['weight']
|
||||
|
||||
if isinstance(input_tensor, ColoTensor):
|
||||
input_tensor = input_tensor.torch_tensor()
|
||||
if isinstance(target, ColoTensor):
|
||||
target = target.torch_tensor()
|
||||
|
||||
return ColoTensor.init_from_torch_tensor(torch.nn.functional.cross_entropy(input_tensor, target, weight))
|
|
@ -14,11 +14,15 @@ class SimpleNet(CheckpointModule):
|
|||
def __init__(self, checkpoint=False) -> None:
|
||||
super().__init__(checkpoint=checkpoint)
|
||||
self.proj1 = nn.Linear(4, 8)
|
||||
self.ln1 = nn.LayerNorm(8)
|
||||
self.proj2 = nn.Linear(8, 4)
|
||||
self.ln2 = nn.LayerNorm(4)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.proj1(x)
|
||||
x = self.ln1(x)
|
||||
x = self.proj2(x)
|
||||
x = self.ln2(x)
|
||||
return x
|
||||
|
||||
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
from cProfile import label
|
||||
from statistics import mode
|
||||
from colossalai.tensor.colo_tensor import ColoTensor
|
||||
from tests.components_to_test.registry import non_distributed_component_funcs
|
||||
|
||||
import colossalai
|
||||
|
@ -20,21 +21,23 @@ def run_simple_net():
|
|||
# A simple net with two stacked nn.Linear
|
||||
get_components_func = non_distributed_component_funcs.get_callable('simple_net')
|
||||
model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func()
|
||||
with ColoInitContext():
|
||||
|
||||
with ColoInitContext(device=get_current_device()):
|
||||
model = model_builder(checkpoint=True)
|
||||
|
||||
# we set the Specs for weight of each linear.
|
||||
model.proj1.weight.set_spec('1Drow')
|
||||
model.proj2.weight.set_spec('1Drow')
|
||||
# model.proj1.weight.set_spec('1Drow')
|
||||
# model.proj2.weight.set_spec('1Drow')
|
||||
|
||||
for i, (data, label) in enumerate(train_dataloader):
|
||||
output = model(data)
|
||||
print(output)
|
||||
|
||||
if criterion:
|
||||
loss = criterion(output, label)
|
||||
else:
|
||||
loss = output
|
||||
|
||||
print(loss.torch_tensor())
|
||||
loss.backward()
|
||||
|
||||
if i > 5:
|
||||
|
@ -49,6 +52,7 @@ def run_dist(rank, world_size, port):
|
|||
run_simple_net()
|
||||
|
||||
|
||||
@pytest.mark.skip
|
||||
@pytest.mark.dist
|
||||
@parameterize('world_size', [1, 4])
|
||||
@rerun_if_address_is_in_use()
|
||||
|
|
Loading…
Reference in New Issue