From 2e68eebdfe9a8f68c0cb9a260ce7f2de4e6b6e27 Mon Sep 17 00:00:00 2001 From: hxwang Date: Thu, 16 May 2024 07:22:10 +0000 Subject: [PATCH 1/2] [chore] refactor & sync --- colossalai/zero/gemini/chunk/chunk.py | 1 + colossalai/zero/gemini/gemini_ddp.py | 3 +- colossalai/zero/gemini/gemini_hook.py | 55 ++++++++++--------- colossalai/zero/gemini/gemini_mgr.py | 27 +++++++-- colossalai/zero/gemini/placement_policy.py | 31 +++++++---- examples/language/gpt/gemini/commons/utils.py | 5 +- .../language/gpt/gemini/train_gpt_demo.py | 6 +- 7 files changed, 82 insertions(+), 46 deletions(-) diff --git a/colossalai/zero/gemini/chunk/chunk.py b/colossalai/zero/gemini/chunk/chunk.py index 299ea0518..81b6192ad 100644 --- a/colossalai/zero/gemini/chunk/chunk.py +++ b/colossalai/zero/gemini/chunk/chunk.py @@ -567,6 +567,7 @@ class Chunk: return self is __o def __repr__(self, detailed: bool = True): + return f"Chunk({self.count_id})" output = [ "Chunk Information:\n", "\tchunk size: {}, chunk dtype: {}, process group size: {}\n".format( diff --git a/colossalai/zero/gemini/gemini_ddp.py b/colossalai/zero/gemini/gemini_ddp.py index b75f69a3b..6bf0b4019 100644 --- a/colossalai/zero/gemini/gemini_ddp.py +++ b/colossalai/zero/gemini/gemini_ddp.py @@ -131,9 +131,10 @@ class GeminiDDP(ModelWrapper): offload_param_frac=offload_param_frac, warmup_non_model_data_ratio=warmup_non_model_data_ratio, steady_cuda_cap_ratio=steady_cuda_cap_ratio, + max_prefetch=max_prefetch ) self.force_outputs_fp32 = force_outputs_fp32 - self.param_op_hook = GeminiZeROHook(self.gemini_manager, max_prefetch=max_prefetch) + self.param_op_hook = GeminiZeROHook(self.gemini_manager) self.fp32_params: List[torch.Tensor] = list() self.fp16_params: List[ColoParameter] = list() self.grads_device: Dict[torch.Tensor, torch.device] = dict() diff --git a/colossalai/zero/gemini/gemini_hook.py b/colossalai/zero/gemini/gemini_hook.py index 82d890975..1d734bd83 100644 --- a/colossalai/zero/gemini/gemini_hook.py +++ b/colossalai/zero/gemini/gemini_hook.py @@ -1,7 +1,7 @@ from contextlib import contextmanager from enum import Enum from functools import partial -from typing import Dict, List +from typing import Dict, List, Iterable, Tuple import torch import torch.distributed as dist @@ -22,45 +22,55 @@ class TrainingPhase(Enum): logger = DistributedLogger("gemini_hook") +import os +rank = int(os.environ['RANK']) class GeminiZeROHook(ColoParamOpHook): - def __init__(self, gemini_manager: GeminiManager, max_prefetch: int = 0) -> None: + def __init__(self, gemini_manager: GeminiManager) -> None: super().__init__() self._gemini_manager = gemini_manager self._chunk_manager = gemini_manager.chunk_manager self._training_phase = TrainingPhase.FORWARD - self._max_prefetch = max_prefetch - self._async_works: Dict[Chunk, dist.work] = {} - def wait_chunks(self, chunks: List[Chunk]) -> List[Chunk]: - non_prefetched_chunks = [] - for chunk in chunks: - if chunk in self._async_works: - print(f"prefetched {chunk.count_id}") - self._async_works[chunk].wait() - del self._async_works[chunk] - else: - non_prefetched_chunks.append(chunk) - return non_prefetched_chunks def pre_op(self, params): + # map params to chunks params = [p for p in params if not is_ddp_ignored(p)] all_chunks = self._chunk_manager.get_chunks(params) + # wait for prefetched chunks, filter those are not prefetched - chunks_fetch_sync = tuple(self.wait_chunks(all_chunks)) + unique_chunks = set(all_chunks) + chunks_fetch_sync = self._gemini_manager.wait_chunks(all_chunks) + + # transfer state for p in params: self._chunk_manager.trans_tensor_state(p, TensorState.COMPUTE) self._gemini_manager.sample_overall_data() - self._gemini_manager.adjust_layout(all_chunks, record_anyway=self._max_prefetch > 0) - # fetch the rest chunks synchronously + + # evit chunks, aware of async fetched + self._gemini_manager.adjust_layout(all_chunks, record_anyway=self._gemini_manager.placement_policy.max_prefetch > 0) + + # fetch the rest synchronously for chunk in chunks_fetch_sync: self._chunk_manager.access_chunk(chunk) - chunks_fetch_async = self._gemini_manager.placement_policy.get_prefetch_chunks(max_prefetch=self._max_prefetch) + + # get possible chunks to prefetch + chunks_fetch_async = self._gemini_manager.placement_policy.get_prefetch_chunks() + if rank == 0 and not self._gemini_manager.is_warmup(): + print(f"compute_id: {self._gemini_manager.compute_idx} self._gemini_manager.compute_list: {self._gemini_manager.compute_list}") + print(f"{all_chunks=}") + print(f"accessed_chunks={self._chunk_manager.accessed_chunks}") + print(f"{chunks_fetch_sync=}") + print(f"{chunks_fetch_async=}") + print(f"works={list(self._gemini_manager._async_works.keys())}") + + # prefetch for chunk in chunks_fetch_async: maybe_work = self._chunk_manager.access_chunk(chunk, async_access=True) if maybe_work is not None: - self._async_works[chunk] = maybe_work - + self._gemini_manager.add_work(chunk, maybe_work) + if rank == 0 and not self._gemini_manager.is_warmup(): + print(f"post accessed_chunks={self._chunk_manager.accessed_chunks}") # record cuda model data of the current OP, including memory for prefetched chunks self._gemini_manager.record_model_data_volume() @@ -88,11 +98,6 @@ class GeminiZeROHook(ColoParamOpHook): @contextmanager def switch_training_phase(self, training_phase: TrainingPhase = TrainingPhase.BACKWARD): - if training_phase == TrainingPhase.FORWARD: - self._cur_param_idx = 0 - else: - self._cur_param_idx = len(self._param_visited_order) - 1 - old_training_phase = self._training_phase try: self._training_phase = training_phase diff --git a/colossalai/zero/gemini/gemini_mgr.py b/colossalai/zero/gemini/gemini_mgr.py index 0362d6523..6640bf03b 100644 --- a/colossalai/zero/gemini/gemini_mgr.py +++ b/colossalai/zero/gemini/gemini_mgr.py @@ -1,8 +1,9 @@ import functools from time import time -from typing import Dict, List, Optional, Tuple +from typing import Dict, List, Optional, Tuple, Iterable import torch +import torch.distributed as dist from .chunk import Chunk, ChunkManager from .memory_tracer import ChunkMemStatsCollector, MemStats @@ -41,9 +42,10 @@ class GeminiManager: self._mem_stats_collector = ( ChunkMemStatsCollector(chunk_manager, self._memstats) if policy_cls.need_mem_stats else None ) - self._placement_policy = policy_cls(chunk_manager, self._mem_stats_collector, **placement_kwargs) + self._placement_policy = policy_cls(self, chunk_manager, self._mem_stats_collector, **placement_kwargs) self._compute_list: List[Tuple[Chunk, ...]] = [] self._compute_idx: int = -1 + self._async_works: Dict[Chunk, dist.work] = {} self._h2d_volume = 0 self._d2h_volume = 0 @@ -98,11 +100,13 @@ class GeminiManager: # find stateful tensor in state COMPUTE start = time() self._record_warmup_chunks_order(chunks, record_anyway=record_anyway) - cuda_demand, hold_cuda_tensor_list = self._get_layout_info(self._compute_idx, self._warmup, chunks) + cuda_demand, can_evict_chunks = self._get_layout_info(self._compute_idx, self._warmup, chunks) + # don't evict chunks that are asynchronously fetched + can_evict_chunks = [chunk for chunk in can_evict_chunks if chunk not in self._async_works] self._layout_time += time() - start vol, evict_time = self._placement_policy.evict_tensors( - can_evict_chunks=hold_cuda_tensor_list, + can_evict_chunks=can_evict_chunks, cuda_demand=cuda_demand, warmup=self._warmup, compute_list=self._compute_list, @@ -114,6 +118,21 @@ class GeminiManager: # move COMPUTE tensors to CUDA self._h2d_volume += cuda_demand + def wait_chunks(self, chunks: Iterable[Chunk]) -> Tuple[Chunk]: + non_prefetched_chunks = [] + for chunk in chunks: + if chunk in self._async_works: + self._async_works[chunk].wait() + del self._async_works[chunk] + else: + non_prefetched_chunks.append(chunk) + return tuple(non_prefetched_chunks) + + def add_work(self, chunk: Chunk, work: dist.Work): + assert work is not None + assert chunk not in self._async_works + self._async_works[chunk] = work + @functools.lru_cache(maxsize=None) def _get_layout_info(self, compute_idx: int, warmup: bool, chunks: Tuple[Chunk, ...]): start = time() diff --git a/colossalai/zero/gemini/placement_policy.py b/colossalai/zero/gemini/placement_policy.py index 452687b7d..aad97321c 100644 --- a/colossalai/zero/gemini/placement_policy.py +++ b/colossalai/zero/gemini/placement_policy.py @@ -13,15 +13,16 @@ from colossalai.zero.gemini.chunk import Chunk from .chunk import Chunk, ChunkManager from .memory_tracer import ChunkMemStatsCollector - class PlacementPolicy(ABC): need_mem_stats: bool = False def __init__( - self, chunk_manager: ChunkManager, mem_stats_collector: Optional[ChunkMemStatsCollector] = None, **kwargs + self, gemini_manager: 'GeminiManager', chunk_manager: ChunkManager, mem_stats_collector: Optional[ChunkMemStatsCollector] = None, max_prefetch:int = 0, **kwargs ) -> None: + self.gemini_manager = gemini_manager self.chunk_manager = chunk_manager self.mem_stats_collector: Optional[ChunkMemStatsCollector] = mem_stats_collector + self.max_prefetch = max_prefetch @abstractmethod def evict_tensors(self, can_evict_chunks: List[Chunk], **kwargs) -> Tuple[int, float]: @@ -34,21 +35,25 @@ class PlacementPolicy(ABC): raise NotImplementedError @abstractmethod - def get_prefetch_chunks(self, max_prefetch: int) -> List[Chunk]: + def get_prefetch_chunks(self) -> List[Chunk]: raise NotImplementedError +import os +rank = int(os.environ["RANK"]) class StaticPlacementPolicy(PlacementPolicy): def __init__( self, + gemini_manager: 'GeminiManager', chunk_manager: ChunkManager, mem_stats_collector: Optional[ChunkMemStatsCollector] = None, + max_prefetch: int = 0, shard_param_frac: float = 1.0, offload_optim_frac: float = 0.0, offload_param_frac: float = 0.0, **kwargs, ) -> None: - super().__init__(chunk_manager, mem_stats_collector=mem_stats_collector) + super().__init__(gemini_manager, chunk_manager, mem_stats_collector=mem_stats_collector, max_prefetch=max_prefetch) if offload_param_frac > 0.0 and (shard_param_frac != 1.0 or offload_optim_frac != 1.0): warnings.warn("offload_param_frac is ignored when shard_param_frac != 1.0 or offload_optim_frac != 1.0") offload_param_frac = 0.0 @@ -99,15 +104,17 @@ class StaticPlacementPolicy(PlacementPolicy): self.keep_gathered_chunk_mem = total_chunk_mem * (1 - self.shard_param_frac) self.keep_cuda_chunk_mem = total_chunk_mem * (1 - self.offload_param_frac) - def get_prefetch_chunks(self, max_prefetch: int) -> List[Chunk]: + def get_prefetch_chunks(self) -> List[Chunk]: + if self.gemini_manager.is_warmup(): # no prefetch during warmup since we need compute_list + return [] prefetch = [] - for i in range(self.chunk_manager.compute_idx + 1, len(self.chunk_manager.compute_list)): - for chunk in self.chunk_manager.compute_list[i]: - if len(prefetch) >= max_prefetch: + for i in range(self.gemini_manager.compute_idx + 1, len(self.gemini_manager.compute_list)): + for chunk in self.gemini_manager.compute_list[i]: + if len(prefetch) >= self.max_prefetch: break - if chunk not in prefetch: + if chunk not in prefetch and chunk not in self.chunk_manager.accessed_chunks: prefetch.append(chunk) - if len(prefetch) >= max_prefetch: + if len(prefetch) >= self.max_prefetch: break return prefetch @@ -117,13 +124,15 @@ class AutoPlacementPolicy(PlacementPolicy): def __init__( self, + gemini_manager: 'GeminiManager', chunk_manager: ChunkManager, mem_stats_collector: Optional[ChunkMemStatsCollector] = None, + max_prefetch: int = 0, warmup_non_model_data_ratio: float = 0.8, steady_cuda_cap_ratio: float = 0.9, **kwargs, ) -> None: - super().__init__(chunk_manager, mem_stats_collector=mem_stats_collector) + super().__init__(gemini_manager, chunk_manager, mem_stats_collector=mem_stats_collector, max_prefetch=max_prefetch) # model data will use 1-_warmup_non_model_data_ratio CUDA memory in warmup phase # you can set them by AutoPlacementPolicy.set_warmup_non_model_data_ratio() # and AutoPlacementPolicy.set_steady_cuda_cap_ratio() diff --git a/examples/language/gpt/gemini/commons/utils.py b/examples/language/gpt/gemini/commons/utils.py index 7ed5fdb92..03054c0a2 100644 --- a/examples/language/gpt/gemini/commons/utils.py +++ b/examples/language/gpt/gemini/commons/utils.py @@ -30,8 +30,9 @@ def get_profile_context(enable_flag, warmup_steps, active_steps, save_dir): activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], schedule=schedule(wait=0, warmup=warmup_steps, active=active_steps), on_trace_ready=tensorboard_trace_handler(save_dir), - record_shapes=True, - profile_memory=True, + # record_shapes=True, + # profile_memory=True, + with_stack=True, ) else: return nullcontext(DummyProfiler()) diff --git a/examples/language/gpt/gemini/train_gpt_demo.py b/examples/language/gpt/gemini/train_gpt_demo.py index 4911ff124..6db74231a 100644 --- a/examples/language/gpt/gemini/train_gpt_demo.py +++ b/examples/language/gpt/gemini/train_gpt_demo.py @@ -129,7 +129,7 @@ def main(): WARMUP_STEPS = 1 assert WARMUP_STEPS < NUM_STEPS, "warmup steps should smaller than the total steps" assert (NUM_STEPS - WARMUP_STEPS) % 2 == 1, "the number of valid steps should be odd to take the median" - PROF_FLAG = False # The flag of profiling, False by default + PROF_FLAG = True # The flag of profiling, False by default disable_existing_loggers() colossalai.launch_from_torch() @@ -166,7 +166,7 @@ def main(): stage=zero_stage, reduce_bucket_size_in_m=12, overlap_communication=True, verbose=True ) elif args.distplan == "CAI_Gemini": - plugin = GeminiPlugin(search_range_m=128, hidden_dim=model.config.n_embd) + plugin = GeminiPlugin(search_range_m=128, hidden_dim=model.config.n_embd, max_prefetch=1) else: raise RuntimeError @@ -248,7 +248,7 @@ def main(): prof.step() tflops_list.sort() - median_index = ((NUM_STEPS - WARMUP_STEPS) >> 1) + WARMUP_STEPS + median_index = min(((NUM_STEPS - WARMUP_STEPS) >> 1) + WARMUP_STEPS, len(tflops_list) - 1) logger.info(f"Median TFLOPS is {tflops_list[median_index]:.3f}") torch.cuda.synchronize() From 6bbe956316d62202f1b19ea1033419d6a3ee91ea Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 16 May 2024 07:26:19 +0000 Subject: [PATCH 2/2] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- colossalai/zero/gemini/gemini_ddp.py | 2 +- colossalai/zero/gemini/gemini_hook.py | 20 ++++++++++--------- colossalai/zero/gemini/gemini_mgr.py | 4 ++-- colossalai/zero/gemini/placement_policy.py | 23 +++++++++++++++++----- 4 files changed, 32 insertions(+), 17 deletions(-) diff --git a/colossalai/zero/gemini/gemini_ddp.py b/colossalai/zero/gemini/gemini_ddp.py index 6bf0b4019..3a0ae59fc 100644 --- a/colossalai/zero/gemini/gemini_ddp.py +++ b/colossalai/zero/gemini/gemini_ddp.py @@ -131,7 +131,7 @@ class GeminiDDP(ModelWrapper): offload_param_frac=offload_param_frac, warmup_non_model_data_ratio=warmup_non_model_data_ratio, steady_cuda_cap_ratio=steady_cuda_cap_ratio, - max_prefetch=max_prefetch + max_prefetch=max_prefetch, ) self.force_outputs_fp32 = force_outputs_fp32 self.param_op_hook = GeminiZeROHook(self.gemini_manager) diff --git a/colossalai/zero/gemini/gemini_hook.py b/colossalai/zero/gemini/gemini_hook.py index 1d734bd83..e6b8cf8ef 100644 --- a/colossalai/zero/gemini/gemini_hook.py +++ b/colossalai/zero/gemini/gemini_hook.py @@ -1,10 +1,9 @@ from contextlib import contextmanager from enum import Enum from functools import partial -from typing import Dict, List, Iterable, Tuple +from typing import List import torch -import torch.distributed as dist from colossalai.logging import DistributedLogger from colossalai.tensor.param_op_hook import ColoParamOpHook @@ -12,8 +11,6 @@ from colossalai.utils import is_ddp_ignored from colossalai.zero.gemini import TensorState from colossalai.zero.gemini.gemini_mgr import GeminiManager -from .chunk import Chunk - class TrainingPhase(Enum): FORWARD = 0 @@ -23,7 +20,9 @@ class TrainingPhase(Enum): logger = DistributedLogger("gemini_hook") import os -rank = int(os.environ['RANK']) + +rank = int(os.environ["RANK"]) + class GeminiZeROHook(ColoParamOpHook): def __init__(self, gemini_manager: GeminiManager) -> None: @@ -32,14 +31,13 @@ class GeminiZeROHook(ColoParamOpHook): self._chunk_manager = gemini_manager.chunk_manager self._training_phase = TrainingPhase.FORWARD - def pre_op(self, params): # map params to chunks params = [p for p in params if not is_ddp_ignored(p)] all_chunks = self._chunk_manager.get_chunks(params) # wait for prefetched chunks, filter those are not prefetched - unique_chunks = set(all_chunks) + set(all_chunks) chunks_fetch_sync = self._gemini_manager.wait_chunks(all_chunks) # transfer state @@ -48,7 +46,9 @@ class GeminiZeROHook(ColoParamOpHook): self._gemini_manager.sample_overall_data() # evit chunks, aware of async fetched - self._gemini_manager.adjust_layout(all_chunks, record_anyway=self._gemini_manager.placement_policy.max_prefetch > 0) + self._gemini_manager.adjust_layout( + all_chunks, record_anyway=self._gemini_manager.placement_policy.max_prefetch > 0 + ) # fetch the rest synchronously for chunk in chunks_fetch_sync: @@ -57,7 +57,9 @@ class GeminiZeROHook(ColoParamOpHook): # get possible chunks to prefetch chunks_fetch_async = self._gemini_manager.placement_policy.get_prefetch_chunks() if rank == 0 and not self._gemini_manager.is_warmup(): - print(f"compute_id: {self._gemini_manager.compute_idx} self._gemini_manager.compute_list: {self._gemini_manager.compute_list}") + print( + f"compute_id: {self._gemini_manager.compute_idx} self._gemini_manager.compute_list: {self._gemini_manager.compute_list}" + ) print(f"{all_chunks=}") print(f"accessed_chunks={self._chunk_manager.accessed_chunks}") print(f"{chunks_fetch_sync=}") diff --git a/colossalai/zero/gemini/gemini_mgr.py b/colossalai/zero/gemini/gemini_mgr.py index 6640bf03b..11bde789c 100644 --- a/colossalai/zero/gemini/gemini_mgr.py +++ b/colossalai/zero/gemini/gemini_mgr.py @@ -1,6 +1,6 @@ import functools from time import time -from typing import Dict, List, Optional, Tuple, Iterable +from typing import Dict, Iterable, List, Optional, Tuple import torch import torch.distributed as dist @@ -101,7 +101,7 @@ class GeminiManager: start = time() self._record_warmup_chunks_order(chunks, record_anyway=record_anyway) cuda_demand, can_evict_chunks = self._get_layout_info(self._compute_idx, self._warmup, chunks) - # don't evict chunks that are asynchronously fetched + # don't evict chunks that are asynchronously fetched can_evict_chunks = [chunk for chunk in can_evict_chunks if chunk not in self._async_works] self._layout_time += time() - start diff --git a/colossalai/zero/gemini/placement_policy.py b/colossalai/zero/gemini/placement_policy.py index aad97321c..e5f61a033 100644 --- a/colossalai/zero/gemini/placement_policy.py +++ b/colossalai/zero/gemini/placement_policy.py @@ -13,11 +13,17 @@ from colossalai.zero.gemini.chunk import Chunk from .chunk import Chunk, ChunkManager from .memory_tracer import ChunkMemStatsCollector + class PlacementPolicy(ABC): need_mem_stats: bool = False def __init__( - self, gemini_manager: 'GeminiManager', chunk_manager: ChunkManager, mem_stats_collector: Optional[ChunkMemStatsCollector] = None, max_prefetch:int = 0, **kwargs + self, + gemini_manager: "GeminiManager", + chunk_manager: ChunkManager, + mem_stats_collector: Optional[ChunkMemStatsCollector] = None, + max_prefetch: int = 0, + **kwargs, ) -> None: self.gemini_manager = gemini_manager self.chunk_manager = chunk_manager @@ -38,13 +44,16 @@ class PlacementPolicy(ABC): def get_prefetch_chunks(self) -> List[Chunk]: raise NotImplementedError + import os + rank = int(os.environ["RANK"]) + class StaticPlacementPolicy(PlacementPolicy): def __init__( self, - gemini_manager: 'GeminiManager', + gemini_manager: "GeminiManager", chunk_manager: ChunkManager, mem_stats_collector: Optional[ChunkMemStatsCollector] = None, max_prefetch: int = 0, @@ -53,7 +62,9 @@ class StaticPlacementPolicy(PlacementPolicy): offload_param_frac: float = 0.0, **kwargs, ) -> None: - super().__init__(gemini_manager, chunk_manager, mem_stats_collector=mem_stats_collector, max_prefetch=max_prefetch) + super().__init__( + gemini_manager, chunk_manager, mem_stats_collector=mem_stats_collector, max_prefetch=max_prefetch + ) if offload_param_frac > 0.0 and (shard_param_frac != 1.0 or offload_optim_frac != 1.0): warnings.warn("offload_param_frac is ignored when shard_param_frac != 1.0 or offload_optim_frac != 1.0") offload_param_frac = 0.0 @@ -124,7 +135,7 @@ class AutoPlacementPolicy(PlacementPolicy): def __init__( self, - gemini_manager: 'GeminiManager', + gemini_manager: "GeminiManager", chunk_manager: ChunkManager, mem_stats_collector: Optional[ChunkMemStatsCollector] = None, max_prefetch: int = 0, @@ -132,7 +143,9 @@ class AutoPlacementPolicy(PlacementPolicy): steady_cuda_cap_ratio: float = 0.9, **kwargs, ) -> None: - super().__init__(gemini_manager, chunk_manager, mem_stats_collector=mem_stats_collector, max_prefetch=max_prefetch) + super().__init__( + gemini_manager, chunk_manager, mem_stats_collector=mem_stats_collector, max_prefetch=max_prefetch + ) # model data will use 1-_warmup_non_model_data_ratio CUDA memory in warmup phase # you can set them by AutoPlacementPolicy.set_warmup_non_model_data_ratio() # and AutoPlacementPolicy.set_steady_cuda_cap_ratio()