#!/usr/bin/env python # -*- encoding: utf-8 -*- import warnings from typing import Any, Optional, Tuple import torch import torch.distributed as dist from einops import rearrange from flash_attn.modules.mha import ( CrossAttention, FlashCrossAttention, FlashSelfAttention, SelfAttention, _update_kv_cache, ) from torch import Tensor, nn from torch.nn import Module from internlm.core.context import IS_TENSOR_PARALLEL, ParallelMode from internlm.core.context import global_context as gpc from internlm.model.embedding import DynamicNTKScalingRotaryEmbedding, RotaryEmbedding from internlm.model.linear import ( ColumnParallelLinearTorch, FSTPLinear, RowParallelLinearTorch, ) # adpated from https://github.com/microsoft/DeepSpeed/blob/master/deepspeed/sequence/layer.py class _SeqAllToAll(torch.autograd.Function): "sequence alltoall" @staticmethod def forward(ctx: Any, group: dist.ProcessGroup, input_: Tensor, scatter_idx: int, gather_idx: int) -> Tensor: ctx.group = group ctx.scatter_idx = scatter_idx ctx.gather_idx = gather_idx seq_world_size = dist.get_world_size(group) input_list = [t.contiguous() for t in torch.tensor_split(input_, seq_world_size, scatter_idx)] output_list = [torch.empty_like(input_list[0]) for _ in range(seq_world_size)] # TODO Use all_to_all_single instead dist.all_to_all(output_list, input_list, group=group) return torch.cat(output_list, dim=gather_idx).contiguous() @staticmethod def backward(ctx: Any, *grad_output: Tensor) -> Tuple[None, Tensor, None, None]: return (None, _SeqAllToAll.apply(ctx.group, *grad_output, ctx.gather_idx, ctx.scatter_idx), None, None) # adpated from https://github.com/microsoft/DeepSpeed/blob/master/deepspeed/sequence/layer.py class DistributedAttention(torch.nn.Module): """Initialization. Arguments: local_attention (Module): local attention with q,k,v sequence_process_group (ProcessGroup): sequence parallel process group first_scatter_idx (int): scatter_idx for the first all2all comm first_gather_idx (int): gather_idx for the first all2all comm second_scatter_idx (int): scatter_idx for the second all2all comm second_gather_idx (int): gather_idx for the second all2all comm """ def __init__( self, local_attention: Module, sequence_process_group: dist.ProcessGroup, first_scatter_idx: int = 2, first_gather_idx: int = 0, second_scatter_idx: int = 0, second_gather_idx: int = 1, ) -> None: super().__init__() self.local_attn = local_attention self.spg = sequence_process_group self.first_scatter_idx = first_scatter_idx self.first_gather_idx = first_gather_idx self.second_scatter_idx = second_scatter_idx self.second_gather_idx = second_gather_idx def forward(self, qkv: Tensor, **kwargs: Any) -> Tensor: """forward Arguments: query (Tensor): query input to the layer key (Tensor): key input to the layer value (Tensor): value input to the layer args: other args Returns: * output (Tensor): context output """ # Evaluation if qkv.ndim == 5: # in shape: [batch, seq/tp_size, 3, head, head_dim] qkv = _SeqAllToAll.apply(self.spg, qkv, self.first_scatter_idx + 1, self.first_gather_idx + 1) # out shape : [batch, seq, head/tp_size, head_dim] context_layer = self.local_attn(qkv, **kwargs) # in shape: [batch, seq, head/tp_size, head_dim] output = _SeqAllToAll.apply( self.spg, context_layer, self.second_scatter_idx + 1, self.second_gather_idx + 1 ) else: # training # in shape: [seq/tp_size, 3, head, head_dim] qkv = _SeqAllToAll.apply(self.spg, qkv, self.first_scatter_idx, self.first_gather_idx) # out shape : [seq, head/tp_size, head_dim] context_layer = self.local_attn(qkv, **kwargs) # in shape: [seq, head/tp_size, head_dim] output = _SeqAllToAll.apply(self.spg, context_layer, self.second_scatter_idx, self.second_gather_idx) # out e.g., [s/p::h] return output class MHA(nn.Module): """ Multi-head self-attention and cross-attention. Args: embed_dim (int): The dimention of hidden state. num_heads (int): The number of attention heads. process_group (torch.distributed.ProcessGroup): The group of the current device for `parallel_mode`. bias (boolean): Whether the bias is needed for linears. Will be used when initializing QKV matrix and output projection. True by default. dropout (float): The dropout rate for cross attention and self attention. 0.0 by default. softmax_scale (float): The temperature to use for the softmax attention. causal (boolean): Whether to apply causal attention mask. False by default. layer_idx (int): The index of current layer. None by default. rotary_emb_dim (int): The dimention of Rotary Embedding. 0 by default. rotary_emb_scale_base (int): The scaling factor of Rotary Embedding. If scale_base > 0, this implements XPos(Sun et al., https://arxiv.org/abs/2212.10554). 0 by default. use_flash_attn (boolean): Whether to use flash attention or not.If False, vanilla attention module will be used. False by default. sequence_parallel (boolean): If True, we're doing Tensor Parallel with sequence parallelism. An all_gather_raw of x will be done before doing the matmul. device (Optional[Union[str, torch.device]]): The device will be used. dtype (Optional[torch.dtype]): The type of data. use_flash_attn (bool): Whether to use flash-attn. True by default. """ def __init__( self, embed_dim: int, num_heads: int, process_group: Optional[torch.distributed.ProcessGroup], max_position_embeddings: int = 2048, dropout: float = 0.0, softmax_scale: float = None, causal: bool = False, layer_idx: int = None, use_dynamic_ntk_rope: bool = False, rotary_emb_dim: int = 0, rotary_emb_scale_base: int = 0, use_flash_attn: bool = True, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None, tp_mode: str = "origin_tp", ) -> None: factory_kwargs = {"device": device, "dtype": dtype} super().__init__() self.embed_dim = embed_dim self.causal = causal self.layer_idx = layer_idx self.max_position_embeddings = max_position_embeddings self.use_dynamic_ntk_rope = use_dynamic_ntk_rope self.rotary_emb_dim = rotary_emb_dim self.use_flash_attn = use_flash_attn self.num_heads = num_heads assert self.embed_dim % num_heads == 0, "self.kdim must be divisible by num_heads" self.head_dim = self.embed_dim // num_heads if self.rotary_emb_dim > 0: if self.use_dynamic_ntk_rope: self.rotary_emb = DynamicNTKScalingRotaryEmbedding( self.rotary_emb_dim, scale_base=rotary_emb_scale_base, device=device, max_position_embeddings=max_position_embeddings, scaling_factor=1.0, # Currently do not support dynamic scaling. ) else: self.rotary_emb = RotaryEmbedding(self.rotary_emb_dim, scale_base=rotary_emb_scale_base, device=device) # notice here should change bias=True Wqkv_cls = ColumnParallelLinearTorch if tp_mode == "origin_tp" else FSTPLinear self.Wqkv = Wqkv_cls( embed_dim, 3 * embed_dim, process_group, bias=True, sequence_parallel=gpc.config.parallel.sequence_parallel, **factory_kwargs, ) # according to https://spaces.ac.cn/archives/9577 inner_attn_cls = FlashSelfAttention if use_flash_attn else SelfAttention inner_cross_attn_cls = FlashCrossAttention if use_flash_attn else CrossAttention self.inner_attn = inner_attn_cls(causal=causal, softmax_scale=softmax_scale, attention_dropout=dropout) self.inner_cross_attn = inner_cross_attn_cls( causal=causal, softmax_scale=softmax_scale, attention_dropout=dropout ) if tp_mode == "fstp": self.inner_attn = DistributedAttention(self.inner_attn, sequence_process_group=process_group) self.inner_cross_attn = DistributedAttention(self.inner_cross_attn, sequence_process_group=process_group) # output projection always have the bias (for now) out_proj_cls = RowParallelLinearTorch if tp_mode == "origin_tp" else FSTPLinear self.out_proj = out_proj_cls( embed_dim, embed_dim, process_group, sequence_parallel=gpc.config.parallel.sequence_parallel, **factory_kwargs, ) # need to assign tp attribute so that internlm know it is tensor parallel module if gpc.get_world_size(ParallelMode.TENSOR) > 1: for name in ["out_proj", "Wqkv"]: for param in getattr(self, name).parameters(): setattr(param, IS_TENSOR_PARALLEL, True) def forward(self, x, seqlen=None, inference_params=None, **kwargs): if kwargs.get("indexes", None) is not None: return self._packed_forward(x=x, inference_params=inference_params, **kwargs) else: return self._forward(x=x, seqlen=seqlen, inference_params=inference_params, **kwargs) def _forward(self, x, seqlen=None, inference_params=None, **kwargs): """ Arguments: x: (batch, seqlen, hidden_dim) (where hidden_dim = num heads * head dim) if seqlen=None. If seqlen is not None, x is (batch * seqlen, hidden_dim). This is so that when we split x during sequence parallel, we split the batch * seqlen dimension (in case batch is small). """ qkv = self.Wqkv(x) if seqlen is None: qkv = rearrange(qkv, "b s (three h d) -> b s three h d", three=3, d=self.head_dim) else: qkv = rearrange(qkv, "(b s) (three h d) -> b s three h d", s=seqlen, three=3, d=self.head_dim) if inference_params is None: if self.rotary_emb_dim > 0: kwargs["inference_params"] = inference_params qkv = self.rotary_emb(qkv, **kwargs) if gpc.config.model.dtype is torch.float32 and gpc.config.model.use_flash_attn: with torch.cuda.amp.autocast(dtype=torch.bfloat16): if qkv.dtype not in [torch.float16, torch.bfloat16]: qkv = qkv.to(torch.bfloat16) context = self.inner_attn(qkv).to(x.dtype) else: context = self.inner_attn(qkv) else: if self.use_dynamic_ntk_rope: q = qkv[:, :, 0] assert self.layer_idx is not None, "Generation requires layer_idx in the constructor" kv = _update_kv_cache(qkv[:, :, 1:], inference_params, self.layer_idx) if inference_params.sequence_len_offset != 0: # q shape: [bsz, 1, nheads, head_dim] # kv shape: [bsz, seqlen, 2, nheads, head_dim] bsz, seq_len, _, nheads, head_dim = kv.shape q = torch.cat([q.new_zeros(size=(bsz, seq_len - 1, nheads, head_dim)), q], dim=1).unsqueeze(2) qkv = torch.cat([q, kv], dim=2) if self.rotary_emb_dim > 0: qkv = self.rotary_emb(qkv) q = qkv[:, [-1], 0] kv = qkv[:, :, 1:] else: if inference_params.sequence_len_offset > self.max_position_embeddings: warnings.warn( "Notice your prompt's length is longer than model's max_position_embeddings: " f"{self.max_position_embeddings}, which will cause deviations in dynamic ntk calculations." ) if self.rotary_emb_dim > 0: kwargs["inference_params"] = inference_params qkv = self.rotary_emb(qkv, **kwargs) q = qkv[:, :, 0] kv = qkv[:, :, 1:] else: if self.rotary_emb_dim > 0: kwargs["inference_params"] = inference_params qkv = self.rotary_emb(qkv, **kwargs) q = qkv[:, :, 0] assert self.layer_idx is not None, "Generation requires layer_idx in the constructor" kv = _update_kv_cache(qkv[:, :, 1:], inference_params, self.layer_idx) # If we're processing the prompt, causal=None (use self.causal). # If we're decoding, then causal=False. causal = None if inference_params.sequence_len_offset == 0 else False context = self.inner_cross_attn(q, kv, causal=causal) if seqlen is None: context = rearrange(context, "b s h d -> b s (h d)") else: context = rearrange(context, "b s h d -> (b s) (h d)") out = self.out_proj(context) return out def _packed_forward(self, x, inference_params=None, **kwargs): """ Arguments: x: (batch, seqlen, hidden_dim) (where hidden_dim = num heads * head dim) if seqlen=None. If seqlen is not None, x is (batch * seqlen, hidden_dim). This is so that when we split x during sequence parallel, we split the batch * seqlen dimension (in case batch is small). """ qkv = self.Wqkv(x) # total x hsz' qkv = rearrange(qkv, "t (three h d) -> t three h d", three=3, d=self.head_dim) # total x 3 x n_head x d qkv = self.rotary_emb(qkv, **kwargs) kwargs.pop("indexes") if inference_params is None: if gpc.config.model.dtype is torch.float32 and gpc.config.model.use_flash_attn: with torch.cuda.amp.autocast(dtype=torch.bfloat16): if qkv.dtype not in [torch.float16, torch.bfloat16]: qkv = qkv.to(torch.bfloat16) context = self.inner_attn(qkv, **kwargs).to(x.dtype) else: context = self.inner_attn(qkv, **kwargs) else: raise RuntimeError("Not support this right now") context = rearrange(context, "b h d -> b (h d)") # recover the shape out = self.out_proj(context) return out