diff --git a/colossalai/zero/gemini/gemini_mgr.py b/colossalai/zero/gemini/gemini_mgr.py index 5b309c7a1..0a8e0ae4a 100644 --- a/colossalai/zero/gemini/gemini_mgr.py +++ b/colossalai/zero/gemini/gemini_mgr.py @@ -42,7 +42,9 @@ class GeminiManager: self._mem_stats_collector = ( ChunkMemStatsCollector(chunk_manager, self._memstats) if policy_cls.need_mem_stats else None ) - self._placement_policy = policy_cls(self, chunk_manager, self._mem_stats_collector, **placement_kwargs) + self._placement_policy = policy_cls( + chunk_manager=chunk_manager, mem_stats_collector=self._mem_stats_collector, **placement_kwargs + ) self._compute_list: List[Tuple[Chunk, ...]] = [] self._compute_idx: int = -1 self._async_works: Dict[Chunk, dist.Work] = {}