mirror of https://github.com/hpcaitech/ColossalAI
[chunk] add PG check for tensor appending (#1383)
parent
8dced41ad0
commit
f792507ff3
|
@ -3,7 +3,7 @@ from typing import Optional, Dict, Deque, Set, List, Tuple, Iterable
|
||||||
from collections import deque
|
from collections import deque
|
||||||
|
|
||||||
from colossalai.utils import get_current_device
|
from colossalai.utils import get_current_device
|
||||||
from colossalai.tensor import ProcessGroup as ColoProcessGroup
|
from colossalai.tensor import ProcessGroup as ColoProcessGroup, ColoTensor
|
||||||
from .chunk import Chunk, ChunkFullError, TensorState
|
from .chunk import Chunk, ChunkFullError, TensorState
|
||||||
|
|
||||||
|
|
||||||
|
@ -13,6 +13,7 @@ class ChunkManager:
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
chunk_size (int): the size of a chunk.
|
chunk_size (int): the size of a chunk.
|
||||||
|
process_group (ColoProcessGroup): process group of the chunk.
|
||||||
enable_distributed_storage (bool): optional, allow for distributed storage of a chunk. The default is false.
|
enable_distributed_storage (bool): optional, allow for distributed storage of a chunk. The default is false.
|
||||||
init_device (torch.device): optional, the device on which the chunk is initialized. The default is None.
|
init_device (torch.device): optional, the device on which the chunk is initialized. The default is None.
|
||||||
"""
|
"""
|
||||||
|
@ -57,6 +58,9 @@ class ChunkManager:
|
||||||
group_name (str): the name of the chunk group.
|
group_name (str): the name of the chunk group.
|
||||||
"""
|
"""
|
||||||
assert tensor not in self.tensor_chunk_map
|
assert tensor not in self.tensor_chunk_map
|
||||||
|
if isinstance(tensor, ColoTensor):
|
||||||
|
assert tensor.get_process_group().dp_process_group() == self.process_group.dp_process_group(
|
||||||
|
), f"Chunk Manager can only manage ColoTensor with the same DP process group"
|
||||||
if self.chunk_size is not None and tensor.numel() > self.chunk_size:
|
if self.chunk_size is not None and tensor.numel() > self.chunk_size:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f'Cannot create chunk, got tensor numel ({tensor.numel()}) > chunk size ({self.chunk_size})')
|
f'Cannot create chunk, got tensor numel ({tensor.numel()}) > chunk size ({self.chunk_size})')
|
||||||
|
|
Loading…
Reference in New Issue