diff --git a/internlm/initialize/launch.py b/internlm/initialize/launch.py index 5bd2b73..8c224bf 100644 --- a/internlm/initialize/launch.py +++ b/internlm/initialize/launch.py @@ -279,12 +279,14 @@ def args_sanity_check(): assert not ( gpc.config.parallel.sequence_parallel is True and gpc.config.model.use_flash_attn is False ), "sequence parallel does not support use_flash_attn=False" - + if gpc.config.parallel["tensor"].get("mode", None) is None: gpc.config.parallel["tensor"]["mode"] = "origin_tp" - - if gpc.config.parallel["tensor"].get("mode", None) == 'fstp': - assert gpc.config.parallel.sequence_parallel is True, "when the tp_mode is fstp, the sequence_parallel should be True." + + if gpc.config.parallel["tensor"].get("mode", None) == "fstp": + assert ( + gpc.config.parallel.sequence_parallel is True + ), "when the tp_mode is fstp, the sequence_parallel should be True." # currently only interleaved pipeline scheduler with overlap can guarantee loss accuracy if hasattr(gpc.config.model, "num_chunks") and gpc.config.model.num_chunks > 1: diff --git a/internlm/model/linear.py b/internlm/model/linear.py index 4075e9e..8e23871 100644 --- a/internlm/model/linear.py +++ b/internlm/model/linear.py @@ -9,10 +9,9 @@ from flash_attn.ops.fused_dense import ColumnParallelLinear, RowParallelLinear from flash_attn.utils.distributed import all_reduce, reduce_scatter from torch import nn - from internlm.core.context import ParallelMode from internlm.core.context import global_context as gpc -from internlm.model.utils import Silu, fused_dense_func_torch, fsdp_fused_dense_func +from internlm.model.utils import Silu, fstp_fused_dense_func, fused_dense_func_torch class ScaleColumnParallelLinear(nn.Linear): @@ -124,7 +123,12 @@ class ColumnParallelLinearTorch(ColumnParallelLinear): # If not, then the input is already gathered. return fused_dense_func_torch( - x, self.weight, self.bias, process_group=self.process_group, sequence_parallel=self.sequence_parallel, gather_dim=gather_dim, + x, + self.weight, + self.bias, + process_group=self.process_group, + sequence_parallel=self.sequence_parallel, + gather_dim=gather_dim, ) @@ -204,31 +208,13 @@ class FeedForward(nn.Module): out = self.w3(Silu(w1_o, w2_o)) return out -class FSDPLinear(ColumnParallelLinear): - + +class FSTPLinear(ColumnParallelLinear): def forward(self, x): - return fsdp_fused_dense_func(x, self.weight, self.bias, process_group=self.process_group) + return fstp_fused_dense_func(x, self.weight, self.bias, process_group=self.process_group) -class FSDPScaleLinear(ScaleColumnParallelLinear): - - def forward(self, input): # pylint: disable=W0622 - # If self.sequence_parallel is True, we're doing Tensor Parallel with sequence parallelism: - # we do an all_gather of x before doing the matmul. - # If not, then the input is already gathered. - if self.weight_scale != 1: - weight = self.weight * self.weight_scale + (1 - self.weight_scale) * self.weight.detach() - else: - weight = self.weight - return fsdp_fused_dense_func( - input, - weight, - self.bias, - process_group=self.process_group, - ) - - -class FSDPFeedForward(nn.Module): +class FSTPFeedForward(nn.Module): """ FeedForward. @@ -259,7 +245,7 @@ class FSDPFeedForward(nn.Module): hidden_features = multiple_of * ((hidden_features + multiple_of - 1) // multiple_of) - self.w1 = FSDPLinear( + self.w1 = FSTPLinear( in_features, hidden_features, process_group, @@ -268,7 +254,7 @@ class FSDPFeedForward(nn.Module): device=device, dtype=dtype, ) - self.w2 = FSDPLinear( + self.w2 = FSTPLinear( in_features, hidden_features, process_group, @@ -277,7 +263,7 @@ class FSDPFeedForward(nn.Module): device=device, dtype=dtype, ) - self.w3 = FSDPLinear( + self.w3 = FSTPLinear( hidden_features, out_features, process_group, diff --git a/internlm/model/modeling_internlm.py b/internlm/model/modeling_internlm.py index 56a8efa..b8d7e60 100644 --- a/internlm/model/modeling_internlm.py +++ b/internlm/model/modeling_internlm.py @@ -15,13 +15,16 @@ from internlm.initialize.initialize_tensor import normal_, scaled_init_method_no from internlm.model.embedding import Embedding1D from internlm.model.linear import ( FeedForward, + FSTPFeedForward, RewardModelLinear, ScaleColumnParallelLinear, - FSDPScaleLinear, - FSDPFeedForward, ) from internlm.model.multi_head_attention import MHA -from internlm.model.utils import gather_forward_split_backward, try_import_RMSNorm, split_forward_gather_backward +from internlm.model.utils import ( + gather_forward_split_backward, + split_forward_gather_backward, + try_import_RMSNorm, +) from internlm.solver.pipeline_utils import partition_uniform from internlm.utils.checkpoint import activation_checkpoint from internlm.utils.common import filter_kwargs @@ -74,7 +77,7 @@ class PackedFlashBaseLayer1D(nn.Module): use_scaled_init: bool = True, use_swiglu: bool = True, use_flash_attn: bool = True, - tp_mode: str = 'origin_tp', + tp_mode: str = "origin_tp", ): super().__init__() self.checkpoint = checkpoint @@ -111,7 +114,7 @@ class PackedFlashBaseLayer1D(nn.Module): self.norm2 = nn.LayerNorm(hidden_size, eps=layer_norm_epsilon) if use_swiglu: - mlp_cls = FeedForward if tp_mode == 'origin_tp' else FSDPFeedForward + mlp_cls = FeedForward if tp_mode == "origin_tp" else FSTPFeedForward self.mlp = mlp_cls( hidden_size, int(hidden_size * mlp_ratio), @@ -173,7 +176,6 @@ class PackedFlashBaseLayer1D(nn.Module): else: normal_(std=0.006 if "fc1" in name else 0.0015)(param.data) - def forward(self, hidden_states, cu_seqlens=None, indexes=None, inference_params=None, max_seqlen=None): if self.checkpoint and self.training: return activation_checkpoint( @@ -341,7 +343,7 @@ class PackedFlashInternLm1D(nn.Module): use_scaled_init=use_scaled_init, use_swiglu=use_swiglu, use_flash_attn=use_flash_attn, - tp_mode = self.tp_mode, + tp_mode=self.tp_mode, ) for lid in range(num_layers) ] @@ -388,9 +390,9 @@ class PackedFlashInternLm1D(nn.Module): # The indexes are used to indicate the actual position IDs of each token in the packed input. indexes = indexes[0] # if the tensor parallel mode is 'fstp', the indexes should also be split in sequence dimension. - if gpc.config.parallel.sequence_parallel and self.tp_mode == 'fstp': + if gpc.config.parallel.sequence_parallel and self.tp_mode == "fstp": indexes = split_forward_gather_backward(indexes, ParallelMode.TENSOR, dim=0) - + max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() if cu_seqlens is not None else None for _, block in enumerate(self.blocks): diff --git a/internlm/model/multi_head_attention.py b/internlm/model/multi_head_attention.py index 8f7a064..287a0e2 100644 --- a/internlm/model/multi_head_attention.py +++ b/internlm/model/multi_head_attention.py @@ -2,9 +2,10 @@ # -*- encoding: utf-8 -*- import warnings -from typing import Optional +from typing import Any, Optional, Tuple import torch +import torch.distributed as dist from einops import rearrange from flash_attn.modules.mha import ( CrossAttention, @@ -13,26 +14,25 @@ from flash_attn.modules.mha import ( SelfAttention, _update_kv_cache, ) -from torch import nn +from torch import Tensor, nn +from torch.nn import Module from internlm.core.context import IS_TENSOR_PARALLEL, ParallelMode from internlm.core.context import global_context as gpc from internlm.model.embedding import DynamicNTKScalingRotaryEmbedding, RotaryEmbedding -from internlm.model.linear import ColumnParallelLinearTorch, RowParallelLinearTorch, FSDPLinear - -import torch - -from typing import Any, Tuple -from torch import Tensor -from torch.nn import Module - -import torch.distributed as dist +from internlm.model.linear import ( + ColumnParallelLinearTorch, + FSTPLinear, + RowParallelLinearTorch, +) +# adpated from https://github.com/microsoft/DeepSpeed/blob/master/deepspeed/sequence/layer.py class _SeqAllToAll(torch.autograd.Function): + "sequence alltoall" @staticmethod - def forward(ctx: Any, group: dist.ProcessGroup, input: Tensor, scatter_idx: int, gather_idx: int) -> Tensor: + def forward(ctx: Any, group: dist.ProcessGroup, input_: Tensor, scatter_idx: int, gather_idx: int) -> Tensor: ctx.group = group ctx.scatter_idx = scatter_idx @@ -40,7 +40,7 @@ class _SeqAllToAll(torch.autograd.Function): seq_world_size = dist.get_world_size(group) - input_list = [t.contiguous() for t in torch.tensor_split(input, seq_world_size, scatter_idx)] + input_list = [t.contiguous() for t in torch.tensor_split(input_, seq_world_size, scatter_idx)] output_list = [torch.empty_like(input_list[0]) for _ in range(seq_world_size)] # TODO Use all_to_all_single instead dist.all_to_all(output_list, input_list, group=group) @@ -51,6 +51,7 @@ class _SeqAllToAll(torch.autograd.Function): return (None, _SeqAllToAll.apply(ctx.group, *grad_output, ctx.gather_idx, ctx.scatter_idx), None, None) +# adpated from https://github.com/microsoft/DeepSpeed/blob/master/deepspeed/sequence/layer.py class DistributedAttention(torch.nn.Module): """Initialization. @@ -73,16 +74,16 @@ class DistributedAttention(torch.nn.Module): second_gather_idx: int = 1, ) -> None: - super(DistributedAttention, self).__init__() + super().__init__() self.local_attn = local_attention self.spg = sequence_process_group self.first_scatter_idx = first_scatter_idx self.first_gather_idx = first_gather_idx self.second_scatter_idx = second_scatter_idx self.second_gather_idx = second_gather_idx - + def forward(self, qkv: Tensor, **kwargs: Any) -> Tensor: - """ forward + """forward Arguments: query (Tensor): query input to the layer @@ -93,24 +94,25 @@ class DistributedAttention(torch.nn.Module): Returns: * output (Tensor): context output """ - # TODO Merge three alltoall calls into one + # Evaluation if qkv.ndim == 5: - # in shape: [seq/tp_size, 3, head, head_dim] + # in shape: [batch, seq/tp_size, 3, head, head_dim] qkv = _SeqAllToAll.apply(self.spg, qkv, self.first_scatter_idx + 1, self.first_gather_idx + 1) - #out shape : [seq, head/tp_size, head_dim] + # out shape : [batch, seq, head/tp_size, head_dim] context_layer = self.local_attn(qkv, **kwargs) - # in shape: [seq, head/tp_size, head_dim] - output = _SeqAllToAll.apply(self.spg, context_layer, self.second_scatter_idx + 1, self.second_gather_idx + 1) - else: - + # in shape: [batch, seq, head/tp_size, head_dim] + output = _SeqAllToAll.apply( + self.spg, context_layer, self.second_scatter_idx + 1, self.second_gather_idx + 1 + ) + else: # training # in shape: [seq/tp_size, 3, head, head_dim] qkv = _SeqAllToAll.apply(self.spg, qkv, self.first_scatter_idx, self.first_gather_idx) - #out shape : [seq, head/tp_size, head_dim] + # out shape : [seq, head/tp_size, head_dim] context_layer = self.local_attn(qkv, **kwargs) # in shape: [seq, head/tp_size, head_dim] output = _SeqAllToAll.apply(self.spg, context_layer, self.second_scatter_idx, self.second_gather_idx) - #out e.g., [s/p::h] + # out e.g., [s/p::h] return output @@ -157,7 +159,7 @@ class MHA(nn.Module): use_flash_attn: bool = True, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None, - tp_mode: str = 'origin_tp', + tp_mode: str = "origin_tp", ) -> None: factory_kwargs = {"device": device, "dtype": dtype} super().__init__() @@ -185,7 +187,7 @@ class MHA(nn.Module): self.rotary_emb = RotaryEmbedding(self.rotary_emb_dim, scale_base=rotary_emb_scale_base, device=device) # notice here should change bias=True - Wqkv_cls = ColumnParallelLinearTorch if tp_mode == 'origin_tp' else FSDPLinear + Wqkv_cls = ColumnParallelLinearTorch if tp_mode == "origin_tp" else FSTPLinear self.Wqkv = Wqkv_cls( embed_dim, 3 * embed_dim, @@ -201,12 +203,12 @@ class MHA(nn.Module): self.inner_cross_attn = inner_cross_attn_cls( causal=causal, softmax_scale=softmax_scale, attention_dropout=dropout ) - if tp_mode == 'fstp': + if tp_mode == "fstp": self.inner_attn = DistributedAttention(self.inner_attn, sequence_process_group=process_group) self.inner_cross_attn = DistributedAttention(self.inner_cross_attn, sequence_process_group=process_group) # output projection always have the bias (for now) - out_proj_cls = RowParallelLinearTorch if tp_mode == 'origin_tp' else FSDPLinear + out_proj_cls = RowParallelLinearTorch if tp_mode == "origin_tp" else FSTPLinear self.out_proj = out_proj_cls( embed_dim, embed_dim, @@ -214,7 +216,6 @@ class MHA(nn.Module): sequence_parallel=gpc.config.parallel.sequence_parallel, **factory_kwargs, ) - # need to assign tp attribute so that internlm know it is tensor parallel module if gpc.get_world_size(ParallelMode.TENSOR) > 1: for name in ["out_proj", "Wqkv"]: @@ -311,7 +312,6 @@ class MHA(nn.Module): qkv = rearrange(qkv, "t (three h d) -> t three h d", three=3, d=self.head_dim) # total x 3 x n_head x d qkv = self.rotary_emb(qkv, **kwargs) kwargs.pop("indexes") - if inference_params is None: if gpc.config.model.dtype is torch.float32 and gpc.config.model.use_flash_attn: with torch.cuda.amp.autocast(dtype=torch.bfloat16): diff --git a/internlm/model/utils.py b/internlm/model/utils.py index c884544..67e89ad 100644 --- a/internlm/model/utils.py +++ b/internlm/model/utils.py @@ -3,18 +3,14 @@ from typing import Optional +import fused_dense_lib as fused_dense_cuda import torch import torch.nn.functional as F -from flash_attn.utils.distributed import ( - all_reduce_raw, - reduce_scatter_raw, -) +from flash_attn.utils.distributed import all_reduce_raw, reduce_scatter_raw from torch import Tensor from torch.cuda.amp import custom_bwd, custom_fwd from torch.distributed import ProcessGroup -import fused_dense_lib as fused_dense_cuda - from internlm.core.context import ParallelMode from internlm.core.context import global_context as gpc from internlm.utils.logger import get_logger @@ -123,8 +119,9 @@ def all_gather_raw(input_: Tensor, process_group: ProcessGroup, async_op: bool = shape = list(input_.shape) shape[gather_dim] = shape[gather_dim] * world_size output = torch.empty(shape, dtype=input_.dtype, device=input_.device) - handle = torch.distributed.all_gather_into_tensor(output, input_.contiguous(), - group=process_group, async_op=async_op) + handle = torch.distributed.all_gather_into_tensor( + output, input_.contiguous(), group=process_group, async_op=async_op + ) return output, handle @@ -137,11 +134,11 @@ def linear_bias_wgrad_torch(my_input, grad_output, has_d_bias): # adpated from https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/ops/fused_dense.py class FusedDenseFunc(torch.autograd.Function): + "tp fused dense function" @staticmethod @custom_fwd - def forward(ctx, x, weight, bias, return_residual=False, process_group=None, - sequence_parallel=True, gather_dim=0): + def forward(ctx, x, weight, bias, return_residual=False, process_group=None, sequence_parallel=True, gather_dim=0): """ If process_group is not None and sequence_parallel=True, we're doing Tensor Parallel with sequence parallelism: we do an all_gather_raw of x before doing the matmul. @@ -171,7 +168,7 @@ class FusedDenseFunc(torch.autograd.Function): batch_dim = batch_shape.numel() # https://github.com/pytorch/pytorch/blob/5b51849b48a7dbccd297286cc0110def4706f9e7/aten/src/ATen/native/cuda/Blas.cpp#L174 if min(batch_dim, n, *weight.shape) > 65535 * 32: - raise RuntimeError('fused_dense only supports matrix dims <= 2M') + raise RuntimeError("fused_dense only supports matrix dims <= 2M") output = F.linear(total_x, weight, bias) if ctx.compute_weight_gradient: ctx.save_for_backward(x, weight) @@ -184,12 +181,12 @@ class FusedDenseFunc(torch.autograd.Function): def backward(ctx, grad_output, *args): grad_output = grad_output.contiguous() if ctx.return_residual: - grad_input, = args + (grad_input,) = args grad_input = grad_input.contiguous() process_group = ctx.process_group sequence_parallel = ctx.sequence_parallel gather_dim = ctx.gather_dim - + if ctx.compute_weight_gradient: x, weight = ctx.saved_tensors if process_group is not None and sequence_parallel: @@ -197,7 +194,7 @@ class FusedDenseFunc(torch.autograd.Function): else: total_x = x else: - weight, = ctx.saved_tensors + (weight,) = ctx.saved_tensors total_x = None batch_shape = grad_output.shape[:-1] batch_dim = batch_shape.numel() @@ -206,8 +203,7 @@ class FusedDenseFunc(torch.autograd.Function): if not ctx.return_residual: grad_input = F.linear(grad_output, weight.t()) else: - grad_input = torch.addmm(grad_input.reshape(batch_dim, grad_input.shape[-1]), - grad_output, weight) + grad_input = torch.addmm(grad_input.reshape(batch_dim, grad_input.shape[-1]), grad_output, weight) grad_input = grad_input.reshape(*batch_shape, grad_input.shape[-1]) if process_group is not None: reduce_fn = reduce_scatter_raw if sequence_parallel else all_reduce_raw @@ -282,7 +278,8 @@ class FusedDenseFuncTorch(FusedDenseFunc): return grad_input, grad_weight, grad_bias, None, None, None, None -class FSDPFusedDenseFunc(torch.autograd.Function): +class FSTPFusedDenseFunc(torch.autograd.Function): + "FSTP fused dense function" @staticmethod @custom_fwd @@ -295,7 +292,7 @@ class FSDPFusedDenseFunc(torch.autograd.Function): if torch.is_autocast_enabled(): x = x.to(dtype=torch.get_autocast_gpu_dtype()) total_x = x.contiguous() - + world_size = gpc.get_world_size(ParallelMode.TENSOR) if world_size > 1: # do all_gather for weight and bias before actual computation @@ -313,13 +310,13 @@ class FSDPFusedDenseFunc(torch.autograd.Function): if torch.is_autocast_enabled(): total_weight = total_weight.to(dtype=torch.get_autocast_gpu_dtype()) total_bias = total_bias.to(dtype=torch.get_autocast_gpu_dtype()) if bias is not None else None - + total_weight = total_weight.contiguous() batch_shape, n = total_x.shape[:-1], total_x.shape[-1] batch_dim = batch_shape.numel() # https://github.com/pytorch/pytorch/blob/5b51849b48a7dbccd297286cc0110def4706f9e7/aten/src/ATen/native/cuda/Blas.cpp#L174 if min(batch_dim, n, *total_weight.shape) > 65535 * 32: - raise RuntimeError('fused_dense only supports matrix dims <= 2M') + raise RuntimeError("fused_dense only supports matrix dims <= 2M") output = F.linear(total_x, total_weight, total_bias) if ctx.compute_weight_gradient: ctx.save_for_backward(x, weight) @@ -332,19 +329,19 @@ class FSDPFusedDenseFunc(torch.autograd.Function): def backward(ctx, grad_output, *args): grad_output = grad_output.contiguous() if ctx.return_residual: - grad_input, = args + (grad_input,) = args grad_input = grad_input.contiguous() process_group = ctx.process_group if ctx.compute_weight_gradient: x, weight = ctx.saved_tensors total_x = x else: - weight, = ctx.saved_tensors + (weight,) = ctx.saved_tensors total_x = None batch_shape = grad_output.shape[:-1] batch_dim = batch_shape.numel() grad_output = grad_output.reshape(batch_dim, grad_output.shape[-1]) - + world_size = gpc.get_world_size(ParallelMode.TENSOR) if world_size > 1: # do all-gather for weight before backward @@ -352,13 +349,12 @@ class FSDPFusedDenseFunc(torch.autograd.Function): handle_weight.wait() else: total_weight = weight - + if ctx.needs_input_grad[0]: if not ctx.return_residual: grad_input = F.linear(grad_output, total_weight.t()) else: - grad_input = torch.addmm(grad_input.reshape(batch_dim, grad_input.shape[-1]), - grad_output, total_weight) + grad_input = torch.addmm(grad_input.reshape(batch_dim, grad_input.shape[-1]), grad_output, total_weight) grad_input = grad_input.reshape(*batch_shape, grad_input.shape[-1]) else: grad_input = None @@ -372,7 +368,7 @@ class FSDPFusedDenseFunc(torch.autograd.Function): if world_size > 1: grad_weight, handle_grad_weight = reduce_scatter_raw(grad_weight, process_group, async_op=True) if grad_bias is not None: - grad_bias, handle_grad_bias = reduce_scatter_raw(grad_bias, process_group, async_op=True) + grad_bias, handle_grad_bias = reduce_scatter_raw(grad_bias, process_group, async_op=True) handle_grad_bias.wait() handle_grad_weight.wait() else: @@ -399,12 +395,14 @@ def fused_dense_func_torch( return FusedDenseFuncTorch.apply(x, weight, bias, return_residual, process_group, sequence_parallel, gather_dim) -def fsdp_fused_dense_func(x: Tensor, weight: Tensor, bias: Optional[Tensor] = None, - return_residual: bool = False, process_group = None): - dtype_eligible = (x.dtype in [torch.float16, torch.bfloat16] - or (x.dtype == torch.float32 and torch.is_autocast_enabled())) +def fstp_fused_dense_func( + x: Tensor, weight: Tensor, bias: Optional[Tensor] = None, return_residual: bool = False, process_group=None +): + dtype_eligible = x.dtype in [torch.float16, torch.bfloat16] or ( + x.dtype == torch.float32 and torch.is_autocast_enabled() + ) if x.is_cuda and weight.is_cuda and (bias is None or bias.is_cuda) and dtype_eligible: - return FSDPFusedDenseFunc.apply(x, weight, bias, return_residual, process_group) + return FSTPFusedDenseFunc.apply(x, weight, bias, return_residual, process_group) else: assert process_group is None out = F.linear(x, weight, bias) diff --git a/internlm/utils/evaluation.py b/internlm/utils/evaluation.py index 148d19d..968a1db 100644 --- a/internlm/utils/evaluation.py +++ b/internlm/utils/evaluation.py @@ -54,7 +54,7 @@ def switch_evaluation_pipeline_scheduler(trainer, num_microbatches, tensor_shape def switch_sequence_parallel_mode(): prev_mode = gpc.config.parallel.sequence_parallel try: - if gpc.config.parallel["tensor"]["mode"] == 'fstp': + if gpc.config.parallel["tensor"]["mode"] == "fstp": gpc.config.parallel.sequence_parallel = True else: gpc.config.parallel.sequence_parallel = False @@ -106,10 +106,14 @@ def evaluate_on_val_dls( total_val_bsz = len(batch[1]) assert total_val_bsz % data_cfg.micro_bsz == 0 num_microbatches = total_val_bsz // data_cfg.micro_bsz - if gpc.config.parallel['tensor']['mode'] == 'fstp': + if gpc.config.parallel["tensor"]["mode"] == "fstp": sequence_world_size = gpc.get_world_size(ParallelMode.TENSOR) tensor_shape = torch.Size( - [data_cfg.micro_bsz, batch[0]["input_ids"].shape[1] // sequence_world_size, gpc.config.HIDDEN_SIZE] + [ + data_cfg.micro_bsz, + batch[0]["input_ids"].shape[1] // sequence_world_size, + gpc.config.HIDDEN_SIZE, + ] ) else: tensor_shape = torch.Size(