ColossalAI/colossalai/tensor/chunk.py

343 lines
14 KiB
Python

import torch
import torch.distributed as dist
from dataclasses import dataclass
from enum import Enum
from typing import Optional, Dict, Deque, Set, List
from collections import deque
from colossalai.core import global_context as gpc
from colossalai.context import ParallelMode
from colossalai.utils import get_current_device
class TensorState(Enum):
FREE = 0
COMPUTE = 1
HOLD = 2
HOLD_AFTER_BWD = 3
READY_FOR_REDUCE = 4
STATE_TRANS = ((TensorState.FREE, TensorState.HOLD), (TensorState.FREE, TensorState.COMPUTE),
(TensorState.HOLD, TensorState.FREE), (TensorState.HOLD, TensorState.COMPUTE),
(TensorState.COMPUTE, TensorState.HOLD), (TensorState.COMPUTE, TensorState.HOLD_AFTER_BWD),
(TensorState.COMPUTE, TensorState.READY_FOR_REDUCE), (TensorState.HOLD_AFTER_BWD, TensorState.COMPUTE),
(TensorState.HOLD_AFTER_BWD, TensorState.READY_FOR_REDUCE), (TensorState.READY_FOR_REDUCE,
TensorState.HOLD))
@dataclass
class TensorInfo:
state: TensorState
offset: int
end: int
class ChunkFullError(Exception):
pass
class Chunk:
def __init__(self,
chunk_size: int,
src_rank: int,
dtype: torch.dtype,
init_device: Optional[torch.device] = None) -> None:
self.size = chunk_size
self.utilized_size = 0
self.src_rank = src_rank
self.is_src_rank = gpc.get_local_rank(ParallelMode.DATA) == src_rank
self.global_src_rank = gpc.get_ranks_in_group(ParallelMode.DATA)[src_rank]
self.dtype = dtype
self.device = init_device or get_current_device()
self.data = torch.empty(chunk_size, dtype=dtype, device=self.device)
if not self.is_src_rank:
self.data.storage().resize_(0)
self.tensors_info: Dict[torch.Tensor, TensorInfo] = {}
self.mem = self.size * self.data.element_size()
def append(self, tensor: torch.Tensor) -> None:
assert tensor.dtype == self.dtype
new_utilized_size = self.utilized_size + tensor.numel()
if new_utilized_size > self.size:
raise ChunkFullError
tensor_state = TensorState.FREE
if self.is_src_rank:
self.data[self.utilized_size:new_utilized_size].copy_(tensor.view(-1))
tensor_state = TensorState.HOLD
tensor.data = self.data[self.utilized_size:new_utilized_size].view(tensor.shape)
else:
tensor.storage().resize_(0)
self.tensors_info[tensor] = TensorInfo(tensor_state, self.utilized_size, new_utilized_size)
self.utilized_size = new_utilized_size
def release(self) -> None:
if not self.is_src_rank:
self.data.storage().resize_(0)
self._update_tensors_state(TensorState.FREE)
def _update_tensors_ptr(self) -> None:
for tensor, tensor_info in self.tensors_info.items():
tensor.data = self.data[tensor_info.offset:tensor_info.end].view(tensor.shape)
def _update_tensors_state(self, next_state: TensorState, prev_state: Optional[TensorState] = None):
for tensor_info in self.tensors_info.values():
if prev_state is None or tensor_info.state == prev_state:
tensor_info.state = next_state
def access(self) -> None:
if not self.is_src_rank:
self.data.storage().resize_(self.size)
self.data.data = self.data.to(get_current_device())
dist.broadcast(self.data, self.global_src_rank, group=gpc.get_group(ParallelMode.DATA))
self._update_tensors_ptr()
if not self.is_src_rank:
self._update_tensors_state(TensorState.HOLD, prev_state=TensorState.FREE)
def move_device(self, device: torch.device) -> None:
self.data.data = self.data.to(device)
self._update_tensors_ptr()
def reduce(self, is_all_reduce: bool = False) -> None:
self.data.data = self.data.to(get_current_device())
if is_all_reduce:
dist.all_reduce(self.data, group=gpc.get_group(ParallelMode.DATA))
else:
dist.reduce(self.data, self.global_src_rank, group=gpc.get_group(ParallelMode.DATA))
self._update_tensors_ptr()
self._update_tensors_state(TensorState.HOLD)
def tensor_trans_state(self, tensor: torch.Tensor, tensor_state: TensorState) -> None:
assert tensor != TensorState.FREE, 'Can only set a chunk of tensors to FREE'
# As the gradient hook can be triggered either before or after post-backward
# tensor's state can be compute -> hold_after_bwd -> ready_for_reduce
# or compute -> ready_for_reduce -> hold_after_bwd
# the second one is invalid, we just ignore ready_for_reduce -> hold_after_bwd
# this function only apply valid state transformation
# invalid calls will be ignored and nothing changes
if (self.tensors_info[tensor].state, tensor_state) not in STATE_TRANS:
# print(
# f'WARNING: Rank{gpc.get_global_rank()} apply invalid state trans: {self.tensors_info[tensor].state} to {tensor_state}'
# )
return
self.tensors_info[tensor].state = tensor_state
def copy_tensor_to_chunk_slice(self, tensor: torch.Tensor, data_slice: torch.Tensor) -> None:
tensor_info = self.tensors_info[tensor]
self.data[tensor_info.offset:tensor_info.end].copy_(data_slice.view(-1))
tensor.data = self.data[tensor_info.offset:tensor_info.end].view(tensor.shape)
@property
def can_release(self) -> bool:
for tensor_info in self.tensors_info.values():
if tensor_info.state != TensorState.HOLD:
return False
return True
@property
def can_move_device(self) -> bool:
for tensor_info in self.tensors_info.values():
if tensor_info.state in (TensorState.COMPUTE, TensorState.READY_FOR_REDUCE):
return False
return True
@property
def can_reduce(self) -> bool:
for tensor_info in self.tensors_info.values():
if tensor_info.state != TensorState.READY_FOR_REDUCE:
return False
return True
@property
def is_free(self) -> bool:
return self.data.storage().size() == 0
def __repr__(self) -> str:
return f'Chunk: src rank={self.src_rank} ,size={self.size}, utilization={self.utilized_size/self.size*100:.2f}%, freed={self.is_free}, tensor states={[info.state.name for info in self.tensors_info.values()]}'
@property
def has_inf_or_nan(self) -> bool:
return torch.isinf(self.data[:self.utilized_size]).any().item() or \
torch.isnan(self.data[:self.utilized_size]).any().item()
def copy_(self, dest_chunk: 'Chunk'):
assert not self.is_free
assert not dest_chunk.is_free
assert self.size == dest_chunk.size
assert self.utilized_size == dest_chunk.utilized_size
self.data.copy_(dest_chunk.data)
self._update_tensors_ptr()
@property
def device_type(self) -> str:
return self.data.device.type
class ChunkManager:
def __init__(self,
chunk_size: Optional[int],
enable_distributed_storage: bool = False,
init_device: Optional[torch.device] = None) -> None:
assert chunk_size is None or chunk_size > 0
self.chunk_size = chunk_size
self.enable_distributed_storage = enable_distributed_storage
self.device = init_device or get_current_device()
self.chunk_groups: Dict[str, Deque[Chunk]] = {}
self.tensor_chunk_map: Dict[torch.Tensor, Chunk] = {}
self.accessed_chunks: Set[Chunk] = set()
self.lazy_release_tensors: List[torch.Tensor] = []
if enable_distributed_storage and chunk_size is None:
self.rank_load: Dict[str, torch.Tensor] = {}
self.total_mem: Dict[str, int] = {'cpu': 0, 'cuda': 0}
def append_tensor(self, tensor: torch.Tensor, group_name: str) -> None:
assert tensor not in self.tensor_chunk_map
if self.chunk_size is not None and tensor.numel() > self.chunk_size:
raise ValueError(
f'Cannot create chunk, got tensor numel ({tensor.numel()}) > chunk size ({self.chunk_size})')
if group_name not in self.chunk_groups:
self.chunk_groups[group_name] = deque()
try:
self.chunk_groups[group_name][-1].append(tensor)
except (IndexError, ChunkFullError):
chunk_size = self.chunk_size or tensor.numel()
src_rank = self._get_next_src_rank(group_name)
chunk = Chunk(chunk_size, src_rank, tensor.dtype, self.device)
if self.enable_distributed_storage and self.chunk_size is None:
self.rank_load[group_name][src_rank] += chunk_size
self.chunk_groups[group_name].append(chunk)
chunk.append(tensor)
if not chunk.is_free:
self.total_mem[chunk.device_type] += chunk.mem
self.tensor_chunk_map[tensor] = self.chunk_groups[group_name][-1]
if not self.enable_distributed_storage:
self.accessed_chunks.add(self.chunk_groups[group_name][-1])
def _get_next_src_rank(self, group_name: str) -> int:
if not self.enable_distributed_storage:
return gpc.get_local_rank(ParallelMode.DATA)
if self.chunk_size is None:
if group_name not in self.rank_load:
self.rank_load[group_name] = torch.zeros(gpc.get_world_size(ParallelMode.DATA), dtype=torch.int64)
src_rank = torch.argmin(self.rank_load[group_name]).item()
else:
chunk_idx = len(self.chunk_groups[group_name])
src_rank = chunk_idx % gpc.get_world_size(ParallelMode.DATA)
return src_rank
def access_chunk(self, tensor: torch.Tensor) -> None:
chunk = self.tensor_chunk_map[tensor]
if chunk in self.accessed_chunks:
return
if not chunk.is_free:
self.total_mem[chunk.device_type] -= chunk.mem
chunk.access()
self.accessed_chunks.add(chunk)
self.total_mem[chunk.device_type] += chunk.mem
def release_chunk(self, tensor: torch.Tensor) -> None:
if not self.enable_distributed_storage:
return
chunk = self.tensor_chunk_map[tensor]
if chunk not in self.accessed_chunks:
return
if chunk.can_release:
chunk.release()
self.accessed_chunks.remove(chunk)
if chunk.is_free:
self.total_mem[chunk.device_type] -= chunk.mem
def move_chunk(self, tensor: torch.Tensor, device: torch.device) -> None:
chunk = self.tensor_chunk_map[tensor]
if chunk.data.device == device:
return
if chunk.can_move_device and not chunk.is_free:
self.total_mem[chunk.device_type] -= chunk.mem
chunk.move_device(device)
self.total_mem[chunk.device_type] += chunk.mem
def trans_tensor_state(self, tensor: torch.Tensor, state: TensorState) -> None:
chunk = self.tensor_chunk_map[tensor]
chunk.tensor_trans_state(tensor, state)
def reduce_chunk(self, tensor: torch.Tensor) -> bool:
chunk = self.tensor_chunk_map[tensor]
if not chunk.can_reduce:
return False
self.total_mem[chunk.device_type] -= chunk.mem
chunk.reduce(is_all_reduce=not self.enable_distributed_storage)
self.total_mem[chunk.device_type] += chunk.mem
return True
def copy_tensor_to_chunk_slice(self, tensor: torch.Tensor, data: torch.Tensor) -> None:
chunk = self.tensor_chunk_map[tensor]
chunk.copy_tensor_to_chunk_slice(tensor, data)
def is_chunk_free(self, tensor: torch.Tensor) -> bool:
chunk = self.tensor_chunk_map[tensor]
return chunk.is_free
def get_chunk(self, tensor: torch.Tensor) -> Chunk:
return self.tensor_chunk_map[tensor]
def add_lazy_release_tensors(self, tensors: List[torch.Tensor]) -> None:
self.lazy_release_tensors.extend(tensors)
def exec_lazy_release(self) -> None:
for tensor in self.lazy_release_tensors:
self.release_chunk(tensor)
self.lazy_release_tensors.clear()
def __repr__(self) -> str:
msg = f'Rank {gpc.get_local_rank(ParallelMode.DATA)}:\n'
msg += 'Total memory: ' + ', '.join([f'{k}={v}B' for k, v in self.total_mem.items()]) + '\n'
for group_name, group in self.chunk_groups.items():
msg += f'Group {group_name}:\n'
for i, chunk in enumerate(group):
msg += f'[{i}] {chunk}\n'
return msg
@staticmethod
def get_chunk_util(chunk_size: int, params_numel: List[int]) -> float:
assert len(params_numel) > 0
total_size = 0
total_utilized_size = 0
cur_chunk_utilized_size = 0
for size in params_numel:
assert chunk_size >= size
total_utilized_size += size
if total_size == 0 or cur_chunk_utilized_size + size > chunk_size:
total_size += chunk_size
cur_chunk_utilized_size = 0
cur_chunk_utilized_size += size
return total_utilized_size / total_size
@staticmethod
def search_chunk_size(module: torch.nn.Module,
search_range: int,
n_grids: int,
min_chunk_size: Optional[int] = None) -> int:
assert search_range % n_grids == 0
# TODO(ver217): sort params and filter unused ones
params_numel = [p.numel() for p in module.parameters()]
max_param_numel = max(params_numel)
if min_chunk_size is not None:
assert min_chunk_size >= max_param_numel
else:
min_chunk_size = max_param_numel
step_size = search_range // n_grids
max_chunk_util = -1
best_chunk_size = -1
for chunk_size in range(min_chunk_size, min_chunk_size + search_range + 1, step_size):
chunk_util = ChunkManager.get_chunk_util(chunk_size, params_numel)
if chunk_util > max_chunk_util:
max_chunk_util = chunk_util
best_chunk_size = chunk_size
return best_chunk_size
def copy_chunk_group(self, dest_group_name: str, src_group_name: str):
for dest_chunk, src_chunk in zip(self.chunk_groups[dest_group_name], self.chunk_groups[src_group_name]):
if not dest_chunk.is_free:
dest_chunk.copy_(src_chunk)