|
|
@ -770,11 +770,13 @@ def split_batch_2p5d(input_: Tensor, dim: int = 0) -> Tensor: |
|
|
|
""" |
|
|
|
""" |
|
|
|
dim_size = input_.size(dim) |
|
|
|
dim_size = input_.size(dim) |
|
|
|
world_size = gpc.get_world_size(ParallelMode.PARALLEL_2P5D_COL) |
|
|
|
world_size = gpc.get_world_size(ParallelMode.PARALLEL_2P5D_COL) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if world_size <= 1: |
|
|
|
|
|
|
|
return input_ |
|
|
|
|
|
|
|
|
|
|
|
assert dim_size % world_size == 0, \ |
|
|
|
assert dim_size % world_size == 0, \ |
|
|
|
f'The batch size ({dim_size}) is not a multiple of 2.5D size * depth ({world_size}).' |
|
|
|
f'The batch size ({dim_size}) is not a multiple of 2.5D size * depth ({world_size}).' |
|
|
|
|
|
|
|
|
|
|
|
if input_.size(dim) <= 1: |
|
|
|
|
|
|
|
return input_ |
|
|
|
|
|
|
|
return torch.chunk(input_, gpc.get_world_size(ParallelMode.PARALLEL_2P5D_COL), |
|
|
|
return torch.chunk(input_, gpc.get_world_size(ParallelMode.PARALLEL_2P5D_COL), |
|
|
|
dim=dim)[gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL)].contiguous() |
|
|
|
dim=dim)[gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL)].contiguous() |
|
|
|
|
|
|
|
|
|
|
|