[tensor] add unitest for colo_tensor 1DTP cross_entropy (#1230)

pull/1226/head
HELSON 2022-07-07 19:17:23 +08:00 committed by GitHub
parent 04537bf83e
commit 42ab36b762
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 73 additions and 19 deletions

View File

@ -23,6 +23,8 @@ def colo_cross_entropy(input_tensor: GeneralTensor,
input_tensor = convert_to_colo_tensor(input_tensor, pg)
if input_tensor.is_replicate(): # Input is gathered
assert target.is_replicate() and weight.is_replicate(), \
"Target tensor and weight tensor both should be complete"
output = F.cross_entropy(input_tensor,
target,
weight=weight,
@ -31,11 +33,15 @@ def colo_cross_entropy(input_tensor: GeneralTensor,
reduce=reduce,
reduction=reduction,
label_smoothing=label_smoothing)
return ColoTensor.from_torch_tensor(output, ColoTensorSpec(pg)).to_replicate()
return ColoTensor.from_torch_tensor(output, ColoTensorSpec(pg))
elif input_tensor.has_compute_spec(): # Single Model Parallel Applied
if input_tensor.is_shard_1dcol():
output = VocabParallelCrossEntropyLoss1D()(input_tensor, target)
return ColoTensor.from_torch_tensor(output, ColoTensorSpec(pg)).to_replicate()
assert weight is None, "Current TP cross entropy loss function doesn't support passing weight tensor in"
assert target.is_replicate(), "Target tensor should be complete in TP cross entropy loss function"
output = VocabParallelCrossEntropyLoss1D()(input_tensor,
target,
process_group=input_tensor.process_group.tp_process_group())
return ColoTensor.from_torch_tensor(output, ColoTensorSpec(pg))
else:
raise NotImplementedError
else:

View File

@ -1,4 +1,5 @@
import torch
import torch.distributed as dist
from colossalai.context import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.registry import LOSSES
@ -10,19 +11,19 @@ class _VocabParallelCrossEntropy1D(torch.autograd.Function):
@staticmethod
@custom_fwd(cast_inputs=torch.float32)
def forward(ctx, vocab_parallel_logits, targets):
def forward(ctx, vocab_parallel_logits, targets, process_group):
if process_group is None:
process_group = gpc.get_group(ParallelMode.PARALLEL_1D)
# Maximum value along vocab dimension across all GPUs.
logits_max = torch.max(vocab_parallel_logits, dim=-1)[0]
torch.distributed.all_reduce(logits_max,
op=torch.distributed.ReduceOp.MAX,
group=gpc.get_group(ParallelMode.PARALLEL_1D))
torch.distributed.all_reduce(logits_max, op=torch.distributed.ReduceOp.MAX, group=process_group)
# Subtract the maximum value.
vocab_parallel_logits.sub_(logits_max.unsqueeze(dim=-1))
# Get the partition's vocab indecies
partition_vocab_size = vocab_parallel_logits.size()[-1]
rank = gpc.get_local_rank(ParallelMode.PARALLEL_1D)
rank = dist.get_rank(process_group)
vocab_start_index = partition_vocab_size * rank
vocab_end_index = vocab_start_index + partition_vocab_size
@ -42,17 +43,12 @@ class _VocabParallelCrossEntropy1D(torch.autograd.Function):
predicted_logits = predicted_logits_1d.view_as(targets)
predicted_logits[target_mask] = 0.0
# All reduce is needed to get the chunks from other GPUs.
torch.distributed.all_reduce(predicted_logits,
op=torch.distributed.ReduceOp.SUM,
group=gpc.get_group(ParallelMode.PARALLEL_1D))
torch.distributed.all_reduce(predicted_logits, op=torch.distributed.ReduceOp.SUM, group=process_group)
# Sum of exponential of logits along vocab dimension across all GPUs.
exp_logits = vocab_parallel_logits
torch.exp(vocab_parallel_logits, out=exp_logits)
exp_logits = torch.exp(vocab_parallel_logits)
sum_exp_logits = exp_logits.sum(dim=-1)
torch.distributed.all_reduce(sum_exp_logits,
op=torch.distributed.ReduceOp.SUM,
group=gpc.get_group(ParallelMode.PARALLEL_1D))
torch.distributed.all_reduce(sum_exp_logits, op=torch.distributed.ReduceOp.SUM, group=process_group)
# Loss = log(sum(exp(logits))) - predicted-logit.
loss = torch.log(sum_exp_logits) - predicted_logits
@ -81,7 +77,7 @@ class _VocabParallelCrossEntropy1D(torch.autograd.Function):
# Finally elementwise multiplication with the output gradients.
grad_input.mul_(grad_output.unsqueeze(dim=-1))
return grad_input, None
return grad_input, None, None
@LOSSES.register_module
@ -96,14 +92,14 @@ class VocabParallelCrossEntropyLoss1D(_Loss):
super().__init__()
self.reduction_mean = reduction
def forward(self, logits, targets):
def forward(self, logits, targets, process_group=None):
"""Calculate loss between logits and targets.
Args:
logits (:class:`torch.tensor`): Predicted unnormalized scores (often referred to as logits).
targets (:class:`torch.tensor`): Ground truth class indices or class probabilities.
"""
loss = _VocabParallelCrossEntropy1D.apply(logits, targets)
loss = _VocabParallelCrossEntropy1D.apply(logits, targets, process_group)
if self.reduction_mean:
loss = loss.mean()
return loss

View File

@ -0,0 +1,52 @@
import torch
import pytest
import colossalai
import torch.nn.functional as F
import torch.multiprocessing as mp
from functools import partial
from colossalai.tensor import ColoTensor, ProcessGroup, ColoTensorSpec
from colossalai.utils import get_current_device
from colossalai.testing import rerun_if_address_is_in_use
from colossalai.utils import free_port
from colossalai.tensor import distspec, ComputeSpec, ComputePattern
def check_cross_entropy():
input_t = torch.randn(4, 4, device=get_current_device(), requires_grad=True)
input_ct = torch.randn(4, 4, device=get_current_device(), requires_grad=True)
with torch.no_grad():
input_ct.copy_(input_t)
target = torch.randint(4, (4,), dtype=torch.int64, device=get_current_device())
world_size = torch.distributed.get_world_size()
pg = ProcessGroup(tp_degree=world_size)
input_t_colo = ColoTensor.from_torch_tensor(tensor=input_ct, spec=ColoTensorSpec(pg))
input_shard = input_t_colo.convert_to_dist_spec(distspec.shard([-1], [pg.tp_world_size()]))
input_shard.set_tensor_spec(dist_spec=None, compute_spec=ComputeSpec(ComputePattern.TP1D))
output = F.cross_entropy(input_t, target)
output_colo = F.cross_entropy(input_shard, target)
assert torch.allclose(output_colo, output)
output.backward()
output_colo.backward()
assert torch.allclose(input_t.grad, input_ct.grad)
def run_dist(rank, world_size, port):
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
check_cross_entropy()
@pytest.mark.dist
@pytest.mark.parametrize('world_size', [2])
@rerun_if_address_is_in_use()
def test_loss_func(world_size):
run_func = partial(run_dist, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size)
if __name__ == '__main__':
test_loss_func(2)