diff --git a/colossalai/zero/gemini/placement_policy.py b/colossalai/zero/gemini/placement_policy.py index a48f8d0d0..cae5cc202 100644 --- a/colossalai/zero/gemini/placement_policy.py +++ b/colossalai/zero/gemini/placement_policy.py @@ -237,12 +237,14 @@ class AutoPlacementPolicy(PlacementPolicy): if self.gemini_manager.is_warmup(): # no prefetch during warmup since we need compute_list return [] # modified from self.evict_tensors - cuda_capacity = self._steady_cuda_cap_ratio * colo_device_memory_capacity(get_accelerator().get_current_device()) + cuda_capacity = self._steady_cuda_cap_ratio * colo_device_memory_capacity( + get_accelerator().get_current_device() + ) max_cuda_non_model_data_per_period = self.mem_stats_collector.next_period_non_model_data_usage("cuda") used_cuda_model_data = self.chunk_manager.total_mem["cuda"] total_cuda_model_data = cuda_capacity - max_cuda_non_model_data_per_period avail_cuda_model_data = total_cuda_model_data - used_cuda_model_data - + prefetch_chunk_memory = 0 can_prefetch = self.max_prefetch - len(self.gemini_manager._async_works) prefetch = [] @@ -259,6 +261,7 @@ class AutoPlacementPolicy(PlacementPolicy): break return prefetch + class PlacementPolicyFactory: policies: Dict[str, Type[PlacementPolicy]] = { "auto": AutoPlacementPolicy,