From 6e38eafebec514cd670758a0fb06b37bd01d224e Mon Sep 17 00:00:00 2001 From: hxwang Date: Wed, 15 May 2024 16:51:44 +0800 Subject: [PATCH 1/4] [gemini] prefetch chunks --- colossalai/zero/gemini/chunk/chunk.py | 12 ++-- colossalai/zero/gemini/chunk/manager.py | 10 +-- colossalai/zero/gemini/gemini_ddp.py | 4 +- colossalai/zero/gemini/gemini_hook.py | 87 +++++++++++++++++++++++-- 4 files changed, 96 insertions(+), 17 deletions(-) diff --git a/colossalai/zero/gemini/chunk/chunk.py b/colossalai/zero/gemini/chunk/chunk.py index cad2622f2..299ea0518 100644 --- a/colossalai/zero/gemini/chunk/chunk.py +++ b/colossalai/zero/gemini/chunk/chunk.py @@ -357,14 +357,14 @@ class Chunk: else: raise NotImplementedError - def access_chunk(self): + def access_chunk(self, async_access: bool = False) -> Optional[dist.Work]: """Make the chunk usable for the parameters inside it. It's an operation done in CUDA.""" # sanity check assert self.chunk_temp is None - if not self.is_gathered: - self.__gather() + return self.__gather(async_op=async_access) self.__update_tensors_ptr() + return None def release_chunk(self): """Release the usable chunk. It's an operation done in CUDA.""" @@ -498,17 +498,19 @@ class Chunk: def get_tensors(self) -> List[torch.Tensor]: return list(self.tensors_info.keys()) - def __gather(self): + def __gather(self, async_op: bool = False) -> Optional[dist.Work]: if not self.is_gathered: # sanity check assert self.cuda_shard is not None alloc_storage(self.cuda_global_chunk) gather_list = list(torch.chunk(input=self.cuda_global_chunk, chunks=self.pg_size, dim=0)) - dist.all_gather(gather_list, self.cuda_shard, self.torch_pg) + work = dist.all_gather(gather_list, self.cuda_shard, self.torch_pg, async_op=async_op) self.cuda_shard = None self.is_gathered = True + return work + return None def __scatter(self): if self.keep_gathered: diff --git a/colossalai/zero/gemini/chunk/manager.py b/colossalai/zero/gemini/chunk/manager.py index 333a3f224..9cee5223e 100644 --- a/colossalai/zero/gemini/chunk/manager.py +++ b/colossalai/zero/gemini/chunk/manager.py @@ -111,15 +111,16 @@ class ChunkManager: for group_name in self.chunk_groups: self.__close_one_chunk(self.chunk_groups[group_name][-1]) - def access_chunk(self, chunk: Chunk) -> None: + def access_chunk(self, chunk: Chunk, async_access: bool = False) -> Optional[dist.Work]: """Make the chunk can be used for calculation.""" if chunk in self.accessed_chunks: return self.__sub_memory_usage(chunk.memory_usage) if chunk.device_type == "cpu": chunk.shard_move(get_accelerator().get_current_device()) - self.__add_accessed_chunk(chunk) + maybe_work = self.__add_accessed_chunk(chunk, async_access=async_access) self.__add_memory_usage(chunk.memory_usage) + return maybe_work def release_chunk(self, chunk: Chunk) -> None: """Scatter the chunk in CUDA.""" @@ -251,10 +252,11 @@ class ChunkManager: for k, v in usage.items(): self.total_mem[k] += v - def __add_accessed_chunk(self, chunk: Chunk): - chunk.access_chunk() + def __add_accessed_chunk(self, chunk: Chunk, async_access: bool = False) -> Optional[dist.Work]: + maybe_work = chunk.access_chunk(async_access=async_access) self.accessed_chunks.add(chunk) self.accessed_mem += chunk.chunk_mem + return maybe_work def __sub_accessed_chunk(self, chunk: Chunk): chunk.release_chunk() diff --git a/colossalai/zero/gemini/gemini_ddp.py b/colossalai/zero/gemini/gemini_ddp.py index c1029097a..21448bdae 100644 --- a/colossalai/zero/gemini/gemini_ddp.py +++ b/colossalai/zero/gemini/gemini_ddp.py @@ -78,6 +78,7 @@ class GeminiDDP(ModelWrapper): chunk_init_device: torch.device = torch.device("cpu"), placement_policy: str = "static", enable_gradient_accumulation: bool = False, + max_prefetch: int = 0, shard_param_frac: float = 1.0, # only for static placement offload_optim_frac: float = 0.0, # only for static placement offload_param_frac: float = 0.0, # only for static placement @@ -132,7 +133,6 @@ class GeminiDDP(ModelWrapper): steady_cuda_cap_ratio=steady_cuda_cap_ratio, ) self.force_outputs_fp32 = force_outputs_fp32 - self.param_op_hook = GeminiZeROHook(self.gemini_manager) self.fp32_params: List[torch.Tensor] = list() self.fp16_params: List[ColoParameter] = list() self.grads_device: Dict[torch.Tensor, torch.device] = dict() @@ -157,6 +157,8 @@ class GeminiDDP(ModelWrapper): for p in module.parameters(): param_order.append(p) + self.param_op_hook = GeminiZeROHook(self.gemini_manager, param_order=param_order, max_prefetch=max_prefetch) + for name, param in module.named_parameters(): self.param2name[param] = name for m_name, m_var in module.named_modules(): diff --git a/colossalai/zero/gemini/gemini_hook.py b/colossalai/zero/gemini/gemini_hook.py index 480a14511..7f75f2471 100644 --- a/colossalai/zero/gemini/gemini_hook.py +++ b/colossalai/zero/gemini/gemini_hook.py @@ -1,14 +1,18 @@ +from chunk import Chunk from contextlib import contextmanager from enum import Enum from functools import partial -from typing import List +from typing import Dict, List import torch +import torch.distributed as dist +from colossalai.tensor.colo_parameter import ColoParameter from colossalai.tensor.param_op_hook import ColoParamOpHook from colossalai.utils import is_ddp_ignored from colossalai.zero.gemini import TensorState from colossalai.zero.gemini.gemini_mgr import GeminiManager +from colossalai.zero.gemini.memory_tracer.param_runtime_order import OrderedParamGenerator class TrainingPhase(Enum): @@ -16,23 +20,92 @@ class TrainingPhase(Enum): BACKWARD = 1 +DEBUG = True # TODO @botbw: remove + + class GeminiZeROHook(ColoParamOpHook): - def __init__(self, gemini_manager: GeminiManager) -> None: + def __init__( + self, gemini_manager: GeminiManager, param_order: OrderedParamGenerator, max_prefetch: int = 0 + ) -> None: super().__init__() self._gemini_manager = gemini_manager self._chunk_manager = gemini_manager.chunk_manager self._training_phase = TrainingPhase.FORWARD + self._cur_param = None + # param_visited_order might be updated somewhere else + self._param_visited_order = param_order.param_visited_order + self._max_prefetch = max_prefetch + self._async_works: Dict[Chunk, dist.work] = {} - def pre_op(self, params): - params = [p for p in params if not is_ddp_ignored(p)] - chunks = self._chunk_manager.get_chunks(params) + # used by get_prefetch_chunks to track current param + self._cur_param_idx = 0 + + def get_prefetch_chunks(self, all_params: List[ColoParameter]) -> List[Chunk]: + chunks_to_prefetch = set() + if self._training_phase == TrainingPhase.FORWARD: # forward phrase: increase + self._cur_param_idx += len(all_params) # need to update first + idx = self._cur_param_idx + 1 + # still have params and prefetched chunks don't exceed the limit + while idx < len(self._param_visited_order) and len(chunks_to_prefetch) + 1 < self._max_prefetch: + param = self._param_visited_order[idx] + if is_ddp_ignored(param): + idx += 1 + continue + chunk = self._chunk_manager.get_chunk(param) + chunks_to_prefetch.add(chunk) + idx += 1 + else: + assert self._training_phase == TrainingPhase.BACKWARD + self._cur_param_idx -= len(all_params) + idx = self._cur_param_idx - 1 + chunks_to_prefetch = set() + while idx >= 0 and len(chunks_to_prefetch) + 1 < self._max_prefetch: + param = self._param_visited_order[idx] + if is_ddp_ignored(param): + idx -= 1 + continue + chunk = self._chunk_manager.get_chunk(self._param_visited_order[idx]) + chunks_to_prefetch.add(chunk) + idx -= 1 + return list(chunks_to_prefetch) + + def wait_chunks(self, chunks: List[Chunk]) -> List[Chunk]: + non_prefetched_chunks = [] + for chunk in chunks: + if chunk in self._async_works: + self._async_works[chunk].wait() + del self._async_works[chunk] + else: + non_prefetched_chunks.append(chunk) + return non_prefetched_chunks + + def pre_op(self, all_params): + if DEBUG: # TODO @botbw: remove + idxs = list(map(lambda x: self._linked_param_order.param_visited_order.index(x), all_params)) + mx = max(idxs) + idxs = sorted(map(lambda x: x - mx, idxs)) + assert list(range(len(idxs))) == idxs, f"{idxs=}" + + # deal with current needed chunks + params = [p for p in all_params if not is_ddp_ignored(p)] + all_chunks = self._chunk_manager.get_chunks(params) + chunks_wo_work = self.wait_chunks(all_chunks) for p in params: self._chunk_manager.trans_tensor_state(p, TensorState.COMPUTE) self._gemini_manager.sample_overall_data() - self._gemini_manager.adjust_layout(chunks) - for chunk in chunks: + self._gemini_manager.adjust_layout(chunks_wo_work) + + # deal with chunks that are to be async fetched + prefetch_chunks = self.get_prefetch_chunks(all_params) + + # deal with chunks that are to be fetched now + for chunk in chunks_wo_work: self._chunk_manager.access_chunk(chunk) + # deal with chunks that are to be pre fetched TODO @botbw: the order here matters? + for chunk in prefetch_chunks: + self._async_works[chunk] = self._chunk_manager.access_chunk(chunk, async_access=True) + # record cuda model data of the current OP self._gemini_manager.record_model_data_volume() From b2e97458883c64e4f357059f585ff2585fa12edd Mon Sep 17 00:00:00 2001 From: hxwang Date: Thu, 16 May 2024 04:45:06 +0000 Subject: [PATCH 2/4] [chore] sync --- colossalai/booster/plugin/gemini_plugin.py | 2 + colossalai/zero/gemini/gemini_hook.py | 57 ++++++++++++++-------- 2 files changed, 39 insertions(+), 20 deletions(-) diff --git a/colossalai/booster/plugin/gemini_plugin.py b/colossalai/booster/plugin/gemini_plugin.py index 964cd302a..b1f8ea24a 100644 --- a/colossalai/booster/plugin/gemini_plugin.py +++ b/colossalai/booster/plugin/gemini_plugin.py @@ -329,6 +329,7 @@ class GeminiPlugin(DPPluginBase): chunk_init_device: Optional[torch.device] = None, placement_policy: str = "static", enable_gradient_accumulation: bool = False, + max_prefetch:int = 0, shard_param_frac: float = 1.0, # only for static placement offload_optim_frac: float = 0.0, # only for static placement offload_param_frac: float = 0.0, # only for static placement @@ -386,6 +387,7 @@ class GeminiPlugin(DPPluginBase): memstats=memstats, mixed_precision=PRECISION_STR_TO_DTYPE[precision], master_weights=master_weights, + max_prefetch=max_prefetch, ) self.zero_optim_config = dict( gpu_margin_mem_ratio=gpu_margin_mem_ratio, diff --git a/colossalai/zero/gemini/gemini_hook.py b/colossalai/zero/gemini/gemini_hook.py index 7f75f2471..01d9c9d07 100644 --- a/colossalai/zero/gemini/gemini_hook.py +++ b/colossalai/zero/gemini/gemini_hook.py @@ -1,4 +1,3 @@ -from chunk import Chunk from contextlib import contextmanager from enum import Enum from functools import partial @@ -13,15 +12,16 @@ from colossalai.utils import is_ddp_ignored from colossalai.zero.gemini import TensorState from colossalai.zero.gemini.gemini_mgr import GeminiManager from colossalai.zero.gemini.memory_tracer.param_runtime_order import OrderedParamGenerator +from colossalai.logging import DistributedLogger +from .chunk import Chunk class TrainingPhase(Enum): FORWARD = 0 BACKWARD = 1 -DEBUG = True # TODO @botbw: remove - +logger = DistributedLogger("gemini_hook") class GeminiZeROHook(ColoParamOpHook): def __init__( @@ -31,16 +31,14 @@ class GeminiZeROHook(ColoParamOpHook): self._gemini_manager = gemini_manager self._chunk_manager = gemini_manager.chunk_manager self._training_phase = TrainingPhase.FORWARD - self._cur_param = None # param_visited_order might be updated somewhere else self._param_visited_order = param_order.param_visited_order self._max_prefetch = max_prefetch self._async_works: Dict[Chunk, dist.work] = {} - # used by get_prefetch_chunks to track current param self._cur_param_idx = 0 - def get_prefetch_chunks(self, all_params: List[ColoParameter]) -> List[Chunk]: + def get_prefetch_chunks(self, all_params: List[ColoParameter], cur_chunks: List[Chunk]) -> List[Chunk]: chunks_to_prefetch = set() if self._training_phase == TrainingPhase.FORWARD: # forward phrase: increase self._cur_param_idx += len(all_params) # need to update first @@ -52,10 +50,10 @@ class GeminiZeROHook(ColoParamOpHook): idx += 1 continue chunk = self._chunk_manager.get_chunk(param) - chunks_to_prefetch.add(chunk) + if chunk not in cur_chunks: + chunks_to_prefetch.add(chunk) idx += 1 else: - assert self._training_phase == TrainingPhase.BACKWARD self._cur_param_idx -= len(all_params) idx = self._cur_param_idx - 1 chunks_to_prefetch = set() @@ -65,14 +63,17 @@ class GeminiZeROHook(ColoParamOpHook): idx -= 1 continue chunk = self._chunk_manager.get_chunk(self._param_visited_order[idx]) - chunks_to_prefetch.add(chunk) + if chunk not in cur_chunks: + chunks_to_prefetch.add(chunk) idx -= 1 + print(f"cur id {self._cur_param_idx}") return list(chunks_to_prefetch) def wait_chunks(self, chunks: List[Chunk]) -> List[Chunk]: non_prefetched_chunks = [] for chunk in chunks: if chunk in self._async_works: + print(f"prefetched {chunk.count_id}") self._async_works[chunk].wait() del self._async_works[chunk] else: @@ -80,31 +81,42 @@ class GeminiZeROHook(ColoParamOpHook): return non_prefetched_chunks def pre_op(self, all_params): - if DEBUG: # TODO @botbw: remove - idxs = list(map(lambda x: self._linked_param_order.param_visited_order.index(x), all_params)) - mx = max(idxs) - idxs = sorted(map(lambda x: x - mx, idxs)) - assert list(range(len(idxs))) == idxs, f"{idxs=}" + # def find_idx(param): + # for i, p in enumerate(self._param_visited_order): + # if param is p: + # return i + # assert False + + # idxs = [find_idx(p) for p in all_params] + # max_id = min(idxs) + # idxs = [i - max_id for i in idxs] + # assert list(range(len(idxs))) == sorted(idxs), f'{idxs}' # deal with current needed chunks params = [p for p in all_params if not is_ddp_ignored(p)] all_chunks = self._chunk_manager.get_chunks(params) - chunks_wo_work = self.wait_chunks(all_chunks) + chunks_need_to_fetch_sync = tuple(self.wait_chunks(all_chunks)) for p in params: self._chunk_manager.trans_tensor_state(p, TensorState.COMPUTE) self._gemini_manager.sample_overall_data() - self._gemini_manager.adjust_layout(chunks_wo_work) + self._gemini_manager.adjust_layout(chunks_need_to_fetch_sync) # deal with chunks that are to be async fetched - prefetch_chunks = self.get_prefetch_chunks(all_params) + chunks_can_be_fetch_async = self.get_prefetch_chunks(all_params=all_params, cur_chunks=chunks_need_to_fetch_sync) + print(f"cur_chunks {' '.join([str(x.count_id) for x in chunks_need_to_fetch_sync])}, prefetch {' '.join([str(x.count_id) for x in chunks_can_be_fetch_async])}") # deal with chunks that are to be fetched now - for chunk in chunks_wo_work: + for chunk in chunks_need_to_fetch_sync: self._chunk_manager.access_chunk(chunk) # deal with chunks that are to be pre fetched TODO @botbw: the order here matters? - for chunk in prefetch_chunks: - self._async_works[chunk] = self._chunk_manager.access_chunk(chunk, async_access=True) + for chunk in chunks_can_be_fetch_async: + if chunk in self._async_works: + continue + maybe_work = self._chunk_manager.access_chunk(chunk, async_access=True) + if maybe_work is not None: + print(f"prefetch {chunk.count_id}") + self._async_works[chunk] = maybe_work # record cuda model data of the current OP self._gemini_manager.record_model_data_volume() @@ -133,6 +145,11 @@ class GeminiZeROHook(ColoParamOpHook): @contextmanager def switch_training_phase(self, training_phase: TrainingPhase = TrainingPhase.BACKWARD): + if training_phase == TrainingPhase.FORWARD: + self._cur_param_idx = 0 + else: + self._cur_param_idx = len(self._param_visited_order) - 1 + old_training_phase = self._training_phase try: self._training_phase = training_phase From 4148ceed9f17446a6c247b49c33805b5abd17984 Mon Sep 17 00:00:00 2001 From: hxwang Date: Thu, 16 May 2024 13:17:26 +0800 Subject: [PATCH 3/4] [gemini] use compute_chunk to find next chunk --- colossalai/zero/gemini/chunk/manager.py | 2 +- colossalai/zero/gemini/gemini_ddp.py | 3 +- colossalai/zero/gemini/gemini_hook.py | 85 ++++------------------ colossalai/zero/gemini/gemini_mgr.py | 22 ++++-- colossalai/zero/gemini/placement_policy.py | 19 +++++ 5 files changed, 52 insertions(+), 79 deletions(-) diff --git a/colossalai/zero/gemini/chunk/manager.py b/colossalai/zero/gemini/chunk/manager.py index 9cee5223e..c7bdd5e1f 100644 --- a/colossalai/zero/gemini/chunk/manager.py +++ b/colossalai/zero/gemini/chunk/manager.py @@ -114,7 +114,7 @@ class ChunkManager: def access_chunk(self, chunk: Chunk, async_access: bool = False) -> Optional[dist.Work]: """Make the chunk can be used for calculation.""" if chunk in self.accessed_chunks: - return + return None self.__sub_memory_usage(chunk.memory_usage) if chunk.device_type == "cpu": chunk.shard_move(get_accelerator().get_current_device()) diff --git a/colossalai/zero/gemini/gemini_ddp.py b/colossalai/zero/gemini/gemini_ddp.py index 21448bdae..b75f69a3b 100644 --- a/colossalai/zero/gemini/gemini_ddp.py +++ b/colossalai/zero/gemini/gemini_ddp.py @@ -133,6 +133,7 @@ class GeminiDDP(ModelWrapper): steady_cuda_cap_ratio=steady_cuda_cap_ratio, ) self.force_outputs_fp32 = force_outputs_fp32 + self.param_op_hook = GeminiZeROHook(self.gemini_manager, max_prefetch=max_prefetch) self.fp32_params: List[torch.Tensor] = list() self.fp16_params: List[ColoParameter] = list() self.grads_device: Dict[torch.Tensor, torch.device] = dict() @@ -157,8 +158,6 @@ class GeminiDDP(ModelWrapper): for p in module.parameters(): param_order.append(p) - self.param_op_hook = GeminiZeROHook(self.gemini_manager, param_order=param_order, max_prefetch=max_prefetch) - for name, param in module.named_parameters(): self.param2name[param] = name for m_name, m_var in module.named_modules(): diff --git a/colossalai/zero/gemini/gemini_hook.py b/colossalai/zero/gemini/gemini_hook.py index 01d9c9d07..82d890975 100644 --- a/colossalai/zero/gemini/gemini_hook.py +++ b/colossalai/zero/gemini/gemini_hook.py @@ -6,16 +6,15 @@ from typing import Dict, List import torch import torch.distributed as dist -from colossalai.tensor.colo_parameter import ColoParameter +from colossalai.logging import DistributedLogger from colossalai.tensor.param_op_hook import ColoParamOpHook from colossalai.utils import is_ddp_ignored from colossalai.zero.gemini import TensorState from colossalai.zero.gemini.gemini_mgr import GeminiManager -from colossalai.zero.gemini.memory_tracer.param_runtime_order import OrderedParamGenerator -from colossalai.logging import DistributedLogger from .chunk import Chunk + class TrainingPhase(Enum): FORWARD = 0 BACKWARD = 1 @@ -23,51 +22,15 @@ class TrainingPhase(Enum): logger = DistributedLogger("gemini_hook") + class GeminiZeROHook(ColoParamOpHook): - def __init__( - self, gemini_manager: GeminiManager, param_order: OrderedParamGenerator, max_prefetch: int = 0 - ) -> None: + def __init__(self, gemini_manager: GeminiManager, max_prefetch: int = 0) -> None: super().__init__() self._gemini_manager = gemini_manager self._chunk_manager = gemini_manager.chunk_manager self._training_phase = TrainingPhase.FORWARD - # param_visited_order might be updated somewhere else - self._param_visited_order = param_order.param_visited_order self._max_prefetch = max_prefetch self._async_works: Dict[Chunk, dist.work] = {} - # used by get_prefetch_chunks to track current param - self._cur_param_idx = 0 - - def get_prefetch_chunks(self, all_params: List[ColoParameter], cur_chunks: List[Chunk]) -> List[Chunk]: - chunks_to_prefetch = set() - if self._training_phase == TrainingPhase.FORWARD: # forward phrase: increase - self._cur_param_idx += len(all_params) # need to update first - idx = self._cur_param_idx + 1 - # still have params and prefetched chunks don't exceed the limit - while idx < len(self._param_visited_order) and len(chunks_to_prefetch) + 1 < self._max_prefetch: - param = self._param_visited_order[idx] - if is_ddp_ignored(param): - idx += 1 - continue - chunk = self._chunk_manager.get_chunk(param) - if chunk not in cur_chunks: - chunks_to_prefetch.add(chunk) - idx += 1 - else: - self._cur_param_idx -= len(all_params) - idx = self._cur_param_idx - 1 - chunks_to_prefetch = set() - while idx >= 0 and len(chunks_to_prefetch) + 1 < self._max_prefetch: - param = self._param_visited_order[idx] - if is_ddp_ignored(param): - idx -= 1 - continue - chunk = self._chunk_manager.get_chunk(self._param_visited_order[idx]) - if chunk not in cur_chunks: - chunks_to_prefetch.add(chunk) - idx -= 1 - print(f"cur id {self._cur_param_idx}") - return list(chunks_to_prefetch) def wait_chunks(self, chunks: List[Chunk]) -> List[Chunk]: non_prefetched_chunks = [] @@ -80,45 +43,25 @@ class GeminiZeROHook(ColoParamOpHook): non_prefetched_chunks.append(chunk) return non_prefetched_chunks - def pre_op(self, all_params): - # def find_idx(param): - # for i, p in enumerate(self._param_visited_order): - # if param is p: - # return i - # assert False - - # idxs = [find_idx(p) for p in all_params] - # max_id = min(idxs) - # idxs = [i - max_id for i in idxs] - # assert list(range(len(idxs))) == sorted(idxs), f'{idxs}' - - # deal with current needed chunks - params = [p for p in all_params if not is_ddp_ignored(p)] + def pre_op(self, params): + params = [p for p in params if not is_ddp_ignored(p)] all_chunks = self._chunk_manager.get_chunks(params) - chunks_need_to_fetch_sync = tuple(self.wait_chunks(all_chunks)) + # wait for prefetched chunks, filter those are not prefetched + chunks_fetch_sync = tuple(self.wait_chunks(all_chunks)) for p in params: self._chunk_manager.trans_tensor_state(p, TensorState.COMPUTE) self._gemini_manager.sample_overall_data() - self._gemini_manager.adjust_layout(chunks_need_to_fetch_sync) - - # deal with chunks that are to be async fetched - chunks_can_be_fetch_async = self.get_prefetch_chunks(all_params=all_params, cur_chunks=chunks_need_to_fetch_sync) - - print(f"cur_chunks {' '.join([str(x.count_id) for x in chunks_need_to_fetch_sync])}, prefetch {' '.join([str(x.count_id) for x in chunks_can_be_fetch_async])}") - # deal with chunks that are to be fetched now - for chunk in chunks_need_to_fetch_sync: + self._gemini_manager.adjust_layout(all_chunks, record_anyway=self._max_prefetch > 0) + # fetch the rest chunks synchronously + for chunk in chunks_fetch_sync: self._chunk_manager.access_chunk(chunk) - - # deal with chunks that are to be pre fetched TODO @botbw: the order here matters? - for chunk in chunks_can_be_fetch_async: - if chunk in self._async_works: - continue + chunks_fetch_async = self._gemini_manager.placement_policy.get_prefetch_chunks(max_prefetch=self._max_prefetch) + for chunk in chunks_fetch_async: maybe_work = self._chunk_manager.access_chunk(chunk, async_access=True) if maybe_work is not None: - print(f"prefetch {chunk.count_id}") self._async_works[chunk] = maybe_work - # record cuda model data of the current OP + # record cuda model data of the current OP, including memory for prefetched chunks self._gemini_manager.record_model_data_volume() def post_op(self, params): diff --git a/colossalai/zero/gemini/gemini_mgr.py b/colossalai/zero/gemini/gemini_mgr.py index 150932e3d..0362d6523 100644 --- a/colossalai/zero/gemini/gemini_mgr.py +++ b/colossalai/zero/gemini/gemini_mgr.py @@ -6,7 +6,7 @@ import torch from .chunk import Chunk, ChunkManager from .memory_tracer import ChunkMemStatsCollector, MemStats -from .placement_policy import PlacementPolicyFactory +from .placement_policy import PlacementPolicy, PlacementPolicyFactory class GeminiManager: @@ -91,13 +91,13 @@ class GeminiManager: self._warmup = False self.reset_attributes() - def adjust_layout(self, chunks: Tuple[Chunk, ...]) -> None: + def adjust_layout(self, chunks: Tuple[Chunk, ...], record_anyway: bool = False) -> None: """Adjust the layout of stateful tensors according to the information provided by mem_stats_collector, which should belongs to a Sharded Model. """ # find stateful tensor in state COMPUTE start = time() - self._record_chunks_order(chunks) + self._record_warmup_chunks_order(chunks, record_anyway=record_anyway) cuda_demand, hold_cuda_tensor_list = self._get_layout_info(self._compute_idx, self._warmup, chunks) self._layout_time += time() - start @@ -133,9 +133,9 @@ class GeminiManager: can_evict_chunks = self._chunk_manager.get_cuda_movable_chunks() return cuda_demand, can_evict_chunks - def _record_chunks_order(self, chunks: Tuple[Chunk, ...]) -> None: + def _record_warmup_chunks_order(self, chunks: Tuple[Chunk, ...], record_anyway: bool = False) -> None: self._compute_idx += 1 - if self._warmup and self._placement_policy.need_mem_stats: + if self._warmup and (self._placement_policy.need_mem_stats or record_anyway): self._compute_list.append(chunks) def sample_overall_data(self): @@ -156,6 +156,18 @@ class GeminiManager: return self._mem_stats_collector.cuda_margin_mem return None + @property + def compute_list(self) -> List[Tuple[Chunk, ...]]: + return self._compute_list + + @property + def compute_idx(self) -> int: + return self._compute_idx + + @property + def placement_policy(self) -> PlacementPolicy: + return self._placement_policy + @property def is_cuda_margin_mem_avail(self) -> bool: return self._placement_policy.need_mem_stats diff --git a/colossalai/zero/gemini/placement_policy.py b/colossalai/zero/gemini/placement_policy.py index 388999549..452687b7d 100644 --- a/colossalai/zero/gemini/placement_policy.py +++ b/colossalai/zero/gemini/placement_policy.py @@ -33,6 +33,10 @@ class PlacementPolicy(ABC): ) -> None: raise NotImplementedError + @abstractmethod + def get_prefetch_chunks(self, max_prefetch: int) -> List[Chunk]: + raise NotImplementedError + class StaticPlacementPolicy(PlacementPolicy): def __init__( @@ -95,6 +99,18 @@ class StaticPlacementPolicy(PlacementPolicy): self.keep_gathered_chunk_mem = total_chunk_mem * (1 - self.shard_param_frac) self.keep_cuda_chunk_mem = total_chunk_mem * (1 - self.offload_param_frac) + def get_prefetch_chunks(self, max_prefetch: int) -> List[Chunk]: + prefetch = [] + for i in range(self.chunk_manager.compute_idx + 1, len(self.chunk_manager.compute_list)): + for chunk in self.chunk_manager.compute_list[i]: + if len(prefetch) >= max_prefetch: + break + if chunk not in prefetch: + prefetch.append(chunk) + if len(prefetch) >= max_prefetch: + break + return prefetch + class AutoPlacementPolicy(PlacementPolicy): need_mem_stats: bool = True @@ -198,6 +214,9 @@ class AutoPlacementPolicy(PlacementPolicy): else: grads_device_map[p] = torch.device("cpu") + def get_prefetch_chunks(self, max_prefetch: int) -> List[Chunk]: + return [] # TODO @botbw: implement prefetching for auto + class PlacementPolicyFactory: policies: Dict[str, Type[PlacementPolicy]] = { From 5bedea6e10db044c8395c8410d0fcab9f4fbb3d9 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 16 May 2024 05:20:00 +0000 Subject: [PATCH 4/4] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- colossalai/booster/plugin/gemini_plugin.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/colossalai/booster/plugin/gemini_plugin.py b/colossalai/booster/plugin/gemini_plugin.py index b1f8ea24a..aeef14487 100644 --- a/colossalai/booster/plugin/gemini_plugin.py +++ b/colossalai/booster/plugin/gemini_plugin.py @@ -329,7 +329,7 @@ class GeminiPlugin(DPPluginBase): chunk_init_device: Optional[torch.device] = None, placement_policy: str = "static", enable_gradient_accumulation: bool = False, - max_prefetch:int = 0, + max_prefetch: int = 0, shard_param_frac: float = 1.0, # only for static placement offload_optim_frac: float = 0.0, # only for static placement offload_param_frac: float = 0.0, # only for static placement