mirror of https://github.com/hpcaitech/ColossalAI
commit
f5a5287f87
|
@ -19,7 +19,7 @@ class PlacementPolicy(ABC):
|
|||
|
||||
def __init__(
|
||||
self,
|
||||
gemini_manager: "GeminiManager",
|
||||
gemini_manager: "GeminiManager", # TODO @botbw: solve circular import
|
||||
chunk_manager: ChunkManager,
|
||||
mem_stats_collector: Optional[ChunkMemStatsCollector] = None,
|
||||
max_prefetch: int = 0,
|
||||
|
@ -40,9 +40,8 @@ class PlacementPolicy(ABC):
|
|||
) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def get_prefetch_chunks(self) -> List[Chunk]:
|
||||
raise NotImplementedError
|
||||
return [] # no prefetch by default
|
||||
|
||||
|
||||
class StaticPlacementPolicy(PlacementPolicy):
|
||||
|
@ -116,12 +115,14 @@ class StaticPlacementPolicy(PlacementPolicy):
|
|||
can_prefetch = self.max_prefetch - len(self.gemini_manager._async_works)
|
||||
prefetch = []
|
||||
for i in range(self.gemini_manager.compute_idx + 1, len(self.gemini_manager.compute_list)):
|
||||
break_flag = False
|
||||
for chunk in self.gemini_manager.compute_list[i]:
|
||||
if len(prefetch) >= can_prefetch:
|
||||
break_flag = True
|
||||
break
|
||||
if chunk not in prefetch and chunk not in self.chunk_manager.accessed_chunks:
|
||||
prefetch.append(chunk)
|
||||
if len(prefetch) >= can_prefetch:
|
||||
if break_flag:
|
||||
break
|
||||
return prefetch
|
||||
|
||||
|
@ -232,8 +233,33 @@ class AutoPlacementPolicy(PlacementPolicy):
|
|||
else:
|
||||
grads_device_map[p] = torch.device("cpu")
|
||||
|
||||
def get_prefetch_chunks(self, max_prefetch: int) -> List[Chunk]:
|
||||
return [] # TODO @botbw: implement prefetching for auto
|
||||
def get_prefetch_chunks(self) -> List[Chunk]:
|
||||
if self.gemini_manager.is_warmup(): # no prefetch during warmup since we need compute_list
|
||||
return []
|
||||
# modified from self.evict_tensors
|
||||
cuda_capacity = self._steady_cuda_cap_ratio * colo_device_memory_capacity(
|
||||
get_accelerator().get_current_device()
|
||||
)
|
||||
max_cuda_non_model_data_per_period = self.mem_stats_collector.next_period_non_model_data_usage("cuda")
|
||||
used_cuda_model_data = self.chunk_manager.total_mem["cuda"]
|
||||
total_cuda_model_data = cuda_capacity - max_cuda_non_model_data_per_period
|
||||
avail_cuda_model_data = total_cuda_model_data - used_cuda_model_data
|
||||
|
||||
prefetch_chunk_memory = 0
|
||||
can_prefetch = self.max_prefetch - len(self.gemini_manager._async_works)
|
||||
prefetch = []
|
||||
for i in range(self.gemini_manager.compute_idx + 1, len(self.gemini_manager.compute_list)):
|
||||
break_flag = False
|
||||
for chunk in self.gemini_manager.compute_list[i]:
|
||||
chunk: Chunk
|
||||
if len(prefetch) >= can_prefetch or prefetch_chunk_memory + chunk.chunk_mem > avail_cuda_model_data:
|
||||
break_flag = True
|
||||
break
|
||||
if chunk not in prefetch and chunk not in self.chunk_manager.accessed_chunks:
|
||||
prefetch.append(chunk)
|
||||
if break_flag:
|
||||
break
|
||||
return prefetch
|
||||
|
||||
|
||||
class PlacementPolicyFactory:
|
||||
|
|
Loading…
Reference in New Issue