From 85ad917ae430c2e89cf4444221c2ced9223d3552 Mon Sep 17 00:00:00 2001 From: huangting4201 <1538303371@qq.com> Date: Fri, 20 Oct 2023 21:50:32 +0800 Subject: [PATCH] feat(model/overlap_handler.py): refactor overlap hook handle --- configs/7B_sft.py | 2 +- internlm/model/linear.py | 296 +++++------------- internlm/model/modeling_internlm.py | 11 +- internlm/model/multi_head_attention.py | 11 +- internlm/model/overlap_handler.py | 253 +++++++++++++++ internlm/model/utils.py | 98 +++--- .../solver/optimizer/hybrid_zero_optim.py | 12 +- internlm/train/training_internlm.py | 56 +--- train.py | 2 +- 9 files changed, 392 insertions(+), 349 deletions(-) create mode 100644 internlm/model/overlap_handler.py diff --git a/configs/7B_sft.py b/configs/7B_sft.py index 09af7f4..c51c812 100644 --- a/configs/7B_sft.py +++ b/configs/7B_sft.py @@ -163,7 +163,7 @@ pipeline parallel (dict): """ parallel = dict( zero1=dict(size=-1, fsdp=False), - tensor=dict(size=8, sp="megatron", intern_overlap=True), + tensor=dict(size=8, sp="intern", intern_overlap=True), pipeline=dict(size=1, interleaved_overlap=True), ) diff --git a/internlm/model/linear.py b/internlm/model/linear.py index 2bbb941..6cd3b9c 100644 --- a/internlm/model/linear.py +++ b/internlm/model/linear.py @@ -1,22 +1,17 @@ #!/usr/bin/env python # -*- encoding: utf-8 -*- -from typing import Any, Optional, Union +from typing import Optional import torch -import torch.nn.functional as F from flash_attn.ops.fused_dense import ColumnParallelLinear, RowParallelLinear from flash_attn.utils.distributed import all_reduce, reduce_scatter from torch import nn from internlm.core.context import ParallelMode from internlm.core.context import global_context as gpc -from internlm.core.naive_amp import NaiveAMPModel -from internlm.model.embedding import Embedding1D from internlm.model.utils import ( Silu, - all_gather_raw, - all_gather_raw_memory_pool, fstp_fused_dense_func, fused_dense_func_torch, megatron_fused_dense_func_torch, @@ -25,20 +20,20 @@ from internlm.model.utils import ( class BaseScaleColumnParallelLinear(nn.Linear): """ - Base class for ScaleColumnParallelLinear. + Base class for ScaleColumnParallelLinear. - Args: - in_features (int): size of each input sample - out_features (int): size of each output sample - process_group (Optional[torch.distributed.ProcessGroup]): The group of the current device for `parallel_mode`. - bias (bool): Whether the bias is needed for linears. True by default. But it is typically set to False - in the config. - sequence_parallel (bool): If sequence_parallel is True, we're doing Tensor Parallel with sequence parallelism: - we do an all_gather of x before doing the matmul. - If not, then the input is already gathered. - device (Optional[Union[str, torch.device]]): The device will be used. - dtype (Optional[torch.dtype]): The type of data. - weight_scale (int): For training stability. 1 by default. + Args: + in_features (int): size of each input sample + out_features (int): size of each output sample + process_group (Optional[torch.distributed.ProcessGroup]): The group of the current device for `parallel_mode`. + bias (bool): Whether the bias is needed for linears. True by default. But it is typically set to False + in the config. + sequence_parallel (bool): If sequence_parallel is True, we're doing Tensor Parallel with sequence parallelism: + we do an all_gather of x before doing the matmul. + If not, then the input is already gathered. + device (Optional[Union[str, torch.device]]): The device will be used. + dtype (Optional[torch.dtype]): The type of data. + weight_scale (int): For training stability. 1 by default. """ def __init__( @@ -58,10 +53,12 @@ class BaseScaleColumnParallelLinear(nn.Linear): self.process_group = process_group self.weight_scale = weight_scale + class ScaleColumnParallelLinear(BaseScaleColumnParallelLinear): """ ScaleColumnParallelLinear in flash implementation. """ + def forward(self, input, gather_dim=0): # pylint: disable=W0622 # If self.sequence_parallel is True, we're doing Tensor Parallel with sequence parallelism: # we do an all_gather of x before doing the matmul. @@ -79,6 +76,7 @@ class ScaleColumnParallelLinear(BaseScaleColumnParallelLinear): gather_dim=gather_dim, ) + class MegatronScaleColumnParallelLinear(BaseScaleColumnParallelLinear): """ ScaleColumnParallelLinear in megatron implementation. @@ -101,6 +99,7 @@ class MegatronScaleColumnParallelLinear(BaseScaleColumnParallelLinear): gather_dim=gather_dim, ) + class RewardModelLinear(ScaleColumnParallelLinear): """ RewardModelLinear. @@ -164,6 +163,7 @@ class ColumnParallelLinearTorch(ColumnParallelLinear): gather_dim=gather_dim, ) + class MegatronColumnParallelLinearTorch(ColumnParallelLinear): def forward(self, x, gather_dim=0): # If self.sequence_parallel is True, we're doing Tensor Parallel with sequence parallelism: @@ -178,6 +178,7 @@ class MegatronColumnParallelLinearTorch(ColumnParallelLinear): gather_dim=gather_dim, ) + class RowParallelLinearTorch(RowParallelLinear): def forward(self, x): """ @@ -188,6 +189,7 @@ class RowParallelLinearTorch(RowParallelLinear): reduce_fn = reduce_scatter if self.sequence_parallel else all_reduce return reduce_fn(out, self.process_group) + class MegatronRowParallelLinearTorch(RowParallelLinear): def forward(self, x): """ @@ -225,8 +227,8 @@ class BaseFeedForward(nn.Module): device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None, multiple_of: int = 256, - colum_cls = None, - row_cls = None, + colum_cls=None, + row_cls=None, ): super().__init__() hidden_features = multiple_of * ((hidden_features + multiple_of - 1) // multiple_of) @@ -265,6 +267,7 @@ class BaseFeedForward(nn.Module): out = self.w3(Silu(w1_o, w2_o)) return out + class FeedForward(BaseFeedForward): """ FeedForward in flash implementation. @@ -292,9 +295,19 @@ class FeedForward(BaseFeedForward): dtype: Optional[torch.dtype] = None, multiple_of: int = 256, ): - super().__init__(in_features, hidden_features, out_features, process_group, bias, device, - dtype, multiple_of, ColumnParallelLinearTorch, RowParallelLinearTorch) - + super().__init__( + in_features, + hidden_features, + out_features, + process_group, + bias, + device, + dtype, + multiple_of, + ColumnParallelLinearTorch, + RowParallelLinearTorch, + ) + class MegatronFeedForward(BaseFeedForward): """ @@ -323,19 +336,35 @@ class MegatronFeedForward(BaseFeedForward): dtype: Optional[torch.dtype] = None, multiple_of: int = 256, ): - super().__init__(in_features, hidden_features, out_features, process_group, bias, device, - dtype, multiple_of, MegatronColumnParallelLinearTorch, MegatronRowParallelLinearTorch) + super().__init__( + in_features, + hidden_features, + out_features, + process_group, + bias, + device, + dtype, + multiple_of, + MegatronColumnParallelLinearTorch, + MegatronRowParallelLinearTorch, + ) + class FSTPLinear(ColumnParallelLinear): def forward(self, x): block_index = gpc.config.fstp_handler.module_to_index[self] - name_index = gpc.config.fstp_handler.module_name_index[self] - name = gpc.config.fstp_handler.module_name[name_index] return fstp_fused_dense_func( - x, self.weight, self.bias, process_group=self.process_group, - module=self, handler=gpc.config.fstp_handler, block_index=block_index, module_name=name + x, + self.weight, + self.bias, + process_group=self.process_group, + module=self, + handler=gpc.config.fstp_handler, + block_index=block_index, + module_name=self._fstp_name, ) + class FSTPFeedForward(BaseFeedForward): """ FeedForward in FSTP. @@ -363,8 +392,19 @@ class FSTPFeedForward(BaseFeedForward): dtype: Optional[torch.dtype] = None, multiple_of: int = 256, ): - super().__init__(in_features, hidden_features, out_features, process_group, bias, device, - dtype, multiple_of, FSTPLinear, FSTPLinear) + super().__init__( + in_features, + hidden_features, + out_features, + process_group, + bias, + device, + dtype, + multiple_of, + FSTPLinear, + FSTPLinear, + ) + def get_mlp_cls(sp_mode: str): if sp_mode in ["none", "flash-attn"]: @@ -375,6 +415,7 @@ def get_mlp_cls(sp_mode: str): mlp_cls = FSTPFeedForward return mlp_cls + def get_linear_cls(sp_mode: str, parallel_mode: str): if parallel_mode == "column": if sp_mode in ["none", "flash-attn"]: @@ -383,7 +424,7 @@ def get_linear_cls(sp_mode: str, parallel_mode: str): cls = MegatronColumnParallelLinearTorch else: cls = FSTPLinear - elif parallel_mode == 'row': + elif parallel_mode == "row": if sp_mode in ["none", "flash-attn"]: cls = RowParallelLinearTorch elif sp_mode == "megatron": @@ -391,192 +432,3 @@ def get_linear_cls(sp_mode: str, parallel_mode: str): else: cls = FSTPLinear return cls - -class CoarseGrainedFSTPAllGatherSyncHandler: - """ - All-gather handler for overlapping the all-gather in adjcent FSTP block. - """ - - def __init__(self, model: Union[nn.Module, nn.ModuleList], process_group) -> None: - # import pdb; pdb.set_trace() - self.process_group = process_group - self.FSTP_blocks = [] - self.FSTP_outs = [] - self.FSTP_modules = [] - self.module_name = ["Wqkv", "out_proj", "w1", "w2", "w3"] - self.FSTP_global_handle = dict() # key: FSTP module; value: module global all-gather op handle - self.FSTP_global_weights = dict() # key: FSTP module; value: module global weight for forward - self.block_handles = dict() # key: transformer block; value: all-gather handles - self.module_to_index = dict() # key: FSTP module; value: transformer block index - self.block_to_index = dict() # key: transformer block; value: transformer block index - self.index_to_block = dict() # key: transformer block index; value: transformer block - self.index_to_fsdp_modules = dict() # key: transformer block index; value: fsdp modules - self.module_name_index = dict() # key: FSTP module; value: the name in index in self.module_name - self.block_module = dict() # key: transformer block index; value: {name_index: FSTP module} - self.head = [] - self.embedding = [] - - self.reduce_scatter_handlers = {} - self.all_reduce_handlers = {} - self.zero_const_pool = {} - - # just want to share same for loop for ModuleList and Module - if not isinstance(model, nn.ModuleList): - model = [model] - - for _chunk in model: - if isinstance(_chunk, NaiveAMPModel): - _chunk = _chunk.model - - for _chunk_name, children in _chunk.named_children(): - if isinstance(children, nn.ModuleList): - for idx, block in enumerate(children): - index = 0 - self.block_module[idx] = {} - self.FSTP_blocks.append(block) - self.block_to_index[block] = idx - self.index_to_block[idx] = block - self.index_to_fsdp_modules[idx] = [] - for _sub_name, sub in block.named_children(): - sub_modules = list(sub.children()) - if len(sub_modules) > 0: - for name, child in sub.named_children(): - if name == "out_proj": - self.FSTP_outs.append(child) - self.module_to_index[child] = idx - if isinstance(child, FSTPLinear): - self.module_to_index[child] = idx - self.block_module[idx][index] = child - self.FSTP_modules.append(child) - self.index_to_fsdp_modules[idx].append(child) - self.module_name_index[child] = index - index = index + 1 - - _full_name = f"{_chunk_name}.{idx}.{_sub_name}.{name}" - setattr(child.weight, "_fstp_reduce_scatter_str", f"{_full_name}.weight") - if child.bias is not None: - setattr(child.bias, "_fstp_reduce_scatter_str", f"{_full_name}.bias") - else: - continue - elif isinstance(children, ScaleColumnParallelLinear): - self.head.append(children) - elif isinstance(children, Embedding1D): - self.embedding.append(children) - - def get_zero_by_shape(self, size: tuple, dtype, device) -> torch.Tensor: - if size not in self.zero_const_pool: - self.zero_const_pool[size] = torch.zeros(*size, dtype=dtype, device=device).contiguous() - - return self.zero_const_pool[size] - - def _all_gather_block_weight_memory_pool(self, block_index: int): - fsdp_modules = self.index_to_fsdp_modules[block_index] - for module in fsdp_modules: - module_index = self.module_name_index[module] - name = self.module_name[module_index] - weight_handle = all_gather_raw_memory_pool( - module.weight, self.process_group, async_op=True, block_index=block_index, module_name=name - ) - self.FSTP_global_handle[module] = weight_handle - - def _register_sync_parameters_hook(self) -> None: - """ - register pre_forward_hook and pre_backward_hook for FSTP block. - - Notice that next block's all_gather op should be after current block's all_to_all op, so we - 1. register pre_forward_hook @out_proj module to prefetch for next block - 2. register pre_forward_hook @block module to wait handles for next block - 3. register pre_backward_hook @wqkv module to prefetch for next block - 4. register pre_backward_hook @block module to wait handles for next block - """ - - def _pre_forward_hook_for_out_proj(module: nn.Module, inputs: Any): - block_index = self.module_to_index[module] - # start the all-gather for next block - if block_index + 1 < gpc.config.NUM_LAYER: - self._all_gather_block_weight_memory_pool(block_index + 1) - - def _post_forward_hook_for_embedding(module: nn.Module, inputs: Any, output): - self._all_gather_block_weight_memory_pool(0) - - def _pre_forward_hook_for_module(module: nn.Module, inputs: Any): - handle = self.FSTP_global_handle[module] - handle.wait() - - def _post_forward_hook_for_module(module: nn.Module, input, output): - if module in self.FSTP_global_weights: - del self.FSTP_global_weights[module] - if module in self.FSTP_global_handle: - del self.FSTP_global_handle[module] - - def _post_backward_hook_for_head(module: nn.Module, grad_input, grad_output): - first_module = self.block_module[gpc.config.NUM_LAYER - 1][4] - total_weight, weight_handler = all_gather_raw(first_module.weight, self.process_group, async_op=True) - self.FSTP_global_handle[first_module] = weight_handler - self.FSTP_global_weights[first_module] = total_weight - - def _pre_backward_hook_for_module_memory_pool(module: nn.Module, grad_output): - block_index = self.module_to_index[module] - name_index = self.module_name_index[module] - - if name_index == 4 and block_index == gpc.config.NUM_LAYER - 1: - weight_handler = self.FSTP_global_handle[module] - weight_handler.wait() - - # start the all-gather for next module - next_module = self.block_module[block_index][name_index - 1] - next_name = self.module_name[name_index - 1] - weights_handler = all_gather_raw_memory_pool( - next_module.weight, - self.process_group, - async_op=True, - block_index=block_index, - module_name=next_name, - ) - self.FSTP_global_handle[next_module] = weights_handler - elif name_index == 0: - handler = self.FSTP_global_handle[module] - handler.wait() - - if block_index - 1 >= 0: - next_module = self.block_module[block_index - 1][4] - name = self.module_name[4] - weights_handler = all_gather_raw_memory_pool( - next_module.weight, - self.process_group, - async_op=True, - block_index=block_index - 1, - module_name=name, - ) - self.FSTP_global_handle[next_module] = weights_handler - else: - handler = self.FSTP_global_handle[module] - handler.wait() - if name_index != 0: - next_module = self.block_module[block_index][name_index - 1] - name = self.module_name[name_index - 1] - weights_handler = all_gather_raw_memory_pool( - next_module.weight, self.process_group, async_op=True, block_index=block_index, module_name=name - ) - self.FSTP_global_handle[next_module] = weights_handler - - def _post_backward_hook_for_module(module, grad_input, grad_output): - if module in self.FSTP_global_weights: - del self.FSTP_global_weights[module] - if module in self.FSTP_global_handle: - del self.FSTP_global_handle[module] - - for embedding in self.embedding: - embedding.register_forward_hook(_post_forward_hook_for_embedding) - - for head in self.head: - head.register_full_backward_hook(_post_backward_hook_for_head) - - for out_proj in self.FSTP_outs: - out_proj.register_forward_pre_hook(_pre_forward_hook_for_out_proj) - - for module in self.FSTP_modules: - module.register_forward_pre_hook(_pre_forward_hook_for_module) - module.register_forward_hook(_post_forward_hook_for_module) - module.register_full_backward_pre_hook(_pre_backward_hook_for_module_memory_pool) - module.register_full_backward_hook(_post_backward_hook_for_module) diff --git a/internlm/model/modeling_internlm.py b/internlm/model/modeling_internlm.py index 3ed78d7..228e1e1 100644 --- a/internlm/model/modeling_internlm.py +++ b/internlm/model/modeling_internlm.py @@ -14,12 +14,9 @@ from internlm.core.context.parallel_context import global_context as gpc from internlm.initialize.initialize_tensor import normal_, scaled_init_method_normal from internlm.model.embedding import Embedding1D from internlm.model.linear import ( - FeedForward, - MegatronFeedForward, - FSTPFeedForward, + MegatronScaleColumnParallelLinear, RewardModelLinear, ScaleColumnParallelLinear, - MegatronScaleColumnParallelLinear, get_mlp_cls, ) from internlm.model.multi_head_attention import MHA @@ -309,7 +306,11 @@ class PackedFlashInternLm1D(nn.Module): if is_reward: head_cls = RewardModelLinear else: - head_cls = ScaleColumnParallelLinear if self.sp_mode in ["flash-attn", "none", "intern"] else MegatronScaleColumnParallelLinear + head_cls = ( + ScaleColumnParallelLinear + if self.sp_mode in ["flash-attn", "none", "intern"] + else MegatronScaleColumnParallelLinear + ) if first: if embed_split_hidden: self.embedding = Embedding1D(num_embeddings=vocab_size, embedding_dim=hidden_size) diff --git a/internlm/model/multi_head_attention.py b/internlm/model/multi_head_attention.py index 8ba49ed..93dbf01 100644 --- a/internlm/model/multi_head_attention.py +++ b/internlm/model/multi_head_attention.py @@ -38,14 +38,7 @@ from torch.nn import Module from internlm.core.context import IS_TENSOR_PARALLEL, ParallelMode from internlm.core.context import global_context as gpc from internlm.model.embedding import DynamicNTKScalingRotaryEmbedding, RotaryEmbedding -from internlm.model.linear import ( - ColumnParallelLinearTorch, - FSTPLinear, - RowParallelLinearTorch, - MegatronColumnParallelLinearTorch, - MegatronRowParallelLinearTorch, - get_linear_cls, -) +from internlm.model.linear import get_linear_cls # adpated from https://github.com/microsoft/DeepSpeed/blob/master/deepspeed/sequence/layer.py @@ -227,7 +220,7 @@ class MHA(nn.Module): self.inner_cross_attn = DistributedAttention(self.inner_cross_attn, sequence_process_group=process_group) # output projection always have the bias (for now) - out_proj_cls = get_linear_cls(sp_mode, 'row') + out_proj_cls = get_linear_cls(sp_mode, "row") self.out_proj = out_proj_cls( embed_dim, embed_dim, diff --git a/internlm/model/overlap_handler.py b/internlm/model/overlap_handler.py new file mode 100644 index 0000000..cafb818 --- /dev/null +++ b/internlm/model/overlap_handler.py @@ -0,0 +1,253 @@ +#!/usr/bin/env python +# -*- encoding: utf-8 -*- + +from typing import Any, Union + +import torch +from torch import nn + +from internlm.core.context import global_context as gpc +from internlm.core.naive_amp import NaiveAMPModel +from internlm.model.embedding import Embedding1D +from internlm.model.linear import FSTPLinear, ScaleColumnParallelLinear +from internlm.model.utils import all_gather_raw_memory_pool +from internlm.utils.common import get_current_device + + +class FSTPOverlapHandler: + """ + FSTP overlap handler for managing the all-gather and reduce_scatter overlapping. + """ + + def __init__(self, model: Union[nn.Module, nn.ModuleList], process_group) -> None: + self.process_group = process_group + self.fstp_outs = [] + self.fstp_modules = [] + self.module_name = ["Wqkv", "out_proj", "w1", "w2", "w3"] + self.fstp_global_handle = dict() # key: fstp module; value: module global all-gather op handle + self.module_to_index = dict() # key: fstp module; value: transformer block index + self.index_to_fstp_modules = dict() # key: transformer block index; value: fsdp modules + self.head = [] + self.embedding = [] + + self.reduce_scatter_handlers = {} + self.zero_const_pool = {} + + # just want to share same for loop for ModuleList and Module + if not isinstance(model, nn.ModuleList): + model = [model] + + for _chunk in model: + if isinstance(_chunk, NaiveAMPModel): + _chunk = _chunk.model + + for _chunk_name, children in _chunk.named_children(): + if isinstance(children, ScaleColumnParallelLinear): + self.head.append(children) + elif isinstance(children, Embedding1D): + self.embedding.append(children) + elif isinstance(children, nn.ModuleList): + for idx, block in enumerate(children): + self.index_to_fstp_modules[idx] = [] + for _sub_name, sub in block.named_children(): + sub_modules = list(sub.children()) + if len(sub_modules) > 0: + for name, child in sub.named_children(): + if name == "out_proj": + self.fstp_outs.append(child) + self.module_to_index[child] = idx + if isinstance(child, FSTPLinear): + self.module_to_index[child] = idx + self.fstp_modules.append(child) + self.index_to_fstp_modules[idx].append(child) + + setattr(child, "_fstp_name", name) + + _full_name = f"{_chunk_name}.{idx}.{_sub_name}.{name}" + setattr(child.weight, "_fstp_reduce_scatter_str", f"{_full_name}.weight") + if child.bias is not None: + setattr(child.bias, "_fstp_reduce_scatter_str", f"{_full_name}.bias") + + self._initialize_memory_pool() + self._register_sync_parameters_hook() + + def get_zero_by_shape(self, size: tuple, dtype, device) -> torch.Tensor: + if size not in self.zero_const_pool: + self.zero_const_pool[size] = torch.zeros(*size, dtype=dtype, device=device).contiguous() + + return self.zero_const_pool[size] + + def _initialize_memory_pool(self) -> None: + # allocate memory pool + hidden_size = gpc.config.HIDDEN_SIZE + mlp_ratio = gpc.config.MLP_RATIO + mlp_hidden_size = int(hidden_size * mlp_ratio) + mlp_hidden_size = 256 * ((mlp_hidden_size + 256 - 1) // 256) + self.all_gather_memory_pool = [] + self.reduce_scatter_memory_pool = {} + + for _ in range(2): + weight = {} + for name in self.module_name: + if name == "Wqkv": + weight[name] = torch.zeros( + (3 * hidden_size, hidden_size), + dtype=gpc.config.model.get("dtype", torch.half), + device=get_current_device(), + ).contiguous() + elif name == "out_proj": + weight[name] = torch.zeros( + (hidden_size, hidden_size), + dtype=gpc.config.model.get("dtype", torch.half), + device=get_current_device(), + ).contiguous() + elif name == "w1" or name == "w2": + weight[name] = torch.zeros( + (mlp_hidden_size, hidden_size), + dtype=gpc.config.model.get("dtype", torch.half), + device=get_current_device(), + ).contiguous() + else: + weight[name] = torch.zeros( + (hidden_size, mlp_hidden_size), + dtype=gpc.config.model.get("dtype", torch.half), + device=get_current_device(), + ).contiguous() + + self.all_gather_memory_pool.append(weight) # containing two groups of block weight + + def get_all_gather_memory(self, index, module_name): + return self.all_gather_memory_pool[index % 2][module_name] + + def get_reduce_scatter_memory(self, key): + return_idx = 0 + + # if key not in dict + if key not in self.reduce_scatter_memory_pool: + self.reduce_scatter_memory_pool[key] = {"data": [], "used": []} + + # if the data is empty + if len(self.reduce_scatter_memory_pool[key]["data"]) == 0: + self.reduce_scatter_memory_pool[key]["data"].append( + torch.zeros( + key, dtype=gpc.config.model.get("dtype", torch.half), device=get_current_device() + ).contiguous() + ) + self.reduce_scatter_memory_pool[key]["used"].append(True) + return_idx = 0 + return return_idx + else: # if not empty + for index, used in enumerate(self.reduce_scatter_memory_pool[key]["used"]): + if used is False: + self.reduce_scatter_memory_pool[key]["used"][index] = True + return_idx = index + return return_idx + # if the memory pool is all used + length = len(self.reduce_scatter_memory_pool[key]["data"]) + self.reduce_scatter_memory_pool[key]["data"].append( + torch.zeros( + key, dtype=gpc.config.model.get("dtype", torch.half), device=get_current_device() + ).contiguous() + ) + self.reduce_scatter_memory_pool[key]["used"].append(True) + return_idx = length + return return_idx + + def release_reduce_scatter_memory(self, size, index): + self.reduce_scatter_memory_pool[size]["used"][index] = False + + def _all_gather_block_weight_memory_pool(self, block_index: int): + fstp_modules = self.index_to_fstp_modules[block_index] + for module in fstp_modules: + weight_handle = all_gather_raw_memory_pool( + module.weight, + self.process_group, + async_op=True, + block_index=block_index, + module_name=getattr(module, "_fstp_name"), + ) + self.fstp_global_handle[module] = weight_handle + + def _register_sync_parameters_hook(self) -> None: + """ + register forward hooks and backward hooks for fstp modules. + """ + + def _post_forward_hook_for_embedding(module: nn.Module, inputs: Any, output: Any): + self._all_gather_block_weight_memory_pool(0) + + def _pre_forward_hook_for_out_proj(module: nn.Module, inputs: Any): + block_index = self.module_to_index[module] + # start the all-gather for next block + if block_index + 1 < gpc.config.NUM_LAYER: + self._all_gather_block_weight_memory_pool(block_index + 1) + + def _pre_forward_hook_for_module(module: nn.Module, inputs: Any): + handle = self.fstp_global_handle[module] + handle.wait() + + def _post_forward_hook_for_module(module: nn.Module, inputs: Any, output: Any): + if module in self.fstp_global_handle: + del self.fstp_global_handle[module] + + def _post_backward_hook_for_head(module: nn.Module, grad_input, grad_output): + first_backward_module = self.fstp_modules[-1] + block_index = self.module_to_index[first_backward_module] + weight_handle = all_gather_raw_memory_pool( + first_backward_module.weight, + self.process_group, + async_op=True, + block_index=block_index, + module_name=getattr(first_backward_module, "_fstp_name"), + ) + self.fstp_global_handle[first_backward_module] = weight_handle + + def _pre_backward_hook_for_module(module: nn.Module, grad_output): + # wait handle for current module + weight_handle = self.fstp_global_handle[module] + weight_handle.wait() + + # start the all-gather for next module + module_index = self.fstp_modules.index(module) + if module_index - 1 >= 0: + next_module = self.fstp_modules[module_index - 1] + block_index = self.module_to_index[next_module] + weight_handle = all_gather_raw_memory_pool( + next_module.weight, + self.process_group, + async_op=True, + block_index=block_index, + module_name=getattr(next_module, "_fstp_name"), + ) + self.fstp_global_handle[next_module] = weight_handle + + def _post_backward_hook_for_module(module, grad_input, grad_output): + if module in self.fstp_global_handle: + del self.fstp_global_handle[module] + + # register forward hooks + # 1. register post_forward_hook @embedding module to prefetch for block 0 + # 2. register pre_forward_hook @out_proj module to prefetch for next block, + # notice that next block's all_gather op should be after current block's all_to_all op + # 3. register pre_forward_hook @fstp_module to wait handle for current module + # 4. register post_forward_hook @fstp_module to release resource + for embedding in self.embedding: + embedding.register_forward_hook(_post_forward_hook_for_embedding) + + for out_proj in self.fstp_outs: + out_proj.register_forward_pre_hook(_pre_forward_hook_for_out_proj) + + for module in self.fstp_modules: + module.register_forward_pre_hook(_pre_forward_hook_for_module) + module.register_forward_hook(_post_forward_hook_for_module) + + # register backward hooks + # 1. register post_backward_hook @head module to prefetch for the last block's last module + # 2. register pre_backward_hook @fstp_module to wait handle for current module and to prefetch for next module + # 3. register post_backward_hook @fstp_module to release resource + for head in self.head: + head.register_full_backward_hook(_post_backward_hook_for_head) + + for module in self.fstp_modules: + module.register_full_backward_pre_hook(_pre_backward_hook_for_module) + module.register_full_backward_hook(_post_backward_hook_for_module) diff --git a/internlm/model/utils.py b/internlm/model/utils.py index b1894e9..ccdca48 100644 --- a/internlm/model/utils.py +++ b/internlm/model/utils.py @@ -135,7 +135,7 @@ def all_gather_raw_memory_pool( module_name: str = None, ): handle = torch.distributed.all_gather_into_tensor( - gpc.config.block_memory[block_index % 2][module_name], + gpc.config.fstp_handler.get_all_gather_memory(block_index, module_name), input_.contiguous(), group=process_group, async_op=async_op, @@ -166,8 +166,8 @@ def reduce_scatter_raw_memory_pool(input_: Tensor, process_group: ProcessGroup, world_size = torch.distributed.get_world_size(process_group) assert input_.shape[0] % world_size == 0 size = (input_.shape[0] // world_size, *input_.shape[1:]) - index = check_reduce_scatter_memory_pool(size) - output = gpc.config.reduce_scatter_memory[size]["data"][index] + index = gpc.config.fstp_handler.get_reduce_scatter_memory(size) + output = gpc.config.fstp_handler.reduce_scatter_memory_pool[size]["data"][index] setattr(output, "index", index) handle = torch.distributed.reduce_scatter_tensor( output, input_.contiguous(), group=process_group, async_op=async_op @@ -269,11 +269,11 @@ class FusedDenseFunc(torch.autograd.Function): class MegatronFusedDenseFunc(torch.autograd.Function): - ''' + """ FusedDenseFunc for tensor parallel in megatron implementation. The diffenrence between the implementation of flash-attn and megatron is that the total_x could be saved for backward in megatron, so that the all-gather in backward is ommited. - ''' + """ @staticmethod @custom_fwd @@ -355,9 +355,10 @@ class MegatronFusedDenseFunc(torch.autograd.Function): handle_grad_input.wait() return grad_input, grad_weight, grad_bias, None, None, None, None + # adpated from https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/ops/fused_dense.py class FusedDenseFuncTorch(FusedDenseFunc): - '''FusedDenseFunc in flash implementation for supporting torch.float32''' + """FusedDenseFunc in flash implementation for supporting torch.float32""" @staticmethod @custom_bwd @@ -407,8 +408,9 @@ class FusedDenseFuncTorch(FusedDenseFunc): handle_grad_input.wait() return grad_input, grad_weight, grad_bias, None, None, None, None + class MegatronFusedDenseFuncTorch(FusedDenseFunc): - '''FusedDenseFunc in megatron implementation for supporting torch.float32''' + """FusedDenseFunc in megatron implementation for supporting torch.float32""" @staticmethod @custom_bwd @@ -452,6 +454,7 @@ class MegatronFusedDenseFuncTorch(FusedDenseFunc): handle_grad_input.wait() return grad_input, grad_weight, grad_bias, None, None, None, None + class FSTPFusedDenseFunc(torch.autograd.Function): "FusedDenseFunc for FSTP, which is optimized based on flash implementation." @@ -485,7 +488,7 @@ class FSTPFusedDenseFunc(torch.autograd.Function): if world_size > 1: # do all_gather for weight and bias before actual computation if overlap_handler is not None: - total_weight = gpc.config.block_memory[block_index % 2][module_name] + total_weight = gpc.config.fstp_handler.get_all_gather_memory(block_index, module_name) else: total_weight, handle_weight = all_gather_raw(weight, process_group, async_op=True) handle_weight.wait() @@ -544,7 +547,7 @@ class FSTPFusedDenseFunc(torch.autograd.Function): world_size = gpc.get_world_size(ParallelMode.TENSOR) if world_size > 1: if overlap_handler is not None: - total_weight = gpc.config.block_memory[block_index % 2][module_name] + total_weight = gpc.config.fstp_handler.get_all_gather_memory(block_index, module_name) else: total_weight, handle_weight = all_gather_raw(weight, process_group, async_op=True) handle_weight.wait() @@ -559,17 +562,39 @@ class FSTPFusedDenseFunc(torch.autograd.Function): ) if world_size > 1: if overlap_handler is not None: - grad_weight_async, handle_grad_weight = reduce_scatter_raw_memory_pool(grad_weight, process_group, async_op=True) + grad_weight_async, handle_grad_weight = reduce_scatter_raw_memory_pool( + grad_weight, process_group, async_op=True + ) assert hasattr(weight, "_fstp_reduce_scatter_str") - overlap_handler.reduce_scatter_handlers[weight._fstp_reduce_scatter_str] = (handle_grad_weight, grad_weight_async) - grad_weight = overlap_handler.get_zero_by_shape((grad_weight.shape[0]//torch.distributed.get_world_size(process_group), *grad_weight.shape[1:]), dtype=grad_weight.dtype, device=grad_weight.device) + overlap_handler.reduce_scatter_handlers[weight._fstp_reduce_scatter_str] = ( + handle_grad_weight, + grad_weight_async, + ) + grad_weight = overlap_handler.get_zero_by_shape( + ( + grad_weight.shape[0] // torch.distributed.get_world_size(process_group), + *grad_weight.shape[1:], + ), + dtype=grad_weight.dtype, + device=grad_weight.device, + ) if grad_bias is not None: grad_bias_async, handle_grad_bias = reduce_scatter_raw_memory_pool( grad_bias, process_group, async_op=True ) assert hasattr(bias, "_fstp_reduce_scatter_str") - overlap_handler.reduce_scatter_handlers[bias._fstp_reduce_scatter_str] = (handle_grad_bias, grad_bias_async) - grad_bias = overlap_handler.get_zero_by_shape((grad_bias.shape[0]//torch.distributed.get_world_size(process_group), *grad_bias.shape[1:]), dtype=grad_bias.dtype, device=grad_bias.device) + overlap_handler.reduce_scatter_handlers[bias._fstp_reduce_scatter_str] = ( + handle_grad_bias, + grad_bias_async, + ) + grad_bias = overlap_handler.get_zero_by_shape( + ( + grad_bias.shape[0] // torch.distributed.get_world_size(process_group), + *grad_bias.shape[1:], + ), + dtype=grad_bias.dtype, + device=grad_bias.device, + ) else: grad_weight, handle_grad_weight = reduce_scatter_raw(grad_weight, process_group, async_op=True) if grad_bias is not None: @@ -613,6 +638,7 @@ def fused_dense_func_torch( else: return FusedDenseFuncTorch.apply(x, weight, bias, return_residual, process_group, sequence_parallel, gather_dim) + def megatron_fused_dense_func_torch( x: Tensor, weight: Tensor, @@ -626,9 +652,14 @@ def megatron_fused_dense_func_torch( x.dtype == torch.float32 and torch.is_autocast_enabled() ) if x.is_cuda and weight.is_cuda and (bias is None or bias.is_cuda) and dtype_eligible: - return MegatronFusedDenseFunc.apply(x, weight, bias, return_residual, process_group, sequence_parallel, gather_dim) + return MegatronFusedDenseFunc.apply( + x, weight, bias, return_residual, process_group, sequence_parallel, gather_dim + ) else: - return MegatronFusedDenseFuncTorch.apply(x, weight, bias, return_residual, process_group, sequence_parallel, gather_dim) + return MegatronFusedDenseFuncTorch.apply( + x, weight, bias, return_residual, process_group, sequence_parallel, gather_dim + ) + def fstp_fused_dense_func( x: Tensor, @@ -693,38 +724,3 @@ def Silu(w1_o, w2_o): Silu = torch.jit.script(Silu) - - -def check_reduce_scatter_memory_pool(key): - return_idx = 0 - - # if key not in dict - if key not in gpc.config.reduce_scatter_memory: - gpc.config.reduce_scatter_memory[key] = {"data": [], "used": []} - - # if the data is empty - if len(gpc.config.reduce_scatter_memory[key]["data"]) == 0: - gpc.config.reduce_scatter_memory[key]["data"].append( - torch.zeros(key, dtype=gpc.config.model.get("dtype", torch.half), device=get_current_device()).contiguous() - ) - gpc.config.reduce_scatter_memory[key]["used"].append(True) - return_idx = 0 - return return_idx - else: # if not empty - for index, used in enumerate(gpc.config.reduce_scatter_memory[key]["used"]): - if used is False: - gpc.config.reduce_scatter_memory[key]["used"][index] = True - return_idx = index - return return_idx - # if the memory pool is all used - length = len(gpc.config.reduce_scatter_memory[key]["data"]) - gpc.config.reduce_scatter_memory[key]["data"].append( - torch.zeros(key, dtype=gpc.config.model.get("dtype", torch.half), device=get_current_device()).contiguous() - ) - gpc.config.reduce_scatter_memory[key]["used"].append(True) - return_idx = length - return return_idx - - -def release_reduce_scatter_memory_pool(size, index): - gpc.config.reduce_scatter_memory[size]["used"][index] = False diff --git a/internlm/solver/optimizer/hybrid_zero_optim.py b/internlm/solver/optimizer/hybrid_zero_optim.py index 0f536ec..e2ec7ef 100644 --- a/internlm/solver/optimizer/hybrid_zero_optim.py +++ b/internlm/solver/optimizer/hybrid_zero_optim.py @@ -11,7 +11,6 @@ from torch.optim import Optimizer from internlm.core.context import Config, ParallelMode from internlm.core.context import global_context as gpc -from internlm.model.utils import release_reduce_scatter_memory_pool from internlm.monitor import send_alert_message from internlm.solver.optimizer.store import ( BucketStore, @@ -41,6 +40,7 @@ from .utils import compute_norm inf = math.inf logger = get_logger(__file__) + class HybridZeroOptimizer(BaseOptimizer): """ Hybrid Zero Optimizer. @@ -65,7 +65,7 @@ class HybridZeroOptimizer(BaseOptimizer): backoff_factor = grad_scal_cfg.backoff_factor hysteresis = grad_scal_cfg.hysteresis max_scale = grad_scal_cfg.max_scale - + self._fstp_handler = None if gpc.config.parallel["tensor"]["sp"] == "intern" and gpc.config.parallel["tensor"]["intern_overlap"] is True: self._fstp_handler = gpc.config.fstp_handler @@ -350,7 +350,7 @@ class HybridZeroOptimizer(BaseOptimizer): _param.grad.add_(_grad) # release cuda memory. - release_reduce_scatter_memory_pool(size=tuple(_grad.size()), index=_grad.index) + gpc.config.fstp_handler.release_reduce_scatter_memory(size=tuple(_grad.size()), index=_grad.index) self._fstp_handler.reduce_scatter_handlers[_key] = None bucket.reset_by_rank(reduce_rank) @@ -635,9 +635,9 @@ class HybridZeroOptimizer(BaseOptimizer): timer("sync_grad").start() self._sync_grad() timer("sync_grad").stop() - - res = self._step(closure=closure, norms=total_norms) - + + res = self._step(closure=closure, norms=total_norms) + return res def _step(self, closure=None, norms=None): diff --git a/internlm/train/training_internlm.py b/internlm/train/training_internlm.py index 53996b3..cabb7eb 100644 --- a/internlm/train/training_internlm.py +++ b/internlm/train/training_internlm.py @@ -36,12 +36,12 @@ from internlm.data.packed_dataset import ( from internlm.data.utils import DATASET_TYPE_IDS_MAP, unpack_data from internlm.model.embedding import Embedding1D from internlm.model.linear import ( - CoarseGrainedFSTPAllGatherSyncHandler, FeedForward, RewardModelLinear, ScaleColumnParallelLinear, ) from internlm.model.multi_head_attention import MHA +from internlm.model.overlap_handler import FSTPOverlapHandler from internlm.model.utils import try_import_RMSNorm from internlm.monitor import send_heartbeat, set_env_var from internlm.monitor.monitor import monitor_manager as mm @@ -109,60 +109,8 @@ def initialize_model(): model = wrap_FSDP_model(model) gpc.config.fstp_handler = None - if gpc.config.parallel["tensor"]["sp"] == "intern" and gpc.config.parallel["tensor"]["intern_overlap"] is True: - handler = CoarseGrainedFSTPAllGatherSyncHandler(model, gpc.get_group(ParallelMode.TENSOR)) - handler._register_sync_parameters_hook() - gpc.config.fstp_handler = handler - - # allocate memory pool - block_memory = {} # containing two groups of block weight - hidden_size = gpc.config.HIDDEN_SIZE - mlp_ratio = gpc.config.MLP_RATIO - mlp_hidden_size = int(hidden_size * mlp_ratio) - mlp_hidden_size = 256 * ((mlp_hidden_size + 256 - 1) // 256) - world_size = gpc.get_world_size(ParallelMode.TENSOR) - size_key = [ - (3 * hidden_size // world_size, hidden_size), - (mlp_hidden_size // world_size, hidden_size), - (hidden_size // world_size, mlp_hidden_size), - (hidden_size // world_size, hidden_size), - ] - module_name = ["Wqkv", "out_proj", "w1", "w2", "w3"] - for i in range(2): - weight = {} - for name in module_name: - if name == "Wqkv": - weight[name] = torch.zeros( - (3 * hidden_size, hidden_size), - dtype=gpc.config.model.get("dtype", torch.half), - device=get_current_device(), - ).contiguous() - elif name == "out_proj": - weight[name] = torch.zeros( - (hidden_size, hidden_size), - dtype=gpc.config.model.get("dtype", torch.half), - device=get_current_device(), - ).contiguous() - elif name == "w1" or name == "w2": - weight[name] = torch.zeros( - (mlp_hidden_size, hidden_size), - dtype=gpc.config.model.get("dtype", torch.half), - device=get_current_device(), - ).contiguous() - else: - weight[name] = torch.zeros( - (hidden_size, mlp_hidden_size), - dtype=gpc.config.model.get("dtype", torch.half), - device=get_current_device(), - ).contiguous() - block_memory[i] = weight - reduce_scatter_memory = {} - for key in size_key: - reduce_scatter_memory[key] = {"data": [], "used": []} - - gpc.config.block_memory = block_memory - gpc.config.reduce_scatter_memory = reduce_scatter_memory + gpc.config.fstp_handler = FSTPOverlapHandler(model, gpc.get_group(ParallelMode.TENSOR)) return model diff --git a/train.py b/train.py index 02f2802..5066960 100644 --- a/train.py +++ b/train.py @@ -299,7 +299,7 @@ def main(args): if gpc.config.fstp_handler is not None: gpc.config.fstp_handler.zero_const_pool = {} - gpc.config.fstp_handler.reduce_scatter_memory = {} + gpc.config.fstp_handler.reduce_scatter_memory_pool = {} # torch.cuda.memory._dump_snapshot(f"my_snapshot_{gpc.get_global_rank()}.pickle") torch.cuda.reset_peak_memory_stats()