From 815a584930622d6c9c81508d41132a6413c86420 Mon Sep 17 00:00:00 2001 From: huangting4201 <1538303371@qq.com> Date: Fri, 20 Oct 2023 11:27:59 +0800 Subject: [PATCH 1/5] feat(model/linear.py): remove useless code --- internlm/model/linear.py | 307 +++---------------------- internlm/model/modeling_internlm.py | 3 - internlm/model/multi_head_attention.py | 1 - internlm/model/utils.py | 152 +++++++----- internlm/train/training_internlm.py | 58 +++-- train.py | 2 +- 6 files changed, 166 insertions(+), 357 deletions(-) diff --git a/internlm/model/linear.py b/internlm/model/linear.py index 4f05cd3..61a5cfc 100644 --- a/internlm/model/linear.py +++ b/internlm/model/linear.py @@ -177,7 +177,6 @@ class FeedForward(nn.Module): device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None, multiple_of: int = 256, - block_idx: int = 0, ): super().__init__() @@ -224,8 +223,14 @@ class FSTPLinear(ColumnParallelLinear): name_index = gpc.config.fstp_handler.module_name_index[self] name = gpc.config.fstp_handler.module_name[name_index] return fstp_fused_dense_func( - x, self.weight, self.bias, process_group=self.process_group, - module=self, handler=gpc.config.fstp_handler, block_index=block_index, module_name=name + x, + self.weight, + self.bias, + process_group=self.process_group, + module=self, + handler=gpc.config.fstp_handler, + block_index=block_index, + module_name=name, ) @@ -255,7 +260,6 @@ class FSTPFeedForward(nn.Module): device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None, multiple_of: int = 256, - block_idx: int = 0, ): super().__init__() @@ -296,129 +300,6 @@ class FSTPFeedForward(nn.Module): return out -class FSTPAllGatherSyncHandler: - """ - All-gather handler for overlapping the all-gather in adjcent FSTP linear. - """ - - def __init__(self, model: Union[nn.Module, nn.ModuleList], process_group) -> None: - # import pdb; pdb.set_trace() - self.process_group = process_group - self.FSTP_modules = [] - self.module_name = ["Wqkv", "out_proj", "w1", "w2", "w3"] - self.FSTP_global_weights = dict() # key: FSTP module; value: module global weight for forward - self.module_handler = dict() # key: FSTP module; value: all-gather handler - self.module_block = dict() # key: FSTP module; value: transformer block index - self.block_module = dict() # key: transformer block index; value: {name_index: FSTP module} - self.module_name_index = dict() # key: FSTP module; value: the name in index in self.module_name - - self.reduce_scatter_handlers = {} - self.all_reduce_handlers = {} - - # just want to share same for loop for ModuleList and Module - if not isinstance(model, nn.ModuleList): - model = [model] - - for _chunk in model: - if isinstance(_chunk, NaiveAMPModel): - _chunk = _chunk.model - - for _chunk_name, children in _chunk.named_children(): - if isinstance(children, nn.ModuleList): - for idx, block in enumerate(children): - index = 0 - self.block_module[idx] = {} - for _sub_name, sub in block.named_children(): - sub_modules = list(sub.children()) - if len(sub_modules) > 0: - for name, child in sub.named_children(): - if isinstance(child, FSTPLinear): - - _full_name = f"{_chunk_name}.{idx}.{_sub_name}.{name}" - setattr(child.weight, "_fstp_reduce_scatter_str", f"{_full_name}.weight") - if child.bias is not None: - setattr(child.bias, "_fstp_reduce_scatter_str", f"{_full_name}.bias") - - self.FSTP_modules.append(child) - self.module_block[child] = idx - self.block_module[idx][index] = child - self.module_name_index[child] = index - index = index + 1 - else: - continue - - def _register_sync_parameters_hook(self) -> None: - """ - register pre_forward_hook and pre_backward_hook for FSTPLinear. - """ - - def _pre_forward_hook(module: nn.Module, inputs: Any): - block_index = self.module_block[module] - name_index = self.module_name_index[module] - if name_index == 0: - total_weight, weight_handler = all_gather_raw(module.weight, self.process_group, async_op=True) - weight_handler.wait() - self.FSTP_global_weights[module] = total_weight - - # start the all-gather for next module - next_module = self.block_module[block_index][name_index + 1] - self.FSTP_global_weights[next_module], weights_handler = all_gather_raw( - next_module.weight, self.process_group, async_op=True - ) - self.module_handler[next_module] = weights_handler - else: - handler = self.module_handler[module] - handler.wait() - if name_index != 4: - next_module = self.block_module[block_index][name_index + 1] - self.FSTP_global_weights[next_module], weights_handler = all_gather_raw( - next_module.weight, self.process_group, async_op=True - ) - self.module_handler[next_module] = weights_handler - - def _post_forward_hook(module: nn.Module, input, output): - if module in self.FSTP_global_weights: - del self.FSTP_global_weights[module] - if module in self.module_handler: - del self.module_handler[module] - - def _pre_backward_hook(module: nn.Module, grad_output): - block_index = self.module_block[module] - name_index = self.module_name_index[module] - if name_index == 4: - total_weight, weight_handler = all_gather_raw(module.weight, self.process_group, async_op=True) - weight_handler.wait() - self.FSTP_global_weights[module] = total_weight - - # start the all-gather for next module - next_module = self.block_module[block_index][name_index - 1] - self.FSTP_global_weights[next_module], weights_handler = all_gather_raw( - next_module.weight, self.process_group, async_op=True - ) - self.module_handler[next_module] = weights_handler - else: - handler = self.module_handler[module] - handler.wait() - if name_index != 0: - next_module = self.block_module[block_index][name_index - 1] - self.FSTP_global_weights[next_module], weights_handler = all_gather_raw( - next_module.weight, self.process_group, async_op=True - ) - self.module_handler[next_module] = weights_handler - - def _post_backward_hook(module, grad_input, grad_output): - del self.FSTP_global_weights[module] - - for module in self.FSTP_modules: - # import pdb; pdb.set_trace() - module.register_forward_pre_hook(_pre_forward_hook) - module.register_forward_hook(_post_forward_hook) - # module.register_backward_pre_hook(_pre_backward_hook) - # module.register_backward_hook(_post_backward_hook) - module.register_full_backward_pre_hook(_pre_backward_hook) - module.register_full_backward_hook(_post_backward_hook) - - class CoarseGrainedFSTPAllGatherSyncHandler: """ All-gather handler for overlapping the all-gather in adjcent FSTP block. @@ -479,49 +360,33 @@ class CoarseGrainedFSTPAllGatherSyncHandler: self.index_to_fsdp_modules[idx].append(child) self.module_name_index[child] = index index = index + 1 - + _full_name = f"{_chunk_name}.{idx}.{_sub_name}.{name}" setattr(child.weight, "_fstp_reduce_scatter_str", f"{_full_name}.weight") if child.bias is not None: setattr(child.bias, "_fstp_reduce_scatter_str", f"{_full_name}.bias") - # _full_name = f"{_chunk_name}.{idx}.{_sub_name}.{name}" - # setattr(child.weight, "_fstp_all_reduce_str", f"{_full_name}.weight") - # if child.bias is not None: - # setattr(child.bias, "_fstp_all_reduce_str", f"{_full_name}.bias") else: continue elif isinstance(children, ScaleColumnParallelLinear): self.head.append(children) elif isinstance(children, Embedding1D): self.embedding.append(children) - - def get_zero_by_shape(self, size:tuple, dtype, device) -> torch.Tensor: - if size not in self.zero_const_pool: + + def get_zero_by_shape(self, size: tuple, dtype, device) -> torch.Tensor: + if size not in self.zero_const_pool: self.zero_const_pool[size] = torch.zeros(*size, dtype=dtype, device=device).contiguous() - + return self.zero_const_pool[size] - - def _all_gather_block_weight(self, block_index: int): - #block = self.index_to_block[block_index] - fsdp_modules = self.index_to_fsdp_modules[block_index] - # self.block_handles[block] = [] - for module in fsdp_modules: - total_weight, weight_handle = all_gather_raw(module.weight, self.process_group, async_op=True) - self.FSTP_global_weights[module] = total_weight - self.FSTP_global_handle[module] = weight_handle - # self.block_handles[block].append(weight_handle) - def _all_gather_block_weight_memory_pool(self, block_index: int): fsdp_modules = self.index_to_fsdp_modules[block_index] - # self.block_handles[block] = [] for module in fsdp_modules: module_index = self.module_name_index[module] name = self.module_name[module_index] - weight_handle = all_gather_raw_memory_pool(module.weight, self.process_group, async_op=True, block_index=block_index, module_name=name) - # self.FSTP_global_weights[module] = total_weight + weight_handle = all_gather_raw_memory_pool( + module.weight, self.process_group, async_op=True, block_index=block_index, module_name=name + ) self.FSTP_global_handle[module] = weight_handle - # self.block_handles[block].append(weight_handle) def _register_sync_parameters_hook(self) -> None: """ @@ -538,41 +403,14 @@ class CoarseGrainedFSTPAllGatherSyncHandler: block_index = self.module_to_index[module] # start the all-gather for next block if block_index + 1 < gpc.config.NUM_LAYER: - # self._all_gather_block_weight(block_index + 1) self._all_gather_block_weight_memory_pool(block_index + 1) - def _pre_forward_hook_for_block(block: nn.Module, inputs: Any): - block_index = self.block_to_index[block] - if block_index == 0: - # all gather weight for block 0 - fsdp_modules = self.index_to_fsdp_modules[block_index] - for module in fsdp_modules: - total_weight, weight_handle = all_gather_raw(module.weight, self.process_group, async_op=True) - weight_handle.wait() - self.FSTP_global_weights[module] = total_weight - else: - # wait handle for current block - handles = self.block_handles[block] - for handle in handles: - handle.wait() - - def _pre_forward_hook_for_embedding(module: nn.Module, inputs: Any, output): - # self._all_gather_block_weight(0) + def _post_forward_hook_for_embedding(module: nn.Module, inputs: Any, output): self._all_gather_block_weight_memory_pool(0) - - def _post_forward_hook_for_block(block: nn.Module, input, output): - block_index = self.block_to_index[block] - fsdp_modules = self.index_to_fsdp_modules[block_index] - if block in self.block_handles: - del self.block_handles[block] - for module in fsdp_modules: - del self.FSTP_global_weights[module] - - def _pre_forward_hook_for_module(module: nn.Module, inputs: Any,): - block_index = self.module_to_index[module] - handler = self.FSTP_global_handle[module] - handler.wait() + def _pre_forward_hook_for_module(module: nn.Module, inputs: Any): + handle = self.FSTP_global_handle[module] + handle.wait() def _post_forward_hook_for_module(module: nn.Module, input, output): if module in self.FSTP_global_weights: @@ -580,67 +418,44 @@ class CoarseGrainedFSTPAllGatherSyncHandler: if module in self.FSTP_global_handle: del self.FSTP_global_handle[module] - def _pre_backward_hook_for_block(block: nn.Module, grad_output): - # import pdb; pdb.set_trace() - block_index = self.block_to_index[block] - # if block_index == gpc.config.NUM_LAYER - 1: - # # all gather weight for the last block - # fsdp_modules = self.index_to_fsdp_modules[block_index] - # for module in fsdp_modules: - # total_weight, weight_handle = all_gather_raw(module.weight, self.process_group, async_op=True) - # weight_handle.wait() - # self.FSTP_global_weights[module] = total_weight - # else: - # # wait handle for current block - # handles = self.block_handles[block] - # for handle in handles: - # handle.wait() - # if block_index == gpc.config.NUM_LAYER - 1: - # self._all_gather_block_weight(block_index) - # start the all-gather for next block - if block_index - 1 >= 0: - self._all_gather_block_weight(block_index - 1) - def _post_backward_hook_for_head(module: nn.Module, grad_input, grad_output): first_module = self.block_module[gpc.config.NUM_LAYER - 1][4] total_weight, weight_handler = all_gather_raw(first_module.weight, self.process_group, async_op=True) self.FSTP_global_handle[first_module] = weight_handler self.FSTP_global_weights[first_module] = total_weight - def _post_backward_hook_for_block(block: nn.Module, grad_input, grad_output): - block_index = self.block_to_index[block] - fsdp_modules = self.index_to_fsdp_modules[block_index] - if block in self.block_handles: - del self.block_handles[block] - for module in fsdp_modules: - del self.FSTP_global_weights[module] - def _pre_backward_hook_for_module_memory_pool(module: nn.Module, grad_output): block_index = self.module_to_index[module] name_index = self.module_name_index[module] - + if name_index == 4 and block_index == gpc.config.NUM_LAYER - 1: - # total_weight, weight_handler = all_gather_raw(module.weight, self.process_group, async_op=True) weight_handler = self.FSTP_global_handle[module] weight_handler.wait() - # self.FSTP_global_weights[module] = total_weight # start the all-gather for next module next_module = self.block_module[block_index][name_index - 1] next_name = self.module_name[name_index - 1] weights_handler = all_gather_raw_memory_pool( - next_module.weight, self.process_group, async_op=True, block_index=block_index, module_name=next_name + next_module.weight, + self.process_group, + async_op=True, + block_index=block_index, + module_name=next_name, ) self.FSTP_global_handle[next_module] = weights_handler elif name_index == 0: handler = self.FSTP_global_handle[module] handler.wait() - + if block_index - 1 >= 0: next_module = self.block_module[block_index - 1][4] name = self.module_name[4] weights_handler = all_gather_raw_memory_pool( - next_module.weight, self.process_group, async_op=True, block_index=block_index - 1, module_name=name, + next_module.weight, + self.process_group, + async_op=True, + block_index=block_index - 1, + module_name=name, ) self.FSTP_global_handle[next_module] = weights_handler else: @@ -653,76 +468,24 @@ class CoarseGrainedFSTPAllGatherSyncHandler: next_module.weight, self.process_group, async_op=True, block_index=block_index, module_name=name ) self.FSTP_global_handle[next_module] = weights_handler - # if module in self.FSTP_global_handle: - # handler = self.FSTP_global_handle[module] - # handler.wait() - - def _pre_backward_hook_for_module(module: nn.Module, grad_output): - block_index = self.module_to_index[module] - name_index = self.module_name_index[module] - - if name_index == 4 and block_index == gpc.config.NUM_LAYER - 1: - # total_weight, weight_handler = all_gather_raw(module.weight, self.process_group, async_op=True) - weight_handler = self.FSTP_global_handle[module] - weight_handler.wait() - # self.FSTP_global_weights[module] = total_weight - - # start the all-gather for next module - next_module = self.block_module[block_index][name_index - 1] - self.FSTP_global_weights[next_module], weights_handler = all_gather_raw( - next_module.weight, self.process_group, async_op=True - ) - self.FSTP_global_handle[next_module] = weights_handler - elif name_index == 0: - handler = self.FSTP_global_handle[module] - handler.wait() - - if block_index - 1 >= 0: - next_module = self.block_module[block_index - 1][4] - self.FSTP_global_weights[next_module], weights_handler = all_gather_raw( - next_module.weight, self.process_group, async_op=True - ) - self.FSTP_global_handle[next_module] = weights_handler - else: - handler = self.FSTP_global_handle[module] - handler.wait() - if name_index != 0: - next_module = self.block_module[block_index][name_index - 1] - self.FSTP_global_weights[next_module], weights_handler = all_gather_raw( - next_module.weight, self.process_group, async_op=True - ) - self.FSTP_global_handle[next_module] = weights_handler - # if module in self.FSTP_global_handle: - # handler = self.FSTP_global_handle[module] - # handler.wait() def _post_backward_hook_for_module(module, grad_input, grad_output): if module in self.FSTP_global_weights: del self.FSTP_global_weights[module] if module in self.FSTP_global_handle: del self.FSTP_global_handle[module] - + for embedding in self.embedding: - embedding.register_forward_hook(_pre_forward_hook_for_embedding) - + embedding.register_forward_hook(_post_forward_hook_for_embedding) + for head in self.head: head.register_full_backward_hook(_post_backward_hook_for_head) - # for block in self.FSTP_blocks: - # block.register_forward_pre_hook(_pre_forward_hook_for_block) - # block.register_forward_hook(_post_forward_hook_for_block) - # block.register_full_backward_pre_hook(_pre_backward_hook_for_block) - # block.register_full_backward_hook(_post_backward_hook_for_block) - for out_proj in self.FSTP_outs: out_proj.register_forward_pre_hook(_pre_forward_hook_for_out_proj) - - # for wqkv in self.FSTP_wqkvs: - # wqkv.register_full_backward_pre_hook(_pre_backward_hook_for_wqkv) for module in self.FSTP_modules: module.register_forward_pre_hook(_pre_forward_hook_for_module) module.register_forward_hook(_post_forward_hook_for_module) - # module.register_full_backward_pre_hook(_pre_backward_hook_for_module) module.register_full_backward_pre_hook(_pre_backward_hook_for_module_memory_pool) module.register_full_backward_hook(_post_backward_hook_for_module) diff --git a/internlm/model/modeling_internlm.py b/internlm/model/modeling_internlm.py index b004dff..0df2b60 100644 --- a/internlm/model/modeling_internlm.py +++ b/internlm/model/modeling_internlm.py @@ -78,7 +78,6 @@ class PackedFlashBaseLayer1D(nn.Module): use_swiglu: bool = True, use_flash_attn: bool = True, tp_mode: str = "origin_tp", - block_idx: int = 0, ): super().__init__() self.checkpoint = checkpoint @@ -104,7 +103,6 @@ class PackedFlashBaseLayer1D(nn.Module): device=device, dtype=dtype, tp_mode=tp_mode, - block_idx=block_idx, ) self.dropout1 = nn.Dropout(drop_rate) @@ -346,7 +344,6 @@ class PackedFlashInternLm1D(nn.Module): use_swiglu=use_swiglu, use_flash_attn=use_flash_attn, tp_mode=self.tp_mode, - block_idx=lid, ) for lid in range(num_layers) ] diff --git a/internlm/model/multi_head_attention.py b/internlm/model/multi_head_attention.py index 7a0f4ed..8dcd3f9 100644 --- a/internlm/model/multi_head_attention.py +++ b/internlm/model/multi_head_attention.py @@ -176,7 +176,6 @@ class MHA(nn.Module): device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None, tp_mode: str = "origin_tp", - block_idx: int = 0, ) -> None: factory_kwargs = {"device": device, "dtype": dtype} super().__init__() diff --git a/internlm/model/utils.py b/internlm/model/utils.py index 2667efe..b9c7c03 100644 --- a/internlm/model/utils.py +++ b/internlm/model/utils.py @@ -6,15 +6,15 @@ from typing import Any, Optional, Union import fused_dense_lib as fused_dense_cuda import torch import torch.nn.functional as F -from flash_attn.utils.distributed import all_reduce_raw #, reduce_scatter_raw +from flash_attn.utils.distributed import all_reduce_raw # , reduce_scatter_raw from torch import Tensor from torch.cuda.amp import custom_bwd, custom_fwd from torch.distributed import ProcessGroup from internlm.core.context import ParallelMode from internlm.core.context import global_context as gpc -from internlm.utils.logger import get_logger from internlm.utils.common import get_current_device +from internlm.utils.logger import get_logger logger = get_logger(__file__) @@ -125,9 +125,20 @@ def all_gather_raw(input_: Tensor, process_group: ProcessGroup, async_op: bool = ) return output, handle -def all_gather_raw_memory_pool(input_: Tensor, process_group: ProcessGroup, async_op: bool = False, gather_dim: int = 0, block_index: int = None, module_name: str = None): + +def all_gather_raw_memory_pool( + input_: Tensor, + process_group: ProcessGroup, + async_op: bool = False, + gather_dim: int = 0, + block_index: int = None, + module_name: str = None, +): handle = torch.distributed.all_gather_into_tensor( - gpc.config.block_memory[block_index % 2][module_name], input_.contiguous(), group=process_group, async_op=async_op + gpc.config.block_memory[block_index % 2][module_name], + input_.contiguous(), + group=process_group, + async_op=async_op, ) return handle @@ -142,23 +153,25 @@ def linear_bias_wgrad_torch(my_input, grad_output, has_d_bias): def reduce_scatter_raw(input_: Tensor, process_group: ProcessGroup, async_op: bool = False): world_size = torch.distributed.get_world_size(process_group) assert input_.shape[0] % world_size == 0 - output = torch.empty(input_.shape[0] // world_size, *input_.shape[1:], - dtype=input_.dtype, device=input_.device).contiguous() - handle = torch.distributed.reduce_scatter_tensor(output, input_.contiguous(), - group=process_group, - async_op=async_op) + output = torch.empty( + input_.shape[0] // world_size, *input_.shape[1:], dtype=input_.dtype, device=input_.device + ).contiguous() + handle = torch.distributed.reduce_scatter_tensor( + output, input_.contiguous(), group=process_group, async_op=async_op + ) return output, handle + def reduce_scatter_raw_memory_pool(input_: Tensor, process_group: ProcessGroup, async_op: bool = False): world_size = torch.distributed.get_world_size(process_group) assert input_.shape[0] % world_size == 0 size = (input_.shape[0] // world_size, *input_.shape[1:]) index = check_reduce_scatter_memory_pool(size) - output = gpc.config.reduce_scatter_memory[size]['data'][index] + output = gpc.config.reduce_scatter_memory[size]["data"][index] setattr(output, "index", index) - handle = torch.distributed.reduce_scatter_tensor(output, input_.contiguous(), - group=process_group, - async_op=async_op) + handle = torch.distributed.reduce_scatter_tensor( + output, input_.contiguous(), group=process_group, async_op=async_op + ) return output, handle @@ -313,7 +326,18 @@ class FSTPFusedDenseFunc(torch.autograd.Function): @staticmethod @custom_fwd - def forward(ctx, x, weight, bias, return_residual=False, process_group=None, module=None, all_gather_handler=None, block_index=None, module_name=None): + def forward( + ctx, + x, + weight, + bias, + return_residual=False, + process_group=None, + module=None, + all_gather_handler=None, + block_index=None, + module_name=None, + ): ctx.compute_weight_gradient = weight.requires_grad ctx.return_residual = return_residual ctx.process_group = process_group @@ -329,9 +353,9 @@ class FSTPFusedDenseFunc(torch.autograd.Function): world_size = gpc.get_world_size(ParallelMode.TENSOR) if world_size > 1: # do all_gather for weight and bias before actual computation - if all_gather_handler is not None:# and module in all_gather_handler.FSTP_global_weights: - # total_weight = all_gather_handler.FSTP_global_weights[module] - total_weight = gpc.config.block_memory[block_index % 2][module_name] + if all_gather_handler is not None: # and module in all_gather_handler.FSTP_global_weights: + # total_weight = all_gather_handler.FSTP_global_weights[module] + total_weight = gpc.config.block_memory[block_index % 2][module_name] else: total_weight, handle_weight = all_gather_raw(weight, process_group, async_op=True) handle_weight.wait() @@ -376,7 +400,7 @@ class FSTPFusedDenseFunc(torch.autograd.Function): module = ctx.module block_index = ctx.block_index module_name = ctx.module_name - + if ctx.compute_weight_gradient: x, weight, bias = ctx.saved_tensors total_x = x @@ -408,32 +432,43 @@ class FSTPFusedDenseFunc(torch.autograd.Function): ) if world_size > 1: if gpc.config.fstp_handler is not None: - # grad_weight_async, handle_grad_weight = all_reduce_raw(grad_weight, process_group, async_op=True) - # assert hasattr(weight, "_fstp_all_reduce_str") - # all_gather_handler.all_reduce_handlers[weight._fstp_all_reduce_str] = (handle_grad_weight, grad_weight_async) - # grad_weight = all_gather_handler.get_zero_by_shape((grad_weight.shape[0]//torch.distributed.get_world_size(process_group), *grad_weight.shape[1:]), dtype=grad_weight.dtype, device=grad_weight.device) - # if grad_bias is not None: - # grad_bias_async, handle_grad_bias = all_reduce_raw(grad_bias, process_group, async_op=True) - # assert hasattr(bias, "_fstp_all_reduce_str") - # all_gather_handler.all_reduce_handlers[bias._fstp_all_reduce_str] = (handle_grad_bias, grad_bias_async) - # grad_bias = all_gather_handler.get_zero_by_shape((grad_weight.shape[0]//torch.distributed.get_world_size(process_group), *grad_weight.shape[1:]), dtype=grad_bias.dtype, device=grad_bias.device) - - grad_weight_async, handle_grad_weight = reduce_scatter_raw_memory_pool(grad_weight, process_group, async_op=True) + grad_weight_async, handle_grad_weight = reduce_scatter_raw_memory_pool( + grad_weight, process_group, async_op=True + ) assert hasattr(weight, "_fstp_reduce_scatter_str") - all_gather_handler.reduce_scatter_handlers[weight._fstp_reduce_scatter_str] = (handle_grad_weight, grad_weight_async) - grad_weight = all_gather_handler.get_zero_by_shape((grad_weight.shape[0]//torch.distributed.get_world_size(process_group), *grad_weight.shape[1:]), dtype=grad_weight.dtype, device=grad_weight.device) + all_gather_handler.reduce_scatter_handlers[weight._fstp_reduce_scatter_str] = ( + handle_grad_weight, + grad_weight_async, + ) + grad_weight = all_gather_handler.get_zero_by_shape( + ( + grad_weight.shape[0] // torch.distributed.get_world_size(process_group), + *grad_weight.shape[1:], + ), + dtype=grad_weight.dtype, + device=grad_weight.device, + ) if grad_bias is not None: - grad_bias_async, handle_grad_bias = reduce_scatter_raw_memory_pool(grad_bias, process_group, async_op=True) + grad_bias_async, handle_grad_bias = reduce_scatter_raw_memory_pool( + grad_bias, process_group, async_op=True + ) assert hasattr(bias, "_fstp_reduce_scatter_str") - all_gather_handler.reduce_scatter_handlers[bias._fstp_reduce_scatter_str] = (handle_grad_bias, grad_bias_async) - grad_bias = all_gather_handler.get_zero_by_shape((grad_bias.shape[0]//torch.distributed.get_world_size(process_group), *grad_bias.shape[1:]), dtype=grad_bias.dtype, device=grad_bias.device) + all_gather_handler.reduce_scatter_handlers[bias._fstp_reduce_scatter_str] = ( + handle_grad_bias, + grad_bias_async, + ) + grad_bias = all_gather_handler.get_zero_by_shape( + ( + grad_bias.shape[0] // torch.distributed.get_world_size(process_group), + *grad_bias.shape[1:], + ), + dtype=grad_bias.dtype, + device=grad_bias.device, + ) else: grad_weight, handle_grad_weight = reduce_scatter_raw(grad_weight, process_group, async_op=True) if grad_bias is not None: grad_bias, handle_grad_bias = reduce_scatter_raw(grad_bias, process_group, async_op=True) - # grad_weight, handle_grad_weight = reduce_scatter_raw(grad_weight, process_group, async_op=True) - # if grad_bias is not None: - # grad_bias, handle_grad_bias = reduce_scatter_raw(grad_bias, process_group, async_op=True) else: grad_weight = None grad_bias = grad_output if ctx.needs_input_grad[2] else None @@ -489,7 +524,9 @@ def fstp_fused_dense_func( x.dtype == torch.float32 and torch.is_autocast_enabled() ) if x.is_cuda and weight.is_cuda and (bias is None or bias.is_cuda) and dtype_eligible: - return FSTPFusedDenseFunc.apply(x, weight, bias, return_residual, process_group, module, handler, block_index, module_name) + return FSTPFusedDenseFunc.apply( + x, weight, bias, return_residual, process_group, module, handler, block_index, module_name + ) else: assert process_group is None out = F.linear(x, weight, bias) @@ -536,36 +573,37 @@ def Silu(w1_o, w2_o): Silu = torch.jit.script(Silu) + def check_reduce_scatter_memory_pool(key): - return_idx = 0 - + # if key not in dict if key not in gpc.config.reduce_scatter_memory: - gpc.config.reduce_scatter_memory[key] = {'data': [], 'used': []} - + gpc.config.reduce_scatter_memory[key] = {"data": [], "used": []} + # if the data is empty - if len(gpc.config.reduce_scatter_memory[key]['data']) == 0: - gpc.config.reduce_scatter_memory[key]['data'].append(torch.zeros(key, - dtype=gpc.config.model.get("dtype", torch.half), - device=get_current_device()).contiguous()) - gpc.config.reduce_scatter_memory[key]['used'].append(True) + if len(gpc.config.reduce_scatter_memory[key]["data"]) == 0: + gpc.config.reduce_scatter_memory[key]["data"].append( + torch.zeros(key, dtype=gpc.config.model.get("dtype", torch.half), device=get_current_device()).contiguous() + ) + gpc.config.reduce_scatter_memory[key]["used"].append(True) return_idx = 0 return return_idx - else: # if not empty - for index, used in enumerate(gpc.config.reduce_scatter_memory[key]['used']): - if used == False: - gpc.config.reduce_scatter_memory[key]['used'][index] = True + else: # if not empty + for index, used in enumerate(gpc.config.reduce_scatter_memory[key]["used"]): + if used is False: + gpc.config.reduce_scatter_memory[key]["used"][index] = True return_idx = index return return_idx # if the memory pool is all used - length = len(gpc.config.reduce_scatter_memory[key]['data']) - gpc.config.reduce_scatter_memory[key]['data'].append(torch.zeros(key, - dtype=gpc.config.model.get("dtype", torch.half), - device=get_current_device()).contiguous()) - gpc.config.reduce_scatter_memory[key]['used'].append(True) + length = len(gpc.config.reduce_scatter_memory[key]["data"]) + gpc.config.reduce_scatter_memory[key]["data"].append( + torch.zeros(key, dtype=gpc.config.model.get("dtype", torch.half), device=get_current_device()).contiguous() + ) + gpc.config.reduce_scatter_memory[key]["used"].append(True) return_idx = length return return_idx + def release_reduce_scatter_memory_pool(size, index): - gpc.config.reduce_scatter_memory[size]['used'][index] = False \ No newline at end of file + gpc.config.reduce_scatter_memory[size]["used"][index] = False diff --git a/internlm/train/training_internlm.py b/internlm/train/training_internlm.py index 2816da0..5205ba5 100644 --- a/internlm/train/training_internlm.py +++ b/internlm/train/training_internlm.py @@ -38,7 +38,6 @@ from internlm.model.embedding import Embedding1D from internlm.model.linear import ( CoarseGrainedFSTPAllGatherSyncHandler, FeedForward, - FSTPAllGatherSyncHandler, RewardModelLinear, ScaleColumnParallelLinear, ) @@ -108,7 +107,7 @@ def initialize_model(): # if fsdp enabled, wrap the model model = wrap_FSDP_model(model) - + gpc.config.fstp_handler = None if gpc.config.parallel["tensor"]["mode"] == "fstp" and gpc.config.parallel["tensor"]["overlap"] == True: @@ -116,40 +115,53 @@ def initialize_model(): # handler = FSTPAllGatherSyncHandler(model, gpc.get_group(ParallelMode.TENSOR)) handler._register_sync_parameters_hook() gpc.config.fstp_handler = handler - + # allocate memory pool - block_memory = {} # containing two groups of block weight + block_memory = {} # containing two groups of block weight hidden_size = gpc.config.HIDDEN_SIZE mlp_ratio = gpc.config.MLP_RATIO mlp_hidden_size = int(hidden_size * mlp_ratio) mlp_hidden_size = 256 * ((mlp_hidden_size + 256 - 1) // 256) world_size = gpc.get_world_size(ParallelMode.TENSOR) - size_key = [(3 * hidden_size // world_size, hidden_size), (mlp_hidden_size // world_size, hidden_size), (hidden_size // world_size, mlp_hidden_size), (hidden_size // world_size, hidden_size)] - module_name = ['Wqkv', 'out_proj', 'w1', 'w2', 'w3'] + size_key = [ + (3 * hidden_size // world_size, hidden_size), + (mlp_hidden_size // world_size, hidden_size), + (hidden_size // world_size, mlp_hidden_size), + (hidden_size // world_size, hidden_size), + ] + module_name = ["Wqkv", "out_proj", "w1", "w2", "w3"] for i in range(2): weight = {} for name in module_name: - if name == 'Wqkv': - weight[name] = torch.zeros((3 * hidden_size, hidden_size), - dtype=gpc.config.model.get("dtype", torch.half), - device=get_current_device()).contiguous() - elif name == 'out_proj': - weight[name] = torch.zeros((hidden_size, hidden_size), - dtype=gpc.config.model.get("dtype", torch.half), - device=get_current_device()).contiguous() - elif name == 'w1' or name == 'w2': - weight[name] = torch.zeros((mlp_hidden_size, hidden_size), - dtype=gpc.config.model.get("dtype", torch.half), - device=get_current_device()).contiguous() + if name == "Wqkv": + weight[name] = torch.zeros( + (3 * hidden_size, hidden_size), + dtype=gpc.config.model.get("dtype", torch.half), + device=get_current_device(), + ).contiguous() + elif name == "out_proj": + weight[name] = torch.zeros( + (hidden_size, hidden_size), + dtype=gpc.config.model.get("dtype", torch.half), + device=get_current_device(), + ).contiguous() + elif name == "w1" or name == "w2": + weight[name] = torch.zeros( + (mlp_hidden_size, hidden_size), + dtype=gpc.config.model.get("dtype", torch.half), + device=get_current_device(), + ).contiguous() else: - weight[name] = torch.zeros((hidden_size, mlp_hidden_size), - dtype=gpc.config.model.get("dtype", torch.half), - device=get_current_device()).contiguous() + weight[name] = torch.zeros( + (hidden_size, mlp_hidden_size), + dtype=gpc.config.model.get("dtype", torch.half), + device=get_current_device(), + ).contiguous() block_memory[i] = weight reduce_scatter_memory = {} for key in size_key: - reduce_scatter_memory[key] = {'data': [], 'used': []} - + reduce_scatter_memory[key] = {"data": [], "used": []} + gpc.config.block_memory = block_memory gpc.config.reduce_scatter_memory = reduce_scatter_memory diff --git a/train.py b/train.py index 41ab070..19a104b 100644 --- a/train.py +++ b/train.py @@ -296,7 +296,7 @@ def main(args): memory_profiler.step() prof.step() - + if gpc.config.fstp_handler is not None: gpc.config.fstp_handler.zero_const_pool = {} gpc.config.fstp_handler.reduce_scatter_memory = {} From 95488d8e8f1737947c4f9a00f888d9f57e6ea606 Mon Sep 17 00:00:00 2001 From: "chenxun.p" Date: Fri, 20 Oct 2023 15:58:06 +0800 Subject: [PATCH 2/5] update optimizer accumulate grad impl when fstp --- .../core/scheduler/no_pipeline_scheduler.py | 1 - .../solver/optimizer/hybrid_zero_optim.py | 133 +++++++----------- 2 files changed, 51 insertions(+), 83 deletions(-) diff --git a/internlm/core/scheduler/no_pipeline_scheduler.py b/internlm/core/scheduler/no_pipeline_scheduler.py index f0caf05..56661d8 100644 --- a/internlm/core/scheduler/no_pipeline_scheduler.py +++ b/internlm/core/scheduler/no_pipeline_scheduler.py @@ -194,7 +194,6 @@ class NonPipelineScheduler(BaseScheduler): _output, _loss, _moe_loss = self._train_one_batch( _data, _label, engine, forward_only, return_loss, self._grad_accum_size ) - engine.optimizer.reset_reduce_bucket() if return_loss: loss += _loss diff --git a/internlm/solver/optimizer/hybrid_zero_optim.py b/internlm/solver/optimizer/hybrid_zero_optim.py index 96a54c0..2c14c65 100644 --- a/internlm/solver/optimizer/hybrid_zero_optim.py +++ b/internlm/solver/optimizer/hybrid_zero_optim.py @@ -2,6 +2,7 @@ # -*- encoding: utf-8 -*- import math +from typing import Optional, List from functools import partial import torch @@ -40,8 +41,20 @@ from .utils import compute_norm inf = math.inf logger = get_logger(__file__) + def print_memory(msg): - print(msg, " rank = ", gpc.get_global_rank(), " memory allocated: ", torch.cuda.memory_allocated() / 1024 / 1024 / 1024, " reverved memory: ", torch.cuda.memory_reserved() / 1024 / 1024 / 1024, " max memory: ", torch.cuda.max_memory_allocated() / 1024 / 1024 / 1024, flush=True) + print( + msg, + " rank = ", + gpc.get_global_rank(), + " memory allocated: ", + torch.cuda.memory_allocated() / 1024 / 1024 / 1024, + " reverved memory: ", + torch.cuda.memory_reserved() / 1024 / 1024 / 1024, + " max memory: ", + torch.cuda.max_memory_allocated() / 1024 / 1024 / 1024, + flush=True, + ) print("===========================================") @@ -69,7 +82,7 @@ class HybridZeroOptimizer(BaseOptimizer): backoff_factor = grad_scal_cfg.backoff_factor hysteresis = grad_scal_cfg.hysteresis max_scale = grad_scal_cfg.max_scale - + if gpc.config.parallel["tensor"]["mode"] == "fstp" and gpc.config.parallel["tensor"]["overlap"] == True: self._fstp_handler = gpc.config.fstp_handler @@ -90,8 +103,8 @@ class HybridZeroOptimizer(BaseOptimizer): # it will not manage the tensors used by mixed precision training self._param_store = ParameterStore(ParallelMode.ZERO1) self._grad_store = GradientStore(ParallelMode.DATA) - self._bucket_store = [] - self._bucket_store_2 = [] + self._bucket_store: List[BucketStore] = [] + self._accum_grad_buckets: List[BucketStore] = [] self._bucket_in_progress = [] # fp16 and fp32 params for mixed precision training @@ -160,7 +173,7 @@ class HybridZeroOptimizer(BaseOptimizer): # TODO _broadcast_parallel_mode is not only used in broadcast, maybe can change its name self._broadcast_parallel_mode.append(zero_mode) self._bucket_store.append(BucketStore(group_id, param_group["dp_mode"])) - self._bucket_store_2.append(BucketStore(group_id, param_group["dp_mode"])) + self._accum_grad_buckets.append(BucketStore(group_id, param_group["dp_mode"])) # assign parameters to ranks the params in the list are sorted params_per_rank, no_params_ranks = self._partition_param_list(group_id, param_group) @@ -306,9 +319,9 @@ class HybridZeroOptimizer(BaseOptimizer): param=param, reduce_rank=reduce_rank, ) - + reduce_scatter_checker = partial( - self._wait_reduce_scatter_and_accumulate_grad, + self._wait_reduce_scatter_and_accumulate_grads, param=param, reduce_rank=reduce_rank, ) @@ -317,7 +330,7 @@ class HybridZeroOptimizer(BaseOptimizer): # NOT IMPORTANT BUT GOOD TO KNOW: # args here is not grad, but allow_unreacable and accumulate_grad def reduce_grad_hook(*args): # pylint: disable=W0613 - if gpc.config.fstp_handler is not None: + if self._fstp_handler is not None: reduce_scatter_checker() if self.skip_grad_reduce is False: @@ -341,84 +354,36 @@ class HybridZeroOptimizer(BaseOptimizer): group_id = getattr(param, "group_id") return tensor_rank == gpc.get_local_rank(self._broadcast_parallel_mode[group_id]) - def reset_reduce_bucket(self) -> None: - for bucket in self._bucket_store_2: - for rank, params in bucket._params.items(): - for _param in params: - if not hasattr(_param, "_fstp_reduce_scatter_str"): - continue + def _accum_grads_store_in_bucket(self, bucket: BucketStore, reduce_rank: Optional[int] = None) -> None: + for _param in bucket.get_param(reduce_rank): + if not hasattr(_param, "_fstp_reduce_scatter_str"): + continue - key = getattr(_param, "_fstp_reduce_scatter_str") - comm_handle, _grad = self._fstp_handler.reduce_scatter_handlers[key] - comm_handle.wait() - _param.grad.add_(_grad) - # self._fstp_handler.reduce_scatter_handlers[key] = None - # del _grad - release_reduce_scatter_memory_pool(size=tuple(_grad.size()),index=_grad.index) - del self._fstp_handler.reduce_scatter_handlers[key] - self._fstp_handler.reduce_scatter_handlers[key] = None - assert key in self._fstp_handler.reduce_scatter_handlers - # if not hasattr(_param, "_fstp_all_reduce_str"): - # continue + # wait and accumulate gardient. + _key = getattr(_param, "_fstp_reduce_scatter_str") + _comm_handle, _grad = self._fstp_handler.reduce_scatter_handlers[_key] + _comm_handle.wait() + _param.grad.add_(_grad) - # key = getattr(_param, "_fstp_all_reduce_str") - # comm_handle, _grad = self._fstp_handler.all_reduce_handlers[key] - # comm_handle.wait() - # with torch.no_grad(): - # _grad = split_forward_gather_backward(_grad, ParallelMode.TENSOR, dim=0) - # _param.grad.add_(_grad) - # # self._fstp_handler.reduce_scatter_handlers[key] = None - # del _grad - # del self._fstp_handler.all_reduce_handlers[key] - # self._fstp_handler.all_reduce_handlers[key] = None - # assert key in self._fstp_handler.all_reduce_handlers + # release cuda memory. + self._fstp_handler.reduce_scatter_handlers[_key] = None + _grad = None - bucket.reset_by_rank(rank) - - def _wait_reduce_scatter_and_accumulate_grad(self, param, reduce_rank=None): + bucket.reset_by_rank(reduce_rank) + + def _wait_reduce_scatter_and_accumulate_grads(self, param, reduce_rank: Optional[int] = None): param_size = param.numel() + group_id = getattr(param, "group_id") + current_bucket = self._accum_grad_buckets[group_id] + # check if the bucket is full # if full, will reduce the grads already in the bucket # after reduction, the bucket will be empty - group_id = getattr(param, "group_id") - current_bucket = self._bucket_store_2[group_id] + if current_bucket.num_elements_in_bucket(reduce_rank) >= self._reduce_bucket_size: + self._accum_grads_store_in_bucket(current_bucket, reduce_rank) - if current_bucket.num_elements_in_bucket(reduce_rank) >= 512 * 1024 * 1024: - # wait reduce scatter communication - params = current_bucket.get_param(reduce_rank) - for _param in params: - if not hasattr(_param, "_fstp_reduce_scatter_str"): - continue - - key = getattr(_param, "_fstp_reduce_scatter_str") - comm_handle, _grad = self._fstp_handler.reduce_scatter_handlers[key] - comm_handle.wait() - _param.grad.add_(_grad) - # self._fstp_handler.reduce_scatter_handlers[key] = None - # del _grad - release_reduce_scatter_memory_pool(size=tuple(_grad.size()),index=_grad.index) - del self._fstp_handler.reduce_scatter_handlers[key] - self._fstp_handler.reduce_scatter_handlers[key] = None - assert key in self._fstp_handler.reduce_scatter_handlers - - # if not hasattr(_param, "_fstp_all_reduce_str"): - # continue - - # key = getattr(_param, "_fstp_all_reduce_str") - # comm_handle, _grad = self._fstp_handler.all_reduce_handlers[key] - # comm_handle.wait() - # with torch.no_grad(): - # _grad = split_forward_gather_backward(_grad, ParallelMode.TENSOR, dim=0) - # _param.grad.add_(_grad) - # # self._fstp_handler.reduce_scatter_handlers[key] = None - # del _grad - # del self._fstp_handler.all_reduce_handlers[key] - # self._fstp_handler.all_reduce_handlers[key] = None - # assert key in self._fstp_handler.all_reduce_handlers - - current_bucket.reset_by_rank(reduce_rank) - + # otherwise, add the parameter into bucket. current_bucket.add_num_elements_in_bucket(param_size, reduce_rank) current_bucket.add_param(param, reduce_rank) @@ -646,6 +611,10 @@ class HybridZeroOptimizer(BaseOptimizer): for group_id in range(self.num_param_groups): self._reduce_grads_stored_in_bucket(self._bucket_store[group_id], reduce_rank=None, last_bucket=True) + # we need to accumulate gradients left in the accumulate gardient bucket + for group_id in range(self.num_param_groups): + self._accum_grads_store_in_bucket(self._accum_grad_buckets[group_id], reduce_rank=None) + # compute norm for gradients in the before bucket groups_norms = [] for group_id in range(self.num_param_groups): @@ -685,16 +654,16 @@ class HybridZeroOptimizer(BaseOptimizer): timer("sync_grad").start() self._sync_grad() timer("sync_grad").stop() - + print_memory("No 4") - + try: - res = self._step(closure=closure, norms=total_norms) + res = self._step(closure=closure, norms=total_norms) except torch.cuda.OutOfMemoryError as e: print(e, flush=True) print(torch.cuda.memory_summary(), flush=True) torch.cuda.memory._dump_snapshot(f"my_snapshot_{gpc.get_global_rank()}.pickle") - + return res def _step(self, closure=None, norms=None): @@ -822,7 +791,7 @@ class HybridZeroOptimizer(BaseOptimizer): torch.cuda.synchronize() with torch.cuda.stream(self._comm_bcast_stream): self.broadcast_params() - + timer("step").stop() # update gradients may not be needed here, because the sync_params function is used in initialization, From d91a5d9d9ec8c7b0444b533a6b44be4430c7c199 Mon Sep 17 00:00:00 2001 From: huangting4201 <1538303371@qq.com> Date: Fri, 20 Oct 2023 15:59:40 +0800 Subject: [PATCH 3/5] feat(initialize/launch.py): refactor config for fstp --- configs/7B_sft.py | 10 ++--- internlm/initialize/launch.py | 23 ++++++---- internlm/model/modeling_internlm.py | 14 +++--- internlm/model/multi_head_attention.py | 8 ++-- .../solver/optimizer/hybrid_zero_optim.py | 45 ++++++++++++------- internlm/train/training_internlm.py | 3 +- internlm/utils/evaluation.py | 4 +- 7 files changed, 63 insertions(+), 44 deletions(-) diff --git a/configs/7B_sft.py b/configs/7B_sft.py index 6ea8b96..c51c812 100644 --- a/configs/7B_sft.py +++ b/configs/7B_sft.py @@ -152,19 +152,19 @@ zero1 parallel (dict): 2. fsdp: bool, enable/disable torch's fully sharded data parallel, defaults to False. tensor parallel (dict): 1. size: int, the size of tensor parallel. - 2. mode: str, the mode should be 'origin_tp' or 'fstp', defaults to 'origin_tp'. If the mode is 'fstp', - the sequence_parallel should be True. + 2. sp: str, the sequence parallel mode, should be in ['none', 'megatron', 'flash-attn', 'intern'], + defaults to 'none', means the sequence parallel will be disabled. + 3. intern_overlap: bool, enable/disable all_gather/reduce_scatter communication overlap when using 'intern' mode sp, + defaults to False. pipeline parallel (dict): 1. size: int, the size of pipeline parallel. 2. interleaved_overlap: bool, enable/disable communication overlap when using interleaved pipeline scheduler, defaults to False. -sequence parallel (bool): enable/disable sequence parallel, defaults to False. """ parallel = dict( zero1=dict(size=-1, fsdp=False), - tensor=dict(size=8, mode="fstp", overlap=True), + tensor=dict(size=8, sp="intern", intern_overlap=True), pipeline=dict(size=1, interleaved_overlap=True), - sequence_parallel=True, ) cudnn_deterministic = False diff --git a/internlm/initialize/launch.py b/internlm/initialize/launch.py index 80611fe..0e74f76 100644 --- a/internlm/initialize/launch.py +++ b/internlm/initialize/launch.py @@ -306,15 +306,20 @@ def args_sanity_check(): ), "sequence parallel does not support use_flash_attn=False" if isinstance(gpc.config.parallel["tensor"], int): - gpc.config.parallel["tensor"] = dict(size=gpc.config.parallel["tensor"], mode="origin_tp") - - if gpc.config.parallel["tensor"].get("mode", None) is None: - gpc.config.parallel["tensor"]["mode"] = "origin_tp" - - if gpc.config.parallel["tensor"].get("mode", None) == "fstp": - assert ( - gpc.config.parallel.sequence_parallel is True - ), "when the tp_mode is fstp, the sequence_parallel should be True." + gpc.config.parallel["tensor"] = dict(size=gpc.config.parallel["tensor"], sp="none", intern_overlap=False) + if gpc.config.parallel["tensor"].get("sp", None) is None: + gpc.config.parallel["tensor"]["sp"] = "none" + if gpc.config.parallel["tensor"].get("intern_overlap", None) is None: + gpc.config.parallel["tensor"]["intern_overlap"] = False + assert gpc.config.parallel["tensor"].get("sp", None) in [ + "none", + "megatron", + "flash-attn", + "intern", + ], "invalid sp mode, only ['none', 'megatron', 'flash-attn', 'intern'] is supported" + # adapt to old version's sequence parallel config + if gpc.config.parallel["tensor"].get("sp", None) in ["megatron", "flash-attn", "intern"]: + gpc.config.parallel.sequence_parallel = True # currently only interleaved pipeline scheduler with overlap can guarantee loss accuracy if hasattr(gpc.config.model, "num_chunks") and gpc.config.model.num_chunks > 1: diff --git a/internlm/model/modeling_internlm.py b/internlm/model/modeling_internlm.py index 0df2b60..9b6420d 100644 --- a/internlm/model/modeling_internlm.py +++ b/internlm/model/modeling_internlm.py @@ -77,7 +77,7 @@ class PackedFlashBaseLayer1D(nn.Module): use_scaled_init: bool = True, use_swiglu: bool = True, use_flash_attn: bool = True, - tp_mode: str = "origin_tp", + sp_mode: str = "none", ): super().__init__() self.checkpoint = checkpoint @@ -102,7 +102,7 @@ class PackedFlashBaseLayer1D(nn.Module): use_flash_attn=use_flash_attn, device=device, dtype=dtype, - tp_mode=tp_mode, + sp_mode=sp_mode, ) self.dropout1 = nn.Dropout(drop_rate) @@ -114,7 +114,7 @@ class PackedFlashBaseLayer1D(nn.Module): self.norm2 = nn.LayerNorm(hidden_size, eps=layer_norm_epsilon) if use_swiglu: - mlp_cls = FeedForward if tp_mode == "origin_tp" else FSTPFeedForward + mlp_cls = FSTPFeedForward if sp_mode == "intern" else FeedForward self.mlp = mlp_cls( hidden_size, int(hidden_size * mlp_ratio), @@ -297,7 +297,7 @@ class PackedFlashInternLm1D(nn.Module): super().__init__() checkpoint_layer_num = int(num_layers * checkpoint) - self.tp_mode = gpc.config.parallel["tensor"]["mode"] + self.sp_mode = gpc.config.parallel["tensor"]["sp"] if is_reward: head_cls = RewardModelLinear @@ -343,7 +343,7 @@ class PackedFlashInternLm1D(nn.Module): use_scaled_init=use_scaled_init, use_swiglu=use_swiglu, use_flash_attn=use_flash_attn, - tp_mode=self.tp_mode, + sp_mode=self.sp_mode, ) for lid in range(num_layers) ] @@ -389,8 +389,8 @@ class PackedFlashInternLm1D(nn.Module): assert len(indexes) == 1 # The indexes are used to indicate the actual position IDs of each token in the packed input. indexes = indexes[0] - # if the tensor parallel mode is 'fstp', the indexes should also be split in sequence dimension. - if gpc.config.parallel.sequence_parallel and self.tp_mode == "fstp": + # if the sequence parallel mode is 'intern', the indexes should also be split in sequence dimension. + if gpc.config.parallel.sequence_parallel and self.sp_mode == "intern": indexes = split_forward_gather_backward(indexes, ParallelMode.TENSOR, dim=0) max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() if cu_seqlens is not None else None diff --git a/internlm/model/multi_head_attention.py b/internlm/model/multi_head_attention.py index 8dcd3f9..cb0efb8 100644 --- a/internlm/model/multi_head_attention.py +++ b/internlm/model/multi_head_attention.py @@ -175,7 +175,7 @@ class MHA(nn.Module): use_flash_attn: bool = True, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None, - tp_mode: str = "origin_tp", + sp_mode: str = "none", ) -> None: factory_kwargs = {"device": device, "dtype": dtype} super().__init__() @@ -203,7 +203,7 @@ class MHA(nn.Module): self.rotary_emb = RotaryEmbedding(self.rotary_emb_dim, scale_base=rotary_emb_scale_base, device=device) # notice here should change bias=True - Wqkv_cls = ColumnParallelLinearTorch if tp_mode == "origin_tp" else FSTPLinear + Wqkv_cls = FSTPLinear if sp_mode == "intern" else ColumnParallelLinearTorch self.Wqkv = Wqkv_cls( embed_dim, 3 * embed_dim, @@ -219,12 +219,12 @@ class MHA(nn.Module): self.inner_cross_attn = inner_cross_attn_cls( causal=causal, softmax_scale=softmax_scale, attention_dropout=dropout ) - if tp_mode == "fstp": + if sp_mode == "intern": self.inner_attn = DistributedAttention(self.inner_attn, sequence_process_group=process_group) self.inner_cross_attn = DistributedAttention(self.inner_cross_attn, sequence_process_group=process_group) # output projection always have the bias (for now) - out_proj_cls = RowParallelLinearTorch if tp_mode == "origin_tp" else FSTPLinear + out_proj_cls = FSTPLinear if sp_mode == "intern" else RowParallelLinearTorch self.out_proj = out_proj_cls( embed_dim, embed_dim, diff --git a/internlm/solver/optimizer/hybrid_zero_optim.py b/internlm/solver/optimizer/hybrid_zero_optim.py index 96a54c0..a4b3173 100644 --- a/internlm/solver/optimizer/hybrid_zero_optim.py +++ b/internlm/solver/optimizer/hybrid_zero_optim.py @@ -10,7 +10,10 @@ from torch.optim import Optimizer from internlm.core.context import Config, ParallelMode from internlm.core.context import global_context as gpc -from internlm.model.utils import split_forward_gather_backward, release_reduce_scatter_memory_pool +from internlm.model.utils import ( + release_reduce_scatter_memory_pool, + split_forward_gather_backward, +) from internlm.monitor import send_alert_message from internlm.solver.optimizer.store import ( BucketStore, @@ -40,8 +43,20 @@ from .utils import compute_norm inf = math.inf logger = get_logger(__file__) + def print_memory(msg): - print(msg, " rank = ", gpc.get_global_rank(), " memory allocated: ", torch.cuda.memory_allocated() / 1024 / 1024 / 1024, " reverved memory: ", torch.cuda.memory_reserved() / 1024 / 1024 / 1024, " max memory: ", torch.cuda.max_memory_allocated() / 1024 / 1024 / 1024, flush=True) + print( + msg, + " rank = ", + gpc.get_global_rank(), + " memory allocated: ", + torch.cuda.memory_allocated() / 1024 / 1024 / 1024, + " reverved memory: ", + torch.cuda.memory_reserved() / 1024 / 1024 / 1024, + " max memory: ", + torch.cuda.max_memory_allocated() / 1024 / 1024 / 1024, + flush=True, + ) print("===========================================") @@ -69,8 +84,8 @@ class HybridZeroOptimizer(BaseOptimizer): backoff_factor = grad_scal_cfg.backoff_factor hysteresis = grad_scal_cfg.hysteresis max_scale = grad_scal_cfg.max_scale - - if gpc.config.parallel["tensor"]["mode"] == "fstp" and gpc.config.parallel["tensor"]["overlap"] == True: + + if gpc.config.parallel["tensor"]["sp"] == "intern" and gpc.config.parallel["tensor"]["intern_overlap"] is True: self._fstp_handler = gpc.config.fstp_handler # Zero related args @@ -306,7 +321,7 @@ class HybridZeroOptimizer(BaseOptimizer): param=param, reduce_rank=reduce_rank, ) - + reduce_scatter_checker = partial( self._wait_reduce_scatter_and_accumulate_grad, param=param, @@ -354,7 +369,7 @@ class HybridZeroOptimizer(BaseOptimizer): _param.grad.add_(_grad) # self._fstp_handler.reduce_scatter_handlers[key] = None # del _grad - release_reduce_scatter_memory_pool(size=tuple(_grad.size()),index=_grad.index) + release_reduce_scatter_memory_pool(size=tuple(_grad.size()), index=_grad.index) del self._fstp_handler.reduce_scatter_handlers[key] self._fstp_handler.reduce_scatter_handlers[key] = None assert key in self._fstp_handler.reduce_scatter_handlers @@ -374,7 +389,7 @@ class HybridZeroOptimizer(BaseOptimizer): # assert key in self._fstp_handler.all_reduce_handlers bucket.reset_by_rank(rank) - + def _wait_reduce_scatter_and_accumulate_grad(self, param, reduce_rank=None): param_size = param.numel() @@ -397,11 +412,11 @@ class HybridZeroOptimizer(BaseOptimizer): _param.grad.add_(_grad) # self._fstp_handler.reduce_scatter_handlers[key] = None # del _grad - release_reduce_scatter_memory_pool(size=tuple(_grad.size()),index=_grad.index) + release_reduce_scatter_memory_pool(size=tuple(_grad.size()), index=_grad.index) del self._fstp_handler.reduce_scatter_handlers[key] self._fstp_handler.reduce_scatter_handlers[key] = None assert key in self._fstp_handler.reduce_scatter_handlers - + # if not hasattr(_param, "_fstp_all_reduce_str"): # continue @@ -418,7 +433,7 @@ class HybridZeroOptimizer(BaseOptimizer): # assert key in self._fstp_handler.all_reduce_handlers current_bucket.reset_by_rank(reduce_rank) - + current_bucket.add_num_elements_in_bucket(param_size, reduce_rank) current_bucket.add_param(param, reduce_rank) @@ -685,16 +700,16 @@ class HybridZeroOptimizer(BaseOptimizer): timer("sync_grad").start() self._sync_grad() timer("sync_grad").stop() - + print_memory("No 4") - + try: - res = self._step(closure=closure, norms=total_norms) + res = self._step(closure=closure, norms=total_norms) except torch.cuda.OutOfMemoryError as e: print(e, flush=True) print(torch.cuda.memory_summary(), flush=True) torch.cuda.memory._dump_snapshot(f"my_snapshot_{gpc.get_global_rank()}.pickle") - + return res def _step(self, closure=None, norms=None): @@ -822,7 +837,7 @@ class HybridZeroOptimizer(BaseOptimizer): torch.cuda.synchronize() with torch.cuda.stream(self._comm_bcast_stream): self.broadcast_params() - + timer("step").stop() # update gradients may not be needed here, because the sync_params function is used in initialization, diff --git a/internlm/train/training_internlm.py b/internlm/train/training_internlm.py index 5205ba5..53996b3 100644 --- a/internlm/train/training_internlm.py +++ b/internlm/train/training_internlm.py @@ -110,9 +110,8 @@ def initialize_model(): gpc.config.fstp_handler = None - if gpc.config.parallel["tensor"]["mode"] == "fstp" and gpc.config.parallel["tensor"]["overlap"] == True: + if gpc.config.parallel["tensor"]["sp"] == "intern" and gpc.config.parallel["tensor"]["intern_overlap"] is True: handler = CoarseGrainedFSTPAllGatherSyncHandler(model, gpc.get_group(ParallelMode.TENSOR)) - # handler = FSTPAllGatherSyncHandler(model, gpc.get_group(ParallelMode.TENSOR)) handler._register_sync_parameters_hook() gpc.config.fstp_handler = handler diff --git a/internlm/utils/evaluation.py b/internlm/utils/evaluation.py index 968a1db..f708fa7 100644 --- a/internlm/utils/evaluation.py +++ b/internlm/utils/evaluation.py @@ -54,7 +54,7 @@ def switch_evaluation_pipeline_scheduler(trainer, num_microbatches, tensor_shape def switch_sequence_parallel_mode(): prev_mode = gpc.config.parallel.sequence_parallel try: - if gpc.config.parallel["tensor"]["mode"] == "fstp": + if gpc.config.parallel["tensor"]["sp"] == "intern": gpc.config.parallel.sequence_parallel = True else: gpc.config.parallel.sequence_parallel = False @@ -106,7 +106,7 @@ def evaluate_on_val_dls( total_val_bsz = len(batch[1]) assert total_val_bsz % data_cfg.micro_bsz == 0 num_microbatches = total_val_bsz // data_cfg.micro_bsz - if gpc.config.parallel["tensor"]["mode"] == "fstp": + if gpc.config.parallel["tensor"]["sp"] == "intern": sequence_world_size = gpc.get_world_size(ParallelMode.TENSOR) tensor_shape = torch.Size( [ From eac382ad0a0ed6075b31fbdb8a56d42239fa9f4f Mon Sep 17 00:00:00 2001 From: huangting4201 <1538303371@qq.com> Date: Fri, 20 Oct 2023 16:22:29 +0800 Subject: [PATCH 4/5] feat(optimizer/hybrid_zero_optim.py): fix lint error --- internlm/model/utils.py | 5 ++--- internlm/solver/optimizer/hybrid_zero_optim.py | 5 +---- internlm/solver/optimizer/store.py | 2 +- 3 files changed, 4 insertions(+), 8 deletions(-) diff --git a/internlm/model/utils.py b/internlm/model/utils.py index b9c7c03..19531e4 100644 --- a/internlm/model/utils.py +++ b/internlm/model/utils.py @@ -1,12 +1,12 @@ #!/usr/bin/env python # -*- encoding: utf-8 -*- -from typing import Any, Optional, Union +from typing import Optional import fused_dense_lib as fused_dense_cuda import torch import torch.nn.functional as F -from flash_attn.utils.distributed import all_reduce_raw # , reduce_scatter_raw +from flash_attn.utils.distributed import all_reduce_raw from torch import Tensor from torch.cuda.amp import custom_bwd, custom_fwd from torch.distributed import ProcessGroup @@ -397,7 +397,6 @@ class FSTPFusedDenseFunc(torch.autograd.Function): grad_input = grad_input.contiguous() process_group = ctx.process_group all_gather_handler = ctx.all_gather_handler - module = ctx.module block_index = ctx.block_index module_name = ctx.module_name diff --git a/internlm/solver/optimizer/hybrid_zero_optim.py b/internlm/solver/optimizer/hybrid_zero_optim.py index d5fec31..cb8aa65 100644 --- a/internlm/solver/optimizer/hybrid_zero_optim.py +++ b/internlm/solver/optimizer/hybrid_zero_optim.py @@ -11,10 +11,7 @@ from torch.optim import Optimizer from internlm.core.context import Config, ParallelMode from internlm.core.context import global_context as gpc -from internlm.model.utils import ( - release_reduce_scatter_memory_pool, - split_forward_gather_backward, -) +from internlm.model.utils import release_reduce_scatter_memory_pool from internlm.monitor import send_alert_message from internlm.solver.optimizer.store import ( BucketStore, diff --git a/internlm/solver/optimizer/store.py b/internlm/solver/optimizer/store.py index 228045e..f486cce 100644 --- a/internlm/solver/optimizer/store.py +++ b/internlm/solver/optimizer/store.py @@ -45,7 +45,7 @@ class BucketStore(BaseStore): def num_elements_in_bucket(self, reduce_rank: int = None): return self._num_elements_in_bucket[reduce_rank] - + def num_params_in_bucket(self, reduce_rank: int = None): return len(self._params[reduce_rank]) From 2acf9b817f6888e73c3606ddc6549f8c95694b27 Mon Sep 17 00:00:00 2001 From: huangting4201 <1538303371@qq.com> Date: Fri, 20 Oct 2023 16:25:08 +0800 Subject: [PATCH 5/5] feat(utils/gputest.py): fix lint error --- internlm/utils/gputest.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/internlm/utils/gputest.py b/internlm/utils/gputest.py index 52d9638..bf4cf1c 100644 --- a/internlm/utils/gputest.py +++ b/internlm/utils/gputest.py @@ -45,7 +45,7 @@ def empty_cache_and_diag(batch_count, interval=50): # # import time # # time.sleep(10) # print(e, "rank = ", gpc.get_global_rank(), flush=True) - # torch.cuda.memory._dump_snapshot(f"my_snapshot_{gpc.get_global_rank()}.pickle") + # torch.cuda.memory._dump_snapshot(f"my_snapshot_{gpc.get_global_rank()}.pickle") # do empty_cache after the bench torch.cuda.empty_cache()