mirror of https://github.com/hpcaitech/ColossalAI
46 lines
1.4 KiB
Python
46 lines
1.4 KiB
Python
import pytest
|
|
import torch
|
|
import torch.nn.functional as F
|
|
|
|
import colossalai
|
|
from colossalai.logging import disable_existing_loggers
|
|
from colossalai.shardformer.layer import cross_entropy_1d
|
|
from colossalai.testing import rerun_if_address_is_in_use, spawn
|
|
|
|
CONFIG = dict(
|
|
parallel=dict(data=1, pipeline=1, tensor=dict(size=2, mode="1d")),
|
|
)
|
|
|
|
|
|
def check_dist_crossentropy(rank, world_size, port, ignore_index):
|
|
disable_existing_loggers()
|
|
colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, port=port, host="localhost", backend="nccl")
|
|
|
|
# prepare data
|
|
pred = torch.randn(2, 4, 8, requires_grad=True)
|
|
labels = torch.randint(8, (2, 4))
|
|
# set some label to -100 to test the ignore index
|
|
labels[0, -1] = ignore_index
|
|
|
|
org_pred = pred.view(-1, 8)
|
|
org_labels = labels.view(-1)
|
|
org_loss = F.cross_entropy(org_pred, org_labels)
|
|
|
|
dist_pred = pred.chunk(world_size, -1)[rank]
|
|
dist_loss = cross_entropy_1d(dist_pred.to("cuda"), labels.to("cuda"), ignore_index=ignore_index)
|
|
|
|
assert torch.allclose(
|
|
org_loss, dist_loss, atol=1e-5
|
|
), f"dist cross entropy loss is not equal to orgin loss\n{org_loss}\n{dist_loss}"
|
|
|
|
|
|
@pytest.mark.dist
|
|
@rerun_if_address_is_in_use()
|
|
def test_dist_crossentropy():
|
|
ignore_index = -100
|
|
spawn(check_dist_crossentropy, 2, ignore_index=ignore_index)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
test_dist_crossentropy()
|