From 0453776defc334e6ed88cba3eae02cdf1b001da2 Mon Sep 17 00:00:00 2001 From: HELSON Date: Fri, 8 Jul 2022 11:18:00 +0800 Subject: [PATCH] [tensor] fix a assertion in colo_tensor cross_entropy (#1232) --- colossalai/nn/_ops/loss.py | 2 +- tests/test_tensor/test_loss_func.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/colossalai/nn/_ops/loss.py b/colossalai/nn/_ops/loss.py index c91c0412d..1e54f6628 100644 --- a/colossalai/nn/_ops/loss.py +++ b/colossalai/nn/_ops/loss.py @@ -23,7 +23,7 @@ def colo_cross_entropy(input_tensor: GeneralTensor, input_tensor = convert_to_colo_tensor(input_tensor, pg) if input_tensor.is_replicate(): # Input is gathered - assert target.is_replicate() and weight.is_replicate(), \ + assert target.is_replicate() and (weight is None or weight.is_replicate()), \ "Target tensor and weight tensor both should be complete" output = F.cross_entropy(input_tensor, target, diff --git a/tests/test_tensor/test_loss_func.py b/tests/test_tensor/test_loss_func.py index 703dcb68a..21a0e281a 100644 --- a/tests/test_tensor/test_loss_func.py +++ b/tests/test_tensor/test_loss_func.py @@ -41,7 +41,7 @@ def run_dist(rank, world_size, port): @pytest.mark.dist -@pytest.mark.parametrize('world_size', [2]) +@pytest.mark.parametrize('world_size', [1, 2]) @rerun_if_address_is_in_use() def test_loss_func(world_size): run_func = partial(run_dist, world_size=world_size, port=free_port()) @@ -49,4 +49,4 @@ def test_loss_func(world_size): if __name__ == '__main__': - test_loss_func(2) + test_loss_func(1)