From f660152c73511ea5d7d63be2b382a34cdc3ae56f Mon Sep 17 00:00:00 2001 From: superhao1995 <804673818@qq.com> Date: Tue, 12 Jul 2022 17:14:15 +0800 Subject: [PATCH] [NFC] polish colossalai/nn/layer/parallel_3d/_operation.py code style (#1258) Co-authored-by: Research --- colossalai/nn/layer/parallel_3d/_operation.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/colossalai/nn/layer/parallel_3d/_operation.py b/colossalai/nn/layer/parallel_3d/_operation.py index b4d6f734e..eb045f2b4 100644 --- a/colossalai/nn/layer/parallel_3d/_operation.py +++ b/colossalai/nn/layer/parallel_3d/_operation.py @@ -326,10 +326,8 @@ def split_batch_3d(input_: Tensor, if input_.size(dim) <= 1: return input_ - output = torch.chunk(input_, weight_world_size, - dim=dim)[gpc.get_local_rank(weight_parallel_mode)].contiguous() - output = torch.chunk(output, input_world_size, - dim=dim)[gpc.get_local_rank(input_parallel_mode)].contiguous() + output = torch.chunk(input_, weight_world_size, dim=dim)[gpc.get_local_rank(weight_parallel_mode)].contiguous() + output = torch.chunk(output, input_world_size, dim=dim)[gpc.get_local_rank(input_parallel_mode)].contiguous() return output