#!/usr/bin/env python3 # -*- coding: utf-8 -*- import math from types import MethodType from typing import Optional, Tuple import torch import torch.nn as nn import torch.nn.functional as F from einops import rearrange from transformers.models.llama.configuration_llama import LlamaConfig from transformers.models.llama.modeling_llama import ( LlamaAttention, LlamaForCausalLM, LlamaModel, LlamaRMSNorm, apply_rotary_pos_emb, repeat_kv, ) from colossalai.accelerator import get_accelerator from colossalai.logging import get_dist_logger logger = get_dist_logger() if get_accelerator().name == "cuda": from flash_attn.bert_padding import pad_input, unpad_input from flash_attn.flash_attn_interface import flash_attn_func, flash_attn_varlen_kvpacked_func from flash_attn.ops.rms_norm import rms_norm def _prepare_decoder_attention_mask( self: LlamaModel, attention_mask: torch.BoolTensor, input_shape: torch.Size, inputs_embeds: torch.Tensor, past_key_values_length: int, ) -> Optional[torch.Tensor]: """ Decoder attetion mask """ if past_key_values_length > 0 and attention_mask is not None: attention_mask = torch.cat( tensors=( torch.full( size=(input_shape[0], past_key_values_length), fill_value=True, dtype=attention_mask.dtype, device=attention_mask.device, ), attention_mask, ), dim=-1, ) # (bsz, past_key_values_length + q_len) if attention_mask is not None and torch.all(attention_mask): return None # Faster return attention_mask def attention_forward( self: LlamaAttention, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Tuple[torch.Tensor]] = None, output_attentions: bool = False, use_cache: bool = False, **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: """ Re-define LLaMA-2 `LlamaAttention` forward method using flash-attention. """ if output_attentions: logger.warning( "Argument `output_attentions` is not supported for flash-attention patched `LlamaAttention`, " "return `None` instead." ) bsz, q_len, _ = hidden_states.size() if self.config.pretraining_tp > 1: q_slicing, kv_slicing = ( dim // self.config.pretraining_tp for dim in ( self.num_heads * self.head_dim, self.num_key_value_heads * self.head_dim, ) ) # `Tuple[int, int]` q_slices, k_slices, v_slices = ( proj.weight.split(slicing, dim=0) for proj, slicing in ( (self.q_proj, q_slicing), (self.k_proj, kv_slicing), (self.v_proj, kv_slicing), ) ) # Tuple[Tuple[torch.Tensor], Tuple[torch.Tensor], Tuple[torch.Tensor]] q, k, v = ( torch.cat( [F.linear(hidden_states, slices[i]) for i in range(self.config.pretraining_tp)], dim=-1, ) for slices in (q_slices, k_slices, v_slices) ) # `Tuple[torch.Tensor, torch.Tensor, torch.Tensor]` of shape: # (bsz, q_len, num_heads * head_dim), # (bsz, q_len, num_key_value_heads * head_dim), # (bsz, q_len, num_key_value_heads * head_dim) else: q, k, v = (proj(hidden_states) for proj in (self.q_proj, self.k_proj, self.v_proj)) # `Tuple[torch.Tensor, torch.Tensor, torch.Tensor]` of shape: # (bsz, q_len, num_heads * head_dim), # (bsz, q_len, num_key_value_heads * head_dim), # (bsz, q_len, num_key_value_heads * head_dim) # (bsz, q_len, num_heads * head_dim) -> (bsz, num_heads, q_len, head_dim); # (bsz, q_len, num_key_value_heads * head_dim) -> (bsz, num_key_value_heads, q_len, head_dim); # (bsz, q_len, num_key_value_heads * head_dim) -> (bsz, num_key_value_heads, q_len, head_dim) q, k, v = ( states.view(bsz, q_len, num_heads, self.head_dim).transpose(1, 2) for states, num_heads in ( (q, self.num_heads), (k, self.num_key_value_heads), (v, self.num_key_value_heads), ) ) kv_len = k.shape[-2] # initially, `kv_len` == `q_len` past_kv_len = 0 if past_key_value is not None: # if `past_key_value` is not None, `kv_len` > `q_len`. past_kv_len = past_key_value[0].shape[-2] kv_len += past_kv_len # two `torch.Tensor` objs of shape (1, 1, kv_len, head_dim) cos, sin = self.rotary_emb(v, seq_len=kv_len) # (bsz, num_heads, q_len, head_dim), (bsz, num_key_value_heads, q_len, head_dim) q, k = apply_rotary_pos_emb(q=q, k=k, cos=cos, sin=sin, position_ids=position_ids) if past_key_value is not None: # reuse k, v, self_attention k = torch.cat([past_key_value[0], k], dim=2) v = torch.cat([past_key_value[1], v], dim=2) past_key_value = (k, v) if use_cache else None # repeat k/v heads if n_kv_heads < n_heads k = repeat_kv(hidden_states=k, n_rep=self.num_key_value_groups) # (bsz, num_key_value_heads, q_len, head_dim) -> (bsz, num_heads, q_len, head_dim) v = repeat_kv(hidden_states=v, n_rep=self.num_key_value_groups) # (bsz, num_key_value_heads, q_len, head_dim) -> (bsz, num_heads, q_len, head_dim) key_padding_mask = attention_mask # (bsz, num_heads, q_len, head_dim) -> (bsz, q_len, num_heads, head_dim) q, k, v = (states.transpose(1, 2) for states in (q, k, v)) if past_kv_len > 0: q = torch.cat( tensors=( torch.full( size=(bsz, past_kv_len, self.num_heads, self.head_dim), fill_value=0.0, dtype=q.dtype, device=q.device, ), q, ), dim=1, ) # (bsz, past_kv_len + q_len, num_heads, head_dim) if key_padding_mask is None: # (bsz, past_kv_len + q_len, num_heads, head_dim) output = flash_attn_func(q=q, k=k, v=v, dropout_p=0.0, softmax_scale=None, causal=True) # (bsz, ) output = rearrange( output, pattern="... h d -> ... (h d)" ) # (bsz, past_kv_len + q_len, num_heads * head_dim) else: q, indices, cu_q_lens, max_q_len = unpad_input(hidden_states=q, attention_mask=key_padding_mask) kv, _, cu_kv_lens, max_kv_len = unpad_input( hidden_states=torch.stack(tensors=(k, v), dim=2), attention_mask=key_padding_mask, ) output_unpad = flash_attn_varlen_kvpacked_func( q=q, kv=kv, cu_seqlens_q=cu_q_lens, cu_seqlens_k=cu_kv_lens, max_seqlen_q=max_q_len, max_seqlen_k=max_kv_len, dropout_p=0.0, softmax_scale=None, causal=True, ) output = pad_input( hidden_states=rearrange(output_unpad, pattern="nnz h d -> nnz (h d)"), indices=indices, batch=bsz, seqlen=past_kv_len + q_len, ) # (bsz, past_kv_len + q_len, num_heads * head_dim) if past_kv_len > 0: # Strip off the zero query outputs. output = output[:, past_kv_len:, ...] # (bsz, q_len, num_heads * head_dim) output = self.o_proj(output) # (bsz, q_len, hidden_size) return output, None, past_key_value def rms_norm_forward(self: LlamaRMSNorm, hidden_states: torch.Tensor) -> torch.Tensor: """ Formard function for RMS Norm """ return rms_norm(x=hidden_states, weight=self.weight, epsilon=self.variance_epsilon) def replace_with_flash_attention(model: LlamaForCausalLM) -> None: for name, module in model.named_modules(): if isinstance(module, LlamaAttention): module.forward = MethodType(attention_forward, module) if isinstance(module, LlamaModel): module._prepare_decoder_attention_mask = MethodType(_prepare_decoder_attention_mask, module) if isinstance(module, LlamaRMSNorm): module.forward = MethodType(rms_norm_forward, module) elif get_accelerator().name == "npu": import torch_npu class NPULlamaAttention(LlamaAttention): use_flash: bool = True def __init__(self, config: LlamaConfig): super().__init__(config) self.setup() def setup(self): self._softmax_scale = 1 / math.sqrt(self.head_dim) def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Tuple[torch.Tensor]] = None, output_attentions: bool = False, use_cache: bool = False, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: bsz, q_len, _ = hidden_states.size() if self.config.pretraining_tp > 1: key_value_slicing = (self.num_key_value_heads * self.head_dim) // self.config.pretraining_tp query_slices = self.q_proj.weight.split( (self.num_heads * self.head_dim) // self.config.pretraining_tp, dim=0 ) key_slices = self.k_proj.weight.split(key_value_slicing, dim=0) value_slices = self.v_proj.weight.split(key_value_slicing, dim=0) query_states = [F.linear(hidden_states, query_slices[i]) for i in range(self.config.pretraining_tp)] query_states = torch.cat(query_states, dim=-1) key_states = [F.linear(hidden_states, key_slices[i]) for i in range(self.config.pretraining_tp)] key_states = torch.cat(key_states, dim=-1) value_states = [F.linear(hidden_states, value_slices[i]) for i in range(self.config.pretraining_tp)] value_states = torch.cat(value_states, dim=-1) else: query_states = self.q_proj(hidden_states) key_states = self.k_proj(hidden_states) value_states = self.v_proj(hidden_states) query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) kv_seq_len = key_states.shape[-2] if past_key_value is not None: kv_seq_len += past_key_value[0].shape[-2] cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) if past_key_value is not None: # reuse k, v, self_attention key_states = torch.cat([past_key_value[0], key_states], dim=2) value_states = torch.cat([past_key_value[1], value_states], dim=2) past_key_value = (key_states, value_states) if use_cache else None key_states = repeat_kv(key_states, self.num_key_value_groups) value_states = repeat_kv(value_states, self.num_key_value_groups) if not self.use_flash: attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len): raise ValueError( f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is" f" {attn_weights.size()}" ) if attention_mask is not None: if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): raise ValueError( f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" ) attn_weights = attn_weights + attention_mask # upcast attention to fp32 attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) attn_output = torch.matmul(attn_weights, value_states) else: attn_output, *_ = torch_npu.npu_fusion_attention( query_states, key_states, value_states, self.num_heads, "BNSD", atten_mask=attention_mask.bool(), scale=self._softmax_scale, padding_mask=None, pre_tockens=65535, next_tockens=0, keep_prob=1.0, inner_precise=0, ) if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): raise ValueError( f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" f" {attn_output.size()}" ) attn_output = attn_output.transpose(1, 2).contiguous() attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) if self.config.pretraining_tp > 1: attn_output = attn_output.split(self.hidden_size // self.config.pretraining_tp, dim=2) o_proj_slices = self.o_proj.weight.split(self.hidden_size // self.config.pretraining_tp, dim=1) attn_output = sum( [F.linear(attn_output[i], o_proj_slices[i]) for i in range(self.config.pretraining_tp)] ) else: attn_output = self.o_proj(attn_output) if not output_attentions: attn_weights = None return attn_output, attn_weights, past_key_value class NPURMSNorm(LlamaRMSNorm): def forward(self, hidden_states): return torch_npu.npu_rms_norm(hidden_states, self.weight, epsilon=self.variance_epsilon)[0] def replace_with_flash_attention(model: LlamaForCausalLM) -> None: for name, module in model.named_modules(): if isinstance(module, LlamaAttention): module.__class__ = NPULlamaAttention module.setup() if isinstance(module, LlamaRMSNorm): module.__class__ = NPURMSNorm