From 85ad917ae430c2e89cf4444221c2ced9223d3552 Mon Sep 17 00:00:00 2001 From: huangting4201 <1538303371@qq.com> Date: Fri, 20 Oct 2023 21:50:32 +0800 Subject: [PATCH 1/8] feat(model/overlap_handler.py): refactor overlap hook handle --- configs/7B_sft.py | 2 +- internlm/model/linear.py | 296 +++++------------- internlm/model/modeling_internlm.py | 11 +- internlm/model/multi_head_attention.py | 11 +- internlm/model/overlap_handler.py | 253 +++++++++++++++ internlm/model/utils.py | 98 +++--- .../solver/optimizer/hybrid_zero_optim.py | 12 +- internlm/train/training_internlm.py | 56 +--- train.py | 2 +- 9 files changed, 392 insertions(+), 349 deletions(-) create mode 100644 internlm/model/overlap_handler.py diff --git a/configs/7B_sft.py b/configs/7B_sft.py index 09af7f4..c51c812 100644 --- a/configs/7B_sft.py +++ b/configs/7B_sft.py @@ -163,7 +163,7 @@ pipeline parallel (dict): """ parallel = dict( zero1=dict(size=-1, fsdp=False), - tensor=dict(size=8, sp="megatron", intern_overlap=True), + tensor=dict(size=8, sp="intern", intern_overlap=True), pipeline=dict(size=1, interleaved_overlap=True), ) diff --git a/internlm/model/linear.py b/internlm/model/linear.py index 2bbb941..6cd3b9c 100644 --- a/internlm/model/linear.py +++ b/internlm/model/linear.py @@ -1,22 +1,17 @@ #!/usr/bin/env python # -*- encoding: utf-8 -*- -from typing import Any, Optional, Union +from typing import Optional import torch -import torch.nn.functional as F from flash_attn.ops.fused_dense import ColumnParallelLinear, RowParallelLinear from flash_attn.utils.distributed import all_reduce, reduce_scatter from torch import nn from internlm.core.context import ParallelMode from internlm.core.context import global_context as gpc -from internlm.core.naive_amp import NaiveAMPModel -from internlm.model.embedding import Embedding1D from internlm.model.utils import ( Silu, - all_gather_raw, - all_gather_raw_memory_pool, fstp_fused_dense_func, fused_dense_func_torch, megatron_fused_dense_func_torch, @@ -25,20 +20,20 @@ from internlm.model.utils import ( class BaseScaleColumnParallelLinear(nn.Linear): """ - Base class for ScaleColumnParallelLinear. + Base class for ScaleColumnParallelLinear. - Args: - in_features (int): size of each input sample - out_features (int): size of each output sample - process_group (Optional[torch.distributed.ProcessGroup]): The group of the current device for `parallel_mode`. - bias (bool): Whether the bias is needed for linears. True by default. But it is typically set to False - in the config. - sequence_parallel (bool): If sequence_parallel is True, we're doing Tensor Parallel with sequence parallelism: - we do an all_gather of x before doing the matmul. - If not, then the input is already gathered. - device (Optional[Union[str, torch.device]]): The device will be used. - dtype (Optional[torch.dtype]): The type of data. - weight_scale (int): For training stability. 1 by default. + Args: + in_features (int): size of each input sample + out_features (int): size of each output sample + process_group (Optional[torch.distributed.ProcessGroup]): The group of the current device for `parallel_mode`. + bias (bool): Whether the bias is needed for linears. True by default. But it is typically set to False + in the config. + sequence_parallel (bool): If sequence_parallel is True, we're doing Tensor Parallel with sequence parallelism: + we do an all_gather of x before doing the matmul. + If not, then the input is already gathered. + device (Optional[Union[str, torch.device]]): The device will be used. + dtype (Optional[torch.dtype]): The type of data. + weight_scale (int): For training stability. 1 by default. """ def __init__( @@ -58,10 +53,12 @@ class BaseScaleColumnParallelLinear(nn.Linear): self.process_group = process_group self.weight_scale = weight_scale + class ScaleColumnParallelLinear(BaseScaleColumnParallelLinear): """ ScaleColumnParallelLinear in flash implementation. """ + def forward(self, input, gather_dim=0): # pylint: disable=W0622 # If self.sequence_parallel is True, we're doing Tensor Parallel with sequence parallelism: # we do an all_gather of x before doing the matmul. @@ -79,6 +76,7 @@ class ScaleColumnParallelLinear(BaseScaleColumnParallelLinear): gather_dim=gather_dim, ) + class MegatronScaleColumnParallelLinear(BaseScaleColumnParallelLinear): """ ScaleColumnParallelLinear in megatron implementation. @@ -101,6 +99,7 @@ class MegatronScaleColumnParallelLinear(BaseScaleColumnParallelLinear): gather_dim=gather_dim, ) + class RewardModelLinear(ScaleColumnParallelLinear): """ RewardModelLinear. @@ -164,6 +163,7 @@ class ColumnParallelLinearTorch(ColumnParallelLinear): gather_dim=gather_dim, ) + class MegatronColumnParallelLinearTorch(ColumnParallelLinear): def forward(self, x, gather_dim=0): # If self.sequence_parallel is True, we're doing Tensor Parallel with sequence parallelism: @@ -178,6 +178,7 @@ class MegatronColumnParallelLinearTorch(ColumnParallelLinear): gather_dim=gather_dim, ) + class RowParallelLinearTorch(RowParallelLinear): def forward(self, x): """ @@ -188,6 +189,7 @@ class RowParallelLinearTorch(RowParallelLinear): reduce_fn = reduce_scatter if self.sequence_parallel else all_reduce return reduce_fn(out, self.process_group) + class MegatronRowParallelLinearTorch(RowParallelLinear): def forward(self, x): """ @@ -225,8 +227,8 @@ class BaseFeedForward(nn.Module): device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None, multiple_of: int = 256, - colum_cls = None, - row_cls = None, + colum_cls=None, + row_cls=None, ): super().__init__() hidden_features = multiple_of * ((hidden_features + multiple_of - 1) // multiple_of) @@ -265,6 +267,7 @@ class BaseFeedForward(nn.Module): out = self.w3(Silu(w1_o, w2_o)) return out + class FeedForward(BaseFeedForward): """ FeedForward in flash implementation. @@ -292,9 +295,19 @@ class FeedForward(BaseFeedForward): dtype: Optional[torch.dtype] = None, multiple_of: int = 256, ): - super().__init__(in_features, hidden_features, out_features, process_group, bias, device, - dtype, multiple_of, ColumnParallelLinearTorch, RowParallelLinearTorch) - + super().__init__( + in_features, + hidden_features, + out_features, + process_group, + bias, + device, + dtype, + multiple_of, + ColumnParallelLinearTorch, + RowParallelLinearTorch, + ) + class MegatronFeedForward(BaseFeedForward): """ @@ -323,19 +336,35 @@ class MegatronFeedForward(BaseFeedForward): dtype: Optional[torch.dtype] = None, multiple_of: int = 256, ): - super().__init__(in_features, hidden_features, out_features, process_group, bias, device, - dtype, multiple_of, MegatronColumnParallelLinearTorch, MegatronRowParallelLinearTorch) + super().__init__( + in_features, + hidden_features, + out_features, + process_group, + bias, + device, + dtype, + multiple_of, + MegatronColumnParallelLinearTorch, + MegatronRowParallelLinearTorch, + ) + class FSTPLinear(ColumnParallelLinear): def forward(self, x): block_index = gpc.config.fstp_handler.module_to_index[self] - name_index = gpc.config.fstp_handler.module_name_index[self] - name = gpc.config.fstp_handler.module_name[name_index] return fstp_fused_dense_func( - x, self.weight, self.bias, process_group=self.process_group, - module=self, handler=gpc.config.fstp_handler, block_index=block_index, module_name=name + x, + self.weight, + self.bias, + process_group=self.process_group, + module=self, + handler=gpc.config.fstp_handler, + block_index=block_index, + module_name=self._fstp_name, ) + class FSTPFeedForward(BaseFeedForward): """ FeedForward in FSTP. @@ -363,8 +392,19 @@ class FSTPFeedForward(BaseFeedForward): dtype: Optional[torch.dtype] = None, multiple_of: int = 256, ): - super().__init__(in_features, hidden_features, out_features, process_group, bias, device, - dtype, multiple_of, FSTPLinear, FSTPLinear) + super().__init__( + in_features, + hidden_features, + out_features, + process_group, + bias, + device, + dtype, + multiple_of, + FSTPLinear, + FSTPLinear, + ) + def get_mlp_cls(sp_mode: str): if sp_mode in ["none", "flash-attn"]: @@ -375,6 +415,7 @@ def get_mlp_cls(sp_mode: str): mlp_cls = FSTPFeedForward return mlp_cls + def get_linear_cls(sp_mode: str, parallel_mode: str): if parallel_mode == "column": if sp_mode in ["none", "flash-attn"]: @@ -383,7 +424,7 @@ def get_linear_cls(sp_mode: str, parallel_mode: str): cls = MegatronColumnParallelLinearTorch else: cls = FSTPLinear - elif parallel_mode == 'row': + elif parallel_mode == "row": if sp_mode in ["none", "flash-attn"]: cls = RowParallelLinearTorch elif sp_mode == "megatron": @@ -391,192 +432,3 @@ def get_linear_cls(sp_mode: str, parallel_mode: str): else: cls = FSTPLinear return cls - -class CoarseGrainedFSTPAllGatherSyncHandler: - """ - All-gather handler for overlapping the all-gather in adjcent FSTP block. - """ - - def __init__(self, model: Union[nn.Module, nn.ModuleList], process_group) -> None: - # import pdb; pdb.set_trace() - self.process_group = process_group - self.FSTP_blocks = [] - self.FSTP_outs = [] - self.FSTP_modules = [] - self.module_name = ["Wqkv", "out_proj", "w1", "w2", "w3"] - self.FSTP_global_handle = dict() # key: FSTP module; value: module global all-gather op handle - self.FSTP_global_weights = dict() # key: FSTP module; value: module global weight for forward - self.block_handles = dict() # key: transformer block; value: all-gather handles - self.module_to_index = dict() # key: FSTP module; value: transformer block index - self.block_to_index = dict() # key: transformer block; value: transformer block index - self.index_to_block = dict() # key: transformer block index; value: transformer block - self.index_to_fsdp_modules = dict() # key: transformer block index; value: fsdp modules - self.module_name_index = dict() # key: FSTP module; value: the name in index in self.module_name - self.block_module = dict() # key: transformer block index; value: {name_index: FSTP module} - self.head = [] - self.embedding = [] - - self.reduce_scatter_handlers = {} - self.all_reduce_handlers = {} - self.zero_const_pool = {} - - # just want to share same for loop for ModuleList and Module - if not isinstance(model, nn.ModuleList): - model = [model] - - for _chunk in model: - if isinstance(_chunk, NaiveAMPModel): - _chunk = _chunk.model - - for _chunk_name, children in _chunk.named_children(): - if isinstance(children, nn.ModuleList): - for idx, block in enumerate(children): - index = 0 - self.block_module[idx] = {} - self.FSTP_blocks.append(block) - self.block_to_index[block] = idx - self.index_to_block[idx] = block - self.index_to_fsdp_modules[idx] = [] - for _sub_name, sub in block.named_children(): - sub_modules = list(sub.children()) - if len(sub_modules) > 0: - for name, child in sub.named_children(): - if name == "out_proj": - self.FSTP_outs.append(child) - self.module_to_index[child] = idx - if isinstance(child, FSTPLinear): - self.module_to_index[child] = idx - self.block_module[idx][index] = child - self.FSTP_modules.append(child) - self.index_to_fsdp_modules[idx].append(child) - self.module_name_index[child] = index - index = index + 1 - - _full_name = f"{_chunk_name}.{idx}.{_sub_name}.{name}" - setattr(child.weight, "_fstp_reduce_scatter_str", f"{_full_name}.weight") - if child.bias is not None: - setattr(child.bias, "_fstp_reduce_scatter_str", f"{_full_name}.bias") - else: - continue - elif isinstance(children, ScaleColumnParallelLinear): - self.head.append(children) - elif isinstance(children, Embedding1D): - self.embedding.append(children) - - def get_zero_by_shape(self, size: tuple, dtype, device) -> torch.Tensor: - if size not in self.zero_const_pool: - self.zero_const_pool[size] = torch.zeros(*size, dtype=dtype, device=device).contiguous() - - return self.zero_const_pool[size] - - def _all_gather_block_weight_memory_pool(self, block_index: int): - fsdp_modules = self.index_to_fsdp_modules[block_index] - for module in fsdp_modules: - module_index = self.module_name_index[module] - name = self.module_name[module_index] - weight_handle = all_gather_raw_memory_pool( - module.weight, self.process_group, async_op=True, block_index=block_index, module_name=name - ) - self.FSTP_global_handle[module] = weight_handle - - def _register_sync_parameters_hook(self) -> None: - """ - register pre_forward_hook and pre_backward_hook for FSTP block. - - Notice that next block's all_gather op should be after current block's all_to_all op, so we - 1. register pre_forward_hook @out_proj module to prefetch for next block - 2. register pre_forward_hook @block module to wait handles for next block - 3. register pre_backward_hook @wqkv module to prefetch for next block - 4. register pre_backward_hook @block module to wait handles for next block - """ - - def _pre_forward_hook_for_out_proj(module: nn.Module, inputs: Any): - block_index = self.module_to_index[module] - # start the all-gather for next block - if block_index + 1 < gpc.config.NUM_LAYER: - self._all_gather_block_weight_memory_pool(block_index + 1) - - def _post_forward_hook_for_embedding(module: nn.Module, inputs: Any, output): - self._all_gather_block_weight_memory_pool(0) - - def _pre_forward_hook_for_module(module: nn.Module, inputs: Any): - handle = self.FSTP_global_handle[module] - handle.wait() - - def _post_forward_hook_for_module(module: nn.Module, input, output): - if module in self.FSTP_global_weights: - del self.FSTP_global_weights[module] - if module in self.FSTP_global_handle: - del self.FSTP_global_handle[module] - - def _post_backward_hook_for_head(module: nn.Module, grad_input, grad_output): - first_module = self.block_module[gpc.config.NUM_LAYER - 1][4] - total_weight, weight_handler = all_gather_raw(first_module.weight, self.process_group, async_op=True) - self.FSTP_global_handle[first_module] = weight_handler - self.FSTP_global_weights[first_module] = total_weight - - def _pre_backward_hook_for_module_memory_pool(module: nn.Module, grad_output): - block_index = self.module_to_index[module] - name_index = self.module_name_index[module] - - if name_index == 4 and block_index == gpc.config.NUM_LAYER - 1: - weight_handler = self.FSTP_global_handle[module] - weight_handler.wait() - - # start the all-gather for next module - next_module = self.block_module[block_index][name_index - 1] - next_name = self.module_name[name_index - 1] - weights_handler = all_gather_raw_memory_pool( - next_module.weight, - self.process_group, - async_op=True, - block_index=block_index, - module_name=next_name, - ) - self.FSTP_global_handle[next_module] = weights_handler - elif name_index == 0: - handler = self.FSTP_global_handle[module] - handler.wait() - - if block_index - 1 >= 0: - next_module = self.block_module[block_index - 1][4] - name = self.module_name[4] - weights_handler = all_gather_raw_memory_pool( - next_module.weight, - self.process_group, - async_op=True, - block_index=block_index - 1, - module_name=name, - ) - self.FSTP_global_handle[next_module] = weights_handler - else: - handler = self.FSTP_global_handle[module] - handler.wait() - if name_index != 0: - next_module = self.block_module[block_index][name_index - 1] - name = self.module_name[name_index - 1] - weights_handler = all_gather_raw_memory_pool( - next_module.weight, self.process_group, async_op=True, block_index=block_index, module_name=name - ) - self.FSTP_global_handle[next_module] = weights_handler - - def _post_backward_hook_for_module(module, grad_input, grad_output): - if module in self.FSTP_global_weights: - del self.FSTP_global_weights[module] - if module in self.FSTP_global_handle: - del self.FSTP_global_handle[module] - - for embedding in self.embedding: - embedding.register_forward_hook(_post_forward_hook_for_embedding) - - for head in self.head: - head.register_full_backward_hook(_post_backward_hook_for_head) - - for out_proj in self.FSTP_outs: - out_proj.register_forward_pre_hook(_pre_forward_hook_for_out_proj) - - for module in self.FSTP_modules: - module.register_forward_pre_hook(_pre_forward_hook_for_module) - module.register_forward_hook(_post_forward_hook_for_module) - module.register_full_backward_pre_hook(_pre_backward_hook_for_module_memory_pool) - module.register_full_backward_hook(_post_backward_hook_for_module) diff --git a/internlm/model/modeling_internlm.py b/internlm/model/modeling_internlm.py index 3ed78d7..228e1e1 100644 --- a/internlm/model/modeling_internlm.py +++ b/internlm/model/modeling_internlm.py @@ -14,12 +14,9 @@ from internlm.core.context.parallel_context import global_context as gpc from internlm.initialize.initialize_tensor import normal_, scaled_init_method_normal from internlm.model.embedding import Embedding1D from internlm.model.linear import ( - FeedForward, - MegatronFeedForward, - FSTPFeedForward, + MegatronScaleColumnParallelLinear, RewardModelLinear, ScaleColumnParallelLinear, - MegatronScaleColumnParallelLinear, get_mlp_cls, ) from internlm.model.multi_head_attention import MHA @@ -309,7 +306,11 @@ class PackedFlashInternLm1D(nn.Module): if is_reward: head_cls = RewardModelLinear else: - head_cls = ScaleColumnParallelLinear if self.sp_mode in ["flash-attn", "none", "intern"] else MegatronScaleColumnParallelLinear + head_cls = ( + ScaleColumnParallelLinear + if self.sp_mode in ["flash-attn", "none", "intern"] + else MegatronScaleColumnParallelLinear + ) if first: if embed_split_hidden: self.embedding = Embedding1D(num_embeddings=vocab_size, embedding_dim=hidden_size) diff --git a/internlm/model/multi_head_attention.py b/internlm/model/multi_head_attention.py index 8ba49ed..93dbf01 100644 --- a/internlm/model/multi_head_attention.py +++ b/internlm/model/multi_head_attention.py @@ -38,14 +38,7 @@ from torch.nn import Module from internlm.core.context import IS_TENSOR_PARALLEL, ParallelMode from internlm.core.context import global_context as gpc from internlm.model.embedding import DynamicNTKScalingRotaryEmbedding, RotaryEmbedding -from internlm.model.linear import ( - ColumnParallelLinearTorch, - FSTPLinear, - RowParallelLinearTorch, - MegatronColumnParallelLinearTorch, - MegatronRowParallelLinearTorch, - get_linear_cls, -) +from internlm.model.linear import get_linear_cls # adpated from https://github.com/microsoft/DeepSpeed/blob/master/deepspeed/sequence/layer.py @@ -227,7 +220,7 @@ class MHA(nn.Module): self.inner_cross_attn = DistributedAttention(self.inner_cross_attn, sequence_process_group=process_group) # output projection always have the bias (for now) - out_proj_cls = get_linear_cls(sp_mode, 'row') + out_proj_cls = get_linear_cls(sp_mode, "row") self.out_proj = out_proj_cls( embed_dim, embed_dim, diff --git a/internlm/model/overlap_handler.py b/internlm/model/overlap_handler.py new file mode 100644 index 0000000..cafb818 --- /dev/null +++ b/internlm/model/overlap_handler.py @@ -0,0 +1,253 @@ +#!/usr/bin/env python +# -*- encoding: utf-8 -*- + +from typing import Any, Union + +import torch +from torch import nn + +from internlm.core.context import global_context as gpc +from internlm.core.naive_amp import NaiveAMPModel +from internlm.model.embedding import Embedding1D +from internlm.model.linear import FSTPLinear, ScaleColumnParallelLinear +from internlm.model.utils import all_gather_raw_memory_pool +from internlm.utils.common import get_current_device + + +class FSTPOverlapHandler: + """ + FSTP overlap handler for managing the all-gather and reduce_scatter overlapping. + """ + + def __init__(self, model: Union[nn.Module, nn.ModuleList], process_group) -> None: + self.process_group = process_group + self.fstp_outs = [] + self.fstp_modules = [] + self.module_name = ["Wqkv", "out_proj", "w1", "w2", "w3"] + self.fstp_global_handle = dict() # key: fstp module; value: module global all-gather op handle + self.module_to_index = dict() # key: fstp module; value: transformer block index + self.index_to_fstp_modules = dict() # key: transformer block index; value: fsdp modules + self.head = [] + self.embedding = [] + + self.reduce_scatter_handlers = {} + self.zero_const_pool = {} + + # just want to share same for loop for ModuleList and Module + if not isinstance(model, nn.ModuleList): + model = [model] + + for _chunk in model: + if isinstance(_chunk, NaiveAMPModel): + _chunk = _chunk.model + + for _chunk_name, children in _chunk.named_children(): + if isinstance(children, ScaleColumnParallelLinear): + self.head.append(children) + elif isinstance(children, Embedding1D): + self.embedding.append(children) + elif isinstance(children, nn.ModuleList): + for idx, block in enumerate(children): + self.index_to_fstp_modules[idx] = [] + for _sub_name, sub in block.named_children(): + sub_modules = list(sub.children()) + if len(sub_modules) > 0: + for name, child in sub.named_children(): + if name == "out_proj": + self.fstp_outs.append(child) + self.module_to_index[child] = idx + if isinstance(child, FSTPLinear): + self.module_to_index[child] = idx + self.fstp_modules.append(child) + self.index_to_fstp_modules[idx].append(child) + + setattr(child, "_fstp_name", name) + + _full_name = f"{_chunk_name}.{idx}.{_sub_name}.{name}" + setattr(child.weight, "_fstp_reduce_scatter_str", f"{_full_name}.weight") + if child.bias is not None: + setattr(child.bias, "_fstp_reduce_scatter_str", f"{_full_name}.bias") + + self._initialize_memory_pool() + self._register_sync_parameters_hook() + + def get_zero_by_shape(self, size: tuple, dtype, device) -> torch.Tensor: + if size not in self.zero_const_pool: + self.zero_const_pool[size] = torch.zeros(*size, dtype=dtype, device=device).contiguous() + + return self.zero_const_pool[size] + + def _initialize_memory_pool(self) -> None: + # allocate memory pool + hidden_size = gpc.config.HIDDEN_SIZE + mlp_ratio = gpc.config.MLP_RATIO + mlp_hidden_size = int(hidden_size * mlp_ratio) + mlp_hidden_size = 256 * ((mlp_hidden_size + 256 - 1) // 256) + self.all_gather_memory_pool = [] + self.reduce_scatter_memory_pool = {} + + for _ in range(2): + weight = {} + for name in self.module_name: + if name == "Wqkv": + weight[name] = torch.zeros( + (3 * hidden_size, hidden_size), + dtype=gpc.config.model.get("dtype", torch.half), + device=get_current_device(), + ).contiguous() + elif name == "out_proj": + weight[name] = torch.zeros( + (hidden_size, hidden_size), + dtype=gpc.config.model.get("dtype", torch.half), + device=get_current_device(), + ).contiguous() + elif name == "w1" or name == "w2": + weight[name] = torch.zeros( + (mlp_hidden_size, hidden_size), + dtype=gpc.config.model.get("dtype", torch.half), + device=get_current_device(), + ).contiguous() + else: + weight[name] = torch.zeros( + (hidden_size, mlp_hidden_size), + dtype=gpc.config.model.get("dtype", torch.half), + device=get_current_device(), + ).contiguous() + + self.all_gather_memory_pool.append(weight) # containing two groups of block weight + + def get_all_gather_memory(self, index, module_name): + return self.all_gather_memory_pool[index % 2][module_name] + + def get_reduce_scatter_memory(self, key): + return_idx = 0 + + # if key not in dict + if key not in self.reduce_scatter_memory_pool: + self.reduce_scatter_memory_pool[key] = {"data": [], "used": []} + + # if the data is empty + if len(self.reduce_scatter_memory_pool[key]["data"]) == 0: + self.reduce_scatter_memory_pool[key]["data"].append( + torch.zeros( + key, dtype=gpc.config.model.get("dtype", torch.half), device=get_current_device() + ).contiguous() + ) + self.reduce_scatter_memory_pool[key]["used"].append(True) + return_idx = 0 + return return_idx + else: # if not empty + for index, used in enumerate(self.reduce_scatter_memory_pool[key]["used"]): + if used is False: + self.reduce_scatter_memory_pool[key]["used"][index] = True + return_idx = index + return return_idx + # if the memory pool is all used + length = len(self.reduce_scatter_memory_pool[key]["data"]) + self.reduce_scatter_memory_pool[key]["data"].append( + torch.zeros( + key, dtype=gpc.config.model.get("dtype", torch.half), device=get_current_device() + ).contiguous() + ) + self.reduce_scatter_memory_pool[key]["used"].append(True) + return_idx = length + return return_idx + + def release_reduce_scatter_memory(self, size, index): + self.reduce_scatter_memory_pool[size]["used"][index] = False + + def _all_gather_block_weight_memory_pool(self, block_index: int): + fstp_modules = self.index_to_fstp_modules[block_index] + for module in fstp_modules: + weight_handle = all_gather_raw_memory_pool( + module.weight, + self.process_group, + async_op=True, + block_index=block_index, + module_name=getattr(module, "_fstp_name"), + ) + self.fstp_global_handle[module] = weight_handle + + def _register_sync_parameters_hook(self) -> None: + """ + register forward hooks and backward hooks for fstp modules. + """ + + def _post_forward_hook_for_embedding(module: nn.Module, inputs: Any, output: Any): + self._all_gather_block_weight_memory_pool(0) + + def _pre_forward_hook_for_out_proj(module: nn.Module, inputs: Any): + block_index = self.module_to_index[module] + # start the all-gather for next block + if block_index + 1 < gpc.config.NUM_LAYER: + self._all_gather_block_weight_memory_pool(block_index + 1) + + def _pre_forward_hook_for_module(module: nn.Module, inputs: Any): + handle = self.fstp_global_handle[module] + handle.wait() + + def _post_forward_hook_for_module(module: nn.Module, inputs: Any, output: Any): + if module in self.fstp_global_handle: + del self.fstp_global_handle[module] + + def _post_backward_hook_for_head(module: nn.Module, grad_input, grad_output): + first_backward_module = self.fstp_modules[-1] + block_index = self.module_to_index[first_backward_module] + weight_handle = all_gather_raw_memory_pool( + first_backward_module.weight, + self.process_group, + async_op=True, + block_index=block_index, + module_name=getattr(first_backward_module, "_fstp_name"), + ) + self.fstp_global_handle[first_backward_module] = weight_handle + + def _pre_backward_hook_for_module(module: nn.Module, grad_output): + # wait handle for current module + weight_handle = self.fstp_global_handle[module] + weight_handle.wait() + + # start the all-gather for next module + module_index = self.fstp_modules.index(module) + if module_index - 1 >= 0: + next_module = self.fstp_modules[module_index - 1] + block_index = self.module_to_index[next_module] + weight_handle = all_gather_raw_memory_pool( + next_module.weight, + self.process_group, + async_op=True, + block_index=block_index, + module_name=getattr(next_module, "_fstp_name"), + ) + self.fstp_global_handle[next_module] = weight_handle + + def _post_backward_hook_for_module(module, grad_input, grad_output): + if module in self.fstp_global_handle: + del self.fstp_global_handle[module] + + # register forward hooks + # 1. register post_forward_hook @embedding module to prefetch for block 0 + # 2. register pre_forward_hook @out_proj module to prefetch for next block, + # notice that next block's all_gather op should be after current block's all_to_all op + # 3. register pre_forward_hook @fstp_module to wait handle for current module + # 4. register post_forward_hook @fstp_module to release resource + for embedding in self.embedding: + embedding.register_forward_hook(_post_forward_hook_for_embedding) + + for out_proj in self.fstp_outs: + out_proj.register_forward_pre_hook(_pre_forward_hook_for_out_proj) + + for module in self.fstp_modules: + module.register_forward_pre_hook(_pre_forward_hook_for_module) + module.register_forward_hook(_post_forward_hook_for_module) + + # register backward hooks + # 1. register post_backward_hook @head module to prefetch for the last block's last module + # 2. register pre_backward_hook @fstp_module to wait handle for current module and to prefetch for next module + # 3. register post_backward_hook @fstp_module to release resource + for head in self.head: + head.register_full_backward_hook(_post_backward_hook_for_head) + + for module in self.fstp_modules: + module.register_full_backward_pre_hook(_pre_backward_hook_for_module) + module.register_full_backward_hook(_post_backward_hook_for_module) diff --git a/internlm/model/utils.py b/internlm/model/utils.py index b1894e9..ccdca48 100644 --- a/internlm/model/utils.py +++ b/internlm/model/utils.py @@ -135,7 +135,7 @@ def all_gather_raw_memory_pool( module_name: str = None, ): handle = torch.distributed.all_gather_into_tensor( - gpc.config.block_memory[block_index % 2][module_name], + gpc.config.fstp_handler.get_all_gather_memory(block_index, module_name), input_.contiguous(), group=process_group, async_op=async_op, @@ -166,8 +166,8 @@ def reduce_scatter_raw_memory_pool(input_: Tensor, process_group: ProcessGroup, world_size = torch.distributed.get_world_size(process_group) assert input_.shape[0] % world_size == 0 size = (input_.shape[0] // world_size, *input_.shape[1:]) - index = check_reduce_scatter_memory_pool(size) - output = gpc.config.reduce_scatter_memory[size]["data"][index] + index = gpc.config.fstp_handler.get_reduce_scatter_memory(size) + output = gpc.config.fstp_handler.reduce_scatter_memory_pool[size]["data"][index] setattr(output, "index", index) handle = torch.distributed.reduce_scatter_tensor( output, input_.contiguous(), group=process_group, async_op=async_op @@ -269,11 +269,11 @@ class FusedDenseFunc(torch.autograd.Function): class MegatronFusedDenseFunc(torch.autograd.Function): - ''' + """ FusedDenseFunc for tensor parallel in megatron implementation. The diffenrence between the implementation of flash-attn and megatron is that the total_x could be saved for backward in megatron, so that the all-gather in backward is ommited. - ''' + """ @staticmethod @custom_fwd @@ -355,9 +355,10 @@ class MegatronFusedDenseFunc(torch.autograd.Function): handle_grad_input.wait() return grad_input, grad_weight, grad_bias, None, None, None, None + # adpated from https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/ops/fused_dense.py class FusedDenseFuncTorch(FusedDenseFunc): - '''FusedDenseFunc in flash implementation for supporting torch.float32''' + """FusedDenseFunc in flash implementation for supporting torch.float32""" @staticmethod @custom_bwd @@ -407,8 +408,9 @@ class FusedDenseFuncTorch(FusedDenseFunc): handle_grad_input.wait() return grad_input, grad_weight, grad_bias, None, None, None, None + class MegatronFusedDenseFuncTorch(FusedDenseFunc): - '''FusedDenseFunc in megatron implementation for supporting torch.float32''' + """FusedDenseFunc in megatron implementation for supporting torch.float32""" @staticmethod @custom_bwd @@ -452,6 +454,7 @@ class MegatronFusedDenseFuncTorch(FusedDenseFunc): handle_grad_input.wait() return grad_input, grad_weight, grad_bias, None, None, None, None + class FSTPFusedDenseFunc(torch.autograd.Function): "FusedDenseFunc for FSTP, which is optimized based on flash implementation." @@ -485,7 +488,7 @@ class FSTPFusedDenseFunc(torch.autograd.Function): if world_size > 1: # do all_gather for weight and bias before actual computation if overlap_handler is not None: - total_weight = gpc.config.block_memory[block_index % 2][module_name] + total_weight = gpc.config.fstp_handler.get_all_gather_memory(block_index, module_name) else: total_weight, handle_weight = all_gather_raw(weight, process_group, async_op=True) handle_weight.wait() @@ -544,7 +547,7 @@ class FSTPFusedDenseFunc(torch.autograd.Function): world_size = gpc.get_world_size(ParallelMode.TENSOR) if world_size > 1: if overlap_handler is not None: - total_weight = gpc.config.block_memory[block_index % 2][module_name] + total_weight = gpc.config.fstp_handler.get_all_gather_memory(block_index, module_name) else: total_weight, handle_weight = all_gather_raw(weight, process_group, async_op=True) handle_weight.wait() @@ -559,17 +562,39 @@ class FSTPFusedDenseFunc(torch.autograd.Function): ) if world_size > 1: if overlap_handler is not None: - grad_weight_async, handle_grad_weight = reduce_scatter_raw_memory_pool(grad_weight, process_group, async_op=True) + grad_weight_async, handle_grad_weight = reduce_scatter_raw_memory_pool( + grad_weight, process_group, async_op=True + ) assert hasattr(weight, "_fstp_reduce_scatter_str") - overlap_handler.reduce_scatter_handlers[weight._fstp_reduce_scatter_str] = (handle_grad_weight, grad_weight_async) - grad_weight = overlap_handler.get_zero_by_shape((grad_weight.shape[0]//torch.distributed.get_world_size(process_group), *grad_weight.shape[1:]), dtype=grad_weight.dtype, device=grad_weight.device) + overlap_handler.reduce_scatter_handlers[weight._fstp_reduce_scatter_str] = ( + handle_grad_weight, + grad_weight_async, + ) + grad_weight = overlap_handler.get_zero_by_shape( + ( + grad_weight.shape[0] // torch.distributed.get_world_size(process_group), + *grad_weight.shape[1:], + ), + dtype=grad_weight.dtype, + device=grad_weight.device, + ) if grad_bias is not None: grad_bias_async, handle_grad_bias = reduce_scatter_raw_memory_pool( grad_bias, process_group, async_op=True ) assert hasattr(bias, "_fstp_reduce_scatter_str") - overlap_handler.reduce_scatter_handlers[bias._fstp_reduce_scatter_str] = (handle_grad_bias, grad_bias_async) - grad_bias = overlap_handler.get_zero_by_shape((grad_bias.shape[0]//torch.distributed.get_world_size(process_group), *grad_bias.shape[1:]), dtype=grad_bias.dtype, device=grad_bias.device) + overlap_handler.reduce_scatter_handlers[bias._fstp_reduce_scatter_str] = ( + handle_grad_bias, + grad_bias_async, + ) + grad_bias = overlap_handler.get_zero_by_shape( + ( + grad_bias.shape[0] // torch.distributed.get_world_size(process_group), + *grad_bias.shape[1:], + ), + dtype=grad_bias.dtype, + device=grad_bias.device, + ) else: grad_weight, handle_grad_weight = reduce_scatter_raw(grad_weight, process_group, async_op=True) if grad_bias is not None: @@ -613,6 +638,7 @@ def fused_dense_func_torch( else: return FusedDenseFuncTorch.apply(x, weight, bias, return_residual, process_group, sequence_parallel, gather_dim) + def megatron_fused_dense_func_torch( x: Tensor, weight: Tensor, @@ -626,9 +652,14 @@ def megatron_fused_dense_func_torch( x.dtype == torch.float32 and torch.is_autocast_enabled() ) if x.is_cuda and weight.is_cuda and (bias is None or bias.is_cuda) and dtype_eligible: - return MegatronFusedDenseFunc.apply(x, weight, bias, return_residual, process_group, sequence_parallel, gather_dim) + return MegatronFusedDenseFunc.apply( + x, weight, bias, return_residual, process_group, sequence_parallel, gather_dim + ) else: - return MegatronFusedDenseFuncTorch.apply(x, weight, bias, return_residual, process_group, sequence_parallel, gather_dim) + return MegatronFusedDenseFuncTorch.apply( + x, weight, bias, return_residual, process_group, sequence_parallel, gather_dim + ) + def fstp_fused_dense_func( x: Tensor, @@ -693,38 +724,3 @@ def Silu(w1_o, w2_o): Silu = torch.jit.script(Silu) - - -def check_reduce_scatter_memory_pool(key): - return_idx = 0 - - # if key not in dict - if key not in gpc.config.reduce_scatter_memory: - gpc.config.reduce_scatter_memory[key] = {"data": [], "used": []} - - # if the data is empty - if len(gpc.config.reduce_scatter_memory[key]["data"]) == 0: - gpc.config.reduce_scatter_memory[key]["data"].append( - torch.zeros(key, dtype=gpc.config.model.get("dtype", torch.half), device=get_current_device()).contiguous() - ) - gpc.config.reduce_scatter_memory[key]["used"].append(True) - return_idx = 0 - return return_idx - else: # if not empty - for index, used in enumerate(gpc.config.reduce_scatter_memory[key]["used"]): - if used is False: - gpc.config.reduce_scatter_memory[key]["used"][index] = True - return_idx = index - return return_idx - # if the memory pool is all used - length = len(gpc.config.reduce_scatter_memory[key]["data"]) - gpc.config.reduce_scatter_memory[key]["data"].append( - torch.zeros(key, dtype=gpc.config.model.get("dtype", torch.half), device=get_current_device()).contiguous() - ) - gpc.config.reduce_scatter_memory[key]["used"].append(True) - return_idx = length - return return_idx - - -def release_reduce_scatter_memory_pool(size, index): - gpc.config.reduce_scatter_memory[size]["used"][index] = False diff --git a/internlm/solver/optimizer/hybrid_zero_optim.py b/internlm/solver/optimizer/hybrid_zero_optim.py index 0f536ec..e2ec7ef 100644 --- a/internlm/solver/optimizer/hybrid_zero_optim.py +++ b/internlm/solver/optimizer/hybrid_zero_optim.py @@ -11,7 +11,6 @@ from torch.optim import Optimizer from internlm.core.context import Config, ParallelMode from internlm.core.context import global_context as gpc -from internlm.model.utils import release_reduce_scatter_memory_pool from internlm.monitor import send_alert_message from internlm.solver.optimizer.store import ( BucketStore, @@ -41,6 +40,7 @@ from .utils import compute_norm inf = math.inf logger = get_logger(__file__) + class HybridZeroOptimizer(BaseOptimizer): """ Hybrid Zero Optimizer. @@ -65,7 +65,7 @@ class HybridZeroOptimizer(BaseOptimizer): backoff_factor = grad_scal_cfg.backoff_factor hysteresis = grad_scal_cfg.hysteresis max_scale = grad_scal_cfg.max_scale - + self._fstp_handler = None if gpc.config.parallel["tensor"]["sp"] == "intern" and gpc.config.parallel["tensor"]["intern_overlap"] is True: self._fstp_handler = gpc.config.fstp_handler @@ -350,7 +350,7 @@ class HybridZeroOptimizer(BaseOptimizer): _param.grad.add_(_grad) # release cuda memory. - release_reduce_scatter_memory_pool(size=tuple(_grad.size()), index=_grad.index) + gpc.config.fstp_handler.release_reduce_scatter_memory(size=tuple(_grad.size()), index=_grad.index) self._fstp_handler.reduce_scatter_handlers[_key] = None bucket.reset_by_rank(reduce_rank) @@ -635,9 +635,9 @@ class HybridZeroOptimizer(BaseOptimizer): timer("sync_grad").start() self._sync_grad() timer("sync_grad").stop() - - res = self._step(closure=closure, norms=total_norms) - + + res = self._step(closure=closure, norms=total_norms) + return res def _step(self, closure=None, norms=None): diff --git a/internlm/train/training_internlm.py b/internlm/train/training_internlm.py index 53996b3..cabb7eb 100644 --- a/internlm/train/training_internlm.py +++ b/internlm/train/training_internlm.py @@ -36,12 +36,12 @@ from internlm.data.packed_dataset import ( from internlm.data.utils import DATASET_TYPE_IDS_MAP, unpack_data from internlm.model.embedding import Embedding1D from internlm.model.linear import ( - CoarseGrainedFSTPAllGatherSyncHandler, FeedForward, RewardModelLinear, ScaleColumnParallelLinear, ) from internlm.model.multi_head_attention import MHA +from internlm.model.overlap_handler import FSTPOverlapHandler from internlm.model.utils import try_import_RMSNorm from internlm.monitor import send_heartbeat, set_env_var from internlm.monitor.monitor import monitor_manager as mm @@ -109,60 +109,8 @@ def initialize_model(): model = wrap_FSDP_model(model) gpc.config.fstp_handler = None - if gpc.config.parallel["tensor"]["sp"] == "intern" and gpc.config.parallel["tensor"]["intern_overlap"] is True: - handler = CoarseGrainedFSTPAllGatherSyncHandler(model, gpc.get_group(ParallelMode.TENSOR)) - handler._register_sync_parameters_hook() - gpc.config.fstp_handler = handler - - # allocate memory pool - block_memory = {} # containing two groups of block weight - hidden_size = gpc.config.HIDDEN_SIZE - mlp_ratio = gpc.config.MLP_RATIO - mlp_hidden_size = int(hidden_size * mlp_ratio) - mlp_hidden_size = 256 * ((mlp_hidden_size + 256 - 1) // 256) - world_size = gpc.get_world_size(ParallelMode.TENSOR) - size_key = [ - (3 * hidden_size // world_size, hidden_size), - (mlp_hidden_size // world_size, hidden_size), - (hidden_size // world_size, mlp_hidden_size), - (hidden_size // world_size, hidden_size), - ] - module_name = ["Wqkv", "out_proj", "w1", "w2", "w3"] - for i in range(2): - weight = {} - for name in module_name: - if name == "Wqkv": - weight[name] = torch.zeros( - (3 * hidden_size, hidden_size), - dtype=gpc.config.model.get("dtype", torch.half), - device=get_current_device(), - ).contiguous() - elif name == "out_proj": - weight[name] = torch.zeros( - (hidden_size, hidden_size), - dtype=gpc.config.model.get("dtype", torch.half), - device=get_current_device(), - ).contiguous() - elif name == "w1" or name == "w2": - weight[name] = torch.zeros( - (mlp_hidden_size, hidden_size), - dtype=gpc.config.model.get("dtype", torch.half), - device=get_current_device(), - ).contiguous() - else: - weight[name] = torch.zeros( - (hidden_size, mlp_hidden_size), - dtype=gpc.config.model.get("dtype", torch.half), - device=get_current_device(), - ).contiguous() - block_memory[i] = weight - reduce_scatter_memory = {} - for key in size_key: - reduce_scatter_memory[key] = {"data": [], "used": []} - - gpc.config.block_memory = block_memory - gpc.config.reduce_scatter_memory = reduce_scatter_memory + gpc.config.fstp_handler = FSTPOverlapHandler(model, gpc.get_group(ParallelMode.TENSOR)) return model diff --git a/train.py b/train.py index 02f2802..5066960 100644 --- a/train.py +++ b/train.py @@ -299,7 +299,7 @@ def main(args): if gpc.config.fstp_handler is not None: gpc.config.fstp_handler.zero_const_pool = {} - gpc.config.fstp_handler.reduce_scatter_memory = {} + gpc.config.fstp_handler.reduce_scatter_memory_pool = {} # torch.cuda.memory._dump_snapshot(f"my_snapshot_{gpc.get_global_rank()}.pickle") torch.cuda.reset_peak_memory_stats() From b20f47a1fe5fb446f2d9df5a83b31cb6033579f0 Mon Sep 17 00:00:00 2001 From: huangting4201 <1538303371@qq.com> Date: Mon, 23 Oct 2023 12:02:32 +0800 Subject: [PATCH 2/8] feat(model/overlap_handler.py): move handler to gpc --- internlm/model/linear.py | 5 +--- internlm/model/overlap_handler.py | 16 ++++------ internlm/model/utils.py | 29 ++++++------------- .../solver/optimizer/hybrid_zero_optim.py | 4 +-- internlm/train/training_internlm.py | 4 +-- train.py | 6 ++-- 6 files changed, 23 insertions(+), 41 deletions(-) diff --git a/internlm/model/linear.py b/internlm/model/linear.py index 6cd3b9c..b92b2ee 100644 --- a/internlm/model/linear.py +++ b/internlm/model/linear.py @@ -352,16 +352,13 @@ class MegatronFeedForward(BaseFeedForward): class FSTPLinear(ColumnParallelLinear): def forward(self, x): - block_index = gpc.config.fstp_handler.module_to_index[self] return fstp_fused_dense_func( x, self.weight, self.bias, process_group=self.process_group, module=self, - handler=gpc.config.fstp_handler, - block_index=block_index, - module_name=self._fstp_name, + handler=gpc.fstp_handler, ) diff --git a/internlm/model/overlap_handler.py b/internlm/model/overlap_handler.py index cafb818..b687723 100644 --- a/internlm/model/overlap_handler.py +++ b/internlm/model/overlap_handler.py @@ -116,8 +116,9 @@ class FSTPOverlapHandler: self.all_gather_memory_pool.append(weight) # containing two groups of block weight - def get_all_gather_memory(self, index, module_name): - return self.all_gather_memory_pool[index % 2][module_name] + def get_all_gather_memory(self, module): + block_index = self.module_to_index[module] + return self.all_gather_memory_pool[block_index % 2][module._fstp_name] def get_reduce_scatter_memory(self, key): return_idx = 0 @@ -163,8 +164,7 @@ class FSTPOverlapHandler: module.weight, self.process_group, async_op=True, - block_index=block_index, - module_name=getattr(module, "_fstp_name"), + module=module, ) self.fstp_global_handle[module] = weight_handle @@ -192,13 +192,11 @@ class FSTPOverlapHandler: def _post_backward_hook_for_head(module: nn.Module, grad_input, grad_output): first_backward_module = self.fstp_modules[-1] - block_index = self.module_to_index[first_backward_module] weight_handle = all_gather_raw_memory_pool( first_backward_module.weight, self.process_group, async_op=True, - block_index=block_index, - module_name=getattr(first_backward_module, "_fstp_name"), + module=first_backward_module, ) self.fstp_global_handle[first_backward_module] = weight_handle @@ -211,13 +209,11 @@ class FSTPOverlapHandler: module_index = self.fstp_modules.index(module) if module_index - 1 >= 0: next_module = self.fstp_modules[module_index - 1] - block_index = self.module_to_index[next_module] weight_handle = all_gather_raw_memory_pool( next_module.weight, self.process_group, async_op=True, - block_index=block_index, - module_name=getattr(next_module, "_fstp_name"), + module=next_module, ) self.fstp_global_handle[next_module] = weight_handle diff --git a/internlm/model/utils.py b/internlm/model/utils.py index ccdca48..cdbed95 100644 --- a/internlm/model/utils.py +++ b/internlm/model/utils.py @@ -7,13 +7,12 @@ import fused_dense_lib as fused_dense_cuda import torch import torch.nn.functional as F from flash_attn.utils.distributed import all_reduce_raw -from torch import Tensor +from torch import Tensor, nn from torch.cuda.amp import custom_bwd, custom_fwd from torch.distributed import ProcessGroup from internlm.core.context import ParallelMode from internlm.core.context import global_context as gpc -from internlm.utils.common import get_current_device from internlm.utils.logger import get_logger logger = get_logger(__file__) @@ -131,11 +130,10 @@ def all_gather_raw_memory_pool( process_group: ProcessGroup, async_op: bool = False, gather_dim: int = 0, - block_index: int = None, - module_name: str = None, + module: nn.Module = None, ): handle = torch.distributed.all_gather_into_tensor( - gpc.config.fstp_handler.get_all_gather_memory(block_index, module_name), + gpc.fstp_handler.get_all_gather_memory(module=module), input_.contiguous(), group=process_group, async_op=async_op, @@ -166,8 +164,8 @@ def reduce_scatter_raw_memory_pool(input_: Tensor, process_group: ProcessGroup, world_size = torch.distributed.get_world_size(process_group) assert input_.shape[0] % world_size == 0 size = (input_.shape[0] // world_size, *input_.shape[1:]) - index = gpc.config.fstp_handler.get_reduce_scatter_memory(size) - output = gpc.config.fstp_handler.reduce_scatter_memory_pool[size]["data"][index] + index = gpc.fstp_handler.get_reduce_scatter_memory(size) + output = gpc.fstp_handler.reduce_scatter_memory_pool[size]["data"][index] setattr(output, "index", index) handle = torch.distributed.reduce_scatter_tensor( output, input_.contiguous(), group=process_group, async_op=async_op @@ -469,16 +467,12 @@ class FSTPFusedDenseFunc(torch.autograd.Function): process_group=None, module=None, overlap_handler=None, - block_index=None, - module_name=None, ): ctx.compute_weight_gradient = weight.requires_grad ctx.return_residual = return_residual ctx.process_group = process_group ctx.overlap_handler = overlap_handler ctx.module = module - ctx.block_index = block_index - ctx.module_name = module_name if torch.is_autocast_enabled(): x = x.to(dtype=torch.get_autocast_gpu_dtype()) @@ -488,7 +482,7 @@ class FSTPFusedDenseFunc(torch.autograd.Function): if world_size > 1: # do all_gather for weight and bias before actual computation if overlap_handler is not None: - total_weight = gpc.config.fstp_handler.get_all_gather_memory(block_index, module_name) + total_weight = gpc.fstp_handler.get_all_gather_memory(module=module) else: total_weight, handle_weight = all_gather_raw(weight, process_group, async_op=True) handle_weight.wait() @@ -531,8 +525,7 @@ class FSTPFusedDenseFunc(torch.autograd.Function): grad_input = grad_input.contiguous() process_group = ctx.process_group overlap_handler = ctx.overlap_handler - block_index = ctx.block_index - module_name = ctx.module_name + module = ctx.module if ctx.compute_weight_gradient: x, weight, bias = ctx.saved_tensors @@ -547,7 +540,7 @@ class FSTPFusedDenseFunc(torch.autograd.Function): world_size = gpc.get_world_size(ParallelMode.TENSOR) if world_size > 1: if overlap_handler is not None: - total_weight = gpc.config.fstp_handler.get_all_gather_memory(block_index, module_name) + total_weight = gpc.fstp_handler.get_all_gather_memory(module=module) else: total_weight, handle_weight = all_gather_raw(weight, process_group, async_op=True) handle_weight.wait() @@ -669,16 +662,12 @@ def fstp_fused_dense_func( process_group=None, module=None, handler=None, - block_index=None, - module_name=None, ): dtype_eligible = x.dtype in [torch.float16, torch.bfloat16] or ( x.dtype == torch.float32 and torch.is_autocast_enabled() ) if x.is_cuda and weight.is_cuda and (bias is None or bias.is_cuda) and dtype_eligible: - return FSTPFusedDenseFunc.apply( - x, weight, bias, return_residual, process_group, module, handler, block_index, module_name - ) + return FSTPFusedDenseFunc.apply(x, weight, bias, return_residual, process_group, module, handler) else: assert process_group is None out = F.linear(x, weight, bias) diff --git a/internlm/solver/optimizer/hybrid_zero_optim.py b/internlm/solver/optimizer/hybrid_zero_optim.py index e2ec7ef..08d9722 100644 --- a/internlm/solver/optimizer/hybrid_zero_optim.py +++ b/internlm/solver/optimizer/hybrid_zero_optim.py @@ -68,7 +68,7 @@ class HybridZeroOptimizer(BaseOptimizer): self._fstp_handler = None if gpc.config.parallel["tensor"]["sp"] == "intern" and gpc.config.parallel["tensor"]["intern_overlap"] is True: - self._fstp_handler = gpc.config.fstp_handler + self._fstp_handler = gpc.fstp_handler # Zero related args reduce_bucket_size = zero_cfg.reduce_bucket_size @@ -350,7 +350,7 @@ class HybridZeroOptimizer(BaseOptimizer): _param.grad.add_(_grad) # release cuda memory. - gpc.config.fstp_handler.release_reduce_scatter_memory(size=tuple(_grad.size()), index=_grad.index) + gpc.fstp_handler.release_reduce_scatter_memory(size=tuple(_grad.size()), index=_grad.index) self._fstp_handler.reduce_scatter_handlers[_key] = None bucket.reset_by_rank(reduce_rank) diff --git a/internlm/train/training_internlm.py b/internlm/train/training_internlm.py index cabb7eb..b05611b 100644 --- a/internlm/train/training_internlm.py +++ b/internlm/train/training_internlm.py @@ -108,9 +108,9 @@ def initialize_model(): # if fsdp enabled, wrap the model model = wrap_FSDP_model(model) - gpc.config.fstp_handler = None + gpc.fstp_handler = None if gpc.config.parallel["tensor"]["sp"] == "intern" and gpc.config.parallel["tensor"]["intern_overlap"] is True: - gpc.config.fstp_handler = FSTPOverlapHandler(model, gpc.get_group(ParallelMode.TENSOR)) + gpc.fstp_handler = FSTPOverlapHandler(model, gpc.get_group(ParallelMode.TENSOR)) return model diff --git a/train.py b/train.py index 5066960..96dc24d 100644 --- a/train.py +++ b/train.py @@ -297,9 +297,9 @@ def main(args): prof.step() - if gpc.config.fstp_handler is not None: - gpc.config.fstp_handler.zero_const_pool = {} - gpc.config.fstp_handler.reduce_scatter_memory_pool = {} + if gpc.fstp_handler is not None: + gpc.fstp_handler.zero_const_pool = {} + gpc.fstp_handler.reduce_scatter_memory_pool = {} # torch.cuda.memory._dump_snapshot(f"my_snapshot_{gpc.get_global_rank()}.pickle") torch.cuda.reset_peak_memory_stats() From e7f9f1d20853e856f175d178bf94350871744b67 Mon Sep 17 00:00:00 2001 From: huangting4201 <1538303371@qq.com> Date: Mon, 23 Oct 2023 13:31:23 +0800 Subject: [PATCH 3/8] feat(model/overlap_handler.py): optimize reduce scatter mem pool --- internlm/model/overlap_handler.py | 35 ++++++++++--------- internlm/model/utils.py | 4 +-- .../solver/optimizer/hybrid_zero_optim.py | 2 +- 3 files changed, 20 insertions(+), 21 deletions(-) diff --git a/internlm/model/overlap_handler.py b/internlm/model/overlap_handler.py index b687723..b3c8b8b 100644 --- a/internlm/model/overlap_handler.py +++ b/internlm/model/overlap_handler.py @@ -125,37 +125,38 @@ class FSTPOverlapHandler: # if key not in dict if key not in self.reduce_scatter_memory_pool: - self.reduce_scatter_memory_pool[key] = {"data": [], "used": []} + self.reduce_scatter_memory_pool[key] = [] # if the data is empty - if len(self.reduce_scatter_memory_pool[key]["data"]) == 0: - self.reduce_scatter_memory_pool[key]["data"].append( + if len(self.reduce_scatter_memory_pool[key]) == 0: + self.reduce_scatter_memory_pool[key].append( torch.zeros( key, dtype=gpc.config.model.get("dtype", torch.half), device=get_current_device() ).contiguous() ) - self.reduce_scatter_memory_pool[key]["used"].append(True) - return_idx = 0 - return return_idx + setattr(self.reduce_scatter_memory_pool[key][return_idx], "idle", False) + setattr(self.reduce_scatter_memory_pool[key][return_idx], "index", return_idx) + return self.reduce_scatter_memory_pool[key][return_idx] else: # if not empty - for index, used in enumerate(self.reduce_scatter_memory_pool[key]["used"]): - if used is False: - self.reduce_scatter_memory_pool[key]["used"][index] = True + for index, mem_item in enumerate(self.reduce_scatter_memory_pool[key]): + if mem_item.idle is True: + self.reduce_scatter_memory_pool[key][index].idle = False return_idx = index - return return_idx + return self.reduce_scatter_memory_pool[key][return_idx] # if the memory pool is all used - length = len(self.reduce_scatter_memory_pool[key]["data"]) - self.reduce_scatter_memory_pool[key]["data"].append( + cur_len = len(self.reduce_scatter_memory_pool[key]) + self.reduce_scatter_memory_pool[key].append( torch.zeros( key, dtype=gpc.config.model.get("dtype", torch.half), device=get_current_device() ).contiguous() ) - self.reduce_scatter_memory_pool[key]["used"].append(True) - return_idx = length - return return_idx + setattr(self.reduce_scatter_memory_pool[key][cur_len], "idle", False) + return_idx = cur_len + setattr(self.reduce_scatter_memory_pool[key][return_idx], "index", return_idx) + return self.reduce_scatter_memory_pool[key][return_idx] - def release_reduce_scatter_memory(self, size, index): - self.reduce_scatter_memory_pool[size]["used"][index] = False + def release_reduce_scatter_memory(self, key, index): + self.reduce_scatter_memory_pool[key][index].idle = True def _all_gather_block_weight_memory_pool(self, block_index: int): fstp_modules = self.index_to_fstp_modules[block_index] diff --git a/internlm/model/utils.py b/internlm/model/utils.py index cdbed95..8070cbd 100644 --- a/internlm/model/utils.py +++ b/internlm/model/utils.py @@ -164,9 +164,7 @@ def reduce_scatter_raw_memory_pool(input_: Tensor, process_group: ProcessGroup, world_size = torch.distributed.get_world_size(process_group) assert input_.shape[0] % world_size == 0 size = (input_.shape[0] // world_size, *input_.shape[1:]) - index = gpc.fstp_handler.get_reduce_scatter_memory(size) - output = gpc.fstp_handler.reduce_scatter_memory_pool[size]["data"][index] - setattr(output, "index", index) + output = gpc.fstp_handler.get_reduce_scatter_memory(size) handle = torch.distributed.reduce_scatter_tensor( output, input_.contiguous(), group=process_group, async_op=async_op ) diff --git a/internlm/solver/optimizer/hybrid_zero_optim.py b/internlm/solver/optimizer/hybrid_zero_optim.py index 08d9722..0d0c8a3 100644 --- a/internlm/solver/optimizer/hybrid_zero_optim.py +++ b/internlm/solver/optimizer/hybrid_zero_optim.py @@ -350,7 +350,7 @@ class HybridZeroOptimizer(BaseOptimizer): _param.grad.add_(_grad) # release cuda memory. - gpc.fstp_handler.release_reduce_scatter_memory(size=tuple(_grad.size()), index=_grad.index) + gpc.fstp_handler.release_reduce_scatter_memory(key=tuple(_grad.size()), index=_grad.index) self._fstp_handler.reduce_scatter_handlers[_key] = None bucket.reset_by_rank(reduce_rank) From f6a5086fe4203727ed96ce4444493a080d91b74d Mon Sep 17 00:00:00 2001 From: yingtongxiong <974106207@qq.com> Date: Mon, 23 Oct 2023 14:51:27 +0800 Subject: [PATCH 4/8] support bias --- internlm/model/overlap_handler.py | 85 ++++++++++++++++++++----------- internlm/model/utils.py | 22 +++++++- 2 files changed, 75 insertions(+), 32 deletions(-) diff --git a/internlm/model/overlap_handler.py b/internlm/model/overlap_handler.py index b3c8b8b..f7132c3 100644 --- a/internlm/model/overlap_handler.py +++ b/internlm/model/overlap_handler.py @@ -10,7 +10,7 @@ from internlm.core.context import global_context as gpc from internlm.core.naive_amp import NaiveAMPModel from internlm.model.embedding import Embedding1D from internlm.model.linear import FSTPLinear, ScaleColumnParallelLinear -from internlm.model.utils import all_gather_raw_memory_pool +from internlm.model.utils import all_gather_raw_memory_pool, all_gather_raw_bias_memory_pool from internlm.utils.common import get_current_device @@ -25,6 +25,7 @@ class FSTPOverlapHandler: self.fstp_modules = [] self.module_name = ["Wqkv", "out_proj", "w1", "w2", "w3"] self.fstp_global_handle = dict() # key: fstp module; value: module global all-gather op handle + self.bias_global_handle = dict() # key: fstp module; value: module bias global all-gather op handle self.module_to_index = dict() # key: fstp module; value: transformer block index self.index_to_fstp_modules = dict() # key: transformer block index; value: fsdp modules self.head = [] @@ -76,49 +77,61 @@ class FSTPOverlapHandler: self.zero_const_pool[size] = torch.zeros(*size, dtype=dtype, device=device).contiguous() return self.zero_const_pool[size] - - def _initialize_memory_pool(self) -> None: - # allocate memory pool + + def _initialize_module_shape(self): hidden_size = gpc.config.HIDDEN_SIZE mlp_ratio = gpc.config.MLP_RATIO mlp_hidden_size = int(hidden_size * mlp_ratio) mlp_hidden_size = 256 * ((mlp_hidden_size + 256 - 1) // 256) + + self.module_shape["Wqkv"] = (3 * hidden_size, hidden_size) + self.module_shape["out_proj"] = (hidden_size, hidden_size) + self.module_shape["w1"] = (mlp_hidden_size, hidden_size) + self.module_shape["w2"] = (mlp_hidden_size, hidden_size) + self.module_shape["w3"] = (hidden_size, mlp_hidden_size) + + def _initialize_memory_pool(self) -> None: + # allocate memory pool self.all_gather_memory_pool = [] + self.all_gather_bias_memory_pool = [] self.reduce_scatter_memory_pool = {} + self.module_shape = {} + + self._initialize_module_shape() + dtype = gpc.config.model.get("dtype", torch.half) + device = get_current_device() for _ in range(2): weight = {} for name in self.module_name: - if name == "Wqkv": - weight[name] = torch.zeros( - (3 * hidden_size, hidden_size), - dtype=gpc.config.model.get("dtype", torch.half), - device=get_current_device(), - ).contiguous() - elif name == "out_proj": - weight[name] = torch.zeros( - (hidden_size, hidden_size), - dtype=gpc.config.model.get("dtype", torch.half), - device=get_current_device(), - ).contiguous() - elif name == "w1" or name == "w2": - weight[name] = torch.zeros( - (mlp_hidden_size, hidden_size), - dtype=gpc.config.model.get("dtype", torch.half), - device=get_current_device(), - ).contiguous() - else: - weight[name] = torch.zeros( - (hidden_size, mlp_hidden_size), - dtype=gpc.config.model.get("dtype", torch.half), - device=get_current_device(), - ).contiguous() - + weight[name] = torch.zeros(self.module_shape[name], dtype=dtype, device=device).contiguous() self.all_gather_memory_pool.append(weight) # containing two groups of block weight def get_all_gather_memory(self, module): block_index = self.module_to_index[module] return self.all_gather_memory_pool[block_index % 2][module._fstp_name] + + def get_bias_memory(self, module: nn.Module): + block_index = self.module_to_index[module] + # if the bias memory pool is empty or module has been not allocated memory + # import pdb; pdb.set_trace() + if len(self.all_gather_bias_memory_pool) == 0: + for _ in range(2): + weight = {} + weight[module._fstp_name] = torch.zeros( + self.module_shape[module._fstp_name][0], + dtype=gpc.config.model.get("dtype", torch.half), + device=get_current_device()).contiguous() + self.all_gather_bias_memory_pool.append(weight) + elif module._fstp_name not in self.all_gather_bias_memory_pool[0]: + for i in range(2): + self.all_gather_bias_memory_pool[i][module._fstp_name] = torch.zeros( + self.module_shape[module._fstp_name][0], + dtype=gpc.config.model.get("dtype", torch.half), + device=get_current_device()).contiguous() + + return self.all_gather_bias_memory_pool[block_index % 2][module._fstp_name] + def get_reduce_scatter_memory(self, key): return_idx = 0 @@ -157,10 +170,19 @@ class FSTPOverlapHandler: def release_reduce_scatter_memory(self, key, index): self.reduce_scatter_memory_pool[key][index].idle = True - + def _all_gather_block_weight_memory_pool(self, block_index: int): fstp_modules = self.index_to_fstp_modules[block_index] for module in fstp_modules: + if module.bias is not None: + bias_handle = all_gather_raw_bias_memory_pool( + module.bias, + self.process_group, + async_op=True, + module=module, + ) + self.bias_global_handle[module] = bias_handle + weight_handle = all_gather_raw_memory_pool( module.weight, self.process_group, @@ -186,6 +208,9 @@ class FSTPOverlapHandler: def _pre_forward_hook_for_module(module: nn.Module, inputs: Any): handle = self.fstp_global_handle[module] handle.wait() + if module.bias is not None: + bias_handle = self.bias_global_handle[module] + bias_handle.wait() def _post_forward_hook_for_module(module: nn.Module, inputs: Any, output: Any): if module in self.fstp_global_handle: diff --git a/internlm/model/utils.py b/internlm/model/utils.py index 8070cbd..8a1281e 100644 --- a/internlm/model/utils.py +++ b/internlm/model/utils.py @@ -140,6 +140,21 @@ def all_gather_raw_memory_pool( ) return handle +def all_gather_raw_bias_memory_pool( + input_: Tensor, + process_group: ProcessGroup, + async_op: bool = False, + gather_dim: int = 0, + module: nn.Module = None, +): + handle = torch.distributed.all_gather_into_tensor( + gpc.fstp_handler.get_bias_memory(module=module), + input_.contiguous(), + group=process_group, + async_op=async_op, + ) + return handle + def linear_bias_wgrad_torch(my_input, grad_output, has_d_bias): assert my_input.dtype == grad_output.dtype @@ -486,8 +501,11 @@ class FSTPFusedDenseFunc(torch.autograd.Function): handle_weight.wait() # TODO memory pool for bias if bias is not None: - total_bias, handle_bias = all_gather_raw(bias, process_group, async_op=True) - handle_bias.wait() + if overlap_handler is not None: + total_bias = gpc.fstp_handler.get_bias_memory(module=module) + else: + total_bias, handle_bias = all_gather_raw(bias, process_group, async_op=True) + handle_bias.wait() else: total_bias = bias else: From 0d693cf3a182b34cc9af7b6ef640f250ff7abbda Mon Sep 17 00:00:00 2001 From: huangting4201 <1538303371@qq.com> Date: Mon, 23 Oct 2023 15:22:03 +0800 Subject: [PATCH 5/8] feat(model/overlap_handler.py): fix lint error --- internlm/model/moe.py | 1 - internlm/model/overlap_handler.py | 40 ++++++++++++++++++------------- internlm/model/utils.py | 1 + train.py | 3 +-- 4 files changed, 26 insertions(+), 19 deletions(-) diff --git a/internlm/model/moe.py b/internlm/model/moe.py index 28e5ae6..0865097 100644 --- a/internlm/model/moe.py +++ b/internlm/model/moe.py @@ -53,7 +53,6 @@ class MoE(torch.nn.Module): device=None, dtype=None, ): - super().__init__() assert ( diff --git a/internlm/model/overlap_handler.py b/internlm/model/overlap_handler.py index f7132c3..3f7ee05 100644 --- a/internlm/model/overlap_handler.py +++ b/internlm/model/overlap_handler.py @@ -10,7 +10,10 @@ from internlm.core.context import global_context as gpc from internlm.core.naive_amp import NaiveAMPModel from internlm.model.embedding import Embedding1D from internlm.model.linear import FSTPLinear, ScaleColumnParallelLinear -from internlm.model.utils import all_gather_raw_memory_pool, all_gather_raw_bias_memory_pool +from internlm.model.utils import ( + all_gather_raw_bias_memory_pool, + all_gather_raw_memory_pool, +) from internlm.utils.common import get_current_device @@ -25,7 +28,7 @@ class FSTPOverlapHandler: self.fstp_modules = [] self.module_name = ["Wqkv", "out_proj", "w1", "w2", "w3"] self.fstp_global_handle = dict() # key: fstp module; value: module global all-gather op handle - self.bias_global_handle = dict() # key: fstp module; value: module bias global all-gather op handle + self.bias_global_handle = dict() # key: fstp module; value: module bias global all-gather op handle self.module_to_index = dict() # key: fstp module; value: transformer block index self.index_to_fstp_modules = dict() # key: transformer block index; value: fsdp modules self.head = [] @@ -77,13 +80,13 @@ class FSTPOverlapHandler: self.zero_const_pool[size] = torch.zeros(*size, dtype=dtype, device=device).contiguous() return self.zero_const_pool[size] - + def _initialize_module_shape(self): hidden_size = gpc.config.HIDDEN_SIZE mlp_ratio = gpc.config.MLP_RATIO mlp_hidden_size = int(hidden_size * mlp_ratio) mlp_hidden_size = 256 * ((mlp_hidden_size + 256 - 1) // 256) - + self.module_shape["Wqkv"] = (3 * hidden_size, hidden_size) self.module_shape["out_proj"] = (hidden_size, hidden_size) self.module_shape["w1"] = (mlp_hidden_size, hidden_size) @@ -96,7 +99,7 @@ class FSTPOverlapHandler: self.all_gather_bias_memory_pool = [] self.reduce_scatter_memory_pool = {} self.module_shape = {} - + self._initialize_module_shape() dtype = gpc.config.model.get("dtype", torch.half) device = get_current_device() @@ -107,10 +110,14 @@ class FSTPOverlapHandler: weight[name] = torch.zeros(self.module_shape[name], dtype=dtype, device=device).contiguous() self.all_gather_memory_pool.append(weight) # containing two groups of block weight + def clear_memory_pool(self) -> None: + self.zero_const_pool = {} + self.reduce_scatter_memory_pool = {} + def get_all_gather_memory(self, module): block_index = self.module_to_index[module] return self.all_gather_memory_pool[block_index % 2][module._fstp_name] - + def get_bias_memory(self, module: nn.Module): block_index = self.module_to_index[module] # if the bias memory pool is empty or module has been not allocated memory @@ -119,19 +126,20 @@ class FSTPOverlapHandler: for _ in range(2): weight = {} weight[module._fstp_name] = torch.zeros( - self.module_shape[module._fstp_name][0], - dtype=gpc.config.model.get("dtype", torch.half), - device=get_current_device()).contiguous() + self.module_shape[module._fstp_name][0], + dtype=gpc.config.model.get("dtype", torch.half), + device=get_current_device(), + ).contiguous() self.all_gather_bias_memory_pool.append(weight) elif module._fstp_name not in self.all_gather_bias_memory_pool[0]: for i in range(2): self.all_gather_bias_memory_pool[i][module._fstp_name] = torch.zeros( - self.module_shape[module._fstp_name][0], - dtype=gpc.config.model.get("dtype", torch.half), - device=get_current_device()).contiguous() - + self.module_shape[module._fstp_name][0], + dtype=gpc.config.model.get("dtype", torch.half), + device=get_current_device(), + ).contiguous() + return self.all_gather_bias_memory_pool[block_index % 2][module._fstp_name] - def get_reduce_scatter_memory(self, key): return_idx = 0 @@ -170,7 +178,7 @@ class FSTPOverlapHandler: def release_reduce_scatter_memory(self, key, index): self.reduce_scatter_memory_pool[key][index].idle = True - + def _all_gather_block_weight_memory_pool(self, block_index: int): fstp_modules = self.index_to_fstp_modules[block_index] for module in fstp_modules: @@ -182,7 +190,7 @@ class FSTPOverlapHandler: module=module, ) self.bias_global_handle[module] = bias_handle - + weight_handle = all_gather_raw_memory_pool( module.weight, self.process_group, diff --git a/internlm/model/utils.py b/internlm/model/utils.py index 8a1281e..42a8400 100644 --- a/internlm/model/utils.py +++ b/internlm/model/utils.py @@ -140,6 +140,7 @@ def all_gather_raw_memory_pool( ) return handle + def all_gather_raw_bias_memory_pool( input_: Tensor, process_group: ProcessGroup, diff --git a/train.py b/train.py index 96dc24d..b4f2a6d 100644 --- a/train.py +++ b/train.py @@ -298,8 +298,7 @@ def main(args): prof.step() if gpc.fstp_handler is not None: - gpc.fstp_handler.zero_const_pool = {} - gpc.fstp_handler.reduce_scatter_memory_pool = {} + gpc.fstp_handler.clear_memory_pool() # torch.cuda.memory._dump_snapshot(f"my_snapshot_{gpc.get_global_rank()}.pickle") torch.cuda.reset_peak_memory_stats() From 03cc7f9b80bc94c4b3234da8d32674189c66aa5f Mon Sep 17 00:00:00 2001 From: huangting4201 <1538303371@qq.com> Date: Mon, 23 Oct 2023 15:28:34 +0800 Subject: [PATCH 6/8] feat(model/overlap_handler.py): fix lint error --- internlm/model/overlap_handler.py | 14 +++++++------- internlm/model/utils.py | 7 ++----- 2 files changed, 9 insertions(+), 12 deletions(-) diff --git a/internlm/model/overlap_handler.py b/internlm/model/overlap_handler.py index 3f7ee05..6870fe6 100644 --- a/internlm/model/overlap_handler.py +++ b/internlm/model/overlap_handler.py @@ -204,27 +204,27 @@ class FSTPOverlapHandler: register forward hooks and backward hooks for fstp modules. """ - def _post_forward_hook_for_embedding(module: nn.Module, inputs: Any, output: Any): + def _post_forward_hook_for_embedding(module: nn.Module, inputs: Any, output: Any): # pylint: disable=W0613 self._all_gather_block_weight_memory_pool(0) - def _pre_forward_hook_for_out_proj(module: nn.Module, inputs: Any): + def _pre_forward_hook_for_out_proj(module: nn.Module, inputs: Any): # pylint: disable=W0613 block_index = self.module_to_index[module] # start the all-gather for next block if block_index + 1 < gpc.config.NUM_LAYER: self._all_gather_block_weight_memory_pool(block_index + 1) - def _pre_forward_hook_for_module(module: nn.Module, inputs: Any): + def _pre_forward_hook_for_module(module: nn.Module, inputs: Any): # pylint: disable=W0613 handle = self.fstp_global_handle[module] handle.wait() if module.bias is not None: bias_handle = self.bias_global_handle[module] bias_handle.wait() - def _post_forward_hook_for_module(module: nn.Module, inputs: Any, output: Any): + def _post_forward_hook_for_module(module: nn.Module, inputs: Any, output: Any): # pylint: disable=W0613 if module in self.fstp_global_handle: del self.fstp_global_handle[module] - def _post_backward_hook_for_head(module: nn.Module, grad_input, grad_output): + def _post_backward_hook_for_head(module: nn.Module, grad_input, grad_output): # pylint: disable=W0613 first_backward_module = self.fstp_modules[-1] weight_handle = all_gather_raw_memory_pool( first_backward_module.weight, @@ -234,7 +234,7 @@ class FSTPOverlapHandler: ) self.fstp_global_handle[first_backward_module] = weight_handle - def _pre_backward_hook_for_module(module: nn.Module, grad_output): + def _pre_backward_hook_for_module(module: nn.Module, grad_output): # pylint: disable=W0613 # wait handle for current module weight_handle = self.fstp_global_handle[module] weight_handle.wait() @@ -251,7 +251,7 @@ class FSTPOverlapHandler: ) self.fstp_global_handle[next_module] = weight_handle - def _post_backward_hook_for_module(module, grad_input, grad_output): + def _post_backward_hook_for_module(module, grad_input, grad_output): # pylint: disable=W0613 if module in self.fstp_global_handle: del self.fstp_global_handle[module] diff --git a/internlm/model/utils.py b/internlm/model/utils.py index 42a8400..982c0e0 100644 --- a/internlm/model/utils.py +++ b/internlm/model/utils.py @@ -129,7 +129,6 @@ def all_gather_raw_memory_pool( input_: Tensor, process_group: ProcessGroup, async_op: bool = False, - gather_dim: int = 0, module: nn.Module = None, ): handle = torch.distributed.all_gather_into_tensor( @@ -145,7 +144,6 @@ def all_gather_raw_bias_memory_pool( input_: Tensor, process_group: ProcessGroup, async_op: bool = False, - gather_dim: int = 0, module: nn.Module = None, ): handle = torch.distributed.all_gather_into_tensor( @@ -283,8 +281,8 @@ class FusedDenseFunc(torch.autograd.Function): class MegatronFusedDenseFunc(torch.autograd.Function): """ FusedDenseFunc for tensor parallel in megatron implementation. - The diffenrence between the implementation of flash-attn and megatron is that the total_x could be saved for backward in megatron, - so that the all-gather in backward is ommited. + The diffenrence between the implementation of flash-attn and megatron is that the total_x could be + saved for backward in megatron, so that the all-gather in backward is ommited. """ @staticmethod @@ -433,7 +431,6 @@ class MegatronFusedDenseFuncTorch(FusedDenseFunc): grad_input = grad_input.contiguous() process_group = ctx.process_group sequence_parallel = ctx.sequence_parallel - gather_dim = ctx.gather_dim if ctx.compute_weight_gradient: total_x, weight = ctx.saved_tensors else: From 9cf1ff0f6e8a3db1dd1e61fd7b91a056b13041ef Mon Sep 17 00:00:00 2001 From: huangting4201 <1538303371@qq.com> Date: Mon, 23 Oct 2023 15:31:41 +0800 Subject: [PATCH 7/8] feat(solver/optimizer/hybrid_zero_optim.py): minor update --- internlm/solver/optimizer/hybrid_zero_optim.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/internlm/solver/optimizer/hybrid_zero_optim.py b/internlm/solver/optimizer/hybrid_zero_optim.py index 0d0c8a3..d2c894c 100644 --- a/internlm/solver/optimizer/hybrid_zero_optim.py +++ b/internlm/solver/optimizer/hybrid_zero_optim.py @@ -350,7 +350,7 @@ class HybridZeroOptimizer(BaseOptimizer): _param.grad.add_(_grad) # release cuda memory. - gpc.fstp_handler.release_reduce_scatter_memory(key=tuple(_grad.size()), index=_grad.index) + self._fstp_handler.release_reduce_scatter_memory(key=tuple(_grad.size()), index=_grad.index) self._fstp_handler.reduce_scatter_handlers[_key] = None bucket.reset_by_rank(reduce_rank) From b2c1a70477bff8e266dcb3155c2f794dfd7cbf5f Mon Sep 17 00:00:00 2001 From: huangting4201 <1538303371@qq.com> Date: Mon, 23 Oct 2023 15:34:24 +0800 Subject: [PATCH 8/8] feat(train/training_internlm.py): fix lint error --- internlm/train/training_internlm.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/internlm/train/training_internlm.py b/internlm/train/training_internlm.py index b05611b..5e874d3 100644 --- a/internlm/train/training_internlm.py +++ b/internlm/train/training_internlm.py @@ -50,7 +50,7 @@ from internlm.solver.lr_scheduler import FineTuneCosineAnnealingWarmupLR from internlm.solver.optimizer import FSDPadaptOptimizer, HybridZeroOptimizer from internlm.solver.optimizer.utils import ParamBcastSyncHandler from internlm.train.utils import create_param_groups -from internlm.utils.common import DummyProfile, get_current_device +from internlm.utils.common import DummyProfile from internlm.utils.logger import get_logger from internlm.utils.megatron_timers import megatron_timer as timer from internlm.utils.parallel import sync_model_param, sync_model_param_within_tp