[chunk] add PG check for tensor appending (#1383)

pull/1386/head
Jiarui Fang 2022-07-29 13:27:05 +08:00 committed by GitHub
parent 8dced41ad0
commit f792507ff3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 5 additions and 1 deletions

View File

@ -3,7 +3,7 @@ from typing import Optional, Dict, Deque, Set, List, Tuple, Iterable
from collections import deque
from colossalai.utils import get_current_device
from colossalai.tensor import ProcessGroup as ColoProcessGroup
from colossalai.tensor import ProcessGroup as ColoProcessGroup, ColoTensor
from .chunk import Chunk, ChunkFullError, TensorState
@ -13,6 +13,7 @@ class ChunkManager:
Args:
chunk_size (int): the size of a chunk.
process_group (ColoProcessGroup): process group of the chunk.
enable_distributed_storage (bool): optional, allow for distributed storage of a chunk. The default is false.
init_device (torch.device): optional, the device on which the chunk is initialized. The default is None.
"""
@ -57,6 +58,9 @@ class ChunkManager:
group_name (str): the name of the chunk group.
"""
assert tensor not in self.tensor_chunk_map
if isinstance(tensor, ColoTensor):
assert tensor.get_process_group().dp_process_group() == self.process_group.dp_process_group(
), f"Chunk Manager can only manage ColoTensor with the same DP process group"
if self.chunk_size is not None and tensor.numel() > self.chunk_size:
raise ValueError(
f'Cannot create chunk, got tensor numel ({tensor.numel()}) > chunk size ({self.chunk_size})')