mirror of https://github.com/InternLM/InternLM
330 lines
15 KiB
Python
330 lines
15 KiB
Python
#!/usr/bin/env python
|
|
# -*- encoding: utf-8 -*-
|
|
|
|
import warnings
|
|
from typing import Optional
|
|
|
|
import torch
|
|
from einops import rearrange
|
|
from flash_attn.modules.mha import (
|
|
CrossAttention,
|
|
FlashCrossAttention,
|
|
FlashSelfAttention,
|
|
SelfAttention,
|
|
_update_kv_cache,
|
|
)
|
|
from torch import nn
|
|
|
|
from internlm.core.context import IS_TENSOR_PARALLEL, ParallelMode
|
|
from internlm.core.context import global_context as gpc
|
|
from internlm.model.embedding import DynamicNTKScalingRotaryEmbedding, RotaryEmbedding
|
|
from internlm.model.linear import ColumnParallelLinearTorch, RowParallelLinearTorch, FSDPLinear
|
|
|
|
import torch
|
|
|
|
from typing import Any, Tuple
|
|
from torch import Tensor
|
|
from torch.nn import Module
|
|
|
|
import torch.distributed as dist
|
|
|
|
|
|
class _SeqAllToAll(torch.autograd.Function):
|
|
|
|
@staticmethod
|
|
def forward(ctx: Any, group: dist.ProcessGroup, input: Tensor, scatter_idx: int, gather_idx: int) -> Tensor:
|
|
|
|
ctx.group = group
|
|
ctx.scatter_idx = scatter_idx
|
|
ctx.gather_idx = gather_idx
|
|
|
|
seq_world_size = dist.get_world_size(group)
|
|
|
|
input_list = [t.contiguous() for t in torch.tensor_split(input, seq_world_size, scatter_idx)]
|
|
output_list = [torch.empty_like(input_list[0]) for _ in range(seq_world_size)]
|
|
# TODO Use all_to_all_single instead
|
|
dist.all_to_all(output_list, input_list, group=group)
|
|
return torch.cat(output_list, dim=gather_idx).contiguous()
|
|
|
|
@staticmethod
|
|
def backward(ctx: Any, *grad_output: Tensor) -> Tuple[None, Tensor, None, None]:
|
|
return (None, _SeqAllToAll.apply(ctx.group, *grad_output, ctx.gather_idx, ctx.scatter_idx), None, None)
|
|
|
|
|
|
class DistributedAttention(torch.nn.Module):
|
|
"""Initialization.
|
|
|
|
Arguments:
|
|
local_attention (Module): local attention with q,k,v
|
|
sequence_process_group (ProcessGroup): sequence parallel process group
|
|
first_scatter_idx (int): scatter_idx for the first all2all comm
|
|
first_gather_idx (int): gather_idx for the first all2all comm
|
|
second_scatter_idx (int): scatter_idx for the second all2all comm
|
|
second_gather_idx (int): gather_idx for the second all2all comm
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
local_attention: Module,
|
|
sequence_process_group: dist.ProcessGroup,
|
|
first_scatter_idx: int = 2,
|
|
first_gather_idx: int = 0,
|
|
second_scatter_idx: int = 0,
|
|
second_gather_idx: int = 1,
|
|
) -> None:
|
|
|
|
super(DistributedAttention, self).__init__()
|
|
self.local_attn = local_attention
|
|
self.spg = sequence_process_group
|
|
self.first_scatter_idx = first_scatter_idx
|
|
self.first_gather_idx = first_gather_idx
|
|
self.second_scatter_idx = second_scatter_idx
|
|
self.second_gather_idx = second_gather_idx
|
|
|
|
def forward(self, qkv: Tensor, **kwargs: Any) -> Tensor:
|
|
""" forward
|
|
|
|
Arguments:
|
|
query (Tensor): query input to the layer
|
|
key (Tensor): key input to the layer
|
|
value (Tensor): value input to the layer
|
|
args: other args
|
|
|
|
Returns:
|
|
* output (Tensor): context output
|
|
"""
|
|
# TODO Merge three alltoall calls into one
|
|
if qkv.ndim == 5:
|
|
# in shape: [seq/tp_size, 3, head, head_dim]
|
|
qkv = _SeqAllToAll.apply(self.spg, qkv, self.first_scatter_idx + 1, self.first_gather_idx + 1)
|
|
#out shape : [seq, head/tp_size, head_dim]
|
|
context_layer = self.local_attn(qkv, **kwargs)
|
|
# in shape: [seq, head/tp_size, head_dim]
|
|
output = _SeqAllToAll.apply(self.spg, context_layer, self.second_scatter_idx + 1, self.second_gather_idx + 1)
|
|
else:
|
|
|
|
# in shape: [seq/tp_size, 3, head, head_dim]
|
|
qkv = _SeqAllToAll.apply(self.spg, qkv, self.first_scatter_idx, self.first_gather_idx)
|
|
#out shape : [seq, head/tp_size, head_dim]
|
|
context_layer = self.local_attn(qkv, **kwargs)
|
|
# in shape: [seq, head/tp_size, head_dim]
|
|
output = _SeqAllToAll.apply(self.spg, context_layer, self.second_scatter_idx, self.second_gather_idx)
|
|
|
|
#out e.g., [s/p::h]
|
|
return output
|
|
|
|
|
|
class MHA(nn.Module):
|
|
"""
|
|
Multi-head self-attention and cross-attention.
|
|
|
|
Args:
|
|
embed_dim (int): The dimention of hidden state.
|
|
num_heads (int): The number of attention heads.
|
|
process_group (torch.distributed.ProcessGroup): The group of the current device for `parallel_mode`.
|
|
bias (boolean): Whether the bias is needed for linears. Will be used when initializing QKV matrix and
|
|
output projection. True by default.
|
|
dropout (float): The dropout rate for cross attention and self attention. 0.0 by default.
|
|
softmax_scale (float): The temperature to use for the softmax attention.
|
|
causal (boolean): Whether to apply causal attention mask. False by default.
|
|
layer_idx (int): The index of current layer. None by default.
|
|
rotary_emb_dim (int): The dimention of Rotary Embedding. 0 by default.
|
|
rotary_emb_scale_base (int): The scaling factor of Rotary Embedding. If scale_base > 0, this implements
|
|
XPos(Sun et al., https://arxiv.org/abs/2212.10554). 0 by default.
|
|
use_flash_attn (boolean): Whether to use flash attention or not.If False, vanilla attention module will be used.
|
|
False by default.
|
|
sequence_parallel (boolean): If True, we're doing Tensor Parallel with sequence parallelism. An all_gather_raw
|
|
of x will be done before doing the matmul.
|
|
device (Optional[Union[str, torch.device]]): The device will be used.
|
|
dtype (Optional[torch.dtype]): The type of data.
|
|
use_flash_attn (bool): Whether to use flash-attn. True by default.
|
|
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
embed_dim: int,
|
|
num_heads: int,
|
|
process_group: Optional[torch.distributed.ProcessGroup],
|
|
max_position_embeddings: int = 2048,
|
|
dropout: float = 0.0,
|
|
softmax_scale: float = None,
|
|
causal: bool = False,
|
|
layer_idx: int = None,
|
|
use_dynamic_ntk_rope: bool = False,
|
|
rotary_emb_dim: int = 0,
|
|
rotary_emb_scale_base: int = 0,
|
|
use_flash_attn: bool = True,
|
|
device: Optional[torch.device] = None,
|
|
dtype: Optional[torch.dtype] = None,
|
|
tp_mode: str = 'origin_tp',
|
|
) -> None:
|
|
factory_kwargs = {"device": device, "dtype": dtype}
|
|
super().__init__()
|
|
self.embed_dim = embed_dim
|
|
self.causal = causal
|
|
self.layer_idx = layer_idx
|
|
self.max_position_embeddings = max_position_embeddings
|
|
self.use_dynamic_ntk_rope = use_dynamic_ntk_rope
|
|
self.rotary_emb_dim = rotary_emb_dim
|
|
self.use_flash_attn = use_flash_attn
|
|
self.num_heads = num_heads
|
|
assert self.embed_dim % num_heads == 0, "self.kdim must be divisible by num_heads"
|
|
self.head_dim = self.embed_dim // num_heads
|
|
|
|
if self.rotary_emb_dim > 0:
|
|
if self.use_dynamic_ntk_rope:
|
|
self.rotary_emb = DynamicNTKScalingRotaryEmbedding(
|
|
self.rotary_emb_dim,
|
|
scale_base=rotary_emb_scale_base,
|
|
device=device,
|
|
max_position_embeddings=max_position_embeddings,
|
|
scaling_factor=1.0, # Currently do not support dynamic scaling.
|
|
)
|
|
else:
|
|
self.rotary_emb = RotaryEmbedding(self.rotary_emb_dim, scale_base=rotary_emb_scale_base, device=device)
|
|
|
|
# notice here should change bias=True
|
|
Wqkv_cls = ColumnParallelLinearTorch if tp_mode == 'origin_tp' else FSDPLinear
|
|
self.Wqkv = Wqkv_cls(
|
|
embed_dim,
|
|
3 * embed_dim,
|
|
process_group,
|
|
bias=True,
|
|
sequence_parallel=gpc.config.parallel.sequence_parallel,
|
|
**factory_kwargs,
|
|
) # according to https://spaces.ac.cn/archives/9577
|
|
|
|
inner_attn_cls = FlashSelfAttention if use_flash_attn else SelfAttention
|
|
inner_cross_attn_cls = FlashCrossAttention if use_flash_attn else CrossAttention
|
|
self.inner_attn = inner_attn_cls(causal=causal, softmax_scale=softmax_scale, attention_dropout=dropout)
|
|
self.inner_cross_attn = inner_cross_attn_cls(
|
|
causal=causal, softmax_scale=softmax_scale, attention_dropout=dropout
|
|
)
|
|
if tp_mode == 'fstp':
|
|
self.inner_attn = DistributedAttention(self.inner_attn, sequence_process_group=process_group)
|
|
self.inner_cross_attn = DistributedAttention(self.inner_cross_attn, sequence_process_group=process_group)
|
|
|
|
# output projection always have the bias (for now)
|
|
out_proj_cls = RowParallelLinearTorch if tp_mode == 'origin_tp' else FSDPLinear
|
|
self.out_proj = out_proj_cls(
|
|
embed_dim,
|
|
embed_dim,
|
|
process_group,
|
|
sequence_parallel=gpc.config.parallel.sequence_parallel,
|
|
**factory_kwargs,
|
|
)
|
|
|
|
# need to assign tp attribute so that internlm know it is tensor parallel module
|
|
if gpc.get_world_size(ParallelMode.TENSOR) > 1:
|
|
for name in ["out_proj", "Wqkv"]:
|
|
for param in getattr(self, name).parameters():
|
|
setattr(param, IS_TENSOR_PARALLEL, True)
|
|
|
|
def forward(self, x, seqlen=None, inference_params=None, **kwargs):
|
|
if kwargs.get("indexes", None) is not None:
|
|
return self._packed_forward(x=x, inference_params=inference_params, **kwargs)
|
|
else:
|
|
return self._forward(x=x, seqlen=seqlen, inference_params=inference_params, **kwargs)
|
|
|
|
def _forward(self, x, seqlen=None, inference_params=None, **kwargs):
|
|
"""
|
|
Arguments:
|
|
x: (batch, seqlen, hidden_dim) (where hidden_dim = num heads * head dim) if seqlen=None.
|
|
If seqlen is not None, x is (batch * seqlen, hidden_dim). This is so that when we
|
|
split x during sequence parallel, we split the batch * seqlen dimension
|
|
(in case batch is small).
|
|
"""
|
|
qkv = self.Wqkv(x)
|
|
if seqlen is None:
|
|
qkv = rearrange(qkv, "b s (three h d) -> b s three h d", three=3, d=self.head_dim)
|
|
else:
|
|
qkv = rearrange(qkv, "(b s) (three h d) -> b s three h d", s=seqlen, three=3, d=self.head_dim)
|
|
|
|
if inference_params is None:
|
|
if self.rotary_emb_dim > 0:
|
|
kwargs["inference_params"] = inference_params
|
|
qkv = self.rotary_emb(qkv, **kwargs)
|
|
if gpc.config.model.dtype is torch.float32 and gpc.config.model.use_flash_attn:
|
|
with torch.cuda.amp.autocast(dtype=torch.bfloat16):
|
|
if qkv.dtype not in [torch.float16, torch.bfloat16]:
|
|
qkv = qkv.to(torch.bfloat16)
|
|
context = self.inner_attn(qkv).to(x.dtype)
|
|
else:
|
|
context = self.inner_attn(qkv)
|
|
else:
|
|
if self.use_dynamic_ntk_rope:
|
|
q = qkv[:, :, 0]
|
|
assert self.layer_idx is not None, "Generation requires layer_idx in the constructor"
|
|
kv = _update_kv_cache(qkv[:, :, 1:], inference_params, self.layer_idx)
|
|
if inference_params.sequence_len_offset != 0:
|
|
# q shape: [bsz, 1, nheads, head_dim]
|
|
# kv shape: [bsz, seqlen, 2, nheads, head_dim]
|
|
bsz, seq_len, _, nheads, head_dim = kv.shape
|
|
q = torch.cat([q.new_zeros(size=(bsz, seq_len - 1, nheads, head_dim)), q], dim=1).unsqueeze(2)
|
|
qkv = torch.cat([q, kv], dim=2)
|
|
if self.rotary_emb_dim > 0:
|
|
qkv = self.rotary_emb(qkv)
|
|
q = qkv[:, [-1], 0]
|
|
kv = qkv[:, :, 1:]
|
|
else:
|
|
if inference_params.sequence_len_offset > self.max_position_embeddings:
|
|
warnings.warn(
|
|
"Notice your prompt's length is longer than model's max_position_embeddings: "
|
|
f"{self.max_position_embeddings}, which will cause deviations in dynamic ntk calculations."
|
|
)
|
|
if self.rotary_emb_dim > 0:
|
|
kwargs["inference_params"] = inference_params
|
|
qkv = self.rotary_emb(qkv, **kwargs)
|
|
q = qkv[:, :, 0]
|
|
kv = qkv[:, :, 1:]
|
|
else:
|
|
if self.rotary_emb_dim > 0:
|
|
kwargs["inference_params"] = inference_params
|
|
qkv = self.rotary_emb(qkv, **kwargs)
|
|
q = qkv[:, :, 0]
|
|
assert self.layer_idx is not None, "Generation requires layer_idx in the constructor"
|
|
kv = _update_kv_cache(qkv[:, :, 1:], inference_params, self.layer_idx)
|
|
|
|
# If we're processing the prompt, causal=None (use self.causal).
|
|
# If we're decoding, then causal=False.
|
|
causal = None if inference_params.sequence_len_offset == 0 else False
|
|
context = self.inner_cross_attn(q, kv, causal=causal)
|
|
|
|
if seqlen is None:
|
|
context = rearrange(context, "b s h d -> b s (h d)")
|
|
else:
|
|
context = rearrange(context, "b s h d -> (b s) (h d)")
|
|
|
|
out = self.out_proj(context)
|
|
return out
|
|
|
|
def _packed_forward(self, x, inference_params=None, **kwargs):
|
|
"""
|
|
Arguments:
|
|
x: (batch, seqlen, hidden_dim) (where hidden_dim = num heads * head dim) if seqlen=None.
|
|
If seqlen is not None, x is (batch * seqlen, hidden_dim). This is so that when we
|
|
split x during sequence parallel, we split the batch * seqlen dimension
|
|
(in case batch is small).
|
|
"""
|
|
qkv = self.Wqkv(x) # total x hsz'
|
|
qkv = rearrange(qkv, "t (three h d) -> t three h d", three=3, d=self.head_dim) # total x 3 x n_head x d
|
|
qkv = self.rotary_emb(qkv, **kwargs)
|
|
kwargs.pop("indexes")
|
|
|
|
if inference_params is None:
|
|
if gpc.config.model.dtype is torch.float32 and gpc.config.model.use_flash_attn:
|
|
with torch.cuda.amp.autocast(dtype=torch.bfloat16):
|
|
if qkv.dtype not in [torch.float16, torch.bfloat16]:
|
|
qkv = qkv.to(torch.bfloat16)
|
|
context = self.inner_attn(qkv, **kwargs).to(x.dtype)
|
|
else:
|
|
context = self.inner_attn(qkv, **kwargs)
|
|
|
|
else:
|
|
raise RuntimeError("Not support this right now")
|
|
|
|
context = rearrange(context, "b h d -> b (h d)") # recover the shape
|
|
out = self.out_proj(context)
|
|
return out
|