diff --git a/colossalai/gemini/update/chunkv2.py b/colossalai/gemini/update/chunkv2.py
index b1bbfce5e..25f7858ea 100644
--- a/colossalai/gemini/update/chunkv2.py
+++ b/colossalai/gemini/update/chunkv2.py
@@ -9,6 +9,7 @@ from colossalai.gemini.chunk import TensorState, STATE_TRANS, TensorInfo, ChunkF
 
 
 class ChunkV2:
+
     def __init__(self,
                  chunk_size: int,
                  process_group: ColoProcessGroup,
@@ -49,9 +50,9 @@ class ChunkV2:
 
         self.dtype = dtype
         device = init_device or get_current_device()
-        self.chunk_temp = torch.zeros(chunk_size, dtype=dtype, device=device)  # keep all zero
-        self.chunk_total = None  # we force chunk_total located in CUDA
-        self.cuda_shard = None  # using two attributes for the better interpretation
+        self.chunk_temp = torch.zeros(chunk_size, dtype=dtype, device=device)    # keep all zero
+        self.chunk_total = None    # we force chunk_total located in CUDA
+        self.cuda_shard = None    # using two attributes for the better interpretation
         self.cpu_shard = None
         self.is_gathered = True
 
@@ -71,7 +72,7 @@ class ChunkV2:
         # so their computation patterns are the same as that of the parameters in DDP
         self.keep_gathered = keep_gathered
         if self.keep_gathered:
-            pin_memory = False  # since this chunk is gathered, it doesn't need to pin
+            pin_memory = False    # since this chunk is gathered, it doesn't need to pin
 
         # if pin_memory is True, we allocate a piece of CPU pin-memory
         # for it all the time
@@ -137,9 +138,9 @@ class ChunkV2:
         if new_utilized_size > self.chunk_size:
             raise ChunkFullError
 
-        self.chunk_temp[self.utilized_size: new_utilized_size].copy_(tensor.data.flatten())
+        self.chunk_temp[self.utilized_size:new_utilized_size].copy_(tensor.data.flatten())
         assert type(self.chunk_temp) == torch.Tensor, "copy_tensor_to_chunk_slice must use a torch tensor"
-        tensor.data = self.chunk_temp[self.utilized_size: new_utilized_size].view(tensor.shape)
+        tensor.data = self.chunk_temp[self.utilized_size:new_utilized_size].view(tensor.shape)
 
         # record all the information about the tensor
         self.num_tensors += 1
@@ -177,11 +178,9 @@ class ChunkV2:
             shard_dev = torch.device('cpu')
 
         if self.pin_memory or shard_dev.type == 'cpu':
-            self.cpu_shard = torch.empty(self.shard_size,
-                                         dtype=self.dtype,
-                                         pin_memory=self.pin_memory)
+            self.cpu_shard = torch.empty(self.shard_size, dtype=self.dtype, pin_memory=self.pin_memory)
             self.cpu_shard.copy_(self.cuda_shard)
-            self.cpu_vis_flag = True  # cpu_shard has been visited
+            self.cpu_vis_flag = True    # cpu_shard has been visited
 
         if shard_dev.type == 'cpu':
             self.cuda_shard = None
@@ -260,8 +259,7 @@ class ChunkV2:
             # we use all-reduce here
             dist.all_reduce(self.chunk_total, group=self.torch_pg)
         else:
-            self.cuda_shard = torch.empty(
-                self.shard_size, dtype=self.dtype, device=get_current_device())
+            self.cuda_shard = torch.empty(self.shard_size, dtype=self.dtype, device=get_current_device())
 
             input_list = list(torch.chunk(self.chunk_total, chunks=self.pg_size, dim=0))
             dist.reduce_scatter(self.cuda_shard, input_list, group=self.torch_pg)
@@ -330,10 +328,10 @@ class ChunkV2:
         Check if the chunk has inf or nan values in CUDA.
         """
         if self.is_gathered:
-            valid_tensor = self.chunk_total[: self.utilized_size]
+            valid_tensor = self.chunk_total[:self.utilized_size]
         else:
-            assert self.cuda_shard is not None  # only check in CUDA
-            valid_tensor = self.cuda_shard[: self.valid_end]
+            assert self.cuda_shard is not None    # only check in CUDA
+            valid_tensor = self.cuda_shard[:self.valid_end]
 
         return torch.isinf(valid_tensor).any().item() | torch.isnan(valid_tensor).any().item()
 
@@ -346,8 +344,7 @@ class ChunkV2:
                 self.chunk_total = self.cuda_shard
             else:
                 alloc_storage(self.chunk_total)
-                gather_list = list(torch.chunk(
-                    input=self.chunk_total, chunks=self.pg_size, dim=0))
+                gather_list = list(torch.chunk(input=self.chunk_total, chunks=self.pg_size, dim=0))
                 dist.all_gather(gather_list, self.cuda_shard, self.torch_pg)
 
             self.cuda_shard = None
@@ -361,11 +358,9 @@ class ChunkV2:
             # sanity check
             assert self.cuda_shard is None
 
-            self.cuda_shard = torch.empty(self.shard_size,
-                                          dtype=self.dtype,
-                                          device=self.chunk_total.device)
+            self.cuda_shard = torch.empty(self.shard_size, dtype=self.dtype, device=self.chunk_total.device)
 
-            self.cuda_shard.copy_(self.chunk_total[self.shard_begin: self.shard_end])
+            self.cuda_shard.copy_(self.chunk_total[self.shard_begin:self.shard_end])
 
             free_storage(self.chunk_total)
             self.is_gathered = False
@@ -412,15 +407,15 @@ class ChunkV2:
     def __repr__(self, detailed: bool = False):
         output = [
             "AgChunk Information:\n",
-            "\tchunk size: {}, chunk dtype: {}, process group size: {}\n".format(
-                self.chunk_size, self.dtype, self.pg_size),
+            "\tchunk size: {}, chunk dtype: {}, process group size: {}\n".format(self.chunk_size, self.dtype,
+                                                                                 self.pg_size),
             "\t# of tensors: {}, utilized size: {}, utilized percentage: {:.2f}\n".format(
                 self.num_tensors, self.utilized_size, self.utilized_size / self.chunk_size)
         ]
 
         def print_tensor(tensor, prefix=''):
-            output.append("{}shape: {}, dtype: {}, device: {}\n".format(
-                prefix, tensor.shape, tensor.dtype, tensor.device))
+            output.append("{}shape: {}, dtype: {}, device: {}\n".format(prefix, tensor.shape, tensor.dtype,
+                                                                        tensor.device))
 
         if self.chunk_temp is not None:
             output.append("\tchunk temp:\n")