From 189a313da6a6b6710f07f7e5e13cacb56eeb7256 Mon Sep 17 00:00:00 2001 From: yingtongxiong <974106207@qq.com> Date: Mon, 9 Oct 2023 17:26:20 +0800 Subject: [PATCH] support fstp and refactor code --- configs/7B_sft.py | 10 +-- internlm/core/context/parallel_context.py | 3 +- internlm/initialize/launch.py | 6 ++ internlm/model/linear.py | 91 +++++++------------ internlm/model/modeling_internlm.py | 29 +++--- internlm/model/multi_head_attention.py | 104 ++++++++-------------- internlm/utils/evaluation.py | 5 +- 7 files changed, 104 insertions(+), 144 deletions(-) diff --git a/configs/7B_sft.py b/configs/7B_sft.py index 5e3e0c9..6758167 100644 --- a/configs/7B_sft.py +++ b/configs/7B_sft.py @@ -5,7 +5,7 @@ SEQ_LEN = 2048 HIDDEN_SIZE = 4096 NUM_ATTENTION_HEAD = 32 MLP_RATIO = 8 / 3 -NUM_LAYER = 32 +NUM_LAYER = 4 VOCAB_SIZE = 103168 MODEL_ONLY_FOLDER = "local:llm_ckpts/xxxx" @@ -55,7 +55,7 @@ data = dict( # defaults to the value of micro_num valid_micro_num=4, # defaults to 0, means disable evaluate - valid_every=1000, + valid_every=10, pack_sample_into_one=False, total_steps=50000, skip_batches="", @@ -64,7 +64,7 @@ data = dict( min_length=50, # train_folder=TRAIN_FOLDER, # valid_folder=VALID_FOLDER, - empty_cache_and_diag_interval=10, + empty_cache_and_diag_interval=100, diag_outlier_ratio=1.1, ) @@ -135,7 +135,7 @@ model = dict( num_layers=NUM_LAYER, mlp_ratio=MLP_RATIO, apply_post_layer_norm=False, - dtype="torch.bfloat16", # Support: "torch.float16", "torch.half", "torch.bfloat16", "torch.float32", "torch.tf32" + dtype="torch.float16", # Support: "torch.float16", "torch.half", "torch.bfloat16", "torch.float32", "torch.tf32" norm_type="rmsnorm", layer_norm_epsilon=1e-5, use_flash_attn=True, @@ -155,7 +155,7 @@ tensor parallel: tensor parallel size, usually the number of GPUs per node. """ parallel = dict( zero1=-1, - tensor=2, + tensor=dict(size=2, mode='fstp'), # the mode should be 'origin_tp' or 'fstp' pipeline=dict(size=1, interleaved_overlap=True), sequence_parallel=True, ) diff --git a/internlm/core/context/parallel_context.py b/internlm/core/context/parallel_context.py index 7f3e415..da6a0d7 100644 --- a/internlm/core/context/parallel_context.py +++ b/internlm/core/context/parallel_context.py @@ -568,7 +568,8 @@ class ParallelContext(metaclass=SingletonMeta): # during model construction), this is because the random state will be different in different tensor parallel # device of the same data parallel group. The underlying reason is that the device of tp_rank = 0 will perform # additional random operations during the RowParallelLinear module building process. - set_mode(ParallelMode.DUMMY) + # set_mode(ParallelMode.DUMMY) + set_mode(ParallelMode.TENSOR) seeds = get_seeds() seed_str = ", ".join([f"{k}: {v}" for k, v in seeds.items()]) diff --git a/internlm/initialize/launch.py b/internlm/initialize/launch.py index 660cc55..895779e 100644 --- a/internlm/initialize/launch.py +++ b/internlm/initialize/launch.py @@ -279,6 +279,12 @@ def args_sanity_check(): assert not ( gpc.config.parallel.sequence_parallel is True and gpc.config.model.use_flash_attn is False ), "sequence parallel does not support use_flash_attn=False" + + if gpc.config.parallel["tensor"].get("mode", None) is None: + gpc.config.parallel["tensor"]["mode"] = "origin_tp" + + if gpc.config.parallel["tensor"].get("mode", None) is 'fstp': + assert gpc.config.parallel.sequence_parallel is True, "when the tp_mode is fstp, the sequence_parallel should be True." # currently only interleaved pipeline scheduler with overlap can guarantee loss accuracy if hasattr(gpc.config.model, "num_chunks") and gpc.config.model.num_chunks > 1: diff --git a/internlm/model/linear.py b/internlm/model/linear.py index 5ea0e80..60a3d27 100644 --- a/internlm/model/linear.py +++ b/internlm/model/linear.py @@ -4,44 +4,20 @@ from typing import Optional import torch +import torch.nn.functional as F from flash_attn.ops.fused_dense import ColumnParallelLinear, RowParallelLinear from flash_attn.utils.distributed import all_reduce, reduce_scatter, all_gather_raw, reduce_scatter_raw from torch import Tensor from torch import nn from torch.cuda.amp import custom_bwd, custom_fwd +# import fused_dense_cuda # from apex +import fused_dense_lib as fused_dense_cuda + from internlm.core.context import ParallelMode from internlm.core.context import global_context as gpc from internlm.model.utils import Silu, fused_dense_func_torch -from typing import Optional -from functools import partial - -import torch -import torch.nn as nn -import torch.nn.functional as F -from torch import Tensor -from torch.distributed import ProcessGroup -from torch.cuda.amp import custom_bwd, custom_fwd - -# import fused_dense_cuda # from apex -import fused_dense_lib as fused_dense_cuda - -from flash_attn.ops.activations import gelu_bwd, relu_bwd, sqrelu_fwd, sqrelu_bwd -from flash_attn.utils.distributed import all_gather_raw, all_reduce_raw -# reduce_scatter_raw -from flash_attn.utils.distributed import reduce_scatter, all_reduce - -def reduce_scatter_raw(input_: Tensor, process_group: ProcessGroup, async_op: bool = False, op=torch.distributed.ReduceOp.SUM): - world_size = torch.distributed.get_world_size(process_group) - assert input_.shape[0] % world_size == 0 - output = torch.empty( - input_.shape[0] // world_size, *input_.shape[1:], dtype=input_.dtype, device=input_.device - ) - handle = torch.distributed.reduce_scatter_tensor( - output, input_.contiguous(), op=op, group=process_group, async_op=async_op - ) - return output, handle class ScaleColumnParallelLinear(nn.Linear): """ @@ -231,7 +207,7 @@ class FeedForward(nn.Module): out = self.w3(Silu(w1_o, w2_o)) return out -class FusedDenseFunc_fsdp(torch.autograd.Function): +class FSDPFusedDenseFunc(torch.autograd.Function): @staticmethod @custom_fwd @@ -243,21 +219,26 @@ class FusedDenseFunc_fsdp(torch.autograd.Function): if torch.is_autocast_enabled(): x = x.to(dtype=torch.get_autocast_gpu_dtype()) - x = x.contiguous() - total_x = x + total_x = x.contiguous() - # do all_gather for weight and bias before actual computation - total_weight, handle_weight = all_gather_raw(weight, process_group, async_op=True) - if bias is not None: - total_bias, handle_bias = all_gather_raw(bias, process_group, async_op=True) - handle_bias.wait() + world_size = gpc.get_world_size(ParallelMode.TENSOR) + if world_size > 1: + # do all_gather for weight and bias before actual computation + total_weight, handle_weight = all_gather_raw(weight, process_group, async_op=True) + if bias is not None: + total_bias, handle_bias = all_gather_raw(bias, process_group, async_op=True) + handle_bias.wait() + else: + total_bias = bias + handle_weight.wait() else: + total_weight = weight total_bias = bias if torch.is_autocast_enabled(): total_weight = total_weight.to(dtype=torch.get_autocast_gpu_dtype()) total_bias = total_bias.to(dtype=torch.get_autocast_gpu_dtype()) if bias is not None else None - handle_weight.wait() + total_weight = total_weight.contiguous() batch_shape, n = total_x.shape[:-1], total_x.shape[-1] batch_dim = batch_shape.numel() @@ -289,9 +270,13 @@ class FusedDenseFunc_fsdp(torch.autograd.Function): batch_dim = batch_shape.numel() grad_output = grad_output.reshape(batch_dim, grad_output.shape[-1]) - # do all-gather for weight before backward - total_weight, handle_weight = all_gather_raw(weight, process_group, async_op=True) - handle_weight.wait() + world_size = gpc.get_world_size(ParallelMode.TENSOR) + if world_size > 1: + # do all-gather for weight before backward + total_weight, handle_weight = all_gather_raw(weight, process_group, async_op=True) + handle_weight.wait() + else: + total_weight = weight if ctx.needs_input_grad[0]: if not ctx.return_residual: @@ -300,32 +285,24 @@ class FusedDenseFunc_fsdp(torch.autograd.Function): grad_input = torch.addmm(grad_input.reshape(batch_dim, grad_input.shape[-1]), grad_output, total_weight) grad_input = grad_input.reshape(*batch_shape, grad_input.shape[-1]) - # if process_group is not None: - # import pdb; pdb.set_trace() - # grad_input, handle_grad_input = reduce_scatter_raw(grad_input, process_group, async_op=True) - # grad_input, handle_grad_input = all_reduce_raw(grad_input, process_group, async_op=True) - else: grad_input = None - # import pdb; pdb.set_trace() + if ctx.needs_input_grad[1]: assert ctx.compute_weight_gradient grad_weight, grad_bias = fused_dense_cuda.linear_bias_wgrad( total_x.reshape(batch_dim, total_x.shape[-1]), grad_output, ctx.needs_input_grad[2] ) - grad_weight, handle_grad_weight = reduce_scatter_raw(grad_weight, process_group, async_op=True) - if grad_bias is not None: - grad_bias, handle_grad_bias = reduce_scatter_raw(grad_bias, process_group, async_op=True) - handle_grad_bias.wait() - handle_grad_weight.wait() - + if world_size > 1: + grad_weight, handle_grad_weight = reduce_scatter_raw(grad_weight, process_group, async_op=True) + if grad_bias is not None: + grad_bias, handle_grad_bias = reduce_scatter_raw(grad_bias, process_group, async_op=True) + handle_grad_bias.wait() + handle_grad_weight.wait() else: grad_weight = None grad_bias = grad_output if ctx.needs_input_grad[2] else None - # if process_group is not None and ctx.needs_input_grad[0]: - # handle_grad_input.wait() - # import pdb; pdb.set_trace() return grad_input, grad_weight, grad_bias, None, None, None @@ -334,7 +311,7 @@ def fsdp_fused_dense_func(x: Tensor, weight: Tensor, bias: Optional[Tensor] = No dtype_eligible = (x.dtype in [torch.float16, torch.bfloat16] or (x.dtype == torch.float32 and torch.is_autocast_enabled())) if x.is_cuda and weight.is_cuda and (bias is None or bias.is_cuda) and dtype_eligible: - return FusedDenseFunc_fsdp.apply(x, weight, bias, return_residual, process_group) + return FSDPFusedDenseFunc.apply(x, weight, bias, return_residual, process_group) else: assert process_group is None out = F.linear(x, weight, bias) @@ -426,5 +403,5 @@ class FSDPFeedForward(nn.Module): def forward(self, x): w1_o = self.w1(x) w2_o = self.w2(x) - out = self.w3(Silu(w1_o, w2_o)) + out = self.w3(F.silu(w1_o) * w2_o) return out diff --git a/internlm/model/modeling_internlm.py b/internlm/model/modeling_internlm.py index 8ac8c58..47d706f 100644 --- a/internlm/model/modeling_internlm.py +++ b/internlm/model/modeling_internlm.py @@ -74,6 +74,7 @@ class PackedFlashBaseLayer1D(nn.Module): use_scaled_init: bool = True, use_swiglu: bool = True, use_flash_attn: bool = True, + tp_mode: str = 'origin_tp', ): super().__init__() self.checkpoint = checkpoint @@ -98,6 +99,7 @@ class PackedFlashBaseLayer1D(nn.Module): use_flash_attn=use_flash_attn, device=device, dtype=dtype, + tp_mode=tp_mode, ) self.dropout1 = nn.Dropout(drop_rate) @@ -109,16 +111,8 @@ class PackedFlashBaseLayer1D(nn.Module): self.norm2 = nn.LayerNorm(hidden_size, eps=layer_norm_epsilon) if use_swiglu: - # self.mlp = FeedForward( - # hidden_size, - # int(hidden_size * mlp_ratio), - # out_features=hidden_size, - # process_group=gpc.get_group(ParallelMode.TENSOR), - # bias=False, - # device=device, - # dtype=dtype, - # ) - self.mlp = FSDPFeedForward( + mlp_cls = FeedForward if tp_mode == 'origin_tp' else FSDPFeedForward + self.mlp = mlp_cls( hidden_size, int(hidden_size * mlp_ratio), out_features=hidden_size, @@ -179,6 +173,7 @@ class PackedFlashBaseLayer1D(nn.Module): else: normal_(std=0.006 if "fc1" in name else 0.0015)(param.data) + def forward(self, hidden_states, cu_seqlens=None, indexes=None, inference_params=None, max_seqlen=None): if self.checkpoint and self.training: return activation_checkpoint( @@ -300,12 +295,12 @@ class PackedFlashInternLm1D(nn.Module): super().__init__() checkpoint_layer_num = int(num_layers * checkpoint) + self.tp_mode = gpc.config.parallel["tensor"]["mode"] if is_reward: head_cls = RewardModelLinear else: - # head_cls = ScaleColumnParallelLinear - head_cls = FSDPScaleLinear + head_cls = ScaleColumnParallelLinear if first: if embed_split_hidden: self.embedding = Embedding1D(num_embeddings=vocab_size, embedding_dim=hidden_size) @@ -346,6 +341,7 @@ class PackedFlashInternLm1D(nn.Module): use_scaled_init=use_scaled_init, use_swiglu=use_swiglu, use_flash_attn=use_flash_attn, + tp_mode = self.tp_mode, ) for lid in range(num_layers) ] @@ -391,7 +387,8 @@ class PackedFlashInternLm1D(nn.Module): assert len(indexes) == 1 # The indexes are used to indicate the actual position IDs of each token in the packed input. indexes = indexes[0] - if gpc.config.parallel.sequence_parallel: + # if the tensor parallel mode is 'fstp', the indexes should also be split in sequence dimension. + if gpc.config.parallel.sequence_parallel and self.tp_mode == 'fstp': indexes = split_forward_gather_backward(indexes, ParallelMode.TENSOR, dim=0) max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() if cu_seqlens is not None else None @@ -408,8 +405,12 @@ class PackedFlashInternLm1D(nn.Module): if hasattr(self, "norm"): hidden_states = self.norm(hidden_states.float()) if hasattr(self, "head"): + # if hidden_states.ndim == 3: + # import pdb; pdb.set_trace() + # hidden_states = self.head(hidden_states, dim=1) + # else: + # hidden_states = self.head(hidden_states) hidden_states = self.head(hidden_states) - hidden_states = gather_forward_split_backward(hidden_states, ParallelMode.TENSOR, dim=0) if not self.parallel_output: hidden_states = gather_forward_split_backward(hidden_states, ParallelMode.TENSOR, dim=-1) diff --git a/internlm/model/multi_head_attention.py b/internlm/model/multi_head_attention.py index e6d0a29..8f7a064 100644 --- a/internlm/model/multi_head_attention.py +++ b/internlm/model/multi_head_attention.py @@ -57,49 +57,29 @@ class DistributedAttention(torch.nn.Module): Arguments: local_attention (Module): local attention with q,k,v sequence_process_group (ProcessGroup): sequence parallel process group - scatter_idx (int): scatter_idx for all2all comm - gather_idx (int): gather_idx for all2all comm + first_scatter_idx (int): scatter_idx for the first all2all comm + first_gather_idx (int): gather_idx for the first all2all comm + second_scatter_idx (int): scatter_idx for the second all2all comm + second_gather_idx (int): gather_idx for the second all2all comm """ def __init__( self, local_attention: Module, sequence_process_group: dist.ProcessGroup, - scatter_idx: int = 2, - gather_idx: int = 0, + first_scatter_idx: int = 2, + first_gather_idx: int = 0, + second_scatter_idx: int = 0, + second_gather_idx: int = 1, ) -> None: super(DistributedAttention, self).__init__() self.local_attn = local_attention self.spg = sequence_process_group - self.scatter_idx = scatter_idx - self.gather_idx = gather_idx - - # def forward(self, query: Tensor, key: Tensor, value: Tensor, *args: Any) -> Tensor: - # """ forward - - # Arguments: - # query (Tensor): query input to the layer - # key (Tensor): key input to the layer - # value (Tensor): value input to the layer - # args: other args - - # Returns: - # * output (Tensor): context output - # """ - # # TODO Merge three alltoall calls into one - # #in shape : e.g., [s/p:h:] - # query_layer = _SeqAllToAll.apply(self.spg, query, self.scatter_idx, self.gather_idx) - # key_layer = _SeqAllToAll.apply(self.spg, key, self.scatter_idx, self.gather_idx) - # value_layer = _SeqAllToAll.apply(self.spg, value, self.scatter_idx, self.gather_idx) - - # #out shape : e.g., [s:h/p:] - # context_layer = self.local_attn(query_layer, key_layer, value_layer, *args) - - # output = _SeqAllToAll.apply(self.spg, context_layer, self.gather_idx, self.scatter_idx) - - # #out e.g., [s/p::h] - # return output + self.first_scatter_idx = first_scatter_idx + self.first_gather_idx = first_gather_idx + self.second_scatter_idx = second_scatter_idx + self.second_gather_idx = second_gather_idx def forward(self, qkv: Tensor, **kwargs: Any) -> Tensor: """ forward @@ -114,15 +94,21 @@ class DistributedAttention(torch.nn.Module): * output (Tensor): context output """ # TODO Merge three alltoall calls into one - #in shape : e.g., [s/p:h:] - qkv = _SeqAllToAll.apply(self.spg, qkv, 2, 0) - # key_layer = _SeqAllToAll.apply(self.spg, key, self.scatter_idx, self.gather_idx) - # value_layer = _SeqAllToAll.apply(self.spg, value, self.scatter_idx, self.gather_idx) - - #out shape : e.g., [s:h/p:] - context_layer = self.local_attn(qkv, **kwargs) - - output = _SeqAllToAll.apply(self.spg, context_layer, 0, 1) + if qkv.ndim == 5: + # in shape: [seq/tp_size, 3, head, head_dim] + qkv = _SeqAllToAll.apply(self.spg, qkv, self.first_scatter_idx + 1, self.first_gather_idx + 1) + #out shape : [seq, head/tp_size, head_dim] + context_layer = self.local_attn(qkv, **kwargs) + # in shape: [seq, head/tp_size, head_dim] + output = _SeqAllToAll.apply(self.spg, context_layer, self.second_scatter_idx + 1, self.second_gather_idx + 1) + else: + + # in shape: [seq/tp_size, 3, head, head_dim] + qkv = _SeqAllToAll.apply(self.spg, qkv, self.first_scatter_idx, self.first_gather_idx) + #out shape : [seq, head/tp_size, head_dim] + context_layer = self.local_attn(qkv, **kwargs) + # in shape: [seq, head/tp_size, head_dim] + output = _SeqAllToAll.apply(self.spg, context_layer, self.second_scatter_idx, self.second_gather_idx) #out e.g., [s/p::h] return output @@ -171,6 +157,7 @@ class MHA(nn.Module): use_flash_attn: bool = True, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None, + tp_mode: str = 'origin_tp', ) -> None: factory_kwargs = {"device": device, "dtype": dtype} super().__init__() @@ -198,16 +185,8 @@ class MHA(nn.Module): self.rotary_emb = RotaryEmbedding(self.rotary_emb_dim, scale_base=rotary_emb_scale_base, device=device) # notice here should change bias=True - # self.Wqkv = ColumnParallelLinearTorch( - # embed_dim, - # 3 * embed_dim, - # process_group, - # bias=True, - # sequence_parallel=gpc.config.parallel.sequence_parallel, - # **factory_kwargs, - # ) # according to https://spaces.ac.cn/archives/9577 - - self.Wqkv = FSDPLinear( + Wqkv_cls = ColumnParallelLinearTorch if tp_mode == 'origin_tp' else FSDPLinear + self.Wqkv = Wqkv_cls( embed_dim, 3 * embed_dim, process_group, @@ -222,25 +201,20 @@ class MHA(nn.Module): self.inner_cross_attn = inner_cross_attn_cls( causal=causal, softmax_scale=softmax_scale, attention_dropout=dropout ) - - self.inner_attn_sp = DistributedAttention(self.inner_attn, sequence_process_group=process_group, scatter_idx=3, gather_idx=0) - self.inner_cross_attn_sp = DistributedAttention(self.inner_cross_attn, sequence_process_group=process_group, scatter_idx=3, gather_idx=0) + if tp_mode == 'fstp': + self.inner_attn = DistributedAttention(self.inner_attn, sequence_process_group=process_group) + self.inner_cross_attn = DistributedAttention(self.inner_cross_attn, sequence_process_group=process_group) # output projection always have the bias (for now) - # self.out_proj = RowParallelLinearTorch( - # embed_dim, - # embed_dim, - # process_group, - # sequence_parallel=gpc.config.parallel.sequence_parallel, - # **factory_kwargs, - # ) - self.out_proj = FSDPLinear( + out_proj_cls = RowParallelLinearTorch if tp_mode == 'origin_tp' else FSDPLinear + self.out_proj = out_proj_cls( embed_dim, embed_dim, process_group, sequence_parallel=gpc.config.parallel.sequence_parallel, **factory_kwargs, ) + # need to assign tp attribute so that internlm know it is tensor parallel module if gpc.get_world_size(ParallelMode.TENSOR) > 1: for name in ["out_proj", "Wqkv"]: @@ -343,11 +317,9 @@ class MHA(nn.Module): with torch.cuda.amp.autocast(dtype=torch.bfloat16): if qkv.dtype not in [torch.float16, torch.bfloat16]: qkv = qkv.to(torch.bfloat16) - # context = self.inner_attn(qkv, **kwargs).to(x.dtype) - context = self.inner_attn_sp(qkv, **kwargs).to(x.dtype) + context = self.inner_attn(qkv, **kwargs).to(x.dtype) else: - # context = self.inner_attn(qkv, **kwargs) - context = self.inner_attn_sp(qkv, **kwargs) + context = self.inner_attn(qkv, **kwargs) else: raise RuntimeError("Not support this right now") diff --git a/internlm/utils/evaluation.py b/internlm/utils/evaluation.py index 6a55fa5..2a11a47 100644 --- a/internlm/utils/evaluation.py +++ b/internlm/utils/evaluation.py @@ -54,7 +54,10 @@ def switch_evaluation_pipeline_scheduler(trainer, num_microbatches, tensor_shape def switch_sequence_parallel_mode(): prev_mode = gpc.config.parallel.sequence_parallel try: - gpc.config.parallel.sequence_parallel = False + if gpc.config.parallel["tensor"]["mode"] == 'fstp': + gpc.config.parallel.sequence_parallel = True + else: + gpc.config.parallel.sequence_parallel = False yield finally: gpc.config.parallel.sequence_parallel = prev_mode