[Hotfix] Fix accuracy and align attention method api with Triton kernel (#5229)

* fix accuracy

* alignment in attention

* fix attention

* fix

* fix bugs

* fix bugs

* fix bugs
pull/5258/head
Jianghai 2024-01-08 15:56:00 +08:00 committed by FrankLeeeee
parent fa4fbdbffb
commit e545a871b8
6 changed files with 168 additions and 107 deletions

View File

@ -1,11 +1,9 @@
import math
from typing import Optional
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers.modeling_attn_mask_utils import AttentionMaskConverter
from transformers.models.llama.modeling_llama import LlamaRotaryEmbedding, apply_rotary_pos_emb
def copy_to_cache(source, cache, lengths, block_tables, type: str = "prefill"):
@ -13,12 +11,12 @@ def copy_to_cache(source, cache, lengths, block_tables, type: str = "prefill"):
Func: copy key/value into key/value cache.
Args: key/value(source): shape [bsz,seq_len,num_heads,head_size]
cache: shape [num_blocks, num_heads, head_size, block_size]
cache: shape [num_blocks, num_kv_heads, head_size, block_size]
lengths: key/value lengths
block_tables
"""
num_blocks, num_heads, head_size, block_size = cache.shape
bsz, max_seq_len = block_tables.shape
bsz, max_blocks_per_seq = block_tables.shape
needed_blocks = (lengths + block_size - 1) // block_size
if type == "prefill":
@ -42,13 +40,14 @@ def copy_to_cache(source, cache, lengths, block_tables, type: str = "prefill"):
return cache
def convert_kvcache(cache, lengths, block_tables):
def convert_kvcache(cache, lengths, block_tables, pad_id=0):
"""
Func: convert key/value cache for calculation
Args: cache: shape [num_blocks, num_heads, head_size, block_size]
lengths: key/value length
block_tables
pad_id: padded_id
"""
num_blocks, num_heads, head_size, block_size = cache.shape
@ -64,35 +63,29 @@ def convert_kvcache(cache, lengths, block_tables):
_cache = torch.cat(
(
cache1,
cache2,
cache[block_tables[i][: needed_blocks[i] - 1]].permute((0, 3, 1, 2)).reshape(-1, num_heads, head_size),
cache[block_tables[i][needed_blocks[i] - 1], :, :, : num_remaing_tokens[i]].permute(2, 0, 1),
),
dim=0,
)
padding = seq_len - _cache.size(0)
if padding > 0:
_cache = F.pad(_cache, (0, 0, 0, 0, 0, 1))
_cache = F.pad(_cache, (0, 0, 0, 0, 0, 1), value=pad_id)
padded_cache.append(_cache)
return torch.stack(padded_cache, dim=0)
class PagedAttention(nn.Module):
class PagedAttention:
"""
Pure Torch implementation version of paged_attention.
Holds different types of forward function and useful components.
"""
def __init__(self, num_heads: int, head_size: int, scale: float = 1.0, sliding_window: Optional[int] = None):
super().__init__()
self.num_heads = num_heads
self.head_size = head_size
self.scale = float(scale)
self.sliding_window = sliding_window
self._init_rope()
def _init_rope(self):
self.rotary_emb = LlamaRotaryEmbedding(self.head_size)
def pad_and_reshape(self, tensor, seq_lengths, max_seq_len, num_heads, head_size):
@staticmethod
def pad_and_reshape(tensor, seq_lengths, max_seq_len, num_heads, head_size):
"""
Transform 1D no_pad tensor into 2D padded tensor with shape [bsz,seq_len,num_heads,head_size]
"""
bsz = len(seq_lengths)
padded_tensor = torch.zeros(bsz, max_seq_len, num_heads, head_size)
@ -103,22 +96,49 @@ class PagedAttention(nn.Module):
token_idx += seq_len
return padded_tensor
def generate_padding_mask(self, lengths, max_seq_len):
@staticmethod
def generate_padding_mask(lengths, max_seq_len):
range_tensor = torch.arange(max_seq_len).expand(len(lengths), max_seq_len)
padding_mask = range_tensor < lengths.unsqueeze(1)
return padding_mask
@staticmethod
def repeat_kv(hidden_states: torch.Tensor, n_rep: int = 1) -> torch.Tensor:
"""
Essential component for MQA. Equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep).
Args: hidden_states(batch, num_key_value_heads, seqlen, head_dim)
n_rep: times of repeatition.
Output: hidden_states (batch, num_attention_heads, seqlen, head_dim)
"""
if n_rep == 1:
return hidden_states
batch, num_key_value_heads, seq_len, head_dim = hidden_states.shape
num_attention_heads = n_rep * num_key_value_heads
hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, seq_len, head_dim)
return hidden_states.reshape(batch, num_attention_heads, seq_len, head_dim)
@staticmethod
def nopad_context_forward(
self,
q: torch.Tensor, # [num_tokens, num_heads, head_size]
k: torch.Tensor,
k: torch.Tensor, # [num_tokens, num_kv_heads, head_size]
v: torch.Tensor,
k_cache: torch.Tensor, # [num_blocks, num_heads, head_size, block_size]
v_cache: torch.Tensor,
context_lengths: torch.Tensor, # [num_seqs]
block_tables: torch.Tensor, # [num_seqs,max_blocks_per_sequence]
):
"""
NOTE: q,k,v are projected and applied rotary embedding, all aligned with triton version.
"""
# Fisrt, do shape verification
num_tokens, num_heads, head_size = q.shape
num_kv_heads = k.shape[-2]
assert num_heads % num_kv_heads == 0, "num_kv_heads should be divisible by num_heads"
num_kv_groups = num_heads // num_kv_heads
block_size = k_cache.shape[-1]
bsz, max_blocks_per_sequence = block_tables.shape
max_seq_len = max_blocks_per_sequence * block_size
@ -127,80 +147,85 @@ class PagedAttention(nn.Module):
assert context_lengths.shape[0] == block_tables.shape[0]
shape = (bsz, max_seq_len, num_heads, head_size)
input_shape = shape[:2]
query = self.pad_and_reshape(q, context_lengths, max_seq_len, num_heads, head_size).transpose(1, 2)
key = self.pad_and_reshape(k, context_lengths, max_seq_len, num_heads, head_size).transpose(1, 2)
value = self.pad_and_reshape(v, context_lengths, max_seq_len, num_heads, head_size).transpose(1, 2)
q = PagedAttention.pad_and_reshape(
q, context_lengths, max_seq_len, num_heads, head_size
) # bsz,seqlen,num_heads,head_size
k = PagedAttention.pad_and_reshape(k, context_lengths, max_seq_len, num_heads, head_size)
v = PagedAttention.pad_and_reshape(v, context_lengths, max_seq_len, num_heads, head_size)
copy_to_cache(k, k_cache, lengths=context_lengths, block_tables=block_tables)
copy_to_cache(v, v_cache, lengths=context_lengths, block_tables=block_tables)
attn_mask = AttentionMaskConverter._make_causal_mask(input_shape, q.dtype, q.device, past_key_values_length=0)
self.generate_padding_mask(context_lengths, max_seq_len)
attn_mask = attn_mask + PagedAttention.generate_padding_mask(context_lengths, max_seq_len)
position_ids = torch.arange(0, max_seq_len, dtype=torch.long, device=query.device)
position_ids = position_ids.unsqueeze(0)
q = q.transpose(1, 2)
k = PagedAttention.repeat_kv(k.transpose(1, 2), num_kv_groups)
v = PagedAttention.repeat_kv(v.transpose(1, 2), num_kv_groups)
cos, sin = self.rotary_emb(value, max_seq_len)
query, key = apply_rotary_pos_emb(query, key, cos, sin, position_ids)
copy_to_cache(key.transpose(1, 2), k_cache, lengths=context_lengths, block_tables=block_tables)
copy_to_cache(value.transpose(1, 2), v_cache, lengths=context_lengths, block_tables=block_tables)
attn_weights = torch.matmul(query, key.transpose(2, 3)) / math.sqrt(head_size)
# position_ids = torch.arange(0, max_seq_len, dtype=torch.long, device=query.device)
# position_ids = position_ids.unsqueeze(0)
# cos, sin = self.rotary_emb(value, max_seq_len)
# query, key = apply_rotary_pos_emb(query, key, cos, sin, position_ids)
attn_weights = torch.matmul(q, k.transpose(2, 3)) / math.sqrt(head_size)
if attn_weights.size() != (bsz, num_heads, max_seq_len, max_seq_len):
raise ValueError(f"Got wrong attn_weights, should be in shape {(bsz,num_heads,max_seq_len,max_seq_len)}.")
if attn_mask is not None:
attn_weights += attn_mask
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
# attn_weights = nn.functional.dropout(attn_weights,p=self.attention_dropout,training=False) maybe useless
attn_output = torch.matmul(attn_weights, value)
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(q.dtype)
attn_output = torch.matmul(attn_weights, v)
if attn_output.size() != (bsz, num_heads, max_seq_len, head_size):
raise ValueError(f"Got wrong attn_output, should be in shape {(bsz,num_heads,max_seq_len,head_size)}.")
attn_output = attn_output.transpose(1, 2).contiguous().reshape(bsz, max_seq_len, -1)
del attn_weights
return attn_output
@staticmethod
def pad_context_forward(
self,
q: torch.Tensor, # [batch_size, seq_len, num_heads, head_size]
k: torch.Tensor,
k: torch.Tensor, # [batch_size, seq_len, num_kv_heads, head_size]
v: torch.Tensor,
k_cache: torch.Tensor, # [num_blocks, num_heads, head_size, block_size]
v_cache: torch.Tensor,
context_lengths: torch.Tensor, # [num_seqs]
block_tables: torch.Tensor, # [num_seqs,max_blocks_per_sequence]
):
# Firt, do shape verification
bsz, seq_len, num_heads, head_size = q.shape
num_kv_heads = k.shape[-2]
assert num_heads % num_kv_heads == 0, "num_kv_heads should be divisible by num_heads"
num_kv_groups = num_heads // num_kv_heads
block_size = k_cache.shape[-1]
assert q.shape[0] == k.shape[0] == v.shape[0] == block_tables.shape[0]
block_tables.shape[-1] * block_size
shape = (bsz, seq_len, num_heads, head_size)
input_shape = shape[:2]
# Copy kv to memory(rotary embedded)
copy_to_cache(k, k_cache, lengths=context_lengths, block_tables=block_tables)
copy_to_cache(v, v_cache, lengths=context_lengths, block_tables=block_tables)
q = q.transpose(1, 2)
k = k.transpose(1, 2)
v = v.transpose(1, 2)
k = PagedAttention.repeat_kv(k.transpose(1, 2), num_kv_groups)
v = PagedAttention.repeat_kv(v.transpose(1, 2), num_kv_groups)
position_ids = torch.arange(0, seq_len, dtype=torch.long, device=q.device)
position_ids = position_ids.unsqueeze(0)
cos, sin = self.rotary_emb(v, seq_len)
query, key = apply_rotary_pos_emb(q, k, cos, sin, position_ids)
copy_to_cache(key.transpose(1, 2), k_cache, lengths=context_lengths, block_tables=block_tables)
copy_to_cache(v.transpose(1, 2), v_cache, lengths=context_lengths, block_tables=block_tables)
attn_weights = torch.matmul(query, key.transpose(2, 3)) / math.sqrt(head_size)
attn_weights = torch.matmul(q, k.transpose(2, 3)) / math.sqrt(head_size)
attn_mask = AttentionMaskConverter._make_causal_mask(input_shape, q.dtype, q.device, past_key_values_length=0)
self.generate_padding_mask(context_lengths, seq_len)
attn_mask = attn_mask + PagedAttention.generate_padding_mask(context_lengths, seq_len)
if attn_weights.size() != (bsz, num_heads, seq_len, seq_len):
raise ValueError(f"Got wrong attn_weights, should be in shape {(bsz,num_heads,seq_len,seq_len)}.")
if attn_mask is not None:
attn_weights += attn_mask
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
# attn_weights = nn.functional.dropout(attn_weights,p=self.attention_dropout,training=False) maybe useless
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(q.dtype)
attn_output = torch.matmul(attn_weights, v)
if attn_output.size() != (bsz, num_heads, seq_len, head_size):
@ -208,62 +233,70 @@ class PagedAttention(nn.Module):
attn_output = attn_output.transpose(1, 2).contiguous().reshape(bsz, seq_len, -1)
del attn_weights
return attn_output
@staticmethod
def pad_decoding_forward(
self,
q: torch.Tensor, # [bsz, 1, num_heads, head_size]
k: torch.Tensor,
k: torch.Tensor, # [bsz, 1, num_kv_heads, head_size]
v: torch.Tensor,
k_cache: torch.Tensor, # [num_blocks, num_heads, head_size, block_size]
v_cache: torch.Tensor,
lengths: torch.Tensor, # [num_seqs]: input_lengths + output_lengths
block_tables: torch.Tensor, # [num_seqs,max_blocks_per_sequence]
):
# Firt, do shape verification.
bsz, _, num_heads, head_size = q.shape
num_kv_heads = k.shape[-2]
assert num_heads % num_kv_heads == 0, "num_kv_heads should be divisible by num_heads"
num_kv_groups = num_heads // num_kv_heads
block_size = k_cache.shape[-1]
seq_len = max(lengths)
assert q.shape[0] == k.shape[0] == v.shape[0] == block_tables.shape[0]
max_seq_len = block_tables.shape[-1] * block_size
block_tables.shape[-1] * block_size
attn_mask = AttentionMaskConverter._make_causal_mask(
q.shape[:2], q.dtype, q.device, past_key_values_length=seq_len - 1
)
self.generate_padding_mask(lengths, max_seq_len)
cos, sin = self.rotary_emb(v, max_seq_len)
attn_mask = attn_mask + PagedAttention.generate_padding_mask(lengths, seq_len).unsqueeze(1).unsqueeze(2)
# cos, sin = self.rotary_emb(v, max_seq_len)
# position_ids = lengths - 1
# position_ids = position_ids.unsqueeze(1)
# query, key = apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=2)
position_ids = lengths - 1
position_ids = position_ids.unsqueeze(1)
query, key = apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=2)
copy_to_cache(key, k_cache, lengths=lengths, block_tables=block_tables, type="decoding")
copy_to_cache(k, k_cache, lengths=lengths, block_tables=block_tables, type="decoding")
copy_to_cache(v, v_cache, lengths=lengths, block_tables=block_tables, type="decoding")
key = convert_kvcache(key, k_cache, lengths, block_tables) # bsz, seqlen,
value = convert_kvcache(v, v_cache, lengths, block_tables)
k = convert_kvcache(k_cache, lengths, block_tables) # bsz, seqlen,
v = convert_kvcache(v_cache, lengths, block_tables)
query = query.transpose(1, 2)
key = key.transpose(1, 2)
value = value.transpose(1, 2)
q = q.transpose(1, 2)
k = PagedAttention.repeat_kv(k.transpose(1, 2), num_kv_groups)
v = PagedAttention.repeat_kv(v.transpose(1, 2), num_kv_groups)
attn_weights = torch.matmul(query, key.transpose(2, 3)) / math.sqrt(head_size)
attn_weights = torch.matmul(q, k.transpose(2, 3)) / math.sqrt(head_size)
if attn_weights.size() != (bsz, num_heads, 1, seq_len):
raise ValueError(f"Got wrong attn_weights, should be in shape {(bsz,num_heads,1,seq_len)}.")
if attn_mask is not None:
attn_weights += attn_mask
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
# attn_weights = nn.functional.dropout(attn_weights,p=self.attention_dropout,training=False) maybe useless
attn_output = torch.matmul(attn_weights, value)
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(q.dtype)
attn_output = torch.matmul(attn_weights, v)
if attn_output.size() != (bsz, num_heads, 1, head_size):
raise ValueError(f"Got wrong attn_output, should be in shape {(bsz,num_heads,1,head_size)}.")
attn_output = attn_output.transpose(1, 2).contiguous().reshape(bsz, 1, -1)
del attn_weights
return attn_output
@staticmethod
def no_pad_decoding_forward(
self,
q: torch.Tensor, # [num_tokens, num_heads, head_size]

View File

@ -3,7 +3,7 @@ import pytest
import colossalai
from colossalai.inference.config import InferenceConfig
from colossalai.inference.struct import BatchInfo, Sequence
from colossalai.testing import spawn
from colossalai.testing import rerun_if_address_is_in_use, spawn
def check_config_and_inference():
@ -74,6 +74,7 @@ def run_dist(rank, world_size, port):
@pytest.mark.dist
@rerun_if_address_is_in_use()
def test_config_and_inference():
spawn(run_dist, 1)

View File

@ -11,7 +11,6 @@ from colossalai.inference.config import InferenceConfig
from colossalai.inference.core.engine import InferenceEngine
from colossalai.testing import rerun_if_address_is_in_use, spawn
def setup_seed(seed):
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)

View File

@ -8,7 +8,7 @@ import colossalai
from colossalai.inference.config import InferenceConfig
from colossalai.inference.kv_cache import CacheBlock, KVCacheManager
from colossalai.logging import disable_existing_loggers
from colossalai.testing import parameterize, spawn
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
@parameterize(
@ -155,6 +155,7 @@ def run_dist(rank, world_size, port):
@pytest.mark.dist
@rerun_if_address_is_in_use()
def test_cache_manager():
spawn(run_dist, 1)

View File

@ -3,15 +3,15 @@ import torch
from transformers.cache_utils import DynamicCache
from transformers.modeling_attn_mask_utils import AttentionMaskConverter
from transformers.models.llama.configuration_llama import LlamaConfig
from transformers.models.llama.modeling_llama import LlamaAttention
from transformers.models.llama.modeling_llama import LlamaAttention, apply_rotary_pos_emb
import colossalai
from colossalai.inference.modeling.layers.attention import PagedAttention, convert_kvcache, copy_to_cache
from colossalai.testing import spawn
from colossalai.testing import rerun_if_address_is_in_use, spawn
def test_copy_to_cache():
key = torch.ones((2, 10, 3, 3))
key = torch.ones((2, 11, 3, 3))
key[0, 9, :, :] = 0
key[1, -2:, :, :] = 0
cache = torch.zeros(8, 3, 3, 8)
@ -32,7 +32,8 @@ def test_convert_kvcache():
key = torch.ones(2, 1, 3, 3) + 1
lengths = torch.tensor([10, 9])
block_tables = torch.tensor([[0, 1], [2, 3]])
converted_cache = convert_kvcache(key, cache=cache, lengths=lengths, block_tables=block_tables)
copy_to_cache(key, cache=cache, lengths=lengths, block_tables=block_tables, type="decoding")
converted_cache = convert_kvcache(cache=cache, lengths=lengths, block_tables=block_tables)
assert converted_cache.shape == (2, 10, 3, 3)
@ -40,7 +41,7 @@ def test_context_attention():
"""
test config: head_num = 4, head_size = 4
"""
attn = PagedAttention(4, 4)
attn = PagedAttention()
q = k = v = torch.randn(8, 4, 4)
k_cache = torch.empty(8, 4, 4, 8)
v_cache = torch.empty(8, 4, 4, 8)
@ -61,48 +62,72 @@ def test_context_attention():
# test accuracy with LlamaAttention
hidden_states = torch.randn(1, 8, 16)
proj_q = transformer_attn.q_proj(hidden_states).view(1, 8, 4, 4)
proj_k = transformer_attn.k_proj(hidden_states).view(1, 8, 4, 4)
proj_v = transformer_attn.v_proj(hidden_states).view(1, 8, 4, 4)
pad_attn_output = attn.pad_context_forward(proj_q, proj_k, proj_v, k_cache, v_cache, context_lengths, block_tables)
pad_attn_output = transformer_attn.o_proj(pad_attn_output)
proj_q = transformer_attn.q_proj(hidden_states).view(1, 8, 4, 4).transpose(1, 2)
proj_k = transformer_attn.k_proj(hidden_states).view(1, 8, 4, 4).transpose(1, 2)
proj_v = transformer_attn.v_proj(hidden_states).view(1, 8, 4, 4).transpose(1, 2)
position_ids = torch.arange(0, 8, dtype=torch.long, device=proj_q.device)
position_ids = position_ids.unsqueeze(0)
cos, sin = transformer_attn.rotary_emb(proj_v, 8)
proj_q, proj_k = apply_rotary_pos_emb(proj_q, proj_k, cos, sin, position_ids)
pad_attn_output = attn.pad_context_forward(
proj_q.transpose(1, 2),
proj_k.transpose(1, 2),
proj_v.transpose(1, 2),
k_cache,
v_cache,
context_lengths,
block_tables,
)
pad_attn_output = transformer_attn.o_proj(pad_attn_output)
attn_mask = AttentionMaskConverter._make_causal_mask(
hidden_states.shape[:2], q.dtype, q.device, past_key_values_length=0
)
attn_mask += PagedAttention.generate_padding_mask(context_lengths, 8)
attn_output, _, _ = transformer_attn.forward(hidden_states, attention_mask=attn_mask)
assert torch.allclose(pad_attn_output, attn_output, atol=1e-3, rtol=1e-2)
assert torch.allclose(pad_attn_output, attn_output, atol=1e-3, rtol=1e-3)
def test_decoding_attention():
# test the pipeline of decoding attention
attn = PagedAttention(4, 4)
q = k = v = torch.randn(2, 1, 4, 4)
k_cache = torch.empty(8, 4, 4, 8)
v_cache = torch.empty(8, 4, 4, 8)
past_kv = torch.randn(2, 8, 4, 4)
attn = PagedAttention()
q = k = v = torch.randn(2, 1, 4, 8)
k_cache = torch.empty(8, 4, 8, 8)
v_cache = torch.empty(8, 4, 8, 8)
past_kv = torch.randn(2, 8, 4, 8)
context_lenghths = torch.tensor([8, 8])
lengths = context_lenghths + 1
block_tables = torch.tensor([[0, 1], [2, 3]])
copy_to_cache(past_kv, k_cache, lengths=context_lenghths, block_tables=block_tables)
copy_to_cache(past_kv, v_cache, lengths=context_lenghths, block_tables=block_tables)
attn.pad_decoding_forward(q, k, v, k_cache, v_cache, lengths=lengths, block_tables=block_tables)
# test decoding accuracy, past_kv is reused
config = LlamaConfig(num_attention_heads=4, num_key_value_heads=None, hidden_size=16)
config = LlamaConfig(num_attention_heads=4, num_key_value_heads=None, hidden_size=32)
transformer_attn = LlamaAttention(config)
transformer_attn.layer_idx = 0
transformer_attn.training = False
hidden_states = torch.randn(2, 1, 16)
proj_q = transformer_attn.q_proj(hidden_states).view(2, 1, 4, 4)
proj_k = transformer_attn.k_proj(hidden_states).view(2, 1, 4, 4)
proj_v = transformer_attn.v_proj(hidden_states).view(2, 1, 4, 4)
hidden_states = torch.randn(2, 1, 32)
proj_q = transformer_attn.q_proj(hidden_states).view(2, 1, 4, 8).transpose(1, 2)
proj_k = transformer_attn.k_proj(hidden_states).view(2, 1, 4, 8).transpose(1, 2)
proj_v = transformer_attn.v_proj(hidden_states).view(2, 1, 4, 8).transpose(1, 2)
cos, sin = transformer_attn.rotary_emb(proj_v, 16)
position_ids = lengths - 1
position_ids = position_ids.unsqueeze(1) # NOTE: this may be wrong
proj_q, proj_k = apply_rotary_pos_emb(proj_q, proj_k, cos, sin, position_ids, unsqueeze_dim=2)
llama_past_kv = DynamicCache()
llama_past_kv.update(key_states=past_kv.transpose(1, 2), value_states=past_kv.transpose(1, 2), layer_idx=0)
# past_key_value shape in Llama: bsz, num_heads, seq_len, head_dim
pad_attn_output = attn.pad_decoding_forward(proj_q, proj_k, proj_v, k_cache, v_cache, lengths, block_tables)
attn_mask = AttentionMaskConverter._make_causal_mask(proj_q.shape[:2], q.dtype, q.device, past_key_values_length=8)
pad_attn_output = attn.pad_decoding_forward(
proj_q.transpose(1, 2), proj_k.transpose(1, 2), proj_v.transpose(1, 2), k_cache, v_cache, lengths, block_tables
)
attn_mask = AttentionMaskConverter._make_causal_mask(q.shape[:2], q.dtype, q.device, past_key_values_length=8)
attn_mask = attn_mask + PagedAttention.generate_padding_mask(lengths, 9).unsqueeze(1).unsqueeze(2)
pad_attn_output = transformer_attn.o_proj(pad_attn_output)
position_ids = context_lenghths.unsqueeze(1)
attn_output, _, _ = transformer_attn.forward(
@ -112,9 +137,9 @@ def test_decoding_attention():
def check_attention_layer():
# test_copy_to_cache()
# test_convert_kvcache()
# test_context_attention()
test_copy_to_cache()
test_convert_kvcache()
test_context_attention()
test_decoding_attention()
@ -124,6 +149,7 @@ def run_dist(rank, world_size, port):
@pytest.mark.dist
@rerun_if_address_is_in_use()
def test_attention_layer():
spawn(run_dist, 1)

View File

@ -6,7 +6,7 @@ import colossalai
from colossalai.inference.config import InferenceConfig
from colossalai.inference.core.request_handler import RequestHandler, RunningList
from colossalai.inference.struct import RequestStatus, Sequence
from colossalai.testing import spawn
from colossalai.testing import rerun_if_address_is_in_use, spawn
def check_running_list():
@ -78,6 +78,7 @@ def run_dist(rank, world_size, port):
@pytest.mark.dist
@rerun_if_address_is_in_use()
def test_running_list_and_request_handler():
spawn(run_dist, 1)