diff --git a/colossalai/zero/gemini/gemini_hook.py b/colossalai/zero/gemini/gemini_hook.py index 27a19c132..cab26c822 100644 --- a/colossalai/zero/gemini/gemini_hook.py +++ b/colossalai/zero/gemini/gemini_hook.py @@ -50,7 +50,12 @@ class GeminiZeROHook(ColoParamOpHook): self._chunk_manager.access_chunk(chunk) # get possible chunks to prefetch - chunks_fetch_async = self._gemini_manager.placement_policy.get_prefetch_chunks() + chunks_fetch_async = self._gemini_manager.placement_policy.get_prefetch_chunks( + is_warmup=self._gemini_manager.is_warmup(), + compute_list=self._gemini_manager.compute_list, + compute_idx=self._gemini_manager.compute_idx, + async_works=self._gemini_manager.async_works, + ) # prefetch for chunk in chunks_fetch_async: diff --git a/colossalai/zero/gemini/gemini_mgr.py b/colossalai/zero/gemini/gemini_mgr.py index 11bde789c..5b309c7a1 100644 --- a/colossalai/zero/gemini/gemini_mgr.py +++ b/colossalai/zero/gemini/gemini_mgr.py @@ -45,7 +45,7 @@ class GeminiManager: self._placement_policy = policy_cls(self, chunk_manager, self._mem_stats_collector, **placement_kwargs) self._compute_list: List[Tuple[Chunk, ...]] = [] self._compute_idx: int = -1 - self._async_works: Dict[Chunk, dist.work] = {} + self._async_works: Dict[Chunk, dist.Work] = {} self._h2d_volume = 0 self._d2h_volume = 0 @@ -183,6 +183,10 @@ class GeminiManager: def compute_idx(self) -> int: return self._compute_idx + @property + def async_works(self) -> Dict[Chunk, dist.Work]: + return self._async_works + @property def placement_policy(self) -> PlacementPolicy: return self._placement_policy diff --git a/colossalai/zero/gemini/placement_policy.py b/colossalai/zero/gemini/placement_policy.py index cae5cc202..9b1d1a6ab 100644 --- a/colossalai/zero/gemini/placement_policy.py +++ b/colossalai/zero/gemini/placement_policy.py @@ -5,6 +5,7 @@ from time import time from typing import Dict, List, Optional, Tuple, Type import torch +import torch.distributed as dist from colossalai.accelerator import get_accelerator from colossalai.legacy.utils.memory import colo_device_memory_capacity @@ -19,13 +20,11 @@ class PlacementPolicy(ABC): def __init__( self, - gemini_manager: "GeminiManager", # TODO @botbw: solve circular import chunk_manager: ChunkManager, mem_stats_collector: Optional[ChunkMemStatsCollector] = None, max_prefetch: int = 0, **kwargs, ) -> None: - self.gemini_manager = gemini_manager self.chunk_manager = chunk_manager self.mem_stats_collector: Optional[ChunkMemStatsCollector] = mem_stats_collector self.max_prefetch = max_prefetch @@ -40,14 +39,15 @@ class PlacementPolicy(ABC): ) -> None: raise NotImplementedError - def get_prefetch_chunks(self) -> List[Chunk]: + def get_prefetch_chunks( + self, is_warmup, compute_list: tuple, compute_idx: int, async_works: Dict[Chunk, dist.Work] + ) -> List[Chunk]: return [] # no prefetch by default class StaticPlacementPolicy(PlacementPolicy): def __init__( self, - gemini_manager: "GeminiManager", chunk_manager: ChunkManager, mem_stats_collector: Optional[ChunkMemStatsCollector] = None, max_prefetch: int = 0, @@ -56,9 +56,7 @@ class StaticPlacementPolicy(PlacementPolicy): offload_param_frac: float = 0.0, **kwargs, ) -> None: - super().__init__( - gemini_manager, chunk_manager, mem_stats_collector=mem_stats_collector, max_prefetch=max_prefetch - ) + super().__init__(chunk_manager, mem_stats_collector=mem_stats_collector, max_prefetch=max_prefetch) if offload_param_frac > 0.0 and (shard_param_frac != 1.0 or offload_optim_frac != 1.0): warnings.warn("offload_param_frac is ignored when shard_param_frac != 1.0 or offload_optim_frac != 1.0") offload_param_frac = 0.0 @@ -109,21 +107,22 @@ class StaticPlacementPolicy(PlacementPolicy): self.keep_gathered_chunk_mem = total_chunk_mem * (1 - self.shard_param_frac) self.keep_cuda_chunk_mem = total_chunk_mem * (1 - self.offload_param_frac) - def get_prefetch_chunks(self) -> List[Chunk]: - if self.gemini_manager.is_warmup(): # no prefetch during warmup since we need compute_list + def get_prefetch_chunks( + self, is_warmup: bool, compute_list: tuple, compute_idx: int, async_works: Dict[Chunk, dist.Work] + ) -> List[Chunk]: + if is_warmup: # no prefetch during warmup since we need compute_list return [] - can_prefetch = self.max_prefetch - len(self.gemini_manager._async_works) + can_prefetch = self.max_prefetch - len(async_works) prefetch = [] - for i in range(self.gemini_manager.compute_idx + 1, len(self.gemini_manager.compute_list)): - break_flag = False - for chunk in self.gemini_manager.compute_list[i]: + for i in range(compute_idx + 1, len(compute_list)): + for chunk in compute_list[i]: if len(prefetch) >= can_prefetch: - break_flag = True break if chunk not in prefetch and chunk not in self.chunk_manager.accessed_chunks: prefetch.append(chunk) - if break_flag: - break + else: + continue + break return prefetch @@ -132,7 +131,6 @@ class AutoPlacementPolicy(PlacementPolicy): def __init__( self, - gemini_manager: "GeminiManager", chunk_manager: ChunkManager, mem_stats_collector: Optional[ChunkMemStatsCollector] = None, max_prefetch: int = 0, @@ -140,9 +138,7 @@ class AutoPlacementPolicy(PlacementPolicy): steady_cuda_cap_ratio: float = 0.9, **kwargs, ) -> None: - super().__init__( - gemini_manager, chunk_manager, mem_stats_collector=mem_stats_collector, max_prefetch=max_prefetch - ) + super().__init__(chunk_manager, mem_stats_collector=mem_stats_collector, max_prefetch=max_prefetch) # model data will use 1-_warmup_non_model_data_ratio CUDA memory in warmup phase # you can set them by AutoPlacementPolicy.set_warmup_non_model_data_ratio() # and AutoPlacementPolicy.set_steady_cuda_cap_ratio() @@ -233,8 +229,10 @@ class AutoPlacementPolicy(PlacementPolicy): else: grads_device_map[p] = torch.device("cpu") - def get_prefetch_chunks(self) -> List[Chunk]: - if self.gemini_manager.is_warmup(): # no prefetch during warmup since we need compute_list + def get_prefetch_chunks( + self, is_warmup: bool, compute_list: tuple, compute_idx: int, async_works: Dict[Chunk, dist.Work] + ) -> List[Chunk]: + if is_warmup: # no prefetch during warmup since we need compute_list return [] # modified from self.evict_tensors cuda_capacity = self._steady_cuda_cap_ratio * colo_device_memory_capacity( @@ -246,19 +244,18 @@ class AutoPlacementPolicy(PlacementPolicy): avail_cuda_model_data = total_cuda_model_data - used_cuda_model_data prefetch_chunk_memory = 0 - can_prefetch = self.max_prefetch - len(self.gemini_manager._async_works) + can_prefetch = self.max_prefetch - len(async_works) prefetch = [] - for i in range(self.gemini_manager.compute_idx + 1, len(self.gemini_manager.compute_list)): - break_flag = False - for chunk in self.gemini_manager.compute_list[i]: - chunk: Chunk + for i in range(compute_idx + 1, len(compute_list)): + for chunk in compute_list[i]: if len(prefetch) >= can_prefetch or prefetch_chunk_memory + chunk.chunk_mem > avail_cuda_model_data: - break_flag = True break if chunk not in prefetch and chunk not in self.chunk_manager.accessed_chunks: + prefetch_chunk_memory += chunk.chunk_mem prefetch.append(chunk) - if break_flag: - break + else: + continue + break return prefetch diff --git a/examples/language/gpt/gemini/run_gemini.sh b/examples/language/gpt/gemini/run_gemini.sh index 5eaa4af4d..bffd26f59 100644 --- a/examples/language/gpt/gemini/run_gemini.sh +++ b/examples/language/gpt/gemini/run_gemini.sh @@ -6,7 +6,7 @@ export DISTPLAN=${DISTPLAN:-"CAI_Gemini"} export GPUNUM=${GPUNUM:-1} export BATCH_SIZE=${BATCH_SIZE:-16} export MODEL_TYPE=${MODEL_TYPE:-"gpt2_medium"} -export TRAIN_STEP=${TRAIN_STEP:-10} +export TRAIN_STEP=${TRAIN_STEP:-2} # export PYTHONPATH=$PWD:$PYTHONPATH diff --git a/examples/language/gpt/gemini/train_gpt_demo.py b/examples/language/gpt/gemini/train_gpt_demo.py index 6db74231a..bf1be87ba 100644 --- a/examples/language/gpt/gemini/train_gpt_demo.py +++ b/examples/language/gpt/gemini/train_gpt_demo.py @@ -66,11 +66,11 @@ class GPTLMLoss(nn.Module): def get_cpu_mem(): - return psutil.Process().memory_info().rss / 1024**2 + return psutil.Process().memory_info().rss / 1024**2 # MB unit def get_gpu_mem(): - return torch.cuda.memory_allocated() / 1024**2 + return torch.cuda.memory_allocated() / 1024**2 # MB unit def get_mem_info(prefix=""): @@ -78,6 +78,7 @@ def get_mem_info(prefix=""): def get_model_size(model: nn.Module): + # get the number of parameter of the model total_numel = 0 for module in model.modules(): for p in module.parameters(recurse=False):