mirror of https://github.com/hpcaitech/ColossalAI
181 lines
6.2 KiB
Python
181 lines
6.2 KiB
Python
from typing import Dict, List, Optional, Set
|
|
|
|
import torch
|
|
import torch.distributed as dist
|
|
from torch.distributed import ProcessGroup
|
|
|
|
from .chunk import Chunk
|
|
from .memory_pool import MemoryPool, TensorBlock
|
|
from .states import TensorState
|
|
|
|
|
|
class ChunkGroup(object):
|
|
"""ChunkGroup manages a group of chunks and their memory pool.
|
|
Commonly, one model has one chunk group.
|
|
It supports chunk allocation, chunk access, and chunk release.
|
|
ChunkGroup is responsible for the memory management before its APIs.
|
|
|
|
args:
|
|
rcache: A memory pool to instantiate chunks.
|
|
"""
|
|
|
|
def __init__(self, rcache: MemoryPool) -> None:
|
|
super().__init__()
|
|
self.rcache = rcache
|
|
self.fused_chunks: Set[Chunk] = set()
|
|
self.float_chunks: Set[Chunk] = set()
|
|
self.ten_to_chunk: Dict[torch.Tensor, Chunk] = dict()
|
|
|
|
self.accessed_fused_chunks: Set[Chunk] = set()
|
|
self.accessed_float_chunks: Set[Chunk] = set()
|
|
|
|
def __add_to_accset(self, chunk: Chunk):
|
|
if chunk.rcache_fused:
|
|
self.accessed_fused_chunks.add(chunk)
|
|
else:
|
|
self.accessed_float_chunks.add(chunk)
|
|
|
|
def __remove_from_accset(self, chunk: Chunk):
|
|
if chunk.rcache_fused:
|
|
self.accessed_fused_chunks.remove(chunk)
|
|
else:
|
|
self.accessed_float_chunks.remove(chunk)
|
|
|
|
def __check_new_float_chunk(self, size: int, dtype: torch.dtype):
|
|
# if the public space is 0, there is no access operations
|
|
if self.rcache.public_space == 0:
|
|
return
|
|
# otherwise, check its size and dtype
|
|
assert size == self.rcache.public_block_size
|
|
assert dtype == self.rcache.public_dtype
|
|
|
|
def inside_check(self, chunk: Chunk) -> None:
|
|
"""Check whether the chunk is in this ChunkGroup"""
|
|
if chunk.rcache_fused:
|
|
assert chunk in self.fused_chunks
|
|
else:
|
|
assert chunk in self.float_chunks
|
|
|
|
def is_accessed(self, chunk: Chunk) -> bool:
|
|
"""Chech whether the chunk is accessed."""
|
|
# sanity check
|
|
self.inside_check(chunk)
|
|
|
|
if chunk.rcache_fused:
|
|
return (chunk in self.accessed_fused_chunks)
|
|
else:
|
|
return (chunk in self.accessed_float_chunks)
|
|
|
|
def open_chunk(self,
|
|
chunk_size: int,
|
|
chunk_dtype: torch.dtype,
|
|
process_group: ProcessGroup,
|
|
chunk_config: Optional[Dict] = None) -> Chunk:
|
|
"""Open a chunk to store parameters."""
|
|
if chunk_config is None:
|
|
chunk_config = {}
|
|
|
|
chunk = Chunk(rcache=self.rcache,
|
|
chunk_size=chunk_size,
|
|
chunk_dtype=chunk_dtype,
|
|
process_group=process_group,
|
|
**chunk_config)
|
|
# sanity check
|
|
if not chunk.rcache_fused:
|
|
self.__check_new_float_chunk(chunk_size, chunk_dtype)
|
|
|
|
return chunk
|
|
|
|
def close_chunk(self, chunk: Chunk) -> bool:
|
|
"""Close the chunk during the allocation."""
|
|
chunk.close_chunk()
|
|
# add the new chunk to the set of allocated chunks
|
|
if chunk.rcache_fused:
|
|
self.fused_chunks.add(chunk)
|
|
else:
|
|
self.float_chunks.add(chunk)
|
|
# add the new chunk to the mapping
|
|
for t in chunk.get_tensors():
|
|
assert t not in self.ten_to_chunk
|
|
self.ten_to_chunk[t] = chunk
|
|
return True
|
|
|
|
def allocate_chunk(self,
|
|
tensor_list: List[torch.Tensor],
|
|
chunk_size: int,
|
|
chunk_dtype: torch.dtype,
|
|
process_group: ProcessGroup,
|
|
chunk_config: Optional[Dict] = None) -> Chunk:
|
|
"""Allocate a chunk for a list of parameters."""
|
|
chunk = self.open_chunk(chunk_size, chunk_dtype, process_group, chunk_config)
|
|
# append tensors
|
|
for t in tensor_list:
|
|
chunk.append_tensor(t)
|
|
self.close_chunk(chunk)
|
|
|
|
return chunk
|
|
|
|
def tensors_to_chunks(self, tensor_list: List[torch.Tensor]) -> List[Chunk]:
|
|
"""Get the chunks of a gevien list of tensors."""
|
|
chunk_list = list()
|
|
for tensor in tensor_list:
|
|
chunk = self.ten_to_chunk.get(tensor)
|
|
if chunk not in chunk_list:
|
|
chunk_list.append(chunk)
|
|
chunk_list.sort(key=lambda c: c.chunk_id)
|
|
return chunk_list
|
|
|
|
def rcache_enough_check(self, chunk: Chunk) -> bool:
|
|
"""Check whether the rcache has enough blocks to store the gathered chunk."""
|
|
if chunk.rcache_fused:
|
|
return True
|
|
return self.rcache.public_free_count > 0
|
|
|
|
def access_chunk(self, chunk: Chunk) -> bool:
|
|
"""Access a chunk into rCache."""
|
|
self.inside_check(chunk)
|
|
# if this chunk is accessed already, return False
|
|
if self.is_accessed(chunk):
|
|
return False
|
|
|
|
if chunk.rcache_fused:
|
|
block = None
|
|
else:
|
|
block = self.rcache.pop_public_block()
|
|
chunk.access_chunk(block)
|
|
self.__add_to_accset(chunk)
|
|
return True
|
|
|
|
def release_chunk(self, chunk: Chunk) -> bool:
|
|
"""Release a chunk from rCache."""
|
|
self.inside_check(chunk)
|
|
assert self.is_accessed(chunk)
|
|
assert chunk.scatter_check
|
|
|
|
block = chunk.release_chunk()
|
|
if block:
|
|
self.rcache.free_public_block(block)
|
|
self.__remove_from_accset(chunk)
|
|
return True
|
|
|
|
def reduce_chunk(self, chunk: Chunk, always_fp32: bool = False, sync: bool = True) -> Optional[TensorBlock]:
|
|
"""Reduce and scatter a gradient chunk from rCache."""
|
|
self.inside_check(chunk)
|
|
assert self.is_accessed(chunk)
|
|
assert chunk.reduce_check
|
|
|
|
block = chunk.reduce_chunk(always_fp32=always_fp32, sync=sync)
|
|
if block and sync:
|
|
# if synchronized, free the block into rcache
|
|
self.rcache.free_public_block(block)
|
|
block = None
|
|
|
|
self.__remove_from_accset(chunk)
|
|
|
|
return block
|
|
|
|
def tensor_trans_state(self, tensor: torch.Tensor, state: TensorState):
|
|
"""Transform the state of a tensor."""
|
|
chunk = self.ten_to_chunk.get(tensor)
|
|
chunk.tensor_trans_state(tensor, state)
|