mirror of https://github.com/hpcaitech/ColossalAI
implement auto policy prefetch and modify a little origin code.
parent
c5ddf17c76
commit
d22bf30ca6
|
@ -11,6 +11,7 @@ from colossalai.legacy.utils.memory import colo_device_memory_capacity
|
||||||
from colossalai.zero.gemini.chunk import Chunk
|
from colossalai.zero.gemini.chunk import Chunk
|
||||||
|
|
||||||
from .chunk import Chunk, ChunkManager
|
from .chunk import Chunk, ChunkManager
|
||||||
|
from .gemini_mgr import GeminiManager
|
||||||
from .memory_tracer import ChunkMemStatsCollector
|
from .memory_tracer import ChunkMemStatsCollector
|
||||||
|
|
||||||
|
|
||||||
|
@ -123,8 +124,9 @@ class StaticPlacementPolicy(PlacementPolicy):
|
||||||
break
|
break
|
||||||
if chunk not in prefetch and chunk not in self.chunk_manager.accessed_chunks:
|
if chunk not in prefetch and chunk not in self.chunk_manager.accessed_chunks:
|
||||||
prefetch.append(chunk)
|
prefetch.append(chunk)
|
||||||
if len(prefetch) >= can_prefetch:
|
else:
|
||||||
break
|
continue
|
||||||
|
break
|
||||||
return prefetch
|
return prefetch
|
||||||
|
|
||||||
|
|
||||||
|
@ -133,7 +135,7 @@ class AutoPlacementPolicy(PlacementPolicy):
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
gemini_manager: "GeminiManager",
|
gemini_manager: GeminiManager,
|
||||||
chunk_manager: ChunkManager,
|
chunk_manager: ChunkManager,
|
||||||
mem_stats_collector: Optional[ChunkMemStatsCollector] = None,
|
mem_stats_collector: Optional[ChunkMemStatsCollector] = None,
|
||||||
max_prefetch: int = 0,
|
max_prefetch: int = 0,
|
||||||
|
@ -234,10 +236,32 @@ class AutoPlacementPolicy(PlacementPolicy):
|
||||||
else:
|
else:
|
||||||
grads_device_map[p] = torch.device("cpu")
|
grads_device_map[p] = torch.device("cpu")
|
||||||
|
|
||||||
def get_prefetch_chunks(self, max_prefetch: int) -> List[Chunk]:
|
def get_prefetch_chunks(self) -> List[Chunk]:
|
||||||
# TODO @haze188 @botbw: implement prefetching for auto
|
if self.gemini_manager.is_warmup(): # no prefetch during warmup since we need compute_list
|
||||||
|
return []
|
||||||
|
# modified from self.evict_tensors
|
||||||
|
cuda_capacity = self._steady_cuda_cap_ratio * colo_device_memory_capacity(
|
||||||
|
get_accelerator().get_current_device()
|
||||||
|
)
|
||||||
|
max_cuda_non_model_data_per_period = self.mem_stats_collector.next_period_non_model_data_usage("cuda")
|
||||||
|
used_cuda_model_data = self.chunk_manager.total_mem["cuda"]
|
||||||
|
total_cuda_model_data = cuda_capacity - max_cuda_non_model_data_per_period
|
||||||
|
avail_cuda_model_data = total_cuda_model_data - used_cuda_model_data
|
||||||
|
|
||||||
return []
|
prefetch_chunk_memory = 0
|
||||||
|
can_prefetch = self.max_prefetch - len(self.gemini_manager._async_works)
|
||||||
|
prefetch = []
|
||||||
|
for i in range(self.gemini_manager.compute_idx + 1, len(self.gemini_manager.compute_list)):
|
||||||
|
for chunk in self.gemini_manager.compute_list[i]:
|
||||||
|
chunk: Chunk
|
||||||
|
if len(prefetch) >= can_prefetch or prefetch_chunk_memory + chunk.chunk_mem > avail_cuda_model_data:
|
||||||
|
break
|
||||||
|
if chunk not in prefetch and chunk not in self.chunk_manager.accessed_chunks:
|
||||||
|
prefetch.append(chunk)
|
||||||
|
else:
|
||||||
|
continue
|
||||||
|
break
|
||||||
|
return prefetch
|
||||||
|
|
||||||
|
|
||||||
class PlacementPolicyFactory:
|
class PlacementPolicyFactory:
|
||||||
|
|
Loading…
Reference in New Issue