import pytest import torch import torch.nn.functional as F import colossalai from colossalai.logging import disable_existing_loggers from colossalai.shardformer.layer.dist_crossentropy import applyDistCrossEntropy from colossalai.testing import rerun_if_address_is_in_use, spawn CONFIG = dict(parallel=dict(data=1, pipeline=1, tensor=dict(size=2, mode='1d')),) def check_dist_crossentropy(rank, world_size, port, ignore_index): disable_existing_loggers() colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, port=port, host='localhost', backend='nccl') # prepare data pred = torch.randn(2, 4, 8, requires_grad=True) labels = torch.randint(8, (2, 4)) # set some label to -100 to test the ignore index labels[0, -1] = ignore_index org_pred = pred.view(-1, 8) org_labels = labels.view(-1) org_loss = F.cross_entropy(org_pred, org_labels) dist_pred = pred.chunk(world_size, -1)[rank] dist_loss = applyDistCrossEntropy(dist_pred.to('cuda'), labels.to('cuda'), ignore_index=ignore_index) assert torch.allclose(org_loss, dist_loss, atol=1e-5), f"dist cross entropy loss is not equal to orgin loss\n{org_loss}\n{dist_loss}" @pytest.mark.dist @rerun_if_address_is_in_use() def test_dist_crossentropy(): ignore_index = -100 spawn(check_dist_crossentropy, 2, ignore_index=ignore_index) if __name__ == '__main__': test_dist_crossentropy()