import torch from typing import Optional, Dict, Deque, Set, List, Tuple, Iterable from collections import deque from colossalai.context import ParallelMode from colossalai.core import global_context as gpc from colossalai.utils import get_current_device from .chunk import Chunk, ChunkFullError, TensorState class ChunkManager: """ A manager class to manipulate the tensors in chunks. Args: chunk_size (int): the size of a chunk. enable_distributed_storage (bool): optional, allow for distributed storage of a chunk. The default is false. init_device (torch.device): optional, the device on which the chunk is initialized. The default is None. """ def __init__(self, chunk_size: Optional[int], enable_distributed_storage: bool = False, init_device: Optional[torch.device] = None) -> None: assert chunk_size is None or chunk_size > 0 self.chunk_size = chunk_size self.enable_distributed_storage = enable_distributed_storage self.device = init_device or get_current_device() self.chunk_groups: Dict[str, Deque[Chunk]] = {} self.groups_force_data_on_cuda: Dict[str, bool] = {} self.tensor_chunk_map: Dict[torch.Tensor, Chunk] = {} self.accessed_chunks: Set[Chunk] = set() self.lazy_release_tensors: List[torch.Tensor] = [] if enable_distributed_storage and chunk_size is None: self.rank_load: Dict[str, torch.Tensor] = {} self.total_mem: Dict[str, int] = {'cpu': 0, 'cuda': 0} def create_group(self, group_name: str, force_data_on_cuda: bool = False) -> None: """Create a chunk group. Args: group_name (str): group name force_data_on_cuda (bool, optional): If True, the data of chunks in this group is always on cuda.. Defaults to False. """ assert group_name not in self.chunk_groups self.chunk_groups[group_name] = deque() self.groups_force_data_on_cuda[group_name] = force_data_on_cuda def append_tensor(self, tensor: torch.Tensor, group_name: str) -> None: """ Append a tensor to a chunk. Args: tensor (torch.Tensor): a tensor to append to the chunk. group_name (str): the name of the chunk group. """ assert tensor not in self.tensor_chunk_map if self.chunk_size is not None and tensor.numel() > self.chunk_size: raise ValueError( f'Cannot create chunk, got tensor numel ({tensor.numel()}) > chunk size ({self.chunk_size})') try: # append the tensor to the last chunk self.chunk_groups[group_name][-1].append(tensor) except (IndexError, ChunkFullError): # the except statement will be triggered when there is no chunk or # the last chunk in the chunk group is full # this will create a new chunk and allocate this chunk to its corresponding process chunk_size = self.chunk_size or tensor.numel() src_rank = self._get_next_src_rank(group_name) chunk = Chunk(chunk_size, src_rank, tensor.dtype, self.device, force_data_on_cuda=self.groups_force_data_on_cuda[group_name]) if self.enable_distributed_storage and self.chunk_size is None: self.rank_load[group_name][src_rank] += chunk_size self.chunk_groups[group_name].append(chunk) chunk.append(tensor) if not chunk.is_empty: self.total_mem[chunk.device_type] += chunk.mem self.tensor_chunk_map[tensor] = self.chunk_groups[group_name][-1] if not self.enable_distributed_storage: # as distributed storage is not enabled, there is no need to broadcast # chunks, thus we set these chunks as accessed self.accessed_chunks.add(self.chunk_groups[group_name][-1]) def _get_next_src_rank(self, group_name: str) -> int: if not self.enable_distributed_storage: # the chunk is owned by the current rank if no distributed storage is enabled return gpc.get_local_rank(ParallelMode.DATA) if self.chunk_size is None: if group_name not in self.rank_load: self.rank_load[group_name] = torch.zeros(gpc.get_world_size(ParallelMode.DATA), dtype=torch.int64) # the process owning the tensor will be the process with the smallest number of elements src_rank = torch.argmin(self.rank_load[group_name]).item() else: # chunk is owned by processes in a round-robin fashion chunk_idx = len(self.chunk_groups[group_name]) src_rank = chunk_idx % gpc.get_world_size(ParallelMode.DATA) return src_rank def access_chunk(self, chunk: Chunk) -> None: """ Synchronize the chunks via broadcast. Args: chunk (Chunk): the chunk to synchronize. """ if chunk in self.accessed_chunks: if chunk.device_type != 'cuda': self.total_mem[chunk.device_type] -= chunk.mem chunk.move_device(get_current_device()) self.total_mem[chunk.device_type] += chunk.mem return if not chunk.is_empty: # as tensor is moved to the target device # the memory consumption of the original device is reduced self.total_mem[chunk.device_type] -= chunk.mem chunk.access() self.accessed_chunks.add(chunk) self.total_mem[chunk.device_type] += chunk.mem def release_chunk(self, chunk: Chunk) -> None: """ Release the memory space of a chunk. Args: chunk (Chunk): the chunk to release memory space """ if not self.enable_distributed_storage: return if chunk not in self.accessed_chunks: return if chunk.can_release: chunk.release() self.accessed_chunks.remove(chunk) if chunk.is_empty: # update the memory consumption after releasing self.total_mem[chunk.device_type] -= chunk.mem def move_chunk(self, chunk: Chunk, device: torch.device, update_ptr: bool = True) -> None: """ Move the chunk to the target device. Args: chunk (Chunk): the chunk to move to target device device (torch.device): target device """ if chunk.device_type == device.type: return if chunk.can_move_device and not chunk.is_empty: self.total_mem[chunk.device_type] -= chunk.mem chunk.move_device(device, update_ptr=update_ptr) self.total_mem[chunk.device_type] += chunk.mem def trans_tensor_state(self, tensor: torch.Tensor, state: TensorState) -> None: """ Transit tensor state according to pre-defined state machine. Args: tensor (torch.Tensor): the tensor for state transititon state (TensorState): next tensor state for transtition """ chunk = self.tensor_chunk_map[tensor] chunk.tensor_trans_state(tensor, state) def reduce_chunk(self, chunk: Chunk) -> bool: """ Reduce or all reduce the chunk. If enable_distributed_storage is true, all-reduce is used. Otherwise, this method uses reduce. Args: chunk (Chunk): the chunk for reduction. """ if not chunk.can_reduce: return False self.total_mem[chunk.device_type] -= chunk.mem chunk.reduce(is_all_reduce=not self.enable_distributed_storage) self.total_mem[chunk.device_type] += chunk.mem return True def copy_tensor_to_chunk_slice(self, tensor: torch.Tensor, data: torch.Tensor) -> None: """ Copy data to the chunk. Args: tensor (torch.Tensor): the tensor used to retrive meta information data (torch.Tensor): the tensor to be copied to the chunk """ chunk = self.tensor_chunk_map[tensor] chunk.copy_tensor_to_chunk_slice(tensor, data) def get_chunk(self, tensor: torch.Tensor) -> Chunk: """ Return the chunk owning the tensor. Args: tensor (torch.Tensor): a torch tensor object """ return self.tensor_chunk_map[tensor] def add_lazy_release_tensors(self, tensors: List[torch.Tensor]) -> None: """ Add tensors to the buffer for lazy release. Args: tensors (List[torch.Tensor]): the tensors to be released lazily """ self.lazy_release_tensors.extend(tensors) def exec_lazy_release(self) -> None: """ Execute release for tensors added to the lazy release buffer. """ for chunk in self.get_chunks(self.lazy_release_tensors): self.release_chunk(chunk) self.lazy_release_tensors.clear() def __repr__(self) -> str: msg = f'Rank {gpc.get_local_rank(ParallelMode.DATA)}:\n' msg += 'Total memory: ' + ', '.join([f'{k}={v}B' for k, v in self.total_mem.items()]) + '\n' for group_name, group in self.chunk_groups.items(): msg += f'Group {group_name}:\n' for i, chunk in enumerate(group): msg += f'[{i}] {chunk}\n' return msg @staticmethod def get_chunk_util(chunk_size: int, params_numel: List[int]) -> float: """ Calculate the utilization rate of a chunk. Args: chunk_size (int): the size of a chunk params_numel (List[int]): the list of integers representing the number of elements of parameters """ assert len(params_numel) > 0 total_size = 0 total_utilized_size = 0 cur_chunk_utilized_size = 0 for size in params_numel: assert chunk_size >= size total_utilized_size += size if total_size == 0 or cur_chunk_utilized_size + size > chunk_size: total_size += chunk_size cur_chunk_utilized_size = 0 cur_chunk_utilized_size += size return total_utilized_size / total_size @staticmethod def search_chunk_size(module: torch.nn.Module, search_range: int, n_grids: int, min_chunk_size: Optional[int] = None) -> int: """ Search for the chunk size for optimal chunk utilization. Args: module (torch.nn.Module): a torch module object search_range (int): the range of chunk size to search. The actual search range will be from max(min_chunk_size, max_param_size) to max(min_chunk_size, max_param_size) + search_range. n_grids (int): the number of intervals in the search range min_chunk_size (int): optional, the minimum size for a chunk. The default is None. """ assert search_range % n_grids == 0 # TODO(ver217): sort params and filter unused ones params_numel = [p.numel() for p in module.parameters()] max_param_numel = max(params_numel) if min_chunk_size is not None: assert min_chunk_size >= max_param_numel else: min_chunk_size = max_param_numel step_size = search_range // n_grids max_chunk_util = -1 best_chunk_size = -1 for chunk_size in range(min_chunk_size, min_chunk_size + search_range + 1, step_size): chunk_util = ChunkManager.get_chunk_util(chunk_size, params_numel) if chunk_util > max_chunk_util: max_chunk_util = chunk_util best_chunk_size = chunk_size return best_chunk_size def copy_chunk_group(self, dest_group_name: str, src_group_name: str): """ Copy chunk data from one group to another group. Args: dest_group_name (str): the destination group which receives the copied data src_group_name (str): the source group which provides the data to copy """ for dest_chunk, src_chunk in zip(self.chunk_groups[dest_group_name], self.chunk_groups[src_group_name]): if not dest_chunk.is_empty: dest_chunk.copy_(src_chunk) def get_chunks(self, tensors: Iterable[torch.Tensor]) -> Tuple[Chunk, ...]: """ Get all chunks owning the input tensors. Args: tensors (Iterable[torch.Tensor]): the tensors used to look for chunks """ chunks = [] for tensor in tensors: chunk = self.get_chunk(tensor) if chunk not in chunks: chunks.append(chunk) return tuple(chunks) def add_extern_static_tensor(self, tensor: torch.Tensor) -> None: """Add extern static tensor to chunk manager. Those tensors won't be managed by chunk manager, but we want to monitor memory usage of them. They are "static", which means their shape, dtype, device never change. Thus, their memory usage never changes. Args: tensor (torch.Tensor): An extern static tensor. E.g. optimizer state. """ assert tensor not in self.tensor_chunk_map self.total_mem[tensor.device.type] += tensor.numel() * tensor.element_size()