From 42ab36b7620f8a8b286f237327344239fb3b25a3 Mon Sep 17 00:00:00 2001 From: HELSON Date: Thu, 7 Jul 2022 19:17:23 +0800 Subject: [PATCH] [tensor] add unitest for colo_tensor 1DTP cross_entropy (#1230) --- colossalai/nn/_ops/loss.py | 12 +++++-- colossalai/nn/loss/loss_1d.py | 28 +++++++--------- tests/test_tensor/test_loss_func.py | 52 +++++++++++++++++++++++++++++ 3 files changed, 73 insertions(+), 19 deletions(-) create mode 100644 tests/test_tensor/test_loss_func.py diff --git a/colossalai/nn/_ops/loss.py b/colossalai/nn/_ops/loss.py index c17406c18..c91c0412d 100644 --- a/colossalai/nn/_ops/loss.py +++ b/colossalai/nn/_ops/loss.py @@ -23,6 +23,8 @@ def colo_cross_entropy(input_tensor: GeneralTensor, input_tensor = convert_to_colo_tensor(input_tensor, pg) if input_tensor.is_replicate(): # Input is gathered + assert target.is_replicate() and weight.is_replicate(), \ + "Target tensor and weight tensor both should be complete" output = F.cross_entropy(input_tensor, target, weight=weight, @@ -31,11 +33,15 @@ def colo_cross_entropy(input_tensor: GeneralTensor, reduce=reduce, reduction=reduction, label_smoothing=label_smoothing) - return ColoTensor.from_torch_tensor(output, ColoTensorSpec(pg)).to_replicate() + return ColoTensor.from_torch_tensor(output, ColoTensorSpec(pg)) elif input_tensor.has_compute_spec(): # Single Model Parallel Applied if input_tensor.is_shard_1dcol(): - output = VocabParallelCrossEntropyLoss1D()(input_tensor, target) - return ColoTensor.from_torch_tensor(output, ColoTensorSpec(pg)).to_replicate() + assert weight is None, "Current TP cross entropy loss function doesn't support passing weight tensor in" + assert target.is_replicate(), "Target tensor should be complete in TP cross entropy loss function" + output = VocabParallelCrossEntropyLoss1D()(input_tensor, + target, + process_group=input_tensor.process_group.tp_process_group()) + return ColoTensor.from_torch_tensor(output, ColoTensorSpec(pg)) else: raise NotImplementedError else: diff --git a/colossalai/nn/loss/loss_1d.py b/colossalai/nn/loss/loss_1d.py index 677d954cc..2fabd954f 100644 --- a/colossalai/nn/loss/loss_1d.py +++ b/colossalai/nn/loss/loss_1d.py @@ -1,4 +1,5 @@ import torch +import torch.distributed as dist from colossalai.context import ParallelMode from colossalai.core import global_context as gpc from colossalai.registry import LOSSES @@ -10,19 +11,19 @@ class _VocabParallelCrossEntropy1D(torch.autograd.Function): @staticmethod @custom_fwd(cast_inputs=torch.float32) - def forward(ctx, vocab_parallel_logits, targets): + def forward(ctx, vocab_parallel_logits, targets, process_group): + if process_group is None: + process_group = gpc.get_group(ParallelMode.PARALLEL_1D) # Maximum value along vocab dimension across all GPUs. logits_max = torch.max(vocab_parallel_logits, dim=-1)[0] - torch.distributed.all_reduce(logits_max, - op=torch.distributed.ReduceOp.MAX, - group=gpc.get_group(ParallelMode.PARALLEL_1D)) + torch.distributed.all_reduce(logits_max, op=torch.distributed.ReduceOp.MAX, group=process_group) # Subtract the maximum value. vocab_parallel_logits.sub_(logits_max.unsqueeze(dim=-1)) # Get the partition's vocab indecies partition_vocab_size = vocab_parallel_logits.size()[-1] - rank = gpc.get_local_rank(ParallelMode.PARALLEL_1D) + rank = dist.get_rank(process_group) vocab_start_index = partition_vocab_size * rank vocab_end_index = vocab_start_index + partition_vocab_size @@ -42,17 +43,12 @@ class _VocabParallelCrossEntropy1D(torch.autograd.Function): predicted_logits = predicted_logits_1d.view_as(targets) predicted_logits[target_mask] = 0.0 # All reduce is needed to get the chunks from other GPUs. - torch.distributed.all_reduce(predicted_logits, - op=torch.distributed.ReduceOp.SUM, - group=gpc.get_group(ParallelMode.PARALLEL_1D)) + torch.distributed.all_reduce(predicted_logits, op=torch.distributed.ReduceOp.SUM, group=process_group) # Sum of exponential of logits along vocab dimension across all GPUs. - exp_logits = vocab_parallel_logits - torch.exp(vocab_parallel_logits, out=exp_logits) + exp_logits = torch.exp(vocab_parallel_logits) sum_exp_logits = exp_logits.sum(dim=-1) - torch.distributed.all_reduce(sum_exp_logits, - op=torch.distributed.ReduceOp.SUM, - group=gpc.get_group(ParallelMode.PARALLEL_1D)) + torch.distributed.all_reduce(sum_exp_logits, op=torch.distributed.ReduceOp.SUM, group=process_group) # Loss = log(sum(exp(logits))) - predicted-logit. loss = torch.log(sum_exp_logits) - predicted_logits @@ -81,7 +77,7 @@ class _VocabParallelCrossEntropy1D(torch.autograd.Function): # Finally elementwise multiplication with the output gradients. grad_input.mul_(grad_output.unsqueeze(dim=-1)) - return grad_input, None + return grad_input, None, None @LOSSES.register_module @@ -96,14 +92,14 @@ class VocabParallelCrossEntropyLoss1D(_Loss): super().__init__() self.reduction_mean = reduction - def forward(self, logits, targets): + def forward(self, logits, targets, process_group=None): """Calculate loss between logits and targets. Args: logits (:class:`torch.tensor`): Predicted unnormalized scores (often referred to as logits). targets (:class:`torch.tensor`): Ground truth class indices or class probabilities. """ - loss = _VocabParallelCrossEntropy1D.apply(logits, targets) + loss = _VocabParallelCrossEntropy1D.apply(logits, targets, process_group) if self.reduction_mean: loss = loss.mean() return loss diff --git a/tests/test_tensor/test_loss_func.py b/tests/test_tensor/test_loss_func.py new file mode 100644 index 000000000..703dcb68a --- /dev/null +++ b/tests/test_tensor/test_loss_func.py @@ -0,0 +1,52 @@ +import torch +import pytest +import colossalai +import torch.nn.functional as F +import torch.multiprocessing as mp +from functools import partial +from colossalai.tensor import ColoTensor, ProcessGroup, ColoTensorSpec +from colossalai.utils import get_current_device +from colossalai.testing import rerun_if_address_is_in_use +from colossalai.utils import free_port +from colossalai.tensor import distspec, ComputeSpec, ComputePattern + + +def check_cross_entropy(): + input_t = torch.randn(4, 4, device=get_current_device(), requires_grad=True) + input_ct = torch.randn(4, 4, device=get_current_device(), requires_grad=True) + with torch.no_grad(): + input_ct.copy_(input_t) + + target = torch.randint(4, (4,), dtype=torch.int64, device=get_current_device()) + + world_size = torch.distributed.get_world_size() + pg = ProcessGroup(tp_degree=world_size) + input_t_colo = ColoTensor.from_torch_tensor(tensor=input_ct, spec=ColoTensorSpec(pg)) + input_shard = input_t_colo.convert_to_dist_spec(distspec.shard([-1], [pg.tp_world_size()])) + input_shard.set_tensor_spec(dist_spec=None, compute_spec=ComputeSpec(ComputePattern.TP1D)) + + output = F.cross_entropy(input_t, target) + output_colo = F.cross_entropy(input_shard, target) + assert torch.allclose(output_colo, output) + + output.backward() + output_colo.backward() + + assert torch.allclose(input_t.grad, input_ct.grad) + + +def run_dist(rank, world_size, port): + colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + check_cross_entropy() + + +@pytest.mark.dist +@pytest.mark.parametrize('world_size', [2]) +@rerun_if_address_is_in_use() +def test_loss_func(world_size): + run_func = partial(run_dist, world_size=world_size, port=free_port()) + mp.spawn(run_func, nprocs=world_size) + + +if __name__ == '__main__': + test_loss_func(2)