diff --git a/colossalai/gemini/chunk/manager.py b/colossalai/gemini/chunk/manager.py index 30ac4d354..2fa65c970 100644 --- a/colossalai/gemini/chunk/manager.py +++ b/colossalai/gemini/chunk/manager.py @@ -72,7 +72,7 @@ class ChunkManager: if tensor.numel() > chunk_size: chunk_size = tensor.numel() - dp_size = tensor.process_group.dp_world_size() + dp_size = tensor.get_dp_world_size() chunk_size = chunk_size + (-chunk_size % dp_size) chunk = Chunk( diff --git a/colossalai/tensor/colo_tensor.py b/colossalai/tensor/colo_tensor.py index bbed8847a..40eefc3ec 100644 --- a/colossalai/tensor/colo_tensor.py +++ b/colossalai/tensor/colo_tensor.py @@ -138,6 +138,15 @@ class ColoTensor(torch.Tensor): def get_tp_world_size(self) -> int: return self.process_group.tp_world_size() + def get_dp_world_size(self) -> int: + """get_dp_world_size + get the dp world size of the tensor. + + Returns: + int: dp world size + """ + return self.process_group.dp_world_size() + def set_dist_spec(self, dist_spec: _DistSpec): """set_dist_spec set dist spec and change the payloads.