
549 lines
21 KiB
Raw Normal View History

2022-08-09 08:39:48 +00:00
import torch
import torch.distributed as dist
from dataclasses import dataclass
from enum import Enum
from typing import Optional, Dict, List
2022-08-09 08:39:48 +00:00
from colossalai.utils import get_current_device
from colossalai.tensor import ProcessGroup as ColoProcessGroup
class TensorState(Enum):
FREE = 0
HOLD = 2
STATE_TRANS = ((TensorState.FREE, TensorState.HOLD), (TensorState.FREE, TensorState.COMPUTE),
(TensorState.HOLD, TensorState.FREE), (TensorState.HOLD, TensorState.COMPUTE),
(TensorState.COMPUTE, TensorState.HOLD), (TensorState.COMPUTE, TensorState.HOLD_AFTER_BWD),
(TensorState.COMPUTE, TensorState.READY_FOR_REDUCE), (TensorState.HOLD_AFTER_BWD, TensorState.COMPUTE),
class TensorInfo:
state: TensorState
offset: int
end: int
class ChunkFullError(Exception):
def is_storage_empty(tensor: torch.Tensor) -> bool:
return == 0
def free_storage(tensor: torch.Tensor) -> None:
if not is_storage_empty(tensor):
def alloc_storage(tensor: torch.Tensor) -> None:
if is_storage_empty(tensor):
class Chunk:
2022-08-09 08:39:48 +00:00
def __init__(self,
chunk_size: int,
process_group: ColoProcessGroup,
dtype: torch.dtype,
init_device: Optional[torch.device] = None,
keep_gathered: bool = False,
pin_memory: bool = False) -> None:
Chunk: A container owning a piece of contiguous memory space for tensors
Here we use all-gather operation to gather the whole chunk.
Currently, Chunk is exclusively used for DDP and ZeRO DDP and it doesn't support unused parameters.
2022-08-09 08:39:48 +00:00
It is designed to make the full use of communication and PCIE bandwidth.
chunk_size (int): the number of elements in the chunk
2022-08-09 08:39:48 +00:00
process_group (ColoProcessGroup): the process group of this chunk
dtype (torch.dtype): the data type of the chunk
init_device (torch.device): optional, the device where the tensor is initialized
The default value is None, which is the current GPU
keep_gathered (bool): optional, if True, this chunk is always gathered in CUDA memory
pin_memory (bool): optional, if True, this chunk always has a shard copied in pinned CPU memory
2022-08-09 08:39:48 +00:00
self.chunk_size = chunk_size
self.utilized_size = 0
# Here, we use torch process group,
# since ColoProcessGroup might get deprecated soon
self.torch_pg = process_group.dp_process_group()
2022-08-09 08:39:48 +00:00
self.pg_size = dist.get_world_size(self.torch_pg)
self.pg_rank = dist.get_rank(self.torch_pg)
# the chunk size should be able to be divied by the size of GPU
if not keep_gathered:
assert chunk_size % self.pg_size == 0
2022-08-09 08:39:48 +00:00
self.shard_size = chunk_size // self.pg_size
self.shard_begin = self.shard_size * self.pg_rank
self.shard_end = self.shard_begin + self.shard_size
self.valid_end = self.shard_size
2022-08-09 08:39:48 +00:00
self.dtype = dtype
device = init_device or get_current_device()
self.chunk_temp = torch.zeros(chunk_size, dtype=dtype, device=device) # keep all zero
self.chunk_total = None # we force chunk_total located in CUDA
self.cuda_shard = None # using two attributes for the better interpretation
2022-08-09 08:39:48 +00:00
self.cpu_shard = None
self.is_gathered = True
self.chunk_mem = self.chunk_size * self.chunk_temp.element_size()
self.shard_mem = self.chunk_mem // self.pg_size
# each tensor is associated with a TensorInfo to track meta info
self.tensors_info: Dict[torch.Tensor, TensorInfo] = {}
# the total number of all tensors
self.num_tensors = 0
# monitor the states of all tensors
self.tensors_state_monitor: Dict[TensorState, int] = dict()
for state in TensorState:
self.tensors_state_monitor[state] = 0
# some chunks can keep gathered all the time
# so their computation patterns are the same as that of the parameters in DDP
self.keep_gathered = keep_gathered
if self.keep_gathered:
pin_memory = False # since this chunk is gathered, it doesn't need to pin
2022-08-09 08:39:48 +00:00
# if pin_memory is True, we allocate a piece of CPU pin-memory
# for it all the time
self.pin_memory = pin_memory
# we introduce the paired chunk here
# it refers to another chunk having the same parameters
# but with different dtype(such as fp16_chunk.paired_chunk -> fp32_chunk
2022-08-09 08:39:48 +00:00
self.paired_chunk = None
# if this chunk is synchronized with the optimizer, the flag is True
self.optim_sync_flag = True
# if the cpu_shard has been visited during the training step, the flag is True
self.cpu_vis_flag = False
def memory_usage(self) -> Dict[str, int]:
2022-08-09 08:39:48 +00:00
cuda_memory = 0
cpu_memory = 0
if self.chunk_temp is not None:
# this chunk is not closed
if self.chunk_temp.device.type == 'cuda':
cuda_memory += self.chunk_mem
cpu_memory += self.chunk_mem
if self.is_gathered:
cuda_memory += self.chunk_mem
if self.cuda_shard is not None:
cuda_memory += self.shard_mem
if self.cpu_shard is not None:
cpu_memory += self.shard_mem
return dict(cuda=cuda_memory, cpu=cpu_memory)
def device_type(self) -> str:
2022-08-09 08:39:48 +00:00
if self.chunk_temp is not None:
return self.chunk_temp.device.type
if self.is_gathered:
2022-08-09 08:39:48 +00:00
return 'cuda'
elif self.cuda_shard is not None:
return 'cuda'
return 'cpu'
def payload(self) -> torch.Tensor:
# sanity check
assert self.chunk_temp is None
if self.is_gathered:
return self.chunk_total
elif self.cuda_shard is not None:
return self.cuda_shard
return self.cpu_shard
def payload_mem(self) -> int:
# sanity check
assert self.chunk_temp is None
if self.is_gathered:
return self.chunk_mem
return self.shard_mem
def can_move(self) -> bool:
return not self.is_gathered
def can_release(self) -> bool:
if self.keep_gathered:
return False
return self.tensors_state_monitor[TensorState.HOLD] + \
self.tensors_state_monitor[TensorState.HOLD_AFTER_BWD] == self.num_tensors
def can_reduce(self):
return self.tensors_state_monitor[TensorState.READY_FOR_REDUCE] == self.num_tensors
def has_inf_or_nan(self) -> bool:
"""Check if the chunk has inf or nan values in CUDA.
if self.is_gathered:
valid_tensor = self.chunk_total[:self.utilized_size]
assert self.cuda_shard is not None # only check in CUDA
valid_tensor = self.cuda_shard[:self.valid_end]
return torch.isinf(valid_tensor).any().item() | torch.isnan(valid_tensor).any().item()
2022-08-09 08:39:48 +00:00
def append_tensor(self, tensor: torch.Tensor):
"""Add a tensor to the chunk.
tensor (torch.Tensor): a tensor to be added to the chunk
# sanity check
assert self.chunk_temp is not None
assert tensor.dtype == self.dtype
new_utilized_size = self.utilized_size + tensor.numel()
# raise exception when the chunk size is exceeded
if new_utilized_size > self.chunk_size:
raise ChunkFullError
2022-08-09 08:39:48 +00:00
assert type(self.chunk_temp) == torch.Tensor, "copy_tensor_to_chunk_slice must use a torch tensor" = self.chunk_temp[self.utilized_size:new_utilized_size].view(tensor.shape)
2022-08-09 08:39:48 +00:00
# record all the information about the tensor
self.num_tensors += 1
tensor_state = TensorState.HOLD
self.tensors_info[tensor] = TensorInfo(tensor_state, self.utilized_size, new_utilized_size)
self.tensors_state_monitor[tensor_state] += 1
self.utilized_size = new_utilized_size
def close_chunk(self, shard_dev: Optional[torch.device] = None):
"""Close the chunk. Any tensor can't be appended to a closed chunk later.
shard_dev: the device where the shard locates
2022-08-09 08:39:48 +00:00
# sanity check
assert self.chunk_temp is not None
# calculate the valid end for each shard
if self.utilized_size <= self.shard_begin:
self.valid_end = 0
elif self.utilized_size < self.shard_end:
self.valid_end = self.utilized_size - self.shard_begin
2022-08-09 08:39:48 +00:00
if self.chunk_temp.device.type == 'cpu':
self.chunk_total =
2022-08-09 08:39:48 +00:00
self.chunk_total = self.chunk_temp
self.chunk_temp = None
if self.keep_gathered:
if shard_dev is None:
shard_dev = get_current_device()
assert shard_dev.type == 'cuda'
elif shard_dev is None:
shard_dev = torch.device('cpu')
2022-08-09 08:39:48 +00:00
if self.pin_memory or shard_dev.type == 'cpu':
self.cpu_shard = torch.empty(self.shard_size, dtype=self.dtype, pin_memory=self.pin_memory)
2022-08-09 08:39:48 +00:00
self.cpu_vis_flag = True # cpu_shard has been visited
2022-08-09 08:39:48 +00:00
if shard_dev.type == 'cpu':
self.cuda_shard = None
def shard_move(self, device: torch.device, force_copy: bool = False):
"""Move the shard tensor in the chunk.
device: the device to which the shard will move
force_copy: if True, copy function is called mandatorily
2022-08-09 08:39:48 +00:00
# sanity check
assert not self.is_gathered
# when the current chunk is not synchronized with the optimizer
# just use another way for the movement
if not self.optim_sync_flag:
assert device.type == 'cuda', "each chunk should first be moved to CUDA"
self.optim_sync_flag = True
if device.type == 'cuda':
assert device == get_current_device(), "can't move chunk to another device"
if self.cuda_shard:
self.cuda_shard =
if not self.pin_memory:
self.cpu_shard = None
elif device.type == 'cpu':
if self.cuda_shard is None:
if self.pin_memory:
if force_copy or not self.cpu_vis_flag:
# if cpu_shard has been visited
# copy operation is not need
self.cpu_shard = self.cuda_shard.cpu()
self.cpu_vis_flag = True
self.cuda_shard = None
raise NotImplementedError
def access_chunk(self):
"""Make the chunk usable for the parameters inside it. It's an operation done in CUDA.
2022-08-09 08:39:48 +00:00
# sanity check
assert self.chunk_temp is None
if not self.is_gathered:
def release_chunk(self):
"""Release the usable chunk. It's an operation done in CUDA.
2022-08-09 08:39:48 +00:00
# sanity check
assert self.chunk_temp is None
if self.is_gathered:
def reduce(self):
"""Reduce scatter all the gradients. It's an operation done in CUDA.
2022-08-09 08:39:48 +00:00
# sanity check
assert self.is_gathered
if self.pg_size == 1:
# tricky code here
# just move chunk_total to cuda_shard
# the communication is not necessary
elif self.keep_gathered:
# we use all-reduce here
dist.all_reduce(self.chunk_total, group=self.torch_pg)
self.cuda_shard = torch.empty(self.shard_size, dtype=self.dtype, device=get_current_device())
2022-08-09 08:39:48 +00:00
input_list = list(torch.chunk(self.chunk_total, chunks=self.pg_size, dim=0))
dist.reduce_scatter(self.cuda_shard, input_list, group=self.torch_pg)
2022-08-09 08:39:48 +00:00
self.is_gathered = False
def tensor_trans_state(self, tensor: torch.Tensor, tensor_state: TensorState) -> None:
Make a transition of the tensor into the next state.
tensor (torch.Tensor): a torch Tensor object.
tensor_state (TensorState): the target state for transition.
# As the gradient hook can be triggered either before or after post-backward
# tensor's state can be compute -> hold_after_bwd -> ready_for_reduce
# or compute -> ready_for_reduce -> hold_after_bwd
# the second one is invalid, we just ignore ready_for_reduce -> hold_after_bwd
# this function only apply valid state transformation
# invalid calls will be ignored and nothing changes
if (self.tensors_info[tensor].state, tensor_state) not in STATE_TRANS:
self.__update_one_tensor_info(self.tensors_info[tensor], tensor_state)
def copy_tensor_to_chunk_slice(self, tensor: torch.Tensor, data_slice: torch.Tensor) -> None:
Copy data slice to the memory space indexed by the input tensor in the chunk.
tensor (torch.Tensor): the tensor used to retrive meta information
data_slice (torch.Tensor): the tensor to be copied to the chunk
# sanity check
assert self.is_gathered
tensor_info = self.tensors_info[tensor]
2022-08-09 08:39:48 +00:00 = self.chunk_total[tensor_info.offset:tensor_info.end].view(tensor.shape)
def get_valid_length(self) -> int:
"""Get the valid length of the chunk's payload.
if self.keep_gathered:
return self.utilized_size
return self.valid_end
2022-08-09 08:39:48 +00:00
def init_pair(self, friend_chunk: 'Chunk') -> None:
"""Initialize the paired chunk.
if self.paired_chunk is None and friend_chunk.paired_chunk is None:
self.paired_chunk = friend_chunk
friend_chunk.paired_chunk = self
assert self.paired_chunk is friend_chunk
assert friend_chunk.paired_chunk is self
def optim_update(self) -> None:
"""Update the fp16 chunks via their fp32 chunks. It's used by the optimizer.
# sanity check
assert self.paired_chunk is not None
friend_chunk = self.paired_chunk
if self.is_gathered is True:
assert friend_chunk.is_gathered is True
self.optim_sync_flag = True
elif friend_chunk.device_type == 'cuda' and self.device_type == 'cuda':
self.optim_sync_flag = True
self.cpu_vis_flag = False
# optim_sync_flag is set to False
# see shard_move function for more details
assert friend_chunk.device_type == 'cpu'
assert self.device_type == 'cpu'
self.optim_sync_flag = False
self.cpu_vis_flag = False
def get_tensors(self) -> List[torch.Tensor]:
return list(self.tensors_info.keys())
2022-08-09 08:39:48 +00:00
def __gather(self):
if not self.is_gathered:
# sanity check
assert self.cuda_shard is not None
gather_list = list(torch.chunk(input=self.chunk_total, chunks=self.pg_size, dim=0))
dist.all_gather(gather_list, self.cuda_shard, self.torch_pg)
2022-08-09 08:39:48 +00:00
self.cuda_shard = None
self.is_gathered = True
def __scatter(self):
if self.keep_gathered:
if self.is_gathered:
# sanity check
assert self.cuda_shard is None
self.cuda_shard = torch.empty(self.shard_size, dtype=self.dtype, device=self.chunk_total.device)
2022-08-09 08:39:48 +00:00
2022-08-09 08:39:48 +00:00
self.is_gathered = False
def __paired_shard_move(self):
assert self.paired_chunk is not None, "chunks should be paired before training"
optim_chunk = self.paired_chunk
assert self.chunk_size == optim_chunk.chunk_size
# only be called when optimizer state is in CPU memory
# the grad and param should be in the same device
assert self.cuda_shard is None
temp =
# avoid to transform FP32 in CPU
self.cuda_shard =
if not self.pin_memory:
self.cpu_shard = None
def __update_tensors_ptr(self) -> None:
# sanity check
assert self.is_gathered
assert type(self.chunk_total) == torch.Tensor
for tensor, tensor_info in self.tensors_info.items(): = self.chunk_total[tensor_info.offset:tensor_info.end].view(tensor.shape)
def __update_one_tensor_info(self, tensor_info: TensorInfo, next_state: TensorState):
self.tensors_state_monitor[tensor_info.state] -= 1
tensor_info.state = next_state
self.tensors_state_monitor[tensor_info.state] += 1
def __update_tensors_state(self, next_state: TensorState, prev_state: Optional[TensorState] = None):
for tensor_info in self.tensors_info.values():
if prev_state is None or tensor_info.state == prev_state:
self.__update_one_tensor_info(tensor_info, next_state)
def __hash__(self) -> int:
return hash(id(self))
def __eq__(self, __o: object) -> bool:
return self is __o
def __repr__(self, detailed: bool = True):
output = [
"Chunk Information:\n",
"\tchunk size: {}, chunk dtype: {}, process group size: {}\n".format(self.chunk_size, self.dtype,
"\t# of tensors: {}, utilized size: {}, utilized percentage: {:.2f}\n".format(
self.num_tensors, self.utilized_size, self.utilized_size / self.chunk_size)
def print_tensor(tensor, prefix=''):
output.append("{}shape: {}, dtype: {}, device: {}\n".format(prefix, tensor.shape, tensor.dtype,
if self.chunk_temp is not None:
output.append("\tchunk temp:\n")
print_tensor(tensor=self.chunk_temp, prefix='\t\t')
if self.chunk_total is not None and > 0:
output.append("\tchunk total:\n")
print_tensor(tensor=self.chunk_total, prefix='\t\t')
if self.cuda_shard is not None:
output.append("\tcuda shard:\n")
print_tensor(tensor=self.cuda_shard, prefix='\t\t')
if self.cpu_shard is not None:
output.append("\tcpu shard:\n")
print_tensor(tensor=self.cpu_shard, prefix='\t\t')
memory_info = self.memory_usage
output.append("\tmemory usage: cuda {}, cpu {}\n".format(memory_info['cuda'], memory_info['cpu']))
if detailed:
output.append("\ttensor state monitor:\n")
for st in TensorState:
output.append("\t\t# of {}: {}\n".format(st, self.tensors_state_monitor[st]))
return ''.join(output)