mirror of https://github.com/hpcaitech/ColossalAI
[zero] add AgChunk (#1417)
parent
d209aff684
commit
c577ed016e
|
@ -0,0 +1,366 @@
|
|||
import torch
|
||||
import torch.distributed as dist
|
||||
from typing import Optional, Dict
|
||||
|
||||
from colossalai.utils import get_current_device
|
||||
from colossalai.tensor import ProcessGroup as ColoProcessGroup
|
||||
from colossalai.gemini.chunk import TensorState, STATE_TRANS, TensorInfo, ChunkFullError, \
|
||||
free_storage, alloc_storage
|
||||
|
||||
|
||||
class AgChunk:
|
||||
def __init__(self,
|
||||
chunk_size: int,
|
||||
process_group: ColoProcessGroup,
|
||||
dtype: torch.dtype,
|
||||
init_device: Optional[torch.device] = None,
|
||||
keep_gathered: bool = False,
|
||||
pin_memory: bool = False) -> None:
|
||||
"""
|
||||
Chunk: A container owning a piece of contiguous memory space for tensors
|
||||
AgChunk is a kind of chunk, which uses all-gather operation to gather the whole chunk.
|
||||
This kind of chunk is exclusively used for DDP and ZeRO DDP.
|
||||
It is designed to make the full use of communication and PCIE bandwidth.
|
||||
|
||||
Args:
|
||||
chunk_size (int): the number of elements in a chunk
|
||||
process_group (ColoProcessGroup): the process group of this chunk
|
||||
dtype (torch.dtype): the data type of the chunk
|
||||
init_device (torch.device): optional, the device where the tensor is initialized
|
||||
The default value is None, which is the current GPU
|
||||
keep_gathered (bool): optional, if True, this chunk is always gathered in CUDA memory
|
||||
pin_memory (bool): optional, if True, this chunk always has a shard copy in pinned CPU memory
|
||||
"""
|
||||
|
||||
self.chunk_size = chunk_size
|
||||
self.utilized_size = 0
|
||||
# Here, we use torch process group,
|
||||
# since ColoProcessGroup might get deprecated soon
|
||||
self.torch_pg = process_group.dp_process_group
|
||||
self.pg_size = dist.get_world_size(self.torch_pg)
|
||||
self.pg_rank = dist.get_rank(self.torch_pg)
|
||||
|
||||
# the chunk size should be able to be divied by the size of GPU
|
||||
assert chunk_size % self.pg_size == 0
|
||||
self.shard_size = chunk_size // self.pg_size
|
||||
self.shard_begin = self.shard_size * self.pg_rank
|
||||
self.shard_end = self.shard_begin + self.shard_size
|
||||
|
||||
self.dtype = dtype
|
||||
device = init_device or get_current_device()
|
||||
self.chunk_temp = torch.empty(chunk_size, dtype=dtype, device=device)
|
||||
self.chunk_total = None # we force chunk_total located in CUDA
|
||||
self.cuda_shard = None # using two attributes for the better interpretation
|
||||
self.cpu_shard = None
|
||||
self.is_gathered = True
|
||||
|
||||
self.chunk_mem = self.chunk_size * self.chunk_temp.element_size()
|
||||
self.shard_mem = self.chunk_mem // self.pg_size
|
||||
|
||||
# each tensor is associated with a TensorInfo to track meta info
|
||||
self.tensors_info: Dict[torch.Tensor, TensorInfo] = {}
|
||||
# the total number of all tensors
|
||||
self.num_tensors = 0
|
||||
# monitor the states of all tensors
|
||||
self.tensors_state_monitor: Dict[TensorState, int] = dict()
|
||||
for state in TensorState:
|
||||
self.tensors_state_monitor[state] = 0
|
||||
|
||||
# some chunks can keep gathered all the time
|
||||
# so their computation patterns are the same as that of the parameters in DDP
|
||||
self.keep_gathered = keep_gathered
|
||||
|
||||
# if pin_memory is True, we allocate a piece of CPU pin-memory
|
||||
# for it all the time
|
||||
self.pin_memory = pin_memory
|
||||
|
||||
# we introduce the paired chunk here
|
||||
# it refers to another chunk having the same parameters
|
||||
# but with different dtype(such as fp16_chunk.mapping_chunk -> fp32_chunk
|
||||
self.paired_chunk = None
|
||||
# if the the gradient of this chunk is reduced, the flag is True
|
||||
# so the flag is False for unused parameters
|
||||
self.grad_reduced_flag = False
|
||||
# if this chunk is synchronized with the optimizer, the flag is True
|
||||
self.optim_sync_flag = True
|
||||
# if the cpu_shard has been visited during the training step, the flag is True
|
||||
self.cpu_vis_flag = False
|
||||
|
||||
@property
|
||||
def memory_usage(self):
|
||||
cuda_memory = 0
|
||||
cpu_memory = 0
|
||||
|
||||
if self.chunk_temp is not None:
|
||||
# this chunk is not closed
|
||||
if self.chunk_temp.device.type == 'cuda':
|
||||
cuda_memory += self.chunk_mem
|
||||
else:
|
||||
cpu_memory += self.chunk_mem
|
||||
else:
|
||||
if self.is_gathered:
|
||||
cuda_memory += self.chunk_mem
|
||||
if self.cuda_shard is not None:
|
||||
cuda_memory += self.shard_mem
|
||||
if self.cpu_shard is not None:
|
||||
cpu_memory += self.shard_mem
|
||||
|
||||
return dict(cuda=cuda_memory, cpu=cpu_memory)
|
||||
|
||||
@property
|
||||
def device_type(self):
|
||||
if self.chunk_temp is not None:
|
||||
return self.chunk_temp.device.type
|
||||
else:
|
||||
if self.chunk_total is not None:
|
||||
return 'cuda'
|
||||
elif self.cuda_shard is not None:
|
||||
return 'cuda'
|
||||
else:
|
||||
return 'cpu'
|
||||
|
||||
def append_tensor(self, tensor: torch.Tensor):
|
||||
"""Add a tensor to the chunk.
|
||||
|
||||
Args:
|
||||
tensor (torch.Tensor): a tensor to be added to the chunk
|
||||
"""
|
||||
# sanity check
|
||||
assert self.chunk_temp is not None
|
||||
assert tensor.dtype == self.dtype
|
||||
|
||||
new_utilized_size = self.utilized_size + tensor.numel()
|
||||
# raise exception when the chunk size is exceeded
|
||||
if new_utilized_size > self.chunk_size:
|
||||
raise ChunkFullError
|
||||
|
||||
self.chunk_temp[self.utilized_size: new_utilized_size].copy_(tensor.flatten())
|
||||
assert type(self.chunk_temp) == torch.Tensor, "copy_tensor_to_chunk_slice must use a torch tensor"
|
||||
tensor.data = self.chunk_temp[self.utilized_size: new_utilized_size].view(tensor.shape)
|
||||
|
||||
# record all the information about the tensor
|
||||
self.num_tensors += 1
|
||||
tensor_state = TensorState.HOLD
|
||||
self.tensors_info[tensor] = TensorInfo(tensor_state, self.utilized_size, new_utilized_size)
|
||||
self.tensors_state_monitor[tensor_state] += 1
|
||||
self.utilized_size = new_utilized_size
|
||||
|
||||
def close_chunk(self, shard_dev: torch.device):
|
||||
"""Close the chunk. Any tensor can't be appended to a closed chunk.
|
||||
"""
|
||||
# sanity check
|
||||
assert self.chunk_temp is not None
|
||||
|
||||
if self.chunk_temp.device.type == 'cpu':
|
||||
self.chunk_total = self.chunk_temp.to(get_current_device())
|
||||
else:
|
||||
self.chunk_total = self.chunk_temp
|
||||
self.chunk_temp = None
|
||||
|
||||
self.__scatter()
|
||||
|
||||
if self.pin_memory or shard_dev.type == 'cpu':
|
||||
self.cpu_shard = torch.empty(self.shard_size,
|
||||
dtype=self.dtype,
|
||||
pin_memory=self.pin_memory)
|
||||
self.cpu_shard.copy_(self.cuda_shard)
|
||||
self.cpu_vis_flag = True # cpu_shard has been visited
|
||||
|
||||
if shard_dev.type == 'cpu':
|
||||
self.cuda_shard = None
|
||||
|
||||
def shard_move(self, device: torch.device, force_copy: bool = False):
|
||||
# sanity check
|
||||
assert not self.is_gathered
|
||||
# when the current chunk is not synchronized with the optimizer
|
||||
# just use another way for the movement
|
||||
if not self.optim_sync_flag:
|
||||
assert device.type == 'cuda', "each chunk should first be moved to CUDA"
|
||||
self.__paired_shard_move()
|
||||
self.optim_sync_flag = True
|
||||
return
|
||||
|
||||
if device.type == 'cuda':
|
||||
assert device == get_current_device(), "can't move chunk to another device"
|
||||
|
||||
if self.cuda_shard:
|
||||
return
|
||||
|
||||
self.cuda_shard = self.cpu_shard.to(get_current_device())
|
||||
|
||||
if not self.pin_memory:
|
||||
self.cpu_shard = None
|
||||
elif device.type == 'cpu':
|
||||
if self.cuda_shard is None:
|
||||
return
|
||||
|
||||
if self.pin_memory:
|
||||
if force_copy or not self.cpu_vis_flag:
|
||||
self.cpu_shard.copy_(self.cuda_shard)
|
||||
# if cpu_shard has been visited
|
||||
# copy operation is not need
|
||||
else:
|
||||
self.cpu_shard = self.cuda_shard.cpu()
|
||||
self.cpu_vis_flag = True
|
||||
self.cuda_shard = None
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
def access_chunk(self):
|
||||
"""Make the chunk usable for the parameters inside it.
|
||||
It is an operation done in CUDA.
|
||||
"""
|
||||
# sanity check
|
||||
assert self.chunk_temp is None
|
||||
|
||||
if not self.is_gathered:
|
||||
self.__gather()
|
||||
self.__update_tensors_ptr()
|
||||
|
||||
def release_chunk(self):
|
||||
"""Release the usable chunk.
|
||||
It is an operation done in CUDA.
|
||||
"""
|
||||
# sanity check
|
||||
assert self.chunk_temp is None
|
||||
|
||||
if self.is_gathered:
|
||||
self.__scatter()
|
||||
|
||||
def reduce(self):
|
||||
"""Reduce scatter all the gradients.
|
||||
It is an operation done in CUDA.
|
||||
"""
|
||||
# sanity check
|
||||
assert self.is_gathered
|
||||
|
||||
if self.pg_size == 1:
|
||||
# tricky code here
|
||||
# just move chunk_total to cuda_shard
|
||||
# the communication is not necessary
|
||||
self.__scatter()
|
||||
elif self.keep_gathered:
|
||||
# we use all-reduce here
|
||||
dist.all_reduce(self.chunk_total, group=self.torch_pg)
|
||||
else:
|
||||
self.cuda_shard = torch.empty(
|
||||
self.shard_size, dtype=self.dtype, device=get_current_device())
|
||||
|
||||
input_list = list(torch.chunk(self.chunk_total, chunks=self.pg_size, dim=0))
|
||||
dist.reduce_scatter(self.cuda_shard, input_list, self.torch_pg)
|
||||
|
||||
free_storage(self.chunk_total)
|
||||
self.is_gathered = False
|
||||
self.__update_tensors_state(TensorState.HOLD)
|
||||
self.grad_reduced_flag = True
|
||||
|
||||
def tensor_trans_state(self, tensor: torch.Tensor, tensor_state: TensorState) -> None:
|
||||
"""
|
||||
Make a transition of the tensor into the next state.
|
||||
|
||||
Args:
|
||||
tensor (torch.Tensor): a torch Tensor object.
|
||||
tensor_state (TensorState): the target state for transition.
|
||||
"""
|
||||
|
||||
# As the gradient hook can be triggered either before or after post-backward
|
||||
# tensor's state can be compute -> hold_after_bwd -> ready_for_reduce
|
||||
# or compute -> ready_for_reduce -> hold_after_bwd
|
||||
# the second one is invalid, we just ignore ready_for_reduce -> hold_after_bwd
|
||||
# this function only apply valid state transformation
|
||||
# invalid calls will be ignored and nothing changes
|
||||
if (self.tensors_info[tensor].state, tensor_state) not in STATE_TRANS:
|
||||
# print(
|
||||
# f'WARNING: Rank{self.process_group.rank()} apply invalid state trans: {self.tensors_info[tensor].state} to {tensor_state}'
|
||||
# )
|
||||
return
|
||||
self.__update_one_tensor_info(self.tensors_info[tensor], tensor_state)
|
||||
|
||||
def copy_tensor_to_chunk_slice(self, tensor: torch.Tensor, data_slice: torch.Tensor) -> None:
|
||||
"""
|
||||
Copy data slice to the memory space indexed by the input tensor in the chunk.
|
||||
|
||||
Args:
|
||||
tensor (torch.Tensor): the tensor used to retrive meta information
|
||||
data_slice (torch.Tensor): the tensor to be copied to the chunk
|
||||
"""
|
||||
# sanity check
|
||||
assert self.is_gathered
|
||||
|
||||
tensor_info = self.tensors_info[tensor]
|
||||
self.chunk_total[tensor_info.offset:tensor_info.end].copy_(data_slice.flatten())
|
||||
tensor.data = self.chunk_total[tensor_info.offset:tensor_info.end].view(tensor.shape)
|
||||
|
||||
@property
|
||||
def can_release(self) -> bool:
|
||||
return self.tensors_state_monitor[TensorState.HOLD] == self.num_tensors
|
||||
|
||||
@property
|
||||
def can_reduce(self):
|
||||
return self.tensors_state_monitor[TensorState.READY_FOR_REDUCE] == self.num_tensors
|
||||
|
||||
def __gather(self):
|
||||
if not self.is_gathered:
|
||||
# sanity check
|
||||
assert self.cuda_shard is not None
|
||||
|
||||
if self.pg_size == 1:
|
||||
self.chunk_total = self.cuda_shard
|
||||
else:
|
||||
alloc_storage(self.chunk_total)
|
||||
gather_list = list(torch.chunk(
|
||||
input=self.chunk_total, chunks=self.pg_size, dim=0))
|
||||
dist.all_gather(gather_list, self.cuda_shard, self.torch_pg)
|
||||
|
||||
self.cuda_shard = None
|
||||
self.is_gathered = True
|
||||
|
||||
def __scatter(self):
|
||||
if self.keep_gathered:
|
||||
return
|
||||
|
||||
if self.is_gathered:
|
||||
# sanity check
|
||||
assert self.cuda_shard is None
|
||||
|
||||
self.cuda_shard = torch.empty(self.shard_size,
|
||||
dtype=self.dtype,
|
||||
device=self.chunk_total.device)
|
||||
|
||||
self.cuda_shard.copy_(self.chunk_total[self.shard_begin: self.shard_end])
|
||||
|
||||
free_storage(self.chunk_total)
|
||||
self.is_gathered = False
|
||||
|
||||
def __paired_shard_move(self):
|
||||
assert self.paired_chunk is not None, "chunks should be paired before training"
|
||||
optim_chunk = self.paired_chunk
|
||||
assert self.chunk_size == optim_chunk.chunk_size
|
||||
|
||||
# only be called when optimizer state is in CPU memory
|
||||
# the grad and param should be in the same device
|
||||
assert self.cuda_shard is None
|
||||
temp = optim_chunk.cpu_shard.to(get_current_device())
|
||||
# avoid to transform FP32 in CPU
|
||||
self.cuda_shard = temp.to(self.dtype)
|
||||
|
||||
if not self.pin_memory:
|
||||
self.cpu_shard = None
|
||||
|
||||
def __update_tensors_ptr(self) -> None:
|
||||
# sanity check
|
||||
assert self.is_gathered
|
||||
assert type(self.chunk_total) == torch.Tensor
|
||||
|
||||
for tensor, tensor_info in self.tensors_info.items():
|
||||
tensor.data = self.chunk_total[tensor_info.offset:tensor_info.end].view(tensor.shape)
|
||||
|
||||
def __update_one_tensor_info(self, tensor_info: TensorInfo, next_state: TensorState):
|
||||
self.tensors_state_monitor[tensor_info.state] -= 1
|
||||
tensor_info.state = next_state
|
||||
self.tensors_state_monitor[tensor_info.state] += 1
|
||||
|
||||
def __update_tensors_state(self, next_state: TensorState, prev_state: Optional[TensorState] = None):
|
||||
for tensor_info in self.tensors_info.values():
|
||||
if prev_state is None or tensor_info.state == prev_state:
|
||||
self.__update_one_tensor_info(tensor_info, next_state)
|
Loading…
Reference in New Issue