diff --git a/colossalai/zero/gemini/gemini_mgr.py b/colossalai/zero/gemini/gemini_mgr.py index 83e475575..d6b539f55 100644 --- a/colossalai/zero/gemini/gemini_mgr.py +++ b/colossalai/zero/gemini/gemini_mgr.py @@ -57,6 +57,7 @@ class GeminiManager: self._comp_cuda_demand_time = 0 def reset_attributes(self): + assert self._compute_idx + 1 == len(self._compute_list) self._compute_idx = -1 self._h2d_volume = 0 self._d2h_volume = 0 diff --git a/colossalai/zero/gemini/placement_policy.py b/colossalai/zero/gemini/placement_policy.py index 9b1d1a6ab..c26db00e0 100644 --- a/colossalai/zero/gemini/placement_policy.py +++ b/colossalai/zero/gemini/placement_policy.py @@ -145,6 +145,8 @@ class AutoPlacementPolicy(PlacementPolicy): self._warmup_non_model_data_ratio = warmup_non_model_data_ratio self._steady_cuda_cap_ratio = steady_cuda_cap_ratio + self.__avail_cuda_model_data_for_prefetch = None + def evict_tensors( self, can_evict_chunks: List[Chunk], @@ -204,6 +206,7 @@ class AutoPlacementPolicy(PlacementPolicy): f"Adjust layout failed! No enough CUDA memory! " f"Need {to_free_cuda_model_data}, freed {freed_cuda_model_data}" ) + self.__avail_cuda_model_data_for_prefetch = avail_cuda_model_data - freed_cuda_model_data return freed_cuda_model_data, time() - start @staticmethod @@ -234,14 +237,9 @@ class AutoPlacementPolicy(PlacementPolicy): ) -> List[Chunk]: if is_warmup: # no prefetch during warmup since we need compute_list return [] - # modified from self.evict_tensors - cuda_capacity = self._steady_cuda_cap_ratio * colo_device_memory_capacity( - get_accelerator().get_current_device() - ) - max_cuda_non_model_data_per_period = self.mem_stats_collector.next_period_non_model_data_usage("cuda") - used_cuda_model_data = self.chunk_manager.total_mem["cuda"] - total_cuda_model_data = cuda_capacity - max_cuda_non_model_data_per_period - avail_cuda_model_data = total_cuda_model_data - used_cuda_model_data + + avail_cuda_model_data = self.__avail_cuda_model_data_for_prefetch + self.__avail_cuda_model_data_for_prefetch = None # incase of double use prefetch_chunk_memory = 0 can_prefetch = self.max_prefetch - len(async_works)