mirror of https://github.com/hpcaitech/ColossalAI
[NFC] polish colossalai/nn/layer/parallel_3d/_operation.py code style (#1258)
Co-authored-by: Research <research@soccf-snr3-017.comp.nus.edu.sg>pull/1298/head
parent
9738fb0f78
commit
f660152c73
|
@ -326,10 +326,8 @@ def split_batch_3d(input_: Tensor,
|
|||
|
||||
if input_.size(dim) <= 1:
|
||||
return input_
|
||||
output = torch.chunk(input_, weight_world_size,
|
||||
dim=dim)[gpc.get_local_rank(weight_parallel_mode)].contiguous()
|
||||
output = torch.chunk(output, input_world_size,
|
||||
dim=dim)[gpc.get_local_rank(input_parallel_mode)].contiguous()
|
||||
output = torch.chunk(input_, weight_world_size, dim=dim)[gpc.get_local_rank(weight_parallel_mode)].contiguous()
|
||||
output = torch.chunk(output, input_world_size, dim=dim)[gpc.get_local_rank(input_parallel_mode)].contiguous()
|
||||
return output
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue