mirror of https://github.com/hpcaitech/ColossalAI
Merge pull request #5733 from Hz188/feature/prefetch
[Gemini] implement auto policy prefetch and a little origin code modification.pull/5738/head
commit
f5b7de38a4
|
@ -50,7 +50,12 @@ class GeminiZeROHook(ColoParamOpHook):
|
||||||
self._chunk_manager.access_chunk(chunk)
|
self._chunk_manager.access_chunk(chunk)
|
||||||
|
|
||||||
# get possible chunks to prefetch
|
# get possible chunks to prefetch
|
||||||
chunks_fetch_async = self._gemini_manager.placement_policy.get_prefetch_chunks()
|
chunks_fetch_async = self._gemini_manager.placement_policy.get_prefetch_chunks(
|
||||||
|
is_warmup=self._gemini_manager.is_warmup(),
|
||||||
|
compute_list=self._gemini_manager.compute_list,
|
||||||
|
compute_idx=self._gemini_manager.compute_idx,
|
||||||
|
async_works=self._gemini_manager.async_works,
|
||||||
|
)
|
||||||
|
|
||||||
# prefetch
|
# prefetch
|
||||||
for chunk in chunks_fetch_async:
|
for chunk in chunks_fetch_async:
|
||||||
|
|
|
@ -45,7 +45,7 @@ class GeminiManager:
|
||||||
self._placement_policy = policy_cls(self, chunk_manager, self._mem_stats_collector, **placement_kwargs)
|
self._placement_policy = policy_cls(self, chunk_manager, self._mem_stats_collector, **placement_kwargs)
|
||||||
self._compute_list: List[Tuple[Chunk, ...]] = []
|
self._compute_list: List[Tuple[Chunk, ...]] = []
|
||||||
self._compute_idx: int = -1
|
self._compute_idx: int = -1
|
||||||
self._async_works: Dict[Chunk, dist.work] = {}
|
self._async_works: Dict[Chunk, dist.Work] = {}
|
||||||
|
|
||||||
self._h2d_volume = 0
|
self._h2d_volume = 0
|
||||||
self._d2h_volume = 0
|
self._d2h_volume = 0
|
||||||
|
@ -183,6 +183,10 @@ class GeminiManager:
|
||||||
def compute_idx(self) -> int:
|
def compute_idx(self) -> int:
|
||||||
return self._compute_idx
|
return self._compute_idx
|
||||||
|
|
||||||
|
@property
|
||||||
|
def async_works(self) -> Dict[Chunk, dist.Work]:
|
||||||
|
return self._async_works
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def placement_policy(self) -> PlacementPolicy:
|
def placement_policy(self) -> PlacementPolicy:
|
||||||
return self._placement_policy
|
return self._placement_policy
|
||||||
|
|
|
@ -5,6 +5,7 @@ from time import time
|
||||||
from typing import Dict, List, Optional, Tuple, Type
|
from typing import Dict, List, Optional, Tuple, Type
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
import torch.distributed as dist
|
||||||
|
|
||||||
from colossalai.accelerator import get_accelerator
|
from colossalai.accelerator import get_accelerator
|
||||||
from colossalai.legacy.utils.memory import colo_device_memory_capacity
|
from colossalai.legacy.utils.memory import colo_device_memory_capacity
|
||||||
|
@ -19,13 +20,11 @@ class PlacementPolicy(ABC):
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
gemini_manager: "GeminiManager", # TODO @botbw: solve circular import
|
|
||||||
chunk_manager: ChunkManager,
|
chunk_manager: ChunkManager,
|
||||||
mem_stats_collector: Optional[ChunkMemStatsCollector] = None,
|
mem_stats_collector: Optional[ChunkMemStatsCollector] = None,
|
||||||
max_prefetch: int = 0,
|
max_prefetch: int = 0,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> None:
|
) -> None:
|
||||||
self.gemini_manager = gemini_manager
|
|
||||||
self.chunk_manager = chunk_manager
|
self.chunk_manager = chunk_manager
|
||||||
self.mem_stats_collector: Optional[ChunkMemStatsCollector] = mem_stats_collector
|
self.mem_stats_collector: Optional[ChunkMemStatsCollector] = mem_stats_collector
|
||||||
self.max_prefetch = max_prefetch
|
self.max_prefetch = max_prefetch
|
||||||
|
@ -40,14 +39,15 @@ class PlacementPolicy(ABC):
|
||||||
) -> None:
|
) -> None:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
def get_prefetch_chunks(self) -> List[Chunk]:
|
def get_prefetch_chunks(
|
||||||
|
self, is_warmup, compute_list: tuple, compute_idx: int, async_works: Dict[Chunk, dist.Work]
|
||||||
|
) -> List[Chunk]:
|
||||||
return [] # no prefetch by default
|
return [] # no prefetch by default
|
||||||
|
|
||||||
|
|
||||||
class StaticPlacementPolicy(PlacementPolicy):
|
class StaticPlacementPolicy(PlacementPolicy):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
gemini_manager: "GeminiManager",
|
|
||||||
chunk_manager: ChunkManager,
|
chunk_manager: ChunkManager,
|
||||||
mem_stats_collector: Optional[ChunkMemStatsCollector] = None,
|
mem_stats_collector: Optional[ChunkMemStatsCollector] = None,
|
||||||
max_prefetch: int = 0,
|
max_prefetch: int = 0,
|
||||||
|
@ -56,9 +56,7 @@ class StaticPlacementPolicy(PlacementPolicy):
|
||||||
offload_param_frac: float = 0.0,
|
offload_param_frac: float = 0.0,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__(
|
super().__init__(chunk_manager, mem_stats_collector=mem_stats_collector, max_prefetch=max_prefetch)
|
||||||
gemini_manager, chunk_manager, mem_stats_collector=mem_stats_collector, max_prefetch=max_prefetch
|
|
||||||
)
|
|
||||||
if offload_param_frac > 0.0 and (shard_param_frac != 1.0 or offload_optim_frac != 1.0):
|
if offload_param_frac > 0.0 and (shard_param_frac != 1.0 or offload_optim_frac != 1.0):
|
||||||
warnings.warn("offload_param_frac is ignored when shard_param_frac != 1.0 or offload_optim_frac != 1.0")
|
warnings.warn("offload_param_frac is ignored when shard_param_frac != 1.0 or offload_optim_frac != 1.0")
|
||||||
offload_param_frac = 0.0
|
offload_param_frac = 0.0
|
||||||
|
@ -109,21 +107,22 @@ class StaticPlacementPolicy(PlacementPolicy):
|
||||||
self.keep_gathered_chunk_mem = total_chunk_mem * (1 - self.shard_param_frac)
|
self.keep_gathered_chunk_mem = total_chunk_mem * (1 - self.shard_param_frac)
|
||||||
self.keep_cuda_chunk_mem = total_chunk_mem * (1 - self.offload_param_frac)
|
self.keep_cuda_chunk_mem = total_chunk_mem * (1 - self.offload_param_frac)
|
||||||
|
|
||||||
def get_prefetch_chunks(self) -> List[Chunk]:
|
def get_prefetch_chunks(
|
||||||
if self.gemini_manager.is_warmup(): # no prefetch during warmup since we need compute_list
|
self, is_warmup: bool, compute_list: tuple, compute_idx: int, async_works: Dict[Chunk, dist.Work]
|
||||||
|
) -> List[Chunk]:
|
||||||
|
if is_warmup: # no prefetch during warmup since we need compute_list
|
||||||
return []
|
return []
|
||||||
can_prefetch = self.max_prefetch - len(self.gemini_manager._async_works)
|
can_prefetch = self.max_prefetch - len(async_works)
|
||||||
prefetch = []
|
prefetch = []
|
||||||
for i in range(self.gemini_manager.compute_idx + 1, len(self.gemini_manager.compute_list)):
|
for i in range(compute_idx + 1, len(compute_list)):
|
||||||
break_flag = False
|
for chunk in compute_list[i]:
|
||||||
for chunk in self.gemini_manager.compute_list[i]:
|
|
||||||
if len(prefetch) >= can_prefetch:
|
if len(prefetch) >= can_prefetch:
|
||||||
break_flag = True
|
|
||||||
break
|
break
|
||||||
if chunk not in prefetch and chunk not in self.chunk_manager.accessed_chunks:
|
if chunk not in prefetch and chunk not in self.chunk_manager.accessed_chunks:
|
||||||
prefetch.append(chunk)
|
prefetch.append(chunk)
|
||||||
if break_flag:
|
else:
|
||||||
break
|
continue
|
||||||
|
break
|
||||||
return prefetch
|
return prefetch
|
||||||
|
|
||||||
|
|
||||||
|
@ -132,7 +131,6 @@ class AutoPlacementPolicy(PlacementPolicy):
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
gemini_manager: "GeminiManager",
|
|
||||||
chunk_manager: ChunkManager,
|
chunk_manager: ChunkManager,
|
||||||
mem_stats_collector: Optional[ChunkMemStatsCollector] = None,
|
mem_stats_collector: Optional[ChunkMemStatsCollector] = None,
|
||||||
max_prefetch: int = 0,
|
max_prefetch: int = 0,
|
||||||
|
@ -140,9 +138,7 @@ class AutoPlacementPolicy(PlacementPolicy):
|
||||||
steady_cuda_cap_ratio: float = 0.9,
|
steady_cuda_cap_ratio: float = 0.9,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__(
|
super().__init__(chunk_manager, mem_stats_collector=mem_stats_collector, max_prefetch=max_prefetch)
|
||||||
gemini_manager, chunk_manager, mem_stats_collector=mem_stats_collector, max_prefetch=max_prefetch
|
|
||||||
)
|
|
||||||
# model data will use 1-_warmup_non_model_data_ratio CUDA memory in warmup phase
|
# model data will use 1-_warmup_non_model_data_ratio CUDA memory in warmup phase
|
||||||
# you can set them by AutoPlacementPolicy.set_warmup_non_model_data_ratio()
|
# you can set them by AutoPlacementPolicy.set_warmup_non_model_data_ratio()
|
||||||
# and AutoPlacementPolicy.set_steady_cuda_cap_ratio()
|
# and AutoPlacementPolicy.set_steady_cuda_cap_ratio()
|
||||||
|
@ -233,8 +229,10 @@ class AutoPlacementPolicy(PlacementPolicy):
|
||||||
else:
|
else:
|
||||||
grads_device_map[p] = torch.device("cpu")
|
grads_device_map[p] = torch.device("cpu")
|
||||||
|
|
||||||
def get_prefetch_chunks(self) -> List[Chunk]:
|
def get_prefetch_chunks(
|
||||||
if self.gemini_manager.is_warmup(): # no prefetch during warmup since we need compute_list
|
self, is_warmup: bool, compute_list: tuple, compute_idx: int, async_works: Dict[Chunk, dist.Work]
|
||||||
|
) -> List[Chunk]:
|
||||||
|
if is_warmup: # no prefetch during warmup since we need compute_list
|
||||||
return []
|
return []
|
||||||
# modified from self.evict_tensors
|
# modified from self.evict_tensors
|
||||||
cuda_capacity = self._steady_cuda_cap_ratio * colo_device_memory_capacity(
|
cuda_capacity = self._steady_cuda_cap_ratio * colo_device_memory_capacity(
|
||||||
|
@ -246,19 +244,18 @@ class AutoPlacementPolicy(PlacementPolicy):
|
||||||
avail_cuda_model_data = total_cuda_model_data - used_cuda_model_data
|
avail_cuda_model_data = total_cuda_model_data - used_cuda_model_data
|
||||||
|
|
||||||
prefetch_chunk_memory = 0
|
prefetch_chunk_memory = 0
|
||||||
can_prefetch = self.max_prefetch - len(self.gemini_manager._async_works)
|
can_prefetch = self.max_prefetch - len(async_works)
|
||||||
prefetch = []
|
prefetch = []
|
||||||
for i in range(self.gemini_manager.compute_idx + 1, len(self.gemini_manager.compute_list)):
|
for i in range(compute_idx + 1, len(compute_list)):
|
||||||
break_flag = False
|
for chunk in compute_list[i]:
|
||||||
for chunk in self.gemini_manager.compute_list[i]:
|
|
||||||
chunk: Chunk
|
|
||||||
if len(prefetch) >= can_prefetch or prefetch_chunk_memory + chunk.chunk_mem > avail_cuda_model_data:
|
if len(prefetch) >= can_prefetch or prefetch_chunk_memory + chunk.chunk_mem > avail_cuda_model_data:
|
||||||
break_flag = True
|
|
||||||
break
|
break
|
||||||
if chunk not in prefetch and chunk not in self.chunk_manager.accessed_chunks:
|
if chunk not in prefetch and chunk not in self.chunk_manager.accessed_chunks:
|
||||||
|
prefetch_chunk_memory += chunk.chunk_mem
|
||||||
prefetch.append(chunk)
|
prefetch.append(chunk)
|
||||||
if break_flag:
|
else:
|
||||||
break
|
continue
|
||||||
|
break
|
||||||
return prefetch
|
return prefetch
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -6,7 +6,7 @@ export DISTPLAN=${DISTPLAN:-"CAI_Gemini"}
|
||||||
export GPUNUM=${GPUNUM:-1}
|
export GPUNUM=${GPUNUM:-1}
|
||||||
export BATCH_SIZE=${BATCH_SIZE:-16}
|
export BATCH_SIZE=${BATCH_SIZE:-16}
|
||||||
export MODEL_TYPE=${MODEL_TYPE:-"gpt2_medium"}
|
export MODEL_TYPE=${MODEL_TYPE:-"gpt2_medium"}
|
||||||
export TRAIN_STEP=${TRAIN_STEP:-10}
|
export TRAIN_STEP=${TRAIN_STEP:-2}
|
||||||
# export PYTHONPATH=$PWD:$PYTHONPATH
|
# export PYTHONPATH=$PWD:$PYTHONPATH
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -66,11 +66,11 @@ class GPTLMLoss(nn.Module):
|
||||||
|
|
||||||
|
|
||||||
def get_cpu_mem():
|
def get_cpu_mem():
|
||||||
return psutil.Process().memory_info().rss / 1024**2
|
return psutil.Process().memory_info().rss / 1024**2 # MB unit
|
||||||
|
|
||||||
|
|
||||||
def get_gpu_mem():
|
def get_gpu_mem():
|
||||||
return torch.cuda.memory_allocated() / 1024**2
|
return torch.cuda.memory_allocated() / 1024**2 # MB unit
|
||||||
|
|
||||||
|
|
||||||
def get_mem_info(prefix=""):
|
def get_mem_info(prefix=""):
|
||||||
|
@ -78,6 +78,7 @@ def get_mem_info(prefix=""):
|
||||||
|
|
||||||
|
|
||||||
def get_model_size(model: nn.Module):
|
def get_model_size(model: nn.Module):
|
||||||
|
# get the number of parameter of the model
|
||||||
total_numel = 0
|
total_numel = 0
|
||||||
for module in model.modules():
|
for module in model.modules():
|
||||||
for p in module.parameters(recurse=False):
|
for p in module.parameters(recurse=False):
|
||||||
|
|
Loading…
Reference in New Issue