|
|
|
@ -9,6 +9,7 @@ from colossalai.gemini.chunk import TensorState, STATE_TRANS, TensorInfo, ChunkF
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ChunkV2:
|
|
|
|
|
|
|
|
|
|
def __init__(self,
|
|
|
|
|
chunk_size: int,
|
|
|
|
|
process_group: ColoProcessGroup,
|
|
|
|
@ -49,9 +50,9 @@ class ChunkV2:
|
|
|
|
|
|
|
|
|
|
self.dtype = dtype
|
|
|
|
|
device = init_device or get_current_device()
|
|
|
|
|
self.chunk_temp = torch.zeros(chunk_size, dtype=dtype, device=device) # keep all zero
|
|
|
|
|
self.chunk_total = None # we force chunk_total located in CUDA
|
|
|
|
|
self.cuda_shard = None # using two attributes for the better interpretation
|
|
|
|
|
self.chunk_temp = torch.zeros(chunk_size, dtype=dtype, device=device) # keep all zero
|
|
|
|
|
self.chunk_total = None # we force chunk_total located in CUDA
|
|
|
|
|
self.cuda_shard = None # using two attributes for the better interpretation
|
|
|
|
|
self.cpu_shard = None
|
|
|
|
|
self.is_gathered = True
|
|
|
|
|
|
|
|
|
@ -71,7 +72,7 @@ class ChunkV2:
|
|
|
|
|
# so their computation patterns are the same as that of the parameters in DDP
|
|
|
|
|
self.keep_gathered = keep_gathered
|
|
|
|
|
if self.keep_gathered:
|
|
|
|
|
pin_memory = False # since this chunk is gathered, it doesn't need to pin
|
|
|
|
|
pin_memory = False # since this chunk is gathered, it doesn't need to pin
|
|
|
|
|
|
|
|
|
|
# if pin_memory is True, we allocate a piece of CPU pin-memory
|
|
|
|
|
# for it all the time
|
|
|
|
@ -137,9 +138,9 @@ class ChunkV2:
|
|
|
|
|
if new_utilized_size > self.chunk_size:
|
|
|
|
|
raise ChunkFullError
|
|
|
|
|
|
|
|
|
|
self.chunk_temp[self.utilized_size: new_utilized_size].copy_(tensor.data.flatten())
|
|
|
|
|
self.chunk_temp[self.utilized_size:new_utilized_size].copy_(tensor.data.flatten())
|
|
|
|
|
assert type(self.chunk_temp) == torch.Tensor, "copy_tensor_to_chunk_slice must use a torch tensor"
|
|
|
|
|
tensor.data = self.chunk_temp[self.utilized_size: new_utilized_size].view(tensor.shape)
|
|
|
|
|
tensor.data = self.chunk_temp[self.utilized_size:new_utilized_size].view(tensor.shape)
|
|
|
|
|
|
|
|
|
|
# record all the information about the tensor
|
|
|
|
|
self.num_tensors += 1
|
|
|
|
@ -177,11 +178,9 @@ class ChunkV2:
|
|
|
|
|
shard_dev = torch.device('cpu')
|
|
|
|
|
|
|
|
|
|
if self.pin_memory or shard_dev.type == 'cpu':
|
|
|
|
|
self.cpu_shard = torch.empty(self.shard_size,
|
|
|
|
|
dtype=self.dtype,
|
|
|
|
|
pin_memory=self.pin_memory)
|
|
|
|
|
self.cpu_shard = torch.empty(self.shard_size, dtype=self.dtype, pin_memory=self.pin_memory)
|
|
|
|
|
self.cpu_shard.copy_(self.cuda_shard)
|
|
|
|
|
self.cpu_vis_flag = True # cpu_shard has been visited
|
|
|
|
|
self.cpu_vis_flag = True # cpu_shard has been visited
|
|
|
|
|
|
|
|
|
|
if shard_dev.type == 'cpu':
|
|
|
|
|
self.cuda_shard = None
|
|
|
|
@ -260,8 +259,7 @@ class ChunkV2:
|
|
|
|
|
# we use all-reduce here
|
|
|
|
|
dist.all_reduce(self.chunk_total, group=self.torch_pg)
|
|
|
|
|
else:
|
|
|
|
|
self.cuda_shard = torch.empty(
|
|
|
|
|
self.shard_size, dtype=self.dtype, device=get_current_device())
|
|
|
|
|
self.cuda_shard = torch.empty(self.shard_size, dtype=self.dtype, device=get_current_device())
|
|
|
|
|
|
|
|
|
|
input_list = list(torch.chunk(self.chunk_total, chunks=self.pg_size, dim=0))
|
|
|
|
|
dist.reduce_scatter(self.cuda_shard, input_list, group=self.torch_pg)
|
|
|
|
@ -330,10 +328,10 @@ class ChunkV2:
|
|
|
|
|
Check if the chunk has inf or nan values in CUDA.
|
|
|
|
|
"""
|
|
|
|
|
if self.is_gathered:
|
|
|
|
|
valid_tensor = self.chunk_total[: self.utilized_size]
|
|
|
|
|
valid_tensor = self.chunk_total[:self.utilized_size]
|
|
|
|
|
else:
|
|
|
|
|
assert self.cuda_shard is not None # only check in CUDA
|
|
|
|
|
valid_tensor = self.cuda_shard[: self.valid_end]
|
|
|
|
|
assert self.cuda_shard is not None # only check in CUDA
|
|
|
|
|
valid_tensor = self.cuda_shard[:self.valid_end]
|
|
|
|
|
|
|
|
|
|
return torch.isinf(valid_tensor).any().item() | torch.isnan(valid_tensor).any().item()
|
|
|
|
|
|
|
|
|
@ -346,8 +344,7 @@ class ChunkV2:
|
|
|
|
|
self.chunk_total = self.cuda_shard
|
|
|
|
|
else:
|
|
|
|
|
alloc_storage(self.chunk_total)
|
|
|
|
|
gather_list = list(torch.chunk(
|
|
|
|
|
input=self.chunk_total, chunks=self.pg_size, dim=0))
|
|
|
|
|
gather_list = list(torch.chunk(input=self.chunk_total, chunks=self.pg_size, dim=0))
|
|
|
|
|
dist.all_gather(gather_list, self.cuda_shard, self.torch_pg)
|
|
|
|
|
|
|
|
|
|
self.cuda_shard = None
|
|
|
|
@ -361,11 +358,9 @@ class ChunkV2:
|
|
|
|
|
# sanity check
|
|
|
|
|
assert self.cuda_shard is None
|
|
|
|
|
|
|
|
|
|
self.cuda_shard = torch.empty(self.shard_size,
|
|
|
|
|
dtype=self.dtype,
|
|
|
|
|
device=self.chunk_total.device)
|
|
|
|
|
self.cuda_shard = torch.empty(self.shard_size, dtype=self.dtype, device=self.chunk_total.device)
|
|
|
|
|
|
|
|
|
|
self.cuda_shard.copy_(self.chunk_total[self.shard_begin: self.shard_end])
|
|
|
|
|
self.cuda_shard.copy_(self.chunk_total[self.shard_begin:self.shard_end])
|
|
|
|
|
|
|
|
|
|
free_storage(self.chunk_total)
|
|
|
|
|
self.is_gathered = False
|
|
|
|
@ -412,15 +407,15 @@ class ChunkV2:
|
|
|
|
|
def __repr__(self, detailed: bool = False):
|
|
|
|
|
output = [
|
|
|
|
|
"AgChunk Information:\n",
|
|
|
|
|
"\tchunk size: {}, chunk dtype: {}, process group size: {}\n".format(
|
|
|
|
|
self.chunk_size, self.dtype, self.pg_size),
|
|
|
|
|
"\tchunk size: {}, chunk dtype: {}, process group size: {}\n".format(self.chunk_size, self.dtype,
|
|
|
|
|
self.pg_size),
|
|
|
|
|
"\t# of tensors: {}, utilized size: {}, utilized percentage: {:.2f}\n".format(
|
|
|
|
|
self.num_tensors, self.utilized_size, self.utilized_size / self.chunk_size)
|
|
|
|
|
]
|
|
|
|
|
|
|
|
|
|
def print_tensor(tensor, prefix=''):
|
|
|
|
|
output.append("{}shape: {}, dtype: {}, device: {}\n".format(
|
|
|
|
|
prefix, tensor.shape, tensor.dtype, tensor.device))
|
|
|
|
|
output.append("{}shape: {}, dtype: {}, device: {}\n".format(prefix, tensor.shape, tensor.dtype,
|
|
|
|
|
tensor.device))
|
|
|
|
|
|
|
|
|
|
if self.chunk_temp is not None:
|
|
|
|
|
output.append("\tchunk temp:\n")
|
|
|
|
|