From 1a229045af97767a21223ee1b3694c9aedac154e Mon Sep 17 00:00:00 2001 From: YH <100389977+yhna940@users.noreply.github.com> Date: Mon, 27 Mar 2023 10:42:21 +0900 Subject: [PATCH] Add interface for colo tesnor dp size (#3227) --- colossalai/gemini/chunk/manager.py | 2 +- colossalai/tensor/colo_tensor.py | 9 +++++++++ 2 files changed, 10 insertions(+), 1 deletion(-) diff --git a/colossalai/gemini/chunk/manager.py b/colossalai/gemini/chunk/manager.py index 30ac4d354..2fa65c970 100644 --- a/colossalai/gemini/chunk/manager.py +++ b/colossalai/gemini/chunk/manager.py @@ -72,7 +72,7 @@ class ChunkManager: if tensor.numel() > chunk_size: chunk_size = tensor.numel() - dp_size = tensor.process_group.dp_world_size() + dp_size = tensor.get_dp_world_size() chunk_size = chunk_size + (-chunk_size % dp_size) chunk = Chunk( diff --git a/colossalai/tensor/colo_tensor.py b/colossalai/tensor/colo_tensor.py index bbed8847a..40eefc3ec 100644 --- a/colossalai/tensor/colo_tensor.py +++ b/colossalai/tensor/colo_tensor.py @@ -138,6 +138,15 @@ class ColoTensor(torch.Tensor): def get_tp_world_size(self) -> int: return self.process_group.tp_world_size() + def get_dp_world_size(self) -> int: + """get_dp_world_size + get the dp world size of the tensor. + + Returns: + int: dp world size + """ + return self.process_group.dp_world_size() + def set_dist_spec(self, dist_spec: _DistSpec): """set_dist_spec set dist spec and change the payloads.