diff --git a/colossalai/zero/gemini/placement_policy.py b/colossalai/zero/gemini/placement_policy.py index 9803d7f6d..9e9fb1f58 100644 --- a/colossalai/zero/gemini/placement_policy.py +++ b/colossalai/zero/gemini/placement_policy.py @@ -20,7 +20,7 @@ class PlacementPolicy(ABC): def __init__( self, - gemini_manager: "GeminiManager", + gemini_manager: "GeminiManager", # TODO @botbw: solve circular import chunk_manager: ChunkManager, mem_stats_collector: Optional[ChunkMemStatsCollector] = None, max_prefetch: int = 0, @@ -41,9 +41,8 @@ class PlacementPolicy(ABC): ) -> None: raise NotImplementedError - @abstractmethod def get_prefetch_chunks(self) -> List[Chunk]: - raise NotImplementedError + return [] # no prefetch by default class StaticPlacementPolicy(PlacementPolicy):