From a4cec1715b165235e2ba7cef4efd72b1ee6ef041 Mon Sep 17 00:00:00 2001 From: Hongxin Liu Date: Mon, 5 Feb 2024 16:48:34 +0800 Subject: [PATCH] [llama] add flash attn patch for npu (#5362) --- .../utils/flash_attention_patch.py | 486 +++++++++++------- 1 file changed, 313 insertions(+), 173 deletions(-) diff --git a/applications/Colossal-LLaMA-2/colossal_llama2/utils/flash_attention_patch.py b/applications/Colossal-LLaMA-2/colossal_llama2/utils/flash_attention_patch.py index 1926ec78a..6c048c3b1 100644 --- a/applications/Colossal-LLaMA-2/colossal_llama2/utils/flash_attention_patch.py +++ b/applications/Colossal-LLaMA-2/colossal_llama2/utils/flash_attention_patch.py @@ -1,15 +1,15 @@ #!/usr/bin/env python3 # -*- coding: utf-8 -*- +import math from types import MethodType from typing import Optional, Tuple import torch +import torch.nn as nn import torch.nn.functional as F from einops import rearrange -from flash_attn.bert_padding import pad_input, unpad_input -from flash_attn.flash_attn_interface import flash_attn_func, flash_attn_varlen_kvpacked_func -from flash_attn.ops.rms_norm import rms_norm +from transformers.models.llama.configuration_llama import LlamaConfig from transformers.models.llama.modeling_llama import ( LlamaAttention, LlamaForCausalLM, @@ -19,194 +19,334 @@ from transformers.models.llama.modeling_llama import ( repeat_kv, ) +from colossalai.accelerator import get_accelerator from colossalai.logging import get_dist_logger logger = get_dist_logger() +if get_accelerator().name == "cuda": + from flash_attn.bert_padding import pad_input, unpad_input + from flash_attn.flash_attn_interface import flash_attn_func, flash_attn_varlen_kvpacked_func + from flash_attn.ops.rms_norm import rms_norm -def _prepare_decoder_attention_mask( - self: LlamaModel, - attention_mask: torch.BoolTensor, - input_shape: torch.Size, - inputs_embeds: torch.Tensor, - past_key_values_length: int, -) -> Optional[torch.Tensor]: - """ - Decoder attetion mask - """ - if past_key_values_length > 0 and attention_mask is not None: - attention_mask = torch.cat( - tensors=( - torch.full( - size=(input_shape[0], past_key_values_length), - fill_value=True, - dtype=attention_mask.dtype, - device=attention_mask.device, + def _prepare_decoder_attention_mask( + self: LlamaModel, + attention_mask: torch.BoolTensor, + input_shape: torch.Size, + inputs_embeds: torch.Tensor, + past_key_values_length: int, + ) -> Optional[torch.Tensor]: + """ + Decoder attetion mask + """ + if past_key_values_length > 0 and attention_mask is not None: + attention_mask = torch.cat( + tensors=( + torch.full( + size=(input_shape[0], past_key_values_length), + fill_value=True, + dtype=attention_mask.dtype, + device=attention_mask.device, + ), + attention_mask, ), - attention_mask, - ), - dim=-1, - ) # (bsz, past_key_values_length + q_len) - if attention_mask is not None and torch.all(attention_mask): - return None # Faster - return attention_mask - - -def attention_forward( - self: LlamaAttention, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, - output_attentions: bool = False, - use_cache: bool = False, - **kwargs, -) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - """ - Re-define LLaMA-2 `LlamaAttention` forward method using flash-attention. - """ - if output_attentions: - logger.warning( - "Argument `output_attentions` is not supported for flash-attention patched `LlamaAttention`, " - "return `None` instead." - ) - - bsz, q_len, _ = hidden_states.size() - - if self.config.pretraining_tp > 1: - q_slicing, kv_slicing = ( - dim // self.config.pretraining_tp - for dim in ( - self.num_heads * self.head_dim, - self.num_key_value_heads * self.head_dim, - ) - ) # `Tuple[int, int]` - q_slices, k_slices, v_slices = ( - proj.weight.split(slicing, dim=0) - for proj, slicing in ( - (self.q_proj, q_slicing), - (self.k_proj, kv_slicing), - (self.v_proj, kv_slicing), - ) - ) # Tuple[Tuple[torch.Tensor], Tuple[torch.Tensor], Tuple[torch.Tensor]] - q, k, v = ( - torch.cat( - [F.linear(hidden_states, slices[i]) for i in range(self.config.pretraining_tp)], dim=-1, + ) # (bsz, past_key_values_length + q_len) + if attention_mask is not None and torch.all(attention_mask): + return None # Faster + return attention_mask + + def attention_forward( + self: LlamaAttention, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: bool = False, + use_cache: bool = False, + **kwargs, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + """ + Re-define LLaMA-2 `LlamaAttention` forward method using flash-attention. + """ + if output_attentions: + logger.warning( + "Argument `output_attentions` is not supported for flash-attention patched `LlamaAttention`, " + "return `None` instead." + ) + + bsz, q_len, _ = hidden_states.size() + + if self.config.pretraining_tp > 1: + q_slicing, kv_slicing = ( + dim // self.config.pretraining_tp + for dim in ( + self.num_heads * self.head_dim, + self.num_key_value_heads * self.head_dim, + ) + ) # `Tuple[int, int]` + q_slices, k_slices, v_slices = ( + proj.weight.split(slicing, dim=0) + for proj, slicing in ( + (self.q_proj, q_slicing), + (self.k_proj, kv_slicing), + (self.v_proj, kv_slicing), + ) + ) # Tuple[Tuple[torch.Tensor], Tuple[torch.Tensor], Tuple[torch.Tensor]] + q, k, v = ( + torch.cat( + [F.linear(hidden_states, slices[i]) for i in range(self.config.pretraining_tp)], + dim=-1, + ) + for slices in (q_slices, k_slices, v_slices) + ) + # `Tuple[torch.Tensor, torch.Tensor, torch.Tensor]` of shape: + # (bsz, q_len, num_heads * head_dim), + # (bsz, q_len, num_key_value_heads * head_dim), + # (bsz, q_len, num_key_value_heads * head_dim) + else: + q, k, v = (proj(hidden_states) for proj in (self.q_proj, self.k_proj, self.v_proj)) + # `Tuple[torch.Tensor, torch.Tensor, torch.Tensor]` of shape: + # (bsz, q_len, num_heads * head_dim), + # (bsz, q_len, num_key_value_heads * head_dim), + # (bsz, q_len, num_key_value_heads * head_dim) + + # (bsz, q_len, num_heads * head_dim) -> (bsz, num_heads, q_len, head_dim); + # (bsz, q_len, num_key_value_heads * head_dim) -> (bsz, num_key_value_heads, q_len, head_dim); + # (bsz, q_len, num_key_value_heads * head_dim) -> (bsz, num_key_value_heads, q_len, head_dim) + q, k, v = ( + states.view(bsz, q_len, num_heads, self.head_dim).transpose(1, 2) + for states, num_heads in ( + (q, self.num_heads), + (k, self.num_key_value_heads), + (v, self.num_key_value_heads), ) - for slices in (q_slices, k_slices, v_slices) ) - # `Tuple[torch.Tensor, torch.Tensor, torch.Tensor]` of shape: - # (bsz, q_len, num_heads * head_dim), - # (bsz, q_len, num_key_value_heads * head_dim), - # (bsz, q_len, num_key_value_heads * head_dim) - else: - q, k, v = (proj(hidden_states) for proj in (self.q_proj, self.k_proj, self.v_proj)) - # `Tuple[torch.Tensor, torch.Tensor, torch.Tensor]` of shape: - # (bsz, q_len, num_heads * head_dim), - # (bsz, q_len, num_key_value_heads * head_dim), - # (bsz, q_len, num_key_value_heads * head_dim) + kv_len = k.shape[-2] # initially, `kv_len` == `q_len` + past_kv_len = 0 + if past_key_value is not None: + # if `past_key_value` is not None, `kv_len` > `q_len`. + past_kv_len = past_key_value[0].shape[-2] + kv_len += past_kv_len - # (bsz, q_len, num_heads * head_dim) -> (bsz, num_heads, q_len, head_dim); - # (bsz, q_len, num_key_value_heads * head_dim) -> (bsz, num_key_value_heads, q_len, head_dim); - # (bsz, q_len, num_key_value_heads * head_dim) -> (bsz, num_key_value_heads, q_len, head_dim) - q, k, v = ( - states.view(bsz, q_len, num_heads, self.head_dim).transpose(1, 2) - for states, num_heads in ( - (q, self.num_heads), - (k, self.num_key_value_heads), - (v, self.num_key_value_heads), - ) - ) - kv_len = k.shape[-2] # initially, `kv_len` == `q_len` - past_kv_len = 0 - if past_key_value is not None: - # if `past_key_value` is not None, `kv_len` > `q_len`. - past_kv_len = past_key_value[0].shape[-2] - kv_len += past_kv_len + # two `torch.Tensor` objs of shape (1, 1, kv_len, head_dim) + cos, sin = self.rotary_emb(v, seq_len=kv_len) + # (bsz, num_heads, q_len, head_dim), (bsz, num_key_value_heads, q_len, head_dim) + q, k = apply_rotary_pos_emb(q=q, k=k, cos=cos, sin=sin, position_ids=position_ids) + if past_key_value is not None: + # reuse k, v, self_attention + k = torch.cat([past_key_value[0], k], dim=2) + v = torch.cat([past_key_value[1], v], dim=2) - # two `torch.Tensor` objs of shape (1, 1, kv_len, head_dim) - cos, sin = self.rotary_emb(v, seq_len=kv_len) - # (bsz, num_heads, q_len, head_dim), (bsz, num_key_value_heads, q_len, head_dim) - q, k = apply_rotary_pos_emb(q=q, k=k, cos=cos, sin=sin, position_ids=position_ids) - if past_key_value is not None: - # reuse k, v, self_attention - k = torch.cat([past_key_value[0], k], dim=2) - v = torch.cat([past_key_value[1], v], dim=2) + past_key_value = (k, v) if use_cache else None - past_key_value = (k, v) if use_cache else None + # repeat k/v heads if n_kv_heads < n_heads + k = repeat_kv(hidden_states=k, n_rep=self.num_key_value_groups) + # (bsz, num_key_value_heads, q_len, head_dim) -> (bsz, num_heads, q_len, head_dim) + v = repeat_kv(hidden_states=v, n_rep=self.num_key_value_groups) + # (bsz, num_key_value_heads, q_len, head_dim) -> (bsz, num_heads, q_len, head_dim) - # repeat k/v heads if n_kv_heads < n_heads - k = repeat_kv(hidden_states=k, n_rep=self.num_key_value_groups) - # (bsz, num_key_value_heads, q_len, head_dim) -> (bsz, num_heads, q_len, head_dim) - v = repeat_kv(hidden_states=v, n_rep=self.num_key_value_groups) - # (bsz, num_key_value_heads, q_len, head_dim) -> (bsz, num_heads, q_len, head_dim) + key_padding_mask = attention_mask + # (bsz, num_heads, q_len, head_dim) -> (bsz, q_len, num_heads, head_dim) + q, k, v = (states.transpose(1, 2) for states in (q, k, v)) - key_padding_mask = attention_mask - # (bsz, num_heads, q_len, head_dim) -> (bsz, q_len, num_heads, head_dim) - q, k, v = (states.transpose(1, 2) for states in (q, k, v)) - - if past_kv_len > 0: - q = torch.cat( - tensors=( - torch.full( - size=(bsz, past_kv_len, self.num_heads, self.head_dim), - fill_value=0.0, - dtype=q.dtype, - device=q.device, + if past_kv_len > 0: + q = torch.cat( + tensors=( + torch.full( + size=(bsz, past_kv_len, self.num_heads, self.head_dim), + fill_value=0.0, + dtype=q.dtype, + device=q.device, + ), + q, ), - q, - ), - dim=1, - ) # (bsz, past_kv_len + q_len, num_heads, head_dim) + dim=1, + ) # (bsz, past_kv_len + q_len, num_heads, head_dim) - if key_padding_mask is None: - # (bsz, past_kv_len + q_len, num_heads, head_dim) - output = flash_attn_func(q=q, k=k, v=v, dropout_p=0.0, softmax_scale=None, causal=True) # (bsz, ) - output = rearrange(output, pattern="... h d -> ... (h d)") # (bsz, past_kv_len + q_len, num_heads * head_dim) - else: - q, indices, cu_q_lens, max_q_len = unpad_input(hidden_states=q, attention_mask=key_padding_mask) - kv, _, cu_kv_lens, max_kv_len = unpad_input( - hidden_states=torch.stack(tensors=(k, v), dim=2), - attention_mask=key_padding_mask, - ) - output_unpad = flash_attn_varlen_kvpacked_func( - q=q, - kv=kv, - cu_seqlens_q=cu_q_lens, - cu_seqlens_k=cu_kv_lens, - max_seqlen_q=max_q_len, - max_seqlen_k=max_kv_len, - dropout_p=0.0, - softmax_scale=None, - causal=True, - ) - output = pad_input( - hidden_states=rearrange(output_unpad, pattern="nnz h d -> nnz (h d)"), - indices=indices, - batch=bsz, - seqlen=past_kv_len + q_len, - ) # (bsz, past_kv_len + q_len, num_heads * head_dim) + if key_padding_mask is None: + # (bsz, past_kv_len + q_len, num_heads, head_dim) + output = flash_attn_func(q=q, k=k, v=v, dropout_p=0.0, softmax_scale=None, causal=True) # (bsz, ) + output = rearrange( + output, pattern="... h d -> ... (h d)" + ) # (bsz, past_kv_len + q_len, num_heads * head_dim) + else: + q, indices, cu_q_lens, max_q_len = unpad_input(hidden_states=q, attention_mask=key_padding_mask) + kv, _, cu_kv_lens, max_kv_len = unpad_input( + hidden_states=torch.stack(tensors=(k, v), dim=2), + attention_mask=key_padding_mask, + ) + output_unpad = flash_attn_varlen_kvpacked_func( + q=q, + kv=kv, + cu_seqlens_q=cu_q_lens, + cu_seqlens_k=cu_kv_lens, + max_seqlen_q=max_q_len, + max_seqlen_k=max_kv_len, + dropout_p=0.0, + softmax_scale=None, + causal=True, + ) + output = pad_input( + hidden_states=rearrange(output_unpad, pattern="nnz h d -> nnz (h d)"), + indices=indices, + batch=bsz, + seqlen=past_kv_len + q_len, + ) # (bsz, past_kv_len + q_len, num_heads * head_dim) - if past_kv_len > 0: - # Strip off the zero query outputs. - output = output[:, past_kv_len:, ...] # (bsz, q_len, num_heads * head_dim) - output = self.o_proj(output) # (bsz, q_len, hidden_size) - return output, None, past_key_value + if past_kv_len > 0: + # Strip off the zero query outputs. + output = output[:, past_kv_len:, ...] # (bsz, q_len, num_heads * head_dim) + output = self.o_proj(output) # (bsz, q_len, hidden_size) + return output, None, past_key_value + def rms_norm_forward(self: LlamaRMSNorm, hidden_states: torch.Tensor) -> torch.Tensor: + """ + Formard function for RMS Norm + """ + return rms_norm(x=hidden_states, weight=self.weight, epsilon=self.variance_epsilon) -def rms_norm_forward(self: LlamaRMSNorm, hidden_states: torch.Tensor) -> torch.Tensor: - """ - Formard function for RMS Norm - """ - return rms_norm(x=hidden_states, weight=self.weight, epsilon=self.variance_epsilon) + def replace_with_flash_attention(model: LlamaForCausalLM) -> None: + for name, module in model.named_modules(): + if isinstance(module, LlamaAttention): + module.forward = MethodType(attention_forward, module) + if isinstance(module, LlamaModel): + module._prepare_decoder_attention_mask = MethodType(_prepare_decoder_attention_mask, module) + if isinstance(module, LlamaRMSNorm): + module.forward = MethodType(rms_norm_forward, module) +elif get_accelerator().name == "npu": + import torch_npu -def replace_with_flash_attention(model: LlamaForCausalLM) -> None: - for name, module in model.named_modules(): - if isinstance(module, LlamaAttention): - module.forward = MethodType(attention_forward, module) - if isinstance(module, LlamaModel): - module._prepare_decoder_attention_mask = MethodType(_prepare_decoder_attention_mask, module) - if isinstance(module, LlamaRMSNorm): - module.forward = MethodType(rms_norm_forward, module) + class NPULlamaAttention(LlamaAttention): + use_flash: bool = True + + def __init__(self, config: LlamaConfig): + super().__init__(config) + self.setup() + + def setup(self): + self._softmax_scale = 1 / math.sqrt(self.head_dim) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: bool = False, + use_cache: bool = False, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + bsz, q_len, _ = hidden_states.size() + + if self.config.pretraining_tp > 1: + key_value_slicing = (self.num_key_value_heads * self.head_dim) // self.config.pretraining_tp + query_slices = self.q_proj.weight.split( + (self.num_heads * self.head_dim) // self.config.pretraining_tp, dim=0 + ) + key_slices = self.k_proj.weight.split(key_value_slicing, dim=0) + value_slices = self.v_proj.weight.split(key_value_slicing, dim=0) + + query_states = [F.linear(hidden_states, query_slices[i]) for i in range(self.config.pretraining_tp)] + query_states = torch.cat(query_states, dim=-1) + + key_states = [F.linear(hidden_states, key_slices[i]) for i in range(self.config.pretraining_tp)] + key_states = torch.cat(key_states, dim=-1) + + value_states = [F.linear(hidden_states, value_slices[i]) for i in range(self.config.pretraining_tp)] + value_states = torch.cat(value_states, dim=-1) + + else: + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + kv_seq_len = key_states.shape[-2] + if past_key_value is not None: + kv_seq_len += past_key_value[0].shape[-2] + cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) + + if past_key_value is not None: + # reuse k, v, self_attention + key_states = torch.cat([past_key_value[0], key_states], dim=2) + value_states = torch.cat([past_key_value[1], value_states], dim=2) + + past_key_value = (key_states, value_states) if use_cache else None + + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + if not self.use_flash: + attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) + + if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len): + raise ValueError( + f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is" + f" {attn_weights.size()}" + ) + + if attention_mask is not None: + if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): + raise ValueError( + f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" + ) + attn_weights = attn_weights + attention_mask + + # upcast attention to fp32 + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) + attn_output = torch.matmul(attn_weights, value_states) + else: + attn_output, *_ = torch_npu.npu_fusion_attention( + query_states, + key_states, + value_states, + self.num_heads, + "BNSD", + atten_mask=attention_mask.bool(), + scale=self._softmax_scale, + padding_mask=None, + pre_tockens=65535, + next_tockens=0, + keep_prob=1.0, + inner_precise=0, + ) + + if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" + f" {attn_output.size()}" + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) + + if self.config.pretraining_tp > 1: + attn_output = attn_output.split(self.hidden_size // self.config.pretraining_tp, dim=2) + o_proj_slices = self.o_proj.weight.split(self.hidden_size // self.config.pretraining_tp, dim=1) + attn_output = sum( + [F.linear(attn_output[i], o_proj_slices[i]) for i in range(self.config.pretraining_tp)] + ) + else: + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + class NPURMSNorm(LlamaRMSNorm): + def forward(self, hidden_states): + return torch_npu.npu_rms_norm(hidden_states, self.weight, epsilon=self.variance_epsilon)[0] + + def replace_with_flash_attention(model: LlamaForCausalLM) -> None: + for name, module in model.named_modules(): + if isinstance(module, LlamaAttention): + module.__class__ = NPULlamaAttention + module.setup() + if isinstance(module, LlamaRMSNorm): + module.__class__ = NPURMSNorm