From 10aa63f0e112eac79e5fcafc6dc75961c5b76403 Mon Sep 17 00:00:00 2001 From: yingtongxiong <974106207@qq.com> Date: Sat, 7 Oct 2023 14:03:47 +0800 Subject: [PATCH] support optimized sp --- configs/7B_sft.py | 6 +- internlm/model/linear.py | 219 ++++++++++++++++++++++++- internlm/model/modeling_internlm.py | 22 ++- internlm/model/multi_head_attention.py | 140 +++++++++++++++- train.py | 21 +-- 5 files changed, 378 insertions(+), 30 deletions(-) diff --git a/configs/7B_sft.py b/configs/7B_sft.py index 25a98bf..a23edce 100644 --- a/configs/7B_sft.py +++ b/configs/7B_sft.py @@ -146,10 +146,10 @@ pipeline parallel (dict): tensor parallel: tensor parallel size, usually the number of GPUs per node. """ parallel = dict( - zero1=8, - tensor=1, + zero1=-1, + tensor=2, pipeline=dict(size=1, interleaved_overlap=True), - sequence_parallel=False, + sequence_parallel=True, ) cudnn_deterministic = False diff --git a/internlm/model/linear.py b/internlm/model/linear.py index d18308a..5ee1af9 100644 --- a/internlm/model/linear.py +++ b/internlm/model/linear.py @@ -5,13 +5,32 @@ from typing import Optional import torch from flash_attn.ops.fused_dense import ColumnParallelLinear, RowParallelLinear -from flash_attn.utils.distributed import all_reduce, reduce_scatter +from flash_attn.utils.distributed import all_reduce, reduce_scatter, all_gather_raw, reduce_scatter_raw +from torch import Tensor from torch import nn +from torch.cuda.amp import custom_bwd, custom_fwd from internlm.core.context import ParallelMode from internlm.core.context import global_context as gpc from internlm.model.utils import Silu, fused_dense_func_torch +from typing import Optional +from functools import partial + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch import Tensor +from torch.distributed import ProcessGroup +from torch.cuda.amp import custom_bwd, custom_fwd + +# import fused_dense_cuda # from apex +import fused_dense_lib as fused_dense_cuda + +from flash_attn.ops.activations import gelu_bwd, relu_bwd, sqrelu_fwd, sqrelu_bwd +from flash_attn.utils.distributed import all_gather_raw, reduce_scatter_raw, all_reduce_raw +from flash_attn.utils.distributed import reduce_scatter, all_reduce + class ScaleColumnParallelLinear(nn.Linear): """ @@ -200,3 +219,201 @@ class FeedForward(nn.Module): w2_o = self.w2(x) out = self.w3(Silu(w1_o, w2_o)) return out + +class FusedDenseFunc_fsdp(torch.autograd.Function): + + @staticmethod + @custom_fwd + def forward(ctx, x, weight, bias, return_residual=False, process_group=None): + + ctx.compute_weight_gradient = weight.requires_grad + ctx.return_residual = return_residual + ctx.process_group = process_group + + if torch.is_autocast_enabled(): + x = x.to(dtype=torch.get_autocast_gpu_dtype()) + x = x.contiguous() + total_x = x + + # do all_gather for weight and bias before actual computation + total_weight, handle_weight = all_gather_raw(weight, process_group, async_op=True) + if bias is not None: + total_bias, handle_bias = all_gather_raw(bias, process_group, async_op=True) + handle_bias.wait() + else: + total_bias = bias + + if torch.is_autocast_enabled(): + total_weight = total_weight.to(dtype=torch.get_autocast_gpu_dtype()) + total_bias = total_bias.to(dtype=torch.get_autocast_gpu_dtype()) if bias is not None else None + handle_weight.wait() + total_weight = total_weight.contiguous() + batch_shape, n = total_x.shape[:-1], total_x.shape[-1] + batch_dim = batch_shape.numel() + # https://github.com/pytorch/pytorch/blob/5b51849b48a7dbccd297286cc0110def4706f9e7/aten/src/ATen/native/cuda/Blas.cpp#L174 + if min(batch_dim, n, *total_weight.shape) > 65535 * 32: + raise RuntimeError('fused_dense only supports matrix dims <= 2M') + output = F.linear(total_x, total_weight, total_bias) + if ctx.compute_weight_gradient: + ctx.save_for_backward(x, weight) + else: + ctx.save_for_backward(weight) + return output if not return_residual else (output, x) + + @staticmethod + @custom_bwd + def backward(ctx, grad_output, *args): + grad_output = grad_output.contiguous() + if ctx.return_residual: + grad_input, = args + grad_input = grad_input.contiguous() + process_group = ctx.process_group + if ctx.compute_weight_gradient: + x, weight = ctx.saved_tensors + total_x = x + else: + weight, = ctx.saved_tensors + total_x = None + batch_shape = grad_output.shape[:-1] + batch_dim = batch_shape.numel() + grad_output = grad_output.reshape(batch_dim, grad_output.shape[-1]) + + # do all-gather for weight before backward + weight, handle_weight = all_gather_raw(weight, process_group, async_op=True) + handle_weight.wait() + + if ctx.needs_input_grad[0]: + if not ctx.return_residual: + grad_input = F.linear(grad_output, weight.t()) + else: + grad_input = torch.addmm(grad_input.reshape(batch_dim, grad_input.shape[-1]), + grad_output, weight) + grad_input = grad_input.reshape(*batch_shape, grad_input.shape[-1]) + # if process_group is not None: + # import pdb; pdb.set_trace() + # grad_input, handle_grad_input = reduce_scatter_raw(grad_input, process_group, async_op=True) + # grad_input, handle_grad_input = all_reduce_raw(grad_input, process_group, async_op=True) + + else: + grad_input = None + # import pdb; pdb.set_trace() + if ctx.needs_input_grad[1]: + assert ctx.compute_weight_gradient + + grad_weight, grad_bias = fused_dense_cuda.linear_bias_wgrad( + total_x.reshape(batch_dim, total_x.shape[-1]), grad_output, ctx.needs_input_grad[2] + ) + grad_weight, handle_grad_weight = reduce_scatter_raw(grad_weight, process_group, async_op=True) + if grad_bias is not None: + grad_bias, handle_grad_bias = reduce_scatter_raw(grad_bias, process_group, async_op=True) + handle_grad_bias.wait() + handle_grad_weight.wait() + + else: + grad_weight = None + grad_bias = grad_output if ctx.needs_input_grad[2] else None + # if process_group is not None and ctx.needs_input_grad[0]: + # handle_grad_input.wait() + # import pdb; pdb.set_trace() + return grad_input, grad_weight, grad_bias, None, None, None + + +def fsdp_fused_dense_func(x: Tensor, weight: Tensor, bias: Optional[Tensor] = None, + return_residual: bool = False, process_group = None): + dtype_eligible = (x.dtype in [torch.float16, torch.bfloat16] + or (x.dtype == torch.float32 and torch.is_autocast_enabled())) + if x.is_cuda and weight.is_cuda and (bias is None or bias.is_cuda) and dtype_eligible: + return FusedDenseFunc_fsdp.apply(x, weight, bias, return_residual, process_group) + else: + assert process_group is None + out = F.linear(x, weight, bias) + return out if not return_residual else (out, x) + +class FSDPLinear(ColumnParallelLinear): + + def forward(self, x): + return fsdp_fused_dense_func(x, self.weight, self.bias, process_group=self.process_group) + + +class FSDPScaleLinear(ScaleColumnParallelLinear): + + def forward(self, input): # pylint: disable=W0622 + # If self.sequence_parallel is True, we're doing Tensor Parallel with sequence parallelism: + # we do an all_gather of x before doing the matmul. + # If not, then the input is already gathered. + if self.weight_scale != 1: + weight = self.weight * self.weight_scale + (1 - self.weight_scale) * self.weight.detach() + else: + weight = self.weight + return fsdp_fused_dense_func( + input, + weight, + self.bias, + process_group=self.process_group, + ) + + +class FSDPFeedForward(nn.Module): + """ + FeedForward. + + Args: + in_features (int): size of each input sample + hidden_features (int): size of hidden state of FFN + out_features (int): size of each output sample + process_group (Optional[torch.distributed.ProcessGroup]): The group of the current device for `parallel_mode`. + bias (bool): Whether the bias is needed for linears. True by default. But it is typically set to False + in the config. + device (Optional[Union[str, torch.device]]): The device will be used. + dtype (Optional[torch.dtype]): The type of data. + multiple_of (int): For efficient training. Reset the size of hidden feature. 256 by default. + """ + + def __init__( + self, + in_features: int, + hidden_features: int, + out_features: int = None, + process_group: Optional[torch.distributed.ProcessGroup] = None, + bias: bool = True, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + multiple_of: int = 256, + ): + super().__init__() + + hidden_features = multiple_of * ((hidden_features + multiple_of - 1) // multiple_of) + + self.w1 = FSDPLinear( + in_features, + hidden_features, + process_group, + bias, + sequence_parallel=gpc.config.parallel.sequence_parallel, + device=device, + dtype=dtype, + ) + self.w2 = FSDPLinear( + in_features, + hidden_features, + process_group, + bias, + sequence_parallel=gpc.config.parallel.sequence_parallel, + device=device, + dtype=dtype, + ) + self.w3 = FSDPLinear( + hidden_features, + out_features, + process_group, + bias=bias, + sequence_parallel=gpc.config.parallel.sequence_parallel, + device=device, + dtype=dtype, + ) + + def forward(self, x): + w1_o = self.w1(x) + w2_o = self.w2(x) + out = self.w3(Silu(w1_o, w2_o)) + return out diff --git a/internlm/model/modeling_internlm.py b/internlm/model/modeling_internlm.py index 2856a78..8ac8c58 100644 --- a/internlm/model/modeling_internlm.py +++ b/internlm/model/modeling_internlm.py @@ -17,9 +17,11 @@ from internlm.model.linear import ( FeedForward, RewardModelLinear, ScaleColumnParallelLinear, + FSDPScaleLinear, + FSDPFeedForward, ) from internlm.model.multi_head_attention import MHA -from internlm.model.utils import gather_forward_split_backward, try_import_RMSNorm +from internlm.model.utils import gather_forward_split_backward, try_import_RMSNorm, split_forward_gather_backward from internlm.solver.pipeline_utils import partition_uniform from internlm.utils.checkpoint import activation_checkpoint from internlm.utils.common import filter_kwargs @@ -107,7 +109,16 @@ class PackedFlashBaseLayer1D(nn.Module): self.norm2 = nn.LayerNorm(hidden_size, eps=layer_norm_epsilon) if use_swiglu: - self.mlp = FeedForward( + # self.mlp = FeedForward( + # hidden_size, + # int(hidden_size * mlp_ratio), + # out_features=hidden_size, + # process_group=gpc.get_group(ParallelMode.TENSOR), + # bias=False, + # device=device, + # dtype=dtype, + # ) + self.mlp = FSDPFeedForward( hidden_size, int(hidden_size * mlp_ratio), out_features=hidden_size, @@ -293,7 +304,8 @@ class PackedFlashInternLm1D(nn.Module): if is_reward: head_cls = RewardModelLinear else: - head_cls = ScaleColumnParallelLinear + # head_cls = ScaleColumnParallelLinear + head_cls = FSDPScaleLinear if first: if embed_split_hidden: self.embedding = Embedding1D(num_embeddings=vocab_size, embedding_dim=hidden_size) @@ -379,6 +391,9 @@ class PackedFlashInternLm1D(nn.Module): assert len(indexes) == 1 # The indexes are used to indicate the actual position IDs of each token in the packed input. indexes = indexes[0] + if gpc.config.parallel.sequence_parallel: + indexes = split_forward_gather_backward(indexes, ParallelMode.TENSOR, dim=0) + max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() if cu_seqlens is not None else None for _, block in enumerate(self.blocks): @@ -394,6 +409,7 @@ class PackedFlashInternLm1D(nn.Module): hidden_states = self.norm(hidden_states.float()) if hasattr(self, "head"): hidden_states = self.head(hidden_states) + hidden_states = gather_forward_split_backward(hidden_states, ParallelMode.TENSOR, dim=0) if not self.parallel_output: hidden_states = gather_forward_split_backward(hidden_states, ParallelMode.TENSOR, dim=-1) diff --git a/internlm/model/multi_head_attention.py b/internlm/model/multi_head_attention.py index e4008e1..abb9f19 100644 --- a/internlm/model/multi_head_attention.py +++ b/internlm/model/multi_head_attention.py @@ -18,7 +18,114 @@ from torch import nn from internlm.core.context import IS_TENSOR_PARALLEL, ParallelMode from internlm.core.context import global_context as gpc from internlm.model.embedding import DynamicNTKScalingRotaryEmbedding, RotaryEmbedding -from internlm.model.linear import ColumnParallelLinearTorch, RowParallelLinearTorch +from internlm.model.linear import ColumnParallelLinearTorch, RowParallelLinearTorch, FSDPLinear + +import torch + +from typing import Any, Tuple +from torch import Tensor +from torch.nn import Module + +import torch.distributed as dist + + +class _SeqAllToAll(torch.autograd.Function): + + @staticmethod + def forward(ctx: Any, group: dist.ProcessGroup, input: Tensor, scatter_idx: int, gather_idx: int) -> Tensor: + + ctx.group = group + ctx.scatter_idx = scatter_idx + ctx.gather_idx = gather_idx + + seq_world_size = dist.get_world_size(group) + + input_list = [t.contiguous() for t in torch.tensor_split(input, seq_world_size, scatter_idx)] + output_list = [torch.empty_like(input_list[0]) for _ in range(seq_world_size)] + # TODO Use all_to_all_single instead + dist.all_to_all(output_list, input_list, group=group) + return torch.cat(output_list, dim=gather_idx).contiguous() + + @staticmethod + def backward(ctx: Any, *grad_output: Tensor) -> Tuple[None, Tensor, None, None]: + return (None, _SeqAllToAll.apply(ctx.group, *grad_output, ctx.gather_idx, ctx.scatter_idx), None, None) + + +class DistributedAttention(torch.nn.Module): + """Initialization. + + Arguments: + local_attention (Module): local attention with q,k,v + sequence_process_group (ProcessGroup): sequence parallel process group + scatter_idx (int): scatter_idx for all2all comm + gather_idx (int): gather_idx for all2all comm + """ + + def __init__( + self, + local_attention: Module, + sequence_process_group: dist.ProcessGroup, + scatter_idx: int = 2, + gather_idx: int = 0, + ) -> None: + + super(DistributedAttention, self).__init__() + self.local_attn = local_attention + self.spg = sequence_process_group + self.scatter_idx = scatter_idx + self.gather_idx = gather_idx + + # def forward(self, query: Tensor, key: Tensor, value: Tensor, *args: Any) -> Tensor: + # """ forward + + # Arguments: + # query (Tensor): query input to the layer + # key (Tensor): key input to the layer + # value (Tensor): value input to the layer + # args: other args + + # Returns: + # * output (Tensor): context output + # """ + # # TODO Merge three alltoall calls into one + # #in shape : e.g., [s/p:h:] + # query_layer = _SeqAllToAll.apply(self.spg, query, self.scatter_idx, self.gather_idx) + # key_layer = _SeqAllToAll.apply(self.spg, key, self.scatter_idx, self.gather_idx) + # value_layer = _SeqAllToAll.apply(self.spg, value, self.scatter_idx, self.gather_idx) + + # #out shape : e.g., [s:h/p:] + # context_layer = self.local_attn(query_layer, key_layer, value_layer, *args) + + # output = _SeqAllToAll.apply(self.spg, context_layer, self.gather_idx, self.scatter_idx) + + # #out e.g., [s/p::h] + # return output + + def forward(self, qkv: Tensor, **kwargs: Any) -> Tensor: + """ forward + + Arguments: + query (Tensor): query input to the layer + key (Tensor): key input to the layer + value (Tensor): value input to the layer + args: other args + + Returns: + * output (Tensor): context output + """ + # TODO Merge three alltoall calls into one + #in shape : e.g., [s/p:h:] + qkv = _SeqAllToAll.apply(self.spg, qkv, self.scatter_idx, self.gather_idx) + # key_layer = _SeqAllToAll.apply(self.spg, key, self.scatter_idx, self.gather_idx) + # value_layer = _SeqAllToAll.apply(self.spg, value, self.scatter_idx, self.gather_idx) + + #out shape : e.g., [s:h/p:] + context_layer = self.local_attn(qkv, **kwargs) + + output = _SeqAllToAll.apply(self.spg, context_layer, 0, 2) + + #out e.g., [s/p::h] + return output class MHA(nn.Module): @@ -91,7 +198,16 @@ class MHA(nn.Module): self.rotary_emb = RotaryEmbedding(self.rotary_emb_dim, scale_base=rotary_emb_scale_base, device=device) # notice here should change bias=True - self.Wqkv = ColumnParallelLinearTorch( + # self.Wqkv = ColumnParallelLinearTorch( + # embed_dim, + # 3 * embed_dim, + # process_group, + # bias=True, + # sequence_parallel=gpc.config.parallel.sequence_parallel, + # **factory_kwargs, + # ) # according to https://spaces.ac.cn/archives/9577 + + self.Wqkv = FSDPLinear( embed_dim, 3 * embed_dim, process_group, @@ -106,9 +222,19 @@ class MHA(nn.Module): self.inner_cross_attn = inner_cross_attn_cls( causal=causal, softmax_scale=softmax_scale, attention_dropout=dropout ) + + self.inner_attn_sp = DistributedAttention(self.inner_attn, sequence_process_group=process_group, scatter_idx=3, gather_idx=0) + self.inner_cross_attn_sp = DistributedAttention(self.inner_cross_attn, sequence_process_group=process_group, scatter_idx=3, gather_idx=0) # output projection always have the bias (for now) - self.out_proj = RowParallelLinearTorch( + # self.out_proj = RowParallelLinearTorch( + # embed_dim, + # embed_dim, + # process_group, + # sequence_parallel=gpc.config.parallel.sequence_parallel, + # **factory_kwargs, + # ) + self.out_proj = FSDPLinear( embed_dim, embed_dim, process_group, @@ -211,15 +337,17 @@ class MHA(nn.Module): qkv = rearrange(qkv, "t (three h d) -> t three h d", three=3, d=self.head_dim) # total x 3 x n_head x d qkv = self.rotary_emb(qkv, **kwargs) kwargs.pop("indexes") - + if inference_params is None: if gpc.config.model.dtype is torch.float32 and gpc.config.model.use_flash_attn: with torch.cuda.amp.autocast(dtype=torch.bfloat16): if qkv.dtype not in [torch.float16, torch.bfloat16]: qkv = qkv.to(torch.bfloat16) - context = self.inner_attn(qkv, **kwargs).to(x.dtype) + # context = self.inner_attn(qkv, **kwargs).to(x.dtype) + context = self.inner_attn_sp(qkv, **kwargs).to(x.dtype) else: - context = self.inner_attn(qkv, **kwargs) + # context = self.inner_attn(qkv, **kwargs) + context = self.inner_attn_sp(qkv, **kwargs) else: raise RuntimeError("Not support this right now") diff --git a/train.py b/train.py index 139bac1..9bc4bd7 100644 --- a/train.py +++ b/train.py @@ -110,7 +110,6 @@ def main(args): # initialize and resume train state train_state = TrainState(gpc.config, train_dl.batch_sampler) - optimizer, beta2_scheduler, lr_scheduler = initialize_optimizer(model=model) ckpt_manager = CheckpointManager( @@ -170,6 +169,7 @@ def main(args): beta2_scheduler=beta2_scheduler, scheduler_hooks=scheduler_hooks, ) + # initialize simple memory profiler if args.profiling: @@ -219,21 +219,9 @@ def main(args): # do forward and backward timer("fwd-bwd").start() - moe_loss = None - if hasattr(gpc.config.model, "num_experts"): - _, _, loss, moe_loss = trainer.execute_schedule( - batch, - forward_only=False, - return_loss=True, - return_output_label=False, - ) - else: - _, _, loss = trainer.execute_schedule( - batch, - forward_only=False, - return_loss=True, - return_output_label=False, - ) + _, _, loss = trainer.execute_schedule( + batch, forward_only=False, return_loss=True, return_output_label=False + ) timer("fwd-bwd").stop() # update parameters, and returns (success_update, grad_norm) @@ -266,7 +254,6 @@ def main(args): trainer=trainer, start_time=start_time, loss=loss, - moe_loss=moe_loss, grad_norm=grad_norm_groups, metric=metric, update_panel=uniscale_logger is not None,