ColossalAI/colossalai/elixir/chunk/core/chunk.py

575 lines
21 KiB
Python

from dataclasses import dataclass
from typing import Dict, List, Optional
import torch
import torch.distributed as dist
from torch.distributed import ProcessGroup
from colossalai.elixir.cuda import gpu_device
from colossalai.elixir.tensor import FakeTensor
from .memory_pool import MemoryPool, TensorBlock
from .states import TensorState, validate_tensor_state_update
class ChunkFullError(Exception):
pass
@dataclass
class TensorInfo:
state: TensorState
fake_data: FakeTensor
offset: int
end: int
class Chunk:
"""Chunk is a type of data structure to store tensors.
It allows us to store a sequence of tensors into one continuous memory block.
Moreover, Chunk manages the storage of tensors in a distributed way.
Normally, a chunk is scattered across its process group.
When a tensor in this chunk should be used later, the chunk can be gathered by access_chunk.
When the training is done, the chunk can be scattered by reduce_chunk.
args:
rcache: the memory pool to store replicated chunks
chunk_size: the size of the chunk
chunk_dtype: the dtype of the chunk
process_group: the torch communication group of the chunk
temp_device: the device to store the temporary chunk when initializing
shard_device: the device to store the shard of the scattered chunk
rcache_fused: whether this chunk is fused in rcache without eviction
cpu_pin_memory: whether this chunk use cpu pin memory for its shard
"""
total_count = 0
def __init__(
self,
rcache: MemoryPool,
chunk_size: int,
chunk_dtype: torch.dtype,
process_group: ProcessGroup,
temp_device: Optional[torch.device] = None,
shard_device: Optional[torch.device] = None,
rcache_fused: bool = False, # whether this chunk is used in ZeRO2
cpu_pin_memory: bool = False # whether this chunk has a permanent copy in cpu
) -> None:
self.chunk_id: int = Chunk.total_count
Chunk.total_count += 1
# set replicated cache pool
self.rcache: MemoryPool = rcache
self.chunk_size: int = chunk_size
self.chunk_dtype: torch.dtype = chunk_dtype
self.utilized_size: int = 0
self.torch_pg: ProcessGroup = process_group
self.pg_size: int = dist.get_world_size(self.torch_pg)
self.pg_rank: int = dist.get_rank(self.torch_pg)
# the chunk size should be divisible by the dp degree
assert chunk_size % self.pg_size == 0
self.shard_size: int = chunk_size // self.pg_size
self.shard_begin: int = self.shard_size * self.pg_rank
self.shard_end: int = self.shard_begin + self.shard_size
self.valid_end: int = self.shard_size + 1 # set to an illegal number
# notice: release blocks reserved by Pytorch
torch.cuda.empty_cache()
# rcache block, the global replicated chunk in R cache
self.rcb: Optional[TensorBlock] = None
self.rcache_fused: bool = rcache_fused
self._my_block = None
self.is_replica: bool = True
# allocate a private block for fused chunks
if self.rcache_fused:
self._my_block = rcache.get_private_block(chunk_size, chunk_dtype)
temp_device: torch.device = temp_device or gpu_device()
# chunk_temp is a global chunk, which only exists during building the chunks.
# keep all elements to zero
self.chunk_temp: Optional[torch.Tensor] = None
if rcache_fused:
self.chunk_temp = self._my_block.payload
torch.zero_(self.chunk_temp)
else:
self.chunk_temp = torch.zeros(chunk_size, dtype=chunk_dtype, device=temp_device)
# configure the init device of the shard
# no-offload default: fp16, fp32 -> CUDA
# offload default: fp16, fp32 -> CPU
shard_device: torch.device = shard_device or torch.device('cpu')
pin_flag: bool = cpu_pin_memory and shard_device.type == 'cpu'
# chunk.shard is a local chunk
# it is desinged to exist permanently
self.shard: torch.Tensor = torch.empty(self.shard_size,
dtype=chunk_dtype,
device=shard_device,
pin_memory=pin_flag)
# calculate the memory occupation of the chunk and the shard
self.chunk_memo: int = self.chunk_size * self.chunk_temp.element_size()
self.shard_memo: int = self.chunk_memo // self.pg_size
# each tensor is associated with a TensorInfo to track its meta info
# (state, shape, offset, end)
self.tensors_info: Dict[torch.Tensor, TensorInfo] = {}
# the total number of tensors in the chunk
self.num_tensors: int = 0
# Record the number of tensors in different states
self.tensor_state_cnter: Dict[TensorState, int] = dict()
for state in TensorState:
self.tensor_state_cnter[state] = 0
# we introduce the paired chunk here
# it refers to another chunk having the same parameters
# but with different dtype(such as fp16_chunk.paired_chunk -> fp32_chunk
self.paired_chunk = None
# if this chunk is synchronized with the optimizer, the flag is True
self.optim_sync_flag = True
# whether to record l2 norm for the gradient clipping calculation
self.l2_norm_flag = False
self.l2_norm = None
# whether it overflows after the reduction
self.overflow = False
@property
def prepared_block(self):
return self._my_block
@property
def is_init(self):
return self.chunk_temp is not None
@property
def in_rcache(self):
return self.rcb is not None
@property
def shard_device(self):
return self.shard.device
@property
def memory_usage(self) -> Dict[str, int]:
cuda_memory = 0
cpu_memory = 0
# this chunk is not closed
if self.is_init:
if self.chunk_temp.device.type == 'cuda':
cuda_memory += self.chunk_memo
else:
cpu_memory += self.chunk_memo
# this chunk is on the rcache
if self.in_rcache:
cuda_memory += self.rcb.memo_occ
# calculate the occupation of the chunk shard
if self.shard_device.type == 'cuda':
cuda_memory += self.shard_memo
elif self.shard_device.type == 'cpu':
cpu_memory += self.shard_memo
else:
raise NotImplementedError
return dict(cuda=cuda_memory, cpu=cpu_memory)
@property
def payload(self) -> torch.Tensor:
if self.is_init:
return self.chunk_temp
if self.in_rcache:
return self.rcb.payload
else:
return self.shard
@property
def shard_move_check(self) -> bool:
return not self.in_rcache
def _not_compute_number(self):
total = 0
state_list = [TensorState.HOLD, TensorState.HOLD_AFTER_BWD, TensorState.READY_FOR_REDUCE]
for state in state_list:
total += self.tensor_state_cnter[state]
return total
@property
def scatter_check(self) -> bool:
if self.rcache_fused:
return False
return self._not_compute_number() == self.num_tensors
@property
def reduce_check(self):
return self.tensor_state_cnter[TensorState.READY_FOR_REDUCE] == self.num_tensors
def enable_l2_norm_flag(self) -> None:
self.l2_norm_flag = True
def set_overflow_flag(self, valid_tensor: torch.Tensor) -> None:
assert not self.overflow
self.overflow = torch.isinf(valid_tensor).any().item() | torch.isnan(valid_tensor).any().item()
def set_l2_norm(self, valid_tensor: torch.Tensor) -> None:
assert self.l2_norm is None, 'you are calculating the l2 norm twice'
chunk_l2_norm = valid_tensor.data.float().norm(2)
self.l2_norm = chunk_l2_norm.item()**2
def append_tensor(self, tensor: torch.Tensor):
# sanity check
assert self.is_init
assert tensor.dtype == self.chunk_dtype
new_utilized_size = self.utilized_size + tensor.numel()
# raise exception when the chunk size is exceeded
if new_utilized_size > self.chunk_size:
raise ChunkFullError
self.chunk_temp[self.utilized_size:new_utilized_size].copy_(tensor.data.flatten())
tensor.data = self.chunk_temp[self.utilized_size:new_utilized_size].view(tensor.shape)
fake_data = FakeTensor(tensor.data)
# record all the information about the tensor
self.num_tensors += 1
tensor_state = TensorState.HOLD
self.tensor_state_cnter[tensor_state] += 1
self.tensors_info[tensor] = TensorInfo(state=tensor_state,
fake_data=fake_data,
offset=self.utilized_size,
end=new_utilized_size)
self.utilized_size = new_utilized_size
def close_chunk(self):
# sanity check
assert self.is_init
# calculate the valid end for each shard
if self.utilized_size <= self.shard_begin:
self.valid_end = 0
elif self.utilized_size < self.shard_end:
self.valid_end = self.utilized_size - self.shard_begin
self.__remove_tensors_ptr()
self.__update_shard(self.chunk_temp, self.shard)
self.is_replica = False
self.chunk_temp = None
def replicate(self):
assert not self.is_replica
self.is_replica = True
this_shard = self.shard if self.optim_sync_flag else self.__paired_shard()
self.__update_replica(self.rcb.payload, this_shard)
self.__update_tensors_ptr()
def scatter(self):
assert not self.rcache_fused
assert self.is_replica
self.__remove_tensors_ptr()
if not self.optim_sync_flag:
self.__update_shard(self.rcb.payload, self.shard)
self.optim_sync_flag = True
self.is_replica = False
def reduce(self, always_fp32: bool = False):
assert self.is_replica
self.__remove_tensors_ptr()
if self.pg_size > 1:
cast_to_fp32 = False
if always_fp32 and self.chunk_dtype != torch.float:
cast_to_fp32 = True
# cast the payload to fp32
reduce_buffer = self.rcb.payload.to(dtype=torch.float)
else:
# otherwise, use the same payload
reduce_buffer = self.rcb.payload
# divide the reduce buffer by the size of the process group
reduce_buffer /= self.pg_size
# try to use inplace reduce scatter
# notice: pytorch does not allow true inplace reduce scatter
# because pytorch will allocate a continuous memory space for collective communications
shard_buffer = reduce_buffer[self.shard_begin:self.shard_end]
dist.reduce_scatter_tensor(shard_buffer, reduce_buffer, group=self.torch_pg)
# the result should be moved to payload for norm calculating
if cast_to_fp32:
calc_buffer = self.rcb.payload[self.shard_begin:self.shard_end]
calc_buffer.copy_(shard_buffer)
else:
# if process group size equals to 1, do not communicate
reduce_buffer = self.rcb.payload
self.__update_shard(reduce_buffer, self.shard)
self.is_replica = False
def access_chunk(self, block: Optional[TensorBlock] = None):
# sanity check
assert not self.is_init
assert not self.is_replica
if self.rcache_fused:
assert block is None
self.rcb = self._my_block
else:
assert block in self.rcache.public_used_blocks
assert self.rcb is None
self.rcb = block
self.replicate()
def release_chunk(self) -> TensorBlock:
# sanity check
assert not self.is_init
assert self.is_replica
if self.rcache_fused:
raise RuntimeError
self.scatter()
block = self.rcb
self.rcb = None
return block
def update_extra_reduce_info(self, block: Optional[TensorBlock]):
if self.rcache_fused:
assert block is None
block = self._my_block
else:
assert block is not None
buffer = block.payload[self.shard_begin:self.shard_end]
valid_tensor = buffer[:self.valid_end]
self.set_overflow_flag(valid_tensor)
if self.l2_norm_flag:
self.set_l2_norm(valid_tensor)
def reduce_chunk(self, always_fp32: bool = False, sync: bool = True) -> Optional[TensorBlock]:
"""Reduce scatter all the gradients. It's an operation done in CUDA.
"""
# sanity check
assert not self.is_init
assert self.is_replica
self.reduce(always_fp32=always_fp32)
self.__update_tensors_state(TensorState.HOLD)
# reset the rcb pointer
block = self.rcb
self.rcb = None
if self.rcache_fused:
block = None
if sync:
self.update_extra_reduce_info(block)
return block
def tensor_trans_state(self, tensor: torch.Tensor, tensor_state: TensorState) -> None:
prev_state = self.tensors_info[tensor].state
if prev_state == tensor_state:
return
# validate whether the update is legal
# if illegal, raise an exception
is_update_valid = validate_tensor_state_update(prev_state, tensor_state, raise_exception=True)
if is_update_valid:
self.__update_one_tensor_info(self.tensors_info[tensor], tensor_state)
def copy_tensor_to_chunk_slice(self, tensor: torch.Tensor, data_slice: torch.Tensor) -> None:
# sanity check
assert self.is_replica
info = self.tensors_info[tensor]
payload = self.rcb.payload
payload[info.offset:info.end].copy_(data_slice.data.flatten())
tensor.data = payload[info.offset:info.end].view(tensor.shape)
def init_pair(self, friend_chunk: 'Chunk') -> None:
if self.paired_chunk is None and friend_chunk.paired_chunk is None:
self.paired_chunk = friend_chunk
friend_chunk.paired_chunk = self
else:
assert self.paired_chunk is friend_chunk
assert friend_chunk.paired_chunk is self
def optim_update(self) -> None:
"""Update the fp16 chunks via their fp32 chunks. It's used by the optimizer.
"""
# sanity check
assert self.paired_chunk is not None
friend_chunk: Chunk = self.paired_chunk
assert not friend_chunk.is_replica
# gradient and optimizer should be on the same device
assert self.shard_device.type == friend_chunk.shard_device.type
if self.shard_device.type == 'cuda':
self.shard.copy_(friend_chunk.shard)
self.optim_sync_flag = True
elif self.shard_device.type == 'cpu':
# optim_sync_flag is set to False
# see shard_move function for more details
self.optim_sync_flag = False
else:
raise NotImplementedError
def get_tensors(self) -> List[torch.Tensor]:
return list(self.tensors_info.keys())
def get_cpu_copy(self, only_rank_0: bool = False) -> List[torch.Tensor]:
assert not self.is_init
if self.is_replica:
# use the payload directly when being replica
temp_buffer = self.rcb.payload
else:
# otherwise, create a temporary buffer
temp_buffer = torch.empty(self.chunk_size, dtype=self.chunk_dtype, device=gpu_device())
# cheat the assertion in __update_replica
self.is_replica = True
self.__update_replica(temp_buffer, self.shard)
self.is_replica = False
cpu_copys = [None] * self.num_tensors
if not only_rank_0 or self.pg_rank == 0:
for i, (t, info) in enumerate(self.tensors_info.items()):
t_copy = temp_buffer[info.offset:info.end].view(t.shape).cpu()
cpu_copys[i] = t_copy
# synchronize
dist.barrier()
return cpu_copys
def load_tensors(self, tensor_list: List[Optional[torch.Tensor]], only_rank_0: bool = False) -> bool:
assert not self.is_replica
assert not self.is_init
temp_buffer = torch.empty(self.chunk_size, dtype=self.chunk_dtype, device=gpu_device())
# cheat the assertion in __update_replica
self.is_replica = True
self.__update_replica(temp_buffer, self.shard)
self.is_replica = False
if not only_rank_0 or self.pg_rank == 0:
for (_, c_info), load_tensor in zip(self.tensors_info.items(), tensor_list):
if load_tensor is None:
continue
temp_buffer[c_info.offset:c_info.end].copy_(load_tensor.data.flatten())
# synchronize
dist.barrier()
if only_rank_0:
dist.broadcast(temp_buffer, src=0, group=self.torch_pg)
# cheat the assertion in __update_shard
self.is_replica = True
self.__update_shard(temp_buffer, self.shard)
self.is_replica = False
def __update_replica(self, replica: torch.Tensor, shard: torch.Tensor):
assert self.is_replica
assert replica.numel() == self.chunk_size
assert shard.numel() == self.shard_size
buffer = replica[self.shard_begin:self.shard_end]
buffer.copy_(shard)
dist.all_gather_into_tensor(replica, buffer, group=self.torch_pg)
def __update_shard(self, replica: torch.Tensor, shard: torch.Tensor):
assert self.is_replica
assert replica.numel() == self.chunk_size
assert shard.numel() == self.shard_size
shard.copy_(replica[self.shard_begin:self.shard_end])
def __paired_shard(self):
assert self.paired_chunk is not None, 'chunks should be paired before training'
optim_chunk: Chunk = self.paired_chunk
assert self.chunk_size == optim_chunk.chunk_size
# only be called when optimizer state is in CPU memory
# the grad and param should be in the same device
assert self.shard_device.type == 'cpu'
return optim_chunk.shard.to(gpu_device())
def __remove_tensors_ptr(self) -> None:
# sanity check
# each tensor should point to its fake data before scatter
assert self.is_replica
for tensor, info in self.tensors_info.items():
tensor.data = info.fake_data
def __update_tensors_ptr(self) -> None:
# sanity check
# the chunk should be replicated to get the correct pointer
assert self.is_replica
payload = self.rcb.payload
for tensor, info in self.tensors_info.items():
tensor.data = payload[info.offset:info.end].view(tensor.shape)
def __update_one_tensor_info(self, tensor_info: TensorInfo, next_state: TensorState):
self.tensor_state_cnter[tensor_info.state] -= 1
tensor_info.state = next_state
self.tensor_state_cnter[tensor_info.state] += 1
def __update_tensors_state(self, next_state: TensorState, prev_state: Optional[TensorState] = None):
for tensor_info in self.tensors_info.values():
if prev_state is None or tensor_info.state == prev_state:
self.__update_one_tensor_info(tensor_info, next_state)
def __hash__(self) -> int:
return self.chunk_id
def __lt__(self, other: object) -> bool:
return self.chunk_id < other.chunk_id
def __eq__(self, other: object) -> bool:
return self.chunk_id == other.chunk_id
def __repr__(self, detailed: bool = True):
if self.is_init:
state = 'initialization'
elif self.in_rcache:
state = 'replicated'
else:
state = 'scattered'
output = [
f'Chunk {self.chunk_id} details: state -> {state}\n',
f' length: {self.chunk_size}, dtype: {self.chunk_dtype}, group_size: {self.pg_size}, tensors: {self.num_tensors}\n'
f' utilized size: {self.utilized_size}, utilized percentage: {100 * (self.utilized_size / self.chunk_size):.0f}%\n'
]
memory_info = self.memory_usage
output.append(' memory usage: (cuda -> {}, cpu -> {})\n'.format(memory_info['cuda'], memory_info['cpu']))
def print_tensor(name, tensor, prefix=''):
output.append(f'{prefix}{name}: (shape={tensor.shape}, dtype={tensor.dtype}, device={tensor.device})\n')
if self.is_init:
print_tensor(name='temp', tensor=self.chunk_temp, prefix=' ')
if self.in_rcache:
print_tensor(name='block', tensor=self.rcb.payload, prefix=' ')
if self.shard is not None:
print_tensor(name='shard', tensor=self.shard, prefix=' ')
if detailed:
output.append(' tensor state monitor:\n')
for st in TensorState:
output.append(' # of {}: {}\n'.format(st, self.tensor_state_cnter[st]))
return ''.join(output)