From f45f8a2aa74e6734f0f7ad89729d7e00b5d3d985 Mon Sep 17 00:00:00 2001 From: hxwang Date: Thu, 16 May 2024 16:12:53 +0800 Subject: [PATCH] [gemini] maxprefetch means maximum work to keep --- colossalai/zero/gemini/placement_policy.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/colossalai/zero/gemini/placement_policy.py b/colossalai/zero/gemini/placement_policy.py index e5f61a033..c0f92fa50 100644 --- a/colossalai/zero/gemini/placement_policy.py +++ b/colossalai/zero/gemini/placement_policy.py @@ -118,14 +118,15 @@ class StaticPlacementPolicy(PlacementPolicy): def get_prefetch_chunks(self) -> List[Chunk]: if self.gemini_manager.is_warmup(): # no prefetch during warmup since we need compute_list return [] + can_prefetch = self.max_prefetch - len(self.gemini_manager._async_works) prefetch = [] for i in range(self.gemini_manager.compute_idx + 1, len(self.gemini_manager.compute_list)): for chunk in self.gemini_manager.compute_list[i]: - if len(prefetch) >= self.max_prefetch: + if len(prefetch) >= can_prefetch: break if chunk not in prefetch and chunk not in self.chunk_manager.accessed_chunks: prefetch.append(chunk) - if len(prefetch) >= self.max_prefetch: + if len(prefetch) >= can_prefetch: break return prefetch