diff --git a/colossalai/booster/plugin/gemini_plugin.py b/colossalai/booster/plugin/gemini_plugin.py index 964cd302a..aeef14487 100644 --- a/colossalai/booster/plugin/gemini_plugin.py +++ b/colossalai/booster/plugin/gemini_plugin.py @@ -329,6 +329,7 @@ class GeminiPlugin(DPPluginBase): chunk_init_device: Optional[torch.device] = None, placement_policy: str = "static", enable_gradient_accumulation: bool = False, + max_prefetch: int = 0, shard_param_frac: float = 1.0, # only for static placement offload_optim_frac: float = 0.0, # only for static placement offload_param_frac: float = 0.0, # only for static placement @@ -386,6 +387,7 @@ class GeminiPlugin(DPPluginBase): memstats=memstats, mixed_precision=PRECISION_STR_TO_DTYPE[precision], master_weights=master_weights, + max_prefetch=max_prefetch, ) self.zero_optim_config = dict( gpu_margin_mem_ratio=gpu_margin_mem_ratio, diff --git a/colossalai/zero/gemini/chunk/chunk.py b/colossalai/zero/gemini/chunk/chunk.py index cad2622f2..299ea0518 100644 --- a/colossalai/zero/gemini/chunk/chunk.py +++ b/colossalai/zero/gemini/chunk/chunk.py @@ -357,14 +357,14 @@ class Chunk: else: raise NotImplementedError - def access_chunk(self): + def access_chunk(self, async_access: bool = False) -> Optional[dist.Work]: """Make the chunk usable for the parameters inside it. It's an operation done in CUDA.""" # sanity check assert self.chunk_temp is None - if not self.is_gathered: - self.__gather() + return self.__gather(async_op=async_access) self.__update_tensors_ptr() + return None def release_chunk(self): """Release the usable chunk. It's an operation done in CUDA.""" @@ -498,17 +498,19 @@ class Chunk: def get_tensors(self) -> List[torch.Tensor]: return list(self.tensors_info.keys()) - def __gather(self): + def __gather(self, async_op: bool = False) -> Optional[dist.Work]: if not self.is_gathered: # sanity check assert self.cuda_shard is not None alloc_storage(self.cuda_global_chunk) gather_list = list(torch.chunk(input=self.cuda_global_chunk, chunks=self.pg_size, dim=0)) - dist.all_gather(gather_list, self.cuda_shard, self.torch_pg) + work = dist.all_gather(gather_list, self.cuda_shard, self.torch_pg, async_op=async_op) self.cuda_shard = None self.is_gathered = True + return work + return None def __scatter(self): if self.keep_gathered: diff --git a/colossalai/zero/gemini/chunk/manager.py b/colossalai/zero/gemini/chunk/manager.py index 333a3f224..c7bdd5e1f 100644 --- a/colossalai/zero/gemini/chunk/manager.py +++ b/colossalai/zero/gemini/chunk/manager.py @@ -111,15 +111,16 @@ class ChunkManager: for group_name in self.chunk_groups: self.__close_one_chunk(self.chunk_groups[group_name][-1]) - def access_chunk(self, chunk: Chunk) -> None: + def access_chunk(self, chunk: Chunk, async_access: bool = False) -> Optional[dist.Work]: """Make the chunk can be used for calculation.""" if chunk in self.accessed_chunks: - return + return None self.__sub_memory_usage(chunk.memory_usage) if chunk.device_type == "cpu": chunk.shard_move(get_accelerator().get_current_device()) - self.__add_accessed_chunk(chunk) + maybe_work = self.__add_accessed_chunk(chunk, async_access=async_access) self.__add_memory_usage(chunk.memory_usage) + return maybe_work def release_chunk(self, chunk: Chunk) -> None: """Scatter the chunk in CUDA.""" @@ -251,10 +252,11 @@ class ChunkManager: for k, v in usage.items(): self.total_mem[k] += v - def __add_accessed_chunk(self, chunk: Chunk): - chunk.access_chunk() + def __add_accessed_chunk(self, chunk: Chunk, async_access: bool = False) -> Optional[dist.Work]: + maybe_work = chunk.access_chunk(async_access=async_access) self.accessed_chunks.add(chunk) self.accessed_mem += chunk.chunk_mem + return maybe_work def __sub_accessed_chunk(self, chunk: Chunk): chunk.release_chunk() diff --git a/colossalai/zero/gemini/gemini_ddp.py b/colossalai/zero/gemini/gemini_ddp.py index c1029097a..b75f69a3b 100644 --- a/colossalai/zero/gemini/gemini_ddp.py +++ b/colossalai/zero/gemini/gemini_ddp.py @@ -78,6 +78,7 @@ class GeminiDDP(ModelWrapper): chunk_init_device: torch.device = torch.device("cpu"), placement_policy: str = "static", enable_gradient_accumulation: bool = False, + max_prefetch: int = 0, shard_param_frac: float = 1.0, # only for static placement offload_optim_frac: float = 0.0, # only for static placement offload_param_frac: float = 0.0, # only for static placement @@ -132,7 +133,7 @@ class GeminiDDP(ModelWrapper): steady_cuda_cap_ratio=steady_cuda_cap_ratio, ) self.force_outputs_fp32 = force_outputs_fp32 - self.param_op_hook = GeminiZeROHook(self.gemini_manager) + self.param_op_hook = GeminiZeROHook(self.gemini_manager, max_prefetch=max_prefetch) self.fp32_params: List[torch.Tensor] = list() self.fp16_params: List[ColoParameter] = list() self.grads_device: Dict[torch.Tensor, torch.device] = dict() diff --git a/colossalai/zero/gemini/gemini_hook.py b/colossalai/zero/gemini/gemini_hook.py index 480a14511..82d890975 100644 --- a/colossalai/zero/gemini/gemini_hook.py +++ b/colossalai/zero/gemini/gemini_hook.py @@ -1,39 +1,67 @@ from contextlib import contextmanager from enum import Enum from functools import partial -from typing import List +from typing import Dict, List import torch +import torch.distributed as dist +from colossalai.logging import DistributedLogger from colossalai.tensor.param_op_hook import ColoParamOpHook from colossalai.utils import is_ddp_ignored from colossalai.zero.gemini import TensorState from colossalai.zero.gemini.gemini_mgr import GeminiManager +from .chunk import Chunk + class TrainingPhase(Enum): FORWARD = 0 BACKWARD = 1 +logger = DistributedLogger("gemini_hook") + + class GeminiZeROHook(ColoParamOpHook): - def __init__(self, gemini_manager: GeminiManager) -> None: + def __init__(self, gemini_manager: GeminiManager, max_prefetch: int = 0) -> None: super().__init__() self._gemini_manager = gemini_manager self._chunk_manager = gemini_manager.chunk_manager self._training_phase = TrainingPhase.FORWARD + self._max_prefetch = max_prefetch + self._async_works: Dict[Chunk, dist.work] = {} + + def wait_chunks(self, chunks: List[Chunk]) -> List[Chunk]: + non_prefetched_chunks = [] + for chunk in chunks: + if chunk in self._async_works: + print(f"prefetched {chunk.count_id}") + self._async_works[chunk].wait() + del self._async_works[chunk] + else: + non_prefetched_chunks.append(chunk) + return non_prefetched_chunks def pre_op(self, params): params = [p for p in params if not is_ddp_ignored(p)] - chunks = self._chunk_manager.get_chunks(params) + all_chunks = self._chunk_manager.get_chunks(params) + # wait for prefetched chunks, filter those are not prefetched + chunks_fetch_sync = tuple(self.wait_chunks(all_chunks)) for p in params: self._chunk_manager.trans_tensor_state(p, TensorState.COMPUTE) self._gemini_manager.sample_overall_data() - self._gemini_manager.adjust_layout(chunks) - for chunk in chunks: + self._gemini_manager.adjust_layout(all_chunks, record_anyway=self._max_prefetch > 0) + # fetch the rest chunks synchronously + for chunk in chunks_fetch_sync: self._chunk_manager.access_chunk(chunk) + chunks_fetch_async = self._gemini_manager.placement_policy.get_prefetch_chunks(max_prefetch=self._max_prefetch) + for chunk in chunks_fetch_async: + maybe_work = self._chunk_manager.access_chunk(chunk, async_access=True) + if maybe_work is not None: + self._async_works[chunk] = maybe_work - # record cuda model data of the current OP + # record cuda model data of the current OP, including memory for prefetched chunks self._gemini_manager.record_model_data_volume() def post_op(self, params): @@ -60,6 +88,11 @@ class GeminiZeROHook(ColoParamOpHook): @contextmanager def switch_training_phase(self, training_phase: TrainingPhase = TrainingPhase.BACKWARD): + if training_phase == TrainingPhase.FORWARD: + self._cur_param_idx = 0 + else: + self._cur_param_idx = len(self._param_visited_order) - 1 + old_training_phase = self._training_phase try: self._training_phase = training_phase diff --git a/colossalai/zero/gemini/gemini_mgr.py b/colossalai/zero/gemini/gemini_mgr.py index 150932e3d..0362d6523 100644 --- a/colossalai/zero/gemini/gemini_mgr.py +++ b/colossalai/zero/gemini/gemini_mgr.py @@ -6,7 +6,7 @@ import torch from .chunk import Chunk, ChunkManager from .memory_tracer import ChunkMemStatsCollector, MemStats -from .placement_policy import PlacementPolicyFactory +from .placement_policy import PlacementPolicy, PlacementPolicyFactory class GeminiManager: @@ -91,13 +91,13 @@ class GeminiManager: self._warmup = False self.reset_attributes() - def adjust_layout(self, chunks: Tuple[Chunk, ...]) -> None: + def adjust_layout(self, chunks: Tuple[Chunk, ...], record_anyway: bool = False) -> None: """Adjust the layout of stateful tensors according to the information provided by mem_stats_collector, which should belongs to a Sharded Model. """ # find stateful tensor in state COMPUTE start = time() - self._record_chunks_order(chunks) + self._record_warmup_chunks_order(chunks, record_anyway=record_anyway) cuda_demand, hold_cuda_tensor_list = self._get_layout_info(self._compute_idx, self._warmup, chunks) self._layout_time += time() - start @@ -133,9 +133,9 @@ class GeminiManager: can_evict_chunks = self._chunk_manager.get_cuda_movable_chunks() return cuda_demand, can_evict_chunks - def _record_chunks_order(self, chunks: Tuple[Chunk, ...]) -> None: + def _record_warmup_chunks_order(self, chunks: Tuple[Chunk, ...], record_anyway: bool = False) -> None: self._compute_idx += 1 - if self._warmup and self._placement_policy.need_mem_stats: + if self._warmup and (self._placement_policy.need_mem_stats or record_anyway): self._compute_list.append(chunks) def sample_overall_data(self): @@ -156,6 +156,18 @@ class GeminiManager: return self._mem_stats_collector.cuda_margin_mem return None + @property + def compute_list(self) -> List[Tuple[Chunk, ...]]: + return self._compute_list + + @property + def compute_idx(self) -> int: + return self._compute_idx + + @property + def placement_policy(self) -> PlacementPolicy: + return self._placement_policy + @property def is_cuda_margin_mem_avail(self) -> bool: return self._placement_policy.need_mem_stats diff --git a/colossalai/zero/gemini/placement_policy.py b/colossalai/zero/gemini/placement_policy.py index 388999549..452687b7d 100644 --- a/colossalai/zero/gemini/placement_policy.py +++ b/colossalai/zero/gemini/placement_policy.py @@ -33,6 +33,10 @@ class PlacementPolicy(ABC): ) -> None: raise NotImplementedError + @abstractmethod + def get_prefetch_chunks(self, max_prefetch: int) -> List[Chunk]: + raise NotImplementedError + class StaticPlacementPolicy(PlacementPolicy): def __init__( @@ -95,6 +99,18 @@ class StaticPlacementPolicy(PlacementPolicy): self.keep_gathered_chunk_mem = total_chunk_mem * (1 - self.shard_param_frac) self.keep_cuda_chunk_mem = total_chunk_mem * (1 - self.offload_param_frac) + def get_prefetch_chunks(self, max_prefetch: int) -> List[Chunk]: + prefetch = [] + for i in range(self.chunk_manager.compute_idx + 1, len(self.chunk_manager.compute_list)): + for chunk in self.chunk_manager.compute_list[i]: + if len(prefetch) >= max_prefetch: + break + if chunk not in prefetch: + prefetch.append(chunk) + if len(prefetch) >= max_prefetch: + break + return prefetch + class AutoPlacementPolicy(PlacementPolicy): need_mem_stats: bool = True @@ -198,6 +214,9 @@ class AutoPlacementPolicy(PlacementPolicy): else: grads_device_map[p] = torch.device("cpu") + def get_prefetch_chunks(self, max_prefetch: int) -> List[Chunk]: + return [] # TODO @botbw: implement prefetching for auto + class PlacementPolicyFactory: policies: Dict[str, Type[PlacementPolicy]] = {