[Hotfix] Fix accuracy and align attention method api with Triton kernel (#5229)

* fix accuracy

* alignment in attention

* fix attention

* fix

* fix bugs

* fix bugs

* fix bugs
pull/5258/head
Jianghai 2024-01-08 15:56:00 +08:00 committed by FrankLeeeee
parent fa4fbdbffb
commit e545a871b8
6 changed files with 168 additions and 107 deletions

View File

@ -1,11 +1,9 @@
import math import math
from typing import Optional
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from transformers.modeling_attn_mask_utils import AttentionMaskConverter from transformers.modeling_attn_mask_utils import AttentionMaskConverter
from transformers.models.llama.modeling_llama import LlamaRotaryEmbedding, apply_rotary_pos_emb
def copy_to_cache(source, cache, lengths, block_tables, type: str = "prefill"): def copy_to_cache(source, cache, lengths, block_tables, type: str = "prefill"):
@ -13,12 +11,12 @@ def copy_to_cache(source, cache, lengths, block_tables, type: str = "prefill"):
Func: copy key/value into key/value cache. Func: copy key/value into key/value cache.
Args: key/value(source): shape [bsz,seq_len,num_heads,head_size] Args: key/value(source): shape [bsz,seq_len,num_heads,head_size]
cache: shape [num_blocks, num_heads, head_size, block_size] cache: shape [num_blocks, num_kv_heads, head_size, block_size]
lengths: key/value lengths lengths: key/value lengths
block_tables block_tables
""" """
num_blocks, num_heads, head_size, block_size = cache.shape num_blocks, num_heads, head_size, block_size = cache.shape
bsz, max_seq_len = block_tables.shape bsz, max_blocks_per_seq = block_tables.shape
needed_blocks = (lengths + block_size - 1) // block_size needed_blocks = (lengths + block_size - 1) // block_size
if type == "prefill": if type == "prefill":
@ -42,13 +40,14 @@ def copy_to_cache(source, cache, lengths, block_tables, type: str = "prefill"):
return cache return cache
def convert_kvcache(cache, lengths, block_tables): def convert_kvcache(cache, lengths, block_tables, pad_id=0):
""" """
Func: convert key/value cache for calculation Func: convert key/value cache for calculation
Args: cache: shape [num_blocks, num_heads, head_size, block_size] Args: cache: shape [num_blocks, num_heads, head_size, block_size]
lengths: key/value length lengths: key/value length
block_tables block_tables
pad_id: padded_id
""" """
num_blocks, num_heads, head_size, block_size = cache.shape num_blocks, num_heads, head_size, block_size = cache.shape
@ -64,35 +63,29 @@ def convert_kvcache(cache, lengths, block_tables):
_cache = torch.cat( _cache = torch.cat(
( (
cache1, cache[block_tables[i][: needed_blocks[i] - 1]].permute((0, 3, 1, 2)).reshape(-1, num_heads, head_size),
cache2, cache[block_tables[i][needed_blocks[i] - 1], :, :, : num_remaing_tokens[i]].permute(2, 0, 1),
), ),
dim=0, dim=0,
) )
padding = seq_len - _cache.size(0) padding = seq_len - _cache.size(0)
if padding > 0: if padding > 0:
_cache = F.pad(_cache, (0, 0, 0, 0, 0, 1)) _cache = F.pad(_cache, (0, 0, 0, 0, 0, 1), value=pad_id)
padded_cache.append(_cache) padded_cache.append(_cache)
return torch.stack(padded_cache, dim=0) return torch.stack(padded_cache, dim=0)
class PagedAttention(nn.Module): class PagedAttention:
""" """
Pure Torch implementation version of paged_attention. Pure Torch implementation version of paged_attention.
Holds different types of forward function and useful components.
""" """
def __init__(self, num_heads: int, head_size: int, scale: float = 1.0, sliding_window: Optional[int] = None): @staticmethod
super().__init__() def pad_and_reshape(tensor, seq_lengths, max_seq_len, num_heads, head_size):
self.num_heads = num_heads """
self.head_size = head_size Transform 1D no_pad tensor into 2D padded tensor with shape [bsz,seq_len,num_heads,head_size]
self.scale = float(scale) """
self.sliding_window = sliding_window
self._init_rope()
def _init_rope(self):
self.rotary_emb = LlamaRotaryEmbedding(self.head_size)
def pad_and_reshape(self, tensor, seq_lengths, max_seq_len, num_heads, head_size):
bsz = len(seq_lengths) bsz = len(seq_lengths)
padded_tensor = torch.zeros(bsz, max_seq_len, num_heads, head_size) padded_tensor = torch.zeros(bsz, max_seq_len, num_heads, head_size)
@ -103,22 +96,49 @@ class PagedAttention(nn.Module):
token_idx += seq_len token_idx += seq_len
return padded_tensor return padded_tensor
def generate_padding_mask(self, lengths, max_seq_len): @staticmethod
def generate_padding_mask(lengths, max_seq_len):
range_tensor = torch.arange(max_seq_len).expand(len(lengths), max_seq_len) range_tensor = torch.arange(max_seq_len).expand(len(lengths), max_seq_len)
padding_mask = range_tensor < lengths.unsqueeze(1) padding_mask = range_tensor < lengths.unsqueeze(1)
return padding_mask return padding_mask
@staticmethod
def repeat_kv(hidden_states: torch.Tensor, n_rep: int = 1) -> torch.Tensor:
"""
Essential component for MQA. Equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep).
Args: hidden_states(batch, num_key_value_heads, seqlen, head_dim)
n_rep: times of repeatition.
Output: hidden_states (batch, num_attention_heads, seqlen, head_dim)
"""
if n_rep == 1:
return hidden_states
batch, num_key_value_heads, seq_len, head_dim = hidden_states.shape
num_attention_heads = n_rep * num_key_value_heads
hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, seq_len, head_dim)
return hidden_states.reshape(batch, num_attention_heads, seq_len, head_dim)
@staticmethod
def nopad_context_forward( def nopad_context_forward(
self,
q: torch.Tensor, # [num_tokens, num_heads, head_size] q: torch.Tensor, # [num_tokens, num_heads, head_size]
k: torch.Tensor, k: torch.Tensor, # [num_tokens, num_kv_heads, head_size]
v: torch.Tensor, v: torch.Tensor,
k_cache: torch.Tensor, # [num_blocks, num_heads, head_size, block_size] k_cache: torch.Tensor, # [num_blocks, num_heads, head_size, block_size]
v_cache: torch.Tensor, v_cache: torch.Tensor,
context_lengths: torch.Tensor, # [num_seqs] context_lengths: torch.Tensor, # [num_seqs]
block_tables: torch.Tensor, # [num_seqs,max_blocks_per_sequence] block_tables: torch.Tensor, # [num_seqs,max_blocks_per_sequence]
): ):
"""
NOTE: q,k,v are projected and applied rotary embedding, all aligned with triton version.
"""
# Fisrt, do shape verification
num_tokens, num_heads, head_size = q.shape num_tokens, num_heads, head_size = q.shape
num_kv_heads = k.shape[-2]
assert num_heads % num_kv_heads == 0, "num_kv_heads should be divisible by num_heads"
num_kv_groups = num_heads // num_kv_heads
block_size = k_cache.shape[-1] block_size = k_cache.shape[-1]
bsz, max_blocks_per_sequence = block_tables.shape bsz, max_blocks_per_sequence = block_tables.shape
max_seq_len = max_blocks_per_sequence * block_size max_seq_len = max_blocks_per_sequence * block_size
@ -127,80 +147,85 @@ class PagedAttention(nn.Module):
assert context_lengths.shape[0] == block_tables.shape[0] assert context_lengths.shape[0] == block_tables.shape[0]
shape = (bsz, max_seq_len, num_heads, head_size) shape = (bsz, max_seq_len, num_heads, head_size)
input_shape = shape[:2] input_shape = shape[:2]
query = self.pad_and_reshape(q, context_lengths, max_seq_len, num_heads, head_size).transpose(1, 2)
key = self.pad_and_reshape(k, context_lengths, max_seq_len, num_heads, head_size).transpose(1, 2) q = PagedAttention.pad_and_reshape(
value = self.pad_and_reshape(v, context_lengths, max_seq_len, num_heads, head_size).transpose(1, 2) q, context_lengths, max_seq_len, num_heads, head_size
) # bsz,seqlen,num_heads,head_size
k = PagedAttention.pad_and_reshape(k, context_lengths, max_seq_len, num_heads, head_size)
v = PagedAttention.pad_and_reshape(v, context_lengths, max_seq_len, num_heads, head_size)
copy_to_cache(k, k_cache, lengths=context_lengths, block_tables=block_tables)
copy_to_cache(v, v_cache, lengths=context_lengths, block_tables=block_tables)
attn_mask = AttentionMaskConverter._make_causal_mask(input_shape, q.dtype, q.device, past_key_values_length=0) attn_mask = AttentionMaskConverter._make_causal_mask(input_shape, q.dtype, q.device, past_key_values_length=0)
self.generate_padding_mask(context_lengths, max_seq_len) attn_mask = attn_mask + PagedAttention.generate_padding_mask(context_lengths, max_seq_len)
position_ids = torch.arange(0, max_seq_len, dtype=torch.long, device=query.device) q = q.transpose(1, 2)
position_ids = position_ids.unsqueeze(0) k = PagedAttention.repeat_kv(k.transpose(1, 2), num_kv_groups)
v = PagedAttention.repeat_kv(v.transpose(1, 2), num_kv_groups)
cos, sin = self.rotary_emb(value, max_seq_len) # position_ids = torch.arange(0, max_seq_len, dtype=torch.long, device=query.device)
query, key = apply_rotary_pos_emb(query, key, cos, sin, position_ids) # position_ids = position_ids.unsqueeze(0)
# cos, sin = self.rotary_emb(value, max_seq_len)
copy_to_cache(key.transpose(1, 2), k_cache, lengths=context_lengths, block_tables=block_tables) # query, key = apply_rotary_pos_emb(query, key, cos, sin, position_ids)
copy_to_cache(value.transpose(1, 2), v_cache, lengths=context_lengths, block_tables=block_tables)
attn_weights = torch.matmul(query, key.transpose(2, 3)) / math.sqrt(head_size)
attn_weights = torch.matmul(q, k.transpose(2, 3)) / math.sqrt(head_size)
if attn_weights.size() != (bsz, num_heads, max_seq_len, max_seq_len): if attn_weights.size() != (bsz, num_heads, max_seq_len, max_seq_len):
raise ValueError(f"Got wrong attn_weights, should be in shape {(bsz,num_heads,max_seq_len,max_seq_len)}.") raise ValueError(f"Got wrong attn_weights, should be in shape {(bsz,num_heads,max_seq_len,max_seq_len)}.")
if attn_mask is not None: if attn_mask is not None:
attn_weights += attn_mask attn_weights += attn_mask
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(q.dtype)
# attn_weights = nn.functional.dropout(attn_weights,p=self.attention_dropout,training=False) maybe useless attn_output = torch.matmul(attn_weights, v)
attn_output = torch.matmul(attn_weights, value)
if attn_output.size() != (bsz, num_heads, max_seq_len, head_size): if attn_output.size() != (bsz, num_heads, max_seq_len, head_size):
raise ValueError(f"Got wrong attn_output, should be in shape {(bsz,num_heads,max_seq_len,head_size)}.") raise ValueError(f"Got wrong attn_output, should be in shape {(bsz,num_heads,max_seq_len,head_size)}.")
attn_output = attn_output.transpose(1, 2).contiguous().reshape(bsz, max_seq_len, -1) attn_output = attn_output.transpose(1, 2).contiguous().reshape(bsz, max_seq_len, -1)
del attn_weights
return attn_output return attn_output
@staticmethod
def pad_context_forward( def pad_context_forward(
self,
q: torch.Tensor, # [batch_size, seq_len, num_heads, head_size] q: torch.Tensor, # [batch_size, seq_len, num_heads, head_size]
k: torch.Tensor, k: torch.Tensor, # [batch_size, seq_len, num_kv_heads, head_size]
v: torch.Tensor, v: torch.Tensor,
k_cache: torch.Tensor, # [num_blocks, num_heads, head_size, block_size] k_cache: torch.Tensor, # [num_blocks, num_heads, head_size, block_size]
v_cache: torch.Tensor, v_cache: torch.Tensor,
context_lengths: torch.Tensor, # [num_seqs] context_lengths: torch.Tensor, # [num_seqs]
block_tables: torch.Tensor, # [num_seqs,max_blocks_per_sequence] block_tables: torch.Tensor, # [num_seqs,max_blocks_per_sequence]
): ):
# Firt, do shape verification
bsz, seq_len, num_heads, head_size = q.shape bsz, seq_len, num_heads, head_size = q.shape
num_kv_heads = k.shape[-2]
assert num_heads % num_kv_heads == 0, "num_kv_heads should be divisible by num_heads"
num_kv_groups = num_heads // num_kv_heads
block_size = k_cache.shape[-1] block_size = k_cache.shape[-1]
assert q.shape[0] == k.shape[0] == v.shape[0] == block_tables.shape[0] assert q.shape[0] == k.shape[0] == v.shape[0] == block_tables.shape[0]
block_tables.shape[-1] * block_size block_tables.shape[-1] * block_size
shape = (bsz, seq_len, num_heads, head_size) shape = (bsz, seq_len, num_heads, head_size)
input_shape = shape[:2] input_shape = shape[:2]
# Copy kv to memory(rotary embedded)
copy_to_cache(k, k_cache, lengths=context_lengths, block_tables=block_tables)
copy_to_cache(v, v_cache, lengths=context_lengths, block_tables=block_tables)
q = q.transpose(1, 2) q = q.transpose(1, 2)
k = k.transpose(1, 2) k = PagedAttention.repeat_kv(k.transpose(1, 2), num_kv_groups)
v = v.transpose(1, 2) v = PagedAttention.repeat_kv(v.transpose(1, 2), num_kv_groups)
position_ids = torch.arange(0, seq_len, dtype=torch.long, device=q.device) attn_weights = torch.matmul(q, k.transpose(2, 3)) / math.sqrt(head_size)
position_ids = position_ids.unsqueeze(0)
cos, sin = self.rotary_emb(v, seq_len)
query, key = apply_rotary_pos_emb(q, k, cos, sin, position_ids)
copy_to_cache(key.transpose(1, 2), k_cache, lengths=context_lengths, block_tables=block_tables)
copy_to_cache(v.transpose(1, 2), v_cache, lengths=context_lengths, block_tables=block_tables)
attn_weights = torch.matmul(query, key.transpose(2, 3)) / math.sqrt(head_size)
attn_mask = AttentionMaskConverter._make_causal_mask(input_shape, q.dtype, q.device, past_key_values_length=0) attn_mask = AttentionMaskConverter._make_causal_mask(input_shape, q.dtype, q.device, past_key_values_length=0)
self.generate_padding_mask(context_lengths, seq_len) attn_mask = attn_mask + PagedAttention.generate_padding_mask(context_lengths, seq_len)
if attn_weights.size() != (bsz, num_heads, seq_len, seq_len): if attn_weights.size() != (bsz, num_heads, seq_len, seq_len):
raise ValueError(f"Got wrong attn_weights, should be in shape {(bsz,num_heads,seq_len,seq_len)}.") raise ValueError(f"Got wrong attn_weights, should be in shape {(bsz,num_heads,seq_len,seq_len)}.")
if attn_mask is not None: if attn_mask is not None:
attn_weights += attn_mask attn_weights += attn_mask
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(q.dtype)
# attn_weights = nn.functional.dropout(attn_weights,p=self.attention_dropout,training=False) maybe useless
attn_output = torch.matmul(attn_weights, v) attn_output = torch.matmul(attn_weights, v)
if attn_output.size() != (bsz, num_heads, seq_len, head_size): if attn_output.size() != (bsz, num_heads, seq_len, head_size):
@ -208,62 +233,70 @@ class PagedAttention(nn.Module):
attn_output = attn_output.transpose(1, 2).contiguous().reshape(bsz, seq_len, -1) attn_output = attn_output.transpose(1, 2).contiguous().reshape(bsz, seq_len, -1)
del attn_weights
return attn_output return attn_output
@staticmethod
def pad_decoding_forward( def pad_decoding_forward(
self,
q: torch.Tensor, # [bsz, 1, num_heads, head_size] q: torch.Tensor, # [bsz, 1, num_heads, head_size]
k: torch.Tensor, k: torch.Tensor, # [bsz, 1, num_kv_heads, head_size]
v: torch.Tensor, v: torch.Tensor,
k_cache: torch.Tensor, # [num_blocks, num_heads, head_size, block_size] k_cache: torch.Tensor, # [num_blocks, num_heads, head_size, block_size]
v_cache: torch.Tensor, v_cache: torch.Tensor,
lengths: torch.Tensor, # [num_seqs]: input_lengths + output_lengths lengths: torch.Tensor, # [num_seqs]: input_lengths + output_lengths
block_tables: torch.Tensor, # [num_seqs,max_blocks_per_sequence] block_tables: torch.Tensor, # [num_seqs,max_blocks_per_sequence]
): ):
# Firt, do shape verification.
bsz, _, num_heads, head_size = q.shape bsz, _, num_heads, head_size = q.shape
num_kv_heads = k.shape[-2]
assert num_heads % num_kv_heads == 0, "num_kv_heads should be divisible by num_heads"
num_kv_groups = num_heads // num_kv_heads
block_size = k_cache.shape[-1] block_size = k_cache.shape[-1]
seq_len = max(lengths) seq_len = max(lengths)
assert q.shape[0] == k.shape[0] == v.shape[0] == block_tables.shape[0] assert q.shape[0] == k.shape[0] == v.shape[0] == block_tables.shape[0]
max_seq_len = block_tables.shape[-1] * block_size block_tables.shape[-1] * block_size
attn_mask = AttentionMaskConverter._make_causal_mask( attn_mask = AttentionMaskConverter._make_causal_mask(
q.shape[:2], q.dtype, q.device, past_key_values_length=seq_len - 1 q.shape[:2], q.dtype, q.device, past_key_values_length=seq_len - 1
) )
self.generate_padding_mask(lengths, max_seq_len) attn_mask = attn_mask + PagedAttention.generate_padding_mask(lengths, seq_len).unsqueeze(1).unsqueeze(2)
cos, sin = self.rotary_emb(v, max_seq_len) # cos, sin = self.rotary_emb(v, max_seq_len)
# position_ids = lengths - 1
# position_ids = position_ids.unsqueeze(1)
# query, key = apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=2)
position_ids = lengths - 1 copy_to_cache(k, k_cache, lengths=lengths, block_tables=block_tables, type="decoding")
position_ids = position_ids.unsqueeze(1)
query, key = apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=2)
copy_to_cache(key, k_cache, lengths=lengths, block_tables=block_tables, type="decoding")
copy_to_cache(v, v_cache, lengths=lengths, block_tables=block_tables, type="decoding") copy_to_cache(v, v_cache, lengths=lengths, block_tables=block_tables, type="decoding")
key = convert_kvcache(key, k_cache, lengths, block_tables) # bsz, seqlen, k = convert_kvcache(k_cache, lengths, block_tables) # bsz, seqlen,
value = convert_kvcache(v, v_cache, lengths, block_tables) v = convert_kvcache(v_cache, lengths, block_tables)
query = query.transpose(1, 2) q = q.transpose(1, 2)
key = key.transpose(1, 2) k = PagedAttention.repeat_kv(k.transpose(1, 2), num_kv_groups)
value = value.transpose(1, 2) v = PagedAttention.repeat_kv(v.transpose(1, 2), num_kv_groups)
attn_weights = torch.matmul(query, key.transpose(2, 3)) / math.sqrt(head_size) attn_weights = torch.matmul(q, k.transpose(2, 3)) / math.sqrt(head_size)
if attn_weights.size() != (bsz, num_heads, 1, seq_len): if attn_weights.size() != (bsz, num_heads, 1, seq_len):
raise ValueError(f"Got wrong attn_weights, should be in shape {(bsz,num_heads,1,seq_len)}.") raise ValueError(f"Got wrong attn_weights, should be in shape {(bsz,num_heads,1,seq_len)}.")
if attn_mask is not None: if attn_mask is not None:
attn_weights += attn_mask attn_weights += attn_mask
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(q.dtype)
# attn_weights = nn.functional.dropout(attn_weights,p=self.attention_dropout,training=False) maybe useless attn_output = torch.matmul(attn_weights, v)
attn_output = torch.matmul(attn_weights, value)
if attn_output.size() != (bsz, num_heads, 1, head_size): if attn_output.size() != (bsz, num_heads, 1, head_size):
raise ValueError(f"Got wrong attn_output, should be in shape {(bsz,num_heads,1,head_size)}.") raise ValueError(f"Got wrong attn_output, should be in shape {(bsz,num_heads,1,head_size)}.")
attn_output = attn_output.transpose(1, 2).contiguous().reshape(bsz, 1, -1) attn_output = attn_output.transpose(1, 2).contiguous().reshape(bsz, 1, -1)
del attn_weights
return attn_output return attn_output
@staticmethod
def no_pad_decoding_forward( def no_pad_decoding_forward(
self, self,
q: torch.Tensor, # [num_tokens, num_heads, head_size] q: torch.Tensor, # [num_tokens, num_heads, head_size]

View File

@ -3,7 +3,7 @@ import pytest
import colossalai import colossalai
from colossalai.inference.config import InferenceConfig from colossalai.inference.config import InferenceConfig
from colossalai.inference.struct import BatchInfo, Sequence from colossalai.inference.struct import BatchInfo, Sequence
from colossalai.testing import spawn from colossalai.testing import rerun_if_address_is_in_use, spawn
def check_config_and_inference(): def check_config_and_inference():
@ -74,6 +74,7 @@ def run_dist(rank, world_size, port):
@pytest.mark.dist @pytest.mark.dist
@rerun_if_address_is_in_use()
def test_config_and_inference(): def test_config_and_inference():
spawn(run_dist, 1) spawn(run_dist, 1)

View File

@ -11,7 +11,6 @@ from colossalai.inference.config import InferenceConfig
from colossalai.inference.core.engine import InferenceEngine from colossalai.inference.core.engine import InferenceEngine
from colossalai.testing import rerun_if_address_is_in_use, spawn from colossalai.testing import rerun_if_address_is_in_use, spawn
def setup_seed(seed): def setup_seed(seed):
torch.manual_seed(seed) torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed) torch.cuda.manual_seed_all(seed)

View File

@ -8,7 +8,7 @@ import colossalai
from colossalai.inference.config import InferenceConfig from colossalai.inference.config import InferenceConfig
from colossalai.inference.kv_cache import CacheBlock, KVCacheManager from colossalai.inference.kv_cache import CacheBlock, KVCacheManager
from colossalai.logging import disable_existing_loggers from colossalai.logging import disable_existing_loggers
from colossalai.testing import parameterize, spawn from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
@parameterize( @parameterize(
@ -155,6 +155,7 @@ def run_dist(rank, world_size, port):
@pytest.mark.dist @pytest.mark.dist
@rerun_if_address_is_in_use()
def test_cache_manager(): def test_cache_manager():
spawn(run_dist, 1) spawn(run_dist, 1)

View File

@ -3,15 +3,15 @@ import torch
from transformers.cache_utils import DynamicCache from transformers.cache_utils import DynamicCache
from transformers.modeling_attn_mask_utils import AttentionMaskConverter from transformers.modeling_attn_mask_utils import AttentionMaskConverter
from transformers.models.llama.configuration_llama import LlamaConfig from transformers.models.llama.configuration_llama import LlamaConfig
from transformers.models.llama.modeling_llama import LlamaAttention from transformers.models.llama.modeling_llama import LlamaAttention, apply_rotary_pos_emb
import colossalai import colossalai
from colossalai.inference.modeling.layers.attention import PagedAttention, convert_kvcache, copy_to_cache from colossalai.inference.modeling.layers.attention import PagedAttention, convert_kvcache, copy_to_cache
from colossalai.testing import spawn from colossalai.testing import rerun_if_address_is_in_use, spawn
def test_copy_to_cache(): def test_copy_to_cache():
key = torch.ones((2, 10, 3, 3)) key = torch.ones((2, 11, 3, 3))
key[0, 9, :, :] = 0 key[0, 9, :, :] = 0
key[1, -2:, :, :] = 0 key[1, -2:, :, :] = 0
cache = torch.zeros(8, 3, 3, 8) cache = torch.zeros(8, 3, 3, 8)
@ -32,7 +32,8 @@ def test_convert_kvcache():
key = torch.ones(2, 1, 3, 3) + 1 key = torch.ones(2, 1, 3, 3) + 1
lengths = torch.tensor([10, 9]) lengths = torch.tensor([10, 9])
block_tables = torch.tensor([[0, 1], [2, 3]]) block_tables = torch.tensor([[0, 1], [2, 3]])
converted_cache = convert_kvcache(key, cache=cache, lengths=lengths, block_tables=block_tables) copy_to_cache(key, cache=cache, lengths=lengths, block_tables=block_tables, type="decoding")
converted_cache = convert_kvcache(cache=cache, lengths=lengths, block_tables=block_tables)
assert converted_cache.shape == (2, 10, 3, 3) assert converted_cache.shape == (2, 10, 3, 3)
@ -40,7 +41,7 @@ def test_context_attention():
""" """
test config: head_num = 4, head_size = 4 test config: head_num = 4, head_size = 4
""" """
attn = PagedAttention(4, 4) attn = PagedAttention()
q = k = v = torch.randn(8, 4, 4) q = k = v = torch.randn(8, 4, 4)
k_cache = torch.empty(8, 4, 4, 8) k_cache = torch.empty(8, 4, 4, 8)
v_cache = torch.empty(8, 4, 4, 8) v_cache = torch.empty(8, 4, 4, 8)
@ -61,48 +62,72 @@ def test_context_attention():
# test accuracy with LlamaAttention # test accuracy with LlamaAttention
hidden_states = torch.randn(1, 8, 16) hidden_states = torch.randn(1, 8, 16)
proj_q = transformer_attn.q_proj(hidden_states).view(1, 8, 4, 4) proj_q = transformer_attn.q_proj(hidden_states).view(1, 8, 4, 4).transpose(1, 2)
proj_k = transformer_attn.k_proj(hidden_states).view(1, 8, 4, 4) proj_k = transformer_attn.k_proj(hidden_states).view(1, 8, 4, 4).transpose(1, 2)
proj_v = transformer_attn.v_proj(hidden_states).view(1, 8, 4, 4) proj_v = transformer_attn.v_proj(hidden_states).view(1, 8, 4, 4).transpose(1, 2)
pad_attn_output = attn.pad_context_forward(proj_q, proj_k, proj_v, k_cache, v_cache, context_lengths, block_tables)
pad_attn_output = transformer_attn.o_proj(pad_attn_output)
position_ids = torch.arange(0, 8, dtype=torch.long, device=proj_q.device)
position_ids = position_ids.unsqueeze(0)
cos, sin = transformer_attn.rotary_emb(proj_v, 8)
proj_q, proj_k = apply_rotary_pos_emb(proj_q, proj_k, cos, sin, position_ids)
pad_attn_output = attn.pad_context_forward(
proj_q.transpose(1, 2),
proj_k.transpose(1, 2),
proj_v.transpose(1, 2),
k_cache,
v_cache,
context_lengths,
block_tables,
)
pad_attn_output = transformer_attn.o_proj(pad_attn_output)
attn_mask = AttentionMaskConverter._make_causal_mask( attn_mask = AttentionMaskConverter._make_causal_mask(
hidden_states.shape[:2], q.dtype, q.device, past_key_values_length=0 hidden_states.shape[:2], q.dtype, q.device, past_key_values_length=0
) )
attn_mask += PagedAttention.generate_padding_mask(context_lengths, 8)
attn_output, _, _ = transformer_attn.forward(hidden_states, attention_mask=attn_mask) attn_output, _, _ = transformer_attn.forward(hidden_states, attention_mask=attn_mask)
assert torch.allclose(pad_attn_output, attn_output, atol=1e-3, rtol=1e-2) assert torch.allclose(pad_attn_output, attn_output, atol=1e-3, rtol=1e-3)
def test_decoding_attention(): def test_decoding_attention():
# test the pipeline of decoding attention # test the pipeline of decoding attention
attn = PagedAttention(4, 4) attn = PagedAttention()
q = k = v = torch.randn(2, 1, 4, 4) q = k = v = torch.randn(2, 1, 4, 8)
k_cache = torch.empty(8, 4, 4, 8) k_cache = torch.empty(8, 4, 8, 8)
v_cache = torch.empty(8, 4, 4, 8) v_cache = torch.empty(8, 4, 8, 8)
past_kv = torch.randn(2, 8, 4, 4) past_kv = torch.randn(2, 8, 4, 8)
context_lenghths = torch.tensor([8, 8]) context_lenghths = torch.tensor([8, 8])
lengths = context_lenghths + 1 lengths = context_lenghths + 1
block_tables = torch.tensor([[0, 1], [2, 3]]) block_tables = torch.tensor([[0, 1], [2, 3]])
copy_to_cache(past_kv, k_cache, lengths=context_lenghths, block_tables=block_tables) copy_to_cache(past_kv, k_cache, lengths=context_lenghths, block_tables=block_tables)
copy_to_cache(past_kv, v_cache, lengths=context_lenghths, block_tables=block_tables) copy_to_cache(past_kv, v_cache, lengths=context_lenghths, block_tables=block_tables)
attn.pad_decoding_forward(q, k, v, k_cache, v_cache, lengths=lengths, block_tables=block_tables) attn.pad_decoding_forward(q, k, v, k_cache, v_cache, lengths=lengths, block_tables=block_tables)
# test decoding accuracy, past_kv is reused # test decoding accuracy, past_kv is reused
config = LlamaConfig(num_attention_heads=4, num_key_value_heads=None, hidden_size=16) config = LlamaConfig(num_attention_heads=4, num_key_value_heads=None, hidden_size=32)
transformer_attn = LlamaAttention(config) transformer_attn = LlamaAttention(config)
transformer_attn.layer_idx = 0 transformer_attn.layer_idx = 0
transformer_attn.training = False transformer_attn.training = False
hidden_states = torch.randn(2, 1, 16) hidden_states = torch.randn(2, 1, 32)
proj_q = transformer_attn.q_proj(hidden_states).view(2, 1, 4, 4) proj_q = transformer_attn.q_proj(hidden_states).view(2, 1, 4, 8).transpose(1, 2)
proj_k = transformer_attn.k_proj(hidden_states).view(2, 1, 4, 4) proj_k = transformer_attn.k_proj(hidden_states).view(2, 1, 4, 8).transpose(1, 2)
proj_v = transformer_attn.v_proj(hidden_states).view(2, 1, 4, 4) proj_v = transformer_attn.v_proj(hidden_states).view(2, 1, 4, 8).transpose(1, 2)
cos, sin = transformer_attn.rotary_emb(proj_v, 16)
position_ids = lengths - 1
position_ids = position_ids.unsqueeze(1) # NOTE: this may be wrong
proj_q, proj_k = apply_rotary_pos_emb(proj_q, proj_k, cos, sin, position_ids, unsqueeze_dim=2)
llama_past_kv = DynamicCache() llama_past_kv = DynamicCache()
llama_past_kv.update(key_states=past_kv.transpose(1, 2), value_states=past_kv.transpose(1, 2), layer_idx=0) llama_past_kv.update(key_states=past_kv.transpose(1, 2), value_states=past_kv.transpose(1, 2), layer_idx=0)
# past_key_value shape in Llama: bsz, num_heads, seq_len, head_dim # past_key_value shape in Llama: bsz, num_heads, seq_len, head_dim
pad_attn_output = attn.pad_decoding_forward(proj_q, proj_k, proj_v, k_cache, v_cache, lengths, block_tables) pad_attn_output = attn.pad_decoding_forward(
attn_mask = AttentionMaskConverter._make_causal_mask(proj_q.shape[:2], q.dtype, q.device, past_key_values_length=8) proj_q.transpose(1, 2), proj_k.transpose(1, 2), proj_v.transpose(1, 2), k_cache, v_cache, lengths, block_tables
)
attn_mask = AttentionMaskConverter._make_causal_mask(q.shape[:2], q.dtype, q.device, past_key_values_length=8)
attn_mask = attn_mask + PagedAttention.generate_padding_mask(lengths, 9).unsqueeze(1).unsqueeze(2)
pad_attn_output = transformer_attn.o_proj(pad_attn_output) pad_attn_output = transformer_attn.o_proj(pad_attn_output)
position_ids = context_lenghths.unsqueeze(1) position_ids = context_lenghths.unsqueeze(1)
attn_output, _, _ = transformer_attn.forward( attn_output, _, _ = transformer_attn.forward(
@ -112,9 +137,9 @@ def test_decoding_attention():
def check_attention_layer(): def check_attention_layer():
# test_copy_to_cache() test_copy_to_cache()
# test_convert_kvcache() test_convert_kvcache()
# test_context_attention() test_context_attention()
test_decoding_attention() test_decoding_attention()
@ -124,6 +149,7 @@ def run_dist(rank, world_size, port):
@pytest.mark.dist @pytest.mark.dist
@rerun_if_address_is_in_use()
def test_attention_layer(): def test_attention_layer():
spawn(run_dist, 1) spawn(run_dist, 1)

View File

@ -6,7 +6,7 @@ import colossalai
from colossalai.inference.config import InferenceConfig from colossalai.inference.config import InferenceConfig
from colossalai.inference.core.request_handler import RequestHandler, RunningList from colossalai.inference.core.request_handler import RequestHandler, RunningList
from colossalai.inference.struct import RequestStatus, Sequence from colossalai.inference.struct import RequestStatus, Sequence
from colossalai.testing import spawn from colossalai.testing import rerun_if_address_is_in_use, spawn
def check_running_list(): def check_running_list():
@ -78,6 +78,7 @@ def run_dist(rank, world_size, port):
@pytest.mark.dist @pytest.mark.dist
@rerun_if_address_is_in_use()
def test_running_list_and_request_handler(): def test_running_list_and_request_handler():
spawn(run_dist, 1) spawn(run_dist, 1)