From 6e38eafebec514cd670758a0fb06b37bd01d224e Mon Sep 17 00:00:00 2001 From: hxwang Date: Wed, 15 May 2024 16:51:44 +0800 Subject: [PATCH 01/36] [gemini] prefetch chunks --- colossalai/zero/gemini/chunk/chunk.py | 12 ++-- colossalai/zero/gemini/chunk/manager.py | 10 +-- colossalai/zero/gemini/gemini_ddp.py | 4 +- colossalai/zero/gemini/gemini_hook.py | 87 +++++++++++++++++++++++-- 4 files changed, 96 insertions(+), 17 deletions(-) diff --git a/colossalai/zero/gemini/chunk/chunk.py b/colossalai/zero/gemini/chunk/chunk.py index cad2622f2..299ea0518 100644 --- a/colossalai/zero/gemini/chunk/chunk.py +++ b/colossalai/zero/gemini/chunk/chunk.py @@ -357,14 +357,14 @@ class Chunk: else: raise NotImplementedError - def access_chunk(self): + def access_chunk(self, async_access: bool = False) -> Optional[dist.Work]: """Make the chunk usable for the parameters inside it. It's an operation done in CUDA.""" # sanity check assert self.chunk_temp is None - if not self.is_gathered: - self.__gather() + return self.__gather(async_op=async_access) self.__update_tensors_ptr() + return None def release_chunk(self): """Release the usable chunk. It's an operation done in CUDA.""" @@ -498,17 +498,19 @@ class Chunk: def get_tensors(self) -> List[torch.Tensor]: return list(self.tensors_info.keys()) - def __gather(self): + def __gather(self, async_op: bool = False) -> Optional[dist.Work]: if not self.is_gathered: # sanity check assert self.cuda_shard is not None alloc_storage(self.cuda_global_chunk) gather_list = list(torch.chunk(input=self.cuda_global_chunk, chunks=self.pg_size, dim=0)) - dist.all_gather(gather_list, self.cuda_shard, self.torch_pg) + work = dist.all_gather(gather_list, self.cuda_shard, self.torch_pg, async_op=async_op) self.cuda_shard = None self.is_gathered = True + return work + return None def __scatter(self): if self.keep_gathered: diff --git a/colossalai/zero/gemini/chunk/manager.py b/colossalai/zero/gemini/chunk/manager.py index 333a3f224..9cee5223e 100644 --- a/colossalai/zero/gemini/chunk/manager.py +++ b/colossalai/zero/gemini/chunk/manager.py @@ -111,15 +111,16 @@ class ChunkManager: for group_name in self.chunk_groups: self.__close_one_chunk(self.chunk_groups[group_name][-1]) - def access_chunk(self, chunk: Chunk) -> None: + def access_chunk(self, chunk: Chunk, async_access: bool = False) -> Optional[dist.Work]: """Make the chunk can be used for calculation.""" if chunk in self.accessed_chunks: return self.__sub_memory_usage(chunk.memory_usage) if chunk.device_type == "cpu": chunk.shard_move(get_accelerator().get_current_device()) - self.__add_accessed_chunk(chunk) + maybe_work = self.__add_accessed_chunk(chunk, async_access=async_access) self.__add_memory_usage(chunk.memory_usage) + return maybe_work def release_chunk(self, chunk: Chunk) -> None: """Scatter the chunk in CUDA.""" @@ -251,10 +252,11 @@ class ChunkManager: for k, v in usage.items(): self.total_mem[k] += v - def __add_accessed_chunk(self, chunk: Chunk): - chunk.access_chunk() + def __add_accessed_chunk(self, chunk: Chunk, async_access: bool = False) -> Optional[dist.Work]: + maybe_work = chunk.access_chunk(async_access=async_access) self.accessed_chunks.add(chunk) self.accessed_mem += chunk.chunk_mem + return maybe_work def __sub_accessed_chunk(self, chunk: Chunk): chunk.release_chunk() diff --git a/colossalai/zero/gemini/gemini_ddp.py b/colossalai/zero/gemini/gemini_ddp.py index c1029097a..21448bdae 100644 --- a/colossalai/zero/gemini/gemini_ddp.py +++ b/colossalai/zero/gemini/gemini_ddp.py @@ -78,6 +78,7 @@ class GeminiDDP(ModelWrapper): chunk_init_device: torch.device = torch.device("cpu"), placement_policy: str = "static", enable_gradient_accumulation: bool = False, + max_prefetch: int = 0, shard_param_frac: float = 1.0, # only for static placement offload_optim_frac: float = 0.0, # only for static placement offload_param_frac: float = 0.0, # only for static placement @@ -132,7 +133,6 @@ class GeminiDDP(ModelWrapper): steady_cuda_cap_ratio=steady_cuda_cap_ratio, ) self.force_outputs_fp32 = force_outputs_fp32 - self.param_op_hook = GeminiZeROHook(self.gemini_manager) self.fp32_params: List[torch.Tensor] = list() self.fp16_params: List[ColoParameter] = list() self.grads_device: Dict[torch.Tensor, torch.device] = dict() @@ -157,6 +157,8 @@ class GeminiDDP(ModelWrapper): for p in module.parameters(): param_order.append(p) + self.param_op_hook = GeminiZeROHook(self.gemini_manager, param_order=param_order, max_prefetch=max_prefetch) + for name, param in module.named_parameters(): self.param2name[param] = name for m_name, m_var in module.named_modules(): diff --git a/colossalai/zero/gemini/gemini_hook.py b/colossalai/zero/gemini/gemini_hook.py index 480a14511..7f75f2471 100644 --- a/colossalai/zero/gemini/gemini_hook.py +++ b/colossalai/zero/gemini/gemini_hook.py @@ -1,14 +1,18 @@ +from chunk import Chunk from contextlib import contextmanager from enum import Enum from functools import partial -from typing import List +from typing import Dict, List import torch +import torch.distributed as dist +from colossalai.tensor.colo_parameter import ColoParameter from colossalai.tensor.param_op_hook import ColoParamOpHook from colossalai.utils import is_ddp_ignored from colossalai.zero.gemini import TensorState from colossalai.zero.gemini.gemini_mgr import GeminiManager +from colossalai.zero.gemini.memory_tracer.param_runtime_order import OrderedParamGenerator class TrainingPhase(Enum): @@ -16,23 +20,92 @@ class TrainingPhase(Enum): BACKWARD = 1 +DEBUG = True # TODO @botbw: remove + + class GeminiZeROHook(ColoParamOpHook): - def __init__(self, gemini_manager: GeminiManager) -> None: + def __init__( + self, gemini_manager: GeminiManager, param_order: OrderedParamGenerator, max_prefetch: int = 0 + ) -> None: super().__init__() self._gemini_manager = gemini_manager self._chunk_manager = gemini_manager.chunk_manager self._training_phase = TrainingPhase.FORWARD + self._cur_param = None + # param_visited_order might be updated somewhere else + self._param_visited_order = param_order.param_visited_order + self._max_prefetch = max_prefetch + self._async_works: Dict[Chunk, dist.work] = {} - def pre_op(self, params): - params = [p for p in params if not is_ddp_ignored(p)] - chunks = self._chunk_manager.get_chunks(params) + # used by get_prefetch_chunks to track current param + self._cur_param_idx = 0 + + def get_prefetch_chunks(self, all_params: List[ColoParameter]) -> List[Chunk]: + chunks_to_prefetch = set() + if self._training_phase == TrainingPhase.FORWARD: # forward phrase: increase + self._cur_param_idx += len(all_params) # need to update first + idx = self._cur_param_idx + 1 + # still have params and prefetched chunks don't exceed the limit + while idx < len(self._param_visited_order) and len(chunks_to_prefetch) + 1 < self._max_prefetch: + param = self._param_visited_order[idx] + if is_ddp_ignored(param): + idx += 1 + continue + chunk = self._chunk_manager.get_chunk(param) + chunks_to_prefetch.add(chunk) + idx += 1 + else: + assert self._training_phase == TrainingPhase.BACKWARD + self._cur_param_idx -= len(all_params) + idx = self._cur_param_idx - 1 + chunks_to_prefetch = set() + while idx >= 0 and len(chunks_to_prefetch) + 1 < self._max_prefetch: + param = self._param_visited_order[idx] + if is_ddp_ignored(param): + idx -= 1 + continue + chunk = self._chunk_manager.get_chunk(self._param_visited_order[idx]) + chunks_to_prefetch.add(chunk) + idx -= 1 + return list(chunks_to_prefetch) + + def wait_chunks(self, chunks: List[Chunk]) -> List[Chunk]: + non_prefetched_chunks = [] + for chunk in chunks: + if chunk in self._async_works: + self._async_works[chunk].wait() + del self._async_works[chunk] + else: + non_prefetched_chunks.append(chunk) + return non_prefetched_chunks + + def pre_op(self, all_params): + if DEBUG: # TODO @botbw: remove + idxs = list(map(lambda x: self._linked_param_order.param_visited_order.index(x), all_params)) + mx = max(idxs) + idxs = sorted(map(lambda x: x - mx, idxs)) + assert list(range(len(idxs))) == idxs, f"{idxs=}" + + # deal with current needed chunks + params = [p for p in all_params if not is_ddp_ignored(p)] + all_chunks = self._chunk_manager.get_chunks(params) + chunks_wo_work = self.wait_chunks(all_chunks) for p in params: self._chunk_manager.trans_tensor_state(p, TensorState.COMPUTE) self._gemini_manager.sample_overall_data() - self._gemini_manager.adjust_layout(chunks) - for chunk in chunks: + self._gemini_manager.adjust_layout(chunks_wo_work) + + # deal with chunks that are to be async fetched + prefetch_chunks = self.get_prefetch_chunks(all_params) + + # deal with chunks that are to be fetched now + for chunk in chunks_wo_work: self._chunk_manager.access_chunk(chunk) + # deal with chunks that are to be pre fetched TODO @botbw: the order here matters? + for chunk in prefetch_chunks: + self._async_works[chunk] = self._chunk_manager.access_chunk(chunk, async_access=True) + # record cuda model data of the current OP self._gemini_manager.record_model_data_volume() From b2e97458883c64e4f357059f585ff2585fa12edd Mon Sep 17 00:00:00 2001 From: hxwang Date: Thu, 16 May 2024 04:45:06 +0000 Subject: [PATCH 02/36] [chore] sync --- colossalai/booster/plugin/gemini_plugin.py | 2 + colossalai/zero/gemini/gemini_hook.py | 57 ++++++++++++++-------- 2 files changed, 39 insertions(+), 20 deletions(-) diff --git a/colossalai/booster/plugin/gemini_plugin.py b/colossalai/booster/plugin/gemini_plugin.py index 964cd302a..b1f8ea24a 100644 --- a/colossalai/booster/plugin/gemini_plugin.py +++ b/colossalai/booster/plugin/gemini_plugin.py @@ -329,6 +329,7 @@ class GeminiPlugin(DPPluginBase): chunk_init_device: Optional[torch.device] = None, placement_policy: str = "static", enable_gradient_accumulation: bool = False, + max_prefetch:int = 0, shard_param_frac: float = 1.0, # only for static placement offload_optim_frac: float = 0.0, # only for static placement offload_param_frac: float = 0.0, # only for static placement @@ -386,6 +387,7 @@ class GeminiPlugin(DPPluginBase): memstats=memstats, mixed_precision=PRECISION_STR_TO_DTYPE[precision], master_weights=master_weights, + max_prefetch=max_prefetch, ) self.zero_optim_config = dict( gpu_margin_mem_ratio=gpu_margin_mem_ratio, diff --git a/colossalai/zero/gemini/gemini_hook.py b/colossalai/zero/gemini/gemini_hook.py index 7f75f2471..01d9c9d07 100644 --- a/colossalai/zero/gemini/gemini_hook.py +++ b/colossalai/zero/gemini/gemini_hook.py @@ -1,4 +1,3 @@ -from chunk import Chunk from contextlib import contextmanager from enum import Enum from functools import partial @@ -13,15 +12,16 @@ from colossalai.utils import is_ddp_ignored from colossalai.zero.gemini import TensorState from colossalai.zero.gemini.gemini_mgr import GeminiManager from colossalai.zero.gemini.memory_tracer.param_runtime_order import OrderedParamGenerator +from colossalai.logging import DistributedLogger +from .chunk import Chunk class TrainingPhase(Enum): FORWARD = 0 BACKWARD = 1 -DEBUG = True # TODO @botbw: remove - +logger = DistributedLogger("gemini_hook") class GeminiZeROHook(ColoParamOpHook): def __init__( @@ -31,16 +31,14 @@ class GeminiZeROHook(ColoParamOpHook): self._gemini_manager = gemini_manager self._chunk_manager = gemini_manager.chunk_manager self._training_phase = TrainingPhase.FORWARD - self._cur_param = None # param_visited_order might be updated somewhere else self._param_visited_order = param_order.param_visited_order self._max_prefetch = max_prefetch self._async_works: Dict[Chunk, dist.work] = {} - # used by get_prefetch_chunks to track current param self._cur_param_idx = 0 - def get_prefetch_chunks(self, all_params: List[ColoParameter]) -> List[Chunk]: + def get_prefetch_chunks(self, all_params: List[ColoParameter], cur_chunks: List[Chunk]) -> List[Chunk]: chunks_to_prefetch = set() if self._training_phase == TrainingPhase.FORWARD: # forward phrase: increase self._cur_param_idx += len(all_params) # need to update first @@ -52,10 +50,10 @@ class GeminiZeROHook(ColoParamOpHook): idx += 1 continue chunk = self._chunk_manager.get_chunk(param) - chunks_to_prefetch.add(chunk) + if chunk not in cur_chunks: + chunks_to_prefetch.add(chunk) idx += 1 else: - assert self._training_phase == TrainingPhase.BACKWARD self._cur_param_idx -= len(all_params) idx = self._cur_param_idx - 1 chunks_to_prefetch = set() @@ -65,14 +63,17 @@ class GeminiZeROHook(ColoParamOpHook): idx -= 1 continue chunk = self._chunk_manager.get_chunk(self._param_visited_order[idx]) - chunks_to_prefetch.add(chunk) + if chunk not in cur_chunks: + chunks_to_prefetch.add(chunk) idx -= 1 + print(f"cur id {self._cur_param_idx}") return list(chunks_to_prefetch) def wait_chunks(self, chunks: List[Chunk]) -> List[Chunk]: non_prefetched_chunks = [] for chunk in chunks: if chunk in self._async_works: + print(f"prefetched {chunk.count_id}") self._async_works[chunk].wait() del self._async_works[chunk] else: @@ -80,31 +81,42 @@ class GeminiZeROHook(ColoParamOpHook): return non_prefetched_chunks def pre_op(self, all_params): - if DEBUG: # TODO @botbw: remove - idxs = list(map(lambda x: self._linked_param_order.param_visited_order.index(x), all_params)) - mx = max(idxs) - idxs = sorted(map(lambda x: x - mx, idxs)) - assert list(range(len(idxs))) == idxs, f"{idxs=}" + # def find_idx(param): + # for i, p in enumerate(self._param_visited_order): + # if param is p: + # return i + # assert False + + # idxs = [find_idx(p) for p in all_params] + # max_id = min(idxs) + # idxs = [i - max_id for i in idxs] + # assert list(range(len(idxs))) == sorted(idxs), f'{idxs}' # deal with current needed chunks params = [p for p in all_params if not is_ddp_ignored(p)] all_chunks = self._chunk_manager.get_chunks(params) - chunks_wo_work = self.wait_chunks(all_chunks) + chunks_need_to_fetch_sync = tuple(self.wait_chunks(all_chunks)) for p in params: self._chunk_manager.trans_tensor_state(p, TensorState.COMPUTE) self._gemini_manager.sample_overall_data() - self._gemini_manager.adjust_layout(chunks_wo_work) + self._gemini_manager.adjust_layout(chunks_need_to_fetch_sync) # deal with chunks that are to be async fetched - prefetch_chunks = self.get_prefetch_chunks(all_params) + chunks_can_be_fetch_async = self.get_prefetch_chunks(all_params=all_params, cur_chunks=chunks_need_to_fetch_sync) + print(f"cur_chunks {' '.join([str(x.count_id) for x in chunks_need_to_fetch_sync])}, prefetch {' '.join([str(x.count_id) for x in chunks_can_be_fetch_async])}") # deal with chunks that are to be fetched now - for chunk in chunks_wo_work: + for chunk in chunks_need_to_fetch_sync: self._chunk_manager.access_chunk(chunk) # deal with chunks that are to be pre fetched TODO @botbw: the order here matters? - for chunk in prefetch_chunks: - self._async_works[chunk] = self._chunk_manager.access_chunk(chunk, async_access=True) + for chunk in chunks_can_be_fetch_async: + if chunk in self._async_works: + continue + maybe_work = self._chunk_manager.access_chunk(chunk, async_access=True) + if maybe_work is not None: + print(f"prefetch {chunk.count_id}") + self._async_works[chunk] = maybe_work # record cuda model data of the current OP self._gemini_manager.record_model_data_volume() @@ -133,6 +145,11 @@ class GeminiZeROHook(ColoParamOpHook): @contextmanager def switch_training_phase(self, training_phase: TrainingPhase = TrainingPhase.BACKWARD): + if training_phase == TrainingPhase.FORWARD: + self._cur_param_idx = 0 + else: + self._cur_param_idx = len(self._param_visited_order) - 1 + old_training_phase = self._training_phase try: self._training_phase = training_phase From 4148ceed9f17446a6c247b49c33805b5abd17984 Mon Sep 17 00:00:00 2001 From: hxwang Date: Thu, 16 May 2024 13:17:26 +0800 Subject: [PATCH 03/36] [gemini] use compute_chunk to find next chunk --- colossalai/zero/gemini/chunk/manager.py | 2 +- colossalai/zero/gemini/gemini_ddp.py | 3 +- colossalai/zero/gemini/gemini_hook.py | 85 ++++------------------ colossalai/zero/gemini/gemini_mgr.py | 22 ++++-- colossalai/zero/gemini/placement_policy.py | 19 +++++ 5 files changed, 52 insertions(+), 79 deletions(-) diff --git a/colossalai/zero/gemini/chunk/manager.py b/colossalai/zero/gemini/chunk/manager.py index 9cee5223e..c7bdd5e1f 100644 --- a/colossalai/zero/gemini/chunk/manager.py +++ b/colossalai/zero/gemini/chunk/manager.py @@ -114,7 +114,7 @@ class ChunkManager: def access_chunk(self, chunk: Chunk, async_access: bool = False) -> Optional[dist.Work]: """Make the chunk can be used for calculation.""" if chunk in self.accessed_chunks: - return + return None self.__sub_memory_usage(chunk.memory_usage) if chunk.device_type == "cpu": chunk.shard_move(get_accelerator().get_current_device()) diff --git a/colossalai/zero/gemini/gemini_ddp.py b/colossalai/zero/gemini/gemini_ddp.py index 21448bdae..b75f69a3b 100644 --- a/colossalai/zero/gemini/gemini_ddp.py +++ b/colossalai/zero/gemini/gemini_ddp.py @@ -133,6 +133,7 @@ class GeminiDDP(ModelWrapper): steady_cuda_cap_ratio=steady_cuda_cap_ratio, ) self.force_outputs_fp32 = force_outputs_fp32 + self.param_op_hook = GeminiZeROHook(self.gemini_manager, max_prefetch=max_prefetch) self.fp32_params: List[torch.Tensor] = list() self.fp16_params: List[ColoParameter] = list() self.grads_device: Dict[torch.Tensor, torch.device] = dict() @@ -157,8 +158,6 @@ class GeminiDDP(ModelWrapper): for p in module.parameters(): param_order.append(p) - self.param_op_hook = GeminiZeROHook(self.gemini_manager, param_order=param_order, max_prefetch=max_prefetch) - for name, param in module.named_parameters(): self.param2name[param] = name for m_name, m_var in module.named_modules(): diff --git a/colossalai/zero/gemini/gemini_hook.py b/colossalai/zero/gemini/gemini_hook.py index 01d9c9d07..82d890975 100644 --- a/colossalai/zero/gemini/gemini_hook.py +++ b/colossalai/zero/gemini/gemini_hook.py @@ -6,16 +6,15 @@ from typing import Dict, List import torch import torch.distributed as dist -from colossalai.tensor.colo_parameter import ColoParameter +from colossalai.logging import DistributedLogger from colossalai.tensor.param_op_hook import ColoParamOpHook from colossalai.utils import is_ddp_ignored from colossalai.zero.gemini import TensorState from colossalai.zero.gemini.gemini_mgr import GeminiManager -from colossalai.zero.gemini.memory_tracer.param_runtime_order import OrderedParamGenerator -from colossalai.logging import DistributedLogger from .chunk import Chunk + class TrainingPhase(Enum): FORWARD = 0 BACKWARD = 1 @@ -23,51 +22,15 @@ class TrainingPhase(Enum): logger = DistributedLogger("gemini_hook") + class GeminiZeROHook(ColoParamOpHook): - def __init__( - self, gemini_manager: GeminiManager, param_order: OrderedParamGenerator, max_prefetch: int = 0 - ) -> None: + def __init__(self, gemini_manager: GeminiManager, max_prefetch: int = 0) -> None: super().__init__() self._gemini_manager = gemini_manager self._chunk_manager = gemini_manager.chunk_manager self._training_phase = TrainingPhase.FORWARD - # param_visited_order might be updated somewhere else - self._param_visited_order = param_order.param_visited_order self._max_prefetch = max_prefetch self._async_works: Dict[Chunk, dist.work] = {} - # used by get_prefetch_chunks to track current param - self._cur_param_idx = 0 - - def get_prefetch_chunks(self, all_params: List[ColoParameter], cur_chunks: List[Chunk]) -> List[Chunk]: - chunks_to_prefetch = set() - if self._training_phase == TrainingPhase.FORWARD: # forward phrase: increase - self._cur_param_idx += len(all_params) # need to update first - idx = self._cur_param_idx + 1 - # still have params and prefetched chunks don't exceed the limit - while idx < len(self._param_visited_order) and len(chunks_to_prefetch) + 1 < self._max_prefetch: - param = self._param_visited_order[idx] - if is_ddp_ignored(param): - idx += 1 - continue - chunk = self._chunk_manager.get_chunk(param) - if chunk not in cur_chunks: - chunks_to_prefetch.add(chunk) - idx += 1 - else: - self._cur_param_idx -= len(all_params) - idx = self._cur_param_idx - 1 - chunks_to_prefetch = set() - while idx >= 0 and len(chunks_to_prefetch) + 1 < self._max_prefetch: - param = self._param_visited_order[idx] - if is_ddp_ignored(param): - idx -= 1 - continue - chunk = self._chunk_manager.get_chunk(self._param_visited_order[idx]) - if chunk not in cur_chunks: - chunks_to_prefetch.add(chunk) - idx -= 1 - print(f"cur id {self._cur_param_idx}") - return list(chunks_to_prefetch) def wait_chunks(self, chunks: List[Chunk]) -> List[Chunk]: non_prefetched_chunks = [] @@ -80,45 +43,25 @@ class GeminiZeROHook(ColoParamOpHook): non_prefetched_chunks.append(chunk) return non_prefetched_chunks - def pre_op(self, all_params): - # def find_idx(param): - # for i, p in enumerate(self._param_visited_order): - # if param is p: - # return i - # assert False - - # idxs = [find_idx(p) for p in all_params] - # max_id = min(idxs) - # idxs = [i - max_id for i in idxs] - # assert list(range(len(idxs))) == sorted(idxs), f'{idxs}' - - # deal with current needed chunks - params = [p for p in all_params if not is_ddp_ignored(p)] + def pre_op(self, params): + params = [p for p in params if not is_ddp_ignored(p)] all_chunks = self._chunk_manager.get_chunks(params) - chunks_need_to_fetch_sync = tuple(self.wait_chunks(all_chunks)) + # wait for prefetched chunks, filter those are not prefetched + chunks_fetch_sync = tuple(self.wait_chunks(all_chunks)) for p in params: self._chunk_manager.trans_tensor_state(p, TensorState.COMPUTE) self._gemini_manager.sample_overall_data() - self._gemini_manager.adjust_layout(chunks_need_to_fetch_sync) - - # deal with chunks that are to be async fetched - chunks_can_be_fetch_async = self.get_prefetch_chunks(all_params=all_params, cur_chunks=chunks_need_to_fetch_sync) - - print(f"cur_chunks {' '.join([str(x.count_id) for x in chunks_need_to_fetch_sync])}, prefetch {' '.join([str(x.count_id) for x in chunks_can_be_fetch_async])}") - # deal with chunks that are to be fetched now - for chunk in chunks_need_to_fetch_sync: + self._gemini_manager.adjust_layout(all_chunks, record_anyway=self._max_prefetch > 0) + # fetch the rest chunks synchronously + for chunk in chunks_fetch_sync: self._chunk_manager.access_chunk(chunk) - - # deal with chunks that are to be pre fetched TODO @botbw: the order here matters? - for chunk in chunks_can_be_fetch_async: - if chunk in self._async_works: - continue + chunks_fetch_async = self._gemini_manager.placement_policy.get_prefetch_chunks(max_prefetch=self._max_prefetch) + for chunk in chunks_fetch_async: maybe_work = self._chunk_manager.access_chunk(chunk, async_access=True) if maybe_work is not None: - print(f"prefetch {chunk.count_id}") self._async_works[chunk] = maybe_work - # record cuda model data of the current OP + # record cuda model data of the current OP, including memory for prefetched chunks self._gemini_manager.record_model_data_volume() def post_op(self, params): diff --git a/colossalai/zero/gemini/gemini_mgr.py b/colossalai/zero/gemini/gemini_mgr.py index 150932e3d..0362d6523 100644 --- a/colossalai/zero/gemini/gemini_mgr.py +++ b/colossalai/zero/gemini/gemini_mgr.py @@ -6,7 +6,7 @@ import torch from .chunk import Chunk, ChunkManager from .memory_tracer import ChunkMemStatsCollector, MemStats -from .placement_policy import PlacementPolicyFactory +from .placement_policy import PlacementPolicy, PlacementPolicyFactory class GeminiManager: @@ -91,13 +91,13 @@ class GeminiManager: self._warmup = False self.reset_attributes() - def adjust_layout(self, chunks: Tuple[Chunk, ...]) -> None: + def adjust_layout(self, chunks: Tuple[Chunk, ...], record_anyway: bool = False) -> None: """Adjust the layout of stateful tensors according to the information provided by mem_stats_collector, which should belongs to a Sharded Model. """ # find stateful tensor in state COMPUTE start = time() - self._record_chunks_order(chunks) + self._record_warmup_chunks_order(chunks, record_anyway=record_anyway) cuda_demand, hold_cuda_tensor_list = self._get_layout_info(self._compute_idx, self._warmup, chunks) self._layout_time += time() - start @@ -133,9 +133,9 @@ class GeminiManager: can_evict_chunks = self._chunk_manager.get_cuda_movable_chunks() return cuda_demand, can_evict_chunks - def _record_chunks_order(self, chunks: Tuple[Chunk, ...]) -> None: + def _record_warmup_chunks_order(self, chunks: Tuple[Chunk, ...], record_anyway: bool = False) -> None: self._compute_idx += 1 - if self._warmup and self._placement_policy.need_mem_stats: + if self._warmup and (self._placement_policy.need_mem_stats or record_anyway): self._compute_list.append(chunks) def sample_overall_data(self): @@ -156,6 +156,18 @@ class GeminiManager: return self._mem_stats_collector.cuda_margin_mem return None + @property + def compute_list(self) -> List[Tuple[Chunk, ...]]: + return self._compute_list + + @property + def compute_idx(self) -> int: + return self._compute_idx + + @property + def placement_policy(self) -> PlacementPolicy: + return self._placement_policy + @property def is_cuda_margin_mem_avail(self) -> bool: return self._placement_policy.need_mem_stats diff --git a/colossalai/zero/gemini/placement_policy.py b/colossalai/zero/gemini/placement_policy.py index 388999549..452687b7d 100644 --- a/colossalai/zero/gemini/placement_policy.py +++ b/colossalai/zero/gemini/placement_policy.py @@ -33,6 +33,10 @@ class PlacementPolicy(ABC): ) -> None: raise NotImplementedError + @abstractmethod + def get_prefetch_chunks(self, max_prefetch: int) -> List[Chunk]: + raise NotImplementedError + class StaticPlacementPolicy(PlacementPolicy): def __init__( @@ -95,6 +99,18 @@ class StaticPlacementPolicy(PlacementPolicy): self.keep_gathered_chunk_mem = total_chunk_mem * (1 - self.shard_param_frac) self.keep_cuda_chunk_mem = total_chunk_mem * (1 - self.offload_param_frac) + def get_prefetch_chunks(self, max_prefetch: int) -> List[Chunk]: + prefetch = [] + for i in range(self.chunk_manager.compute_idx + 1, len(self.chunk_manager.compute_list)): + for chunk in self.chunk_manager.compute_list[i]: + if len(prefetch) >= max_prefetch: + break + if chunk not in prefetch: + prefetch.append(chunk) + if len(prefetch) >= max_prefetch: + break + return prefetch + class AutoPlacementPolicy(PlacementPolicy): need_mem_stats: bool = True @@ -198,6 +214,9 @@ class AutoPlacementPolicy(PlacementPolicy): else: grads_device_map[p] = torch.device("cpu") + def get_prefetch_chunks(self, max_prefetch: int) -> List[Chunk]: + return [] # TODO @botbw: implement prefetching for auto + class PlacementPolicyFactory: policies: Dict[str, Type[PlacementPolicy]] = { From 5bedea6e10db044c8395c8410d0fcab9f4fbb3d9 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 16 May 2024 05:20:00 +0000 Subject: [PATCH 04/36] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- colossalai/booster/plugin/gemini_plugin.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/colossalai/booster/plugin/gemini_plugin.py b/colossalai/booster/plugin/gemini_plugin.py index b1f8ea24a..aeef14487 100644 --- a/colossalai/booster/plugin/gemini_plugin.py +++ b/colossalai/booster/plugin/gemini_plugin.py @@ -329,7 +329,7 @@ class GeminiPlugin(DPPluginBase): chunk_init_device: Optional[torch.device] = None, placement_policy: str = "static", enable_gradient_accumulation: bool = False, - max_prefetch:int = 0, + max_prefetch: int = 0, shard_param_frac: float = 1.0, # only for static placement offload_optim_frac: float = 0.0, # only for static placement offload_param_frac: float = 0.0, # only for static placement From 2e68eebdfe9a8f68c0cb9a260ce7f2de4e6b6e27 Mon Sep 17 00:00:00 2001 From: hxwang Date: Thu, 16 May 2024 07:22:10 +0000 Subject: [PATCH 05/36] [chore] refactor & sync --- colossalai/zero/gemini/chunk/chunk.py | 1 + colossalai/zero/gemini/gemini_ddp.py | 3 +- colossalai/zero/gemini/gemini_hook.py | 55 ++++++++++--------- colossalai/zero/gemini/gemini_mgr.py | 27 +++++++-- colossalai/zero/gemini/placement_policy.py | 31 +++++++---- examples/language/gpt/gemini/commons/utils.py | 5 +- .../language/gpt/gemini/train_gpt_demo.py | 6 +- 7 files changed, 82 insertions(+), 46 deletions(-) diff --git a/colossalai/zero/gemini/chunk/chunk.py b/colossalai/zero/gemini/chunk/chunk.py index 299ea0518..81b6192ad 100644 --- a/colossalai/zero/gemini/chunk/chunk.py +++ b/colossalai/zero/gemini/chunk/chunk.py @@ -567,6 +567,7 @@ class Chunk: return self is __o def __repr__(self, detailed: bool = True): + return f"Chunk({self.count_id})" output = [ "Chunk Information:\n", "\tchunk size: {}, chunk dtype: {}, process group size: {}\n".format( diff --git a/colossalai/zero/gemini/gemini_ddp.py b/colossalai/zero/gemini/gemini_ddp.py index b75f69a3b..6bf0b4019 100644 --- a/colossalai/zero/gemini/gemini_ddp.py +++ b/colossalai/zero/gemini/gemini_ddp.py @@ -131,9 +131,10 @@ class GeminiDDP(ModelWrapper): offload_param_frac=offload_param_frac, warmup_non_model_data_ratio=warmup_non_model_data_ratio, steady_cuda_cap_ratio=steady_cuda_cap_ratio, + max_prefetch=max_prefetch ) self.force_outputs_fp32 = force_outputs_fp32 - self.param_op_hook = GeminiZeROHook(self.gemini_manager, max_prefetch=max_prefetch) + self.param_op_hook = GeminiZeROHook(self.gemini_manager) self.fp32_params: List[torch.Tensor] = list() self.fp16_params: List[ColoParameter] = list() self.grads_device: Dict[torch.Tensor, torch.device] = dict() diff --git a/colossalai/zero/gemini/gemini_hook.py b/colossalai/zero/gemini/gemini_hook.py index 82d890975..1d734bd83 100644 --- a/colossalai/zero/gemini/gemini_hook.py +++ b/colossalai/zero/gemini/gemini_hook.py @@ -1,7 +1,7 @@ from contextlib import contextmanager from enum import Enum from functools import partial -from typing import Dict, List +from typing import Dict, List, Iterable, Tuple import torch import torch.distributed as dist @@ -22,45 +22,55 @@ class TrainingPhase(Enum): logger = DistributedLogger("gemini_hook") +import os +rank = int(os.environ['RANK']) class GeminiZeROHook(ColoParamOpHook): - def __init__(self, gemini_manager: GeminiManager, max_prefetch: int = 0) -> None: + def __init__(self, gemini_manager: GeminiManager) -> None: super().__init__() self._gemini_manager = gemini_manager self._chunk_manager = gemini_manager.chunk_manager self._training_phase = TrainingPhase.FORWARD - self._max_prefetch = max_prefetch - self._async_works: Dict[Chunk, dist.work] = {} - def wait_chunks(self, chunks: List[Chunk]) -> List[Chunk]: - non_prefetched_chunks = [] - for chunk in chunks: - if chunk in self._async_works: - print(f"prefetched {chunk.count_id}") - self._async_works[chunk].wait() - del self._async_works[chunk] - else: - non_prefetched_chunks.append(chunk) - return non_prefetched_chunks def pre_op(self, params): + # map params to chunks params = [p for p in params if not is_ddp_ignored(p)] all_chunks = self._chunk_manager.get_chunks(params) + # wait for prefetched chunks, filter those are not prefetched - chunks_fetch_sync = tuple(self.wait_chunks(all_chunks)) + unique_chunks = set(all_chunks) + chunks_fetch_sync = self._gemini_manager.wait_chunks(all_chunks) + + # transfer state for p in params: self._chunk_manager.trans_tensor_state(p, TensorState.COMPUTE) self._gemini_manager.sample_overall_data() - self._gemini_manager.adjust_layout(all_chunks, record_anyway=self._max_prefetch > 0) - # fetch the rest chunks synchronously + + # evit chunks, aware of async fetched + self._gemini_manager.adjust_layout(all_chunks, record_anyway=self._gemini_manager.placement_policy.max_prefetch > 0) + + # fetch the rest synchronously for chunk in chunks_fetch_sync: self._chunk_manager.access_chunk(chunk) - chunks_fetch_async = self._gemini_manager.placement_policy.get_prefetch_chunks(max_prefetch=self._max_prefetch) + + # get possible chunks to prefetch + chunks_fetch_async = self._gemini_manager.placement_policy.get_prefetch_chunks() + if rank == 0 and not self._gemini_manager.is_warmup(): + print(f"compute_id: {self._gemini_manager.compute_idx} self._gemini_manager.compute_list: {self._gemini_manager.compute_list}") + print(f"{all_chunks=}") + print(f"accessed_chunks={self._chunk_manager.accessed_chunks}") + print(f"{chunks_fetch_sync=}") + print(f"{chunks_fetch_async=}") + print(f"works={list(self._gemini_manager._async_works.keys())}") + + # prefetch for chunk in chunks_fetch_async: maybe_work = self._chunk_manager.access_chunk(chunk, async_access=True) if maybe_work is not None: - self._async_works[chunk] = maybe_work - + self._gemini_manager.add_work(chunk, maybe_work) + if rank == 0 and not self._gemini_manager.is_warmup(): + print(f"post accessed_chunks={self._chunk_manager.accessed_chunks}") # record cuda model data of the current OP, including memory for prefetched chunks self._gemini_manager.record_model_data_volume() @@ -88,11 +98,6 @@ class GeminiZeROHook(ColoParamOpHook): @contextmanager def switch_training_phase(self, training_phase: TrainingPhase = TrainingPhase.BACKWARD): - if training_phase == TrainingPhase.FORWARD: - self._cur_param_idx = 0 - else: - self._cur_param_idx = len(self._param_visited_order) - 1 - old_training_phase = self._training_phase try: self._training_phase = training_phase diff --git a/colossalai/zero/gemini/gemini_mgr.py b/colossalai/zero/gemini/gemini_mgr.py index 0362d6523..6640bf03b 100644 --- a/colossalai/zero/gemini/gemini_mgr.py +++ b/colossalai/zero/gemini/gemini_mgr.py @@ -1,8 +1,9 @@ import functools from time import time -from typing import Dict, List, Optional, Tuple +from typing import Dict, List, Optional, Tuple, Iterable import torch +import torch.distributed as dist from .chunk import Chunk, ChunkManager from .memory_tracer import ChunkMemStatsCollector, MemStats @@ -41,9 +42,10 @@ class GeminiManager: self._mem_stats_collector = ( ChunkMemStatsCollector(chunk_manager, self._memstats) if policy_cls.need_mem_stats else None ) - self._placement_policy = policy_cls(chunk_manager, self._mem_stats_collector, **placement_kwargs) + self._placement_policy = policy_cls(self, chunk_manager, self._mem_stats_collector, **placement_kwargs) self._compute_list: List[Tuple[Chunk, ...]] = [] self._compute_idx: int = -1 + self._async_works: Dict[Chunk, dist.work] = {} self._h2d_volume = 0 self._d2h_volume = 0 @@ -98,11 +100,13 @@ class GeminiManager: # find stateful tensor in state COMPUTE start = time() self._record_warmup_chunks_order(chunks, record_anyway=record_anyway) - cuda_demand, hold_cuda_tensor_list = self._get_layout_info(self._compute_idx, self._warmup, chunks) + cuda_demand, can_evict_chunks = self._get_layout_info(self._compute_idx, self._warmup, chunks) + # don't evict chunks that are asynchronously fetched + can_evict_chunks = [chunk for chunk in can_evict_chunks if chunk not in self._async_works] self._layout_time += time() - start vol, evict_time = self._placement_policy.evict_tensors( - can_evict_chunks=hold_cuda_tensor_list, + can_evict_chunks=can_evict_chunks, cuda_demand=cuda_demand, warmup=self._warmup, compute_list=self._compute_list, @@ -114,6 +118,21 @@ class GeminiManager: # move COMPUTE tensors to CUDA self._h2d_volume += cuda_demand + def wait_chunks(self, chunks: Iterable[Chunk]) -> Tuple[Chunk]: + non_prefetched_chunks = [] + for chunk in chunks: + if chunk in self._async_works: + self._async_works[chunk].wait() + del self._async_works[chunk] + else: + non_prefetched_chunks.append(chunk) + return tuple(non_prefetched_chunks) + + def add_work(self, chunk: Chunk, work: dist.Work): + assert work is not None + assert chunk not in self._async_works + self._async_works[chunk] = work + @functools.lru_cache(maxsize=None) def _get_layout_info(self, compute_idx: int, warmup: bool, chunks: Tuple[Chunk, ...]): start = time() diff --git a/colossalai/zero/gemini/placement_policy.py b/colossalai/zero/gemini/placement_policy.py index 452687b7d..aad97321c 100644 --- a/colossalai/zero/gemini/placement_policy.py +++ b/colossalai/zero/gemini/placement_policy.py @@ -13,15 +13,16 @@ from colossalai.zero.gemini.chunk import Chunk from .chunk import Chunk, ChunkManager from .memory_tracer import ChunkMemStatsCollector - class PlacementPolicy(ABC): need_mem_stats: bool = False def __init__( - self, chunk_manager: ChunkManager, mem_stats_collector: Optional[ChunkMemStatsCollector] = None, **kwargs + self, gemini_manager: 'GeminiManager', chunk_manager: ChunkManager, mem_stats_collector: Optional[ChunkMemStatsCollector] = None, max_prefetch:int = 0, **kwargs ) -> None: + self.gemini_manager = gemini_manager self.chunk_manager = chunk_manager self.mem_stats_collector: Optional[ChunkMemStatsCollector] = mem_stats_collector + self.max_prefetch = max_prefetch @abstractmethod def evict_tensors(self, can_evict_chunks: List[Chunk], **kwargs) -> Tuple[int, float]: @@ -34,21 +35,25 @@ class PlacementPolicy(ABC): raise NotImplementedError @abstractmethod - def get_prefetch_chunks(self, max_prefetch: int) -> List[Chunk]: + def get_prefetch_chunks(self) -> List[Chunk]: raise NotImplementedError +import os +rank = int(os.environ["RANK"]) class StaticPlacementPolicy(PlacementPolicy): def __init__( self, + gemini_manager: 'GeminiManager', chunk_manager: ChunkManager, mem_stats_collector: Optional[ChunkMemStatsCollector] = None, + max_prefetch: int = 0, shard_param_frac: float = 1.0, offload_optim_frac: float = 0.0, offload_param_frac: float = 0.0, **kwargs, ) -> None: - super().__init__(chunk_manager, mem_stats_collector=mem_stats_collector) + super().__init__(gemini_manager, chunk_manager, mem_stats_collector=mem_stats_collector, max_prefetch=max_prefetch) if offload_param_frac > 0.0 and (shard_param_frac != 1.0 or offload_optim_frac != 1.0): warnings.warn("offload_param_frac is ignored when shard_param_frac != 1.0 or offload_optim_frac != 1.0") offload_param_frac = 0.0 @@ -99,15 +104,17 @@ class StaticPlacementPolicy(PlacementPolicy): self.keep_gathered_chunk_mem = total_chunk_mem * (1 - self.shard_param_frac) self.keep_cuda_chunk_mem = total_chunk_mem * (1 - self.offload_param_frac) - def get_prefetch_chunks(self, max_prefetch: int) -> List[Chunk]: + def get_prefetch_chunks(self) -> List[Chunk]: + if self.gemini_manager.is_warmup(): # no prefetch during warmup since we need compute_list + return [] prefetch = [] - for i in range(self.chunk_manager.compute_idx + 1, len(self.chunk_manager.compute_list)): - for chunk in self.chunk_manager.compute_list[i]: - if len(prefetch) >= max_prefetch: + for i in range(self.gemini_manager.compute_idx + 1, len(self.gemini_manager.compute_list)): + for chunk in self.gemini_manager.compute_list[i]: + if len(prefetch) >= self.max_prefetch: break - if chunk not in prefetch: + if chunk not in prefetch and chunk not in self.chunk_manager.accessed_chunks: prefetch.append(chunk) - if len(prefetch) >= max_prefetch: + if len(prefetch) >= self.max_prefetch: break return prefetch @@ -117,13 +124,15 @@ class AutoPlacementPolicy(PlacementPolicy): def __init__( self, + gemini_manager: 'GeminiManager', chunk_manager: ChunkManager, mem_stats_collector: Optional[ChunkMemStatsCollector] = None, + max_prefetch: int = 0, warmup_non_model_data_ratio: float = 0.8, steady_cuda_cap_ratio: float = 0.9, **kwargs, ) -> None: - super().__init__(chunk_manager, mem_stats_collector=mem_stats_collector) + super().__init__(gemini_manager, chunk_manager, mem_stats_collector=mem_stats_collector, max_prefetch=max_prefetch) # model data will use 1-_warmup_non_model_data_ratio CUDA memory in warmup phase # you can set them by AutoPlacementPolicy.set_warmup_non_model_data_ratio() # and AutoPlacementPolicy.set_steady_cuda_cap_ratio() diff --git a/examples/language/gpt/gemini/commons/utils.py b/examples/language/gpt/gemini/commons/utils.py index 7ed5fdb92..03054c0a2 100644 --- a/examples/language/gpt/gemini/commons/utils.py +++ b/examples/language/gpt/gemini/commons/utils.py @@ -30,8 +30,9 @@ def get_profile_context(enable_flag, warmup_steps, active_steps, save_dir): activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], schedule=schedule(wait=0, warmup=warmup_steps, active=active_steps), on_trace_ready=tensorboard_trace_handler(save_dir), - record_shapes=True, - profile_memory=True, + # record_shapes=True, + # profile_memory=True, + with_stack=True, ) else: return nullcontext(DummyProfiler()) diff --git a/examples/language/gpt/gemini/train_gpt_demo.py b/examples/language/gpt/gemini/train_gpt_demo.py index 4911ff124..6db74231a 100644 --- a/examples/language/gpt/gemini/train_gpt_demo.py +++ b/examples/language/gpt/gemini/train_gpt_demo.py @@ -129,7 +129,7 @@ def main(): WARMUP_STEPS = 1 assert WARMUP_STEPS < NUM_STEPS, "warmup steps should smaller than the total steps" assert (NUM_STEPS - WARMUP_STEPS) % 2 == 1, "the number of valid steps should be odd to take the median" - PROF_FLAG = False # The flag of profiling, False by default + PROF_FLAG = True # The flag of profiling, False by default disable_existing_loggers() colossalai.launch_from_torch() @@ -166,7 +166,7 @@ def main(): stage=zero_stage, reduce_bucket_size_in_m=12, overlap_communication=True, verbose=True ) elif args.distplan == "CAI_Gemini": - plugin = GeminiPlugin(search_range_m=128, hidden_dim=model.config.n_embd) + plugin = GeminiPlugin(search_range_m=128, hidden_dim=model.config.n_embd, max_prefetch=1) else: raise RuntimeError @@ -248,7 +248,7 @@ def main(): prof.step() tflops_list.sort() - median_index = ((NUM_STEPS - WARMUP_STEPS) >> 1) + WARMUP_STEPS + median_index = min(((NUM_STEPS - WARMUP_STEPS) >> 1) + WARMUP_STEPS, len(tflops_list) - 1) logger.info(f"Median TFLOPS is {tflops_list[median_index]:.3f}") torch.cuda.synchronize() From 6bbe956316d62202f1b19ea1033419d6a3ee91ea Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 16 May 2024 07:26:19 +0000 Subject: [PATCH 06/36] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- colossalai/zero/gemini/gemini_ddp.py | 2 +- colossalai/zero/gemini/gemini_hook.py | 20 ++++++++++--------- colossalai/zero/gemini/gemini_mgr.py | 4 ++-- colossalai/zero/gemini/placement_policy.py | 23 +++++++++++++++++----- 4 files changed, 32 insertions(+), 17 deletions(-) diff --git a/colossalai/zero/gemini/gemini_ddp.py b/colossalai/zero/gemini/gemini_ddp.py index 6bf0b4019..3a0ae59fc 100644 --- a/colossalai/zero/gemini/gemini_ddp.py +++ b/colossalai/zero/gemini/gemini_ddp.py @@ -131,7 +131,7 @@ class GeminiDDP(ModelWrapper): offload_param_frac=offload_param_frac, warmup_non_model_data_ratio=warmup_non_model_data_ratio, steady_cuda_cap_ratio=steady_cuda_cap_ratio, - max_prefetch=max_prefetch + max_prefetch=max_prefetch, ) self.force_outputs_fp32 = force_outputs_fp32 self.param_op_hook = GeminiZeROHook(self.gemini_manager) diff --git a/colossalai/zero/gemini/gemini_hook.py b/colossalai/zero/gemini/gemini_hook.py index 1d734bd83..e6b8cf8ef 100644 --- a/colossalai/zero/gemini/gemini_hook.py +++ b/colossalai/zero/gemini/gemini_hook.py @@ -1,10 +1,9 @@ from contextlib import contextmanager from enum import Enum from functools import partial -from typing import Dict, List, Iterable, Tuple +from typing import List import torch -import torch.distributed as dist from colossalai.logging import DistributedLogger from colossalai.tensor.param_op_hook import ColoParamOpHook @@ -12,8 +11,6 @@ from colossalai.utils import is_ddp_ignored from colossalai.zero.gemini import TensorState from colossalai.zero.gemini.gemini_mgr import GeminiManager -from .chunk import Chunk - class TrainingPhase(Enum): FORWARD = 0 @@ -23,7 +20,9 @@ class TrainingPhase(Enum): logger = DistributedLogger("gemini_hook") import os -rank = int(os.environ['RANK']) + +rank = int(os.environ["RANK"]) + class GeminiZeROHook(ColoParamOpHook): def __init__(self, gemini_manager: GeminiManager) -> None: @@ -32,14 +31,13 @@ class GeminiZeROHook(ColoParamOpHook): self._chunk_manager = gemini_manager.chunk_manager self._training_phase = TrainingPhase.FORWARD - def pre_op(self, params): # map params to chunks params = [p for p in params if not is_ddp_ignored(p)] all_chunks = self._chunk_manager.get_chunks(params) # wait for prefetched chunks, filter those are not prefetched - unique_chunks = set(all_chunks) + set(all_chunks) chunks_fetch_sync = self._gemini_manager.wait_chunks(all_chunks) # transfer state @@ -48,7 +46,9 @@ class GeminiZeROHook(ColoParamOpHook): self._gemini_manager.sample_overall_data() # evit chunks, aware of async fetched - self._gemini_manager.adjust_layout(all_chunks, record_anyway=self._gemini_manager.placement_policy.max_prefetch > 0) + self._gemini_manager.adjust_layout( + all_chunks, record_anyway=self._gemini_manager.placement_policy.max_prefetch > 0 + ) # fetch the rest synchronously for chunk in chunks_fetch_sync: @@ -57,7 +57,9 @@ class GeminiZeROHook(ColoParamOpHook): # get possible chunks to prefetch chunks_fetch_async = self._gemini_manager.placement_policy.get_prefetch_chunks() if rank == 0 and not self._gemini_manager.is_warmup(): - print(f"compute_id: {self._gemini_manager.compute_idx} self._gemini_manager.compute_list: {self._gemini_manager.compute_list}") + print( + f"compute_id: {self._gemini_manager.compute_idx} self._gemini_manager.compute_list: {self._gemini_manager.compute_list}" + ) print(f"{all_chunks=}") print(f"accessed_chunks={self._chunk_manager.accessed_chunks}") print(f"{chunks_fetch_sync=}") diff --git a/colossalai/zero/gemini/gemini_mgr.py b/colossalai/zero/gemini/gemini_mgr.py index 6640bf03b..11bde789c 100644 --- a/colossalai/zero/gemini/gemini_mgr.py +++ b/colossalai/zero/gemini/gemini_mgr.py @@ -1,6 +1,6 @@ import functools from time import time -from typing import Dict, List, Optional, Tuple, Iterable +from typing import Dict, Iterable, List, Optional, Tuple import torch import torch.distributed as dist @@ -101,7 +101,7 @@ class GeminiManager: start = time() self._record_warmup_chunks_order(chunks, record_anyway=record_anyway) cuda_demand, can_evict_chunks = self._get_layout_info(self._compute_idx, self._warmup, chunks) - # don't evict chunks that are asynchronously fetched + # don't evict chunks that are asynchronously fetched can_evict_chunks = [chunk for chunk in can_evict_chunks if chunk not in self._async_works] self._layout_time += time() - start diff --git a/colossalai/zero/gemini/placement_policy.py b/colossalai/zero/gemini/placement_policy.py index aad97321c..e5f61a033 100644 --- a/colossalai/zero/gemini/placement_policy.py +++ b/colossalai/zero/gemini/placement_policy.py @@ -13,11 +13,17 @@ from colossalai.zero.gemini.chunk import Chunk from .chunk import Chunk, ChunkManager from .memory_tracer import ChunkMemStatsCollector + class PlacementPolicy(ABC): need_mem_stats: bool = False def __init__( - self, gemini_manager: 'GeminiManager', chunk_manager: ChunkManager, mem_stats_collector: Optional[ChunkMemStatsCollector] = None, max_prefetch:int = 0, **kwargs + self, + gemini_manager: "GeminiManager", + chunk_manager: ChunkManager, + mem_stats_collector: Optional[ChunkMemStatsCollector] = None, + max_prefetch: int = 0, + **kwargs, ) -> None: self.gemini_manager = gemini_manager self.chunk_manager = chunk_manager @@ -38,13 +44,16 @@ class PlacementPolicy(ABC): def get_prefetch_chunks(self) -> List[Chunk]: raise NotImplementedError + import os + rank = int(os.environ["RANK"]) + class StaticPlacementPolicy(PlacementPolicy): def __init__( self, - gemini_manager: 'GeminiManager', + gemini_manager: "GeminiManager", chunk_manager: ChunkManager, mem_stats_collector: Optional[ChunkMemStatsCollector] = None, max_prefetch: int = 0, @@ -53,7 +62,9 @@ class StaticPlacementPolicy(PlacementPolicy): offload_param_frac: float = 0.0, **kwargs, ) -> None: - super().__init__(gemini_manager, chunk_manager, mem_stats_collector=mem_stats_collector, max_prefetch=max_prefetch) + super().__init__( + gemini_manager, chunk_manager, mem_stats_collector=mem_stats_collector, max_prefetch=max_prefetch + ) if offload_param_frac > 0.0 and (shard_param_frac != 1.0 or offload_optim_frac != 1.0): warnings.warn("offload_param_frac is ignored when shard_param_frac != 1.0 or offload_optim_frac != 1.0") offload_param_frac = 0.0 @@ -124,7 +135,7 @@ class AutoPlacementPolicy(PlacementPolicy): def __init__( self, - gemini_manager: 'GeminiManager', + gemini_manager: "GeminiManager", chunk_manager: ChunkManager, mem_stats_collector: Optional[ChunkMemStatsCollector] = None, max_prefetch: int = 0, @@ -132,7 +143,9 @@ class AutoPlacementPolicy(PlacementPolicy): steady_cuda_cap_ratio: float = 0.9, **kwargs, ) -> None: - super().__init__(gemini_manager, chunk_manager, mem_stats_collector=mem_stats_collector, max_prefetch=max_prefetch) + super().__init__( + gemini_manager, chunk_manager, mem_stats_collector=mem_stats_collector, max_prefetch=max_prefetch + ) # model data will use 1-_warmup_non_model_data_ratio CUDA memory in warmup phase # you can set them by AutoPlacementPolicy.set_warmup_non_model_data_ratio() # and AutoPlacementPolicy.set_steady_cuda_cap_ratio() From 5470e5f94e302b1a60ee2c1add0caf5f0e879e42 Mon Sep 17 00:00:00 2001 From: genghaozhe <939857490@qq.com> Date: Thu, 16 May 2024 08:03:40 +0000 Subject: [PATCH 07/36] a commit for fake push test --- tests/test_zero/test_gemini/test_optim.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/tests/test_zero/test_gemini/test_optim.py b/tests/test_zero/test_gemini/test_optim.py index a9366e7bc..1c914ca0e 100644 --- a/tests/test_zero/test_gemini/test_optim.py +++ b/tests/test_zero/test_gemini/test_optim.py @@ -40,9 +40,7 @@ EXAMPLE_MODELS = [ ] # bfloat16 cannot represent them exactly -BF16_IGNORED_KEYS = [ - "masked_bias", -] +BF16_IGNORED_KEYS = ["masked_bias"] def check_param(model: GeminiDDP, torch_model: torch.nn.Module, dtype: torch.dtype): From f45f8a2aa74e6734f0f7ad89729d7e00b5d3d985 Mon Sep 17 00:00:00 2001 From: hxwang Date: Thu, 16 May 2024 16:12:53 +0800 Subject: [PATCH 08/36] [gemini] maxprefetch means maximum work to keep --- colossalai/zero/gemini/placement_policy.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/colossalai/zero/gemini/placement_policy.py b/colossalai/zero/gemini/placement_policy.py index e5f61a033..c0f92fa50 100644 --- a/colossalai/zero/gemini/placement_policy.py +++ b/colossalai/zero/gemini/placement_policy.py @@ -118,14 +118,15 @@ class StaticPlacementPolicy(PlacementPolicy): def get_prefetch_chunks(self) -> List[Chunk]: if self.gemini_manager.is_warmup(): # no prefetch during warmup since we need compute_list return [] + can_prefetch = self.max_prefetch - len(self.gemini_manager._async_works) prefetch = [] for i in range(self.gemini_manager.compute_idx + 1, len(self.gemini_manager.compute_list)): for chunk in self.gemini_manager.compute_list[i]: - if len(prefetch) >= self.max_prefetch: + if len(prefetch) >= can_prefetch: break if chunk not in prefetch and chunk not in self.chunk_manager.accessed_chunks: prefetch.append(chunk) - if len(prefetch) >= self.max_prefetch: + if len(prefetch) >= can_prefetch: break return prefetch From 20701d45330df8bd2eed3a803f58b67a990e5ece Mon Sep 17 00:00:00 2001 From: hxwang Date: Thu, 16 May 2024 16:45:50 +0800 Subject: [PATCH 09/36] [chore] remove print --- colossalai/zero/gemini/gemini_hook.py | 16 +--------------- 1 file changed, 1 insertion(+), 15 deletions(-) diff --git a/colossalai/zero/gemini/gemini_hook.py b/colossalai/zero/gemini/gemini_hook.py index e6b8cf8ef..d1fd0867f 100644 --- a/colossalai/zero/gemini/gemini_hook.py +++ b/colossalai/zero/gemini/gemini_hook.py @@ -19,10 +19,6 @@ class TrainingPhase(Enum): logger = DistributedLogger("gemini_hook") -import os - -rank = int(os.environ["RANK"]) - class GeminiZeROHook(ColoParamOpHook): def __init__(self, gemini_manager: GeminiManager) -> None: @@ -56,23 +52,13 @@ class GeminiZeROHook(ColoParamOpHook): # get possible chunks to prefetch chunks_fetch_async = self._gemini_manager.placement_policy.get_prefetch_chunks() - if rank == 0 and not self._gemini_manager.is_warmup(): - print( - f"compute_id: {self._gemini_manager.compute_idx} self._gemini_manager.compute_list: {self._gemini_manager.compute_list}" - ) - print(f"{all_chunks=}") - print(f"accessed_chunks={self._chunk_manager.accessed_chunks}") - print(f"{chunks_fetch_sync=}") - print(f"{chunks_fetch_async=}") - print(f"works={list(self._gemini_manager._async_works.keys())}") # prefetch for chunk in chunks_fetch_async: maybe_work = self._chunk_manager.access_chunk(chunk, async_access=True) if maybe_work is not None: self._gemini_manager.add_work(chunk, maybe_work) - if rank == 0 and not self._gemini_manager.is_warmup(): - print(f"post accessed_chunks={self._chunk_manager.accessed_chunks}") + # record cuda model data of the current OP, including memory for prefetched chunks self._gemini_manager.record_model_data_volume() From 6efbadba25d6192ed1b54f3d50491a8b91aacc77 Mon Sep 17 00:00:00 2001 From: hxwang Date: Thu, 16 May 2024 16:46:39 +0800 Subject: [PATCH 10/36] [chore] remove debugging info --- colossalai/zero/gemini/chunk/chunk.py | 1 - 1 file changed, 1 deletion(-) diff --git a/colossalai/zero/gemini/chunk/chunk.py b/colossalai/zero/gemini/chunk/chunk.py index 81b6192ad..299ea0518 100644 --- a/colossalai/zero/gemini/chunk/chunk.py +++ b/colossalai/zero/gemini/chunk/chunk.py @@ -567,7 +567,6 @@ class Chunk: return self is __o def __repr__(self, detailed: bool = True): - return f"Chunk({self.count_id})" output = [ "Chunk Information:\n", "\tchunk size: {}, chunk dtype: {}, process group size: {}\n".format( From 013690a86b6906d0b3eaa11173a9510856be09b3 Mon Sep 17 00:00:00 2001 From: genghaozhe <939857490@qq.com> Date: Thu, 16 May 2024 09:57:51 +0000 Subject: [PATCH 11/36] remove set(all_chunks) --- colossalai/zero/gemini/gemini_hook.py | 1 - 1 file changed, 1 deletion(-) diff --git a/colossalai/zero/gemini/gemini_hook.py b/colossalai/zero/gemini/gemini_hook.py index d1fd0867f..27a19c132 100644 --- a/colossalai/zero/gemini/gemini_hook.py +++ b/colossalai/zero/gemini/gemini_hook.py @@ -33,7 +33,6 @@ class GeminiZeROHook(ColoParamOpHook): all_chunks = self._chunk_manager.get_chunks(params) # wait for prefetched chunks, filter those are not prefetched - set(all_chunks) chunks_fetch_sync = self._gemini_manager.wait_chunks(all_chunks) # transfer state From e57812c6727e325971cb0d8769c0789c088f62ae Mon Sep 17 00:00:00 2001 From: botbw Date: Fri, 17 May 2024 13:42:18 +0800 Subject: [PATCH 12/36] [chore] Update placement_policy.py --- colossalai/zero/gemini/placement_policy.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/colossalai/zero/gemini/placement_policy.py b/colossalai/zero/gemini/placement_policy.py index c0f92fa50..e9e871b46 100644 --- a/colossalai/zero/gemini/placement_policy.py +++ b/colossalai/zero/gemini/placement_policy.py @@ -45,11 +45,6 @@ class PlacementPolicy(ABC): raise NotImplementedError -import os - -rank = int(os.environ["RANK"]) - - class StaticPlacementPolicy(PlacementPolicy): def __init__( self, From 3d625ca83656793d81eaece4c49967bd4fafcf7d Mon Sep 17 00:00:00 2001 From: genghaozhe <939857490@qq.com> Date: Fri, 17 May 2024 10:55:28 +0000 Subject: [PATCH 13/36] add some todo Message --- colossalai/zero/gemini/chunk/manager.py | 2 +- colossalai/zero/gemini/gemini_hook.py | 5 ++++- colossalai/zero/gemini/gemini_mgr.py | 3 ++- colossalai/zero/gemini/placement_policy.py | 12 ++++++++---- 4 files changed, 15 insertions(+), 7 deletions(-) diff --git a/colossalai/zero/gemini/chunk/manager.py b/colossalai/zero/gemini/chunk/manager.py index c7bdd5e1f..341790a72 100644 --- a/colossalai/zero/gemini/chunk/manager.py +++ b/colossalai/zero/gemini/chunk/manager.py @@ -83,7 +83,7 @@ class ChunkManager: if chunk_group: # the chunk group is not empty # close the last chunk - self.__close_one_chunk(chunk_group[-1]) + self.__close_one_chunk(chunk_group[-1]) # chunk[-1] 满了,所以关闭,不能再添加,然后同时scatter到ZeRO PG中 if tensor.numel() > chunk_size: chunk_size = tensor.numel() diff --git a/colossalai/zero/gemini/gemini_hook.py b/colossalai/zero/gemini/gemini_hook.py index 27a19c132..bf990d127 100644 --- a/colossalai/zero/gemini/gemini_hook.py +++ b/colossalai/zero/gemini/gemini_hook.py @@ -33,19 +33,22 @@ class GeminiZeROHook(ColoParamOpHook): all_chunks = self._chunk_manager.get_chunks(params) # wait for prefetched chunks, filter those are not prefetched - chunks_fetch_sync = self._gemini_manager.wait_chunks(all_chunks) + chunks_fetch_sync = self._gemini_manager.wait_chunks(all_chunks) # 当前要fetch的chunk # transfer state for p in params: + # TODO(haze188): check状态转换 self._chunk_manager.trans_tensor_state(p, TensorState.COMPUTE) self._gemini_manager.sample_overall_data() # evit chunks, aware of async fetched + # TODO(haze188): 可能我们prefetch的又被淘汰掉, check一下 self._gemini_manager.adjust_layout( all_chunks, record_anyway=self._gemini_manager.placement_policy.max_prefetch > 0 ) # fetch the rest synchronously + # TODO(haze188): 1. 先prefetch还是先fetch(prefetch是异步,fetch是同步) for chunk in chunks_fetch_sync: self._chunk_manager.access_chunk(chunk) diff --git a/colossalai/zero/gemini/gemini_mgr.py b/colossalai/zero/gemini/gemini_mgr.py index 11bde789c..2e96c22f3 100644 --- a/colossalai/zero/gemini/gemini_mgr.py +++ b/colossalai/zero/gemini/gemini_mgr.py @@ -125,7 +125,7 @@ class GeminiManager: self._async_works[chunk].wait() del self._async_works[chunk] else: - non_prefetched_chunks.append(chunk) + non_prefetched_chunks.append(chunk) # 没在之前prefetch过,现在要prefetch的chunk return tuple(non_prefetched_chunks) def add_work(self, chunk: Chunk, work: dist.Work): @@ -154,6 +154,7 @@ class GeminiManager: def _record_warmup_chunks_order(self, chunks: Tuple[Chunk, ...], record_anyway: bool = False) -> None: self._compute_idx += 1 + # TODO(haze188): _compute_list 记录块的访问顺序 if self._warmup and (self._placement_policy.need_mem_stats or record_anyway): self._compute_list.append(chunks) diff --git a/colossalai/zero/gemini/placement_policy.py b/colossalai/zero/gemini/placement_policy.py index c0f92fa50..4c3d8dbe2 100644 --- a/colossalai/zero/gemini/placement_policy.py +++ b/colossalai/zero/gemini/placement_policy.py @@ -45,9 +45,9 @@ class PlacementPolicy(ABC): raise NotImplementedError -import os - -rank = int(os.environ["RANK"]) +# import torch.distributed as dist +# # rank = int(os.environ["RANK"]) +# rank = dist.get_rank() class StaticPlacementPolicy(PlacementPolicy): @@ -118,8 +118,10 @@ class StaticPlacementPolicy(PlacementPolicy): def get_prefetch_chunks(self) -> List[Chunk]: if self.gemini_manager.is_warmup(): # no prefetch during warmup since we need compute_list return [] + # 最多有多少个异步的work can_prefetch = self.max_prefetch - len(self.gemini_manager._async_works) prefetch = [] + # static炸就炸了,dynamic可能需要我们要先分析当前运行时的内存情况,分配空间或者淘汰块 for i in range(self.gemini_manager.compute_idx + 1, len(self.gemini_manager.compute_list)): for chunk in self.gemini_manager.compute_list[i]: if len(prefetch) >= can_prefetch: @@ -238,7 +240,9 @@ class AutoPlacementPolicy(PlacementPolicy): grads_device_map[p] = torch.device("cpu") def get_prefetch_chunks(self, max_prefetch: int) -> List[Chunk]: - return [] # TODO @botbw: implement prefetching for auto + # TODO @haze188 @botbw: implement prefetching for auto + + return [] class PlacementPolicyFactory: From 06a3a100b330d10b28615af285642c6667ba8c23 Mon Sep 17 00:00:00 2001 From: genghaozhe <939857490@qq.com> Date: Fri, 17 May 2024 10:57:49 +0000 Subject: [PATCH 14/36] remove unrelated code --- colossalai/zero/gemini/placement_policy.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/colossalai/zero/gemini/placement_policy.py b/colossalai/zero/gemini/placement_policy.py index 4c3d8dbe2..c0d03ba3b 100644 --- a/colossalai/zero/gemini/placement_policy.py +++ b/colossalai/zero/gemini/placement_policy.py @@ -45,11 +45,6 @@ class PlacementPolicy(ABC): raise NotImplementedError -# import torch.distributed as dist -# # rank = int(os.environ["RANK"]) -# rank = dist.get_rank() - - class StaticPlacementPolicy(PlacementPolicy): def __init__( self, From a55a9e298bad86297eca89923d24d5db9b1f0aaf Mon Sep 17 00:00:00 2001 From: hxwang Date: Mon, 20 May 2024 02:21:17 +0000 Subject: [PATCH 15/36] [gemini] init auto policy prefetch --- colossalai/zero/gemini/placement_policy.py | 37 ++++++++++++++++++---- 1 file changed, 30 insertions(+), 7 deletions(-) diff --git a/colossalai/zero/gemini/placement_policy.py b/colossalai/zero/gemini/placement_policy.py index e9e871b46..a48f8d0d0 100644 --- a/colossalai/zero/gemini/placement_policy.py +++ b/colossalai/zero/gemini/placement_policy.py @@ -19,7 +19,7 @@ class PlacementPolicy(ABC): def __init__( self, - gemini_manager: "GeminiManager", + gemini_manager: "GeminiManager", # TODO @botbw: solve circular import chunk_manager: ChunkManager, mem_stats_collector: Optional[ChunkMemStatsCollector] = None, max_prefetch: int = 0, @@ -40,9 +40,8 @@ class PlacementPolicy(ABC): ) -> None: raise NotImplementedError - @abstractmethod def get_prefetch_chunks(self) -> List[Chunk]: - raise NotImplementedError + return [] # no prefetch by default class StaticPlacementPolicy(PlacementPolicy): @@ -116,12 +115,14 @@ class StaticPlacementPolicy(PlacementPolicy): can_prefetch = self.max_prefetch - len(self.gemini_manager._async_works) prefetch = [] for i in range(self.gemini_manager.compute_idx + 1, len(self.gemini_manager.compute_list)): + break_flag = False for chunk in self.gemini_manager.compute_list[i]: if len(prefetch) >= can_prefetch: + break_flag = True break if chunk not in prefetch and chunk not in self.chunk_manager.accessed_chunks: prefetch.append(chunk) - if len(prefetch) >= can_prefetch: + if break_flag: break return prefetch @@ -232,9 +233,31 @@ class AutoPlacementPolicy(PlacementPolicy): else: grads_device_map[p] = torch.device("cpu") - def get_prefetch_chunks(self, max_prefetch: int) -> List[Chunk]: - return [] # TODO @botbw: implement prefetching for auto - + def get_prefetch_chunks(self) -> List[Chunk]: + if self.gemini_manager.is_warmup(): # no prefetch during warmup since we need compute_list + return [] + # modified from self.evict_tensors + cuda_capacity = self._steady_cuda_cap_ratio * colo_device_memory_capacity(get_accelerator().get_current_device()) + max_cuda_non_model_data_per_period = self.mem_stats_collector.next_period_non_model_data_usage("cuda") + used_cuda_model_data = self.chunk_manager.total_mem["cuda"] + total_cuda_model_data = cuda_capacity - max_cuda_non_model_data_per_period + avail_cuda_model_data = total_cuda_model_data - used_cuda_model_data + + prefetch_chunk_memory = 0 + can_prefetch = self.max_prefetch - len(self.gemini_manager._async_works) + prefetch = [] + for i in range(self.gemini_manager.compute_idx + 1, len(self.gemini_manager.compute_list)): + break_flag = False + for chunk in self.gemini_manager.compute_list[i]: + chunk: Chunk + if len(prefetch) >= can_prefetch or prefetch_chunk_memory + chunk.chunk_mem > avail_cuda_model_data: + break_flag = True + break + if chunk not in prefetch and chunk not in self.chunk_manager.accessed_chunks: + prefetch.append(chunk) + if break_flag: + break + return prefetch class PlacementPolicyFactory: policies: Dict[str, Type[PlacementPolicy]] = { From f1918e18a5051113290f9702a47e11266db492f2 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 20 May 2024 03:00:06 +0000 Subject: [PATCH 16/36] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- colossalai/zero/gemini/placement_policy.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/colossalai/zero/gemini/placement_policy.py b/colossalai/zero/gemini/placement_policy.py index a48f8d0d0..cae5cc202 100644 --- a/colossalai/zero/gemini/placement_policy.py +++ b/colossalai/zero/gemini/placement_policy.py @@ -237,12 +237,14 @@ class AutoPlacementPolicy(PlacementPolicy): if self.gemini_manager.is_warmup(): # no prefetch during warmup since we need compute_list return [] # modified from self.evict_tensors - cuda_capacity = self._steady_cuda_cap_ratio * colo_device_memory_capacity(get_accelerator().get_current_device()) + cuda_capacity = self._steady_cuda_cap_ratio * colo_device_memory_capacity( + get_accelerator().get_current_device() + ) max_cuda_non_model_data_per_period = self.mem_stats_collector.next_period_non_model_data_usage("cuda") used_cuda_model_data = self.chunk_manager.total_mem["cuda"] total_cuda_model_data = cuda_capacity - max_cuda_non_model_data_per_period avail_cuda_model_data = total_cuda_model_data - used_cuda_model_data - + prefetch_chunk_memory = 0 can_prefetch = self.max_prefetch - len(self.gemini_manager._async_works) prefetch = [] @@ -259,6 +261,7 @@ class AutoPlacementPolicy(PlacementPolicy): break return prefetch + class PlacementPolicyFactory: policies: Dict[str, Type[PlacementPolicy]] = { "auto": AutoPlacementPolicy, From d22bf30ca645aac20265373fe34db281db6abb2e Mon Sep 17 00:00:00 2001 From: genghaozhe <939857490@qq.com> Date: Mon, 20 May 2024 04:01:53 +0000 Subject: [PATCH 17/36] implement auto policy prefetch and modify a little origin code. --- colossalai/zero/gemini/placement_policy.py | 36 ++++++++++++++++++---- 1 file changed, 30 insertions(+), 6 deletions(-) diff --git a/colossalai/zero/gemini/placement_policy.py b/colossalai/zero/gemini/placement_policy.py index c0d03ba3b..9803d7f6d 100644 --- a/colossalai/zero/gemini/placement_policy.py +++ b/colossalai/zero/gemini/placement_policy.py @@ -11,6 +11,7 @@ from colossalai.legacy.utils.memory import colo_device_memory_capacity from colossalai.zero.gemini.chunk import Chunk from .chunk import Chunk, ChunkManager +from .gemini_mgr import GeminiManager from .memory_tracer import ChunkMemStatsCollector @@ -123,8 +124,9 @@ class StaticPlacementPolicy(PlacementPolicy): break if chunk not in prefetch and chunk not in self.chunk_manager.accessed_chunks: prefetch.append(chunk) - if len(prefetch) >= can_prefetch: - break + else: + continue + break return prefetch @@ -133,7 +135,7 @@ class AutoPlacementPolicy(PlacementPolicy): def __init__( self, - gemini_manager: "GeminiManager", + gemini_manager: GeminiManager, chunk_manager: ChunkManager, mem_stats_collector: Optional[ChunkMemStatsCollector] = None, max_prefetch: int = 0, @@ -234,10 +236,32 @@ class AutoPlacementPolicy(PlacementPolicy): else: grads_device_map[p] = torch.device("cpu") - def get_prefetch_chunks(self, max_prefetch: int) -> List[Chunk]: - # TODO @haze188 @botbw: implement prefetching for auto + def get_prefetch_chunks(self) -> List[Chunk]: + if self.gemini_manager.is_warmup(): # no prefetch during warmup since we need compute_list + return [] + # modified from self.evict_tensors + cuda_capacity = self._steady_cuda_cap_ratio * colo_device_memory_capacity( + get_accelerator().get_current_device() + ) + max_cuda_non_model_data_per_period = self.mem_stats_collector.next_period_non_model_data_usage("cuda") + used_cuda_model_data = self.chunk_manager.total_mem["cuda"] + total_cuda_model_data = cuda_capacity - max_cuda_non_model_data_per_period + avail_cuda_model_data = total_cuda_model_data - used_cuda_model_data - return [] + prefetch_chunk_memory = 0 + can_prefetch = self.max_prefetch - len(self.gemini_manager._async_works) + prefetch = [] + for i in range(self.gemini_manager.compute_idx + 1, len(self.gemini_manager.compute_list)): + for chunk in self.gemini_manager.compute_list[i]: + chunk: Chunk + if len(prefetch) >= can_prefetch or prefetch_chunk_memory + chunk.chunk_mem > avail_cuda_model_data: + break + if chunk not in prefetch and chunk not in self.chunk_manager.accessed_chunks: + prefetch.append(chunk) + else: + continue + break + return prefetch class PlacementPolicyFactory: From df63db7e63d951017fd1fa797fbcaec259fba644 Mon Sep 17 00:00:00 2001 From: genghaozhe <939857490@qq.com> Date: Mon, 20 May 2024 05:15:51 +0000 Subject: [PATCH 18/36] remote comments --- colossalai/zero/gemini/gemini_hook.py | 4 +- colossalai/zero/gemini/gemini_mgr.py | 1 - examples/language/gpt/gemini/demo.ipynb | 142 ++++++++++++++++++ examples/language/gpt/gemini/run_gemini.sh | 2 +- .../language/gpt/gemini/train_gpt_demo.py | 6 +- tests/test_zero/test_gemini/test_optim.py | 2 +- 6 files changed, 148 insertions(+), 9 deletions(-) create mode 100644 examples/language/gpt/gemini/demo.ipynb diff --git a/colossalai/zero/gemini/gemini_hook.py b/colossalai/zero/gemini/gemini_hook.py index bf990d127..e691b423b 100644 --- a/colossalai/zero/gemini/gemini_hook.py +++ b/colossalai/zero/gemini/gemini_hook.py @@ -37,18 +37,16 @@ class GeminiZeROHook(ColoParamOpHook): # transfer state for p in params: - # TODO(haze188): check状态转换 self._chunk_manager.trans_tensor_state(p, TensorState.COMPUTE) self._gemini_manager.sample_overall_data() # evit chunks, aware of async fetched - # TODO(haze188): 可能我们prefetch的又被淘汰掉, check一下 + # TODO: check if prefetched chunks will be evicted self._gemini_manager.adjust_layout( all_chunks, record_anyway=self._gemini_manager.placement_policy.max_prefetch > 0 ) # fetch the rest synchronously - # TODO(haze188): 1. 先prefetch还是先fetch(prefetch是异步,fetch是同步) for chunk in chunks_fetch_sync: self._chunk_manager.access_chunk(chunk) diff --git a/colossalai/zero/gemini/gemini_mgr.py b/colossalai/zero/gemini/gemini_mgr.py index 2e96c22f3..85beafd32 100644 --- a/colossalai/zero/gemini/gemini_mgr.py +++ b/colossalai/zero/gemini/gemini_mgr.py @@ -154,7 +154,6 @@ class GeminiManager: def _record_warmup_chunks_order(self, chunks: Tuple[Chunk, ...], record_anyway: bool = False) -> None: self._compute_idx += 1 - # TODO(haze188): _compute_list 记录块的访问顺序 if self._warmup and (self._placement_policy.need_mem_stats or record_anyway): self._compute_list.append(chunks) diff --git a/examples/language/gpt/gemini/demo.ipynb b/examples/language/gpt/gemini/demo.ipynb new file mode 100644 index 000000000..09953b3a9 --- /dev/null +++ b/examples/language/gpt/gemini/demo.ipynb @@ -0,0 +1,142 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "import torch\n", + "import torch.nn as nn" + ] + }, + { + "cell_type": "code", + "execution_count": 23, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Linear(in_features=10, out_features=5, bias=False) 50\n", + "Linear(in_features=5, out_features=10, bias=False) 50\n", + "Linear(in_features=10, out_features=10, bias=False) 100\n" + ] + } + ], + "source": [ + "class Toy(nn.Module):\n", + " \n", + " def __init__(self):\n", + " super(Toy, self).__init__()\n", + " self.fc1 = nn.Linear(10,5, bias=False)\n", + " self.m3 = nn.Sequential(nn.Linear(5, 10, bias=False), nn.Linear(10,10, bias=False))\n", + "\n", + "t = Toy()\n", + "for mod in t.modules():\n", + " for p in mod.parameters(recurse=False):\n", + " print(mod, p.numel())" + ] + }, + { + "cell_type": "code", + "execution_count": 24, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "torch.Size([5, 10]) 50\n", + "torch.Size([10, 5]) 50\n", + "torch.Size([10, 10]) 100\n" + ] + } + ], + "source": [ + "for p in t.parameters():\n", + " print(p.shape, p.numel())" + ] + }, + { + "cell_type": "code", + "execution_count": 27, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "'224'" + ] + }, + "execution_count": 27, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "conf_str = torch.__config__.parallel_info()\n", + "inter_str = conf_str.split(\"hardware_concurrency() : \")[1]\n", + "max_concurrency = inter_str.split(\"\\n\")[0]\n", + "max_concurrency" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "0 0\n", + "0 1\n", + "0 2\n", + "1 0\n", + "1 1\n", + "1 2\n" + ] + } + ], + "source": [ + "for i in range(3):\n", + " for j in range(3):\n", + " print(i, j)\n", + " if i == 1 and j == 2:break\n", + " else:\n", + " continue\n", + " break" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "colossalai-py310", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.14" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/examples/language/gpt/gemini/run_gemini.sh b/examples/language/gpt/gemini/run_gemini.sh index 5eaa4af4d..bffd26f59 100644 --- a/examples/language/gpt/gemini/run_gemini.sh +++ b/examples/language/gpt/gemini/run_gemini.sh @@ -6,7 +6,7 @@ export DISTPLAN=${DISTPLAN:-"CAI_Gemini"} export GPUNUM=${GPUNUM:-1} export BATCH_SIZE=${BATCH_SIZE:-16} export MODEL_TYPE=${MODEL_TYPE:-"gpt2_medium"} -export TRAIN_STEP=${TRAIN_STEP:-10} +export TRAIN_STEP=${TRAIN_STEP:-2} # export PYTHONPATH=$PWD:$PYTHONPATH diff --git a/examples/language/gpt/gemini/train_gpt_demo.py b/examples/language/gpt/gemini/train_gpt_demo.py index 6db74231a..667a0c77a 100644 --- a/examples/language/gpt/gemini/train_gpt_demo.py +++ b/examples/language/gpt/gemini/train_gpt_demo.py @@ -66,18 +66,18 @@ class GPTLMLoss(nn.Module): def get_cpu_mem(): - return psutil.Process().memory_info().rss / 1024**2 + return psutil.Process().memory_info().rss / 1024**2 # 返回值是B,转换成MB def get_gpu_mem(): - return torch.cuda.memory_allocated() / 1024**2 + return torch.cuda.memory_allocated() / 1024**2 # 转换成MB def get_mem_info(prefix=""): return f"{prefix}GPU memory usage: {get_gpu_mem():.2f} MB, CPU memory usage: {get_cpu_mem():.2f} MB" -def get_model_size(model: nn.Module): +def get_model_size(model: nn.Module): # 得到模型参数量 total_numel = 0 for module in model.modules(): for p in module.parameters(recurse=False): diff --git a/tests/test_zero/test_gemini/test_optim.py b/tests/test_zero/test_gemini/test_optim.py index 1c914ca0e..4e1fb988b 100644 --- a/tests/test_zero/test_gemini/test_optim.py +++ b/tests/test_zero/test_gemini/test_optim.py @@ -26,7 +26,7 @@ PLACEMENT_CONFIGS = [ "offload_optim_frac": 1.0, "offload_param_frac": 1.0, }, # zero3-offload-all - {"placement_policy": "auto"}, + # {"placement_policy": "auto"}, ] # this model is large enough to slice to chunks From 5c6c5d6be316a4f4e867d0d8049b508e0d59ad6c Mon Sep 17 00:00:00 2001 From: genghaozhe <939857490@qq.com> Date: Mon, 20 May 2024 05:15:51 +0000 Subject: [PATCH 19/36] remove comments --- colossalai/zero/gemini/gemini_hook.py | 4 +- colossalai/zero/gemini/gemini_mgr.py | 1 - examples/language/gpt/gemini/demo.ipynb | 142 ++++++++++++++++++ examples/language/gpt/gemini/run_gemini.sh | 2 +- .../language/gpt/gemini/train_gpt_demo.py | 6 +- tests/test_zero/test_gemini/test_optim.py | 2 +- 6 files changed, 148 insertions(+), 9 deletions(-) create mode 100644 examples/language/gpt/gemini/demo.ipynb diff --git a/colossalai/zero/gemini/gemini_hook.py b/colossalai/zero/gemini/gemini_hook.py index bf990d127..e691b423b 100644 --- a/colossalai/zero/gemini/gemini_hook.py +++ b/colossalai/zero/gemini/gemini_hook.py @@ -37,18 +37,16 @@ class GeminiZeROHook(ColoParamOpHook): # transfer state for p in params: - # TODO(haze188): check状态转换 self._chunk_manager.trans_tensor_state(p, TensorState.COMPUTE) self._gemini_manager.sample_overall_data() # evit chunks, aware of async fetched - # TODO(haze188): 可能我们prefetch的又被淘汰掉, check一下 + # TODO: check if prefetched chunks will be evicted self._gemini_manager.adjust_layout( all_chunks, record_anyway=self._gemini_manager.placement_policy.max_prefetch > 0 ) # fetch the rest synchronously - # TODO(haze188): 1. 先prefetch还是先fetch(prefetch是异步,fetch是同步) for chunk in chunks_fetch_sync: self._chunk_manager.access_chunk(chunk) diff --git a/colossalai/zero/gemini/gemini_mgr.py b/colossalai/zero/gemini/gemini_mgr.py index 2e96c22f3..85beafd32 100644 --- a/colossalai/zero/gemini/gemini_mgr.py +++ b/colossalai/zero/gemini/gemini_mgr.py @@ -154,7 +154,6 @@ class GeminiManager: def _record_warmup_chunks_order(self, chunks: Tuple[Chunk, ...], record_anyway: bool = False) -> None: self._compute_idx += 1 - # TODO(haze188): _compute_list 记录块的访问顺序 if self._warmup and (self._placement_policy.need_mem_stats or record_anyway): self._compute_list.append(chunks) diff --git a/examples/language/gpt/gemini/demo.ipynb b/examples/language/gpt/gemini/demo.ipynb new file mode 100644 index 000000000..09953b3a9 --- /dev/null +++ b/examples/language/gpt/gemini/demo.ipynb @@ -0,0 +1,142 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "import torch\n", + "import torch.nn as nn" + ] + }, + { + "cell_type": "code", + "execution_count": 23, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Linear(in_features=10, out_features=5, bias=False) 50\n", + "Linear(in_features=5, out_features=10, bias=False) 50\n", + "Linear(in_features=10, out_features=10, bias=False) 100\n" + ] + } + ], + "source": [ + "class Toy(nn.Module):\n", + " \n", + " def __init__(self):\n", + " super(Toy, self).__init__()\n", + " self.fc1 = nn.Linear(10,5, bias=False)\n", + " self.m3 = nn.Sequential(nn.Linear(5, 10, bias=False), nn.Linear(10,10, bias=False))\n", + "\n", + "t = Toy()\n", + "for mod in t.modules():\n", + " for p in mod.parameters(recurse=False):\n", + " print(mod, p.numel())" + ] + }, + { + "cell_type": "code", + "execution_count": 24, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "torch.Size([5, 10]) 50\n", + "torch.Size([10, 5]) 50\n", + "torch.Size([10, 10]) 100\n" + ] + } + ], + "source": [ + "for p in t.parameters():\n", + " print(p.shape, p.numel())" + ] + }, + { + "cell_type": "code", + "execution_count": 27, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "'224'" + ] + }, + "execution_count": 27, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "conf_str = torch.__config__.parallel_info()\n", + "inter_str = conf_str.split(\"hardware_concurrency() : \")[1]\n", + "max_concurrency = inter_str.split(\"\\n\")[0]\n", + "max_concurrency" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "0 0\n", + "0 1\n", + "0 2\n", + "1 0\n", + "1 1\n", + "1 2\n" + ] + } + ], + "source": [ + "for i in range(3):\n", + " for j in range(3):\n", + " print(i, j)\n", + " if i == 1 and j == 2:break\n", + " else:\n", + " continue\n", + " break" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "colossalai-py310", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.14" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/examples/language/gpt/gemini/run_gemini.sh b/examples/language/gpt/gemini/run_gemini.sh index 5eaa4af4d..bffd26f59 100644 --- a/examples/language/gpt/gemini/run_gemini.sh +++ b/examples/language/gpt/gemini/run_gemini.sh @@ -6,7 +6,7 @@ export DISTPLAN=${DISTPLAN:-"CAI_Gemini"} export GPUNUM=${GPUNUM:-1} export BATCH_SIZE=${BATCH_SIZE:-16} export MODEL_TYPE=${MODEL_TYPE:-"gpt2_medium"} -export TRAIN_STEP=${TRAIN_STEP:-10} +export TRAIN_STEP=${TRAIN_STEP:-2} # export PYTHONPATH=$PWD:$PYTHONPATH diff --git a/examples/language/gpt/gemini/train_gpt_demo.py b/examples/language/gpt/gemini/train_gpt_demo.py index 6db74231a..667a0c77a 100644 --- a/examples/language/gpt/gemini/train_gpt_demo.py +++ b/examples/language/gpt/gemini/train_gpt_demo.py @@ -66,18 +66,18 @@ class GPTLMLoss(nn.Module): def get_cpu_mem(): - return psutil.Process().memory_info().rss / 1024**2 + return psutil.Process().memory_info().rss / 1024**2 # 返回值是B,转换成MB def get_gpu_mem(): - return torch.cuda.memory_allocated() / 1024**2 + return torch.cuda.memory_allocated() / 1024**2 # 转换成MB def get_mem_info(prefix=""): return f"{prefix}GPU memory usage: {get_gpu_mem():.2f} MB, CPU memory usage: {get_cpu_mem():.2f} MB" -def get_model_size(model: nn.Module): +def get_model_size(model: nn.Module): # 得到模型参数量 total_numel = 0 for module in model.modules(): for p in module.parameters(recurse=False): diff --git a/tests/test_zero/test_gemini/test_optim.py b/tests/test_zero/test_gemini/test_optim.py index 1c914ca0e..4e1fb988b 100644 --- a/tests/test_zero/test_gemini/test_optim.py +++ b/tests/test_zero/test_gemini/test_optim.py @@ -26,7 +26,7 @@ PLACEMENT_CONFIGS = [ "offload_optim_frac": 1.0, "offload_param_frac": 1.0, }, # zero3-offload-all - {"placement_policy": "auto"}, + # {"placement_policy": "auto"}, ] # this model is large enough to slice to chunks From 1ec92d29af16fcfc1b641e61eded877c5680cc47 Mon Sep 17 00:00:00 2001 From: genghaozhe <939857490@qq.com> Date: Mon, 20 May 2024 05:21:26 +0000 Subject: [PATCH 20/36] remove perf log, unrelated file and so on --- colossalai/zero/gemini/chunk/manager.py | 2 +- colossalai/zero/gemini/gemini_hook.py | 2 +- colossalai/zero/gemini/gemini_mgr.py | 2 +- colossalai/zero/gemini/placement_policy.py | 2 - examples/language/gpt/gemini/demo.ipynb | 142 ------------------ .../language/gpt/gemini/train_gpt_demo.py | 6 +- tests/test_zero/test_gemini/test_optim.py | 2 +- 7 files changed, 7 insertions(+), 151 deletions(-) delete mode 100644 examples/language/gpt/gemini/demo.ipynb diff --git a/colossalai/zero/gemini/chunk/manager.py b/colossalai/zero/gemini/chunk/manager.py index 341790a72..c7bdd5e1f 100644 --- a/colossalai/zero/gemini/chunk/manager.py +++ b/colossalai/zero/gemini/chunk/manager.py @@ -83,7 +83,7 @@ class ChunkManager: if chunk_group: # the chunk group is not empty # close the last chunk - self.__close_one_chunk(chunk_group[-1]) # chunk[-1] 满了,所以关闭,不能再添加,然后同时scatter到ZeRO PG中 + self.__close_one_chunk(chunk_group[-1]) if tensor.numel() > chunk_size: chunk_size = tensor.numel() diff --git a/colossalai/zero/gemini/gemini_hook.py b/colossalai/zero/gemini/gemini_hook.py index e691b423b..450cb3ad6 100644 --- a/colossalai/zero/gemini/gemini_hook.py +++ b/colossalai/zero/gemini/gemini_hook.py @@ -33,7 +33,7 @@ class GeminiZeROHook(ColoParamOpHook): all_chunks = self._chunk_manager.get_chunks(params) # wait for prefetched chunks, filter those are not prefetched - chunks_fetch_sync = self._gemini_manager.wait_chunks(all_chunks) # 当前要fetch的chunk + chunks_fetch_sync = self._gemini_manager.wait_chunks(all_chunks) # transfer state for p in params: diff --git a/colossalai/zero/gemini/gemini_mgr.py b/colossalai/zero/gemini/gemini_mgr.py index 85beafd32..11bde789c 100644 --- a/colossalai/zero/gemini/gemini_mgr.py +++ b/colossalai/zero/gemini/gemini_mgr.py @@ -125,7 +125,7 @@ class GeminiManager: self._async_works[chunk].wait() del self._async_works[chunk] else: - non_prefetched_chunks.append(chunk) # 没在之前prefetch过,现在要prefetch的chunk + non_prefetched_chunks.append(chunk) return tuple(non_prefetched_chunks) def add_work(self, chunk: Chunk, work: dist.Work): diff --git a/colossalai/zero/gemini/placement_policy.py b/colossalai/zero/gemini/placement_policy.py index 9e9fb1f58..cfbf16d1b 100644 --- a/colossalai/zero/gemini/placement_policy.py +++ b/colossalai/zero/gemini/placement_policy.py @@ -113,10 +113,8 @@ class StaticPlacementPolicy(PlacementPolicy): def get_prefetch_chunks(self) -> List[Chunk]: if self.gemini_manager.is_warmup(): # no prefetch during warmup since we need compute_list return [] - # 最多有多少个异步的work can_prefetch = self.max_prefetch - len(self.gemini_manager._async_works) prefetch = [] - # static炸就炸了,dynamic可能需要我们要先分析当前运行时的内存情况,分配空间或者淘汰块 for i in range(self.gemini_manager.compute_idx + 1, len(self.gemini_manager.compute_list)): for chunk in self.gemini_manager.compute_list[i]: if len(prefetch) >= can_prefetch: diff --git a/examples/language/gpt/gemini/demo.ipynb b/examples/language/gpt/gemini/demo.ipynb deleted file mode 100644 index 09953b3a9..000000000 --- a/examples/language/gpt/gemini/demo.ipynb +++ /dev/null @@ -1,142 +0,0 @@ -{ - "cells": [ - { - "cell_type": "code", - "execution_count": 1, - "metadata": {}, - "outputs": [], - "source": [ - "import torch\n", - "import torch.nn as nn" - ] - }, - { - "cell_type": "code", - "execution_count": 23, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Linear(in_features=10, out_features=5, bias=False) 50\n", - "Linear(in_features=5, out_features=10, bias=False) 50\n", - "Linear(in_features=10, out_features=10, bias=False) 100\n" - ] - } - ], - "source": [ - "class Toy(nn.Module):\n", - " \n", - " def __init__(self):\n", - " super(Toy, self).__init__()\n", - " self.fc1 = nn.Linear(10,5, bias=False)\n", - " self.m3 = nn.Sequential(nn.Linear(5, 10, bias=False), nn.Linear(10,10, bias=False))\n", - "\n", - "t = Toy()\n", - "for mod in t.modules():\n", - " for p in mod.parameters(recurse=False):\n", - " print(mod, p.numel())" - ] - }, - { - "cell_type": "code", - "execution_count": 24, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "torch.Size([5, 10]) 50\n", - "torch.Size([10, 5]) 50\n", - "torch.Size([10, 10]) 100\n" - ] - } - ], - "source": [ - "for p in t.parameters():\n", - " print(p.shape, p.numel())" - ] - }, - { - "cell_type": "code", - "execution_count": 27, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "'224'" - ] - }, - "execution_count": 27, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "conf_str = torch.__config__.parallel_info()\n", - "inter_str = conf_str.split(\"hardware_concurrency() : \")[1]\n", - "max_concurrency = inter_str.split(\"\\n\")[0]\n", - "max_concurrency" - ] - }, - { - "cell_type": "code", - "execution_count": 8, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "0 0\n", - "0 1\n", - "0 2\n", - "1 0\n", - "1 1\n", - "1 2\n" - ] - } - ], - "source": [ - "for i in range(3):\n", - " for j in range(3):\n", - " print(i, j)\n", - " if i == 1 and j == 2:break\n", - " else:\n", - " continue\n", - " break" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] - } - ], - "metadata": { - "kernelspec": { - "display_name": "colossalai-py310", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.10.14" - } - }, - "nbformat": 4, - "nbformat_minor": 2 -} diff --git a/examples/language/gpt/gemini/train_gpt_demo.py b/examples/language/gpt/gemini/train_gpt_demo.py index 667a0c77a..6db74231a 100644 --- a/examples/language/gpt/gemini/train_gpt_demo.py +++ b/examples/language/gpt/gemini/train_gpt_demo.py @@ -66,18 +66,18 @@ class GPTLMLoss(nn.Module): def get_cpu_mem(): - return psutil.Process().memory_info().rss / 1024**2 # 返回值是B,转换成MB + return psutil.Process().memory_info().rss / 1024**2 def get_gpu_mem(): - return torch.cuda.memory_allocated() / 1024**2 # 转换成MB + return torch.cuda.memory_allocated() / 1024**2 def get_mem_info(prefix=""): return f"{prefix}GPU memory usage: {get_gpu_mem():.2f} MB, CPU memory usage: {get_cpu_mem():.2f} MB" -def get_model_size(model: nn.Module): # 得到模型参数量 +def get_model_size(model: nn.Module): total_numel = 0 for module in model.modules(): for p in module.parameters(recurse=False): diff --git a/tests/test_zero/test_gemini/test_optim.py b/tests/test_zero/test_gemini/test_optim.py index 4e1fb988b..1c914ca0e 100644 --- a/tests/test_zero/test_gemini/test_optim.py +++ b/tests/test_zero/test_gemini/test_optim.py @@ -26,7 +26,7 @@ PLACEMENT_CONFIGS = [ "offload_optim_frac": 1.0, "offload_param_frac": 1.0, }, # zero3-offload-all - # {"placement_policy": "auto"}, + {"placement_policy": "auto"}, ] # this model is large enough to slice to chunks From a280517dd9618247deaea729b4f1aaddbc17995c Mon Sep 17 00:00:00 2001 From: genghaozhe <939857490@qq.com> Date: Mon, 20 May 2024 05:25:35 +0000 Subject: [PATCH 21/36] remove unrelated file --- examples/language/gpt/gemini/demo.ipynb | 142 ------------------ .../language/gpt/gemini/train_gpt_demo.py | 7 +- tests/test_zero/test_gemini/test_optim.py | 2 +- 3 files changed, 5 insertions(+), 146 deletions(-) delete mode 100644 examples/language/gpt/gemini/demo.ipynb diff --git a/examples/language/gpt/gemini/demo.ipynb b/examples/language/gpt/gemini/demo.ipynb deleted file mode 100644 index 09953b3a9..000000000 --- a/examples/language/gpt/gemini/demo.ipynb +++ /dev/null @@ -1,142 +0,0 @@ -{ - "cells": [ - { - "cell_type": "code", - "execution_count": 1, - "metadata": {}, - "outputs": [], - "source": [ - "import torch\n", - "import torch.nn as nn" - ] - }, - { - "cell_type": "code", - "execution_count": 23, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Linear(in_features=10, out_features=5, bias=False) 50\n", - "Linear(in_features=5, out_features=10, bias=False) 50\n", - "Linear(in_features=10, out_features=10, bias=False) 100\n" - ] - } - ], - "source": [ - "class Toy(nn.Module):\n", - " \n", - " def __init__(self):\n", - " super(Toy, self).__init__()\n", - " self.fc1 = nn.Linear(10,5, bias=False)\n", - " self.m3 = nn.Sequential(nn.Linear(5, 10, bias=False), nn.Linear(10,10, bias=False))\n", - "\n", - "t = Toy()\n", - "for mod in t.modules():\n", - " for p in mod.parameters(recurse=False):\n", - " print(mod, p.numel())" - ] - }, - { - "cell_type": "code", - "execution_count": 24, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "torch.Size([5, 10]) 50\n", - "torch.Size([10, 5]) 50\n", - "torch.Size([10, 10]) 100\n" - ] - } - ], - "source": [ - "for p in t.parameters():\n", - " print(p.shape, p.numel())" - ] - }, - { - "cell_type": "code", - "execution_count": 27, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "'224'" - ] - }, - "execution_count": 27, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "conf_str = torch.__config__.parallel_info()\n", - "inter_str = conf_str.split(\"hardware_concurrency() : \")[1]\n", - "max_concurrency = inter_str.split(\"\\n\")[0]\n", - "max_concurrency" - ] - }, - { - "cell_type": "code", - "execution_count": 8, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "0 0\n", - "0 1\n", - "0 2\n", - "1 0\n", - "1 1\n", - "1 2\n" - ] - } - ], - "source": [ - "for i in range(3):\n", - " for j in range(3):\n", - " print(i, j)\n", - " if i == 1 and j == 2:break\n", - " else:\n", - " continue\n", - " break" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] - } - ], - "metadata": { - "kernelspec": { - "display_name": "colossalai-py310", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.10.14" - } - }, - "nbformat": 4, - "nbformat_minor": 2 -} diff --git a/examples/language/gpt/gemini/train_gpt_demo.py b/examples/language/gpt/gemini/train_gpt_demo.py index 667a0c77a..bf1be87ba 100644 --- a/examples/language/gpt/gemini/train_gpt_demo.py +++ b/examples/language/gpt/gemini/train_gpt_demo.py @@ -66,18 +66,19 @@ class GPTLMLoss(nn.Module): def get_cpu_mem(): - return psutil.Process().memory_info().rss / 1024**2 # 返回值是B,转换成MB + return psutil.Process().memory_info().rss / 1024**2 # MB unit def get_gpu_mem(): - return torch.cuda.memory_allocated() / 1024**2 # 转换成MB + return torch.cuda.memory_allocated() / 1024**2 # MB unit def get_mem_info(prefix=""): return f"{prefix}GPU memory usage: {get_gpu_mem():.2f} MB, CPU memory usage: {get_cpu_mem():.2f} MB" -def get_model_size(model: nn.Module): # 得到模型参数量 +def get_model_size(model: nn.Module): + # get the number of parameter of the model total_numel = 0 for module in model.modules(): for p in module.parameters(recurse=False): diff --git a/tests/test_zero/test_gemini/test_optim.py b/tests/test_zero/test_gemini/test_optim.py index 4e1fb988b..1c914ca0e 100644 --- a/tests/test_zero/test_gemini/test_optim.py +++ b/tests/test_zero/test_gemini/test_optim.py @@ -26,7 +26,7 @@ PLACEMENT_CONFIGS = [ "offload_optim_frac": 1.0, "offload_param_frac": 1.0, }, # zero3-offload-all - # {"placement_policy": "auto"}, + {"placement_policy": "auto"}, ] # this model is large enough to slice to chunks From bfcb2d1ff8dee52746f9d7af76ffe3acf0312ea5 Mon Sep 17 00:00:00 2001 From: genghaozhe <939857490@qq.com> Date: Mon, 20 May 2024 07:25:24 +0000 Subject: [PATCH 22/36] refactor the code structure to solve the circular import --- colossalai/zero/gemini/gemini_hook.py | 7 +++- colossalai/zero/gemini/gemini_mgr.py | 6 ++- colossalai/zero/gemini/placement_policy.py | 44 +++++++++++----------- 3 files changed, 32 insertions(+), 25 deletions(-) diff --git a/colossalai/zero/gemini/gemini_hook.py b/colossalai/zero/gemini/gemini_hook.py index 450cb3ad6..315730f7a 100644 --- a/colossalai/zero/gemini/gemini_hook.py +++ b/colossalai/zero/gemini/gemini_hook.py @@ -51,7 +51,12 @@ class GeminiZeROHook(ColoParamOpHook): self._chunk_manager.access_chunk(chunk) # get possible chunks to prefetch - chunks_fetch_async = self._gemini_manager.placement_policy.get_prefetch_chunks() + chunks_fetch_async = self._gemini_manager.placement_policy.get_prefetch_chunks( + is_warmup=self._gemini_manager.is_warmup(), + compute_list=self._gemini_manager.compute_list, + compute_idx=self._gemini_manager.compute_idx, + async_works=self._gemini_manager.async_works, + ) # prefetch for chunk in chunks_fetch_async: diff --git a/colossalai/zero/gemini/gemini_mgr.py b/colossalai/zero/gemini/gemini_mgr.py index 11bde789c..5b309c7a1 100644 --- a/colossalai/zero/gemini/gemini_mgr.py +++ b/colossalai/zero/gemini/gemini_mgr.py @@ -45,7 +45,7 @@ class GeminiManager: self._placement_policy = policy_cls(self, chunk_manager, self._mem_stats_collector, **placement_kwargs) self._compute_list: List[Tuple[Chunk, ...]] = [] self._compute_idx: int = -1 - self._async_works: Dict[Chunk, dist.work] = {} + self._async_works: Dict[Chunk, dist.Work] = {} self._h2d_volume = 0 self._d2h_volume = 0 @@ -183,6 +183,10 @@ class GeminiManager: def compute_idx(self) -> int: return self._compute_idx + @property + def async_works(self) -> Dict[Chunk, dist.Work]: + return self._async_works + @property def placement_policy(self) -> PlacementPolicy: return self._placement_policy diff --git a/colossalai/zero/gemini/placement_policy.py b/colossalai/zero/gemini/placement_policy.py index cfbf16d1b..9b1d1a6ab 100644 --- a/colossalai/zero/gemini/placement_policy.py +++ b/colossalai/zero/gemini/placement_policy.py @@ -5,13 +5,13 @@ from time import time from typing import Dict, List, Optional, Tuple, Type import torch +import torch.distributed as dist from colossalai.accelerator import get_accelerator from colossalai.legacy.utils.memory import colo_device_memory_capacity from colossalai.zero.gemini.chunk import Chunk from .chunk import Chunk, ChunkManager -from .gemini_mgr import GeminiManager from .memory_tracer import ChunkMemStatsCollector @@ -20,13 +20,11 @@ class PlacementPolicy(ABC): def __init__( self, - gemini_manager: "GeminiManager", # TODO @botbw: solve circular import chunk_manager: ChunkManager, mem_stats_collector: Optional[ChunkMemStatsCollector] = None, max_prefetch: int = 0, **kwargs, ) -> None: - self.gemini_manager = gemini_manager self.chunk_manager = chunk_manager self.mem_stats_collector: Optional[ChunkMemStatsCollector] = mem_stats_collector self.max_prefetch = max_prefetch @@ -41,14 +39,15 @@ class PlacementPolicy(ABC): ) -> None: raise NotImplementedError - def get_prefetch_chunks(self) -> List[Chunk]: + def get_prefetch_chunks( + self, is_warmup, compute_list: tuple, compute_idx: int, async_works: Dict[Chunk, dist.Work] + ) -> List[Chunk]: return [] # no prefetch by default class StaticPlacementPolicy(PlacementPolicy): def __init__( self, - gemini_manager: "GeminiManager", chunk_manager: ChunkManager, mem_stats_collector: Optional[ChunkMemStatsCollector] = None, max_prefetch: int = 0, @@ -57,9 +56,7 @@ class StaticPlacementPolicy(PlacementPolicy): offload_param_frac: float = 0.0, **kwargs, ) -> None: - super().__init__( - gemini_manager, chunk_manager, mem_stats_collector=mem_stats_collector, max_prefetch=max_prefetch - ) + super().__init__(chunk_manager, mem_stats_collector=mem_stats_collector, max_prefetch=max_prefetch) if offload_param_frac > 0.0 and (shard_param_frac != 1.0 or offload_optim_frac != 1.0): warnings.warn("offload_param_frac is ignored when shard_param_frac != 1.0 or offload_optim_frac != 1.0") offload_param_frac = 0.0 @@ -110,13 +107,15 @@ class StaticPlacementPolicy(PlacementPolicy): self.keep_gathered_chunk_mem = total_chunk_mem * (1 - self.shard_param_frac) self.keep_cuda_chunk_mem = total_chunk_mem * (1 - self.offload_param_frac) - def get_prefetch_chunks(self) -> List[Chunk]: - if self.gemini_manager.is_warmup(): # no prefetch during warmup since we need compute_list + def get_prefetch_chunks( + self, is_warmup: bool, compute_list: tuple, compute_idx: int, async_works: Dict[Chunk, dist.Work] + ) -> List[Chunk]: + if is_warmup: # no prefetch during warmup since we need compute_list return [] - can_prefetch = self.max_prefetch - len(self.gemini_manager._async_works) + can_prefetch = self.max_prefetch - len(async_works) prefetch = [] - for i in range(self.gemini_manager.compute_idx + 1, len(self.gemini_manager.compute_list)): - for chunk in self.gemini_manager.compute_list[i]: + for i in range(compute_idx + 1, len(compute_list)): + for chunk in compute_list[i]: if len(prefetch) >= can_prefetch: break if chunk not in prefetch and chunk not in self.chunk_manager.accessed_chunks: @@ -132,7 +131,6 @@ class AutoPlacementPolicy(PlacementPolicy): def __init__( self, - gemini_manager: GeminiManager, chunk_manager: ChunkManager, mem_stats_collector: Optional[ChunkMemStatsCollector] = None, max_prefetch: int = 0, @@ -140,9 +138,7 @@ class AutoPlacementPolicy(PlacementPolicy): steady_cuda_cap_ratio: float = 0.9, **kwargs, ) -> None: - super().__init__( - gemini_manager, chunk_manager, mem_stats_collector=mem_stats_collector, max_prefetch=max_prefetch - ) + super().__init__(chunk_manager, mem_stats_collector=mem_stats_collector, max_prefetch=max_prefetch) # model data will use 1-_warmup_non_model_data_ratio CUDA memory in warmup phase # you can set them by AutoPlacementPolicy.set_warmup_non_model_data_ratio() # and AutoPlacementPolicy.set_steady_cuda_cap_ratio() @@ -233,8 +229,10 @@ class AutoPlacementPolicy(PlacementPolicy): else: grads_device_map[p] = torch.device("cpu") - def get_prefetch_chunks(self) -> List[Chunk]: - if self.gemini_manager.is_warmup(): # no prefetch during warmup since we need compute_list + def get_prefetch_chunks( + self, is_warmup: bool, compute_list: tuple, compute_idx: int, async_works: Dict[Chunk, dist.Work] + ) -> List[Chunk]: + if is_warmup: # no prefetch during warmup since we need compute_list return [] # modified from self.evict_tensors cuda_capacity = self._steady_cuda_cap_ratio * colo_device_memory_capacity( @@ -246,14 +244,14 @@ class AutoPlacementPolicy(PlacementPolicy): avail_cuda_model_data = total_cuda_model_data - used_cuda_model_data prefetch_chunk_memory = 0 - can_prefetch = self.max_prefetch - len(self.gemini_manager._async_works) + can_prefetch = self.max_prefetch - len(async_works) prefetch = [] - for i in range(self.gemini_manager.compute_idx + 1, len(self.gemini_manager.compute_list)): - for chunk in self.gemini_manager.compute_list[i]: - chunk: Chunk + for i in range(compute_idx + 1, len(compute_list)): + for chunk in compute_list[i]: if len(prefetch) >= can_prefetch or prefetch_chunk_memory + chunk.chunk_mem > avail_cuda_model_data: break if chunk not in prefetch and chunk not in self.chunk_manager.accessed_chunks: + prefetch_chunk_memory += chunk.chunk_mem prefetch.append(chunk) else: continue From 90d8d0183c39832cc2a5951d4d4437a69e878a18 Mon Sep 17 00:00:00 2001 From: genghaozhe <939857490@qq.com> Date: Mon, 20 May 2024 07:28:20 +0000 Subject: [PATCH 23/36] remove personal comments --- colossalai/zero/gemini/gemini_hook.py | 1 - 1 file changed, 1 deletion(-) diff --git a/colossalai/zero/gemini/gemini_hook.py b/colossalai/zero/gemini/gemini_hook.py index 315730f7a..cab26c822 100644 --- a/colossalai/zero/gemini/gemini_hook.py +++ b/colossalai/zero/gemini/gemini_hook.py @@ -41,7 +41,6 @@ class GeminiZeROHook(ColoParamOpHook): self._gemini_manager.sample_overall_data() # evit chunks, aware of async fetched - # TODO: check if prefetched chunks will be evicted self._gemini_manager.adjust_layout( all_chunks, record_anyway=self._gemini_manager.placement_policy.max_prefetch > 0 ) From 137a7c341b2cf893458f00e4db86d9b8f761f08d Mon Sep 17 00:00:00 2001 From: hxwang Date: Tue, 21 May 2024 02:07:21 +0000 Subject: [PATCH 24/36] [chore] fix init error --- colossalai/zero/gemini/gemini_mgr.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/colossalai/zero/gemini/gemini_mgr.py b/colossalai/zero/gemini/gemini_mgr.py index 5b309c7a1..332f86512 100644 --- a/colossalai/zero/gemini/gemini_mgr.py +++ b/colossalai/zero/gemini/gemini_mgr.py @@ -42,7 +42,7 @@ class GeminiManager: self._mem_stats_collector = ( ChunkMemStatsCollector(chunk_manager, self._memstats) if policy_cls.need_mem_stats else None ) - self._placement_policy = policy_cls(self, chunk_manager, self._mem_stats_collector, **placement_kwargs) + self._placement_policy = policy_cls(chunk_manager=chunk_manager, mem_stats_collector=self._mem_stats_collector, **placement_kwargs) self._compute_list: List[Tuple[Chunk, ...]] = [] self._compute_idx: int = -1 self._async_works: Dict[Chunk, dist.Work] = {} From b3c0e6d87159fb1061c2b2940743bdde8a81e454 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 21 May 2024 02:09:14 +0000 Subject: [PATCH 25/36] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- colossalai/zero/gemini/gemini_mgr.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/colossalai/zero/gemini/gemini_mgr.py b/colossalai/zero/gemini/gemini_mgr.py index 332f86512..0a8e0ae4a 100644 --- a/colossalai/zero/gemini/gemini_mgr.py +++ b/colossalai/zero/gemini/gemini_mgr.py @@ -42,7 +42,9 @@ class GeminiManager: self._mem_stats_collector = ( ChunkMemStatsCollector(chunk_manager, self._memstats) if policy_cls.need_mem_stats else None ) - self._placement_policy = policy_cls(chunk_manager=chunk_manager, mem_stats_collector=self._mem_stats_collector, **placement_kwargs) + self._placement_policy = policy_cls( + chunk_manager=chunk_manager, mem_stats_collector=self._mem_stats_collector, **placement_kwargs + ) self._compute_list: List[Tuple[Chunk, ...]] = [] self._compute_idx: int = -1 self._async_works: Dict[Chunk, dist.Work] = {} From 13c06d36a3504af6e2f1d6f1ce08bf400ac11f1c Mon Sep 17 00:00:00 2001 From: botbw Date: Tue, 21 May 2024 14:21:58 +0800 Subject: [PATCH 26/36] [bug] fix early return (#5740) * [bug] fix silly bug * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * [chore] add test for prefetch * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- colossalai/zero/gemini/chunk/chunk.py | 5 +++-- colossalai/zero/gemini/gemini_hook.py | 4 ---- colossalai/zero/gemini/gemini_mgr.py | 8 ++++---- tests/test_zero/test_gemini/test_fwd_bwd.py | 10 +++++++++- tests/test_zero/test_gemini/test_grad_accum.py | 9 ++++++++- tests/test_zero/test_gemini/test_grad_clip.py | 4 +++- tests/test_zero/test_gemini/test_optim.py | 12 ++++++++++-- .../test_zero/test_gemini/test_zeroddp_state_dict.py | 12 ++++++++++-- .../test_gemini/test_zerooptim_state_dict.py | 5 +++-- 9 files changed, 50 insertions(+), 19 deletions(-) diff --git a/colossalai/zero/gemini/chunk/chunk.py b/colossalai/zero/gemini/chunk/chunk.py index 299ea0518..c4a4f245a 100644 --- a/colossalai/zero/gemini/chunk/chunk.py +++ b/colossalai/zero/gemini/chunk/chunk.py @@ -361,10 +361,11 @@ class Chunk: """Make the chunk usable for the parameters inside it. It's an operation done in CUDA.""" # sanity check assert self.chunk_temp is None + maybe_work = None if not self.is_gathered: - return self.__gather(async_op=async_access) + maybe_work = self.__gather(async_op=async_access) self.__update_tensors_ptr() - return None + return maybe_work def release_chunk(self): """Release the usable chunk. It's an operation done in CUDA.""" diff --git a/colossalai/zero/gemini/gemini_hook.py b/colossalai/zero/gemini/gemini_hook.py index cab26c822..736238a09 100644 --- a/colossalai/zero/gemini/gemini_hook.py +++ b/colossalai/zero/gemini/gemini_hook.py @@ -5,7 +5,6 @@ from typing import List import torch -from colossalai.logging import DistributedLogger from colossalai.tensor.param_op_hook import ColoParamOpHook from colossalai.utils import is_ddp_ignored from colossalai.zero.gemini import TensorState @@ -17,9 +16,6 @@ class TrainingPhase(Enum): BACKWARD = 1 -logger = DistributedLogger("gemini_hook") - - class GeminiZeROHook(ColoParamOpHook): def __init__(self, gemini_manager: GeminiManager) -> None: super().__init__() diff --git a/colossalai/zero/gemini/gemini_mgr.py b/colossalai/zero/gemini/gemini_mgr.py index 0a8e0ae4a..83e475575 100644 --- a/colossalai/zero/gemini/gemini_mgr.py +++ b/colossalai/zero/gemini/gemini_mgr.py @@ -177,6 +177,10 @@ class GeminiManager: return self._mem_stats_collector.cuda_margin_mem return None + @property + def placement_policy(self) -> PlacementPolicy: + return self._placement_policy + @property def compute_list(self) -> List[Tuple[Chunk, ...]]: return self._compute_list @@ -189,10 +193,6 @@ class GeminiManager: def async_works(self) -> Dict[Chunk, dist.Work]: return self._async_works - @property - def placement_policy(self) -> PlacementPolicy: - return self._placement_policy - @property def is_cuda_margin_mem_avail(self) -> bool: return self._placement_policy.need_mem_stats diff --git a/tests/test_zero/test_gemini/test_fwd_bwd.py b/tests/test_zero/test_gemini/test_fwd_bwd.py index 570a0aa42..478ace3d4 100644 --- a/tests/test_zero/test_gemini/test_fwd_bwd.py +++ b/tests/test_zero/test_gemini/test_fwd_bwd.py @@ -40,12 +40,14 @@ def check_grad(model: GeminiDDP, torch_model: torch.nn.Module): @parameterize("model_name", ["transformers_gpt_lm"]) @parameterize("use_grad_checkpoint", [False, True]) @parameterize("master_weights", [False, True]) +@parameterize("max_prefetch", [0, 1, 4]) def exam_gpt_fwd_bwd( placement_config, keep_gather, model_name: str, use_grad_checkpoint: bool = False, master_weights: bool = True, + max_prefetch: int = 0, ): init_device = get_accelerator().get_current_device() model_builder, data_gen_fn, output_transform_fn, loss_fn, *_ = next( @@ -69,7 +71,13 @@ def exam_gpt_fwd_bwd( config_dict[world_size]["chunk_size"] = 5000 config_dict[world_size]["keep_gathered"] = keep_gather model = GeminiDDP( - model, config_dict, init_device, pin_memory=True, **placement_config, master_weights=master_weights + model, + config_dict, + init_device, + pin_memory=True, + **placement_config, + master_weights=master_weights, + max_prefetch=max_prefetch, ) optimizer = HybridAdam(model.parameters(), lr=1e-3) zero_optim = GeminiOptimizer(optimizer, model, initial_scale=1) diff --git a/tests/test_zero/test_gemini/test_grad_accum.py b/tests/test_zero/test_gemini/test_grad_accum.py index fd0e9fd7c..11d29c50f 100644 --- a/tests/test_zero/test_gemini/test_grad_accum.py +++ b/tests/test_zero/test_gemini/test_grad_accum.py @@ -50,8 +50,14 @@ def check_grad(model: GeminiDDP, torch_model: torch.nn.Module): @parameterize("model_name", ["transformers_gpt_lm"]) @parameterize("master_weights", [False, True]) @parameterize("use_grad_checkpoint", [False, True]) +@parameterize("max_prefetch", [0, 1, 4]) def exam_gemini_grad_acc( - placement_config, keep_gathered: bool, model_name: str, master_weights: bool, use_grad_checkpoint: bool + placement_config, + keep_gathered: bool, + model_name: str, + master_weights: bool, + use_grad_checkpoint: bool, + max_prefetch: int, ): init_device = get_accelerator().get_current_device() model_builder, data_gen_fn, output_transform_fn, loss_fn, *_ = next( @@ -81,6 +87,7 @@ def exam_gemini_grad_acc( pin_memory=True, enable_gradient_accumulation=True, master_weights=master_weights, + max_prefetch=max_prefetch, **placement_config, ) optimizer = HybridAdam(gemini_model.parameters(), lr=1e-3) diff --git a/tests/test_zero/test_gemini/test_grad_clip.py b/tests/test_zero/test_gemini/test_grad_clip.py index 0a9bac092..ad6dc2f78 100644 --- a/tests/test_zero/test_gemini/test_grad_clip.py +++ b/tests/test_zero/test_gemini/test_grad_clip.py @@ -52,7 +52,8 @@ def check_param(model: GeminiDDP, torch_model: torch.nn.Module): @parameterize("placement_config", PLACEMENT_CONFIGS) @parameterize("model_name", ["transformers_gpt_lm"]) @parameterize("master_weights", [True, False]) -def exam_grad_clipping(placement_config, model_name: str, master_weights: bool): +@parameterize("max_prefetch", [0, 1, 4]) +def exam_grad_clipping(placement_config, model_name: str, master_weights: bool, max_prefetch: int): set_seed(1912) model_builder, data_gen_fn, output_transform_fn, loss_fn, *_ = next( iter(model_zoo.get_sub_registry(model_name).values()) @@ -84,6 +85,7 @@ def exam_grad_clipping(placement_config, model_name: str, master_weights: bool): chunk_init_device=init_device, pin_memory=True, master_weights=master_weights, + max_prefetch=max_prefetch, **placement_config, ) diff --git a/tests/test_zero/test_gemini/test_optim.py b/tests/test_zero/test_gemini/test_optim.py index 1c914ca0e..eab55f190 100644 --- a/tests/test_zero/test_gemini/test_optim.py +++ b/tests/test_zero/test_gemini/test_optim.py @@ -71,7 +71,10 @@ def check_param(model: GeminiDDP, torch_model: torch.nn.Module, dtype: torch.dty @parameterize("model_name", TEST_MODELS) @parameterize("mixed_precision", [torch.half, torch.bfloat16]) @parameterize("master_weights", [True, False]) -def exam_model_step(placement_config, model_name: str, mixed_precision: torch.dtype, master_weights: bool): +@parameterize("max_prefetch", [0, 1, 4]) +def exam_model_step( + placement_config, model_name: str, mixed_precision: torch.dtype, master_weights: bool, max_prefetch: int +): set_seed(42) model_builder, data_gen_fn, output_transform_fn, loss_fn, *_ = next( iter(model_zoo.get_sub_registry(model_name).values()) @@ -94,7 +97,12 @@ def exam_model_step(placement_config, model_name: str, mixed_precision: torch.dt config_dict[world_size]["chunk_size"] = 5000 config_dict[world_size]["keep_gathered"] = False model = GeminiDDP( - model, config_dict, **placement_config, mixed_precision=mixed_precision, master_weights=master_weights + model, + config_dict, + **placement_config, + mixed_precision=mixed_precision, + master_weights=master_weights, + max_prefetch=max_prefetch, ) optimizer = HybridAdam(model.parameters(), lr=1e-3) diff --git a/tests/test_zero/test_gemini/test_zeroddp_state_dict.py b/tests/test_zero/test_gemini/test_zeroddp_state_dict.py index 23e2d8083..3cbd36917 100644 --- a/tests/test_zero/test_gemini/test_zeroddp_state_dict.py +++ b/tests/test_zero/test_gemini/test_zeroddp_state_dict.py @@ -28,7 +28,8 @@ def ignore_the_first_parameter(model: torch.nn.Module): @parameterize("keep_gathered", [True, False]) @parameterize("model_name", ["transformers_gpt_lm", "transformers_bert_for_sequence_classification"]) @parameterize("master_weights", [False, True]) -def exam_state_dict(placement_config, keep_gathered, model_name: str, master_weights: bool): +@parameterize("max_prefetch", [0, 1, 4]) +def exam_state_dict(placement_config, keep_gathered, model_name: str, master_weights: bool, max_prefetch: int): set_seed(431) model_builder, data_gen_fn, output_transform_fn, *_ = next(iter(model_zoo.get_sub_registry(model_name).values())) @@ -44,7 +45,14 @@ def exam_state_dict(placement_config, keep_gathered, model_name: str, master_wei config_dict, *_ = search_chunk_configuration(model, search_range_m=1, search_interval=100) config_dict[world_size]["chunk_size"] = 5000 config_dict[world_size]["keep_gathered"] = keep_gathered - model = GeminiDDP(model, config_dict, **placement_config, pin_memory=True, master_weights=master_weights) + model = GeminiDDP( + model, + config_dict, + **placement_config, + pin_memory=True, + master_weights=master_weights, + max_prefetch=max_prefetch, + ) model.train() zero_dict = model.state_dict(only_rank_0=False) diff --git a/tests/test_zero/test_gemini/test_zerooptim_state_dict.py b/tests/test_zero/test_gemini/test_zerooptim_state_dict.py index 8d70ae3b1..a721c96a1 100644 --- a/tests/test_zero/test_gemini/test_zerooptim_state_dict.py +++ b/tests/test_zero/test_gemini/test_zerooptim_state_dict.py @@ -20,7 +20,8 @@ PLACEMENT_CONFIGS = [ @parameterize("placement_config", PLACEMENT_CONFIGS) @parameterize("keep_gathered", [True, False]) -def exam_zero_optim_state_dict(placement_config, keep_gathered): +@parameterize("max_prefetch", [0, 1, 4]) +def exam_zero_optim_state_dict(placement_config, keep_gathered, max_prefetch): set_seed(431) model_builder, data_gen_fn, output_transform_fn, *_ = next( iter(model_zoo.get_sub_registry("transformers_gpt_lm").values()) @@ -35,7 +36,7 @@ def exam_zero_optim_state_dict(placement_config, keep_gathered): config_dict[world_size]["chunk_size"] = 5000 config_dict[world_size]["keep_gathered"] = keep_gathered - model = GeminiDDP(model, config_dict, **placement_config, pin_memory=True) + model = GeminiDDP(model, config_dict, **placement_config, pin_memory=True, max_prefetch=max_prefetch) optimizer = HybridAdam(model.parameters()) optim = GeminiOptimizer(optimizer, model, initial_scale=32) # initialize the link between chunk16 and chunk32 From 63c057cd8e4974279f0e829231af42f8171d0a10 Mon Sep 17 00:00:00 2001 From: hxwang Date: Fri, 24 May 2024 03:59:36 +0000 Subject: [PATCH 27/36] [example] add profile util for llama --- examples/language/llama/benchmark.py | 55 ++++++++++++++-------- examples/language/performance_evaluator.py | 22 +++++++++ 2 files changed, 57 insertions(+), 20 deletions(-) diff --git a/examples/language/llama/benchmark.py b/examples/language/llama/benchmark.py index 5cc602181..106251776 100644 --- a/examples/language/llama/benchmark.py +++ b/examples/language/llama/benchmark.py @@ -1,11 +1,12 @@ import argparse import resource +import time from contextlib import nullcontext import torch from data_utils import RandomDataset from model_utils import format_numel_str, get_model_numel -from performance_evaluator import PerformanceEvaluator +from performance_evaluator import PerformanceEvaluator, get_profile_context from torch.distributed.fsdp.fully_sharded_data_parallel import CPUOffload, MixedPrecision from tqdm import tqdm from transformers import AutoConfig, AutoModelForCausalLM @@ -76,6 +77,7 @@ def main(): parser.add_argument("--mbs", type=int, default=1, help="Micro batch size of pipeline parallel") parser.add_argument("--zero", type=int, default=0, help="Zero Stage when hybrid plugin is enabled") parser.add_argument("--custom-ckpt", action="store_true", help="Customize checkpoint", default=False) + parser.add_argument("--profile", action="store_true", help="Enable profiling", default=False) args = parser.parse_args() colossalai.launch_from_torch() @@ -110,6 +112,7 @@ def main(): extra_dp_size=args.extra_dp, enable_fused_normalization=torch.cuda.is_available(), enable_flash_attention=args.xformers, + max_prefetch=10, ) elif args.plugin == "gemini_auto": plugin = GeminiPlugin( @@ -246,25 +249,37 @@ def main(): f"Booster init max CPU memory: {resource.getrusage(resource.RUSAGE_SELF).ru_maxrss/1024:.2f} MB" ) - if isinstance(plugin, HybridParallelPlugin) and args.pp > 1: - data_iter = iter(dataloader) - for step in tqdm(range(len(dataloader)), desc="Step", disable=not coordinator.is_master()): - performance_evaluator.on_step_start(step) - booster.execute_pipeline( - data_iter, model, criterion=lambda outputs, inputs: outputs[0], optimizer=optimizer, return_loss=False - ) - optimizer.step() - optimizer.zero_grad() - performance_evaluator.on_step_end(input_ids=torch.empty(args.batch_size, args.max_length)) - else: - for step, batch in enumerate(tqdm(dataloader, desc="Step", disable=not coordinator.is_master())): - performance_evaluator.on_step_start(step) - outputs = model(**batch) - loss = outputs[0] - booster.backward(loss, optimizer) - optimizer.step() - optimizer.zero_grad() - performance_evaluator.on_step_end(**batch) + with get_profile_context( + args.profile, + 1, + len(dataloader) - 1, + save_dir=f"profile/{time.strftime('%H:%M', time.localtime())}-{args.plugin}-llama-{args.config}", + ) as prof: + if isinstance(plugin, HybridParallelPlugin) and args.pp > 1: + data_iter = iter(dataloader) + for step in tqdm(range(len(dataloader)), desc="Step", disable=not coordinator.is_master()): + performance_evaluator.on_step_start(step) + booster.execute_pipeline( + data_iter, + model, + criterion=lambda outputs, inputs: outputs[0], + optimizer=optimizer, + return_loss=False, + ) + optimizer.step() + optimizer.zero_grad() + performance_evaluator.on_step_end(input_ids=torch.empty(args.batch_size, args.max_length)) + prof.step() + else: + for step, batch in enumerate(tqdm(dataloader, desc="Step", disable=not coordinator.is_master())): + performance_evaluator.on_step_start(step) + outputs = model(**batch) + loss = outputs[0] + booster.backward(loss, optimizer) + optimizer.step() + optimizer.zero_grad() + performance_evaluator.on_step_end(**batch) + prof.step() performance_evaluator.on_fit_end() coordinator.print_on_master(f"Max CUDA memory usage: {get_accelerator().max_memory_allocated()/1024**2:.2f} MB") diff --git a/examples/language/performance_evaluator.py b/examples/language/performance_evaluator.py index c2169a730..0b147b7ea 100644 --- a/examples/language/performance_evaluator.py +++ b/examples/language/performance_evaluator.py @@ -4,6 +4,7 @@ from typing import Optional import torch import torch.distributed as dist from torch import Tensor +from torch.profiler import ProfilerActivity, profile, schedule, tensorboard_trace_handler from colossalai.accelerator import get_accelerator from colossalai.cluster import DistCoordinator @@ -27,6 +28,27 @@ def all_reduce_mean(x: float, world_size: int) -> float: return tensor.item() +def get_profile_context(enable_flag, warmup_steps, active_steps, save_dir): + class DummyProfiler: + def __init__(self): + self.step_number = 0 + + def step(self): + self.step_number += 1 + + if enable_flag: + return profile( + activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], + schedule=schedule(wait=0, warmup=warmup_steps, active=active_steps), + on_trace_ready=tensorboard_trace_handler(save_dir), + # record_shapes=True, + # profile_memory=True, + with_stack=True, + ) + else: + return DummyProfiler() + + class Timer: def __init__(self) -> None: self.start_time: Optional[float] = None From ca674549e0ea815f949d19654b110d7f065e3496 Mon Sep 17 00:00:00 2001 From: hxwang Date: Fri, 24 May 2024 06:09:36 +0000 Subject: [PATCH 28/36] [chore] remove unnecessary test & changes --- examples/language/gpt/gemini/run_gemini.sh | 2 +- examples/language/gpt/gemini/train_gpt_demo.py | 11 +++++------ tests/test_zero/test_gemini/test_fwd_bwd.py | 2 +- tests/test_zero/test_gemini/test_grad_accum.py | 2 +- tests/test_zero/test_gemini/test_optim.py | 13 ++++--------- .../test_gemini/test_zeroddp_state_dict.py | 12 ++---------- .../test_gemini/test_zerooptim_state_dict.py | 5 ++--- 7 files changed, 16 insertions(+), 31 deletions(-) diff --git a/examples/language/gpt/gemini/run_gemini.sh b/examples/language/gpt/gemini/run_gemini.sh index bffd26f59..5eaa4af4d 100644 --- a/examples/language/gpt/gemini/run_gemini.sh +++ b/examples/language/gpt/gemini/run_gemini.sh @@ -6,7 +6,7 @@ export DISTPLAN=${DISTPLAN:-"CAI_Gemini"} export GPUNUM=${GPUNUM:-1} export BATCH_SIZE=${BATCH_SIZE:-16} export MODEL_TYPE=${MODEL_TYPE:-"gpt2_medium"} -export TRAIN_STEP=${TRAIN_STEP:-2} +export TRAIN_STEP=${TRAIN_STEP:-10} # export PYTHONPATH=$PWD:$PYTHONPATH diff --git a/examples/language/gpt/gemini/train_gpt_demo.py b/examples/language/gpt/gemini/train_gpt_demo.py index bf1be87ba..4911ff124 100644 --- a/examples/language/gpt/gemini/train_gpt_demo.py +++ b/examples/language/gpt/gemini/train_gpt_demo.py @@ -66,11 +66,11 @@ class GPTLMLoss(nn.Module): def get_cpu_mem(): - return psutil.Process().memory_info().rss / 1024**2 # MB unit + return psutil.Process().memory_info().rss / 1024**2 def get_gpu_mem(): - return torch.cuda.memory_allocated() / 1024**2 # MB unit + return torch.cuda.memory_allocated() / 1024**2 def get_mem_info(prefix=""): @@ -78,7 +78,6 @@ def get_mem_info(prefix=""): def get_model_size(model: nn.Module): - # get the number of parameter of the model total_numel = 0 for module in model.modules(): for p in module.parameters(recurse=False): @@ -130,7 +129,7 @@ def main(): WARMUP_STEPS = 1 assert WARMUP_STEPS < NUM_STEPS, "warmup steps should smaller than the total steps" assert (NUM_STEPS - WARMUP_STEPS) % 2 == 1, "the number of valid steps should be odd to take the median" - PROF_FLAG = True # The flag of profiling, False by default + PROF_FLAG = False # The flag of profiling, False by default disable_existing_loggers() colossalai.launch_from_torch() @@ -167,7 +166,7 @@ def main(): stage=zero_stage, reduce_bucket_size_in_m=12, overlap_communication=True, verbose=True ) elif args.distplan == "CAI_Gemini": - plugin = GeminiPlugin(search_range_m=128, hidden_dim=model.config.n_embd, max_prefetch=1) + plugin = GeminiPlugin(search_range_m=128, hidden_dim=model.config.n_embd) else: raise RuntimeError @@ -249,7 +248,7 @@ def main(): prof.step() tflops_list.sort() - median_index = min(((NUM_STEPS - WARMUP_STEPS) >> 1) + WARMUP_STEPS, len(tflops_list) - 1) + median_index = ((NUM_STEPS - WARMUP_STEPS) >> 1) + WARMUP_STEPS logger.info(f"Median TFLOPS is {tflops_list[median_index]:.3f}") torch.cuda.synchronize() diff --git a/tests/test_zero/test_gemini/test_fwd_bwd.py b/tests/test_zero/test_gemini/test_fwd_bwd.py index 517031c83..4d3981329 100644 --- a/tests/test_zero/test_gemini/test_fwd_bwd.py +++ b/tests/test_zero/test_gemini/test_fwd_bwd.py @@ -40,7 +40,7 @@ def check_grad(model: GeminiDDP, torch_model: torch.nn.Module): @parameterize("model_name", ["transformers_gpt_lm"]) @parameterize("use_grad_checkpoint", [False, True]) @parameterize("master_weights", [False, True]) -@parameterize("max_prefetch", [0, 1, 4]) +@parameterize("max_prefetch", [0, 4]) @parameterize("enable_async_reduce", [False, True]) def exam_gpt_fwd_bwd( placement_config, diff --git a/tests/test_zero/test_gemini/test_grad_accum.py b/tests/test_zero/test_gemini/test_grad_accum.py index c2c11a8f3..002741389 100644 --- a/tests/test_zero/test_gemini/test_grad_accum.py +++ b/tests/test_zero/test_gemini/test_grad_accum.py @@ -50,7 +50,7 @@ def check_grad(model: GeminiDDP, torch_model: torch.nn.Module): @parameterize("model_name", ["transformers_gpt_lm"]) @parameterize("master_weights", [False, True]) @parameterize("use_grad_checkpoint", [False, True]) -@parameterize("max_prefetch", [0, 1, 4]) +@parameterize("max_prefetch", [0, 4]) @parameterize("enable_async_reduce", [False, True]) def exam_gemini_grad_acc( placement_config, diff --git a/tests/test_zero/test_gemini/test_optim.py b/tests/test_zero/test_gemini/test_optim.py index a0cbc7d60..c610259b2 100644 --- a/tests/test_zero/test_gemini/test_optim.py +++ b/tests/test_zero/test_gemini/test_optim.py @@ -40,7 +40,9 @@ EXAMPLE_MODELS = [ ] # bfloat16 cannot represent them exactly -BF16_IGNORED_KEYS = ["masked_bias"] +BF16_IGNORED_KEYS = [ + "masked_bias", +] def check_param(model: GeminiDDP, torch_model: torch.nn.Module, dtype: torch.dtype): @@ -71,15 +73,9 @@ def check_param(model: GeminiDDP, torch_model: torch.nn.Module, dtype: torch.dty @parameterize("model_name", TEST_MODELS) @parameterize("mixed_precision", [torch.half, torch.bfloat16]) @parameterize("master_weights", [True, False]) -@parameterize("max_prefetch", [0, 1, 4]) @parameterize("enable_async_reduce", [False, True]) def exam_model_step( - placement_config, - model_name: str, - mixed_precision: torch.dtype, - master_weights: bool, - max_prefetch: int, - enable_async_reduce=True, + placement_config, model_name: str, mixed_precision: torch.dtype, master_weights: bool, enable_async_reduce=True ): set_seed(42) model_builder, data_gen_fn, output_transform_fn, loss_fn, *_ = next( @@ -108,7 +104,6 @@ def exam_model_step( **placement_config, mixed_precision=mixed_precision, master_weights=master_weights, - max_prefetch=max_prefetch, enable_async_reduce=enable_async_reduce, ) diff --git a/tests/test_zero/test_gemini/test_zeroddp_state_dict.py b/tests/test_zero/test_gemini/test_zeroddp_state_dict.py index 3cbd36917..23e2d8083 100644 --- a/tests/test_zero/test_gemini/test_zeroddp_state_dict.py +++ b/tests/test_zero/test_gemini/test_zeroddp_state_dict.py @@ -28,8 +28,7 @@ def ignore_the_first_parameter(model: torch.nn.Module): @parameterize("keep_gathered", [True, False]) @parameterize("model_name", ["transformers_gpt_lm", "transformers_bert_for_sequence_classification"]) @parameterize("master_weights", [False, True]) -@parameterize("max_prefetch", [0, 1, 4]) -def exam_state_dict(placement_config, keep_gathered, model_name: str, master_weights: bool, max_prefetch: int): +def exam_state_dict(placement_config, keep_gathered, model_name: str, master_weights: bool): set_seed(431) model_builder, data_gen_fn, output_transform_fn, *_ = next(iter(model_zoo.get_sub_registry(model_name).values())) @@ -45,14 +44,7 @@ def exam_state_dict(placement_config, keep_gathered, model_name: str, master_wei config_dict, *_ = search_chunk_configuration(model, search_range_m=1, search_interval=100) config_dict[world_size]["chunk_size"] = 5000 config_dict[world_size]["keep_gathered"] = keep_gathered - model = GeminiDDP( - model, - config_dict, - **placement_config, - pin_memory=True, - master_weights=master_weights, - max_prefetch=max_prefetch, - ) + model = GeminiDDP(model, config_dict, **placement_config, pin_memory=True, master_weights=master_weights) model.train() zero_dict = model.state_dict(only_rank_0=False) diff --git a/tests/test_zero/test_gemini/test_zerooptim_state_dict.py b/tests/test_zero/test_gemini/test_zerooptim_state_dict.py index a721c96a1..8d70ae3b1 100644 --- a/tests/test_zero/test_gemini/test_zerooptim_state_dict.py +++ b/tests/test_zero/test_gemini/test_zerooptim_state_dict.py @@ -20,8 +20,7 @@ PLACEMENT_CONFIGS = [ @parameterize("placement_config", PLACEMENT_CONFIGS) @parameterize("keep_gathered", [True, False]) -@parameterize("max_prefetch", [0, 1, 4]) -def exam_zero_optim_state_dict(placement_config, keep_gathered, max_prefetch): +def exam_zero_optim_state_dict(placement_config, keep_gathered): set_seed(431) model_builder, data_gen_fn, output_transform_fn, *_ = next( iter(model_zoo.get_sub_registry("transformers_gpt_lm").values()) @@ -36,7 +35,7 @@ def exam_zero_optim_state_dict(placement_config, keep_gathered, max_prefetch): config_dict[world_size]["chunk_size"] = 5000 config_dict[world_size]["keep_gathered"] = keep_gathered - model = GeminiDDP(model, config_dict, **placement_config, pin_memory=True, max_prefetch=max_prefetch) + model = GeminiDDP(model, config_dict, **placement_config, pin_memory=True) optimizer = HybridAdam(model.parameters()) optim = GeminiOptimizer(optimizer, model, initial_scale=32) # initialize the link between chunk16 and chunk32 From fba04e857b57abc54ba4864cbfb3af0461e2c5e7 Mon Sep 17 00:00:00 2001 From: genghaozhe <939857490@qq.com> Date: Sat, 25 May 2024 14:55:09 +0000 Subject: [PATCH 29/36] [bugs] fix args.profile=False DummyProfiler errro --- examples/language/performance_evaluator.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/examples/language/performance_evaluator.py b/examples/language/performance_evaluator.py index 0b147b7ea..99df8f1da 100644 --- a/examples/language/performance_evaluator.py +++ b/examples/language/performance_evaluator.py @@ -36,6 +36,12 @@ def get_profile_context(enable_flag, warmup_steps, active_steps, save_dir): def step(self): self.step_number += 1 + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_value, traceback): + pass + if enable_flag: return profile( activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], From b9269d962dff742df667ae19000f63622b45f56b Mon Sep 17 00:00:00 2001 From: genghaozhe <939857490@qq.com> Date: Sat, 25 May 2024 14:55:50 +0000 Subject: [PATCH 30/36] add args.prefetch_num for benchmark --- examples/language/llama/benchmark.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/examples/language/llama/benchmark.py b/examples/language/llama/benchmark.py index 712703b45..b71203518 100644 --- a/examples/language/llama/benchmark.py +++ b/examples/language/llama/benchmark.py @@ -79,7 +79,7 @@ def main(): parser.add_argument("--custom-ckpt", action="store_true", help="Customize checkpoint", default=False) parser.add_argument("--profile", action="store_true", help="Enable profiling", default=False) parser.add_argument("--disable-async-reduce", action="store_true", help="Customize checkpoint", default=False) - + parser.add_argument("--prefetch_num", type=int, default=0, help="chunk prefetch max number") args = parser.parse_args() colossalai.launch_from_torch() @@ -114,7 +114,7 @@ def main(): extra_dp_size=args.extra_dp, enable_fused_normalization=torch.cuda.is_available(), enable_flash_attention=args.xformers, - max_prefetch=10, + max_prefetch=args.prefetch_num, enable_async_reduce=not args.disable_async_reduce, ) elif args.plugin == "gemini_auto": @@ -125,6 +125,8 @@ def main(): tp_size=args.tp, extra_dp_size=args.extra_dp, enable_fused_normalization=torch.cuda.is_available(), + max_prefetch=args.prefetch_num, + enable_async_reduce=not args.disable_async_reduce, enable_flash_attention=args.xformers, ) elif args.plugin == "fsdp": From 4d097def9637a67629a988c269093c46ac3e7cbf Mon Sep 17 00:00:00 2001 From: Haze188 Date: Sat, 25 May 2024 23:00:13 +0800 Subject: [PATCH 31/36] [Gemini] add some code for reduce-scatter overlap, chunk prefetch in llama benchmark. (#5751) * [bugs] fix args.profile=False DummyProfiler errro * add args.prefetch_num for benchmark --- examples/language/llama/benchmark.py | 6 ++++-- examples/language/performance_evaluator.py | 6 ++++++ 2 files changed, 10 insertions(+), 2 deletions(-) diff --git a/examples/language/llama/benchmark.py b/examples/language/llama/benchmark.py index 712703b45..b71203518 100644 --- a/examples/language/llama/benchmark.py +++ b/examples/language/llama/benchmark.py @@ -79,7 +79,7 @@ def main(): parser.add_argument("--custom-ckpt", action="store_true", help="Customize checkpoint", default=False) parser.add_argument("--profile", action="store_true", help="Enable profiling", default=False) parser.add_argument("--disable-async-reduce", action="store_true", help="Customize checkpoint", default=False) - + parser.add_argument("--prefetch_num", type=int, default=0, help="chunk prefetch max number") args = parser.parse_args() colossalai.launch_from_torch() @@ -114,7 +114,7 @@ def main(): extra_dp_size=args.extra_dp, enable_fused_normalization=torch.cuda.is_available(), enable_flash_attention=args.xformers, - max_prefetch=10, + max_prefetch=args.prefetch_num, enable_async_reduce=not args.disable_async_reduce, ) elif args.plugin == "gemini_auto": @@ -125,6 +125,8 @@ def main(): tp_size=args.tp, extra_dp_size=args.extra_dp, enable_fused_normalization=torch.cuda.is_available(), + max_prefetch=args.prefetch_num, + enable_async_reduce=not args.disable_async_reduce, enable_flash_attention=args.xformers, ) elif args.plugin == "fsdp": diff --git a/examples/language/performance_evaluator.py b/examples/language/performance_evaluator.py index 0b147b7ea..99df8f1da 100644 --- a/examples/language/performance_evaluator.py +++ b/examples/language/performance_evaluator.py @@ -36,6 +36,12 @@ def get_profile_context(enable_flag, warmup_steps, active_steps, save_dir): def step(self): self.step_number += 1 + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_value, traceback): + pass + if enable_flag: return profile( activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], From 87665d79228df9e8e40363e731874939f3b66b2f Mon Sep 17 00:00:00 2001 From: genghaozhe <939857490@qq.com> Date: Mon, 27 May 2024 06:03:53 +0000 Subject: [PATCH 32/36] correct argument help message --- examples/language/llama/benchmark.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/examples/language/llama/benchmark.py b/examples/language/llama/benchmark.py index b71203518..8d4dae314 100644 --- a/examples/language/llama/benchmark.py +++ b/examples/language/llama/benchmark.py @@ -78,7 +78,9 @@ def main(): parser.add_argument("--zero", type=int, default=0, help="Zero Stage when hybrid plugin is enabled") parser.add_argument("--custom-ckpt", action="store_true", help="Customize checkpoint", default=False) parser.add_argument("--profile", action="store_true", help="Enable profiling", default=False) - parser.add_argument("--disable-async-reduce", action="store_true", help="Customize checkpoint", default=False) + parser.add_argument( + "--disable-async-reduce", action="store_true", help="Disable the asynchronous reduce operation", default=False + ) parser.add_argument("--prefetch_num", type=int, default=0, help="chunk prefetch max number") args = parser.parse_args() From 936dd96dbb6a9a4c1d167d795d4494b45edb4f5a Mon Sep 17 00:00:00 2001 From: hxwang Date: Tue, 28 May 2024 02:33:12 +0000 Subject: [PATCH 33/36] [bug] workaround for idx fix --- colossalai/zero/gemini/gemini_mgr.py | 1 + colossalai/zero/gemini/placement_policy.py | 14 ++++++-------- 2 files changed, 7 insertions(+), 8 deletions(-) diff --git a/colossalai/zero/gemini/gemini_mgr.py b/colossalai/zero/gemini/gemini_mgr.py index 83e475575..d6b539f55 100644 --- a/colossalai/zero/gemini/gemini_mgr.py +++ b/colossalai/zero/gemini/gemini_mgr.py @@ -57,6 +57,7 @@ class GeminiManager: self._comp_cuda_demand_time = 0 def reset_attributes(self): + assert self._compute_idx + 1 == len(self._compute_list) self._compute_idx = -1 self._h2d_volume = 0 self._d2h_volume = 0 diff --git a/colossalai/zero/gemini/placement_policy.py b/colossalai/zero/gemini/placement_policy.py index 9b1d1a6ab..c26db00e0 100644 --- a/colossalai/zero/gemini/placement_policy.py +++ b/colossalai/zero/gemini/placement_policy.py @@ -145,6 +145,8 @@ class AutoPlacementPolicy(PlacementPolicy): self._warmup_non_model_data_ratio = warmup_non_model_data_ratio self._steady_cuda_cap_ratio = steady_cuda_cap_ratio + self.__avail_cuda_model_data_for_prefetch = None + def evict_tensors( self, can_evict_chunks: List[Chunk], @@ -204,6 +206,7 @@ class AutoPlacementPolicy(PlacementPolicy): f"Adjust layout failed! No enough CUDA memory! " f"Need {to_free_cuda_model_data}, freed {freed_cuda_model_data}" ) + self.__avail_cuda_model_data_for_prefetch = avail_cuda_model_data - freed_cuda_model_data return freed_cuda_model_data, time() - start @staticmethod @@ -234,14 +237,9 @@ class AutoPlacementPolicy(PlacementPolicy): ) -> List[Chunk]: if is_warmup: # no prefetch during warmup since we need compute_list return [] - # modified from self.evict_tensors - cuda_capacity = self._steady_cuda_cap_ratio * colo_device_memory_capacity( - get_accelerator().get_current_device() - ) - max_cuda_non_model_data_per_period = self.mem_stats_collector.next_period_non_model_data_usage("cuda") - used_cuda_model_data = self.chunk_manager.total_mem["cuda"] - total_cuda_model_data = cuda_capacity - max_cuda_non_model_data_per_period - avail_cuda_model_data = total_cuda_model_data - used_cuda_model_data + + avail_cuda_model_data = self.__avail_cuda_model_data_for_prefetch + self.__avail_cuda_model_data_for_prefetch = None # incase of double use prefetch_chunk_memory = 0 can_prefetch = self.max_prefetch - len(async_works) From e5e3320948ce4b05cf8839e79e936c482b0326fb Mon Sep 17 00:00:00 2001 From: hxwang Date: Tue, 28 May 2024 02:41:23 +0000 Subject: [PATCH 34/36] [bug] continue fix --- colossalai/zero/gemini/placement_policy.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/colossalai/zero/gemini/placement_policy.py b/colossalai/zero/gemini/placement_policy.py index c26db00e0..178755d03 100644 --- a/colossalai/zero/gemini/placement_policy.py +++ b/colossalai/zero/gemini/placement_policy.py @@ -206,7 +206,7 @@ class AutoPlacementPolicy(PlacementPolicy): f"Adjust layout failed! No enough CUDA memory! " f"Need {to_free_cuda_model_data}, freed {freed_cuda_model_data}" ) - self.__avail_cuda_model_data_for_prefetch = avail_cuda_model_data - freed_cuda_model_data + self.__avail_cuda_model_data_for_prefetch = avail_cuda_model_data + freed_cuda_model_data return freed_cuda_model_data, time() - start @staticmethod From 8547562884c4bfec614059ec5ae775a326d78036 Mon Sep 17 00:00:00 2001 From: hxwang Date: Tue, 28 May 2024 05:16:02 +0000 Subject: [PATCH 35/36] [chore] remove unnecessary assert since compute list might not be recorded --- colossalai/zero/gemini/gemini_mgr.py | 1 - 1 file changed, 1 deletion(-) diff --git a/colossalai/zero/gemini/gemini_mgr.py b/colossalai/zero/gemini/gemini_mgr.py index d6b539f55..83e475575 100644 --- a/colossalai/zero/gemini/gemini_mgr.py +++ b/colossalai/zero/gemini/gemini_mgr.py @@ -57,7 +57,6 @@ class GeminiManager: self._comp_cuda_demand_time = 0 def reset_attributes(self): - assert self._compute_idx + 1 == len(self._compute_list) self._compute_idx = -1 self._h2d_volume = 0 self._d2h_volume = 0 From 154720ba6ee10631added5486eabd1134d480c9e Mon Sep 17 00:00:00 2001 From: hxwang Date: Tue, 28 May 2024 12:41:42 +0000 Subject: [PATCH 36/36] [chore] refactor profiler utils --- .../gpt/gemini/commons/performance_evaluator.py | 1 + examples/language/gpt/gemini/commons/utils.py | 16 ---------------- examples/language/gpt/gemini/train_gpt_demo.py | 3 ++- examples/language/performance_evaluator.py | 4 ++-- 4 files changed, 5 insertions(+), 19 deletions(-) create mode 120000 examples/language/gpt/gemini/commons/performance_evaluator.py diff --git a/examples/language/gpt/gemini/commons/performance_evaluator.py b/examples/language/gpt/gemini/commons/performance_evaluator.py new file mode 120000 index 000000000..152602774 --- /dev/null +++ b/examples/language/gpt/gemini/commons/performance_evaluator.py @@ -0,0 +1 @@ +../../../performance_evaluator.py \ No newline at end of file diff --git a/examples/language/gpt/gemini/commons/utils.py b/examples/language/gpt/gemini/commons/utils.py index 03054c0a2..ba80cc4a6 100644 --- a/examples/language/gpt/gemini/commons/utils.py +++ b/examples/language/gpt/gemini/commons/utils.py @@ -1,8 +1,6 @@ import time -from contextlib import nullcontext import torch -from torch.profiler import ProfilerActivity, profile, schedule, tensorboard_trace_handler class DummyProfiler: @@ -24,20 +22,6 @@ def get_tflops(model_numel, batch_size, seq_len, step_time): return model_numel * batch_size * seq_len * 8 / 1e12 / (step_time + 1e-12) -def get_profile_context(enable_flag, warmup_steps, active_steps, save_dir): - if enable_flag: - return profile( - activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], - schedule=schedule(wait=0, warmup=warmup_steps, active=active_steps), - on_trace_ready=tensorboard_trace_handler(save_dir), - # record_shapes=True, - # profile_memory=True, - with_stack=True, - ) - else: - return nullcontext(DummyProfiler()) - - def get_time_stamp(): cur_time = time.strftime("%d-%H:%M", time.localtime()) return cur_time diff --git a/examples/language/gpt/gemini/train_gpt_demo.py b/examples/language/gpt/gemini/train_gpt_demo.py index 4911ff124..cb5d2c32c 100644 --- a/examples/language/gpt/gemini/train_gpt_demo.py +++ b/examples/language/gpt/gemini/train_gpt_demo.py @@ -8,7 +8,8 @@ import psutil import torch import torch.nn as nn from commons.model_zoo import model_builder -from commons.utils import get_data, get_profile_context, get_tflops, get_time_stamp +from commons.performance_evaluator import get_profile_context +from commons.utils import get_data, get_tflops, get_time_stamp from packaging import version import colossalai diff --git a/examples/language/performance_evaluator.py b/examples/language/performance_evaluator.py index 99df8f1da..6b8daf37d 100644 --- a/examples/language/performance_evaluator.py +++ b/examples/language/performance_evaluator.py @@ -47,8 +47,8 @@ def get_profile_context(enable_flag, warmup_steps, active_steps, save_dir): activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], schedule=schedule(wait=0, warmup=warmup_steps, active=active_steps), on_trace_ready=tensorboard_trace_handler(save_dir), - # record_shapes=True, - # profile_memory=True, + record_shapes=True, + profile_memory=True, with_stack=True, ) else: