From 74754397df336db3c9fd03fb297792f8c4b546d8 Mon Sep 17 00:00:00 2001 From: huangting4201 <1538303371@qq.com> Date: Mon, 13 Nov 2023 21:09:59 +0800 Subject: [PATCH] feat(model/overlap_handler.py): add memory_pool switch and refactor overlap handler --- configs/7B_sft.py | 2 +- internlm/model/overlap_handler.py | 193 ++++++++++-------- internlm/model/utils.py | 23 ++- .../solver/optimizer/hybrid_zero_optim.py | 4 +- train.py | 2 +- 5 files changed, 128 insertions(+), 96 deletions(-) diff --git a/configs/7B_sft.py b/configs/7B_sft.py index e85d2df..63fa67e 100644 --- a/configs/7B_sft.py +++ b/configs/7B_sft.py @@ -163,7 +163,7 @@ pipeline parallel (dict): """ parallel = dict( zero1=dict(size=-1, fsdp=False), - tensor=dict(size=4, sp="intern", intern_overlap=True), + tensor=dict(size=4, sp="intern", intern_overlap=True, memory_pool=True), pipeline=dict(size=1, interleaved_overlap=True), ) diff --git a/internlm/model/overlap_handler.py b/internlm/model/overlap_handler.py index e3198bb..cb00d22 100644 --- a/internlm/model/overlap_handler.py +++ b/internlm/model/overlap_handler.py @@ -13,6 +13,7 @@ from internlm.core.scheduler import SchedulerHook from internlm.model.embedding import Embedding1D from internlm.model.linear import FSTPLinear, ScaleColumnParallelLinear from internlm.model.utils import ( + all_gather_raw, all_gather_raw_bias_memory_pool, all_gather_raw_memory_pool, ) @@ -29,14 +30,17 @@ class FSTPOverlapHandler: self.fstp_outs = [] self.fstp_modules = [] self.module_name = ["Wqkv", "out_proj", "w1", "w2", "w3"] - self.fstp_global_handle = dict() # key: fstp module; value: module global all-gather op handle + self.weight_global_handle = dict() # key: fstp module; value: module global all-gather op handle self.bias_global_handle = dict() # key: fstp module; value: module bias global all-gather op handle + self.weight_global_output = dict() # key: fstp module; value: module global weight after all-gather op + self.bias_global_output = dict() # key: fstp module; value: module bias global weight after all-gather op self.module_to_index = dict() # key: fstp module; value: transformer block index self.index_to_fstp_modules = dict() # key: transformer block index; value: fsdp modules self.last_block = None self.head = [] self.embedding = [] self.model_checkpoint = gpc.config.model.checkpoint + self.enable_memory_pool = gpc.config.parallel["tensor"].get("memory_pool", False) self.is_forward = True self.reduce_scatter_handlers = {} @@ -60,34 +64,36 @@ class FSTPOverlapHandler: for idx, block in enumerate(children): self.index_to_fstp_modules[idx] = [] for _sub_name, sub in block.named_children(): - sub_modules = list(sub.children()) - if len(sub_modules) > 0: - for name, child in sub.named_children(): - if name == "out_proj": - self.fstp_outs.append(child) - self.module_to_index[child] = idx - if isinstance(child, FSTPLinear): - self.module_to_index[child] = idx - self.fstp_modules.append(child) - self.index_to_fstp_modules[idx].append(child) + for name, child in sub.named_children(): + if name == "out_proj": + self.fstp_outs.append(child) + self.module_to_index[child] = idx + if isinstance(child, FSTPLinear): + self.module_to_index[child] = idx + self.fstp_modules.append(child) + self.index_to_fstp_modules[idx].append(child) - setattr(child, "_fstp_name", name) + setattr(child, "_fstp_name", name) - _full_name = f"{_chunk_name}.{idx}.{_sub_name}.{name}" - setattr(child.weight, "_fstp_reduce_scatter_str", f"{_full_name}.weight") - if child.bias is not None: - setattr(child.bias, "_fstp_reduce_scatter_str", f"{_full_name}.bias") + _full_name = f"{_chunk_name}.{idx}.{_sub_name}.{name}" + setattr(child.weight, "_fstp_reduce_scatter_str", f"{_full_name}.weight") + if child.bias is not None: + setattr(child.bias, "_fstp_reduce_scatter_str", f"{_full_name}.bias") self.num_blocks = len(self.index_to_fstp_modules) - self._initialize_memory_pool() + if self.enable_memory_pool: + self._initialize_memory_pool() self._register_sync_parameters_hook() def get_zero_by_shape(self, size: tuple, dtype, device) -> torch.Tensor: - if size not in self.zero_const_pool: - self.zero_const_pool[size] = torch.zeros(*size, dtype=dtype, device=device).contiguous() + if self.enable_memory_pool: + if size not in self.zero_const_pool: + self.zero_const_pool[size] = torch.zeros(*size, dtype=dtype, device=device).contiguous() - return self.zero_const_pool[size] + return self.zero_const_pool[size] + else: + return torch.zeros(*size, dtype=dtype, device=device).contiguous() def set_forward_mode(self, flag): self.is_forward = flag @@ -122,14 +128,20 @@ class FSTPOverlapHandler: self.all_gather_memory_pool.append(weight) # containing two groups of block weight def clear_memory_pool(self) -> None: + assert self.enable_memory_pool + self.zero_const_pool = {} self.reduce_scatter_memory_pool = {} - def get_all_gather_memory(self, module): + def _get_weight_from_memory_pool(self, module): + assert self.enable_memory_pool + block_index = self.module_to_index[module] return self.all_gather_memory_pool[block_index % 2][module._fstp_name] - def get_bias_memory(self, module: nn.Module): + def _get_bias_from_memory_pool(self, module: nn.Module): + assert self.enable_memory_pool + block_index = self.module_to_index[module] # if the bias memory pool is empty or module has been not allocated memory if len(self.all_gather_bias_memory_pool) == 0: @@ -151,7 +163,21 @@ class FSTPOverlapHandler: return self.all_gather_bias_memory_pool[block_index % 2][module._fstp_name] + def get_weight_all_gather(self, module): + if self.enable_memory_pool: + return self._get_weight_from_memory_pool(module) + else: + return self.weight_global_output[module] + + def get_bias_all_gather(self, module): + if self.enable_memory_pool: + return self._get_bias_from_memory_pool(module) + else: + return self.bias_global_output[module] + def get_reduce_scatter_memory(self, key): + assert self.enable_memory_pool + # if key not in dict if key not in self.reduce_scatter_memory_pool: self.reduce_scatter_memory_pool[key] = [] @@ -171,11 +197,11 @@ class FSTPOverlapHandler: return self.reduce_scatter_memory_pool[key][cur_len] def release_reduce_scatter_memory(self, key, index): + assert self.enable_memory_pool self.reduce_scatter_memory_pool[key][index].idle = True - def _all_gather_block_weight_memory_pool(self, block_index: int): - fstp_modules = self.index_to_fstp_modules[block_index] - for module in fstp_modules: + def _all_gather_module_weight(self, module): + if self.enable_memory_pool: if module.bias is not None: bias_handle = all_gather_raw_bias_memory_pool( module.bias, @@ -191,103 +217,102 @@ class FSTPOverlapHandler: async_op=True, module=module, ) - self.fstp_global_handle[module] = weight_handle + self.weight_global_handle[module] = weight_handle + else: + if module.bias is not None: + bias_output, bias_handle = all_gather_raw( + module.bias, + self.process_group, + async_op=True, + ) + self.bias_global_handle[module] = bias_handle + self.bias_global_output[module] = bias_output + + weight_output, weight_handle = all_gather_raw( + module.weight, + self.process_group, + async_op=True, + ) + self.weight_global_handle[module] = weight_handle + self.weight_global_output[module] = weight_output + + def _all_gather_block_weight(self, block_index: int): + fstp_modules = self.index_to_fstp_modules[block_index] + for module in fstp_modules: + self._all_gather_module_weight(module) def _register_sync_parameters_hook(self) -> None: """ register forward hooks and backward hooks for fstp modules. """ + def _wait_handle(module): + handle = self.weight_global_handle[module] + handle.wait() + if module.bias is not None: + bias_handle = self.bias_global_handle[module] + bias_handle.wait() + + def _clear_handle(module): + if module in self.weight_global_handle: + del self.weight_global_handle[module] + if module in self.bias_global_handle: + del self.bias_global_handle[module] + # if module in self.weight_global_output: + # del self.weight_global_output[module] + # if module in self.bias_global_output: + # del self.bias_global_output[module] + def _post_forward_hook_for_embedding(module: nn.Module, inputs: Any, output: Any): # pylint: disable=W0613 - self._all_gather_block_weight_memory_pool(0) + self._all_gather_block_weight(0) def _pre_forward_hook_for_out_proj(module: nn.Module, inputs: Any): # pylint: disable=W0613 block_index = self.module_to_index[module] if self.model_checkpoint and self.is_forward is False: if block_index - 1 >= 0: - self._all_gather_block_weight_memory_pool(block_index - 1) + self._all_gather_block_weight(block_index - 1) else: # start the all-gather for next block if block_index + 1 < self.num_blocks: - self._all_gather_block_weight_memory_pool(block_index + 1) + self._all_gather_block_weight(block_index + 1) def _pre_forward_hook_for_module(module: nn.Module, inputs: Any): # pylint: disable=W0613 - if module in self.fstp_global_handle: - handle = self.fstp_global_handle[module] - handle.wait() - if module.bias is not None: - bias_handle = self.bias_global_handle[module] - bias_handle.wait() - else: - weight_handle = all_gather_raw_memory_pool( - module.weight, - self.process_group, - async_op=True, - module=module, - ) - self.fstp_global_handle[module] = weight_handle - weight_handle.wait() + if module not in self.weight_global_handle: + self._all_gather_module_weight(module) + + _wait_handle(module) def _pre_forward_hook_for_block(block: nn.Module, inputs: Any): # pylint: disable=W0613 fstp_modules = self.index_to_fstp_modules[self.num_blocks - 1] if module in fstp_modules: - weight_handle = all_gather_raw_memory_pool( - module.weight, - self.process_group, - async_op=True, - module=module, - ) - self.fstp_global_handle[module] = weight_handle - weight_handle.wait() + self._all_gather_module_weight(module) + _wait_handle(module) def _post_forward_hook_for_module(module: nn.Module, inputs: Any, output: Any): # pylint: disable=W0613 - if module in self.fstp_global_handle: - del self.fstp_global_handle[module] + _clear_handle(module) def _post_backward_hook_for_head(module: nn.Module, grad_input, grad_output): # pylint: disable=W0613 - first_backward_module = self.fstp_modules[-1] - weight_handle = all_gather_raw_memory_pool( - first_backward_module.weight, - self.process_group, - async_op=True, - module=first_backward_module, - ) - self.fstp_global_handle[first_backward_module] = weight_handle + self._all_gather_module_weight(self.fstp_modules[-1]) def _pre_backward_hook_for_head(module: nn.Module, grad_output): if self.is_forward is False: - self._all_gather_block_weight_memory_pool(self.num_blocks - 1) + self._all_gather_block_weight(self.num_blocks - 1) def _pre_backward_hook_for_module(module: nn.Module, grad_output): # pylint: disable=W0613 # wait handle for current module - if module in self.fstp_global_handle: - weight_handle = self.fstp_global_handle[module] - weight_handle.wait() - else: - weight_handle = all_gather_raw_memory_pool( - module.weight, - self.process_group, - async_op=True, - module=module, - ) - self.fstp_global_handle[module] = weight_handle - weight_handle.wait() + if module not in self.weight_global_handle: + self._all_gather_module_weight(module) + + _wait_handle(module) # start the all-gather for next module module_index = self.fstp_modules.index(module) if module_index - 1 >= 0: next_module = self.fstp_modules[module_index - 1] - weight_handle = all_gather_raw_memory_pool( - next_module.weight, - self.process_group, - async_op=True, - module=next_module, - ) - self.fstp_global_handle[next_module] = weight_handle + self._all_gather_module_weight(next_module) def _post_backward_hook_for_module(module, grad_input, grad_output): # pylint: disable=W0613 - if module in self.fstp_global_handle: - del self.fstp_global_handle[module] + _clear_handle(module) # register forward hooks # 1. register post_forward_hook @embedding module to prefetch for block 0 diff --git a/internlm/model/utils.py b/internlm/model/utils.py index 556752a..45d2f51 100644 --- a/internlm/model/utils.py +++ b/internlm/model/utils.py @@ -132,7 +132,7 @@ def all_gather_raw_memory_pool( module: nn.Module = None, ): handle = torch.distributed.all_gather_into_tensor( - gpc.fstp_handler.get_all_gather_memory(module=module), + gpc.fstp_handler.get_weight_all_gather(module=module), input_.contiguous(), group=process_group, async_op=async_op, @@ -147,7 +147,7 @@ def all_gather_raw_bias_memory_pool( module: nn.Module = None, ): handle = torch.distributed.all_gather_into_tensor( - gpc.fstp_handler.get_bias_memory(module=module), + gpc.fstp_handler.get_bias_all_gather(module=module), input_.contiguous(), group=process_group, async_op=async_op, @@ -177,8 +177,13 @@ def reduce_scatter_raw(input_: Tensor, process_group: ProcessGroup, async_op: bo def reduce_scatter_raw_memory_pool(input_: Tensor, process_group: ProcessGroup, async_op: bool = False): world_size = torch.distributed.get_world_size(process_group) assert input_.shape[0] % world_size == 0 - size = (input_.shape[0] // world_size, *input_.shape[1:]) - output = gpc.fstp_handler.get_reduce_scatter_memory(size) + if gpc.fstp_handler.enable_memory_pool: + size = (input_.shape[0] // world_size, *input_.shape[1:]) + output = gpc.fstp_handler.get_reduce_scatter_memory(size) + else: + output = torch.empty( + input_.shape[0] // world_size, *input_.shape[1:], dtype=input_.dtype, device=input_.device + ).contiguous() handle = torch.distributed.reduce_scatter_tensor( output, input_.contiguous(), group=process_group, async_op=async_op ) @@ -493,14 +498,14 @@ class FSTPFusedDenseFunc(torch.autograd.Function): if world_size > 1: # do all_gather for weight and bias before actual computation if overlap_handler is not None: - total_weight = gpc.fstp_handler.get_all_gather_memory(module=module) + total_weight = gpc.fstp_handler.get_weight_all_gather(module=module) else: total_weight, handle_weight = all_gather_raw(weight, process_group, async_op=True) handle_weight.wait() - # TODO memory pool for bias + if bias is not None: if overlap_handler is not None: - total_bias = gpc.fstp_handler.get_bias_memory(module=module) + total_bias = gpc.fstp_handler.get_bias_all_gather(module=module) else: total_bias, handle_bias = all_gather_raw(bias, process_group, async_op=True) handle_bias.wait() @@ -554,7 +559,7 @@ class FSTPFusedDenseFunc(torch.autograd.Function): world_size = gpc.get_world_size(ParallelMode.TENSOR) if world_size > 1: if overlap_handler is not None: - total_weight = gpc.fstp_handler.get_all_gather_memory(module=module) + total_weight = gpc.fstp_handler.get_weight_all_gather(module=module) else: total_weight, handle_weight = all_gather_raw(weight, process_group, async_op=True) handle_weight.wait() @@ -655,7 +660,7 @@ class FSTPFusedDenseFuncTorch(FSTPFusedDenseFunc): world_size = gpc.get_world_size(ParallelMode.TENSOR) if world_size > 1: if overlap_handler is not None: - total_weight = gpc.fstp_handler.get_all_gather_memory(module=module) + total_weight = gpc.fstp_handler.get_weight_all_gather(module=module) else: total_weight, handle_weight = all_gather_raw(weight, process_group, async_op=True) handle_weight.wait() diff --git a/internlm/solver/optimizer/hybrid_zero_optim.py b/internlm/solver/optimizer/hybrid_zero_optim.py index b033539..3092a62 100644 --- a/internlm/solver/optimizer/hybrid_zero_optim.py +++ b/internlm/solver/optimizer/hybrid_zero_optim.py @@ -389,7 +389,9 @@ class HybridZeroOptimizer(BaseOptimizer): _param.grad.add_(_grad) # release cuda memory. - self._fstp_handler.release_reduce_scatter_memory(key=tuple(_grad.size()), index=_grad.index) + if self._fstp_handler.enable_memory_pool: + self._fstp_handler.release_reduce_scatter_memory(key=tuple(_grad.size()), index=_grad.index) + _grad = None self._fstp_handler.reduce_scatter_handlers[_key] = None bucket.reset_by_rank(reduce_rank) diff --git a/train.py b/train.py index 644bbeb..5ea91e8 100644 --- a/train.py +++ b/train.py @@ -324,7 +324,7 @@ def main(args): if batch_count % 2 == 0: prof.step() - if gpc.fstp_handler is not None: + if gpc.fstp_handler is not None and gpc.fstp_handler.enable_memory_pool: gpc.fstp_handler.clear_memory_pool() # torch.cuda.memory._dump_snapshot(f"my_snapshot_{gpc.get_global_rank()}.pickle") torch.cuda.reset_peak_memory_stats()