From 4b01da24cd6bc424190cf1c7663f8740461f4bd7 Mon Sep 17 00:00:00 2001 From: Ziyue Jiang Date: Sat, 16 Apr 2022 21:29:57 +0800 Subject: [PATCH] [TP] change the check assert in split batch 2d (#772) --- colossalai/nn/layer/parallel_2d/_operation.py | 6 ++++-- colossalai/nn/layer/parallel_2p5d/_operation.py | 6 ++++-- 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/colossalai/nn/layer/parallel_2d/_operation.py b/colossalai/nn/layer/parallel_2d/_operation.py index 592924bd4..306577dbd 100644 --- a/colossalai/nn/layer/parallel_2d/_operation.py +++ b/colossalai/nn/layer/parallel_2d/_operation.py @@ -739,11 +739,13 @@ def split_batch_2d(input_: Tensor, dim: int = 0) -> Tensor: """ dim_size = input_.size(dim) world_size = gpc.get_world_size(ParallelMode.PARALLEL_2D_COL) + + if world_size <= 1: + return input_ + assert dim_size % world_size == 0, \ f'The batch size ({dim_size}) is not a multiple of 2D size ({world_size}).' - if input_.size(dim) <= 1: - return input_ return torch.chunk(input_, gpc.get_world_size(ParallelMode.PARALLEL_2D_COL), dim=dim)[gpc.get_local_rank(ParallelMode.PARALLEL_2D_COL)].contiguous() diff --git a/colossalai/nn/layer/parallel_2p5d/_operation.py b/colossalai/nn/layer/parallel_2p5d/_operation.py index 0bcc8ecee..5a0f537cd 100644 --- a/colossalai/nn/layer/parallel_2p5d/_operation.py +++ b/colossalai/nn/layer/parallel_2p5d/_operation.py @@ -770,11 +770,13 @@ def split_batch_2p5d(input_: Tensor, dim: int = 0) -> Tensor: """ dim_size = input_.size(dim) world_size = gpc.get_world_size(ParallelMode.PARALLEL_2P5D_COL) + + if world_size <= 1: + return input_ + assert dim_size % world_size == 0, \ f'The batch size ({dim_size}) is not a multiple of 2.5D size * depth ({world_size}).' - if input_.size(dim) <= 1: - return input_ return torch.chunk(input_, gpc.get_world_size(ParallelMode.PARALLEL_2P5D_COL), dim=dim)[gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL)].contiguous()