diff --git a/colossalai/gemini/update/chunkv2.py b/colossalai/gemini/update/chunkv2.py index b1bbfce5e..25f7858ea 100644 --- a/colossalai/gemini/update/chunkv2.py +++ b/colossalai/gemini/update/chunkv2.py @@ -9,6 +9,7 @@ from colossalai.gemini.chunk import TensorState, STATE_TRANS, TensorInfo, ChunkF class ChunkV2: + def __init__(self, chunk_size: int, process_group: ColoProcessGroup, @@ -49,9 +50,9 @@ class ChunkV2: self.dtype = dtype device = init_device or get_current_device() - self.chunk_temp = torch.zeros(chunk_size, dtype=dtype, device=device) # keep all zero - self.chunk_total = None # we force chunk_total located in CUDA - self.cuda_shard = None # using two attributes for the better interpretation + self.chunk_temp = torch.zeros(chunk_size, dtype=dtype, device=device) # keep all zero + self.chunk_total = None # we force chunk_total located in CUDA + self.cuda_shard = None # using two attributes for the better interpretation self.cpu_shard = None self.is_gathered = True @@ -71,7 +72,7 @@ class ChunkV2: # so their computation patterns are the same as that of the parameters in DDP self.keep_gathered = keep_gathered if self.keep_gathered: - pin_memory = False # since this chunk is gathered, it doesn't need to pin + pin_memory = False # since this chunk is gathered, it doesn't need to pin # if pin_memory is True, we allocate a piece of CPU pin-memory # for it all the time @@ -137,9 +138,9 @@ class ChunkV2: if new_utilized_size > self.chunk_size: raise ChunkFullError - self.chunk_temp[self.utilized_size: new_utilized_size].copy_(tensor.data.flatten()) + self.chunk_temp[self.utilized_size:new_utilized_size].copy_(tensor.data.flatten()) assert type(self.chunk_temp) == torch.Tensor, "copy_tensor_to_chunk_slice must use a torch tensor" - tensor.data = self.chunk_temp[self.utilized_size: new_utilized_size].view(tensor.shape) + tensor.data = self.chunk_temp[self.utilized_size:new_utilized_size].view(tensor.shape) # record all the information about the tensor self.num_tensors += 1 @@ -177,11 +178,9 @@ class ChunkV2: shard_dev = torch.device('cpu') if self.pin_memory or shard_dev.type == 'cpu': - self.cpu_shard = torch.empty(self.shard_size, - dtype=self.dtype, - pin_memory=self.pin_memory) + self.cpu_shard = torch.empty(self.shard_size, dtype=self.dtype, pin_memory=self.pin_memory) self.cpu_shard.copy_(self.cuda_shard) - self.cpu_vis_flag = True # cpu_shard has been visited + self.cpu_vis_flag = True # cpu_shard has been visited if shard_dev.type == 'cpu': self.cuda_shard = None @@ -260,8 +259,7 @@ class ChunkV2: # we use all-reduce here dist.all_reduce(self.chunk_total, group=self.torch_pg) else: - self.cuda_shard = torch.empty( - self.shard_size, dtype=self.dtype, device=get_current_device()) + self.cuda_shard = torch.empty(self.shard_size, dtype=self.dtype, device=get_current_device()) input_list = list(torch.chunk(self.chunk_total, chunks=self.pg_size, dim=0)) dist.reduce_scatter(self.cuda_shard, input_list, group=self.torch_pg) @@ -330,10 +328,10 @@ class ChunkV2: Check if the chunk has inf or nan values in CUDA. """ if self.is_gathered: - valid_tensor = self.chunk_total[: self.utilized_size] + valid_tensor = self.chunk_total[:self.utilized_size] else: - assert self.cuda_shard is not None # only check in CUDA - valid_tensor = self.cuda_shard[: self.valid_end] + assert self.cuda_shard is not None # only check in CUDA + valid_tensor = self.cuda_shard[:self.valid_end] return torch.isinf(valid_tensor).any().item() | torch.isnan(valid_tensor).any().item() @@ -346,8 +344,7 @@ class ChunkV2: self.chunk_total = self.cuda_shard else: alloc_storage(self.chunk_total) - gather_list = list(torch.chunk( - input=self.chunk_total, chunks=self.pg_size, dim=0)) + gather_list = list(torch.chunk(input=self.chunk_total, chunks=self.pg_size, dim=0)) dist.all_gather(gather_list, self.cuda_shard, self.torch_pg) self.cuda_shard = None @@ -361,11 +358,9 @@ class ChunkV2: # sanity check assert self.cuda_shard is None - self.cuda_shard = torch.empty(self.shard_size, - dtype=self.dtype, - device=self.chunk_total.device) + self.cuda_shard = torch.empty(self.shard_size, dtype=self.dtype, device=self.chunk_total.device) - self.cuda_shard.copy_(self.chunk_total[self.shard_begin: self.shard_end]) + self.cuda_shard.copy_(self.chunk_total[self.shard_begin:self.shard_end]) free_storage(self.chunk_total) self.is_gathered = False @@ -412,15 +407,15 @@ class ChunkV2: def __repr__(self, detailed: bool = False): output = [ "AgChunk Information:\n", - "\tchunk size: {}, chunk dtype: {}, process group size: {}\n".format( - self.chunk_size, self.dtype, self.pg_size), + "\tchunk size: {}, chunk dtype: {}, process group size: {}\n".format(self.chunk_size, self.dtype, + self.pg_size), "\t# of tensors: {}, utilized size: {}, utilized percentage: {:.2f}\n".format( self.num_tensors, self.utilized_size, self.utilized_size / self.chunk_size) ] def print_tensor(tensor, prefix=''): - output.append("{}shape: {}, dtype: {}, device: {}\n".format( - prefix, tensor.shape, tensor.dtype, tensor.device)) + output.append("{}shape: {}, dtype: {}, device: {}\n".format(prefix, tensor.shape, tensor.dtype, + tensor.device)) if self.chunk_temp is not None: output.append("\tchunk temp:\n")