From 75d221918af047e7323494907da3139a811791cd Mon Sep 17 00:00:00 2001 From: Ziyue Jiang Date: Sat, 7 May 2022 15:49:14 +0800 Subject: [PATCH] [Tensor] add 1d vocab loss (#918) * add 1d vocab loss * polish --- colossalai/tensor/_ops/loss.py | 28 ++++++++++++++++++-------- tests/components_to_test/simple_net.py | 6 ++++-- tests/test_tensor/test_model.py | 16 ++++++++++++--- 3 files changed, 37 insertions(+), 13 deletions(-) diff --git a/colossalai/tensor/_ops/loss.py b/colossalai/tensor/_ops/loss.py index 7bc75dfe6..89683d3aa 100644 --- a/colossalai/tensor/_ops/loss.py +++ b/colossalai/tensor/_ops/loss.py @@ -1,7 +1,8 @@ +from colossalai.tensor.spec import ShardPattern import torch from colossalai.tensor.op_wrapper import colo_op_impl from colossalai.tensor import ColoTensor - +from colossalai.nn.loss.loss_1d import VocabParallelCrossEntropyLoss1D @colo_op_impl(torch.nn.functional.cross_entropy) def colo_cross_entropy(types, args=(), kwargs=None, pg=None): @@ -12,18 +13,29 @@ def colo_cross_entropy(types, args=(), kwargs=None, pg=None): if arg_num > 1: target = args[1] if arg_num > 2: - weight = args[3] + weight = args[2] if 'input' in kwargs: - input_tensor = kwargs['input'] + input_tensor = kwargs.pop('input') if 'target' in kwargs: - target = kwargs['target'] + target = kwargs.pop('target') if 'weight' in kwargs: - weight = kwargs['weight'] + weight = kwargs.pop('weight') - if isinstance(input_tensor, ColoTensor): - input_tensor = input_tensor.torch_tensor() + if not isinstance(input_tensor, ColoTensor): + input_tensor = ColoTensor.init_from_torch_tensor(input_tensor) if isinstance(target, ColoTensor): target = target.torch_tensor() - return ColoTensor.init_from_torch_tensor(torch.nn.functional.cross_entropy(input_tensor, target, weight)) + if input_tensor.is_gathered(): # Input is gathered + # TODO(jzy) Shall we make the result of loss function a ColoTensor? + return ColoTensor.init_from_torch_tensor(torch.nn.functional.cross_entropy( + input_tensor.torch_tensor(), target, weight)) + elif input_tensor.has_spec() and input_tensor.shard_spec.num_action == 1: # Single Model Parallel Applied + if input_tensor.shard_pattern == ShardPattern.Col: + return ColoTensor.init_from_torch_tensor( + VocabParallelCrossEntropyLoss1D()(input_tensor.torch_tensor(), target)) + else: + raise NotImplementedError + else: + raise NotImplementedError diff --git a/tests/components_to_test/simple_net.py b/tests/components_to_test/simple_net.py index 8c7ba2863..fd4988d9e 100644 --- a/tests/components_to_test/simple_net.py +++ b/tests/components_to_test/simple_net.py @@ -17,6 +17,7 @@ class SimpleNet(CheckpointModule): self.ln1 = nn.LayerNorm(8) self.proj2 = nn.Linear(8, 4) self.ln2 = nn.LayerNorm(4) + self.classifier = nn.Linear(4, 4) def forward(self, x): x = self.embed(x) @@ -24,6 +25,7 @@ class SimpleNet(CheckpointModule): x = self.ln1(x) x = self.proj2(x) x = self.ln2(x) + x = self.classifier(x) return x @@ -31,8 +33,8 @@ class SimpleNet(CheckpointModule): class DummyDataLoader(DummyDataGenerator): def generate(self): - data = torch.randint(low=0, high=20, size=(16,20), device=get_current_device()) - label = torch.randint(low=0, high=2, size=(16,4), device=get_current_device()) + data = torch.randint(low=0, high=20, size=(16,), device=get_current_device()) + label = torch.randint(low=0, high=2, size=(16,), device=get_current_device()) return data, label diff --git a/tests/test_tensor/test_model.py b/tests/test_tensor/test_model.py index 4d2b1a4aa..f8366516e 100644 --- a/tests/test_tensor/test_model.py +++ b/tests/test_tensor/test_model.py @@ -144,10 +144,18 @@ def run_1d_hybrid_tp(model_name): parallel_action_list_col = [ ParallelAction(priority=1, compute_pattern=ComputePattern.TP1DCol_Linear, - parallel_mode=ParallelMode.PARALLEL_1D) + parallel_mode=ParallelMode.PARALLEL_1D), ] spec_col = TensorSpec(parallel_action_list_col) + parallel_action_list_classifier_col = [ + ParallelAction(priority=1, + compute_pattern=ComputePattern.TP1DCol_Linear, + parallel_mode=ParallelMode.PARALLEL_1D, + gather_out=False), + ] + spec_classifier_col = TensorSpec(parallel_action_list_classifier_col) + parallel_action_list_embedding_col = [ ParallelAction(priority=1, compute_pattern=ComputePattern.TP1DCol_Embedding, @@ -158,12 +166,14 @@ def run_1d_hybrid_tp(model_name): for name, p in model.colo_named_parameters(): if not isinstance(p, ColoTensor): continue + if 'embed' in name and 'weight' in name: + p.set_spec(spec_embedding_col) if 'proj1' in name and ('weight' in name or 'bias' in name): p.set_spec(spec_col) if 'proj2' in name and 'weight' in name: p.set_spec(spec_row) - if 'embed' in name and 'weight' in name: - p.set_spec(spec_embedding_col) + if 'classifier' in name and ('weight' in name or 'bias' in name): + p.set_spec(spec_classifier_col) set_seed(1) if rank == 0: