diff --git a/colossalai/gemini/placement_policy.py b/colossalai/gemini/placement_policy.py index 5ae1dfaa1..ec6afbc07 100644 --- a/colossalai/gemini/placement_policy.py +++ b/colossalai/gemini/placement_policy.py @@ -58,13 +58,13 @@ class CUDAPlacementPolicy(PlacementPolicy): class AutoPlacementPolicy(PlacementPolicy): need_mem_stats: bool = True + # model data will use 1-_warmup_non_model_data_ratio CUDA memory in warmup phase + # you can set them by AutoPlacementPolicy.set_warmup_non_model_data_ratio() and AutoPlacementPolicy.set_steady_cuda_cap_ratio() + _warmup_non_model_data_ratio: float = 0.8 + _steady_cuda_cap_ratio: float = 0.9 def __init__(self, chunk_manager: ChunkManager, mem_stats_collector: Optional[MemStatsCollectorV2] = None) -> None: super().__init__(chunk_manager, mem_stats_collector=mem_stats_collector) - # model data will use 1-self._warmup_non_model_data_ratio CUDA memory in warmup phase - # TODO(ver217): make these args configurable - self._warmup_non_model_data_ratio: float = 0.8 - self._steady_cuda_cap_ratio: float = 0.9 def evict_tensors(self, can_evict_chunks: List[Chunk], @@ -94,11 +94,11 @@ class AutoPlacementPolicy(PlacementPolicy): used_cuda_model_data = self.chunk_manager.total_mem['cuda'] if warmup: # We designate a part of CUDA memory for model data in warmup iterations. - max_cuda_non_model_data_per_period = cuda_capacity * self._warmup_non_model_data_ratio + max_cuda_non_model_data_per_period = cuda_capacity * AutoPlacementPolicy._warmup_non_model_data_ratio else: # max non-model-data cuda memory consumption of this sampling moment and the next sampling moment. max_cuda_non_model_data_per_period = self.mem_stats_collector.next_period_non_model_data_usage('cuda') - cuda_capacity *= self._steady_cuda_cap_ratio + cuda_capacity *= AutoPlacementPolicy._steady_cuda_cap_ratio total_cuda_model_data = cuda_capacity - max_cuda_non_model_data_per_period avail_cuda_model_data = total_cuda_model_data - used_cuda_model_data freed_cuda_model_data = 0 @@ -133,6 +133,18 @@ class AutoPlacementPolicy(PlacementPolicy): next_compute_idx = sorted(next_compute_idx.items(), key=lambda pair: pair[1], reverse=True) return [t for (t, idx) in next_compute_idx] + @staticmethod + def set_warmup_non_model_data_ratio(ratio: float) -> None: + ratio = float(ratio) + assert 0.0 < ratio < 1.0 + AutoPlacementPolicy._warmup_non_model_data_ratio = ratio + + @staticmethod + def set_steady_cuda_cap_ratio(ratio: float) -> None: + ratio = float(ratio) + assert 0.0 < ratio < 1.0 + AutoPlacementPolicy._steady_cuda_cap_ratio = ratio + class PlacementPolicyFactory: policies: Dict[str, PlacementPolicy] = {