diff --git a/colossalai/booster/plugin/gemini_plugin.py b/colossalai/booster/plugin/gemini_plugin.py index 964cd302a..b1f8ea24a 100644 --- a/colossalai/booster/plugin/gemini_plugin.py +++ b/colossalai/booster/plugin/gemini_plugin.py @@ -329,6 +329,7 @@ class GeminiPlugin(DPPluginBase): chunk_init_device: Optional[torch.device] = None, placement_policy: str = "static", enable_gradient_accumulation: bool = False, + max_prefetch:int = 0, shard_param_frac: float = 1.0, # only for static placement offload_optim_frac: float = 0.0, # only for static placement offload_param_frac: float = 0.0, # only for static placement @@ -386,6 +387,7 @@ class GeminiPlugin(DPPluginBase): memstats=memstats, mixed_precision=PRECISION_STR_TO_DTYPE[precision], master_weights=master_weights, + max_prefetch=max_prefetch, ) self.zero_optim_config = dict( gpu_margin_mem_ratio=gpu_margin_mem_ratio, diff --git a/colossalai/zero/gemini/gemini_hook.py b/colossalai/zero/gemini/gemini_hook.py index 7f75f2471..01d9c9d07 100644 --- a/colossalai/zero/gemini/gemini_hook.py +++ b/colossalai/zero/gemini/gemini_hook.py @@ -1,4 +1,3 @@ -from chunk import Chunk from contextlib import contextmanager from enum import Enum from functools import partial @@ -13,15 +12,16 @@ from colossalai.utils import is_ddp_ignored from colossalai.zero.gemini import TensorState from colossalai.zero.gemini.gemini_mgr import GeminiManager from colossalai.zero.gemini.memory_tracer.param_runtime_order import OrderedParamGenerator +from colossalai.logging import DistributedLogger +from .chunk import Chunk class TrainingPhase(Enum): FORWARD = 0 BACKWARD = 1 -DEBUG = True # TODO @botbw: remove - +logger = DistributedLogger("gemini_hook") class GeminiZeROHook(ColoParamOpHook): def __init__( @@ -31,16 +31,14 @@ class GeminiZeROHook(ColoParamOpHook): self._gemini_manager = gemini_manager self._chunk_manager = gemini_manager.chunk_manager self._training_phase = TrainingPhase.FORWARD - self._cur_param = None # param_visited_order might be updated somewhere else self._param_visited_order = param_order.param_visited_order self._max_prefetch = max_prefetch self._async_works: Dict[Chunk, dist.work] = {} - # used by get_prefetch_chunks to track current param self._cur_param_idx = 0 - def get_prefetch_chunks(self, all_params: List[ColoParameter]) -> List[Chunk]: + def get_prefetch_chunks(self, all_params: List[ColoParameter], cur_chunks: List[Chunk]) -> List[Chunk]: chunks_to_prefetch = set() if self._training_phase == TrainingPhase.FORWARD: # forward phrase: increase self._cur_param_idx += len(all_params) # need to update first @@ -52,10 +50,10 @@ class GeminiZeROHook(ColoParamOpHook): idx += 1 continue chunk = self._chunk_manager.get_chunk(param) - chunks_to_prefetch.add(chunk) + if chunk not in cur_chunks: + chunks_to_prefetch.add(chunk) idx += 1 else: - assert self._training_phase == TrainingPhase.BACKWARD self._cur_param_idx -= len(all_params) idx = self._cur_param_idx - 1 chunks_to_prefetch = set() @@ -65,14 +63,17 @@ class GeminiZeROHook(ColoParamOpHook): idx -= 1 continue chunk = self._chunk_manager.get_chunk(self._param_visited_order[idx]) - chunks_to_prefetch.add(chunk) + if chunk not in cur_chunks: + chunks_to_prefetch.add(chunk) idx -= 1 + print(f"cur id {self._cur_param_idx}") return list(chunks_to_prefetch) def wait_chunks(self, chunks: List[Chunk]) -> List[Chunk]: non_prefetched_chunks = [] for chunk in chunks: if chunk in self._async_works: + print(f"prefetched {chunk.count_id}") self._async_works[chunk].wait() del self._async_works[chunk] else: @@ -80,31 +81,42 @@ class GeminiZeROHook(ColoParamOpHook): return non_prefetched_chunks def pre_op(self, all_params): - if DEBUG: # TODO @botbw: remove - idxs = list(map(lambda x: self._linked_param_order.param_visited_order.index(x), all_params)) - mx = max(idxs) - idxs = sorted(map(lambda x: x - mx, idxs)) - assert list(range(len(idxs))) == idxs, f"{idxs=}" + # def find_idx(param): + # for i, p in enumerate(self._param_visited_order): + # if param is p: + # return i + # assert False + + # idxs = [find_idx(p) for p in all_params] + # max_id = min(idxs) + # idxs = [i - max_id for i in idxs] + # assert list(range(len(idxs))) == sorted(idxs), f'{idxs}' # deal with current needed chunks params = [p for p in all_params if not is_ddp_ignored(p)] all_chunks = self._chunk_manager.get_chunks(params) - chunks_wo_work = self.wait_chunks(all_chunks) + chunks_need_to_fetch_sync = tuple(self.wait_chunks(all_chunks)) for p in params: self._chunk_manager.trans_tensor_state(p, TensorState.COMPUTE) self._gemini_manager.sample_overall_data() - self._gemini_manager.adjust_layout(chunks_wo_work) + self._gemini_manager.adjust_layout(chunks_need_to_fetch_sync) # deal with chunks that are to be async fetched - prefetch_chunks = self.get_prefetch_chunks(all_params) + chunks_can_be_fetch_async = self.get_prefetch_chunks(all_params=all_params, cur_chunks=chunks_need_to_fetch_sync) + print(f"cur_chunks {' '.join([str(x.count_id) for x in chunks_need_to_fetch_sync])}, prefetch {' '.join([str(x.count_id) for x in chunks_can_be_fetch_async])}") # deal with chunks that are to be fetched now - for chunk in chunks_wo_work: + for chunk in chunks_need_to_fetch_sync: self._chunk_manager.access_chunk(chunk) # deal with chunks that are to be pre fetched TODO @botbw: the order here matters? - for chunk in prefetch_chunks: - self._async_works[chunk] = self._chunk_manager.access_chunk(chunk, async_access=True) + for chunk in chunks_can_be_fetch_async: + if chunk in self._async_works: + continue + maybe_work = self._chunk_manager.access_chunk(chunk, async_access=True) + if maybe_work is not None: + print(f"prefetch {chunk.count_id}") + self._async_works[chunk] = maybe_work # record cuda model data of the current OP self._gemini_manager.record_model_data_volume() @@ -133,6 +145,11 @@ class GeminiZeROHook(ColoParamOpHook): @contextmanager def switch_training_phase(self, training_phase: TrainingPhase = TrainingPhase.BACKWARD): + if training_phase == TrainingPhase.FORWARD: + self._cur_param_idx = 0 + else: + self._cur_param_idx = len(self._param_visited_order) - 1 + old_training_phase = self._training_phase try: self._training_phase = training_phase