diff --git a/colossalai/zero/low_level/low_level_optim.py b/colossalai/zero/low_level/low_level_optim.py index 368fa8ae6..1c6417d45 100644 --- a/colossalai/zero/low_level/low_level_optim.py +++ b/colossalai/zero/low_level/low_level_optim.py @@ -880,7 +880,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper): return None grad_flat = torch.empty((grad_store.world_size, *grad.shape), dtype=grad.dtype, device=grad.device) dist.all_gather_into_tensor(grad_flat, grad, group=grad_store.torch_pg) - return grad_flat[: working_param.numel()].reshape_as(working_param) + return grad_flat.view(-1)[: working_param.numel()].view_as(working_param) def get_working_grads_by_group_id(self, group_id: int) -> List[Tensor]: working_grads = [] diff --git a/tests/test_shardformer/test_model/test_shard_deepseek.py b/tests/test_shardformer/test_model/test_shard_deepseek.py index 16513b2f5..c301777f2 100644 --- a/tests/test_shardformer/test_model/test_shard_deepseek.py +++ b/tests/test_shardformer/test_model/test_shard_deepseek.py @@ -179,9 +179,9 @@ def run_dist(rank, world_size, port): @pytest.mark.dist @pytest.mark.parametrize("world_size", [4]) @rerun_if_address_is_in_use() -def test_mistral(world_size): +def test_deepseek(world_size): spawn(run_dist, world_size) if __name__ == "__main__": - test_mistral(world_size=8) + test_deepseek(world_size=4) diff --git a/tests/test_shardformer/test_model/test_shard_mixtral.py b/tests/test_shardformer/test_model/test_shard_mixtral.py index 2b8623e13..419679797 100644 --- a/tests/test_shardformer/test_model/test_shard_mixtral.py +++ b/tests/test_shardformer/test_model/test_shard_mixtral.py @@ -180,9 +180,9 @@ def run_dist(rank, world_size, port): @pytest.mark.dist @pytest.mark.parametrize("world_size", [4]) @rerun_if_address_is_in_use() -def test_mistral(world_size): +def test_mixtral(world_size): spawn(run_dist, world_size) if __name__ == "__main__": - test_mistral(world_size=8) + test_mixtral(world_size=4)