mirror of https://github.com/hpcaitech/ColossalAI
[TP] change the check assert in split batch 2d (#772)
parent
846406a07a
commit
4b01da24cd
|
@ -739,11 +739,13 @@ def split_batch_2d(input_: Tensor, dim: int = 0) -> Tensor:
|
||||||
"""
|
"""
|
||||||
dim_size = input_.size(dim)
|
dim_size = input_.size(dim)
|
||||||
world_size = gpc.get_world_size(ParallelMode.PARALLEL_2D_COL)
|
world_size = gpc.get_world_size(ParallelMode.PARALLEL_2D_COL)
|
||||||
|
|
||||||
|
if world_size <= 1:
|
||||||
|
return input_
|
||||||
|
|
||||||
assert dim_size % world_size == 0, \
|
assert dim_size % world_size == 0, \
|
||||||
f'The batch size ({dim_size}) is not a multiple of 2D size ({world_size}).'
|
f'The batch size ({dim_size}) is not a multiple of 2D size ({world_size}).'
|
||||||
|
|
||||||
if input_.size(dim) <= 1:
|
|
||||||
return input_
|
|
||||||
return torch.chunk(input_, gpc.get_world_size(ParallelMode.PARALLEL_2D_COL),
|
return torch.chunk(input_, gpc.get_world_size(ParallelMode.PARALLEL_2D_COL),
|
||||||
dim=dim)[gpc.get_local_rank(ParallelMode.PARALLEL_2D_COL)].contiguous()
|
dim=dim)[gpc.get_local_rank(ParallelMode.PARALLEL_2D_COL)].contiguous()
|
||||||
|
|
||||||
|
|
|
@ -770,11 +770,13 @@ def split_batch_2p5d(input_: Tensor, dim: int = 0) -> Tensor:
|
||||||
"""
|
"""
|
||||||
dim_size = input_.size(dim)
|
dim_size = input_.size(dim)
|
||||||
world_size = gpc.get_world_size(ParallelMode.PARALLEL_2P5D_COL)
|
world_size = gpc.get_world_size(ParallelMode.PARALLEL_2P5D_COL)
|
||||||
|
|
||||||
|
if world_size <= 1:
|
||||||
|
return input_
|
||||||
|
|
||||||
assert dim_size % world_size == 0, \
|
assert dim_size % world_size == 0, \
|
||||||
f'The batch size ({dim_size}) is not a multiple of 2.5D size * depth ({world_size}).'
|
f'The batch size ({dim_size}) is not a multiple of 2.5D size * depth ({world_size}).'
|
||||||
|
|
||||||
if input_.size(dim) <= 1:
|
|
||||||
return input_
|
|
||||||
return torch.chunk(input_, gpc.get_world_size(ParallelMode.PARALLEL_2P5D_COL),
|
return torch.chunk(input_, gpc.get_world_size(ParallelMode.PARALLEL_2P5D_COL),
|
||||||
dim=dim)[gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL)].contiguous()
|
dim=dim)[gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL)].contiguous()
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue