mirror of https://github.com/hpcaitech/ColossalAI
[tensor] add unitest for colo_tensor 1DTP cross_entropy (#1230)
parent
04537bf83e
commit
42ab36b762
|
@ -23,6 +23,8 @@ def colo_cross_entropy(input_tensor: GeneralTensor,
|
|||
input_tensor = convert_to_colo_tensor(input_tensor, pg)
|
||||
|
||||
if input_tensor.is_replicate(): # Input is gathered
|
||||
assert target.is_replicate() and weight.is_replicate(), \
|
||||
"Target tensor and weight tensor both should be complete"
|
||||
output = F.cross_entropy(input_tensor,
|
||||
target,
|
||||
weight=weight,
|
||||
|
@ -31,11 +33,15 @@ def colo_cross_entropy(input_tensor: GeneralTensor,
|
|||
reduce=reduce,
|
||||
reduction=reduction,
|
||||
label_smoothing=label_smoothing)
|
||||
return ColoTensor.from_torch_tensor(output, ColoTensorSpec(pg)).to_replicate()
|
||||
return ColoTensor.from_torch_tensor(output, ColoTensorSpec(pg))
|
||||
elif input_tensor.has_compute_spec(): # Single Model Parallel Applied
|
||||
if input_tensor.is_shard_1dcol():
|
||||
output = VocabParallelCrossEntropyLoss1D()(input_tensor, target)
|
||||
return ColoTensor.from_torch_tensor(output, ColoTensorSpec(pg)).to_replicate()
|
||||
assert weight is None, "Current TP cross entropy loss function doesn't support passing weight tensor in"
|
||||
assert target.is_replicate(), "Target tensor should be complete in TP cross entropy loss function"
|
||||
output = VocabParallelCrossEntropyLoss1D()(input_tensor,
|
||||
target,
|
||||
process_group=input_tensor.process_group.tp_process_group())
|
||||
return ColoTensor.from_torch_tensor(output, ColoTensorSpec(pg))
|
||||
else:
|
||||
raise NotImplementedError
|
||||
else:
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
import torch
|
||||
import torch.distributed as dist
|
||||
from colossalai.context import ParallelMode
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.registry import LOSSES
|
||||
|
@ -10,19 +11,19 @@ class _VocabParallelCrossEntropy1D(torch.autograd.Function):
|
|||
|
||||
@staticmethod
|
||||
@custom_fwd(cast_inputs=torch.float32)
|
||||
def forward(ctx, vocab_parallel_logits, targets):
|
||||
def forward(ctx, vocab_parallel_logits, targets, process_group):
|
||||
if process_group is None:
|
||||
process_group = gpc.get_group(ParallelMode.PARALLEL_1D)
|
||||
|
||||
# Maximum value along vocab dimension across all GPUs.
|
||||
logits_max = torch.max(vocab_parallel_logits, dim=-1)[0]
|
||||
torch.distributed.all_reduce(logits_max,
|
||||
op=torch.distributed.ReduceOp.MAX,
|
||||
group=gpc.get_group(ParallelMode.PARALLEL_1D))
|
||||
torch.distributed.all_reduce(logits_max, op=torch.distributed.ReduceOp.MAX, group=process_group)
|
||||
# Subtract the maximum value.
|
||||
vocab_parallel_logits.sub_(logits_max.unsqueeze(dim=-1))
|
||||
|
||||
# Get the partition's vocab indecies
|
||||
partition_vocab_size = vocab_parallel_logits.size()[-1]
|
||||
rank = gpc.get_local_rank(ParallelMode.PARALLEL_1D)
|
||||
rank = dist.get_rank(process_group)
|
||||
vocab_start_index = partition_vocab_size * rank
|
||||
vocab_end_index = vocab_start_index + partition_vocab_size
|
||||
|
||||
|
@ -42,17 +43,12 @@ class _VocabParallelCrossEntropy1D(torch.autograd.Function):
|
|||
predicted_logits = predicted_logits_1d.view_as(targets)
|
||||
predicted_logits[target_mask] = 0.0
|
||||
# All reduce is needed to get the chunks from other GPUs.
|
||||
torch.distributed.all_reduce(predicted_logits,
|
||||
op=torch.distributed.ReduceOp.SUM,
|
||||
group=gpc.get_group(ParallelMode.PARALLEL_1D))
|
||||
torch.distributed.all_reduce(predicted_logits, op=torch.distributed.ReduceOp.SUM, group=process_group)
|
||||
|
||||
# Sum of exponential of logits along vocab dimension across all GPUs.
|
||||
exp_logits = vocab_parallel_logits
|
||||
torch.exp(vocab_parallel_logits, out=exp_logits)
|
||||
exp_logits = torch.exp(vocab_parallel_logits)
|
||||
sum_exp_logits = exp_logits.sum(dim=-1)
|
||||
torch.distributed.all_reduce(sum_exp_logits,
|
||||
op=torch.distributed.ReduceOp.SUM,
|
||||
group=gpc.get_group(ParallelMode.PARALLEL_1D))
|
||||
torch.distributed.all_reduce(sum_exp_logits, op=torch.distributed.ReduceOp.SUM, group=process_group)
|
||||
|
||||
# Loss = log(sum(exp(logits))) - predicted-logit.
|
||||
loss = torch.log(sum_exp_logits) - predicted_logits
|
||||
|
@ -81,7 +77,7 @@ class _VocabParallelCrossEntropy1D(torch.autograd.Function):
|
|||
# Finally elementwise multiplication with the output gradients.
|
||||
grad_input.mul_(grad_output.unsqueeze(dim=-1))
|
||||
|
||||
return grad_input, None
|
||||
return grad_input, None, None
|
||||
|
||||
|
||||
@LOSSES.register_module
|
||||
|
@ -96,14 +92,14 @@ class VocabParallelCrossEntropyLoss1D(_Loss):
|
|||
super().__init__()
|
||||
self.reduction_mean = reduction
|
||||
|
||||
def forward(self, logits, targets):
|
||||
def forward(self, logits, targets, process_group=None):
|
||||
"""Calculate loss between logits and targets.
|
||||
|
||||
Args:
|
||||
logits (:class:`torch.tensor`): Predicted unnormalized scores (often referred to as logits).
|
||||
targets (:class:`torch.tensor`): Ground truth class indices or class probabilities.
|
||||
"""
|
||||
loss = _VocabParallelCrossEntropy1D.apply(logits, targets)
|
||||
loss = _VocabParallelCrossEntropy1D.apply(logits, targets, process_group)
|
||||
if self.reduction_mean:
|
||||
loss = loss.mean()
|
||||
return loss
|
||||
|
|
|
@ -0,0 +1,52 @@
|
|||
import torch
|
||||
import pytest
|
||||
import colossalai
|
||||
import torch.nn.functional as F
|
||||
import torch.multiprocessing as mp
|
||||
from functools import partial
|
||||
from colossalai.tensor import ColoTensor, ProcessGroup, ColoTensorSpec
|
||||
from colossalai.utils import get_current_device
|
||||
from colossalai.testing import rerun_if_address_is_in_use
|
||||
from colossalai.utils import free_port
|
||||
from colossalai.tensor import distspec, ComputeSpec, ComputePattern
|
||||
|
||||
|
||||
def check_cross_entropy():
|
||||
input_t = torch.randn(4, 4, device=get_current_device(), requires_grad=True)
|
||||
input_ct = torch.randn(4, 4, device=get_current_device(), requires_grad=True)
|
||||
with torch.no_grad():
|
||||
input_ct.copy_(input_t)
|
||||
|
||||
target = torch.randint(4, (4,), dtype=torch.int64, device=get_current_device())
|
||||
|
||||
world_size = torch.distributed.get_world_size()
|
||||
pg = ProcessGroup(tp_degree=world_size)
|
||||
input_t_colo = ColoTensor.from_torch_tensor(tensor=input_ct, spec=ColoTensorSpec(pg))
|
||||
input_shard = input_t_colo.convert_to_dist_spec(distspec.shard([-1], [pg.tp_world_size()]))
|
||||
input_shard.set_tensor_spec(dist_spec=None, compute_spec=ComputeSpec(ComputePattern.TP1D))
|
||||
|
||||
output = F.cross_entropy(input_t, target)
|
||||
output_colo = F.cross_entropy(input_shard, target)
|
||||
assert torch.allclose(output_colo, output)
|
||||
|
||||
output.backward()
|
||||
output_colo.backward()
|
||||
|
||||
assert torch.allclose(input_t.grad, input_ct.grad)
|
||||
|
||||
|
||||
def run_dist(rank, world_size, port):
|
||||
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
check_cross_entropy()
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
@pytest.mark.parametrize('world_size', [2])
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_loss_func(world_size):
|
||||
run_func = partial(run_dist, world_size=world_size, port=free_port())
|
||||
mp.spawn(run_func, nprocs=world_size)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_loss_func(2)
|
Loading…
Reference in New Issue