Browse Source

Add interface for colo tesnor dp size (#3227)

pull/3253/head
YH 2 years ago committed by GitHub
parent
commit
1a229045af
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 2
      colossalai/gemini/chunk/manager.py
  2. 9
      colossalai/tensor/colo_tensor.py

2
colossalai/gemini/chunk/manager.py

@ -72,7 +72,7 @@ class ChunkManager:
if tensor.numel() > chunk_size: if tensor.numel() > chunk_size:
chunk_size = tensor.numel() chunk_size = tensor.numel()
dp_size = tensor.process_group.dp_world_size() dp_size = tensor.get_dp_world_size()
chunk_size = chunk_size + (-chunk_size % dp_size) chunk_size = chunk_size + (-chunk_size % dp_size)
chunk = Chunk( chunk = Chunk(

9
colossalai/tensor/colo_tensor.py

@ -138,6 +138,15 @@ class ColoTensor(torch.Tensor):
def get_tp_world_size(self) -> int: def get_tp_world_size(self) -> int:
return self.process_group.tp_world_size() return self.process_group.tp_world_size()
def get_dp_world_size(self) -> int:
"""get_dp_world_size
get the dp world size of the tensor.
Returns:
int: dp world size
"""
return self.process_group.dp_world_size()
def set_dist_spec(self, dist_spec: _DistSpec): def set_dist_spec(self, dist_spec: _DistSpec):
"""set_dist_spec """set_dist_spec
set dist spec and change the payloads. set dist spec and change the payloads.

Loading…
Cancel
Save