diff --git a/colossalai/nn/layer/parallel_3d/_operation.py b/colossalai/nn/layer/parallel_3d/_operation.py index b4d6f734e..eb045f2b4 100644 --- a/colossalai/nn/layer/parallel_3d/_operation.py +++ b/colossalai/nn/layer/parallel_3d/_operation.py @@ -326,10 +326,8 @@ def split_batch_3d(input_: Tensor, if input_.size(dim) <= 1: return input_ - output = torch.chunk(input_, weight_world_size, - dim=dim)[gpc.get_local_rank(weight_parallel_mode)].contiguous() - output = torch.chunk(output, input_world_size, - dim=dim)[gpc.get_local_rank(input_parallel_mode)].contiguous() + output = torch.chunk(input_, weight_world_size, dim=dim)[gpc.get_local_rank(weight_parallel_mode)].contiguous() + output = torch.chunk(output, input_world_size, dim=dim)[gpc.get_local_rank(input_parallel_mode)].contiguous() return output