diff --git a/colossalai/zero/gemini/placement_policy.py b/colossalai/zero/gemini/placement_policy.py index c0d03ba3b..9803d7f6d 100644 --- a/colossalai/zero/gemini/placement_policy.py +++ b/colossalai/zero/gemini/placement_policy.py @@ -11,6 +11,7 @@ from colossalai.legacy.utils.memory import colo_device_memory_capacity from colossalai.zero.gemini.chunk import Chunk from .chunk import Chunk, ChunkManager +from .gemini_mgr import GeminiManager from .memory_tracer import ChunkMemStatsCollector @@ -123,8 +124,9 @@ class StaticPlacementPolicy(PlacementPolicy): break if chunk not in prefetch and chunk not in self.chunk_manager.accessed_chunks: prefetch.append(chunk) - if len(prefetch) >= can_prefetch: - break + else: + continue + break return prefetch @@ -133,7 +135,7 @@ class AutoPlacementPolicy(PlacementPolicy): def __init__( self, - gemini_manager: "GeminiManager", + gemini_manager: GeminiManager, chunk_manager: ChunkManager, mem_stats_collector: Optional[ChunkMemStatsCollector] = None, max_prefetch: int = 0, @@ -234,10 +236,32 @@ class AutoPlacementPolicy(PlacementPolicy): else: grads_device_map[p] = torch.device("cpu") - def get_prefetch_chunks(self, max_prefetch: int) -> List[Chunk]: - # TODO @haze188 @botbw: implement prefetching for auto + def get_prefetch_chunks(self) -> List[Chunk]: + if self.gemini_manager.is_warmup(): # no prefetch during warmup since we need compute_list + return [] + # modified from self.evict_tensors + cuda_capacity = self._steady_cuda_cap_ratio * colo_device_memory_capacity( + get_accelerator().get_current_device() + ) + max_cuda_non_model_data_per_period = self.mem_stats_collector.next_period_non_model_data_usage("cuda") + used_cuda_model_data = self.chunk_manager.total_mem["cuda"] + total_cuda_model_data = cuda_capacity - max_cuda_non_model_data_per_period + avail_cuda_model_data = total_cuda_model_data - used_cuda_model_data - return [] + prefetch_chunk_memory = 0 + can_prefetch = self.max_prefetch - len(self.gemini_manager._async_works) + prefetch = [] + for i in range(self.gemini_manager.compute_idx + 1, len(self.gemini_manager.compute_list)): + for chunk in self.gemini_manager.compute_list[i]: + chunk: Chunk + if len(prefetch) >= can_prefetch or prefetch_chunk_memory + chunk.chunk_mem > avail_cuda_model_data: + break + if chunk not in prefetch and chunk not in self.chunk_manager.accessed_chunks: + prefetch.append(chunk) + else: + continue + break + return prefetch class PlacementPolicyFactory: