|
|
|
@ -5,6 +5,7 @@ from time import time
|
|
|
|
|
from typing import Dict, List, Optional, Tuple, Type |
|
|
|
|
|
|
|
|
|
import torch |
|
|
|
|
import torch.distributed as dist |
|
|
|
|
|
|
|
|
|
from colossalai.accelerator import get_accelerator |
|
|
|
|
from colossalai.legacy.utils.memory import colo_device_memory_capacity |
|
|
|
@ -19,13 +20,11 @@ class PlacementPolicy(ABC):
|
|
|
|
|
|
|
|
|
|
def __init__( |
|
|
|
|
self, |
|
|
|
|
gemini_manager: "GeminiManager", # TODO @botbw: solve circular import |
|
|
|
|
chunk_manager: ChunkManager, |
|
|
|
|
mem_stats_collector: Optional[ChunkMemStatsCollector] = None, |
|
|
|
|
max_prefetch: int = 0, |
|
|
|
|
**kwargs, |
|
|
|
|
) -> None: |
|
|
|
|
self.gemini_manager = gemini_manager |
|
|
|
|
self.chunk_manager = chunk_manager |
|
|
|
|
self.mem_stats_collector: Optional[ChunkMemStatsCollector] = mem_stats_collector |
|
|
|
|
self.max_prefetch = max_prefetch |
|
|
|
@ -40,14 +39,15 @@ class PlacementPolicy(ABC):
|
|
|
|
|
) -> None: |
|
|
|
|
raise NotImplementedError |
|
|
|
|
|
|
|
|
|
def get_prefetch_chunks(self) -> List[Chunk]: |
|
|
|
|
def get_prefetch_chunks( |
|
|
|
|
self, is_warmup, compute_list: tuple, compute_idx: int, async_works: Dict[Chunk, dist.Work] |
|
|
|
|
) -> List[Chunk]: |
|
|
|
|
return [] # no prefetch by default |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class StaticPlacementPolicy(PlacementPolicy): |
|
|
|
|
def __init__( |
|
|
|
|
self, |
|
|
|
|
gemini_manager: "GeminiManager", |
|
|
|
|
chunk_manager: ChunkManager, |
|
|
|
|
mem_stats_collector: Optional[ChunkMemStatsCollector] = None, |
|
|
|
|
max_prefetch: int = 0, |
|
|
|
@ -56,9 +56,7 @@ class StaticPlacementPolicy(PlacementPolicy):
|
|
|
|
|
offload_param_frac: float = 0.0, |
|
|
|
|
**kwargs, |
|
|
|
|
) -> None: |
|
|
|
|
super().__init__( |
|
|
|
|
gemini_manager, chunk_manager, mem_stats_collector=mem_stats_collector, max_prefetch=max_prefetch |
|
|
|
|
) |
|
|
|
|
super().__init__(chunk_manager, mem_stats_collector=mem_stats_collector, max_prefetch=max_prefetch) |
|
|
|
|
if offload_param_frac > 0.0 and (shard_param_frac != 1.0 or offload_optim_frac != 1.0): |
|
|
|
|
warnings.warn("offload_param_frac is ignored when shard_param_frac != 1.0 or offload_optim_frac != 1.0") |
|
|
|
|
offload_param_frac = 0.0 |
|
|
|
@ -109,21 +107,22 @@ class StaticPlacementPolicy(PlacementPolicy):
|
|
|
|
|
self.keep_gathered_chunk_mem = total_chunk_mem * (1 - self.shard_param_frac) |
|
|
|
|
self.keep_cuda_chunk_mem = total_chunk_mem * (1 - self.offload_param_frac) |
|
|
|
|
|
|
|
|
|
def get_prefetch_chunks(self) -> List[Chunk]: |
|
|
|
|
if self.gemini_manager.is_warmup(): # no prefetch during warmup since we need compute_list |
|
|
|
|
def get_prefetch_chunks( |
|
|
|
|
self, is_warmup: bool, compute_list: tuple, compute_idx: int, async_works: Dict[Chunk, dist.Work] |
|
|
|
|
) -> List[Chunk]: |
|
|
|
|
if is_warmup: # no prefetch during warmup since we need compute_list |
|
|
|
|
return [] |
|
|
|
|
can_prefetch = self.max_prefetch - len(self.gemini_manager._async_works) |
|
|
|
|
can_prefetch = self.max_prefetch - len(async_works) |
|
|
|
|
prefetch = [] |
|
|
|
|
for i in range(self.gemini_manager.compute_idx + 1, len(self.gemini_manager.compute_list)): |
|
|
|
|
break_flag = False |
|
|
|
|
for chunk in self.gemini_manager.compute_list[i]: |
|
|
|
|
for i in range(compute_idx + 1, len(compute_list)): |
|
|
|
|
for chunk in compute_list[i]: |
|
|
|
|
if len(prefetch) >= can_prefetch: |
|
|
|
|
break_flag = True |
|
|
|
|
break |
|
|
|
|
if chunk not in prefetch and chunk not in self.chunk_manager.accessed_chunks: |
|
|
|
|
prefetch.append(chunk) |
|
|
|
|
if break_flag: |
|
|
|
|
break |
|
|
|
|
else: |
|
|
|
|
continue |
|
|
|
|
break |
|
|
|
|
return prefetch |
|
|
|
|
|
|
|
|
|
|
|
|
|
@ -132,7 +131,6 @@ class AutoPlacementPolicy(PlacementPolicy):
|
|
|
|
|
|
|
|
|
|
def __init__( |
|
|
|
|
self, |
|
|
|
|
gemini_manager: "GeminiManager", |
|
|
|
|
chunk_manager: ChunkManager, |
|
|
|
|
mem_stats_collector: Optional[ChunkMemStatsCollector] = None, |
|
|
|
|
max_prefetch: int = 0, |
|
|
|
@ -140,9 +138,7 @@ class AutoPlacementPolicy(PlacementPolicy):
|
|
|
|
|
steady_cuda_cap_ratio: float = 0.9, |
|
|
|
|
**kwargs, |
|
|
|
|
) -> None: |
|
|
|
|
super().__init__( |
|
|
|
|
gemini_manager, chunk_manager, mem_stats_collector=mem_stats_collector, max_prefetch=max_prefetch |
|
|
|
|
) |
|
|
|
|
super().__init__(chunk_manager, mem_stats_collector=mem_stats_collector, max_prefetch=max_prefetch) |
|
|
|
|
# model data will use 1-_warmup_non_model_data_ratio CUDA memory in warmup phase |
|
|
|
|
# you can set them by AutoPlacementPolicy.set_warmup_non_model_data_ratio() |
|
|
|
|
# and AutoPlacementPolicy.set_steady_cuda_cap_ratio() |
|
|
|
@ -233,8 +229,10 @@ class AutoPlacementPolicy(PlacementPolicy):
|
|
|
|
|
else: |
|
|
|
|
grads_device_map[p] = torch.device("cpu") |
|
|
|
|
|
|
|
|
|
def get_prefetch_chunks(self) -> List[Chunk]: |
|
|
|
|
if self.gemini_manager.is_warmup(): # no prefetch during warmup since we need compute_list |
|
|
|
|
def get_prefetch_chunks( |
|
|
|
|
self, is_warmup: bool, compute_list: tuple, compute_idx: int, async_works: Dict[Chunk, dist.Work] |
|
|
|
|
) -> List[Chunk]: |
|
|
|
|
if is_warmup: # no prefetch during warmup since we need compute_list |
|
|
|
|
return [] |
|
|
|
|
# modified from self.evict_tensors |
|
|
|
|
cuda_capacity = self._steady_cuda_cap_ratio * colo_device_memory_capacity( |
|
|
|
@ -246,19 +244,18 @@ class AutoPlacementPolicy(PlacementPolicy):
|
|
|
|
|
avail_cuda_model_data = total_cuda_model_data - used_cuda_model_data |
|
|
|
|
|
|
|
|
|
prefetch_chunk_memory = 0 |
|
|
|
|
can_prefetch = self.max_prefetch - len(self.gemini_manager._async_works) |
|
|
|
|
can_prefetch = self.max_prefetch - len(async_works) |
|
|
|
|
prefetch = [] |
|
|
|
|
for i in range(self.gemini_manager.compute_idx + 1, len(self.gemini_manager.compute_list)): |
|
|
|
|
break_flag = False |
|
|
|
|
for chunk in self.gemini_manager.compute_list[i]: |
|
|
|
|
chunk: Chunk |
|
|
|
|
for i in range(compute_idx + 1, len(compute_list)): |
|
|
|
|
for chunk in compute_list[i]: |
|
|
|
|
if len(prefetch) >= can_prefetch or prefetch_chunk_memory + chunk.chunk_mem > avail_cuda_model_data: |
|
|
|
|
break_flag = True |
|
|
|
|
break |
|
|
|
|
if chunk not in prefetch and chunk not in self.chunk_manager.accessed_chunks: |
|
|
|
|
prefetch_chunk_memory += chunk.chunk_mem |
|
|
|
|
prefetch.append(chunk) |
|
|
|
|
if break_flag: |
|
|
|
|
break |
|
|
|
|
else: |
|
|
|
|
continue |
|
|
|
|
break |
|
|
|
|
return prefetch |
|
|
|
|
|
|
|
|
|
|
|
|
|
|