InternLM/internlm/model/multi_head_attention.py

330 lines
15 KiB
Python

#!/usr/bin/env python
# -*- encoding: utf-8 -*-
import warnings
from typing import Optional
import torch
from einops import rearrange
from flash_attn.modules.mha import (
CrossAttention,
FlashCrossAttention,
FlashSelfAttention,
SelfAttention,
_update_kv_cache,
)
from torch import nn
from internlm.core.context import IS_TENSOR_PARALLEL, ParallelMode
from internlm.core.context import global_context as gpc
from internlm.model.embedding import DynamicNTKScalingRotaryEmbedding, RotaryEmbedding
from internlm.model.linear import ColumnParallelLinearTorch, RowParallelLinearTorch, FSDPLinear
import torch
from typing import Any, Tuple
from torch import Tensor
from torch.nn import Module
import torch.distributed as dist
class _SeqAllToAll(torch.autograd.Function):
@staticmethod
def forward(ctx: Any, group: dist.ProcessGroup, input: Tensor, scatter_idx: int, gather_idx: int) -> Tensor:
ctx.group = group
ctx.scatter_idx = scatter_idx
ctx.gather_idx = gather_idx
seq_world_size = dist.get_world_size(group)
input_list = [t.contiguous() for t in torch.tensor_split(input, seq_world_size, scatter_idx)]
output_list = [torch.empty_like(input_list[0]) for _ in range(seq_world_size)]
# TODO Use all_to_all_single instead
dist.all_to_all(output_list, input_list, group=group)
return torch.cat(output_list, dim=gather_idx).contiguous()
@staticmethod
def backward(ctx: Any, *grad_output: Tensor) -> Tuple[None, Tensor, None, None]:
return (None, _SeqAllToAll.apply(ctx.group, *grad_output, ctx.gather_idx, ctx.scatter_idx), None, None)
class DistributedAttention(torch.nn.Module):
"""Initialization.
Arguments:
local_attention (Module): local attention with q,k,v
sequence_process_group (ProcessGroup): sequence parallel process group
first_scatter_idx (int): scatter_idx for the first all2all comm
first_gather_idx (int): gather_idx for the first all2all comm
second_scatter_idx (int): scatter_idx for the second all2all comm
second_gather_idx (int): gather_idx for the second all2all comm
"""
def __init__(
self,
local_attention: Module,
sequence_process_group: dist.ProcessGroup,
first_scatter_idx: int = 2,
first_gather_idx: int = 0,
second_scatter_idx: int = 0,
second_gather_idx: int = 1,
) -> None:
super(DistributedAttention, self).__init__()
self.local_attn = local_attention
self.spg = sequence_process_group
self.first_scatter_idx = first_scatter_idx
self.first_gather_idx = first_gather_idx
self.second_scatter_idx = second_scatter_idx
self.second_gather_idx = second_gather_idx
def forward(self, qkv: Tensor, **kwargs: Any) -> Tensor:
""" forward
Arguments:
query (Tensor): query input to the layer
key (Tensor): key input to the layer
value (Tensor): value input to the layer
args: other args
Returns:
* output (Tensor): context output
"""
# TODO Merge three alltoall calls into one
if qkv.ndim == 5:
# in shape: [seq/tp_size, 3, head, head_dim]
qkv = _SeqAllToAll.apply(self.spg, qkv, self.first_scatter_idx + 1, self.first_gather_idx + 1)
#out shape : [seq, head/tp_size, head_dim]
context_layer = self.local_attn(qkv, **kwargs)
# in shape: [seq, head/tp_size, head_dim]
output = _SeqAllToAll.apply(self.spg, context_layer, self.second_scatter_idx + 1, self.second_gather_idx + 1)
else:
# in shape: [seq/tp_size, 3, head, head_dim]
qkv = _SeqAllToAll.apply(self.spg, qkv, self.first_scatter_idx, self.first_gather_idx)
#out shape : [seq, head/tp_size, head_dim]
context_layer = self.local_attn(qkv, **kwargs)
# in shape: [seq, head/tp_size, head_dim]
output = _SeqAllToAll.apply(self.spg, context_layer, self.second_scatter_idx, self.second_gather_idx)
#out e.g., [s/p::h]
return output
class MHA(nn.Module):
"""
Multi-head self-attention and cross-attention.
Args:
embed_dim (int): The dimention of hidden state.
num_heads (int): The number of attention heads.
process_group (torch.distributed.ProcessGroup): The group of the current device for `parallel_mode`.
bias (boolean): Whether the bias is needed for linears. Will be used when initializing QKV matrix and
output projection. True by default.
dropout (float): The dropout rate for cross attention and self attention. 0.0 by default.
softmax_scale (float): The temperature to use for the softmax attention.
causal (boolean): Whether to apply causal attention mask. False by default.
layer_idx (int): The index of current layer. None by default.
rotary_emb_dim (int): The dimention of Rotary Embedding. 0 by default.
rotary_emb_scale_base (int): The scaling factor of Rotary Embedding. If scale_base > 0, this implements
XPos(Sun et al., https://arxiv.org/abs/2212.10554). 0 by default.
use_flash_attn (boolean): Whether to use flash attention or not.If False, vanilla attention module will be used.
False by default.
sequence_parallel (boolean): If True, we're doing Tensor Parallel with sequence parallelism. An all_gather_raw
of x will be done before doing the matmul.
device (Optional[Union[str, torch.device]]): The device will be used.
dtype (Optional[torch.dtype]): The type of data.
use_flash_attn (bool): Whether to use flash-attn. True by default.
"""
def __init__(
self,
embed_dim: int,
num_heads: int,
process_group: Optional[torch.distributed.ProcessGroup],
max_position_embeddings: int = 2048,
dropout: float = 0.0,
softmax_scale: float = None,
causal: bool = False,
layer_idx: int = None,
use_dynamic_ntk_rope: bool = False,
rotary_emb_dim: int = 0,
rotary_emb_scale_base: int = 0,
use_flash_attn: bool = True,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
tp_mode: str = 'origin_tp',
) -> None:
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__()
self.embed_dim = embed_dim
self.causal = causal
self.layer_idx = layer_idx
self.max_position_embeddings = max_position_embeddings
self.use_dynamic_ntk_rope = use_dynamic_ntk_rope
self.rotary_emb_dim = rotary_emb_dim
self.use_flash_attn = use_flash_attn
self.num_heads = num_heads
assert self.embed_dim % num_heads == 0, "self.kdim must be divisible by num_heads"
self.head_dim = self.embed_dim // num_heads
if self.rotary_emb_dim > 0:
if self.use_dynamic_ntk_rope:
self.rotary_emb = DynamicNTKScalingRotaryEmbedding(
self.rotary_emb_dim,
scale_base=rotary_emb_scale_base,
device=device,
max_position_embeddings=max_position_embeddings,
scaling_factor=1.0, # Currently do not support dynamic scaling.
)
else:
self.rotary_emb = RotaryEmbedding(self.rotary_emb_dim, scale_base=rotary_emb_scale_base, device=device)
# notice here should change bias=True
Wqkv_cls = ColumnParallelLinearTorch if tp_mode == 'origin_tp' else FSDPLinear
self.Wqkv = Wqkv_cls(
embed_dim,
3 * embed_dim,
process_group,
bias=True,
sequence_parallel=gpc.config.parallel.sequence_parallel,
**factory_kwargs,
) # according to https://spaces.ac.cn/archives/9577
inner_attn_cls = FlashSelfAttention if use_flash_attn else SelfAttention
inner_cross_attn_cls = FlashCrossAttention if use_flash_attn else CrossAttention
self.inner_attn = inner_attn_cls(causal=causal, softmax_scale=softmax_scale, attention_dropout=dropout)
self.inner_cross_attn = inner_cross_attn_cls(
causal=causal, softmax_scale=softmax_scale, attention_dropout=dropout
)
if tp_mode == 'fstp':
self.inner_attn = DistributedAttention(self.inner_attn, sequence_process_group=process_group)
self.inner_cross_attn = DistributedAttention(self.inner_cross_attn, sequence_process_group=process_group)
# output projection always have the bias (for now)
out_proj_cls = RowParallelLinearTorch if tp_mode == 'origin_tp' else FSDPLinear
self.out_proj = out_proj_cls(
embed_dim,
embed_dim,
process_group,
sequence_parallel=gpc.config.parallel.sequence_parallel,
**factory_kwargs,
)
# need to assign tp attribute so that internlm know it is tensor parallel module
if gpc.get_world_size(ParallelMode.TENSOR) > 1:
for name in ["out_proj", "Wqkv"]:
for param in getattr(self, name).parameters():
setattr(param, IS_TENSOR_PARALLEL, True)
def forward(self, x, seqlen=None, inference_params=None, **kwargs):
if kwargs.get("indexes", None) is not None:
return self._packed_forward(x=x, inference_params=inference_params, **kwargs)
else:
return self._forward(x=x, seqlen=seqlen, inference_params=inference_params, **kwargs)
def _forward(self, x, seqlen=None, inference_params=None, **kwargs):
"""
Arguments:
x: (batch, seqlen, hidden_dim) (where hidden_dim = num heads * head dim) if seqlen=None.
If seqlen is not None, x is (batch * seqlen, hidden_dim). This is so that when we
split x during sequence parallel, we split the batch * seqlen dimension
(in case batch is small).
"""
qkv = self.Wqkv(x)
if seqlen is None:
qkv = rearrange(qkv, "b s (three h d) -> b s three h d", three=3, d=self.head_dim)
else:
qkv = rearrange(qkv, "(b s) (three h d) -> b s three h d", s=seqlen, three=3, d=self.head_dim)
if inference_params is None:
if self.rotary_emb_dim > 0:
kwargs["inference_params"] = inference_params
qkv = self.rotary_emb(qkv, **kwargs)
if gpc.config.model.dtype is torch.float32 and gpc.config.model.use_flash_attn:
with torch.cuda.amp.autocast(dtype=torch.bfloat16):
if qkv.dtype not in [torch.float16, torch.bfloat16]:
qkv = qkv.to(torch.bfloat16)
context = self.inner_attn(qkv).to(x.dtype)
else:
context = self.inner_attn(qkv)
else:
if self.use_dynamic_ntk_rope:
q = qkv[:, :, 0]
assert self.layer_idx is not None, "Generation requires layer_idx in the constructor"
kv = _update_kv_cache(qkv[:, :, 1:], inference_params, self.layer_idx)
if inference_params.sequence_len_offset != 0:
# q shape: [bsz, 1, nheads, head_dim]
# kv shape: [bsz, seqlen, 2, nheads, head_dim]
bsz, seq_len, _, nheads, head_dim = kv.shape
q = torch.cat([q.new_zeros(size=(bsz, seq_len - 1, nheads, head_dim)), q], dim=1).unsqueeze(2)
qkv = torch.cat([q, kv], dim=2)
if self.rotary_emb_dim > 0:
qkv = self.rotary_emb(qkv)
q = qkv[:, [-1], 0]
kv = qkv[:, :, 1:]
else:
if inference_params.sequence_len_offset > self.max_position_embeddings:
warnings.warn(
"Notice your prompt's length is longer than model's max_position_embeddings: "
f"{self.max_position_embeddings}, which will cause deviations in dynamic ntk calculations."
)
if self.rotary_emb_dim > 0:
kwargs["inference_params"] = inference_params
qkv = self.rotary_emb(qkv, **kwargs)
q = qkv[:, :, 0]
kv = qkv[:, :, 1:]
else:
if self.rotary_emb_dim > 0:
kwargs["inference_params"] = inference_params
qkv = self.rotary_emb(qkv, **kwargs)
q = qkv[:, :, 0]
assert self.layer_idx is not None, "Generation requires layer_idx in the constructor"
kv = _update_kv_cache(qkv[:, :, 1:], inference_params, self.layer_idx)
# If we're processing the prompt, causal=None (use self.causal).
# If we're decoding, then causal=False.
causal = None if inference_params.sequence_len_offset == 0 else False
context = self.inner_cross_attn(q, kv, causal=causal)
if seqlen is None:
context = rearrange(context, "b s h d -> b s (h d)")
else:
context = rearrange(context, "b s h d -> (b s) (h d)")
out = self.out_proj(context)
return out
def _packed_forward(self, x, inference_params=None, **kwargs):
"""
Arguments:
x: (batch, seqlen, hidden_dim) (where hidden_dim = num heads * head dim) if seqlen=None.
If seqlen is not None, x is (batch * seqlen, hidden_dim). This is so that when we
split x during sequence parallel, we split the batch * seqlen dimension
(in case batch is small).
"""
qkv = self.Wqkv(x) # total x hsz'
qkv = rearrange(qkv, "t (three h d) -> t three h d", three=3, d=self.head_dim) # total x 3 x n_head x d
qkv = self.rotary_emb(qkv, **kwargs)
kwargs.pop("indexes")
if inference_params is None:
if gpc.config.model.dtype is torch.float32 and gpc.config.model.use_flash_attn:
with torch.cuda.amp.autocast(dtype=torch.bfloat16):
if qkv.dtype not in [torch.float16, torch.bfloat16]:
qkv = qkv.to(torch.bfloat16)
context = self.inner_attn(qkv, **kwargs).to(x.dtype)
else:
context = self.inner_attn(qkv, **kwargs)
else:
raise RuntimeError("Not support this right now")
context = rearrange(context, "b h d -> b (h d)") # recover the shape
out = self.out_proj(context)
return out