From 1190b2c4a49e646dc7bfaa54c3544cc19fe56005 Mon Sep 17 00:00:00 2001 From: Jiarui Fang Date: Mon, 25 Apr 2022 16:01:52 +0800 Subject: [PATCH] [tensor] add cross_entrophy_loss (#868) --- colossalai/tensor/_ops/__init__.py | 3 ++- colossalai/tensor/_ops/layernorm.py | 1 - colossalai/tensor/_ops/linear.py | 6 +++++- colossalai/tensor/_ops/loss.py | 29 ++++++++++++++++++++++++++ tests/components_to_test/simple_net.py | 4 ++++ tests/test_tensor/test_net_tp.py | 12 +++++++---- 6 files changed, 48 insertions(+), 7 deletions(-) create mode 100644 colossalai/tensor/_ops/loss.py diff --git a/colossalai/tensor/_ops/__init__.py b/colossalai/tensor/_ops/__init__.py index d1b945dd2..034a2f695 100644 --- a/colossalai/tensor/_ops/__init__.py +++ b/colossalai/tensor/_ops/__init__.py @@ -1,4 +1,5 @@ from .init import colo_uniform from .linear import colo_linear from .element_wise import colo_mean -from .layernorm import colo_layernorm \ No newline at end of file +from .layernorm import colo_layernorm +from .loss import colo_cross_entropy \ No newline at end of file diff --git a/colossalai/tensor/_ops/layernorm.py b/colossalai/tensor/_ops/layernorm.py index d616fd104..6658c05b1 100644 --- a/colossalai/tensor/_ops/layernorm.py +++ b/colossalai/tensor/_ops/layernorm.py @@ -1,4 +1,3 @@ -from numpy import isin, kaiser import torch from colossalai.tensor.op_wrapper import colo_op_impl from colossalai.tensor import ColoTensor diff --git a/colossalai/tensor/_ops/linear.py b/colossalai/tensor/_ops/linear.py index 824ce702c..c6bb78dd4 100644 --- a/colossalai/tensor/_ops/linear.py +++ b/colossalai/tensor/_ops/linear.py @@ -31,7 +31,11 @@ def colo_linear(types, args, kwargs, pg): # Add communication logic before and after linear call. if isinstance(weight, ColoTensor): if weight.shard_spec == None: - return torch.nn.functional.linear(input_tensor, weight.torch_tensor(), bias) + if isinstance(input_tensor, ColoTensor): + input_tensor = input_tensor.torch_tensor() + if isinstance(weight, ColoTensor): + weight = weight.torch_tensor() + return torch.nn.functional.linear(input_tensor, weight, bias) elif weight.shard_spec == '1Drow': # Input:S[1] x Weight:S[0] = Output:P # All-Reduce(Output) + bias = res diff --git a/colossalai/tensor/_ops/loss.py b/colossalai/tensor/_ops/loss.py new file mode 100644 index 000000000..7bc75dfe6 --- /dev/null +++ b/colossalai/tensor/_ops/loss.py @@ -0,0 +1,29 @@ +import torch +from colossalai.tensor.op_wrapper import colo_op_impl +from colossalai.tensor import ColoTensor + + +@colo_op_impl(torch.nn.functional.cross_entropy) +def colo_cross_entropy(types, args=(), kwargs=None, pg=None): + arg_num = len(args) + + if arg_num > 0: + input_tensor = args[0] + if arg_num > 1: + target = args[1] + if arg_num > 2: + weight = args[3] + + if 'input' in kwargs: + input_tensor = kwargs['input'] + if 'target' in kwargs: + target = kwargs['target'] + if 'weight' in kwargs: + weight = kwargs['weight'] + + if isinstance(input_tensor, ColoTensor): + input_tensor = input_tensor.torch_tensor() + if isinstance(target, ColoTensor): + target = target.torch_tensor() + + return ColoTensor.init_from_torch_tensor(torch.nn.functional.cross_entropy(input_tensor, target, weight)) diff --git a/tests/components_to_test/simple_net.py b/tests/components_to_test/simple_net.py index 1d77d5b71..58c9835d8 100644 --- a/tests/components_to_test/simple_net.py +++ b/tests/components_to_test/simple_net.py @@ -14,11 +14,15 @@ class SimpleNet(CheckpointModule): def __init__(self, checkpoint=False) -> None: super().__init__(checkpoint=checkpoint) self.proj1 = nn.Linear(4, 8) + self.ln1 = nn.LayerNorm(8) self.proj2 = nn.Linear(8, 4) + self.ln2 = nn.LayerNorm(4) def forward(self, x): x = self.proj1(x) + x = self.ln1(x) x = self.proj2(x) + x = self.ln2(x) return x diff --git a/tests/test_tensor/test_net_tp.py b/tests/test_tensor/test_net_tp.py index e63e786a2..f21d1b459 100644 --- a/tests/test_tensor/test_net_tp.py +++ b/tests/test_tensor/test_net_tp.py @@ -1,5 +1,6 @@ from cProfile import label from statistics import mode +from colossalai.tensor.colo_tensor import ColoTensor from tests.components_to_test.registry import non_distributed_component_funcs import colossalai @@ -20,21 +21,23 @@ def run_simple_net(): # A simple net with two stacked nn.Linear get_components_func = non_distributed_component_funcs.get_callable('simple_net') model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func() - with ColoInitContext(): + + with ColoInitContext(device=get_current_device()): model = model_builder(checkpoint=True) # we set the Specs for weight of each linear. - model.proj1.weight.set_spec('1Drow') - model.proj2.weight.set_spec('1Drow') + # model.proj1.weight.set_spec('1Drow') + # model.proj2.weight.set_spec('1Drow') for i, (data, label) in enumerate(train_dataloader): output = model(data) - print(output) + if criterion: loss = criterion(output, label) else: loss = output + print(loss.torch_tensor()) loss.backward() if i > 5: @@ -49,6 +52,7 @@ def run_dist(rank, world_size, port): run_simple_net() +@pytest.mark.skip @pytest.mark.dist @parameterize('world_size', [1, 4]) @rerun_if_address_is_in_use()