ColossalAI/colossalai/gemini/update/chunkv2.py

448 lines
18 KiB
Python
Raw Normal View History

2022-08-09 08:39:48 +00:00
import torch
import torch.distributed as dist
from typing import Optional, Dict, List
2022-08-09 08:39:48 +00:00
from colossalai.utils import get_current_device
from colossalai.tensor import ProcessGroup as ColoProcessGroup
from colossalai.gemini.chunk import TensorState, STATE_TRANS, TensorInfo, ChunkFullError, \
free_storage, alloc_storage
class ChunkV2:
2022-08-09 08:39:48 +00:00
def __init__(self,
chunk_size: int,
process_group: ColoProcessGroup,
dtype: torch.dtype,
init_device: Optional[torch.device] = None,
keep_gathered: bool = False,
pin_memory: bool = False) -> None:
"""
Chunk: A container owning a piece of contiguous memory space for tensors
AgChunk is a kind of chunk, which uses all-gather operation to gather the whole chunk.
This kind of chunk is exclusively used for DDP and ZeRO DDP.
It is designed to make the full use of communication and PCIE bandwidth.
Args:
chunk_size (int): the number of elements in a chunk
process_group (ColoProcessGroup): the process group of this chunk
dtype (torch.dtype): the data type of the chunk
init_device (torch.device): optional, the device where the tensor is initialized
The default value is None, which is the current GPU
keep_gathered (bool): optional, if True, this chunk is always gathered in CUDA memory
pin_memory (bool): optional, if True, this chunk always has a shard copy in pinned CPU memory
"""
self.chunk_size = chunk_size
self.utilized_size = 0
# Here, we use torch process group,
# since ColoProcessGroup might get deprecated soon
self.torch_pg = process_group.dp_process_group()
2022-08-09 08:39:48 +00:00
self.pg_size = dist.get_world_size(self.torch_pg)
self.pg_rank = dist.get_rank(self.torch_pg)
# the chunk size should be able to be divied by the size of GPU
assert chunk_size % self.pg_size == 0
self.shard_size = chunk_size // self.pg_size
self.shard_begin = self.shard_size * self.pg_rank
self.shard_end = self.shard_begin + self.shard_size
self.valid_end = self.shard_size
2022-08-09 08:39:48 +00:00
self.dtype = dtype
device = init_device or get_current_device()
self.chunk_temp = torch.zeros(chunk_size, dtype=dtype, device=device) # keep all zero
self.chunk_total = None # we force chunk_total located in CUDA
self.cuda_shard = None # using two attributes for the better interpretation
2022-08-09 08:39:48 +00:00
self.cpu_shard = None
self.is_gathered = True
self.chunk_mem = self.chunk_size * self.chunk_temp.element_size()
self.shard_mem = self.chunk_mem // self.pg_size
# each tensor is associated with a TensorInfo to track meta info
self.tensors_info: Dict[torch.Tensor, TensorInfo] = {}
# the total number of all tensors
self.num_tensors = 0
# monitor the states of all tensors
self.tensors_state_monitor: Dict[TensorState, int] = dict()
for state in TensorState:
self.tensors_state_monitor[state] = 0
# some chunks can keep gathered all the time
# so their computation patterns are the same as that of the parameters in DDP
self.keep_gathered = keep_gathered
if self.keep_gathered:
pin_memory = False # since this chunk is gathered, it doesn't need to pin
2022-08-09 08:39:48 +00:00
# if pin_memory is True, we allocate a piece of CPU pin-memory
# for it all the time
self.pin_memory = pin_memory
# we introduce the paired chunk here
# it refers to another chunk having the same parameters
# but with different dtype(such as fp16_chunk.mapping_chunk -> fp32_chunk
self.paired_chunk = None
# if the the gradient of this chunk is reduced, the flag is True
# so the flag is False for unused parameters
self.grad_reduced_flag = False
# if this chunk is synchronized with the optimizer, the flag is True
self.optim_sync_flag = True
# if the cpu_shard has been visited during the training step, the flag is True
self.cpu_vis_flag = False
@property
def memory_usage(self):
cuda_memory = 0
cpu_memory = 0
if self.chunk_temp is not None:
# this chunk is not closed
if self.chunk_temp.device.type == 'cuda':
cuda_memory += self.chunk_mem
else:
cpu_memory += self.chunk_mem
else:
if self.is_gathered:
cuda_memory += self.chunk_mem
if self.cuda_shard is not None:
cuda_memory += self.shard_mem
if self.cpu_shard is not None:
cpu_memory += self.shard_mem
return dict(cuda=cuda_memory, cpu=cpu_memory)
@property
def device_type(self):
if self.chunk_temp is not None:
return self.chunk_temp.device.type
else:
if self.is_gathered:
2022-08-09 08:39:48 +00:00
return 'cuda'
elif self.cuda_shard is not None:
return 'cuda'
else:
return 'cpu'
def append_tensor(self, tensor: torch.Tensor):
"""Add a tensor to the chunk.
Args:
tensor (torch.Tensor): a tensor to be added to the chunk
"""
# sanity check
assert self.chunk_temp is not None
assert tensor.dtype == self.dtype
new_utilized_size = self.utilized_size + tensor.numel()
# raise exception when the chunk size is exceeded
if new_utilized_size > self.chunk_size:
raise ChunkFullError
self.chunk_temp[self.utilized_size:new_utilized_size].copy_(tensor.data.flatten())
2022-08-09 08:39:48 +00:00
assert type(self.chunk_temp) == torch.Tensor, "copy_tensor_to_chunk_slice must use a torch tensor"
tensor.data = self.chunk_temp[self.utilized_size:new_utilized_size].view(tensor.shape)
2022-08-09 08:39:48 +00:00
# record all the information about the tensor
self.num_tensors += 1
tensor_state = TensorState.HOLD
self.tensors_info[tensor] = TensorInfo(tensor_state, self.utilized_size, new_utilized_size)
self.tensors_state_monitor[tensor_state] += 1
self.utilized_size = new_utilized_size
def close_chunk(self, shard_dev: Optional[torch.device] = None):
2022-08-09 08:39:48 +00:00
"""Close the chunk. Any tensor can't be appended to a closed chunk.
"""
# sanity check
assert self.chunk_temp is not None
# calculate the valid end for each shard
if self.utilized_size <= self.shard_begin:
self.valid_end = 0
elif self.utilized_size < self.shard_end:
self.valid_end = self.utilized_size - self.shard_begin
2022-08-09 08:39:48 +00:00
if self.chunk_temp.device.type == 'cpu':
self.chunk_total = self.chunk_temp.to(get_current_device())
else:
self.chunk_total = self.chunk_temp
self.chunk_temp = None
self.__scatter()
if self.keep_gathered:
if shard_dev is None:
shard_dev = get_current_device()
else:
assert shard_dev.type == 'cuda'
elif shard_dev is None:
shard_dev = torch.device('cpu')
2022-08-09 08:39:48 +00:00
if self.pin_memory or shard_dev.type == 'cpu':
self.cpu_shard = torch.empty(self.shard_size, dtype=self.dtype, pin_memory=self.pin_memory)
2022-08-09 08:39:48 +00:00
self.cpu_shard.copy_(self.cuda_shard)
self.cpu_vis_flag = True # cpu_shard has been visited
2022-08-09 08:39:48 +00:00
if shard_dev.type == 'cpu':
self.cuda_shard = None
def shard_move(self, device: torch.device, force_copy: bool = False):
# sanity check
assert not self.is_gathered
# when the current chunk is not synchronized with the optimizer
# just use another way for the movement
if not self.optim_sync_flag:
assert device.type == 'cuda', "each chunk should first be moved to CUDA"
self.__paired_shard_move()
self.optim_sync_flag = True
return
if device.type == 'cuda':
assert device == get_current_device(), "can't move chunk to another device"
if self.cuda_shard:
return
self.cuda_shard = self.cpu_shard.to(get_current_device())
if not self.pin_memory:
self.cpu_shard = None
elif device.type == 'cpu':
if self.cuda_shard is None:
return
if self.pin_memory:
if force_copy or not self.cpu_vis_flag:
self.cpu_shard.copy_(self.cuda_shard)
# if cpu_shard has been visited
# copy operation is not need
else:
self.cpu_shard = self.cuda_shard.cpu()
self.cpu_vis_flag = True
self.cuda_shard = None
else:
raise NotImplementedError
def access_chunk(self):
"""Make the chunk usable for the parameters inside it.
It is an operation done in CUDA.
"""
# sanity check
assert self.chunk_temp is None
if not self.is_gathered:
self.__gather()
self.__update_tensors_ptr()
def release_chunk(self):
"""Release the usable chunk.
It is an operation done in CUDA.
"""
# sanity check
assert self.chunk_temp is None
if self.is_gathered:
self.__scatter()
def reduce(self):
"""Reduce scatter all the gradients.
It is an operation done in CUDA.
"""
# sanity check
assert self.is_gathered
if self.pg_size == 1:
# tricky code here
# just move chunk_total to cuda_shard
# the communication is not necessary
self.__scatter()
elif self.keep_gathered:
# we use all-reduce here
dist.all_reduce(self.chunk_total, group=self.torch_pg)
else:
self.cuda_shard = torch.empty(self.shard_size, dtype=self.dtype, device=get_current_device())
2022-08-09 08:39:48 +00:00
input_list = list(torch.chunk(self.chunk_total, chunks=self.pg_size, dim=0))
dist.reduce_scatter(self.cuda_shard, input_list, group=self.torch_pg)
2022-08-09 08:39:48 +00:00
free_storage(self.chunk_total)
self.is_gathered = False
self.__update_tensors_state(TensorState.HOLD)
self.grad_reduced_flag = True
def tensor_trans_state(self, tensor: torch.Tensor, tensor_state: TensorState) -> None:
"""
Make a transition of the tensor into the next state.
Args:
tensor (torch.Tensor): a torch Tensor object.
tensor_state (TensorState): the target state for transition.
"""
# As the gradient hook can be triggered either before or after post-backward
# tensor's state can be compute -> hold_after_bwd -> ready_for_reduce
# or compute -> ready_for_reduce -> hold_after_bwd
# the second one is invalid, we just ignore ready_for_reduce -> hold_after_bwd
# this function only apply valid state transformation
# invalid calls will be ignored and nothing changes
if (self.tensors_info[tensor].state, tensor_state) not in STATE_TRANS:
# print(
# f'WARNING: Rank{self.process_group.rank()} apply invalid state trans: {self.tensors_info[tensor].state} to {tensor_state}'
# )
return
self.__update_one_tensor_info(self.tensors_info[tensor], tensor_state)
def copy_tensor_to_chunk_slice(self, tensor: torch.Tensor, data_slice: torch.Tensor) -> None:
"""
Copy data slice to the memory space indexed by the input tensor in the chunk.
Args:
tensor (torch.Tensor): the tensor used to retrive meta information
data_slice (torch.Tensor): the tensor to be copied to the chunk
"""
# sanity check
assert self.is_gathered
tensor_info = self.tensors_info[tensor]
self.chunk_total[tensor_info.offset:tensor_info.end].copy_(data_slice.data.flatten())
2022-08-09 08:39:48 +00:00
tensor.data = self.chunk_total[tensor_info.offset:tensor_info.end].view(tensor.shape)
@property
def can_move(self) -> bool:
return not self.is_gathered
2022-08-09 08:39:48 +00:00
@property
def can_release(self) -> bool:
if self.keep_gathered:
return False
else:
return self.tensors_state_monitor[TensorState.HOLD] + \
self.tensors_state_monitor[TensorState.HOLD_AFTER_BWD] == self.num_tensors
2022-08-09 08:39:48 +00:00
@property
def can_reduce(self):
return self.tensors_state_monitor[TensorState.READY_FOR_REDUCE] == self.num_tensors
@property
def has_inf_or_nan(self) -> bool:
"""
Check if the chunk has inf or nan values in CUDA.
"""
if self.is_gathered:
valid_tensor = self.chunk_total[:self.utilized_size]
else:
assert self.cuda_shard is not None # only check in CUDA
valid_tensor = self.cuda_shard[:self.valid_end]
return torch.isinf(valid_tensor).any().item() | torch.isnan(valid_tensor).any().item()
2022-08-09 08:39:48 +00:00
def __gather(self):
if not self.is_gathered:
# sanity check
assert self.cuda_shard is not None
if self.pg_size == 1:
self.chunk_total = self.cuda_shard
else:
alloc_storage(self.chunk_total)
gather_list = list(torch.chunk(input=self.chunk_total, chunks=self.pg_size, dim=0))
2022-08-09 08:39:48 +00:00
dist.all_gather(gather_list, self.cuda_shard, self.torch_pg)
self.cuda_shard = None
self.is_gathered = True
def __scatter(self):
if self.keep_gathered:
return
if self.is_gathered:
# sanity check
assert self.cuda_shard is None
self.cuda_shard = torch.empty(self.shard_size, dtype=self.dtype, device=self.chunk_total.device)
2022-08-09 08:39:48 +00:00
self.cuda_shard.copy_(self.chunk_total[self.shard_begin:self.shard_end])
2022-08-09 08:39:48 +00:00
free_storage(self.chunk_total)
self.is_gathered = False
def __paired_shard_move(self):
assert self.paired_chunk is not None, "chunks should be paired before training"
optim_chunk = self.paired_chunk
assert self.chunk_size == optim_chunk.chunk_size
# only be called when optimizer state is in CPU memory
# the grad and param should be in the same device
assert self.cuda_shard is None
temp = optim_chunk.cpu_shard.to(get_current_device())
# avoid to transform FP32 in CPU
self.cuda_shard = temp.to(self.dtype)
if not self.pin_memory:
self.cpu_shard = None
def __update_tensors_ptr(self) -> None:
# sanity check
assert self.is_gathered
assert type(self.chunk_total) == torch.Tensor
for tensor, tensor_info in self.tensors_info.items():
tensor.data = self.chunk_total[tensor_info.offset:tensor_info.end].view(tensor.shape)
def __update_one_tensor_info(self, tensor_info: TensorInfo, next_state: TensorState):
self.tensors_state_monitor[tensor_info.state] -= 1
tensor_info.state = next_state
self.tensors_state_monitor[tensor_info.state] += 1
def __update_tensors_state(self, next_state: TensorState, prev_state: Optional[TensorState] = None):
for tensor_info in self.tensors_info.values():
if prev_state is None or tensor_info.state == prev_state:
self.__update_one_tensor_info(tensor_info, next_state)
def __hash__(self) -> int:
return hash(id(self))
def __eq__(self, __o: object) -> bool:
return self is __o
def __repr__(self, detailed: bool = False):
output = [
"AgChunk Information:\n",
"\tchunk size: {}, chunk dtype: {}, process group size: {}\n".format(self.chunk_size, self.dtype,
self.pg_size),
"\t# of tensors: {}, utilized size: {}, utilized percentage: {:.2f}\n".format(
self.num_tensors, self.utilized_size, self.utilized_size / self.chunk_size)
]
def print_tensor(tensor, prefix=''):
output.append("{}shape: {}, dtype: {}, device: {}\n".format(prefix, tensor.shape, tensor.dtype,
tensor.device))
if self.chunk_temp is not None:
output.append("\tchunk temp:\n")
print_tensor(tensor=self.chunk_temp, prefix='\t\t')
if self.chunk_total is not None and self.chunk_total.storage().size() > 0:
output.append("\tchunk total:\n")
print_tensor(tensor=self.chunk_total, prefix='\t\t')
if self.cuda_shard is not None:
output.append("\tcuda shard:\n")
print_tensor(tensor=self.cuda_shard, prefix='\t\t')
if self.cpu_shard is not None:
output.append("\tcpu shard:\n")
print_tensor(tensor=self.cpu_shard, prefix='\t\t')
memory_info = self.memory_usage
output.append("\tmemory usage: cuda {}, cpu {}\n".format(memory_info['cuda'], memory_info['cpu']))
if detailed:
output.append("\ttensor state monitor:\n")
for st in TensorState:
output.append("\t\t# of {}: {}\n".format(st, self.tensors_state_monitor[st]))
return ''.join(output)
def get_tensors(self) -> List[torch.Tensor]:
return list(self.tensors_info.keys())