mirror of https://github.com/hpcaitech/ColossalAI
[Hotfix] Fix accuracy and align attention method api with Triton kernel (#5229)
* fix accuracy * alignment in attention * fix attention * fix * fix bugs * fix bugs * fix bugspull/5258/head
parent
fa4fbdbffb
commit
e545a871b8
|
@ -1,11 +1,9 @@
|
||||||
import math
|
import math
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from transformers.modeling_attn_mask_utils import AttentionMaskConverter
|
from transformers.modeling_attn_mask_utils import AttentionMaskConverter
|
||||||
from transformers.models.llama.modeling_llama import LlamaRotaryEmbedding, apply_rotary_pos_emb
|
|
||||||
|
|
||||||
|
|
||||||
def copy_to_cache(source, cache, lengths, block_tables, type: str = "prefill"):
|
def copy_to_cache(source, cache, lengths, block_tables, type: str = "prefill"):
|
||||||
|
@ -13,12 +11,12 @@ def copy_to_cache(source, cache, lengths, block_tables, type: str = "prefill"):
|
||||||
Func: copy key/value into key/value cache.
|
Func: copy key/value into key/value cache.
|
||||||
|
|
||||||
Args: key/value(source): shape [bsz,seq_len,num_heads,head_size]
|
Args: key/value(source): shape [bsz,seq_len,num_heads,head_size]
|
||||||
cache: shape [num_blocks, num_heads, head_size, block_size]
|
cache: shape [num_blocks, num_kv_heads, head_size, block_size]
|
||||||
lengths: key/value lengths
|
lengths: key/value lengths
|
||||||
block_tables
|
block_tables
|
||||||
"""
|
"""
|
||||||
num_blocks, num_heads, head_size, block_size = cache.shape
|
num_blocks, num_heads, head_size, block_size = cache.shape
|
||||||
bsz, max_seq_len = block_tables.shape
|
bsz, max_blocks_per_seq = block_tables.shape
|
||||||
needed_blocks = (lengths + block_size - 1) // block_size
|
needed_blocks = (lengths + block_size - 1) // block_size
|
||||||
|
|
||||||
if type == "prefill":
|
if type == "prefill":
|
||||||
|
@ -42,13 +40,14 @@ def copy_to_cache(source, cache, lengths, block_tables, type: str = "prefill"):
|
||||||
return cache
|
return cache
|
||||||
|
|
||||||
|
|
||||||
def convert_kvcache(cache, lengths, block_tables):
|
def convert_kvcache(cache, lengths, block_tables, pad_id=0):
|
||||||
"""
|
"""
|
||||||
Func: convert key/value cache for calculation
|
Func: convert key/value cache for calculation
|
||||||
|
|
||||||
Args: cache: shape [num_blocks, num_heads, head_size, block_size]
|
Args: cache: shape [num_blocks, num_heads, head_size, block_size]
|
||||||
lengths: key/value length
|
lengths: key/value length
|
||||||
block_tables
|
block_tables
|
||||||
|
pad_id: padded_id
|
||||||
"""
|
"""
|
||||||
num_blocks, num_heads, head_size, block_size = cache.shape
|
num_blocks, num_heads, head_size, block_size = cache.shape
|
||||||
|
|
||||||
|
@ -64,35 +63,29 @@ def convert_kvcache(cache, lengths, block_tables):
|
||||||
|
|
||||||
_cache = torch.cat(
|
_cache = torch.cat(
|
||||||
(
|
(
|
||||||
cache1,
|
cache[block_tables[i][: needed_blocks[i] - 1]].permute((0, 3, 1, 2)).reshape(-1, num_heads, head_size),
|
||||||
cache2,
|
cache[block_tables[i][needed_blocks[i] - 1], :, :, : num_remaing_tokens[i]].permute(2, 0, 1),
|
||||||
),
|
),
|
||||||
dim=0,
|
dim=0,
|
||||||
)
|
)
|
||||||
padding = seq_len - _cache.size(0)
|
padding = seq_len - _cache.size(0)
|
||||||
if padding > 0:
|
if padding > 0:
|
||||||
_cache = F.pad(_cache, (0, 0, 0, 0, 0, 1))
|
_cache = F.pad(_cache, (0, 0, 0, 0, 0, 1), value=pad_id)
|
||||||
padded_cache.append(_cache)
|
padded_cache.append(_cache)
|
||||||
return torch.stack(padded_cache, dim=0)
|
return torch.stack(padded_cache, dim=0)
|
||||||
|
|
||||||
|
|
||||||
class PagedAttention(nn.Module):
|
class PagedAttention:
|
||||||
"""
|
"""
|
||||||
Pure Torch implementation version of paged_attention.
|
Pure Torch implementation version of paged_attention.
|
||||||
|
Holds different types of forward function and useful components.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, num_heads: int, head_size: int, scale: float = 1.0, sliding_window: Optional[int] = None):
|
@staticmethod
|
||||||
super().__init__()
|
def pad_and_reshape(tensor, seq_lengths, max_seq_len, num_heads, head_size):
|
||||||
self.num_heads = num_heads
|
"""
|
||||||
self.head_size = head_size
|
Transform 1D no_pad tensor into 2D padded tensor with shape [bsz,seq_len,num_heads,head_size]
|
||||||
self.scale = float(scale)
|
"""
|
||||||
self.sliding_window = sliding_window
|
|
||||||
self._init_rope()
|
|
||||||
|
|
||||||
def _init_rope(self):
|
|
||||||
self.rotary_emb = LlamaRotaryEmbedding(self.head_size)
|
|
||||||
|
|
||||||
def pad_and_reshape(self, tensor, seq_lengths, max_seq_len, num_heads, head_size):
|
|
||||||
bsz = len(seq_lengths)
|
bsz = len(seq_lengths)
|
||||||
padded_tensor = torch.zeros(bsz, max_seq_len, num_heads, head_size)
|
padded_tensor = torch.zeros(bsz, max_seq_len, num_heads, head_size)
|
||||||
|
|
||||||
|
@ -103,22 +96,49 @@ class PagedAttention(nn.Module):
|
||||||
token_idx += seq_len
|
token_idx += seq_len
|
||||||
return padded_tensor
|
return padded_tensor
|
||||||
|
|
||||||
def generate_padding_mask(self, lengths, max_seq_len):
|
@staticmethod
|
||||||
|
def generate_padding_mask(lengths, max_seq_len):
|
||||||
range_tensor = torch.arange(max_seq_len).expand(len(lengths), max_seq_len)
|
range_tensor = torch.arange(max_seq_len).expand(len(lengths), max_seq_len)
|
||||||
padding_mask = range_tensor < lengths.unsqueeze(1)
|
padding_mask = range_tensor < lengths.unsqueeze(1)
|
||||||
return padding_mask
|
return padding_mask
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def repeat_kv(hidden_states: torch.Tensor, n_rep: int = 1) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Essential component for MQA. Equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep).
|
||||||
|
Args: hidden_states(batch, num_key_value_heads, seqlen, head_dim)
|
||||||
|
n_rep: times of repeatition.
|
||||||
|
Output: hidden_states (batch, num_attention_heads, seqlen, head_dim)
|
||||||
|
"""
|
||||||
|
if n_rep == 1:
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
batch, num_key_value_heads, seq_len, head_dim = hidden_states.shape
|
||||||
|
num_attention_heads = n_rep * num_key_value_heads
|
||||||
|
hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, seq_len, head_dim)
|
||||||
|
|
||||||
|
return hidden_states.reshape(batch, num_attention_heads, seq_len, head_dim)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
def nopad_context_forward(
|
def nopad_context_forward(
|
||||||
self,
|
|
||||||
q: torch.Tensor, # [num_tokens, num_heads, head_size]
|
q: torch.Tensor, # [num_tokens, num_heads, head_size]
|
||||||
k: torch.Tensor,
|
k: torch.Tensor, # [num_tokens, num_kv_heads, head_size]
|
||||||
v: torch.Tensor,
|
v: torch.Tensor,
|
||||||
k_cache: torch.Tensor, # [num_blocks, num_heads, head_size, block_size]
|
k_cache: torch.Tensor, # [num_blocks, num_heads, head_size, block_size]
|
||||||
v_cache: torch.Tensor,
|
v_cache: torch.Tensor,
|
||||||
context_lengths: torch.Tensor, # [num_seqs]
|
context_lengths: torch.Tensor, # [num_seqs]
|
||||||
block_tables: torch.Tensor, # [num_seqs,max_blocks_per_sequence]
|
block_tables: torch.Tensor, # [num_seqs,max_blocks_per_sequence]
|
||||||
):
|
):
|
||||||
|
"""
|
||||||
|
NOTE: q,k,v are projected and applied rotary embedding, all aligned with triton version.
|
||||||
|
"""
|
||||||
|
# Fisrt, do shape verification
|
||||||
num_tokens, num_heads, head_size = q.shape
|
num_tokens, num_heads, head_size = q.shape
|
||||||
|
num_kv_heads = k.shape[-2]
|
||||||
|
|
||||||
|
assert num_heads % num_kv_heads == 0, "num_kv_heads should be divisible by num_heads"
|
||||||
|
num_kv_groups = num_heads // num_kv_heads
|
||||||
|
|
||||||
block_size = k_cache.shape[-1]
|
block_size = k_cache.shape[-1]
|
||||||
bsz, max_blocks_per_sequence = block_tables.shape
|
bsz, max_blocks_per_sequence = block_tables.shape
|
||||||
max_seq_len = max_blocks_per_sequence * block_size
|
max_seq_len = max_blocks_per_sequence * block_size
|
||||||
|
@ -127,80 +147,85 @@ class PagedAttention(nn.Module):
|
||||||
assert context_lengths.shape[0] == block_tables.shape[0]
|
assert context_lengths.shape[0] == block_tables.shape[0]
|
||||||
shape = (bsz, max_seq_len, num_heads, head_size)
|
shape = (bsz, max_seq_len, num_heads, head_size)
|
||||||
input_shape = shape[:2]
|
input_shape = shape[:2]
|
||||||
query = self.pad_and_reshape(q, context_lengths, max_seq_len, num_heads, head_size).transpose(1, 2)
|
|
||||||
key = self.pad_and_reshape(k, context_lengths, max_seq_len, num_heads, head_size).transpose(1, 2)
|
q = PagedAttention.pad_and_reshape(
|
||||||
value = self.pad_and_reshape(v, context_lengths, max_seq_len, num_heads, head_size).transpose(1, 2)
|
q, context_lengths, max_seq_len, num_heads, head_size
|
||||||
|
) # bsz,seqlen,num_heads,head_size
|
||||||
|
k = PagedAttention.pad_and_reshape(k, context_lengths, max_seq_len, num_heads, head_size)
|
||||||
|
v = PagedAttention.pad_and_reshape(v, context_lengths, max_seq_len, num_heads, head_size)
|
||||||
|
|
||||||
|
copy_to_cache(k, k_cache, lengths=context_lengths, block_tables=block_tables)
|
||||||
|
copy_to_cache(v, v_cache, lengths=context_lengths, block_tables=block_tables)
|
||||||
|
|
||||||
attn_mask = AttentionMaskConverter._make_causal_mask(input_shape, q.dtype, q.device, past_key_values_length=0)
|
attn_mask = AttentionMaskConverter._make_causal_mask(input_shape, q.dtype, q.device, past_key_values_length=0)
|
||||||
self.generate_padding_mask(context_lengths, max_seq_len)
|
attn_mask = attn_mask + PagedAttention.generate_padding_mask(context_lengths, max_seq_len)
|
||||||
|
|
||||||
position_ids = torch.arange(0, max_seq_len, dtype=torch.long, device=query.device)
|
q = q.transpose(1, 2)
|
||||||
position_ids = position_ids.unsqueeze(0)
|
k = PagedAttention.repeat_kv(k.transpose(1, 2), num_kv_groups)
|
||||||
|
v = PagedAttention.repeat_kv(v.transpose(1, 2), num_kv_groups)
|
||||||
|
|
||||||
cos, sin = self.rotary_emb(value, max_seq_len)
|
# position_ids = torch.arange(0, max_seq_len, dtype=torch.long, device=query.device)
|
||||||
query, key = apply_rotary_pos_emb(query, key, cos, sin, position_ids)
|
# position_ids = position_ids.unsqueeze(0)
|
||||||
|
# cos, sin = self.rotary_emb(value, max_seq_len)
|
||||||
copy_to_cache(key.transpose(1, 2), k_cache, lengths=context_lengths, block_tables=block_tables)
|
# query, key = apply_rotary_pos_emb(query, key, cos, sin, position_ids)
|
||||||
copy_to_cache(value.transpose(1, 2), v_cache, lengths=context_lengths, block_tables=block_tables)
|
|
||||||
|
|
||||||
attn_weights = torch.matmul(query, key.transpose(2, 3)) / math.sqrt(head_size)
|
|
||||||
|
|
||||||
|
attn_weights = torch.matmul(q, k.transpose(2, 3)) / math.sqrt(head_size)
|
||||||
if attn_weights.size() != (bsz, num_heads, max_seq_len, max_seq_len):
|
if attn_weights.size() != (bsz, num_heads, max_seq_len, max_seq_len):
|
||||||
raise ValueError(f"Got wrong attn_weights, should be in shape {(bsz,num_heads,max_seq_len,max_seq_len)}.")
|
raise ValueError(f"Got wrong attn_weights, should be in shape {(bsz,num_heads,max_seq_len,max_seq_len)}.")
|
||||||
|
|
||||||
if attn_mask is not None:
|
if attn_mask is not None:
|
||||||
attn_weights += attn_mask
|
attn_weights += attn_mask
|
||||||
|
|
||||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
|
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(q.dtype)
|
||||||
# attn_weights = nn.functional.dropout(attn_weights,p=self.attention_dropout,training=False) maybe useless
|
attn_output = torch.matmul(attn_weights, v)
|
||||||
attn_output = torch.matmul(attn_weights, value)
|
|
||||||
|
|
||||||
if attn_output.size() != (bsz, num_heads, max_seq_len, head_size):
|
if attn_output.size() != (bsz, num_heads, max_seq_len, head_size):
|
||||||
raise ValueError(f"Got wrong attn_output, should be in shape {(bsz,num_heads,max_seq_len,head_size)}.")
|
raise ValueError(f"Got wrong attn_output, should be in shape {(bsz,num_heads,max_seq_len,head_size)}.")
|
||||||
attn_output = attn_output.transpose(1, 2).contiguous().reshape(bsz, max_seq_len, -1)
|
attn_output = attn_output.transpose(1, 2).contiguous().reshape(bsz, max_seq_len, -1)
|
||||||
|
|
||||||
|
del attn_weights
|
||||||
|
|
||||||
return attn_output
|
return attn_output
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
def pad_context_forward(
|
def pad_context_forward(
|
||||||
self,
|
|
||||||
q: torch.Tensor, # [batch_size, seq_len, num_heads, head_size]
|
q: torch.Tensor, # [batch_size, seq_len, num_heads, head_size]
|
||||||
k: torch.Tensor,
|
k: torch.Tensor, # [batch_size, seq_len, num_kv_heads, head_size]
|
||||||
v: torch.Tensor,
|
v: torch.Tensor,
|
||||||
k_cache: torch.Tensor, # [num_blocks, num_heads, head_size, block_size]
|
k_cache: torch.Tensor, # [num_blocks, num_heads, head_size, block_size]
|
||||||
v_cache: torch.Tensor,
|
v_cache: torch.Tensor,
|
||||||
context_lengths: torch.Tensor, # [num_seqs]
|
context_lengths: torch.Tensor, # [num_seqs]
|
||||||
block_tables: torch.Tensor, # [num_seqs,max_blocks_per_sequence]
|
block_tables: torch.Tensor, # [num_seqs,max_blocks_per_sequence]
|
||||||
):
|
):
|
||||||
|
# Firt, do shape verification
|
||||||
bsz, seq_len, num_heads, head_size = q.shape
|
bsz, seq_len, num_heads, head_size = q.shape
|
||||||
|
num_kv_heads = k.shape[-2]
|
||||||
|
assert num_heads % num_kv_heads == 0, "num_kv_heads should be divisible by num_heads"
|
||||||
|
num_kv_groups = num_heads // num_kv_heads
|
||||||
block_size = k_cache.shape[-1]
|
block_size = k_cache.shape[-1]
|
||||||
assert q.shape[0] == k.shape[0] == v.shape[0] == block_tables.shape[0]
|
assert q.shape[0] == k.shape[0] == v.shape[0] == block_tables.shape[0]
|
||||||
block_tables.shape[-1] * block_size
|
block_tables.shape[-1] * block_size
|
||||||
shape = (bsz, seq_len, num_heads, head_size)
|
shape = (bsz, seq_len, num_heads, head_size)
|
||||||
input_shape = shape[:2]
|
input_shape = shape[:2]
|
||||||
|
|
||||||
|
# Copy kv to memory(rotary embedded)
|
||||||
|
copy_to_cache(k, k_cache, lengths=context_lengths, block_tables=block_tables)
|
||||||
|
copy_to_cache(v, v_cache, lengths=context_lengths, block_tables=block_tables)
|
||||||
|
|
||||||
q = q.transpose(1, 2)
|
q = q.transpose(1, 2)
|
||||||
k = k.transpose(1, 2)
|
k = PagedAttention.repeat_kv(k.transpose(1, 2), num_kv_groups)
|
||||||
v = v.transpose(1, 2)
|
v = PagedAttention.repeat_kv(v.transpose(1, 2), num_kv_groups)
|
||||||
|
|
||||||
position_ids = torch.arange(0, seq_len, dtype=torch.long, device=q.device)
|
attn_weights = torch.matmul(q, k.transpose(2, 3)) / math.sqrt(head_size)
|
||||||
position_ids = position_ids.unsqueeze(0)
|
|
||||||
cos, sin = self.rotary_emb(v, seq_len)
|
|
||||||
query, key = apply_rotary_pos_emb(q, k, cos, sin, position_ids)
|
|
||||||
|
|
||||||
copy_to_cache(key.transpose(1, 2), k_cache, lengths=context_lengths, block_tables=block_tables)
|
|
||||||
copy_to_cache(v.transpose(1, 2), v_cache, lengths=context_lengths, block_tables=block_tables)
|
|
||||||
|
|
||||||
attn_weights = torch.matmul(query, key.transpose(2, 3)) / math.sqrt(head_size)
|
|
||||||
attn_mask = AttentionMaskConverter._make_causal_mask(input_shape, q.dtype, q.device, past_key_values_length=0)
|
attn_mask = AttentionMaskConverter._make_causal_mask(input_shape, q.dtype, q.device, past_key_values_length=0)
|
||||||
self.generate_padding_mask(context_lengths, seq_len)
|
attn_mask = attn_mask + PagedAttention.generate_padding_mask(context_lengths, seq_len)
|
||||||
|
|
||||||
if attn_weights.size() != (bsz, num_heads, seq_len, seq_len):
|
if attn_weights.size() != (bsz, num_heads, seq_len, seq_len):
|
||||||
raise ValueError(f"Got wrong attn_weights, should be in shape {(bsz,num_heads,seq_len,seq_len)}.")
|
raise ValueError(f"Got wrong attn_weights, should be in shape {(bsz,num_heads,seq_len,seq_len)}.")
|
||||||
if attn_mask is not None:
|
if attn_mask is not None:
|
||||||
attn_weights += attn_mask
|
attn_weights += attn_mask
|
||||||
|
|
||||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
|
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(q.dtype)
|
||||||
|
|
||||||
# attn_weights = nn.functional.dropout(attn_weights,p=self.attention_dropout,training=False) maybe useless
|
|
||||||
attn_output = torch.matmul(attn_weights, v)
|
attn_output = torch.matmul(attn_weights, v)
|
||||||
|
|
||||||
if attn_output.size() != (bsz, num_heads, seq_len, head_size):
|
if attn_output.size() != (bsz, num_heads, seq_len, head_size):
|
||||||
|
@ -208,62 +233,70 @@ class PagedAttention(nn.Module):
|
||||||
|
|
||||||
attn_output = attn_output.transpose(1, 2).contiguous().reshape(bsz, seq_len, -1)
|
attn_output = attn_output.transpose(1, 2).contiguous().reshape(bsz, seq_len, -1)
|
||||||
|
|
||||||
|
del attn_weights
|
||||||
|
|
||||||
return attn_output
|
return attn_output
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
def pad_decoding_forward(
|
def pad_decoding_forward(
|
||||||
self,
|
|
||||||
q: torch.Tensor, # [bsz, 1, num_heads, head_size]
|
q: torch.Tensor, # [bsz, 1, num_heads, head_size]
|
||||||
k: torch.Tensor,
|
k: torch.Tensor, # [bsz, 1, num_kv_heads, head_size]
|
||||||
v: torch.Tensor,
|
v: torch.Tensor,
|
||||||
k_cache: torch.Tensor, # [num_blocks, num_heads, head_size, block_size]
|
k_cache: torch.Tensor, # [num_blocks, num_heads, head_size, block_size]
|
||||||
v_cache: torch.Tensor,
|
v_cache: torch.Tensor,
|
||||||
lengths: torch.Tensor, # [num_seqs]: input_lengths + output_lengths
|
lengths: torch.Tensor, # [num_seqs]: input_lengths + output_lengths
|
||||||
block_tables: torch.Tensor, # [num_seqs,max_blocks_per_sequence]
|
block_tables: torch.Tensor, # [num_seqs,max_blocks_per_sequence]
|
||||||
):
|
):
|
||||||
|
# Firt, do shape verification.
|
||||||
bsz, _, num_heads, head_size = q.shape
|
bsz, _, num_heads, head_size = q.shape
|
||||||
|
|
||||||
|
num_kv_heads = k.shape[-2]
|
||||||
|
assert num_heads % num_kv_heads == 0, "num_kv_heads should be divisible by num_heads"
|
||||||
|
num_kv_groups = num_heads // num_kv_heads
|
||||||
block_size = k_cache.shape[-1]
|
block_size = k_cache.shape[-1]
|
||||||
seq_len = max(lengths)
|
seq_len = max(lengths)
|
||||||
|
|
||||||
assert q.shape[0] == k.shape[0] == v.shape[0] == block_tables.shape[0]
|
assert q.shape[0] == k.shape[0] == v.shape[0] == block_tables.shape[0]
|
||||||
max_seq_len = block_tables.shape[-1] * block_size
|
block_tables.shape[-1] * block_size
|
||||||
|
|
||||||
attn_mask = AttentionMaskConverter._make_causal_mask(
|
attn_mask = AttentionMaskConverter._make_causal_mask(
|
||||||
q.shape[:2], q.dtype, q.device, past_key_values_length=seq_len - 1
|
q.shape[:2], q.dtype, q.device, past_key_values_length=seq_len - 1
|
||||||
)
|
)
|
||||||
self.generate_padding_mask(lengths, max_seq_len)
|
attn_mask = attn_mask + PagedAttention.generate_padding_mask(lengths, seq_len).unsqueeze(1).unsqueeze(2)
|
||||||
cos, sin = self.rotary_emb(v, max_seq_len)
|
# cos, sin = self.rotary_emb(v, max_seq_len)
|
||||||
|
# position_ids = lengths - 1
|
||||||
|
# position_ids = position_ids.unsqueeze(1)
|
||||||
|
# query, key = apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=2)
|
||||||
|
|
||||||
position_ids = lengths - 1
|
copy_to_cache(k, k_cache, lengths=lengths, block_tables=block_tables, type="decoding")
|
||||||
position_ids = position_ids.unsqueeze(1)
|
|
||||||
|
|
||||||
query, key = apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=2)
|
|
||||||
|
|
||||||
copy_to_cache(key, k_cache, lengths=lengths, block_tables=block_tables, type="decoding")
|
|
||||||
copy_to_cache(v, v_cache, lengths=lengths, block_tables=block_tables, type="decoding")
|
copy_to_cache(v, v_cache, lengths=lengths, block_tables=block_tables, type="decoding")
|
||||||
|
|
||||||
key = convert_kvcache(key, k_cache, lengths, block_tables) # bsz, seqlen,
|
k = convert_kvcache(k_cache, lengths, block_tables) # bsz, seqlen,
|
||||||
value = convert_kvcache(v, v_cache, lengths, block_tables)
|
v = convert_kvcache(v_cache, lengths, block_tables)
|
||||||
|
|
||||||
query = query.transpose(1, 2)
|
q = q.transpose(1, 2)
|
||||||
key = key.transpose(1, 2)
|
k = PagedAttention.repeat_kv(k.transpose(1, 2), num_kv_groups)
|
||||||
value = value.transpose(1, 2)
|
v = PagedAttention.repeat_kv(v.transpose(1, 2), num_kv_groups)
|
||||||
|
|
||||||
attn_weights = torch.matmul(query, key.transpose(2, 3)) / math.sqrt(head_size)
|
attn_weights = torch.matmul(q, k.transpose(2, 3)) / math.sqrt(head_size)
|
||||||
if attn_weights.size() != (bsz, num_heads, 1, seq_len):
|
if attn_weights.size() != (bsz, num_heads, 1, seq_len):
|
||||||
raise ValueError(f"Got wrong attn_weights, should be in shape {(bsz,num_heads,1,seq_len)}.")
|
raise ValueError(f"Got wrong attn_weights, should be in shape {(bsz,num_heads,1,seq_len)}.")
|
||||||
|
|
||||||
if attn_mask is not None:
|
if attn_mask is not None:
|
||||||
attn_weights += attn_mask
|
attn_weights += attn_mask
|
||||||
|
|
||||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
|
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(q.dtype)
|
||||||
# attn_weights = nn.functional.dropout(attn_weights,p=self.attention_dropout,training=False) maybe useless
|
attn_output = torch.matmul(attn_weights, v)
|
||||||
attn_output = torch.matmul(attn_weights, value)
|
|
||||||
|
|
||||||
if attn_output.size() != (bsz, num_heads, 1, head_size):
|
if attn_output.size() != (bsz, num_heads, 1, head_size):
|
||||||
raise ValueError(f"Got wrong attn_output, should be in shape {(bsz,num_heads,1,head_size)}.")
|
raise ValueError(f"Got wrong attn_output, should be in shape {(bsz,num_heads,1,head_size)}.")
|
||||||
attn_output = attn_output.transpose(1, 2).contiguous().reshape(bsz, 1, -1)
|
attn_output = attn_output.transpose(1, 2).contiguous().reshape(bsz, 1, -1)
|
||||||
|
|
||||||
|
del attn_weights
|
||||||
|
|
||||||
return attn_output
|
return attn_output
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
def no_pad_decoding_forward(
|
def no_pad_decoding_forward(
|
||||||
self,
|
self,
|
||||||
q: torch.Tensor, # [num_tokens, num_heads, head_size]
|
q: torch.Tensor, # [num_tokens, num_heads, head_size]
|
||||||
|
|
|
@ -3,7 +3,7 @@ import pytest
|
||||||
import colossalai
|
import colossalai
|
||||||
from colossalai.inference.config import InferenceConfig
|
from colossalai.inference.config import InferenceConfig
|
||||||
from colossalai.inference.struct import BatchInfo, Sequence
|
from colossalai.inference.struct import BatchInfo, Sequence
|
||||||
from colossalai.testing import spawn
|
from colossalai.testing import rerun_if_address_is_in_use, spawn
|
||||||
|
|
||||||
|
|
||||||
def check_config_and_inference():
|
def check_config_and_inference():
|
||||||
|
@ -74,6 +74,7 @@ def run_dist(rank, world_size, port):
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.dist
|
@pytest.mark.dist
|
||||||
|
@rerun_if_address_is_in_use()
|
||||||
def test_config_and_inference():
|
def test_config_and_inference():
|
||||||
spawn(run_dist, 1)
|
spawn(run_dist, 1)
|
||||||
|
|
||||||
|
|
|
@ -11,7 +11,6 @@ from colossalai.inference.config import InferenceConfig
|
||||||
from colossalai.inference.core.engine import InferenceEngine
|
from colossalai.inference.core.engine import InferenceEngine
|
||||||
from colossalai.testing import rerun_if_address_is_in_use, spawn
|
from colossalai.testing import rerun_if_address_is_in_use, spawn
|
||||||
|
|
||||||
|
|
||||||
def setup_seed(seed):
|
def setup_seed(seed):
|
||||||
torch.manual_seed(seed)
|
torch.manual_seed(seed)
|
||||||
torch.cuda.manual_seed_all(seed)
|
torch.cuda.manual_seed_all(seed)
|
||||||
|
|
|
@ -8,7 +8,7 @@ import colossalai
|
||||||
from colossalai.inference.config import InferenceConfig
|
from colossalai.inference.config import InferenceConfig
|
||||||
from colossalai.inference.kv_cache import CacheBlock, KVCacheManager
|
from colossalai.inference.kv_cache import CacheBlock, KVCacheManager
|
||||||
from colossalai.logging import disable_existing_loggers
|
from colossalai.logging import disable_existing_loggers
|
||||||
from colossalai.testing import parameterize, spawn
|
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
|
||||||
|
|
||||||
|
|
||||||
@parameterize(
|
@parameterize(
|
||||||
|
@ -155,6 +155,7 @@ def run_dist(rank, world_size, port):
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.dist
|
@pytest.mark.dist
|
||||||
|
@rerun_if_address_is_in_use()
|
||||||
def test_cache_manager():
|
def test_cache_manager():
|
||||||
spawn(run_dist, 1)
|
spawn(run_dist, 1)
|
||||||
|
|
||||||
|
|
|
@ -3,15 +3,15 @@ import torch
|
||||||
from transformers.cache_utils import DynamicCache
|
from transformers.cache_utils import DynamicCache
|
||||||
from transformers.modeling_attn_mask_utils import AttentionMaskConverter
|
from transformers.modeling_attn_mask_utils import AttentionMaskConverter
|
||||||
from transformers.models.llama.configuration_llama import LlamaConfig
|
from transformers.models.llama.configuration_llama import LlamaConfig
|
||||||
from transformers.models.llama.modeling_llama import LlamaAttention
|
from transformers.models.llama.modeling_llama import LlamaAttention, apply_rotary_pos_emb
|
||||||
|
|
||||||
import colossalai
|
import colossalai
|
||||||
from colossalai.inference.modeling.layers.attention import PagedAttention, convert_kvcache, copy_to_cache
|
from colossalai.inference.modeling.layers.attention import PagedAttention, convert_kvcache, copy_to_cache
|
||||||
from colossalai.testing import spawn
|
from colossalai.testing import rerun_if_address_is_in_use, spawn
|
||||||
|
|
||||||
|
|
||||||
def test_copy_to_cache():
|
def test_copy_to_cache():
|
||||||
key = torch.ones((2, 10, 3, 3))
|
key = torch.ones((2, 11, 3, 3))
|
||||||
key[0, 9, :, :] = 0
|
key[0, 9, :, :] = 0
|
||||||
key[1, -2:, :, :] = 0
|
key[1, -2:, :, :] = 0
|
||||||
cache = torch.zeros(8, 3, 3, 8)
|
cache = torch.zeros(8, 3, 3, 8)
|
||||||
|
@ -32,7 +32,8 @@ def test_convert_kvcache():
|
||||||
key = torch.ones(2, 1, 3, 3) + 1
|
key = torch.ones(2, 1, 3, 3) + 1
|
||||||
lengths = torch.tensor([10, 9])
|
lengths = torch.tensor([10, 9])
|
||||||
block_tables = torch.tensor([[0, 1], [2, 3]])
|
block_tables = torch.tensor([[0, 1], [2, 3]])
|
||||||
converted_cache = convert_kvcache(key, cache=cache, lengths=lengths, block_tables=block_tables)
|
copy_to_cache(key, cache=cache, lengths=lengths, block_tables=block_tables, type="decoding")
|
||||||
|
converted_cache = convert_kvcache(cache=cache, lengths=lengths, block_tables=block_tables)
|
||||||
assert converted_cache.shape == (2, 10, 3, 3)
|
assert converted_cache.shape == (2, 10, 3, 3)
|
||||||
|
|
||||||
|
|
||||||
|
@ -40,7 +41,7 @@ def test_context_attention():
|
||||||
"""
|
"""
|
||||||
test config: head_num = 4, head_size = 4
|
test config: head_num = 4, head_size = 4
|
||||||
"""
|
"""
|
||||||
attn = PagedAttention(4, 4)
|
attn = PagedAttention()
|
||||||
q = k = v = torch.randn(8, 4, 4)
|
q = k = v = torch.randn(8, 4, 4)
|
||||||
k_cache = torch.empty(8, 4, 4, 8)
|
k_cache = torch.empty(8, 4, 4, 8)
|
||||||
v_cache = torch.empty(8, 4, 4, 8)
|
v_cache = torch.empty(8, 4, 4, 8)
|
||||||
|
@ -61,48 +62,72 @@ def test_context_attention():
|
||||||
|
|
||||||
# test accuracy with LlamaAttention
|
# test accuracy with LlamaAttention
|
||||||
hidden_states = torch.randn(1, 8, 16)
|
hidden_states = torch.randn(1, 8, 16)
|
||||||
proj_q = transformer_attn.q_proj(hidden_states).view(1, 8, 4, 4)
|
proj_q = transformer_attn.q_proj(hidden_states).view(1, 8, 4, 4).transpose(1, 2)
|
||||||
proj_k = transformer_attn.k_proj(hidden_states).view(1, 8, 4, 4)
|
proj_k = transformer_attn.k_proj(hidden_states).view(1, 8, 4, 4).transpose(1, 2)
|
||||||
proj_v = transformer_attn.v_proj(hidden_states).view(1, 8, 4, 4)
|
proj_v = transformer_attn.v_proj(hidden_states).view(1, 8, 4, 4).transpose(1, 2)
|
||||||
pad_attn_output = attn.pad_context_forward(proj_q, proj_k, proj_v, k_cache, v_cache, context_lengths, block_tables)
|
|
||||||
pad_attn_output = transformer_attn.o_proj(pad_attn_output)
|
|
||||||
|
|
||||||
|
position_ids = torch.arange(0, 8, dtype=torch.long, device=proj_q.device)
|
||||||
|
position_ids = position_ids.unsqueeze(0)
|
||||||
|
cos, sin = transformer_attn.rotary_emb(proj_v, 8)
|
||||||
|
proj_q, proj_k = apply_rotary_pos_emb(proj_q, proj_k, cos, sin, position_ids)
|
||||||
|
|
||||||
|
pad_attn_output = attn.pad_context_forward(
|
||||||
|
proj_q.transpose(1, 2),
|
||||||
|
proj_k.transpose(1, 2),
|
||||||
|
proj_v.transpose(1, 2),
|
||||||
|
k_cache,
|
||||||
|
v_cache,
|
||||||
|
context_lengths,
|
||||||
|
block_tables,
|
||||||
|
)
|
||||||
|
pad_attn_output = transformer_attn.o_proj(pad_attn_output)
|
||||||
attn_mask = AttentionMaskConverter._make_causal_mask(
|
attn_mask = AttentionMaskConverter._make_causal_mask(
|
||||||
hidden_states.shape[:2], q.dtype, q.device, past_key_values_length=0
|
hidden_states.shape[:2], q.dtype, q.device, past_key_values_length=0
|
||||||
)
|
)
|
||||||
|
attn_mask += PagedAttention.generate_padding_mask(context_lengths, 8)
|
||||||
attn_output, _, _ = transformer_attn.forward(hidden_states, attention_mask=attn_mask)
|
attn_output, _, _ = transformer_attn.forward(hidden_states, attention_mask=attn_mask)
|
||||||
assert torch.allclose(pad_attn_output, attn_output, atol=1e-3, rtol=1e-2)
|
assert torch.allclose(pad_attn_output, attn_output, atol=1e-3, rtol=1e-3)
|
||||||
|
|
||||||
|
|
||||||
def test_decoding_attention():
|
def test_decoding_attention():
|
||||||
# test the pipeline of decoding attention
|
# test the pipeline of decoding attention
|
||||||
attn = PagedAttention(4, 4)
|
attn = PagedAttention()
|
||||||
q = k = v = torch.randn(2, 1, 4, 4)
|
q = k = v = torch.randn(2, 1, 4, 8)
|
||||||
k_cache = torch.empty(8, 4, 4, 8)
|
k_cache = torch.empty(8, 4, 8, 8)
|
||||||
v_cache = torch.empty(8, 4, 4, 8)
|
v_cache = torch.empty(8, 4, 8, 8)
|
||||||
past_kv = torch.randn(2, 8, 4, 4)
|
past_kv = torch.randn(2, 8, 4, 8)
|
||||||
context_lenghths = torch.tensor([8, 8])
|
context_lenghths = torch.tensor([8, 8])
|
||||||
lengths = context_lenghths + 1
|
lengths = context_lenghths + 1
|
||||||
block_tables = torch.tensor([[0, 1], [2, 3]])
|
block_tables = torch.tensor([[0, 1], [2, 3]])
|
||||||
copy_to_cache(past_kv, k_cache, lengths=context_lenghths, block_tables=block_tables)
|
copy_to_cache(past_kv, k_cache, lengths=context_lenghths, block_tables=block_tables)
|
||||||
copy_to_cache(past_kv, v_cache, lengths=context_lenghths, block_tables=block_tables)
|
copy_to_cache(past_kv, v_cache, lengths=context_lenghths, block_tables=block_tables)
|
||||||
attn.pad_decoding_forward(q, k, v, k_cache, v_cache, lengths=lengths, block_tables=block_tables)
|
attn.pad_decoding_forward(q, k, v, k_cache, v_cache, lengths=lengths, block_tables=block_tables)
|
||||||
|
|
||||||
# test decoding accuracy, past_kv is reused
|
# test decoding accuracy, past_kv is reused
|
||||||
config = LlamaConfig(num_attention_heads=4, num_key_value_heads=None, hidden_size=16)
|
config = LlamaConfig(num_attention_heads=4, num_key_value_heads=None, hidden_size=32)
|
||||||
transformer_attn = LlamaAttention(config)
|
transformer_attn = LlamaAttention(config)
|
||||||
transformer_attn.layer_idx = 0
|
transformer_attn.layer_idx = 0
|
||||||
transformer_attn.training = False
|
transformer_attn.training = False
|
||||||
hidden_states = torch.randn(2, 1, 16)
|
hidden_states = torch.randn(2, 1, 32)
|
||||||
proj_q = transformer_attn.q_proj(hidden_states).view(2, 1, 4, 4)
|
proj_q = transformer_attn.q_proj(hidden_states).view(2, 1, 4, 8).transpose(1, 2)
|
||||||
proj_k = transformer_attn.k_proj(hidden_states).view(2, 1, 4, 4)
|
proj_k = transformer_attn.k_proj(hidden_states).view(2, 1, 4, 8).transpose(1, 2)
|
||||||
proj_v = transformer_attn.v_proj(hidden_states).view(2, 1, 4, 4)
|
proj_v = transformer_attn.v_proj(hidden_states).view(2, 1, 4, 8).transpose(1, 2)
|
||||||
|
|
||||||
|
cos, sin = transformer_attn.rotary_emb(proj_v, 16)
|
||||||
|
position_ids = lengths - 1
|
||||||
|
position_ids = position_ids.unsqueeze(1) # NOTE: this may be wrong
|
||||||
|
proj_q, proj_k = apply_rotary_pos_emb(proj_q, proj_k, cos, sin, position_ids, unsqueeze_dim=2)
|
||||||
|
|
||||||
llama_past_kv = DynamicCache()
|
llama_past_kv = DynamicCache()
|
||||||
llama_past_kv.update(key_states=past_kv.transpose(1, 2), value_states=past_kv.transpose(1, 2), layer_idx=0)
|
llama_past_kv.update(key_states=past_kv.transpose(1, 2), value_states=past_kv.transpose(1, 2), layer_idx=0)
|
||||||
|
|
||||||
# past_key_value shape in Llama: bsz, num_heads, seq_len, head_dim
|
# past_key_value shape in Llama: bsz, num_heads, seq_len, head_dim
|
||||||
pad_attn_output = attn.pad_decoding_forward(proj_q, proj_k, proj_v, k_cache, v_cache, lengths, block_tables)
|
pad_attn_output = attn.pad_decoding_forward(
|
||||||
attn_mask = AttentionMaskConverter._make_causal_mask(proj_q.shape[:2], q.dtype, q.device, past_key_values_length=8)
|
proj_q.transpose(1, 2), proj_k.transpose(1, 2), proj_v.transpose(1, 2), k_cache, v_cache, lengths, block_tables
|
||||||
|
)
|
||||||
|
attn_mask = AttentionMaskConverter._make_causal_mask(q.shape[:2], q.dtype, q.device, past_key_values_length=8)
|
||||||
|
attn_mask = attn_mask + PagedAttention.generate_padding_mask(lengths, 9).unsqueeze(1).unsqueeze(2)
|
||||||
|
|
||||||
pad_attn_output = transformer_attn.o_proj(pad_attn_output)
|
pad_attn_output = transformer_attn.o_proj(pad_attn_output)
|
||||||
position_ids = context_lenghths.unsqueeze(1)
|
position_ids = context_lenghths.unsqueeze(1)
|
||||||
attn_output, _, _ = transformer_attn.forward(
|
attn_output, _, _ = transformer_attn.forward(
|
||||||
|
@ -112,9 +137,9 @@ def test_decoding_attention():
|
||||||
|
|
||||||
|
|
||||||
def check_attention_layer():
|
def check_attention_layer():
|
||||||
# test_copy_to_cache()
|
test_copy_to_cache()
|
||||||
# test_convert_kvcache()
|
test_convert_kvcache()
|
||||||
# test_context_attention()
|
test_context_attention()
|
||||||
test_decoding_attention()
|
test_decoding_attention()
|
||||||
|
|
||||||
|
|
||||||
|
@ -124,6 +149,7 @@ def run_dist(rank, world_size, port):
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.dist
|
@pytest.mark.dist
|
||||||
|
@rerun_if_address_is_in_use()
|
||||||
def test_attention_layer():
|
def test_attention_layer():
|
||||||
spawn(run_dist, 1)
|
spawn(run_dist, 1)
|
||||||
|
|
||||||
|
|
|
@ -6,7 +6,7 @@ import colossalai
|
||||||
from colossalai.inference.config import InferenceConfig
|
from colossalai.inference.config import InferenceConfig
|
||||||
from colossalai.inference.core.request_handler import RequestHandler, RunningList
|
from colossalai.inference.core.request_handler import RequestHandler, RunningList
|
||||||
from colossalai.inference.struct import RequestStatus, Sequence
|
from colossalai.inference.struct import RequestStatus, Sequence
|
||||||
from colossalai.testing import spawn
|
from colossalai.testing import rerun_if_address_is_in_use, spawn
|
||||||
|
|
||||||
|
|
||||||
def check_running_list():
|
def check_running_list():
|
||||||
|
@ -78,6 +78,7 @@ def run_dist(rank, world_size, port):
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.dist
|
@pytest.mark.dist
|
||||||
|
@rerun_if_address_is_in_use()
|
||||||
def test_running_list_and_request_handler():
|
def test_running_list_and_request_handler():
|
||||||
spawn(run_dist, 1)
|
spawn(run_dist, 1)
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue