Browse Source

[tensor] add cross_entrophy_loss (#868)

pull/870/head
Jiarui Fang 3 years ago committed by GitHub
parent
commit
1190b2c4a4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 3
      colossalai/tensor/_ops/__init__.py
  2. 1
      colossalai/tensor/_ops/layernorm.py
  3. 6
      colossalai/tensor/_ops/linear.py
  4. 29
      colossalai/tensor/_ops/loss.py
  5. 4
      tests/components_to_test/simple_net.py
  6. 12
      tests/test_tensor/test_net_tp.py

3
colossalai/tensor/_ops/__init__.py

@ -1,4 +1,5 @@
from .init import colo_uniform
from .linear import colo_linear
from .element_wise import colo_mean
from .layernorm import colo_layernorm
from .layernorm import colo_layernorm
from .loss import colo_cross_entropy

1
colossalai/tensor/_ops/layernorm.py

@ -1,4 +1,3 @@
from numpy import isin, kaiser
import torch
from colossalai.tensor.op_wrapper import colo_op_impl
from colossalai.tensor import ColoTensor

6
colossalai/tensor/_ops/linear.py

@ -31,7 +31,11 @@ def colo_linear(types, args, kwargs, pg):
# Add communication logic before and after linear call.
if isinstance(weight, ColoTensor):
if weight.shard_spec == None:
return torch.nn.functional.linear(input_tensor, weight.torch_tensor(), bias)
if isinstance(input_tensor, ColoTensor):
input_tensor = input_tensor.torch_tensor()
if isinstance(weight, ColoTensor):
weight = weight.torch_tensor()
return torch.nn.functional.linear(input_tensor, weight, bias)
elif weight.shard_spec == '1Drow':
# Input:S[1] x Weight:S[0] = Output:P
# All-Reduce(Output) + bias = res

29
colossalai/tensor/_ops/loss.py

@ -0,0 +1,29 @@
import torch
from colossalai.tensor.op_wrapper import colo_op_impl
from colossalai.tensor import ColoTensor
@colo_op_impl(torch.nn.functional.cross_entropy)
def colo_cross_entropy(types, args=(), kwargs=None, pg=None):
arg_num = len(args)
if arg_num > 0:
input_tensor = args[0]
if arg_num > 1:
target = args[1]
if arg_num > 2:
weight = args[3]
if 'input' in kwargs:
input_tensor = kwargs['input']
if 'target' in kwargs:
target = kwargs['target']
if 'weight' in kwargs:
weight = kwargs['weight']
if isinstance(input_tensor, ColoTensor):
input_tensor = input_tensor.torch_tensor()
if isinstance(target, ColoTensor):
target = target.torch_tensor()
return ColoTensor.init_from_torch_tensor(torch.nn.functional.cross_entropy(input_tensor, target, weight))

4
tests/components_to_test/simple_net.py

@ -14,11 +14,15 @@ class SimpleNet(CheckpointModule):
def __init__(self, checkpoint=False) -> None:
super().__init__(checkpoint=checkpoint)
self.proj1 = nn.Linear(4, 8)
self.ln1 = nn.LayerNorm(8)
self.proj2 = nn.Linear(8, 4)
self.ln2 = nn.LayerNorm(4)
def forward(self, x):
x = self.proj1(x)
x = self.ln1(x)
x = self.proj2(x)
x = self.ln2(x)
return x

12
tests/test_tensor/test_net_tp.py

@ -1,5 +1,6 @@
from cProfile import label
from statistics import mode
from colossalai.tensor.colo_tensor import ColoTensor
from tests.components_to_test.registry import non_distributed_component_funcs
import colossalai
@ -20,21 +21,23 @@ def run_simple_net():
# A simple net with two stacked nn.Linear
get_components_func = non_distributed_component_funcs.get_callable('simple_net')
model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func()
with ColoInitContext():
with ColoInitContext(device=get_current_device()):
model = model_builder(checkpoint=True)
# we set the Specs for weight of each linear.
model.proj1.weight.set_spec('1Drow')
model.proj2.weight.set_spec('1Drow')
# model.proj1.weight.set_spec('1Drow')
# model.proj2.weight.set_spec('1Drow')
for i, (data, label) in enumerate(train_dataloader):
output = model(data)
print(output)
if criterion:
loss = criterion(output, label)
else:
loss = output
print(loss.torch_tensor())
loss.backward()
if i > 5:
@ -49,6 +52,7 @@ def run_dist(rank, world_size, port):
run_simple_net()
@pytest.mark.skip
@pytest.mark.dist
@parameterize('world_size', [1, 4])
@rerun_if_address_is_in_use()

Loading…
Cancel
Save