mirror of https://github.com/hpcaitech/ColossalAI
547 lines
21 KiB
Python
547 lines
21 KiB
Python
import torch
|
|
import torch.distributed as dist
|
|
from dataclasses import dataclass
|
|
from enum import Enum
|
|
from typing import Optional, Dict, List
|
|
|
|
from colossalai.utils import get_current_device
|
|
from colossalai.tensor import ProcessGroup as ColoProcessGroup
|
|
|
|
|
|
class TensorState(Enum):
|
|
FREE = 0
|
|
COMPUTE = 1
|
|
HOLD = 2
|
|
HOLD_AFTER_BWD = 3
|
|
READY_FOR_REDUCE = 4
|
|
|
|
|
|
STATE_TRANS = ((TensorState.FREE, TensorState.HOLD), (TensorState.FREE, TensorState.COMPUTE),
|
|
(TensorState.HOLD, TensorState.FREE), (TensorState.HOLD, TensorState.COMPUTE),
|
|
(TensorState.COMPUTE, TensorState.HOLD), (TensorState.COMPUTE, TensorState.HOLD_AFTER_BWD),
|
|
(TensorState.COMPUTE, TensorState.READY_FOR_REDUCE), (TensorState.HOLD_AFTER_BWD, TensorState.COMPUTE),
|
|
(TensorState.HOLD_AFTER_BWD, TensorState.READY_FOR_REDUCE), (TensorState.READY_FOR_REDUCE,
|
|
TensorState.HOLD))
|
|
|
|
|
|
@dataclass
|
|
class TensorInfo:
|
|
state: TensorState
|
|
offset: int
|
|
end: int
|
|
|
|
|
|
class ChunkFullError(Exception):
|
|
pass
|
|
|
|
|
|
def is_storage_empty(tensor: torch.Tensor) -> bool:
|
|
return tensor.storage().size() == 0
|
|
|
|
|
|
def free_storage(tensor: torch.Tensor) -> None:
|
|
if not is_storage_empty(tensor):
|
|
tensor.storage().resize_(0)
|
|
|
|
|
|
def alloc_storage(tensor: torch.Tensor) -> None:
|
|
if is_storage_empty(tensor):
|
|
tensor.storage().resize_(tensor.numel())
|
|
|
|
|
|
class Chunk:
|
|
|
|
def __init__(self,
|
|
chunk_size: int,
|
|
process_group: ColoProcessGroup,
|
|
dtype: torch.dtype,
|
|
init_device: Optional[torch.device] = None,
|
|
keep_gathered: bool = False,
|
|
pin_memory: bool = False) -> None:
|
|
"""
|
|
Chunk: A container owning a piece of contiguous memory space for tensors
|
|
Here we use all-gather operation to gather the whole chunk.
|
|
Currently, Chunk is exclusively used for DDP and ZeRO DDP and it doesn't support unused parameters.
|
|
It is designed to make the full use of communication and PCIE bandwidth.
|
|
|
|
Args:
|
|
chunk_size (int): the number of elements in the chunk
|
|
process_group (ColoProcessGroup): the process group of this chunk
|
|
dtype (torch.dtype): the data type of the chunk
|
|
init_device (torch.device): optional, the device where the tensor is initialized
|
|
The default value is None, which is the current GPU
|
|
keep_gathered (bool): optional, if True, this chunk is always gathered in CUDA memory
|
|
pin_memory (bool): optional, if True, this chunk always has a shard copied in pinned CPU memory
|
|
"""
|
|
|
|
self.chunk_size = chunk_size
|
|
self.utilized_size = 0
|
|
# Here, we use torch process group,
|
|
# since ColoProcessGroup might get deprecated soon
|
|
self.torch_pg = process_group.dp_process_group()
|
|
self.pg_size = dist.get_world_size(self.torch_pg)
|
|
self.pg_rank = dist.get_rank(self.torch_pg)
|
|
|
|
# the chunk size should be able to be divied by the size of GPU
|
|
if not keep_gathered:
|
|
assert chunk_size % self.pg_size == 0
|
|
self.shard_size = chunk_size // self.pg_size
|
|
self.shard_begin = self.shard_size * self.pg_rank
|
|
self.shard_end = self.shard_begin + self.shard_size
|
|
self.valid_end = self.shard_size
|
|
|
|
self.dtype = dtype
|
|
device = init_device or get_current_device()
|
|
self.chunk_temp = torch.zeros(chunk_size, dtype=dtype, device=device) # keep all zero
|
|
self.chunk_total = None # we force chunk_total located in CUDA
|
|
self.cuda_shard = None # using two attributes for the better interpretation
|
|
self.cpu_shard = None
|
|
self.is_gathered = True
|
|
|
|
self.chunk_mem = self.chunk_size * self.chunk_temp.element_size()
|
|
self.shard_mem = self.chunk_mem // self.pg_size
|
|
|
|
# each tensor is associated with a TensorInfo to track meta info
|
|
self.tensors_info: Dict[torch.Tensor, TensorInfo] = {}
|
|
# the total number of all tensors
|
|
self.num_tensors = 0
|
|
# monitor the states of all tensors
|
|
self.tensors_state_monitor: Dict[TensorState, int] = dict()
|
|
for state in TensorState:
|
|
self.tensors_state_monitor[state] = 0
|
|
|
|
# some chunks can keep gathered all the time
|
|
# so their computation patterns are the same as that of the parameters in DDP
|
|
self.keep_gathered = keep_gathered
|
|
if self.keep_gathered:
|
|
pin_memory = False # since this chunk is gathered, it doesn't need to pin
|
|
|
|
# if pin_memory is True, we allocate a piece of CPU pin-memory
|
|
# for it all the time
|
|
self.pin_memory = pin_memory
|
|
|
|
# we introduce the paired chunk here
|
|
# it refers to another chunk having the same parameters
|
|
# but with different dtype(such as fp16_chunk.paired_chunk -> fp32_chunk
|
|
self.paired_chunk = None
|
|
# if this chunk is synchronized with the optimizer, the flag is True
|
|
self.optim_sync_flag = True
|
|
# if the cpu_shard has been visited during the training step, the flag is True
|
|
self.cpu_vis_flag = False
|
|
|
|
@property
|
|
def memory_usage(self) -> Dict[str, int]:
|
|
cuda_memory = 0
|
|
cpu_memory = 0
|
|
|
|
if self.chunk_temp is not None:
|
|
# this chunk is not closed
|
|
if self.chunk_temp.device.type == 'cuda':
|
|
cuda_memory += self.chunk_mem
|
|
else:
|
|
cpu_memory += self.chunk_mem
|
|
else:
|
|
if self.is_gathered:
|
|
cuda_memory += self.chunk_mem
|
|
if self.cuda_shard is not None:
|
|
cuda_memory += self.shard_mem
|
|
if self.cpu_shard is not None:
|
|
cpu_memory += self.shard_mem
|
|
|
|
return dict(cuda=cuda_memory, cpu=cpu_memory)
|
|
|
|
@property
|
|
def device_type(self) -> str:
|
|
if self.chunk_temp is not None:
|
|
return self.chunk_temp.device.type
|
|
else:
|
|
if self.is_gathered:
|
|
return 'cuda'
|
|
elif self.cuda_shard is not None:
|
|
return 'cuda'
|
|
else:
|
|
return 'cpu'
|
|
|
|
@property
|
|
def payload(self) -> torch.Tensor:
|
|
# sanity check
|
|
assert self.chunk_temp is None
|
|
|
|
if self.is_gathered:
|
|
return self.chunk_total
|
|
elif self.cuda_shard is not None:
|
|
return self.cuda_shard
|
|
else:
|
|
return self.cpu_shard
|
|
|
|
@property
|
|
def payload_mem(self) -> int:
|
|
# sanity check
|
|
assert self.chunk_temp is None
|
|
|
|
if self.is_gathered:
|
|
return self.chunk_mem
|
|
else:
|
|
return self.shard_mem
|
|
|
|
@property
|
|
def can_move(self) -> bool:
|
|
return not self.is_gathered
|
|
|
|
@property
|
|
def can_release(self) -> bool:
|
|
if self.keep_gathered:
|
|
return False
|
|
else:
|
|
return self.tensors_state_monitor[TensorState.HOLD] + \
|
|
self.tensors_state_monitor[TensorState.HOLD_AFTER_BWD] == self.num_tensors
|
|
|
|
@property
|
|
def can_reduce(self):
|
|
return self.tensors_state_monitor[TensorState.READY_FOR_REDUCE] == self.num_tensors
|
|
|
|
@property
|
|
def has_inf_or_nan(self) -> bool:
|
|
"""Check if the chunk has inf or nan values in CUDA.
|
|
"""
|
|
if self.is_gathered:
|
|
valid_tensor = self.chunk_total[:self.utilized_size]
|
|
else:
|
|
assert self.cuda_shard is not None # only check in CUDA
|
|
valid_tensor = self.cuda_shard[:self.valid_end]
|
|
|
|
return torch.isinf(valid_tensor).any().item() | torch.isnan(valid_tensor).any().item()
|
|
|
|
def append_tensor(self, tensor: torch.Tensor):
|
|
"""Add a tensor to the chunk.
|
|
|
|
Args:
|
|
tensor (torch.Tensor): a tensor to be added to the chunk
|
|
"""
|
|
# sanity check
|
|
assert self.chunk_temp is not None
|
|
assert tensor.dtype == self.dtype
|
|
|
|
new_utilized_size = self.utilized_size + tensor.numel()
|
|
# raise exception when the chunk size is exceeded
|
|
if new_utilized_size > self.chunk_size:
|
|
raise ChunkFullError
|
|
|
|
self.chunk_temp[self.utilized_size:new_utilized_size].copy_(tensor.data.flatten())
|
|
assert type(self.chunk_temp) == torch.Tensor, "copy_tensor_to_chunk_slice must use a torch tensor"
|
|
tensor.data = self.chunk_temp[self.utilized_size:new_utilized_size].view(tensor.shape)
|
|
|
|
# record all the information about the tensor
|
|
self.num_tensors += 1
|
|
tensor_state = TensorState.HOLD
|
|
self.tensors_info[tensor] = TensorInfo(tensor_state, self.utilized_size, new_utilized_size)
|
|
self.tensors_state_monitor[tensor_state] += 1
|
|
self.utilized_size = new_utilized_size
|
|
|
|
def close_chunk(self, shard_dev: Optional[torch.device] = None):
|
|
"""Close the chunk. Any tensor can't be appended to a closed chunk later.
|
|
|
|
Args:
|
|
shard_dev: the device where the shard locates
|
|
"""
|
|
# sanity check
|
|
assert self.chunk_temp is not None
|
|
|
|
# calculate the valid end for each shard
|
|
if self.utilized_size <= self.shard_begin:
|
|
self.valid_end = 0
|
|
elif self.utilized_size < self.shard_end:
|
|
self.valid_end = self.utilized_size - self.shard_begin
|
|
|
|
if self.chunk_temp.device.type == 'cpu':
|
|
self.chunk_total = self.chunk_temp.to(get_current_device())
|
|
self.__update_tensors_ptr()
|
|
else:
|
|
self.chunk_total = self.chunk_temp
|
|
self.chunk_temp = None
|
|
|
|
self.__scatter()
|
|
|
|
if self.keep_gathered:
|
|
if shard_dev is None:
|
|
shard_dev = get_current_device()
|
|
else:
|
|
assert shard_dev.type == 'cuda'
|
|
elif shard_dev is None:
|
|
shard_dev = torch.device('cpu')
|
|
|
|
if self.pin_memory or shard_dev.type == 'cpu':
|
|
self.cpu_shard = torch.empty(self.shard_size, dtype=self.dtype, pin_memory=self.pin_memory)
|
|
self.cpu_shard.copy_(self.cuda_shard)
|
|
self.cpu_vis_flag = True # cpu_shard has been visited
|
|
|
|
if shard_dev.type == 'cpu':
|
|
self.cuda_shard = None
|
|
|
|
def shard_move(self, device: torch.device, force_copy: bool = False):
|
|
"""Move the shard tensor in the chunk.
|
|
|
|
Args:
|
|
device: the device to which the shard will move
|
|
force_copy: if True, copy function is called mandatorily
|
|
"""
|
|
# sanity check
|
|
assert not self.is_gathered
|
|
# when the current chunk is not synchronized with the optimizer
|
|
# just use another way for the movement
|
|
if not self.optim_sync_flag:
|
|
assert device.type == 'cuda', "each chunk should first be moved to CUDA"
|
|
self.__paired_shard_move()
|
|
self.optim_sync_flag = True
|
|
return
|
|
|
|
if device.type == 'cuda':
|
|
assert device == get_current_device(), "can't move chunk to another device"
|
|
|
|
if self.cuda_shard:
|
|
return
|
|
|
|
self.cuda_shard = self.cpu_shard.to(get_current_device())
|
|
|
|
if not self.pin_memory:
|
|
self.cpu_shard = None
|
|
elif device.type == 'cpu':
|
|
if self.cuda_shard is None:
|
|
return
|
|
|
|
if self.pin_memory:
|
|
if force_copy or not self.cpu_vis_flag:
|
|
self.cpu_shard.copy_(self.cuda_shard)
|
|
# if cpu_shard has been visited
|
|
# copy operation is not need
|
|
else:
|
|
self.cpu_shard = self.cuda_shard.cpu()
|
|
self.cpu_vis_flag = True
|
|
self.cuda_shard = None
|
|
else:
|
|
raise NotImplementedError
|
|
|
|
def access_chunk(self):
|
|
"""Make the chunk usable for the parameters inside it. It's an operation done in CUDA.
|
|
"""
|
|
# sanity check
|
|
assert self.chunk_temp is None
|
|
|
|
if not self.is_gathered:
|
|
self.__gather()
|
|
self.__update_tensors_ptr()
|
|
|
|
def release_chunk(self):
|
|
"""Release the usable chunk. It's an operation done in CUDA.
|
|
"""
|
|
# sanity check
|
|
assert self.chunk_temp is None
|
|
|
|
if self.is_gathered:
|
|
self.__scatter()
|
|
|
|
def reduce(self):
|
|
"""Reduce scatter all the gradients. It's an operation done in CUDA.
|
|
"""
|
|
# sanity check
|
|
assert self.is_gathered
|
|
|
|
if self.pg_size == 1:
|
|
# tricky code here
|
|
# just move chunk_total to cuda_shard
|
|
# the communication is not necessary
|
|
self.__scatter()
|
|
elif self.keep_gathered:
|
|
# we use all-reduce here
|
|
dist.all_reduce(self.chunk_total, group=self.torch_pg)
|
|
else:
|
|
self.cuda_shard = torch.empty(self.shard_size, dtype=self.dtype, device=get_current_device())
|
|
|
|
input_list = list(torch.chunk(self.chunk_total, chunks=self.pg_size, dim=0))
|
|
dist.reduce_scatter(self.cuda_shard, input_list, group=self.torch_pg)
|
|
|
|
free_storage(self.chunk_total)
|
|
self.is_gathered = False
|
|
self.__update_tensors_state(TensorState.HOLD)
|
|
|
|
def tensor_trans_state(self, tensor: torch.Tensor, tensor_state: TensorState) -> None:
|
|
"""
|
|
Make a transition of the tensor into the next state.
|
|
|
|
Args:
|
|
tensor (torch.Tensor): a torch Tensor object.
|
|
tensor_state (TensorState): the target state for transition.
|
|
"""
|
|
|
|
# As the gradient hook can be triggered either before or after post-backward
|
|
# tensor's state can be compute -> hold_after_bwd -> ready_for_reduce
|
|
# or compute -> ready_for_reduce -> hold_after_bwd
|
|
# the second one is invalid, we just ignore ready_for_reduce -> hold_after_bwd
|
|
# this function only apply valid state transformation
|
|
# invalid calls will be ignored and nothing changes
|
|
if (self.tensors_info[tensor].state, tensor_state) not in STATE_TRANS:
|
|
return
|
|
self.__update_one_tensor_info(self.tensors_info[tensor], tensor_state)
|
|
|
|
def copy_tensor_to_chunk_slice(self, tensor: torch.Tensor, data_slice: torch.Tensor) -> None:
|
|
"""
|
|
Copy data slice to the memory space indexed by the input tensor in the chunk.
|
|
|
|
Args:
|
|
tensor (torch.Tensor): the tensor used to retrive meta information
|
|
data_slice (torch.Tensor): the tensor to be copied to the chunk
|
|
"""
|
|
# sanity check
|
|
assert self.is_gathered
|
|
|
|
tensor_info = self.tensors_info[tensor]
|
|
self.chunk_total[tensor_info.offset:tensor_info.end].copy_(data_slice.data.flatten())
|
|
tensor.data = self.chunk_total[tensor_info.offset:tensor_info.end].view(tensor.shape)
|
|
|
|
def get_valid_length(self) -> int:
|
|
"""Get the valid length of the chunk's payload.
|
|
"""
|
|
if self.keep_gathered:
|
|
return self.utilized_size
|
|
else:
|
|
return self.valid_end
|
|
|
|
def init_pair(self, friend_chunk: 'Chunk') -> None:
|
|
"""Initialize the paired chunk.
|
|
"""
|
|
if self.paired_chunk is None and friend_chunk.paired_chunk is None:
|
|
self.paired_chunk = friend_chunk
|
|
friend_chunk.paired_chunk = self
|
|
else:
|
|
assert self.paired_chunk is friend_chunk
|
|
assert friend_chunk.paired_chunk is self
|
|
|
|
def optim_update(self) -> None:
|
|
"""Update the fp16 chunks via their fp32 chunks. It's used by the optimizer.
|
|
"""
|
|
# sanity check
|
|
assert self.paired_chunk is not None
|
|
|
|
friend_chunk = self.paired_chunk
|
|
if self.is_gathered is True:
|
|
assert friend_chunk.is_gathered is True
|
|
self.chunk_total.copy_(friend_chunk.chunk_total)
|
|
self.optim_sync_flag = True
|
|
elif friend_chunk.device_type == 'cuda' and self.device_type == 'cuda':
|
|
self.cuda_shard.copy_(friend_chunk.cuda_shard)
|
|
self.optim_sync_flag = True
|
|
self.cpu_vis_flag = False
|
|
else:
|
|
assert friend_chunk.device_type == 'cpu'
|
|
assert self.device_type == 'cpu'
|
|
self.optim_sync_flag = False
|
|
self.cpu_vis_flag = False
|
|
|
|
def get_tensors(self) -> List[torch.Tensor]:
|
|
return list(self.tensors_info.keys())
|
|
|
|
def __gather(self):
|
|
if not self.is_gathered:
|
|
# sanity check
|
|
assert self.cuda_shard is not None
|
|
|
|
alloc_storage(self.chunk_total)
|
|
gather_list = list(torch.chunk(input=self.chunk_total, chunks=self.pg_size, dim=0))
|
|
dist.all_gather(gather_list, self.cuda_shard, self.torch_pg)
|
|
|
|
self.cuda_shard = None
|
|
self.is_gathered = True
|
|
|
|
def __scatter(self):
|
|
if self.keep_gathered:
|
|
return
|
|
|
|
if self.is_gathered:
|
|
# sanity check
|
|
assert self.cuda_shard is None
|
|
|
|
self.cuda_shard = torch.empty(self.shard_size, dtype=self.dtype, device=self.chunk_total.device)
|
|
|
|
self.cuda_shard.copy_(self.chunk_total[self.shard_begin:self.shard_end])
|
|
|
|
free_storage(self.chunk_total)
|
|
self.is_gathered = False
|
|
|
|
def __paired_shard_move(self):
|
|
assert self.paired_chunk is not None, "chunks should be paired before training"
|
|
optim_chunk = self.paired_chunk
|
|
assert self.chunk_size == optim_chunk.chunk_size
|
|
|
|
# only be called when optimizer state is in CPU memory
|
|
# the grad and param should be in the same device
|
|
assert self.cuda_shard is None
|
|
temp = optim_chunk.cpu_shard.to(get_current_device())
|
|
# avoid to transform FP32 in CPU
|
|
self.cuda_shard = temp.to(self.dtype)
|
|
|
|
if not self.pin_memory:
|
|
self.cpu_shard = None
|
|
|
|
def __update_tensors_ptr(self) -> None:
|
|
# sanity check
|
|
assert self.is_gathered
|
|
assert type(self.chunk_total) == torch.Tensor
|
|
|
|
for tensor, tensor_info in self.tensors_info.items():
|
|
tensor.data = self.chunk_total[tensor_info.offset:tensor_info.end].view(tensor.shape)
|
|
|
|
def __update_one_tensor_info(self, tensor_info: TensorInfo, next_state: TensorState):
|
|
self.tensors_state_monitor[tensor_info.state] -= 1
|
|
tensor_info.state = next_state
|
|
self.tensors_state_monitor[tensor_info.state] += 1
|
|
|
|
def __update_tensors_state(self, next_state: TensorState, prev_state: Optional[TensorState] = None):
|
|
for tensor_info in self.tensors_info.values():
|
|
if prev_state is None or tensor_info.state == prev_state:
|
|
self.__update_one_tensor_info(tensor_info, next_state)
|
|
|
|
def __hash__(self) -> int:
|
|
return hash(id(self))
|
|
|
|
def __eq__(self, __o: object) -> bool:
|
|
return self is __o
|
|
|
|
def __repr__(self, detailed: bool = True):
|
|
output = [
|
|
"Chunk Information:\n",
|
|
"\tchunk size: {}, chunk dtype: {}, process group size: {}\n".format(self.chunk_size, self.dtype,
|
|
self.pg_size),
|
|
"\t# of tensors: {}, utilized size: {}, utilized percentage: {:.2f}\n".format(
|
|
self.num_tensors, self.utilized_size, self.utilized_size / self.chunk_size)
|
|
]
|
|
|
|
def print_tensor(tensor, prefix=''):
|
|
output.append("{}shape: {}, dtype: {}, device: {}\n".format(prefix, tensor.shape, tensor.dtype,
|
|
tensor.device))
|
|
|
|
if self.chunk_temp is not None:
|
|
output.append("\tchunk temp:\n")
|
|
print_tensor(tensor=self.chunk_temp, prefix='\t\t')
|
|
|
|
if self.chunk_total is not None and self.chunk_total.storage().size() > 0:
|
|
output.append("\tchunk total:\n")
|
|
print_tensor(tensor=self.chunk_total, prefix='\t\t')
|
|
|
|
if self.cuda_shard is not None:
|
|
output.append("\tcuda shard:\n")
|
|
print_tensor(tensor=self.cuda_shard, prefix='\t\t')
|
|
|
|
if self.cpu_shard is not None:
|
|
output.append("\tcpu shard:\n")
|
|
print_tensor(tensor=self.cpu_shard, prefix='\t\t')
|
|
|
|
memory_info = self.memory_usage
|
|
output.append("\tmemory usage: cuda {}, cpu {}\n".format(memory_info['cuda'], memory_info['cpu']))
|
|
|
|
if detailed:
|
|
output.append("\ttensor state monitor:\n")
|
|
for st in TensorState:
|
|
output.append("\t\t# of {}: {}\n".format(st, self.tensors_state_monitor[st]))
|
|
|
|
return ''.join(output)
|