mirror of https://github.com/hpcaitech/ColossalAI
839 lines
34 KiB
Python
839 lines
34 KiB
Python
|
import math
|
||
|
import os
|
||
|
import types
|
||
|
from collections import defaultdict
|
||
|
from functools import partial
|
||
|
from typing import List, Optional, Tuple, Union
|
||
|
|
||
|
import torch
|
||
|
import torch.nn as nn
|
||
|
import torch.nn.functional as F
|
||
|
from torch_int.nn.bmm import BMM_S8T_S8N_F32T, BMM_S8T_S8N_S8T
|
||
|
from transformers import PreTrainedModel
|
||
|
from transformers.modeling_outputs import BaseModelOutputWithPast
|
||
|
from transformers.models.llama.configuration_llama import LlamaConfig
|
||
|
from transformers.models.llama.modeling_llama import (
|
||
|
LLAMA_INPUTS_DOCSTRING,
|
||
|
LlamaAttention,
|
||
|
LlamaDecoderLayer,
|
||
|
LlamaMLP,
|
||
|
LlamaRotaryEmbedding,
|
||
|
repeat_kv,
|
||
|
rotate_half,
|
||
|
)
|
||
|
from transformers.utils import add_start_docstrings_to_model_forward
|
||
|
|
||
|
from colossalai.inference.tensor_parallel.batch_infer_state import BatchInferState
|
||
|
from colossalai.kernel.triton import (
|
||
|
copy_kv_cache_to_dest,
|
||
|
int8_rotary_embedding_fwd,
|
||
|
smooth_llama_context_attn_fwd,
|
||
|
smooth_token_attention_fwd,
|
||
|
)
|
||
|
|
||
|
from .base_model import BaseSmoothForCausalLM
|
||
|
from .linear import W8A8B8O8Linear, W8A8BFP32O32LinearSiLU, W8A8BFP32OFP32Linear
|
||
|
|
||
|
|
||
|
class LLamaSmoothquantAttention(nn.Module):
|
||
|
def __init__(
|
||
|
self,
|
||
|
hidden_size: int,
|
||
|
num_heads: int,
|
||
|
):
|
||
|
super().__init__()
|
||
|
self.hidden_size = hidden_size
|
||
|
self.num_heads = num_heads
|
||
|
self.head_dim = hidden_size // num_heads
|
||
|
|
||
|
if (self.head_dim * num_heads) != self.hidden_size:
|
||
|
raise ValueError(
|
||
|
f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
|
||
|
f" and `num_heads`: {num_heads})."
|
||
|
)
|
||
|
|
||
|
self.qk_bmm = BMM_S8T_S8N_F32T(1.0)
|
||
|
self.pv_bmm = BMM_S8T_S8N_S8T(1.0)
|
||
|
|
||
|
self.k_proj = W8A8B8O8Linear(hidden_size, hidden_size)
|
||
|
self.v_proj = W8A8B8O8Linear(hidden_size, hidden_size)
|
||
|
self.q_proj = W8A8B8O8Linear(hidden_size, hidden_size)
|
||
|
self.o_proj = W8A8BFP32OFP32Linear(hidden_size, hidden_size)
|
||
|
|
||
|
self.register_buffer("q_output_scale", torch.tensor([1.0]))
|
||
|
self.register_buffer("k_output_scale", torch.tensor([1.0]))
|
||
|
self.register_buffer("v_output_scale", torch.tensor([1.0]))
|
||
|
self.register_buffer("q_rotary_output_scale", torch.tensor([1.0]))
|
||
|
self.register_buffer("k_rotary_output_scale", torch.tensor([1.0]))
|
||
|
self.register_buffer("out_input_scale", torch.tensor([1.0]))
|
||
|
self.register_buffer("attn_input_scale", torch.tensor([1.0]))
|
||
|
|
||
|
self._init_rope()
|
||
|
self.num_key_value_heads = num_heads
|
||
|
|
||
|
def _init_rope(self):
|
||
|
self.rotary_emb = LlamaRotaryEmbedding(
|
||
|
self.head_dim,
|
||
|
max_position_embeddings=2048,
|
||
|
base=10000.0,
|
||
|
)
|
||
|
|
||
|
@staticmethod
|
||
|
def pack(
|
||
|
module: LlamaAttention,
|
||
|
attn_input_scale: float,
|
||
|
q_output_scale: float,
|
||
|
k_output_scale: float,
|
||
|
v_output_scale: float,
|
||
|
q_rotary_output_scale: float,
|
||
|
k_rotary_output_scale: float,
|
||
|
out_input_scale: float,
|
||
|
):
|
||
|
int8_module = LLamaSmoothquantAttention(module.hidden_size, module.num_heads)
|
||
|
|
||
|
int8_module.attn_input_scale = torch.tensor([attn_input_scale])
|
||
|
|
||
|
int8_module.q_output_scale = torch.tensor([q_output_scale])
|
||
|
int8_module.k_output_scale = torch.tensor([k_output_scale])
|
||
|
int8_module.v_output_scale = torch.tensor([v_output_scale])
|
||
|
|
||
|
int8_module.q_rotary_output_scale = torch.tensor([q_rotary_output_scale])
|
||
|
int8_module.k_rotary_output_scale = torch.tensor([k_rotary_output_scale])
|
||
|
|
||
|
int8_module.q_proj = W8A8B8O8Linear.from_float(module.q_proj, attn_input_scale, q_output_scale)
|
||
|
int8_module.k_proj = W8A8B8O8Linear.from_float(module.k_proj, attn_input_scale, k_output_scale)
|
||
|
int8_module.v_proj = W8A8B8O8Linear.from_float(module.v_proj, attn_input_scale, v_output_scale)
|
||
|
int8_module.o_proj = W8A8BFP32OFP32Linear.from_float(module.o_proj, out_input_scale)
|
||
|
|
||
|
int8_module.out_input_scale = torch.tensor([out_input_scale])
|
||
|
|
||
|
return int8_module
|
||
|
|
||
|
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
|
||
|
return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
|
||
|
|
||
|
@torch.no_grad()
|
||
|
def forward(
|
||
|
self,
|
||
|
hidden_states: torch.Tensor,
|
||
|
rotary_emb: Tuple[torch.Tensor] = None,
|
||
|
attention_mask: Optional[torch.Tensor] = None,
|
||
|
position_ids: Optional[torch.LongTensor] = None,
|
||
|
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
||
|
output_attentions: bool = False,
|
||
|
use_cache: bool = False,
|
||
|
padding_mask: Optional[torch.LongTensor] = None,
|
||
|
infer_state: Optional[BatchInferState] = None,
|
||
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||
|
bsz, q_len, _ = hidden_states.size()
|
||
|
|
||
|
query_states = self.q_proj(hidden_states)
|
||
|
key_states = self.k_proj(hidden_states)
|
||
|
value_states = self.v_proj(hidden_states)
|
||
|
|
||
|
cos = rotary_emb[0]
|
||
|
sin = rotary_emb[1]
|
||
|
|
||
|
int8_rotary_embedding_fwd(
|
||
|
query_states.view(-1, self.num_heads, self.head_dim),
|
||
|
cos,
|
||
|
sin,
|
||
|
self.q_output_scale.item(),
|
||
|
self.q_rotary_output_scale.item(),
|
||
|
)
|
||
|
int8_rotary_embedding_fwd(
|
||
|
key_states.view(-1, self.num_heads, self.head_dim),
|
||
|
cos,
|
||
|
sin,
|
||
|
self.k_output_scale.item(),
|
||
|
self.k_rotary_output_scale.item(),
|
||
|
)
|
||
|
|
||
|
def _copy_kv_to_mem_cache(layer_id, key_buffer, value_buffer, context_mem_index, mem_manager):
|
||
|
copy_kv_cache_to_dest(key_buffer, context_mem_index, mem_manager.key_buffer[layer_id])
|
||
|
copy_kv_cache_to_dest(value_buffer, context_mem_index, mem_manager.value_buffer[layer_id])
|
||
|
return
|
||
|
|
||
|
query_states = query_states.view(-1, self.num_heads, self.head_dim)
|
||
|
key_states = key_states.view(-1, self.num_heads, self.head_dim)
|
||
|
value_states = value_states.view(-1, self.num_heads, self.head_dim)
|
||
|
|
||
|
if infer_state.is_context_stage:
|
||
|
# first token generation
|
||
|
|
||
|
# copy key and value calculated in current step to memory manager
|
||
|
_copy_kv_to_mem_cache(
|
||
|
infer_state.decode_layer_id,
|
||
|
key_states,
|
||
|
value_states,
|
||
|
infer_state.context_mem_index,
|
||
|
infer_state.cache_manager,
|
||
|
)
|
||
|
|
||
|
attn_output = torch.empty_like(query_states)
|
||
|
|
||
|
smooth_llama_context_attn_fwd(
|
||
|
query_states,
|
||
|
key_states,
|
||
|
value_states,
|
||
|
attn_output,
|
||
|
self.q_rotary_output_scale.item(),
|
||
|
self.k_rotary_output_scale.item(),
|
||
|
self.v_output_scale.item(),
|
||
|
self.out_input_scale.item(),
|
||
|
infer_state.start_loc,
|
||
|
infer_state.seq_len,
|
||
|
q_len,
|
||
|
)
|
||
|
|
||
|
else:
|
||
|
if infer_state.decode_is_contiguous:
|
||
|
# if decode is contiguous, then we copy to key cache and value cache in cache manager directly
|
||
|
cache_k = infer_state.cache_manager.key_buffer[infer_state.decode_layer_id][
|
||
|
infer_state.decode_mem_start : infer_state.decode_mem_end, :, :
|
||
|
]
|
||
|
cache_v = infer_state.cache_manager.value_buffer[infer_state.decode_layer_id][
|
||
|
infer_state.decode_mem_start : infer_state.decode_mem_end, :, :
|
||
|
]
|
||
|
cache_k.copy_(key_states)
|
||
|
cache_v.copy_(value_states)
|
||
|
else:
|
||
|
# if decode is not contiguous, use triton kernel to copy key and value cache
|
||
|
# k, v shape: [batch_size, num_heads, head_dim/embed_size_per_head
|
||
|
_copy_kv_to_mem_cache(
|
||
|
infer_state.decode_layer_id,
|
||
|
key_states,
|
||
|
value_states,
|
||
|
infer_state.decode_mem_index,
|
||
|
infer_state.cache_manager,
|
||
|
)
|
||
|
|
||
|
# (batch_size, seqlen, nheads, headdim)
|
||
|
attn_output = torch.empty_like(query_states)
|
||
|
|
||
|
smooth_token_attention_fwd(
|
||
|
query_states,
|
||
|
infer_state.cache_manager.key_buffer[infer_state.decode_layer_id],
|
||
|
infer_state.cache_manager.value_buffer[infer_state.decode_layer_id],
|
||
|
attn_output,
|
||
|
self.q_rotary_output_scale.item(),
|
||
|
self.k_rotary_output_scale.item(),
|
||
|
self.v_output_scale.item(),
|
||
|
self.out_input_scale.item(),
|
||
|
infer_state.block_loc,
|
||
|
infer_state.start_loc,
|
||
|
infer_state.seq_len,
|
||
|
infer_state.max_len_in_batch,
|
||
|
)
|
||
|
|
||
|
attn_output = attn_output.view(bsz, q_len, self.num_heads * self.head_dim)
|
||
|
attn_output = self.o_proj(attn_output)
|
||
|
|
||
|
return attn_output, None, None
|
||
|
|
||
|
|
||
|
class LlamaLayerNormQ(torch.nn.Module):
|
||
|
def __init__(self, dim, eps=1e-5):
|
||
|
super().__init__()
|
||
|
self.input_scale = 1.0
|
||
|
self.variance_epsilon = eps
|
||
|
self.register_buffer("weight", torch.ones(dim, dtype=torch.float32))
|
||
|
|
||
|
def forward(self, x):
|
||
|
ln_output_fp = torch.nn.functional.layer_norm(x, x.shape[-1:], self.weight, None, self.variance_epsilon)
|
||
|
ln_output_int8 = ln_output_fp.round().clamp(-128, 127).to(torch.int8)
|
||
|
return ln_output_int8
|
||
|
|
||
|
@staticmethod
|
||
|
def from_float(module: torch.nn.LayerNorm, output_scale: float):
|
||
|
assert module.weight.shape[0] == module.weight.numel()
|
||
|
q_module = LlamaLayerNormQ(module.weight.shape[0], module.variance_epsilon)
|
||
|
q_module.weight = module.weight / output_scale
|
||
|
return q_module
|
||
|
|
||
|
|
||
|
class LlamaSmoothquantMLP(nn.Module):
|
||
|
def __init__(self, intermediate_size, hidden_size):
|
||
|
super().__init__()
|
||
|
self.gate_proj = W8A8BFP32O32LinearSiLU(hidden_size, intermediate_size)
|
||
|
self.up_proj = W8A8BFP32OFP32Linear(hidden_size, intermediate_size)
|
||
|
self.down_proj = W8A8BFP32OFP32Linear(intermediate_size, hidden_size)
|
||
|
self.register_buffer("down_proj_input_scale", torch.tensor([1.0]))
|
||
|
|
||
|
@staticmethod
|
||
|
def pack(
|
||
|
mlp_module: LlamaMLP,
|
||
|
gate_proj_input_scale: float,
|
||
|
up_proj_input_scale: float,
|
||
|
down_proj_input_scale: float,
|
||
|
):
|
||
|
int8_module = LlamaSmoothquantMLP(
|
||
|
mlp_module.intermediate_size,
|
||
|
mlp_module.hidden_size,
|
||
|
)
|
||
|
|
||
|
int8_module.gate_proj = W8A8BFP32O32LinearSiLU.from_float(mlp_module.gate_proj, gate_proj_input_scale)
|
||
|
int8_module.up_proj = W8A8BFP32OFP32Linear.from_float(mlp_module.up_proj, up_proj_input_scale)
|
||
|
int8_module.down_proj = W8A8BFP32OFP32Linear.from_float(mlp_module.down_proj, down_proj_input_scale)
|
||
|
int8_module.down_proj_input_scale = torch.tensor([down_proj_input_scale])
|
||
|
return int8_module
|
||
|
|
||
|
def forward(
|
||
|
self,
|
||
|
hidden_states: torch.Tensor,
|
||
|
):
|
||
|
x_shape = hidden_states.shape
|
||
|
gate_out = self.gate_proj(hidden_states)
|
||
|
up_out = self.up_proj(hidden_states)
|
||
|
inter_out = gate_out * up_out
|
||
|
inter_out = inter_out.div_(self.down_proj_input_scale.item()).round().clamp(-128, 127).to(torch.int8)
|
||
|
down_out = self.down_proj(inter_out)
|
||
|
down_out = down_out.view(*x_shape[:-1], -1)
|
||
|
return down_out
|
||
|
|
||
|
|
||
|
class LlamaSmoothquantDecoderLayer(nn.Module):
|
||
|
def __init__(self, config: LlamaConfig):
|
||
|
super().__init__()
|
||
|
self.hidden_size = config.hidden_size
|
||
|
self.self_attn = LLamaSmoothquantAttention(config.hidden_size, config.num_attention_heads)
|
||
|
|
||
|
self.mlp = LlamaSmoothquantMLP(config.intermediate_size, config.hidden_size)
|
||
|
self.input_layernorm = LlamaLayerNormQ(config.hidden_size, eps=config.rms_norm_eps)
|
||
|
|
||
|
self.post_attention_layernorm = LlamaLayerNormQ(config.hidden_size, eps=config.rms_norm_eps)
|
||
|
|
||
|
@staticmethod
|
||
|
def pack(
|
||
|
module: LlamaDecoderLayer,
|
||
|
attn_input_scale: float,
|
||
|
q_output_scale: float,
|
||
|
k_output_scale: float,
|
||
|
v_output_scale: float,
|
||
|
q_rotary_output_scale: float,
|
||
|
k_rotary_output_scale: float,
|
||
|
out_input_scale: float,
|
||
|
gate_input_scale: float,
|
||
|
up_input_scale: float,
|
||
|
down_input_scale: float,
|
||
|
):
|
||
|
config = module.self_attn.config
|
||
|
int8_decoder_layer = LlamaSmoothquantDecoderLayer(config)
|
||
|
|
||
|
int8_decoder_layer.input_layernorm = LlamaLayerNormQ.from_float(module.input_layernorm, attn_input_scale)
|
||
|
int8_decoder_layer.self_attn = LLamaSmoothquantAttention.pack(
|
||
|
module.self_attn,
|
||
|
attn_input_scale,
|
||
|
q_output_scale,
|
||
|
k_output_scale,
|
||
|
v_output_scale,
|
||
|
q_rotary_output_scale,
|
||
|
k_rotary_output_scale,
|
||
|
out_input_scale,
|
||
|
)
|
||
|
|
||
|
int8_decoder_layer.post_attention_layernorm = LlamaLayerNormQ.from_float(
|
||
|
module.post_attention_layernorm, gate_input_scale
|
||
|
)
|
||
|
|
||
|
int8_decoder_layer.mlp = LlamaSmoothquantMLP.pack(
|
||
|
module.mlp,
|
||
|
gate_input_scale,
|
||
|
up_input_scale,
|
||
|
down_input_scale,
|
||
|
)
|
||
|
|
||
|
return int8_decoder_layer
|
||
|
|
||
|
def forward(
|
||
|
self,
|
||
|
hidden_states: torch.Tensor,
|
||
|
rotary_emb: Tuple[torch.Tensor] = None,
|
||
|
attention_mask: Optional[torch.Tensor] = None,
|
||
|
position_ids: Optional[torch.LongTensor] = None,
|
||
|
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
||
|
output_attentions: Optional[bool] = False,
|
||
|
use_cache: Optional[bool] = False,
|
||
|
padding_mask: Optional[torch.LongTensor] = None,
|
||
|
infer_state: Optional[BatchInferState] = None,
|
||
|
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
||
|
"""
|
||
|
Args:
|
||
|
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
|
||
|
attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
|
||
|
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
|
||
|
output_attentions (`bool`, *optional*):
|
||
|
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
|
||
|
returned tensors for more detail.
|
||
|
use_cache (`bool`, *optional*):
|
||
|
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
|
||
|
(see `past_key_values`).
|
||
|
past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
|
||
|
"""
|
||
|
|
||
|
residual = hidden_states
|
||
|
|
||
|
hidden_states = self.input_layernorm(hidden_states)
|
||
|
|
||
|
# Self Attention
|
||
|
hidden_states, self_attn_weights, present_key_value = self.self_attn(
|
||
|
hidden_states=hidden_states,
|
||
|
rotary_emb=rotary_emb,
|
||
|
attention_mask=attention_mask,
|
||
|
position_ids=position_ids,
|
||
|
past_key_value=past_key_value,
|
||
|
output_attentions=output_attentions,
|
||
|
use_cache=use_cache,
|
||
|
padding_mask=padding_mask,
|
||
|
infer_state=infer_state,
|
||
|
)
|
||
|
hidden_states = residual + hidden_states
|
||
|
|
||
|
# Fully Connected
|
||
|
residual = hidden_states
|
||
|
hidden_states = self.post_attention_layernorm(hidden_states)
|
||
|
hidden_states = self.mlp(hidden_states)
|
||
|
hidden_states = residual + hidden_states
|
||
|
|
||
|
return hidden_states, None, None
|
||
|
|
||
|
|
||
|
class LlamaApplyRotary(nn.Module):
|
||
|
def __init__(self):
|
||
|
super().__init__()
|
||
|
|
||
|
def forward(self, x, cos, sin, position_ids):
|
||
|
# The first two dimensions of cos and sin are always 1, so we can `squeeze` them.
|
||
|
cos = cos.squeeze(1).squeeze(0) # [seq_len, dim]
|
||
|
sin = sin.squeeze(1).squeeze(0) # [seq_len, dim]
|
||
|
cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
|
||
|
sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
|
||
|
x_embed = (x * cos) + (rotate_half(x) * sin)
|
||
|
|
||
|
return x_embed
|
||
|
|
||
|
|
||
|
# Adapted from https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py
|
||
|
def llama_decoder_layer_forward(
|
||
|
self,
|
||
|
hidden_states: torch.Tensor,
|
||
|
attention_mask: Optional[torch.Tensor] = None,
|
||
|
position_ids: Optional[torch.LongTensor] = None,
|
||
|
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
||
|
output_attentions: bool = False,
|
||
|
use_cache: bool = False,
|
||
|
padding_mask: Optional[torch.LongTensor] = None,
|
||
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||
|
bsz, q_len, _ = hidden_states.size()
|
||
|
|
||
|
if self.config.pretraining_tp > 1:
|
||
|
key_value_slicing = (self.num_key_value_heads * self.head_dim) // self.config.pretraining_tp
|
||
|
query_slices = self.q_proj.weight.split((self.num_heads * self.head_dim) // self.config.pretraining_tp, dim=0)
|
||
|
key_slices = self.k_proj.weight.split(key_value_slicing, dim=0)
|
||
|
value_slices = self.v_proj.weight.split(key_value_slicing, dim=0)
|
||
|
|
||
|
query_states = [F.linear(hidden_states, query_slices[i]) for i in range(self.config.pretraining_tp)]
|
||
|
query_states = torch.cat(query_states, dim=-1)
|
||
|
|
||
|
key_states = [F.linear(hidden_states, key_slices[i]) for i in range(self.config.pretraining_tp)]
|
||
|
key_states = torch.cat(key_states, dim=-1)
|
||
|
|
||
|
value_states = [F.linear(hidden_states, value_slices[i]) for i in range(self.config.pretraining_tp)]
|
||
|
value_states = torch.cat(value_states, dim=-1)
|
||
|
|
||
|
else:
|
||
|
query_states = self.q_proj(hidden_states)
|
||
|
key_states = self.k_proj(hidden_states)
|
||
|
value_states = self.v_proj(hidden_states)
|
||
|
|
||
|
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
||
|
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||
|
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||
|
|
||
|
kv_seq_len = key_states.shape[-2]
|
||
|
if past_key_value is not None:
|
||
|
kv_seq_len += past_key_value[0].shape[-2]
|
||
|
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
||
|
query_states = self.q_apply_rotary(query_states, cos, sin, position_ids)
|
||
|
key_states = self.k_apply_rotary(key_states, cos, sin, position_ids)
|
||
|
|
||
|
if past_key_value is not None:
|
||
|
# reuse k, v, self_attention
|
||
|
key_states = torch.cat([past_key_value[0], key_states], dim=2)
|
||
|
value_states = torch.cat([past_key_value[1], value_states], dim=2)
|
||
|
|
||
|
past_key_value = (key_states, value_states) if use_cache else None
|
||
|
|
||
|
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
||
|
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
||
|
|
||
|
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
|
||
|
|
||
|
if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
|
||
|
raise ValueError(
|
||
|
f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
|
||
|
f" {attn_weights.size()}"
|
||
|
)
|
||
|
|
||
|
if attention_mask is not None:
|
||
|
if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
|
||
|
raise ValueError(
|
||
|
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
|
||
|
)
|
||
|
attn_weights = attn_weights + attention_mask
|
||
|
|
||
|
# upcast attention to fp32
|
||
|
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
|
||
|
attn_output = torch.matmul(attn_weights, value_states)
|
||
|
|
||
|
if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
|
||
|
raise ValueError(
|
||
|
f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
|
||
|
f" {attn_output.size()}"
|
||
|
)
|
||
|
|
||
|
attn_output = attn_output.transpose(1, 2).contiguous()
|
||
|
|
||
|
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
|
||
|
|
||
|
if self.config.pretraining_tp > 1:
|
||
|
attn_output = attn_output.split(self.hidden_size // self.config.pretraining_tp, dim=2)
|
||
|
o_proj_slices = self.o_proj.weight.split(self.hidden_size // self.config.pretraining_tp, dim=1)
|
||
|
attn_output = sum([F.linear(attn_output[i], o_proj_slices[i]) for i in range(self.config.pretraining_tp)])
|
||
|
else:
|
||
|
attn_output = self.o_proj(attn_output)
|
||
|
|
||
|
if not output_attentions:
|
||
|
attn_weights = None
|
||
|
|
||
|
return attn_output, attn_weights, past_key_value
|
||
|
|
||
|
|
||
|
def init_to_get_rotary(config, base=10000, use_elem=False):
|
||
|
"""
|
||
|
This function initializes the rotary positional embedding, it is compatible for all models and is called in ShardFormer
|
||
|
Args:
|
||
|
base : calculation arg
|
||
|
use_elem : activated when using chatglm-based models
|
||
|
"""
|
||
|
config.head_dim_ = config.hidden_size // config.num_attention_heads
|
||
|
if not hasattr(config, "rope_scaling"):
|
||
|
rope_scaling_factor = 1.0
|
||
|
else:
|
||
|
rope_scaling_factor = config.rope_scaling.factor if config.rope_scaling is not None else 1.0
|
||
|
|
||
|
if hasattr(config, "max_sequence_length"):
|
||
|
max_seq_len = config.max_sequence_length
|
||
|
elif hasattr(config, "max_position_embeddings"):
|
||
|
max_seq_len = config.max_position_embeddings * rope_scaling_factor
|
||
|
else:
|
||
|
max_seq_len = 2048 * rope_scaling_factor
|
||
|
base = float(base)
|
||
|
|
||
|
# NTK ref: https://www.reddit.com/r/LocalLLaMA/comments/14lz7j5/ntkaware_scaled_rope_allows_llama_models_to_have/
|
||
|
try:
|
||
|
ntk_alpha = float(os.environ.get("INFER_NTK_ALPHA", 1))
|
||
|
assert ntk_alpha >= 1
|
||
|
if ntk_alpha > 1:
|
||
|
print(f"Note: NTK enabled, alpha set to {ntk_alpha}")
|
||
|
max_seq_len *= ntk_alpha
|
||
|
base = base * (ntk_alpha ** (config.head_dim_ / (config.head_dim_ - 2))) # Base change formula
|
||
|
except:
|
||
|
pass
|
||
|
|
||
|
n_elem = config.head_dim_
|
||
|
if use_elem:
|
||
|
n_elem //= 2
|
||
|
|
||
|
inv_freq = 1.0 / (base ** (torch.arange(0, n_elem, 2, device="cpu", dtype=torch.float32) / n_elem))
|
||
|
t = torch.arange(max_seq_len + 1024 * 64, device="cpu", dtype=torch.float32) / rope_scaling_factor
|
||
|
freqs = torch.outer(t, inv_freq)
|
||
|
|
||
|
_cos_cached = torch.cos(freqs).to(torch.float)
|
||
|
_sin_cached = torch.sin(freqs).to(torch.float)
|
||
|
return _cos_cached, _sin_cached
|
||
|
|
||
|
|
||
|
# Adapted from https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py
|
||
|
@add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
|
||
|
def llama_model_forward(
|
||
|
self,
|
||
|
input_ids: torch.LongTensor = None,
|
||
|
attention_mask: Optional[torch.Tensor] = None,
|
||
|
position_ids: Optional[torch.LongTensor] = None,
|
||
|
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
||
|
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||
|
use_cache: Optional[bool] = None,
|
||
|
output_attentions: Optional[bool] = None,
|
||
|
output_hidden_states: Optional[bool] = None,
|
||
|
return_dict: Optional[bool] = None,
|
||
|
) -> Union[Tuple, BaseModelOutputWithPast]:
|
||
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||
|
output_hidden_states = (
|
||
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||
|
)
|
||
|
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
||
|
|
||
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||
|
|
||
|
# retrieve input_ids and inputs_embeds
|
||
|
if input_ids is not None and inputs_embeds is not None:
|
||
|
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
||
|
elif input_ids is not None:
|
||
|
batch_size, seq_length = input_ids.shape
|
||
|
elif inputs_embeds is not None:
|
||
|
batch_size, seq_length, _ = inputs_embeds.shape
|
||
|
else:
|
||
|
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
||
|
|
||
|
infer_state = self.infer_state
|
||
|
if infer_state.is_context_stage:
|
||
|
past_key_values_length = 0
|
||
|
else:
|
||
|
past_key_values_length = infer_state.max_len_in_batch - 1
|
||
|
|
||
|
seq_length_with_past = seq_length + past_key_values_length
|
||
|
|
||
|
# NOTE: differentiate with prefill stage
|
||
|
# block_loc require different value-assigning method for two different stage
|
||
|
# NOTE: differentiate with prefill stage
|
||
|
# block_loc require different value-assigning method for two different stage
|
||
|
if infer_state.is_context_stage:
|
||
|
infer_state.context_mem_index = infer_state.cache_manager.alloc(infer_state.total_token_num)
|
||
|
infer_state.init_block_loc(
|
||
|
infer_state.block_loc, infer_state.seq_len, seq_length, infer_state.context_mem_index
|
||
|
)
|
||
|
else:
|
||
|
alloc_mem = infer_state.cache_manager.alloc_contiguous(batch_size)
|
||
|
if alloc_mem is not None:
|
||
|
infer_state.decode_is_contiguous = True
|
||
|
infer_state.decode_mem_index = alloc_mem[0]
|
||
|
infer_state.decode_mem_start = alloc_mem[1]
|
||
|
infer_state.decode_mem_end = alloc_mem[2]
|
||
|
infer_state.block_loc[:, seq_length_with_past - 1] = infer_state.decode_mem_index
|
||
|
else:
|
||
|
print(f" *** Encountered allocation non-contiguous")
|
||
|
print(f" infer_state.cache_manager.max_len_in_batch: {infer_state.max_len_in_batch}")
|
||
|
infer_state.decode_is_contiguous = False
|
||
|
alloc_mem = infer_state.cache_manager.alloc(batch_size)
|
||
|
infer_state.decode_mem_index = alloc_mem
|
||
|
infer_state.block_loc[:, seq_length_with_past - 1] = infer_state.decode_mem_index
|
||
|
|
||
|
if position_ids is None:
|
||
|
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
||
|
position_ids = torch.arange(
|
||
|
past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
|
||
|
)
|
||
|
position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
|
||
|
else:
|
||
|
position_ids = position_ids.view(-1, seq_length).long()
|
||
|
|
||
|
if inputs_embeds is None:
|
||
|
inputs_embeds = self.embed_tokens(input_ids)
|
||
|
# embed positions
|
||
|
if attention_mask is None:
|
||
|
attention_mask = torch.ones((batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device)
|
||
|
padding_mask = None
|
||
|
else:
|
||
|
if 0 in attention_mask:
|
||
|
padding_mask = attention_mask
|
||
|
else:
|
||
|
padding_mask = None
|
||
|
|
||
|
attention_mask = self._prepare_decoder_attention_mask(
|
||
|
attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length
|
||
|
)
|
||
|
|
||
|
hidden_states = inputs_embeds
|
||
|
|
||
|
if self.gradient_checkpointing and self.training:
|
||
|
raise NotImplementedError("not implement gradient_checkpointing and training options ")
|
||
|
|
||
|
if past_key_values_length == 0:
|
||
|
position_cos = torch.index_select(self._cos_cached, 0, position_ids.view(-1)).view(
|
||
|
position_ids.view(-1).shape[0], -1
|
||
|
)
|
||
|
position_sin = torch.index_select(self._sin_cached, 0, position_ids.view(-1)).view(
|
||
|
position_ids.view(-1).shape[0], -1
|
||
|
)
|
||
|
else:
|
||
|
position_cos = torch.index_select(self._cos_cached, 0, position_ids.view(-1)).view(batch_size, -1)
|
||
|
position_sin = torch.index_select(self._sin_cached, 0, position_ids.view(-1)).view(batch_size, -1)
|
||
|
|
||
|
# decoder layers
|
||
|
all_hidden_states = () if output_hidden_states else None
|
||
|
all_self_attns = () if output_attentions else None
|
||
|
next_decoder_cache = () if use_cache else None
|
||
|
infer_state.decode_layer_id = 0
|
||
|
for idx, decoder_layer in enumerate(self.layers):
|
||
|
if output_hidden_states:
|
||
|
all_hidden_states += (hidden_states,)
|
||
|
|
||
|
past_key_value = past_key_values[idx] if past_key_values is not None else None
|
||
|
|
||
|
layer_outputs = decoder_layer(
|
||
|
hidden_states,
|
||
|
rotary_emb=(position_cos, position_sin),
|
||
|
attention_mask=attention_mask,
|
||
|
position_ids=position_ids,
|
||
|
past_key_value=past_key_value,
|
||
|
output_attentions=output_attentions,
|
||
|
use_cache=use_cache,
|
||
|
padding_mask=padding_mask,
|
||
|
infer_state=infer_state,
|
||
|
)
|
||
|
|
||
|
hidden_states = layer_outputs[0]
|
||
|
infer_state.decode_layer_id += 1
|
||
|
|
||
|
if use_cache:
|
||
|
next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
|
||
|
|
||
|
if output_attentions:
|
||
|
all_self_attns += (layer_outputs[1],)
|
||
|
|
||
|
hidden_states = self.norm(hidden_states)
|
||
|
|
||
|
# add hidden states from the last decoder layer
|
||
|
if output_hidden_states:
|
||
|
all_hidden_states += (hidden_states,)
|
||
|
|
||
|
infer_state.is_context_stage = False
|
||
|
infer_state.start_loc = infer_state.start_loc + torch.arange(0, batch_size, dtype=torch.int32, device="cuda")
|
||
|
infer_state.seq_len += 1
|
||
|
infer_state.max_len_in_batch += 1
|
||
|
|
||
|
next_cache = next_decoder_cache if use_cache else None
|
||
|
if not return_dict:
|
||
|
return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
|
||
|
return BaseModelOutputWithPast(
|
||
|
last_hidden_state=hidden_states,
|
||
|
past_key_values=next_cache,
|
||
|
hidden_states=all_hidden_states,
|
||
|
attentions=all_self_attns,
|
||
|
)
|
||
|
|
||
|
|
||
|
class SmoothLlamaForCausalLM(BaseSmoothForCausalLM):
|
||
|
layer_type = "LlamaDecoderLayer"
|
||
|
|
||
|
def __init__(self, model: PreTrainedModel, quantized: bool = False):
|
||
|
super().__init__(model, quantized)
|
||
|
|
||
|
# Adatped from https://github.com/mit-han-lab/smoothquant/blob/main/smoothquant/calibration.py
|
||
|
def get_act_dict(
|
||
|
self,
|
||
|
tokenizer,
|
||
|
dataset,
|
||
|
num_samples=512,
|
||
|
seq_len=512,
|
||
|
):
|
||
|
llama_model = self.model
|
||
|
|
||
|
llama_model.eval()
|
||
|
device = next(llama_model.parameters()).device
|
||
|
# print("model:", llama_model)
|
||
|
act_dict = defaultdict(dict)
|
||
|
|
||
|
def stat_io_hook(m, x, y, name):
|
||
|
if isinstance(x, tuple):
|
||
|
x = x[0]
|
||
|
if name not in act_dict or "input" not in act_dict[name]:
|
||
|
act_dict[name]["input"] = x.detach().abs().max().item()
|
||
|
else:
|
||
|
act_dict[name]["input"] = max(act_dict[name]["input"], x.detach().abs().max().item())
|
||
|
if isinstance(y, tuple):
|
||
|
y = y[0]
|
||
|
if name not in act_dict or "output" not in act_dict[name]:
|
||
|
act_dict[name]["output"] = y.detach().abs().max().item()
|
||
|
else:
|
||
|
act_dict[name]["output"] = max(act_dict[name]["output"], y.detach().abs().max().item())
|
||
|
|
||
|
for name, m in llama_model.named_modules():
|
||
|
if isinstance(m, LlamaAttention):
|
||
|
setattr(m, "q_apply_rotary", LlamaApplyRotary())
|
||
|
setattr(m, "k_apply_rotary", LlamaApplyRotary())
|
||
|
m.forward = types.MethodType(llama_decoder_layer_forward, m)
|
||
|
|
||
|
hooks = []
|
||
|
for name, m in llama_model.named_modules():
|
||
|
if isinstance(m, LlamaApplyRotary):
|
||
|
hooks.append(m.register_forward_hook(partial(stat_io_hook, name=name)))
|
||
|
if isinstance(m, torch.nn.Linear):
|
||
|
hooks.append(m.register_forward_hook(partial(stat_io_hook, name=name)))
|
||
|
|
||
|
self.collect_act_dict(llama_model, tokenizer, dataset, act_dict, device, num_samples, seq_len)
|
||
|
|
||
|
for hook in hooks:
|
||
|
hook.remove()
|
||
|
return act_dict
|
||
|
|
||
|
def smooth_fn(self, scales, alpha=0.5):
|
||
|
model = self.model
|
||
|
for name, module in model.named_modules():
|
||
|
if isinstance(module, LlamaDecoderLayer):
|
||
|
attn_ln = module.input_layernorm
|
||
|
qkv = [module.self_attn.q_proj, module.self_attn.k_proj, module.self_attn.v_proj]
|
||
|
qkv_input_scales = scales[name + ".self_attn.q_proj"]
|
||
|
self.smooth_ln_fcs(attn_ln, qkv, qkv_input_scales, alpha)
|
||
|
|
||
|
def create_quantized_model(model):
|
||
|
llama_config = model.config
|
||
|
for i, layer in enumerate(model.model.layers):
|
||
|
model.model.layers[i] = LlamaSmoothquantDecoderLayer(llama_config)
|
||
|
|
||
|
model.model.forward = types.MethodType(llama_model_forward, model.model)
|
||
|
cos, sin = init_to_get_rotary(llama_config)
|
||
|
model.model.register_buffer("_cos_cached", cos)
|
||
|
model.model.register_buffer("_sin_cached", sin)
|
||
|
|
||
|
def quantized(
|
||
|
self,
|
||
|
tokenizer,
|
||
|
dataset,
|
||
|
num_samples=512,
|
||
|
seq_len=512,
|
||
|
alpha=0.5,
|
||
|
):
|
||
|
llama_model = self.model
|
||
|
llama_config = llama_model.config
|
||
|
|
||
|
act_scales = self.get_act_scales(llama_model, tokenizer, dataset, num_samples, seq_len)
|
||
|
|
||
|
self.smooth_fn(act_scales, alpha)
|
||
|
|
||
|
act_dict = self.get_act_dict(tokenizer, dataset, num_samples, seq_len)
|
||
|
decoder_layer_scales = []
|
||
|
|
||
|
for idx in range(llama_config.num_hidden_layers):
|
||
|
scale_dict = {}
|
||
|
scale_dict["attn_input_scale"] = act_dict[f"model.layers.{idx}.self_attn.q_proj"]["input"] / 127
|
||
|
scale_dict["q_output_scale"] = act_dict[f"model.layers.{idx}.self_attn.q_proj"]["output"] / 127
|
||
|
scale_dict["k_output_scale"] = act_dict[f"model.layers.{idx}.self_attn.k_proj"]["output"] / 127
|
||
|
scale_dict["v_output_scale"] = act_dict[f"model.layers.{idx}.self_attn.v_proj"]["output"] / 127
|
||
|
|
||
|
scale_dict["q_rotary_output_scale"] = (
|
||
|
act_dict[f"model.layers.{idx}.self_attn.q_apply_rotary"]["output"] / 127
|
||
|
)
|
||
|
scale_dict["k_rotary_output_scale"] = (
|
||
|
act_dict[f"model.layers.{idx}.self_attn.k_apply_rotary"]["output"] / 127
|
||
|
)
|
||
|
|
||
|
scale_dict["out_input_scale"] = act_dict[f"model.layers.{idx}.self_attn.o_proj"]["input"] / 127
|
||
|
|
||
|
scale_dict["gate_input_scale"] = act_dict[f"model.layers.{idx}.mlp.gate_proj"]["input"] / 127
|
||
|
scale_dict["up_input_scale"] = act_dict[f"model.layers.{idx}.mlp.up_proj"]["input"] / 127
|
||
|
scale_dict["down_input_scale"] = act_dict[f"model.layers.{idx}.mlp.down_proj"]["input"] / 127
|
||
|
|
||
|
decoder_layer_scales.append(scale_dict)
|
||
|
|
||
|
for i, layer in enumerate(llama_model.model.layers):
|
||
|
orig_layer = layer
|
||
|
llama_model.model.layers[i] = LlamaSmoothquantDecoderLayer.pack(orig_layer, **decoder_layer_scales[i])
|
||
|
|
||
|
llama_model.model.forward = types.MethodType(llama_model_forward, llama_model.model)
|
||
|
|
||
|
cos, sin = init_to_get_rotary(llama_config)
|
||
|
llama_model.model.register_buffer("_cos_cached", cos.to(self.model.device))
|
||
|
llama_model.model.register_buffer("_sin_cached", sin.to(self.model.device))
|