Browse Source

Merge pull request #5733 from Hz188/feature/prefetch

[Gemini] implement auto policy prefetch and a little origin code modification.
pull/5738/head
botbw 6 months ago committed by GitHub
parent
commit
f5b7de38a4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
  1. 7
      colossalai/zero/gemini/gemini_hook.py
  2. 6
      colossalai/zero/gemini/gemini_mgr.py
  3. 53
      colossalai/zero/gemini/placement_policy.py
  4. 2
      examples/language/gpt/gemini/run_gemini.sh
  5. 5
      examples/language/gpt/gemini/train_gpt_demo.py

7
colossalai/zero/gemini/gemini_hook.py

@ -50,7 +50,12 @@ class GeminiZeROHook(ColoParamOpHook):
self._chunk_manager.access_chunk(chunk) self._chunk_manager.access_chunk(chunk)
# get possible chunks to prefetch # get possible chunks to prefetch
chunks_fetch_async = self._gemini_manager.placement_policy.get_prefetch_chunks() chunks_fetch_async = self._gemini_manager.placement_policy.get_prefetch_chunks(
is_warmup=self._gemini_manager.is_warmup(),
compute_list=self._gemini_manager.compute_list,
compute_idx=self._gemini_manager.compute_idx,
async_works=self._gemini_manager.async_works,
)
# prefetch # prefetch
for chunk in chunks_fetch_async: for chunk in chunks_fetch_async:

6
colossalai/zero/gemini/gemini_mgr.py

@ -45,7 +45,7 @@ class GeminiManager:
self._placement_policy = policy_cls(self, chunk_manager, self._mem_stats_collector, **placement_kwargs) self._placement_policy = policy_cls(self, chunk_manager, self._mem_stats_collector, **placement_kwargs)
self._compute_list: List[Tuple[Chunk, ...]] = [] self._compute_list: List[Tuple[Chunk, ...]] = []
self._compute_idx: int = -1 self._compute_idx: int = -1
self._async_works: Dict[Chunk, dist.work] = {} self._async_works: Dict[Chunk, dist.Work] = {}
self._h2d_volume = 0 self._h2d_volume = 0
self._d2h_volume = 0 self._d2h_volume = 0
@ -183,6 +183,10 @@ class GeminiManager:
def compute_idx(self) -> int: def compute_idx(self) -> int:
return self._compute_idx return self._compute_idx
@property
def async_works(self) -> Dict[Chunk, dist.Work]:
return self._async_works
@property @property
def placement_policy(self) -> PlacementPolicy: def placement_policy(self) -> PlacementPolicy:
return self._placement_policy return self._placement_policy

53
colossalai/zero/gemini/placement_policy.py

@ -5,6 +5,7 @@ from time import time
from typing import Dict, List, Optional, Tuple, Type from typing import Dict, List, Optional, Tuple, Type
import torch import torch
import torch.distributed as dist
from colossalai.accelerator import get_accelerator from colossalai.accelerator import get_accelerator
from colossalai.legacy.utils.memory import colo_device_memory_capacity from colossalai.legacy.utils.memory import colo_device_memory_capacity
@ -19,13 +20,11 @@ class PlacementPolicy(ABC):
def __init__( def __init__(
self, self,
gemini_manager: "GeminiManager", # TODO @botbw: solve circular import
chunk_manager: ChunkManager, chunk_manager: ChunkManager,
mem_stats_collector: Optional[ChunkMemStatsCollector] = None, mem_stats_collector: Optional[ChunkMemStatsCollector] = None,
max_prefetch: int = 0, max_prefetch: int = 0,
**kwargs, **kwargs,
) -> None: ) -> None:
self.gemini_manager = gemini_manager
self.chunk_manager = chunk_manager self.chunk_manager = chunk_manager
self.mem_stats_collector: Optional[ChunkMemStatsCollector] = mem_stats_collector self.mem_stats_collector: Optional[ChunkMemStatsCollector] = mem_stats_collector
self.max_prefetch = max_prefetch self.max_prefetch = max_prefetch
@ -40,14 +39,15 @@ class PlacementPolicy(ABC):
) -> None: ) -> None:
raise NotImplementedError raise NotImplementedError
def get_prefetch_chunks(self) -> List[Chunk]: def get_prefetch_chunks(
self, is_warmup, compute_list: tuple, compute_idx: int, async_works: Dict[Chunk, dist.Work]
) -> List[Chunk]:
return [] # no prefetch by default return [] # no prefetch by default
class StaticPlacementPolicy(PlacementPolicy): class StaticPlacementPolicy(PlacementPolicy):
def __init__( def __init__(
self, self,
gemini_manager: "GeminiManager",
chunk_manager: ChunkManager, chunk_manager: ChunkManager,
mem_stats_collector: Optional[ChunkMemStatsCollector] = None, mem_stats_collector: Optional[ChunkMemStatsCollector] = None,
max_prefetch: int = 0, max_prefetch: int = 0,
@ -56,9 +56,7 @@ class StaticPlacementPolicy(PlacementPolicy):
offload_param_frac: float = 0.0, offload_param_frac: float = 0.0,
**kwargs, **kwargs,
) -> None: ) -> None:
super().__init__( super().__init__(chunk_manager, mem_stats_collector=mem_stats_collector, max_prefetch=max_prefetch)
gemini_manager, chunk_manager, mem_stats_collector=mem_stats_collector, max_prefetch=max_prefetch
)
if offload_param_frac > 0.0 and (shard_param_frac != 1.0 or offload_optim_frac != 1.0): if offload_param_frac > 0.0 and (shard_param_frac != 1.0 or offload_optim_frac != 1.0):
warnings.warn("offload_param_frac is ignored when shard_param_frac != 1.0 or offload_optim_frac != 1.0") warnings.warn("offload_param_frac is ignored when shard_param_frac != 1.0 or offload_optim_frac != 1.0")
offload_param_frac = 0.0 offload_param_frac = 0.0
@ -109,20 +107,21 @@ class StaticPlacementPolicy(PlacementPolicy):
self.keep_gathered_chunk_mem = total_chunk_mem * (1 - self.shard_param_frac) self.keep_gathered_chunk_mem = total_chunk_mem * (1 - self.shard_param_frac)
self.keep_cuda_chunk_mem = total_chunk_mem * (1 - self.offload_param_frac) self.keep_cuda_chunk_mem = total_chunk_mem * (1 - self.offload_param_frac)
def get_prefetch_chunks(self) -> List[Chunk]: def get_prefetch_chunks(
if self.gemini_manager.is_warmup(): # no prefetch during warmup since we need compute_list self, is_warmup: bool, compute_list: tuple, compute_idx: int, async_works: Dict[Chunk, dist.Work]
) -> List[Chunk]:
if is_warmup: # no prefetch during warmup since we need compute_list
return [] return []
can_prefetch = self.max_prefetch - len(self.gemini_manager._async_works) can_prefetch = self.max_prefetch - len(async_works)
prefetch = [] prefetch = []
for i in range(self.gemini_manager.compute_idx + 1, len(self.gemini_manager.compute_list)): for i in range(compute_idx + 1, len(compute_list)):
break_flag = False for chunk in compute_list[i]:
for chunk in self.gemini_manager.compute_list[i]:
if len(prefetch) >= can_prefetch: if len(prefetch) >= can_prefetch:
break_flag = True
break break
if chunk not in prefetch and chunk not in self.chunk_manager.accessed_chunks: if chunk not in prefetch and chunk not in self.chunk_manager.accessed_chunks:
prefetch.append(chunk) prefetch.append(chunk)
if break_flag: else:
continue
break break
return prefetch return prefetch
@ -132,7 +131,6 @@ class AutoPlacementPolicy(PlacementPolicy):
def __init__( def __init__(
self, self,
gemini_manager: "GeminiManager",
chunk_manager: ChunkManager, chunk_manager: ChunkManager,
mem_stats_collector: Optional[ChunkMemStatsCollector] = None, mem_stats_collector: Optional[ChunkMemStatsCollector] = None,
max_prefetch: int = 0, max_prefetch: int = 0,
@ -140,9 +138,7 @@ class AutoPlacementPolicy(PlacementPolicy):
steady_cuda_cap_ratio: float = 0.9, steady_cuda_cap_ratio: float = 0.9,
**kwargs, **kwargs,
) -> None: ) -> None:
super().__init__( super().__init__(chunk_manager, mem_stats_collector=mem_stats_collector, max_prefetch=max_prefetch)
gemini_manager, chunk_manager, mem_stats_collector=mem_stats_collector, max_prefetch=max_prefetch
)
# model data will use 1-_warmup_non_model_data_ratio CUDA memory in warmup phase # model data will use 1-_warmup_non_model_data_ratio CUDA memory in warmup phase
# you can set them by AutoPlacementPolicy.set_warmup_non_model_data_ratio() # you can set them by AutoPlacementPolicy.set_warmup_non_model_data_ratio()
# and AutoPlacementPolicy.set_steady_cuda_cap_ratio() # and AutoPlacementPolicy.set_steady_cuda_cap_ratio()
@ -233,8 +229,10 @@ class AutoPlacementPolicy(PlacementPolicy):
else: else:
grads_device_map[p] = torch.device("cpu") grads_device_map[p] = torch.device("cpu")
def get_prefetch_chunks(self) -> List[Chunk]: def get_prefetch_chunks(
if self.gemini_manager.is_warmup(): # no prefetch during warmup since we need compute_list self, is_warmup: bool, compute_list: tuple, compute_idx: int, async_works: Dict[Chunk, dist.Work]
) -> List[Chunk]:
if is_warmup: # no prefetch during warmup since we need compute_list
return [] return []
# modified from self.evict_tensors # modified from self.evict_tensors
cuda_capacity = self._steady_cuda_cap_ratio * colo_device_memory_capacity( cuda_capacity = self._steady_cuda_cap_ratio * colo_device_memory_capacity(
@ -246,18 +244,17 @@ class AutoPlacementPolicy(PlacementPolicy):
avail_cuda_model_data = total_cuda_model_data - used_cuda_model_data avail_cuda_model_data = total_cuda_model_data - used_cuda_model_data
prefetch_chunk_memory = 0 prefetch_chunk_memory = 0
can_prefetch = self.max_prefetch - len(self.gemini_manager._async_works) can_prefetch = self.max_prefetch - len(async_works)
prefetch = [] prefetch = []
for i in range(self.gemini_manager.compute_idx + 1, len(self.gemini_manager.compute_list)): for i in range(compute_idx + 1, len(compute_list)):
break_flag = False for chunk in compute_list[i]:
for chunk in self.gemini_manager.compute_list[i]:
chunk: Chunk
if len(prefetch) >= can_prefetch or prefetch_chunk_memory + chunk.chunk_mem > avail_cuda_model_data: if len(prefetch) >= can_prefetch or prefetch_chunk_memory + chunk.chunk_mem > avail_cuda_model_data:
break_flag = True
break break
if chunk not in prefetch and chunk not in self.chunk_manager.accessed_chunks: if chunk not in prefetch and chunk not in self.chunk_manager.accessed_chunks:
prefetch_chunk_memory += chunk.chunk_mem
prefetch.append(chunk) prefetch.append(chunk)
if break_flag: else:
continue
break break
return prefetch return prefetch

2
examples/language/gpt/gemini/run_gemini.sh

@ -6,7 +6,7 @@ export DISTPLAN=${DISTPLAN:-"CAI_Gemini"}
export GPUNUM=${GPUNUM:-1} export GPUNUM=${GPUNUM:-1}
export BATCH_SIZE=${BATCH_SIZE:-16} export BATCH_SIZE=${BATCH_SIZE:-16}
export MODEL_TYPE=${MODEL_TYPE:-"gpt2_medium"} export MODEL_TYPE=${MODEL_TYPE:-"gpt2_medium"}
export TRAIN_STEP=${TRAIN_STEP:-10} export TRAIN_STEP=${TRAIN_STEP:-2}
# export PYTHONPATH=$PWD:$PYTHONPATH # export PYTHONPATH=$PWD:$PYTHONPATH

5
examples/language/gpt/gemini/train_gpt_demo.py

@ -66,11 +66,11 @@ class GPTLMLoss(nn.Module):
def get_cpu_mem(): def get_cpu_mem():
return psutil.Process().memory_info().rss / 1024**2 return psutil.Process().memory_info().rss / 1024**2 # MB unit
def get_gpu_mem(): def get_gpu_mem():
return torch.cuda.memory_allocated() / 1024**2 return torch.cuda.memory_allocated() / 1024**2 # MB unit
def get_mem_info(prefix=""): def get_mem_info(prefix=""):
@ -78,6 +78,7 @@ def get_mem_info(prefix=""):
def get_model_size(model: nn.Module): def get_model_size(model: nn.Module):
# get the number of parameter of the model
total_numel = 0 total_numel = 0
for module in model.modules(): for module in model.modules():
for p in module.parameters(recurse=False): for p in module.parameters(recurse=False):

Loading…
Cancel
Save