mirror of https://github.com/hpcaitech/ColossalAI
[chunk] add PG check for tensor appending (#1383)
parent
8dced41ad0
commit
f792507ff3
|
@ -3,7 +3,7 @@ from typing import Optional, Dict, Deque, Set, List, Tuple, Iterable
|
|||
from collections import deque
|
||||
|
||||
from colossalai.utils import get_current_device
|
||||
from colossalai.tensor import ProcessGroup as ColoProcessGroup
|
||||
from colossalai.tensor import ProcessGroup as ColoProcessGroup, ColoTensor
|
||||
from .chunk import Chunk, ChunkFullError, TensorState
|
||||
|
||||
|
||||
|
@ -13,6 +13,7 @@ class ChunkManager:
|
|||
|
||||
Args:
|
||||
chunk_size (int): the size of a chunk.
|
||||
process_group (ColoProcessGroup): process group of the chunk.
|
||||
enable_distributed_storage (bool): optional, allow for distributed storage of a chunk. The default is false.
|
||||
init_device (torch.device): optional, the device on which the chunk is initialized. The default is None.
|
||||
"""
|
||||
|
@ -57,6 +58,9 @@ class ChunkManager:
|
|||
group_name (str): the name of the chunk group.
|
||||
"""
|
||||
assert tensor not in self.tensor_chunk_map
|
||||
if isinstance(tensor, ColoTensor):
|
||||
assert tensor.get_process_group().dp_process_group() == self.process_group.dp_process_group(
|
||||
), f"Chunk Manager can only manage ColoTensor with the same DP process group"
|
||||
if self.chunk_size is not None and tensor.numel() > self.chunk_size:
|
||||
raise ValueError(
|
||||
f'Cannot create chunk, got tensor numel ({tensor.numel()}) > chunk size ({self.chunk_size})')
|
||||
|
|
Loading…
Reference in New Issue