|
|
|
import pytest
|
|
|
|
import torch
|
|
|
|
import torch.nn.functional as F
|
|
|
|
|
|
|
|
import colossalai
|
|
|
|
from colossalai.logging import disable_existing_loggers
|
|
|
|
from colossalai.shardformer.layer import cross_entropy_1d
|
|
|
|
from colossalai.testing import rerun_if_address_is_in_use, spawn
|
|
|
|
|
|
|
|
CONFIG = dict(
|
|
|
|
parallel=dict(data=1, pipeline=1, tensor=dict(size=2, mode="1d")),
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
def check_dist_crossentropy(rank, world_size, port, ignore_index):
|
|
|
|
disable_existing_loggers()
|
|
|
|
colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, port=port, host="localhost", backend="nccl")
|
|
|
|
|
|
|
|
# prepare data
|
|
|
|
pred = torch.randn(2, 4, 8, requires_grad=True).cuda()
|
|
|
|
labels = torch.randint(8, (2, 4)).cuda()
|
|
|
|
# set some label to -100 to test the ignore index
|
|
|
|
labels[0, -1] = ignore_index
|
|
|
|
|
|
|
|
org_pred = pred.view(-1, 8)
|
|
|
|
org_labels = labels.view(-1)
|
|
|
|
org_loss = F.cross_entropy(org_pred, org_labels)
|
|
|
|
pred.retain_grad()
|
|
|
|
org_loss.backward()
|
|
|
|
|
|
|
|
dist_pred = pred.clone().chunk(world_size, -1)[rank].detach()
|
|
|
|
dist_pred.requires_grad = True
|
|
|
|
dist_loss = cross_entropy_1d(dist_pred, labels, ignore_index=ignore_index)
|
|
|
|
dist_pred.retain_grad()
|
|
|
|
dist_loss.backward()
|
|
|
|
|
|
|
|
assert torch.allclose(
|
|
|
|
org_loss, dist_loss, atol=1e-5
|
|
|
|
), f"dist cross entropy loss is not equal to orgin loss\n{org_loss}\n{dist_loss}"
|
|
|
|
|
|
|
|
|
|
|
|
target_grad = torch.chunk(pred.grad, world_size, dim=-1)[rank]
|
|
|
|
assert torch.allclose(target_grad, dist_pred.grad), f"dist grad is not equal to orgin grad\n{target_grad}\n{dist_pred.grad}"
|
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.dist
|
|
|
|
@rerun_if_address_is_in_use()
|
|
|
|
def test_dist_crossentropy():
|
|
|
|
ignore_index = -100
|
|
|
|
spawn(check_dist_crossentropy, 2, ignore_index=ignore_index)
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
test_dist_crossentropy()
|