[NFC] polish colossalai/nn/layer/parallel_3d/_operation.py code style (#1258)

Co-authored-by: Research <research@soccf-snr3-017.comp.nus.edu.sg>
pull/1298/head
superhao1995 2 years ago committed by Frank Lee
parent 9738fb0f78
commit f660152c73

@ -326,10 +326,8 @@ def split_batch_3d(input_: Tensor,
if input_.size(dim) <= 1:
return input_
output = torch.chunk(input_, weight_world_size,
dim=dim)[gpc.get_local_rank(weight_parallel_mode)].contiguous()
output = torch.chunk(output, input_world_size,
dim=dim)[gpc.get_local_rank(input_parallel_mode)].contiguous()
output = torch.chunk(input_, weight_world_size, dim=dim)[gpc.get_local_rank(weight_parallel_mode)].contiguous()
output = torch.chunk(output, input_world_size, dim=dim)[gpc.get_local_rank(input_parallel_mode)].contiguous()
return output

Loading…
Cancel
Save