diff --git a/colossalai/booster/plugin/gemini_plugin.py b/colossalai/booster/plugin/gemini_plugin.py index 964cd302a..aeef14487 100644 --- a/colossalai/booster/plugin/gemini_plugin.py +++ b/colossalai/booster/plugin/gemini_plugin.py @@ -329,6 +329,7 @@ class GeminiPlugin(DPPluginBase): chunk_init_device: Optional[torch.device] = None, placement_policy: str = "static", enable_gradient_accumulation: bool = False, + max_prefetch: int = 0, shard_param_frac: float = 1.0, # only for static placement offload_optim_frac: float = 0.0, # only for static placement offload_param_frac: float = 0.0, # only for static placement @@ -386,6 +387,7 @@ class GeminiPlugin(DPPluginBase): memstats=memstats, mixed_precision=PRECISION_STR_TO_DTYPE[precision], master_weights=master_weights, + max_prefetch=max_prefetch, ) self.zero_optim_config = dict( gpu_margin_mem_ratio=gpu_margin_mem_ratio, diff --git a/colossalai/zero/gemini/chunk/chunk.py b/colossalai/zero/gemini/chunk/chunk.py index cad2622f2..299ea0518 100644 --- a/colossalai/zero/gemini/chunk/chunk.py +++ b/colossalai/zero/gemini/chunk/chunk.py @@ -357,14 +357,14 @@ class Chunk: else: raise NotImplementedError - def access_chunk(self): + def access_chunk(self, async_access: bool = False) -> Optional[dist.Work]: """Make the chunk usable for the parameters inside it. It's an operation done in CUDA.""" # sanity check assert self.chunk_temp is None - if not self.is_gathered: - self.__gather() + return self.__gather(async_op=async_access) self.__update_tensors_ptr() + return None def release_chunk(self): """Release the usable chunk. It's an operation done in CUDA.""" @@ -498,17 +498,19 @@ class Chunk: def get_tensors(self) -> List[torch.Tensor]: return list(self.tensors_info.keys()) - def __gather(self): + def __gather(self, async_op: bool = False) -> Optional[dist.Work]: if not self.is_gathered: # sanity check assert self.cuda_shard is not None alloc_storage(self.cuda_global_chunk) gather_list = list(torch.chunk(input=self.cuda_global_chunk, chunks=self.pg_size, dim=0)) - dist.all_gather(gather_list, self.cuda_shard, self.torch_pg) + work = dist.all_gather(gather_list, self.cuda_shard, self.torch_pg, async_op=async_op) self.cuda_shard = None self.is_gathered = True + return work + return None def __scatter(self): if self.keep_gathered: diff --git a/colossalai/zero/gemini/chunk/manager.py b/colossalai/zero/gemini/chunk/manager.py index 333a3f224..c7bdd5e1f 100644 --- a/colossalai/zero/gemini/chunk/manager.py +++ b/colossalai/zero/gemini/chunk/manager.py @@ -111,15 +111,16 @@ class ChunkManager: for group_name in self.chunk_groups: self.__close_one_chunk(self.chunk_groups[group_name][-1]) - def access_chunk(self, chunk: Chunk) -> None: + def access_chunk(self, chunk: Chunk, async_access: bool = False) -> Optional[dist.Work]: """Make the chunk can be used for calculation.""" if chunk in self.accessed_chunks: - return + return None self.__sub_memory_usage(chunk.memory_usage) if chunk.device_type == "cpu": chunk.shard_move(get_accelerator().get_current_device()) - self.__add_accessed_chunk(chunk) + maybe_work = self.__add_accessed_chunk(chunk, async_access=async_access) self.__add_memory_usage(chunk.memory_usage) + return maybe_work def release_chunk(self, chunk: Chunk) -> None: """Scatter the chunk in CUDA.""" @@ -251,10 +252,11 @@ class ChunkManager: for k, v in usage.items(): self.total_mem[k] += v - def __add_accessed_chunk(self, chunk: Chunk): - chunk.access_chunk() + def __add_accessed_chunk(self, chunk: Chunk, async_access: bool = False) -> Optional[dist.Work]: + maybe_work = chunk.access_chunk(async_access=async_access) self.accessed_chunks.add(chunk) self.accessed_mem += chunk.chunk_mem + return maybe_work def __sub_accessed_chunk(self, chunk: Chunk): chunk.release_chunk() diff --git a/colossalai/zero/gemini/gemini_ddp.py b/colossalai/zero/gemini/gemini_ddp.py index c1029097a..3a0ae59fc 100644 --- a/colossalai/zero/gemini/gemini_ddp.py +++ b/colossalai/zero/gemini/gemini_ddp.py @@ -78,6 +78,7 @@ class GeminiDDP(ModelWrapper): chunk_init_device: torch.device = torch.device("cpu"), placement_policy: str = "static", enable_gradient_accumulation: bool = False, + max_prefetch: int = 0, shard_param_frac: float = 1.0, # only for static placement offload_optim_frac: float = 0.0, # only for static placement offload_param_frac: float = 0.0, # only for static placement @@ -130,6 +131,7 @@ class GeminiDDP(ModelWrapper): offload_param_frac=offload_param_frac, warmup_non_model_data_ratio=warmup_non_model_data_ratio, steady_cuda_cap_ratio=steady_cuda_cap_ratio, + max_prefetch=max_prefetch, ) self.force_outputs_fp32 = force_outputs_fp32 self.param_op_hook = GeminiZeROHook(self.gemini_manager) diff --git a/colossalai/zero/gemini/gemini_hook.py b/colossalai/zero/gemini/gemini_hook.py index 480a14511..27a19c132 100644 --- a/colossalai/zero/gemini/gemini_hook.py +++ b/colossalai/zero/gemini/gemini_hook.py @@ -5,6 +5,7 @@ from typing import List import torch +from colossalai.logging import DistributedLogger from colossalai.tensor.param_op_hook import ColoParamOpHook from colossalai.utils import is_ddp_ignored from colossalai.zero.gemini import TensorState @@ -16,6 +17,9 @@ class TrainingPhase(Enum): BACKWARD = 1 +logger = DistributedLogger("gemini_hook") + + class GeminiZeROHook(ColoParamOpHook): def __init__(self, gemini_manager: GeminiManager) -> None: super().__init__() @@ -24,16 +28,37 @@ class GeminiZeROHook(ColoParamOpHook): self._training_phase = TrainingPhase.FORWARD def pre_op(self, params): + # map params to chunks params = [p for p in params if not is_ddp_ignored(p)] - chunks = self._chunk_manager.get_chunks(params) + all_chunks = self._chunk_manager.get_chunks(params) + + # wait for prefetched chunks, filter those are not prefetched + chunks_fetch_sync = self._gemini_manager.wait_chunks(all_chunks) + + # transfer state for p in params: self._chunk_manager.trans_tensor_state(p, TensorState.COMPUTE) self._gemini_manager.sample_overall_data() - self._gemini_manager.adjust_layout(chunks) - for chunk in chunks: + + # evit chunks, aware of async fetched + self._gemini_manager.adjust_layout( + all_chunks, record_anyway=self._gemini_manager.placement_policy.max_prefetch > 0 + ) + + # fetch the rest synchronously + for chunk in chunks_fetch_sync: self._chunk_manager.access_chunk(chunk) - # record cuda model data of the current OP + # get possible chunks to prefetch + chunks_fetch_async = self._gemini_manager.placement_policy.get_prefetch_chunks() + + # prefetch + for chunk in chunks_fetch_async: + maybe_work = self._chunk_manager.access_chunk(chunk, async_access=True) + if maybe_work is not None: + self._gemini_manager.add_work(chunk, maybe_work) + + # record cuda model data of the current OP, including memory for prefetched chunks self._gemini_manager.record_model_data_volume() def post_op(self, params): diff --git a/colossalai/zero/gemini/gemini_mgr.py b/colossalai/zero/gemini/gemini_mgr.py index 150932e3d..11bde789c 100644 --- a/colossalai/zero/gemini/gemini_mgr.py +++ b/colossalai/zero/gemini/gemini_mgr.py @@ -1,12 +1,13 @@ import functools from time import time -from typing import Dict, List, Optional, Tuple +from typing import Dict, Iterable, List, Optional, Tuple import torch +import torch.distributed as dist from .chunk import Chunk, ChunkManager from .memory_tracer import ChunkMemStatsCollector, MemStats -from .placement_policy import PlacementPolicyFactory +from .placement_policy import PlacementPolicy, PlacementPolicyFactory class GeminiManager: @@ -41,9 +42,10 @@ class GeminiManager: self._mem_stats_collector = ( ChunkMemStatsCollector(chunk_manager, self._memstats) if policy_cls.need_mem_stats else None ) - self._placement_policy = policy_cls(chunk_manager, self._mem_stats_collector, **placement_kwargs) + self._placement_policy = policy_cls(self, chunk_manager, self._mem_stats_collector, **placement_kwargs) self._compute_list: List[Tuple[Chunk, ...]] = [] self._compute_idx: int = -1 + self._async_works: Dict[Chunk, dist.work] = {} self._h2d_volume = 0 self._d2h_volume = 0 @@ -91,18 +93,20 @@ class GeminiManager: self._warmup = False self.reset_attributes() - def adjust_layout(self, chunks: Tuple[Chunk, ...]) -> None: + def adjust_layout(self, chunks: Tuple[Chunk, ...], record_anyway: bool = False) -> None: """Adjust the layout of stateful tensors according to the information provided by mem_stats_collector, which should belongs to a Sharded Model. """ # find stateful tensor in state COMPUTE start = time() - self._record_chunks_order(chunks) - cuda_demand, hold_cuda_tensor_list = self._get_layout_info(self._compute_idx, self._warmup, chunks) + self._record_warmup_chunks_order(chunks, record_anyway=record_anyway) + cuda_demand, can_evict_chunks = self._get_layout_info(self._compute_idx, self._warmup, chunks) + # don't evict chunks that are asynchronously fetched + can_evict_chunks = [chunk for chunk in can_evict_chunks if chunk not in self._async_works] self._layout_time += time() - start vol, evict_time = self._placement_policy.evict_tensors( - can_evict_chunks=hold_cuda_tensor_list, + can_evict_chunks=can_evict_chunks, cuda_demand=cuda_demand, warmup=self._warmup, compute_list=self._compute_list, @@ -114,6 +118,21 @@ class GeminiManager: # move COMPUTE tensors to CUDA self._h2d_volume += cuda_demand + def wait_chunks(self, chunks: Iterable[Chunk]) -> Tuple[Chunk]: + non_prefetched_chunks = [] + for chunk in chunks: + if chunk in self._async_works: + self._async_works[chunk].wait() + del self._async_works[chunk] + else: + non_prefetched_chunks.append(chunk) + return tuple(non_prefetched_chunks) + + def add_work(self, chunk: Chunk, work: dist.Work): + assert work is not None + assert chunk not in self._async_works + self._async_works[chunk] = work + @functools.lru_cache(maxsize=None) def _get_layout_info(self, compute_idx: int, warmup: bool, chunks: Tuple[Chunk, ...]): start = time() @@ -133,9 +152,9 @@ class GeminiManager: can_evict_chunks = self._chunk_manager.get_cuda_movable_chunks() return cuda_demand, can_evict_chunks - def _record_chunks_order(self, chunks: Tuple[Chunk, ...]) -> None: + def _record_warmup_chunks_order(self, chunks: Tuple[Chunk, ...], record_anyway: bool = False) -> None: self._compute_idx += 1 - if self._warmup and self._placement_policy.need_mem_stats: + if self._warmup and (self._placement_policy.need_mem_stats or record_anyway): self._compute_list.append(chunks) def sample_overall_data(self): @@ -156,6 +175,18 @@ class GeminiManager: return self._mem_stats_collector.cuda_margin_mem return None + @property + def compute_list(self) -> List[Tuple[Chunk, ...]]: + return self._compute_list + + @property + def compute_idx(self) -> int: + return self._compute_idx + + @property + def placement_policy(self) -> PlacementPolicy: + return self._placement_policy + @property def is_cuda_margin_mem_avail(self) -> bool: return self._placement_policy.need_mem_stats diff --git a/colossalai/zero/gemini/placement_policy.py b/colossalai/zero/gemini/placement_policy.py index 388999549..e9e871b46 100644 --- a/colossalai/zero/gemini/placement_policy.py +++ b/colossalai/zero/gemini/placement_policy.py @@ -18,10 +18,17 @@ class PlacementPolicy(ABC): need_mem_stats: bool = False def __init__( - self, chunk_manager: ChunkManager, mem_stats_collector: Optional[ChunkMemStatsCollector] = None, **kwargs + self, + gemini_manager: "GeminiManager", + chunk_manager: ChunkManager, + mem_stats_collector: Optional[ChunkMemStatsCollector] = None, + max_prefetch: int = 0, + **kwargs, ) -> None: + self.gemini_manager = gemini_manager self.chunk_manager = chunk_manager self.mem_stats_collector: Optional[ChunkMemStatsCollector] = mem_stats_collector + self.max_prefetch = max_prefetch @abstractmethod def evict_tensors(self, can_evict_chunks: List[Chunk], **kwargs) -> Tuple[int, float]: @@ -33,18 +40,26 @@ class PlacementPolicy(ABC): ) -> None: raise NotImplementedError + @abstractmethod + def get_prefetch_chunks(self) -> List[Chunk]: + raise NotImplementedError + class StaticPlacementPolicy(PlacementPolicy): def __init__( self, + gemini_manager: "GeminiManager", chunk_manager: ChunkManager, mem_stats_collector: Optional[ChunkMemStatsCollector] = None, + max_prefetch: int = 0, shard_param_frac: float = 1.0, offload_optim_frac: float = 0.0, offload_param_frac: float = 0.0, **kwargs, ) -> None: - super().__init__(chunk_manager, mem_stats_collector=mem_stats_collector) + super().__init__( + gemini_manager, chunk_manager, mem_stats_collector=mem_stats_collector, max_prefetch=max_prefetch + ) if offload_param_frac > 0.0 and (shard_param_frac != 1.0 or offload_optim_frac != 1.0): warnings.warn("offload_param_frac is ignored when shard_param_frac != 1.0 or offload_optim_frac != 1.0") offload_param_frac = 0.0 @@ -95,19 +110,38 @@ class StaticPlacementPolicy(PlacementPolicy): self.keep_gathered_chunk_mem = total_chunk_mem * (1 - self.shard_param_frac) self.keep_cuda_chunk_mem = total_chunk_mem * (1 - self.offload_param_frac) + def get_prefetch_chunks(self) -> List[Chunk]: + if self.gemini_manager.is_warmup(): # no prefetch during warmup since we need compute_list + return [] + can_prefetch = self.max_prefetch - len(self.gemini_manager._async_works) + prefetch = [] + for i in range(self.gemini_manager.compute_idx + 1, len(self.gemini_manager.compute_list)): + for chunk in self.gemini_manager.compute_list[i]: + if len(prefetch) >= can_prefetch: + break + if chunk not in prefetch and chunk not in self.chunk_manager.accessed_chunks: + prefetch.append(chunk) + if len(prefetch) >= can_prefetch: + break + return prefetch + class AutoPlacementPolicy(PlacementPolicy): need_mem_stats: bool = True def __init__( self, + gemini_manager: "GeminiManager", chunk_manager: ChunkManager, mem_stats_collector: Optional[ChunkMemStatsCollector] = None, + max_prefetch: int = 0, warmup_non_model_data_ratio: float = 0.8, steady_cuda_cap_ratio: float = 0.9, **kwargs, ) -> None: - super().__init__(chunk_manager, mem_stats_collector=mem_stats_collector) + super().__init__( + gemini_manager, chunk_manager, mem_stats_collector=mem_stats_collector, max_prefetch=max_prefetch + ) # model data will use 1-_warmup_non_model_data_ratio CUDA memory in warmup phase # you can set them by AutoPlacementPolicy.set_warmup_non_model_data_ratio() # and AutoPlacementPolicy.set_steady_cuda_cap_ratio() @@ -198,6 +232,9 @@ class AutoPlacementPolicy(PlacementPolicy): else: grads_device_map[p] = torch.device("cpu") + def get_prefetch_chunks(self, max_prefetch: int) -> List[Chunk]: + return [] # TODO @botbw: implement prefetching for auto + class PlacementPolicyFactory: policies: Dict[str, Type[PlacementPolicy]] = { diff --git a/examples/language/gpt/gemini/commons/utils.py b/examples/language/gpt/gemini/commons/utils.py index 7ed5fdb92..03054c0a2 100644 --- a/examples/language/gpt/gemini/commons/utils.py +++ b/examples/language/gpt/gemini/commons/utils.py @@ -30,8 +30,9 @@ def get_profile_context(enable_flag, warmup_steps, active_steps, save_dir): activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], schedule=schedule(wait=0, warmup=warmup_steps, active=active_steps), on_trace_ready=tensorboard_trace_handler(save_dir), - record_shapes=True, - profile_memory=True, + # record_shapes=True, + # profile_memory=True, + with_stack=True, ) else: return nullcontext(DummyProfiler()) diff --git a/examples/language/gpt/gemini/train_gpt_demo.py b/examples/language/gpt/gemini/train_gpt_demo.py index 4911ff124..6db74231a 100644 --- a/examples/language/gpt/gemini/train_gpt_demo.py +++ b/examples/language/gpt/gemini/train_gpt_demo.py @@ -129,7 +129,7 @@ def main(): WARMUP_STEPS = 1 assert WARMUP_STEPS < NUM_STEPS, "warmup steps should smaller than the total steps" assert (NUM_STEPS - WARMUP_STEPS) % 2 == 1, "the number of valid steps should be odd to take the median" - PROF_FLAG = False # The flag of profiling, False by default + PROF_FLAG = True # The flag of profiling, False by default disable_existing_loggers() colossalai.launch_from_torch() @@ -166,7 +166,7 @@ def main(): stage=zero_stage, reduce_bucket_size_in_m=12, overlap_communication=True, verbose=True ) elif args.distplan == "CAI_Gemini": - plugin = GeminiPlugin(search_range_m=128, hidden_dim=model.config.n_embd) + plugin = GeminiPlugin(search_range_m=128, hidden_dim=model.config.n_embd, max_prefetch=1) else: raise RuntimeError @@ -248,7 +248,7 @@ def main(): prof.step() tflops_list.sort() - median_index = ((NUM_STEPS - WARMUP_STEPS) >> 1) + WARMUP_STEPS + median_index = min(((NUM_STEPS - WARMUP_STEPS) >> 1) + WARMUP_STEPS, len(tflops_list) - 1) logger.info(f"Median TFLOPS is {tflops_list[median_index]:.3f}") torch.cuda.synchronize() diff --git a/tests/test_zero/test_gemini/test_optim.py b/tests/test_zero/test_gemini/test_optim.py index a9366e7bc..1c914ca0e 100644 --- a/tests/test_zero/test_gemini/test_optim.py +++ b/tests/test_zero/test_gemini/test_optim.py @@ -40,9 +40,7 @@ EXAMPLE_MODELS = [ ] # bfloat16 cannot represent them exactly -BF16_IGNORED_KEYS = [ - "masked_bias", -] +BF16_IGNORED_KEYS = ["masked_bias"] def check_param(model: GeminiDDP, torch_model: torch.nn.Module, dtype: torch.dtype):