diff --git a/colossalai/inference/modeling/layers/attention.py b/colossalai/inference/modeling/layers/attention.py index 4619e8c45..8f6d6b569 100644 --- a/colossalai/inference/modeling/layers/attention.py +++ b/colossalai/inference/modeling/layers/attention.py @@ -1,11 +1,9 @@ import math -from typing import Optional import torch import torch.nn as nn import torch.nn.functional as F from transformers.modeling_attn_mask_utils import AttentionMaskConverter -from transformers.models.llama.modeling_llama import LlamaRotaryEmbedding, apply_rotary_pos_emb def copy_to_cache(source, cache, lengths, block_tables, type: str = "prefill"): @@ -13,12 +11,12 @@ def copy_to_cache(source, cache, lengths, block_tables, type: str = "prefill"): Func: copy key/value into key/value cache. Args: key/value(source): shape [bsz,seq_len,num_heads,head_size] - cache: shape [num_blocks, num_heads, head_size, block_size] + cache: shape [num_blocks, num_kv_heads, head_size, block_size] lengths: key/value lengths block_tables """ num_blocks, num_heads, head_size, block_size = cache.shape - bsz, max_seq_len = block_tables.shape + bsz, max_blocks_per_seq = block_tables.shape needed_blocks = (lengths + block_size - 1) // block_size if type == "prefill": @@ -42,13 +40,14 @@ def copy_to_cache(source, cache, lengths, block_tables, type: str = "prefill"): return cache -def convert_kvcache(cache, lengths, block_tables): +def convert_kvcache(cache, lengths, block_tables, pad_id=0): """ Func: convert key/value cache for calculation Args: cache: shape [num_blocks, num_heads, head_size, block_size] lengths: key/value length block_tables + pad_id: padded_id """ num_blocks, num_heads, head_size, block_size = cache.shape @@ -64,35 +63,29 @@ def convert_kvcache(cache, lengths, block_tables): _cache = torch.cat( ( - cache1, - cache2, + cache[block_tables[i][: needed_blocks[i] - 1]].permute((0, 3, 1, 2)).reshape(-1, num_heads, head_size), + cache[block_tables[i][needed_blocks[i] - 1], :, :, : num_remaing_tokens[i]].permute(2, 0, 1), ), dim=0, ) padding = seq_len - _cache.size(0) if padding > 0: - _cache = F.pad(_cache, (0, 0, 0, 0, 0, 1)) + _cache = F.pad(_cache, (0, 0, 0, 0, 0, 1), value=pad_id) padded_cache.append(_cache) return torch.stack(padded_cache, dim=0) -class PagedAttention(nn.Module): +class PagedAttention: """ Pure Torch implementation version of paged_attention. + Holds different types of forward function and useful components. """ - def __init__(self, num_heads: int, head_size: int, scale: float = 1.0, sliding_window: Optional[int] = None): - super().__init__() - self.num_heads = num_heads - self.head_size = head_size - self.scale = float(scale) - self.sliding_window = sliding_window - self._init_rope() - - def _init_rope(self): - self.rotary_emb = LlamaRotaryEmbedding(self.head_size) - - def pad_and_reshape(self, tensor, seq_lengths, max_seq_len, num_heads, head_size): + @staticmethod + def pad_and_reshape(tensor, seq_lengths, max_seq_len, num_heads, head_size): + """ + Transform 1D no_pad tensor into 2D padded tensor with shape [bsz,seq_len,num_heads,head_size] + """ bsz = len(seq_lengths) padded_tensor = torch.zeros(bsz, max_seq_len, num_heads, head_size) @@ -103,22 +96,49 @@ class PagedAttention(nn.Module): token_idx += seq_len return padded_tensor - def generate_padding_mask(self, lengths, max_seq_len): + @staticmethod + def generate_padding_mask(lengths, max_seq_len): range_tensor = torch.arange(max_seq_len).expand(len(lengths), max_seq_len) padding_mask = range_tensor < lengths.unsqueeze(1) return padding_mask + @staticmethod + def repeat_kv(hidden_states: torch.Tensor, n_rep: int = 1) -> torch.Tensor: + """ + Essential component for MQA. Equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). + Args: hidden_states(batch, num_key_value_heads, seqlen, head_dim) + n_rep: times of repeatition. + Output: hidden_states (batch, num_attention_heads, seqlen, head_dim) + """ + if n_rep == 1: + return hidden_states + + batch, num_key_value_heads, seq_len, head_dim = hidden_states.shape + num_attention_heads = n_rep * num_key_value_heads + hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, seq_len, head_dim) + + return hidden_states.reshape(batch, num_attention_heads, seq_len, head_dim) + + @staticmethod def nopad_context_forward( - self, q: torch.Tensor, # [num_tokens, num_heads, head_size] - k: torch.Tensor, + k: torch.Tensor, # [num_tokens, num_kv_heads, head_size] v: torch.Tensor, k_cache: torch.Tensor, # [num_blocks, num_heads, head_size, block_size] v_cache: torch.Tensor, context_lengths: torch.Tensor, # [num_seqs] block_tables: torch.Tensor, # [num_seqs,max_blocks_per_sequence] ): + """ + NOTE: q,k,v are projected and applied rotary embedding, all aligned with triton version. + """ + # Fisrt, do shape verification num_tokens, num_heads, head_size = q.shape + num_kv_heads = k.shape[-2] + + assert num_heads % num_kv_heads == 0, "num_kv_heads should be divisible by num_heads" + num_kv_groups = num_heads // num_kv_heads + block_size = k_cache.shape[-1] bsz, max_blocks_per_sequence = block_tables.shape max_seq_len = max_blocks_per_sequence * block_size @@ -127,80 +147,85 @@ class PagedAttention(nn.Module): assert context_lengths.shape[0] == block_tables.shape[0] shape = (bsz, max_seq_len, num_heads, head_size) input_shape = shape[:2] - query = self.pad_and_reshape(q, context_lengths, max_seq_len, num_heads, head_size).transpose(1, 2) - key = self.pad_and_reshape(k, context_lengths, max_seq_len, num_heads, head_size).transpose(1, 2) - value = self.pad_and_reshape(v, context_lengths, max_seq_len, num_heads, head_size).transpose(1, 2) + + q = PagedAttention.pad_and_reshape( + q, context_lengths, max_seq_len, num_heads, head_size + ) # bsz,seqlen,num_heads,head_size + k = PagedAttention.pad_and_reshape(k, context_lengths, max_seq_len, num_heads, head_size) + v = PagedAttention.pad_and_reshape(v, context_lengths, max_seq_len, num_heads, head_size) + + copy_to_cache(k, k_cache, lengths=context_lengths, block_tables=block_tables) + copy_to_cache(v, v_cache, lengths=context_lengths, block_tables=block_tables) attn_mask = AttentionMaskConverter._make_causal_mask(input_shape, q.dtype, q.device, past_key_values_length=0) - self.generate_padding_mask(context_lengths, max_seq_len) + attn_mask = attn_mask + PagedAttention.generate_padding_mask(context_lengths, max_seq_len) - position_ids = torch.arange(0, max_seq_len, dtype=torch.long, device=query.device) - position_ids = position_ids.unsqueeze(0) + q = q.transpose(1, 2) + k = PagedAttention.repeat_kv(k.transpose(1, 2), num_kv_groups) + v = PagedAttention.repeat_kv(v.transpose(1, 2), num_kv_groups) - cos, sin = self.rotary_emb(value, max_seq_len) - query, key = apply_rotary_pos_emb(query, key, cos, sin, position_ids) - - copy_to_cache(key.transpose(1, 2), k_cache, lengths=context_lengths, block_tables=block_tables) - copy_to_cache(value.transpose(1, 2), v_cache, lengths=context_lengths, block_tables=block_tables) - - attn_weights = torch.matmul(query, key.transpose(2, 3)) / math.sqrt(head_size) + # position_ids = torch.arange(0, max_seq_len, dtype=torch.long, device=query.device) + # position_ids = position_ids.unsqueeze(0) + # cos, sin = self.rotary_emb(value, max_seq_len) + # query, key = apply_rotary_pos_emb(query, key, cos, sin, position_ids) + attn_weights = torch.matmul(q, k.transpose(2, 3)) / math.sqrt(head_size) if attn_weights.size() != (bsz, num_heads, max_seq_len, max_seq_len): raise ValueError(f"Got wrong attn_weights, should be in shape {(bsz,num_heads,max_seq_len,max_seq_len)}.") if attn_mask is not None: attn_weights += attn_mask - attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) - # attn_weights = nn.functional.dropout(attn_weights,p=self.attention_dropout,training=False) maybe useless - attn_output = torch.matmul(attn_weights, value) + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(q.dtype) + attn_output = torch.matmul(attn_weights, v) if attn_output.size() != (bsz, num_heads, max_seq_len, head_size): raise ValueError(f"Got wrong attn_output, should be in shape {(bsz,num_heads,max_seq_len,head_size)}.") attn_output = attn_output.transpose(1, 2).contiguous().reshape(bsz, max_seq_len, -1) + del attn_weights + return attn_output + @staticmethod def pad_context_forward( - self, q: torch.Tensor, # [batch_size, seq_len, num_heads, head_size] - k: torch.Tensor, + k: torch.Tensor, # [batch_size, seq_len, num_kv_heads, head_size] v: torch.Tensor, k_cache: torch.Tensor, # [num_blocks, num_heads, head_size, block_size] v_cache: torch.Tensor, context_lengths: torch.Tensor, # [num_seqs] block_tables: torch.Tensor, # [num_seqs,max_blocks_per_sequence] ): + # Firt, do shape verification bsz, seq_len, num_heads, head_size = q.shape + num_kv_heads = k.shape[-2] + assert num_heads % num_kv_heads == 0, "num_kv_heads should be divisible by num_heads" + num_kv_groups = num_heads // num_kv_heads block_size = k_cache.shape[-1] assert q.shape[0] == k.shape[0] == v.shape[0] == block_tables.shape[0] block_tables.shape[-1] * block_size shape = (bsz, seq_len, num_heads, head_size) input_shape = shape[:2] + + # Copy kv to memory(rotary embedded) + copy_to_cache(k, k_cache, lengths=context_lengths, block_tables=block_tables) + copy_to_cache(v, v_cache, lengths=context_lengths, block_tables=block_tables) + q = q.transpose(1, 2) - k = k.transpose(1, 2) - v = v.transpose(1, 2) + k = PagedAttention.repeat_kv(k.transpose(1, 2), num_kv_groups) + v = PagedAttention.repeat_kv(v.transpose(1, 2), num_kv_groups) - position_ids = torch.arange(0, seq_len, dtype=torch.long, device=q.device) - position_ids = position_ids.unsqueeze(0) - cos, sin = self.rotary_emb(v, seq_len) - query, key = apply_rotary_pos_emb(q, k, cos, sin, position_ids) - - copy_to_cache(key.transpose(1, 2), k_cache, lengths=context_lengths, block_tables=block_tables) - copy_to_cache(v.transpose(1, 2), v_cache, lengths=context_lengths, block_tables=block_tables) - - attn_weights = torch.matmul(query, key.transpose(2, 3)) / math.sqrt(head_size) + attn_weights = torch.matmul(q, k.transpose(2, 3)) / math.sqrt(head_size) attn_mask = AttentionMaskConverter._make_causal_mask(input_shape, q.dtype, q.device, past_key_values_length=0) - self.generate_padding_mask(context_lengths, seq_len) + attn_mask = attn_mask + PagedAttention.generate_padding_mask(context_lengths, seq_len) if attn_weights.size() != (bsz, num_heads, seq_len, seq_len): raise ValueError(f"Got wrong attn_weights, should be in shape {(bsz,num_heads,seq_len,seq_len)}.") if attn_mask is not None: attn_weights += attn_mask - attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) - - # attn_weights = nn.functional.dropout(attn_weights,p=self.attention_dropout,training=False) maybe useless + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(q.dtype) attn_output = torch.matmul(attn_weights, v) if attn_output.size() != (bsz, num_heads, seq_len, head_size): @@ -208,62 +233,70 @@ class PagedAttention(nn.Module): attn_output = attn_output.transpose(1, 2).contiguous().reshape(bsz, seq_len, -1) + del attn_weights + return attn_output + @staticmethod def pad_decoding_forward( - self, q: torch.Tensor, # [bsz, 1, num_heads, head_size] - k: torch.Tensor, + k: torch.Tensor, # [bsz, 1, num_kv_heads, head_size] v: torch.Tensor, k_cache: torch.Tensor, # [num_blocks, num_heads, head_size, block_size] v_cache: torch.Tensor, lengths: torch.Tensor, # [num_seqs]: input_lengths + output_lengths block_tables: torch.Tensor, # [num_seqs,max_blocks_per_sequence] ): + # Firt, do shape verification. bsz, _, num_heads, head_size = q.shape + + num_kv_heads = k.shape[-2] + assert num_heads % num_kv_heads == 0, "num_kv_heads should be divisible by num_heads" + num_kv_groups = num_heads // num_kv_heads block_size = k_cache.shape[-1] seq_len = max(lengths) assert q.shape[0] == k.shape[0] == v.shape[0] == block_tables.shape[0] - max_seq_len = block_tables.shape[-1] * block_size + block_tables.shape[-1] * block_size + attn_mask = AttentionMaskConverter._make_causal_mask( q.shape[:2], q.dtype, q.device, past_key_values_length=seq_len - 1 ) - self.generate_padding_mask(lengths, max_seq_len) - cos, sin = self.rotary_emb(v, max_seq_len) + attn_mask = attn_mask + PagedAttention.generate_padding_mask(lengths, seq_len).unsqueeze(1).unsqueeze(2) + # cos, sin = self.rotary_emb(v, max_seq_len) + # position_ids = lengths - 1 + # position_ids = position_ids.unsqueeze(1) + # query, key = apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=2) - position_ids = lengths - 1 - position_ids = position_ids.unsqueeze(1) - - query, key = apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=2) - - copy_to_cache(key, k_cache, lengths=lengths, block_tables=block_tables, type="decoding") + copy_to_cache(k, k_cache, lengths=lengths, block_tables=block_tables, type="decoding") copy_to_cache(v, v_cache, lengths=lengths, block_tables=block_tables, type="decoding") - key = convert_kvcache(key, k_cache, lengths, block_tables) # bsz, seqlen, - value = convert_kvcache(v, v_cache, lengths, block_tables) + k = convert_kvcache(k_cache, lengths, block_tables) # bsz, seqlen, + v = convert_kvcache(v_cache, lengths, block_tables) - query = query.transpose(1, 2) - key = key.transpose(1, 2) - value = value.transpose(1, 2) + q = q.transpose(1, 2) + k = PagedAttention.repeat_kv(k.transpose(1, 2), num_kv_groups) + v = PagedAttention.repeat_kv(v.transpose(1, 2), num_kv_groups) - attn_weights = torch.matmul(query, key.transpose(2, 3)) / math.sqrt(head_size) + attn_weights = torch.matmul(q, k.transpose(2, 3)) / math.sqrt(head_size) if attn_weights.size() != (bsz, num_heads, 1, seq_len): raise ValueError(f"Got wrong attn_weights, should be in shape {(bsz,num_heads,1,seq_len)}.") if attn_mask is not None: attn_weights += attn_mask - attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) - # attn_weights = nn.functional.dropout(attn_weights,p=self.attention_dropout,training=False) maybe useless - attn_output = torch.matmul(attn_weights, value) + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(q.dtype) + attn_output = torch.matmul(attn_weights, v) if attn_output.size() != (bsz, num_heads, 1, head_size): raise ValueError(f"Got wrong attn_output, should be in shape {(bsz,num_heads,1,head_size)}.") attn_output = attn_output.transpose(1, 2).contiguous().reshape(bsz, 1, -1) + del attn_weights + return attn_output + @staticmethod def no_pad_decoding_forward( self, q: torch.Tensor, # [num_tokens, num_heads, head_size] diff --git a/tests/test_infer/test_config_and_struct.py b/tests/test_infer/test_config_and_struct.py index 7feb1cd41..a89776b6e 100755 --- a/tests/test_infer/test_config_and_struct.py +++ b/tests/test_infer/test_config_and_struct.py @@ -3,7 +3,7 @@ import pytest import colossalai from colossalai.inference.config import InferenceConfig from colossalai.inference.struct import BatchInfo, Sequence -from colossalai.testing import spawn +from colossalai.testing import rerun_if_address_is_in_use, spawn def check_config_and_inference(): @@ -74,6 +74,7 @@ def run_dist(rank, world_size, port): @pytest.mark.dist +@rerun_if_address_is_in_use() def test_config_and_inference(): spawn(run_dist, 1) diff --git a/tests/test_infer/test_inference_engine.py b/tests/test_infer/test_inference_engine.py index 4992fdfc7..ede4fb18a 100644 --- a/tests/test_infer/test_inference_engine.py +++ b/tests/test_infer/test_inference_engine.py @@ -11,7 +11,6 @@ from colossalai.inference.config import InferenceConfig from colossalai.inference.core.engine import InferenceEngine from colossalai.testing import rerun_if_address_is_in_use, spawn - def setup_seed(seed): torch.manual_seed(seed) torch.cuda.manual_seed_all(seed) diff --git a/tests/test_infer/test_kvcache_manager.py b/tests/test_infer/test_kvcache_manager.py index 115f5f282..9f7daa9a5 100755 --- a/tests/test_infer/test_kvcache_manager.py +++ b/tests/test_infer/test_kvcache_manager.py @@ -8,7 +8,7 @@ import colossalai from colossalai.inference.config import InferenceConfig from colossalai.inference.kv_cache import CacheBlock, KVCacheManager from colossalai.logging import disable_existing_loggers -from colossalai.testing import parameterize, spawn +from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn @parameterize( @@ -155,6 +155,7 @@ def run_dist(rank, world_size, port): @pytest.mark.dist +@rerun_if_address_is_in_use() def test_cache_manager(): spawn(run_dist, 1) diff --git a/tests/test_infer/test_models/test_attention.py b/tests/test_infer/test_models/test_attention.py index f3fbd7a0e..b4754fdea 100644 --- a/tests/test_infer/test_models/test_attention.py +++ b/tests/test_infer/test_models/test_attention.py @@ -3,15 +3,15 @@ import torch from transformers.cache_utils import DynamicCache from transformers.modeling_attn_mask_utils import AttentionMaskConverter from transformers.models.llama.configuration_llama import LlamaConfig -from transformers.models.llama.modeling_llama import LlamaAttention +from transformers.models.llama.modeling_llama import LlamaAttention, apply_rotary_pos_emb import colossalai from colossalai.inference.modeling.layers.attention import PagedAttention, convert_kvcache, copy_to_cache -from colossalai.testing import spawn +from colossalai.testing import rerun_if_address_is_in_use, spawn def test_copy_to_cache(): - key = torch.ones((2, 10, 3, 3)) + key = torch.ones((2, 11, 3, 3)) key[0, 9, :, :] = 0 key[1, -2:, :, :] = 0 cache = torch.zeros(8, 3, 3, 8) @@ -32,7 +32,8 @@ def test_convert_kvcache(): key = torch.ones(2, 1, 3, 3) + 1 lengths = torch.tensor([10, 9]) block_tables = torch.tensor([[0, 1], [2, 3]]) - converted_cache = convert_kvcache(key, cache=cache, lengths=lengths, block_tables=block_tables) + copy_to_cache(key, cache=cache, lengths=lengths, block_tables=block_tables, type="decoding") + converted_cache = convert_kvcache(cache=cache, lengths=lengths, block_tables=block_tables) assert converted_cache.shape == (2, 10, 3, 3) @@ -40,7 +41,7 @@ def test_context_attention(): """ test config: head_num = 4, head_size = 4 """ - attn = PagedAttention(4, 4) + attn = PagedAttention() q = k = v = torch.randn(8, 4, 4) k_cache = torch.empty(8, 4, 4, 8) v_cache = torch.empty(8, 4, 4, 8) @@ -61,48 +62,72 @@ def test_context_attention(): # test accuracy with LlamaAttention hidden_states = torch.randn(1, 8, 16) - proj_q = transformer_attn.q_proj(hidden_states).view(1, 8, 4, 4) - proj_k = transformer_attn.k_proj(hidden_states).view(1, 8, 4, 4) - proj_v = transformer_attn.v_proj(hidden_states).view(1, 8, 4, 4) - pad_attn_output = attn.pad_context_forward(proj_q, proj_k, proj_v, k_cache, v_cache, context_lengths, block_tables) - pad_attn_output = transformer_attn.o_proj(pad_attn_output) + proj_q = transformer_attn.q_proj(hidden_states).view(1, 8, 4, 4).transpose(1, 2) + proj_k = transformer_attn.k_proj(hidden_states).view(1, 8, 4, 4).transpose(1, 2) + proj_v = transformer_attn.v_proj(hidden_states).view(1, 8, 4, 4).transpose(1, 2) + position_ids = torch.arange(0, 8, dtype=torch.long, device=proj_q.device) + position_ids = position_ids.unsqueeze(0) + cos, sin = transformer_attn.rotary_emb(proj_v, 8) + proj_q, proj_k = apply_rotary_pos_emb(proj_q, proj_k, cos, sin, position_ids) + + pad_attn_output = attn.pad_context_forward( + proj_q.transpose(1, 2), + proj_k.transpose(1, 2), + proj_v.transpose(1, 2), + k_cache, + v_cache, + context_lengths, + block_tables, + ) + pad_attn_output = transformer_attn.o_proj(pad_attn_output) attn_mask = AttentionMaskConverter._make_causal_mask( hidden_states.shape[:2], q.dtype, q.device, past_key_values_length=0 ) + attn_mask += PagedAttention.generate_padding_mask(context_lengths, 8) attn_output, _, _ = transformer_attn.forward(hidden_states, attention_mask=attn_mask) - assert torch.allclose(pad_attn_output, attn_output, atol=1e-3, rtol=1e-2) + assert torch.allclose(pad_attn_output, attn_output, atol=1e-3, rtol=1e-3) def test_decoding_attention(): # test the pipeline of decoding attention - attn = PagedAttention(4, 4) - q = k = v = torch.randn(2, 1, 4, 4) - k_cache = torch.empty(8, 4, 4, 8) - v_cache = torch.empty(8, 4, 4, 8) - past_kv = torch.randn(2, 8, 4, 4) + attn = PagedAttention() + q = k = v = torch.randn(2, 1, 4, 8) + k_cache = torch.empty(8, 4, 8, 8) + v_cache = torch.empty(8, 4, 8, 8) + past_kv = torch.randn(2, 8, 4, 8) context_lenghths = torch.tensor([8, 8]) lengths = context_lenghths + 1 block_tables = torch.tensor([[0, 1], [2, 3]]) copy_to_cache(past_kv, k_cache, lengths=context_lenghths, block_tables=block_tables) copy_to_cache(past_kv, v_cache, lengths=context_lenghths, block_tables=block_tables) attn.pad_decoding_forward(q, k, v, k_cache, v_cache, lengths=lengths, block_tables=block_tables) + # test decoding accuracy, past_kv is reused - config = LlamaConfig(num_attention_heads=4, num_key_value_heads=None, hidden_size=16) + config = LlamaConfig(num_attention_heads=4, num_key_value_heads=None, hidden_size=32) transformer_attn = LlamaAttention(config) transformer_attn.layer_idx = 0 transformer_attn.training = False - hidden_states = torch.randn(2, 1, 16) - proj_q = transformer_attn.q_proj(hidden_states).view(2, 1, 4, 4) - proj_k = transformer_attn.k_proj(hidden_states).view(2, 1, 4, 4) - proj_v = transformer_attn.v_proj(hidden_states).view(2, 1, 4, 4) + hidden_states = torch.randn(2, 1, 32) + proj_q = transformer_attn.q_proj(hidden_states).view(2, 1, 4, 8).transpose(1, 2) + proj_k = transformer_attn.k_proj(hidden_states).view(2, 1, 4, 8).transpose(1, 2) + proj_v = transformer_attn.v_proj(hidden_states).view(2, 1, 4, 8).transpose(1, 2) + + cos, sin = transformer_attn.rotary_emb(proj_v, 16) + position_ids = lengths - 1 + position_ids = position_ids.unsqueeze(1) # NOTE: this may be wrong + proj_q, proj_k = apply_rotary_pos_emb(proj_q, proj_k, cos, sin, position_ids, unsqueeze_dim=2) llama_past_kv = DynamicCache() llama_past_kv.update(key_states=past_kv.transpose(1, 2), value_states=past_kv.transpose(1, 2), layer_idx=0) # past_key_value shape in Llama: bsz, num_heads, seq_len, head_dim - pad_attn_output = attn.pad_decoding_forward(proj_q, proj_k, proj_v, k_cache, v_cache, lengths, block_tables) - attn_mask = AttentionMaskConverter._make_causal_mask(proj_q.shape[:2], q.dtype, q.device, past_key_values_length=8) + pad_attn_output = attn.pad_decoding_forward( + proj_q.transpose(1, 2), proj_k.transpose(1, 2), proj_v.transpose(1, 2), k_cache, v_cache, lengths, block_tables + ) + attn_mask = AttentionMaskConverter._make_causal_mask(q.shape[:2], q.dtype, q.device, past_key_values_length=8) + attn_mask = attn_mask + PagedAttention.generate_padding_mask(lengths, 9).unsqueeze(1).unsqueeze(2) + pad_attn_output = transformer_attn.o_proj(pad_attn_output) position_ids = context_lenghths.unsqueeze(1) attn_output, _, _ = transformer_attn.forward( @@ -112,9 +137,9 @@ def test_decoding_attention(): def check_attention_layer(): - # test_copy_to_cache() - # test_convert_kvcache() - # test_context_attention() + test_copy_to_cache() + test_convert_kvcache() + test_context_attention() test_decoding_attention() @@ -124,6 +149,7 @@ def run_dist(rank, world_size, port): @pytest.mark.dist +@rerun_if_address_is_in_use() def test_attention_layer(): spawn(run_dist, 1) diff --git a/tests/test_infer/test_request_handler.py b/tests/test_infer/test_request_handler.py index d6c110c96..aa2cac6cb 100644 --- a/tests/test_infer/test_request_handler.py +++ b/tests/test_infer/test_request_handler.py @@ -6,7 +6,7 @@ import colossalai from colossalai.inference.config import InferenceConfig from colossalai.inference.core.request_handler import RequestHandler, RunningList from colossalai.inference.struct import RequestStatus, Sequence -from colossalai.testing import spawn +from colossalai.testing import rerun_if_address_is_in_use, spawn def check_running_list(): @@ -78,6 +78,7 @@ def run_dist(rank, world_size, port): @pytest.mark.dist +@rerun_if_address_is_in_use() def test_running_list_and_request_handler(): spawn(run_dist, 1)