diff --git a/colossalai/inference/modeling/layers/attention.py b/colossalai/inference/modeling/layers/attention.py new file mode 100644 index 000000000..0a9f8566e --- /dev/null +++ b/colossalai/inference/modeling/layers/attention.py @@ -0,0 +1,276 @@ +import math +from typing import Optional + +import torch +import torch.nn as nn +import torch.nn.functional as F +from transformers.modeling_attn_mask_utils import AttentionMaskConverter +from transformers.models.llama.modeling_llama import LlamaRotaryEmbedding, apply_rotary_pos_emb + + +def copy_to_cache(source, cache, lengths, block_tables, type: str = "prefill"): + """ + Func: copy key/value into key/value cache. + + Args: key/value(source): shape [bsz,seq_len,num_heads,head_size] + cache: shape [num_blocks, num_heads, head_size, block_size] + lengths: key/value lengths + block_tables + """ + num_blocks, num_heads, head_size, block_size = cache.shape + bsz, max_seq_len = block_tables.shape + needed_blocks = (lengths + block_size - 1) // block_size + + if type == "prefill": + for i in range(bsz): + seq_len = lengths[i] + block_num = needed_blocks[i] + token_id = 0 + for block_idx in range(block_num - 1): + cache[block_tables[i][block_idx]] = source[i][token_id : token_id + block_size].permute(1, 2, 0) + token_id += block_size + cache[block_tables[i][block_num - 1]] = source[i][token_id:seq_len].permute(1, 2, 0) + elif type == "decoding": + assert len(source[0]) == 1, "seq_len should be equal to 1 when decoding." + source = source.squeeze(1) + slot_idx = (lengths + block_size - 1) % block_size + for i in range(bsz): + cache[block_tables[i, needed_blocks[i] - 1], :, :, slot_idx[i]] = source[i].permute(0, 1) + + return cache + + +def convert_kvcache(source, cache, lengths, block_tables): + """ + Func: convert key/value cache for calculation + + Args: key/value(source): shape [bsz, 1, num_heads, head_size] + cache: shape [num_blocks, num_heads, head_size, block_size] + lengths: key/value length + block_tables + """ + num_blocks, num_heads, head_size, block_size = cache.shape + + needed_blocks = (lengths + block_size - 1) // block_size + num_remaing_tokens = (lengths - 1) % block_size + bsz = block_tables.shape[0] + seq_len = max(lengths) + padded_cache = [] + for i in range(bsz): + _cache = torch.cat( + ( + cache[block_tables[i][: needed_blocks[i] - 1]].permute((3, 0, 1, 2)).reshape(-1, num_heads, head_size), + cache[block_tables[i][needed_blocks[i] - 1], :, :, : num_remaing_tokens[i]].permute(2, 1, 0), + ), + dim=0, + ) + concat_cache = torch.cat((_cache, source[i]), dim=0) + padding = seq_len - concat_cache.size(0) + if padding > 0: + concat_cache = F.pad(concat_cache, (0, 0, 0, 0, 0, 1)) + padded_cache.append(concat_cache) + + return torch.stack(padded_cache, dim=0) + + +class PagedAttention(nn.Module): + """ + Pure Torch implementation version of paged_attention. + """ + + def __init__(self, num_heads: int, head_size: int, scale: float = 1.0, sliding_window: Optional[int] = None): + super().__init__() + self.num_heads = num_heads + self.head_size = head_size + self.scale = float(scale) + self.sliding_window = sliding_window + self._init_rope() + + def _init_rope(self): + self.rotary_emb = LlamaRotaryEmbedding(self.head_size) + + def pad_and_reshape(self, tensor, seq_lengths, max_seq_len, num_heads, head_size): + bsz = len(seq_lengths) + padded_tensor = torch.zeros(bsz, max_seq_len, num_heads, head_size) + + token_idx = 0 + for i, seq_len in enumerate(seq_lengths): + seq_tensor = tensor[token_idx : token_idx + seq_len] + padded_tensor[i, :seq_len, :, :] = seq_tensor + token_idx += seq_len + return padded_tensor + + def generate_padding_mask(self, lengths, max_seq_len): + range_tensor = torch.arange(max_seq_len).expand(len(lengths), max_seq_len) + padding_mask = range_tensor < lengths.unsqueeze(1) + return padding_mask + + def nopad_context_forward( + self, + q: torch.Tensor, # [num_tokens, num_heads, head_size] + k: torch.Tensor, + v: torch.Tensor, + k_cache: torch.Tensor, # [num_blocks, num_heads, head_size, block_size] + v_cache: torch.Tensor, + context_lengths: torch.Tensor, # [num_seqs] + block_tables: torch.Tensor, # [num_seqs,max_blocks_per_sequence] + ): + num_tokens, num_heads, head_size = q.shape + block_size = k_cache.shape[-1] + bsz, max_blocks_per_sequence = block_tables.shape + max_seq_len = max_blocks_per_sequence * block_size + assert q.shape[-1] == k.shape[-1] == v.shape[-1] + assert q.shape[0] == k.shape[0] == v.shape[0] + assert context_lengths.shape[0] == block_tables.shape[0] + shape = (bsz, max_seq_len, num_heads, head_size) + input_shape = shape[:2] + query = self.pad_and_reshape(q, context_lengths, max_seq_len, num_heads, head_size).transpose(1, 2) + key = self.pad_and_reshape(k, context_lengths, max_seq_len, num_heads, head_size).transpose(1, 2) + value = self.pad_and_reshape(v, context_lengths, max_seq_len, num_heads, head_size).transpose(1, 2) + + attn_mask = AttentionMaskConverter._make_causal_mask(input_shape, q.dtype, q.device, past_key_values_length=0) + self.generate_padding_mask(context_lengths, max_seq_len) + + position_ids = torch.arange(0, max_seq_len, dtype=torch.long, device=query.device) + position_ids = position_ids.unsqueeze(0) + + cos, sin = self.rotary_emb(value, max_seq_len) + query, key = apply_rotary_pos_emb(query, key, cos, sin, position_ids) + + copy_to_cache(key.transpose(1, 2), k_cache, lengths=context_lengths, block_tables=block_tables) + copy_to_cache(value.transpose(1, 2), v_cache, lengths=context_lengths, block_tables=block_tables) + + attn_weights = torch.matmul(query, key.transpose(2, 3)) / math.sqrt(head_size) + + if attn_weights.size() != (bsz, num_heads, max_seq_len, max_seq_len): + raise ValueError(f"Got wrong attn_weights, should be in shape {(bsz,num_heads,max_seq_len,max_seq_len)}.") + + if attn_mask is not None: + attn_weights += attn_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) + # attn_weights = nn.functional.dropout(attn_weights,p=self.attention_dropout,training=False) maybe useless + attn_output = torch.matmul(attn_weights, value) + + if attn_output.size() != (bsz, num_heads, max_seq_len, head_size): + raise ValueError(f"Got wrong attn_output, should be in shape {(bsz,num_heads,max_seq_len,head_size)}.") + attn_output = attn_output.transpose(1, 2).contiguous().reshape(bsz, max_seq_len, -1) + + return attn_output + + def pad_context_forward( + self, + q: torch.Tensor, # [batch_size, seq_len, num_heads, head_size] + k: torch.Tensor, + v: torch.Tensor, + k_cache: torch.Tensor, # [num_blocks, num_heads, head_size, block_size] + v_cache: torch.Tensor, + context_lengths: torch.Tensor, # [num_seqs] + block_tables: torch.Tensor, # [num_seqs,max_blocks_per_sequence] + ): + bsz, seq_len, num_heads, head_size = q.shape + block_size = k_cache.shape[-1] + assert q.shape[0] == k.shape[0] == v.shape[0] == block_tables.shape[0] + block_tables.shape[-1] * block_size + shape = (bsz, seq_len, num_heads, head_size) + input_shape = shape[:2] + q = q.transpose(1, 2) + k = k.transpose(1, 2) + v = v.transpose(1, 2) + + position_ids = torch.arange(0, seq_len, dtype=torch.long, device=q.device) + position_ids = position_ids.unsqueeze(0) + cos, sin = self.rotary_emb(v, seq_len) + query, key = apply_rotary_pos_emb(q, k, cos, sin, position_ids) + + copy_to_cache(key.transpose(1, 2), k_cache, lengths=context_lengths, block_tables=block_tables) + copy_to_cache(v.transpose(1, 2), v_cache, lengths=context_lengths, block_tables=block_tables) + + attn_weights = torch.matmul(query, key.transpose(2, 3)) / math.sqrt(head_size) + attn_mask = AttentionMaskConverter._make_causal_mask(input_shape, q.dtype, q.device, past_key_values_length=0) + self.generate_padding_mask(context_lengths, seq_len) + + if attn_weights.size() != (bsz, num_heads, seq_len, seq_len): + raise ValueError(f"Got wrong attn_weights, should be in shape {(bsz,num_heads,seq_len,seq_len)}.") + if attn_mask is not None: + attn_weights += attn_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) + + # attn_weights = nn.functional.dropout(attn_weights,p=self.attention_dropout,training=False) maybe useless + attn_output = torch.matmul(attn_weights, v) + + if attn_output.size() != (bsz, num_heads, seq_len, head_size): + raise ValueError(f"Got wrong attn_output, should be in shape {(bsz,num_heads,seq_len,head_size)}.") + + attn_output = attn_output.transpose(1, 2).contiguous().reshape(bsz, seq_len, -1) + + return attn_output + + def pad_decoding_forward( + self, + q: torch.Tensor, # [bsz, 1, num_heads, head_size] + k: torch.Tensor, + v: torch.Tensor, + k_cache: torch.Tensor, # [num_blocks, num_heads, head_size, block_size] + v_cache: torch.Tensor, + lengths: torch.Tensor, # [num_seqs]: input_lengths + output_lengths + block_tables: torch.Tensor, # [num_seqs,max_blocks_per_sequence] + ): + bsz, _, num_heads, head_size = q.shape + block_size = k_cache.shape[-1] + seq_len = max(lengths) + + assert q.shape[0] == k.shape[0] == v.shape[0] == block_tables.shape[0] + max_seq_len = block_tables.shape[-1] * block_size + attn_mask = AttentionMaskConverter._make_causal_mask( + q.shape[:2], q.dtype, q.device, past_key_values_length=seq_len - 1 + ) + self.generate_padding_mask(lengths, max_seq_len) + cos, sin = self.rotary_emb(v, max_seq_len) + + position_ids = lengths - 1 + position_ids = position_ids.unsqueeze(1) + + query, key = apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=2) + + copy_to_cache(key, k_cache, lengths=lengths, block_tables=block_tables, type="decoding") + copy_to_cache(v, v_cache, lengths=lengths, block_tables=block_tables, type="decoding") + + key = convert_kvcache(key, k_cache, lengths, block_tables) # bsz, seqlen, + value = convert_kvcache(v, v_cache, lengths, block_tables) + + query = query.transpose(1, 2) + key = key.transpose(1, 2) + value = value.transpose(1, 2) + + attn_weights = torch.matmul(query, key.transpose(2, 3)) / math.sqrt(head_size) + if attn_weights.size() != (bsz, num_heads, 1, seq_len): + raise ValueError(f"Got wrong attn_weights, should be in shape {(bsz,num_heads,1,seq_len)}.") + + if attn_mask is not None: + attn_weights += attn_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) + # attn_weights = nn.functional.dropout(attn_weights,p=self.attention_dropout,training=False) maybe useless + attn_output = torch.matmul(attn_weights, value) + + if attn_output.size() != (bsz, num_heads, 1, head_size): + raise ValueError(f"Got wrong attn_output, should be in shape {(bsz,num_heads,1,head_size)}.") + attn_output = attn_output.transpose(1, 2).contiguous().reshape(bsz, 1, -1) + + return attn_output + + def no_pad_decoding_forward( + self, + q: torch.Tensor, # [num_tokens, num_heads, head_size] + k: torch.Tensor, + v: torch.Tensor, + k_cache: torch.Tensor, # [num_blocks, num_heads, head_size, block_size] + v_cache: torch.Tensor, + lengths: torch.Tensor, # [num_seqs]: input_lengths + output_lengths + block_tables: torch.Tensor, # [num_seqs,max_blocks_per_sequence] + ): + return self.pad_decoding_forward( + q.unsqueeze(1), k.unsqueeze(1), v.unsqueeze(1), k_cache, v_cache, lengths, block_tables + ) diff --git a/tests/test_infer/test_models/test_attention.py b/tests/test_infer/test_models/test_attention.py new file mode 100644 index 000000000..f3fbd7a0e --- /dev/null +++ b/tests/test_infer/test_models/test_attention.py @@ -0,0 +1,132 @@ +import pytest +import torch +from transformers.cache_utils import DynamicCache +from transformers.modeling_attn_mask_utils import AttentionMaskConverter +from transformers.models.llama.configuration_llama import LlamaConfig +from transformers.models.llama.modeling_llama import LlamaAttention + +import colossalai +from colossalai.inference.modeling.layers.attention import PagedAttention, convert_kvcache, copy_to_cache +from colossalai.testing import spawn + + +def test_copy_to_cache(): + key = torch.ones((2, 10, 3, 3)) + key[0, 9, :, :] = 0 + key[1, -2:, :, :] = 0 + cache = torch.zeros(8, 3, 3, 8) + block_tables = torch.tensor([[0, 1], [2, 3]]) + lengths = torch.tensor([9, 8]) + cache = copy_to_cache(key, cache=cache, lengths=lengths, block_tables=block_tables, type="prefill") + assert cache[1, 0, 0, 0] == 1 + assert cache[3, 0, 0, 0] == 0 + + decoding_key = torch.ones((2, 1, 3, 3)) + cache = copy_to_cache(decoding_key, cache=cache, lengths=lengths + 1, block_tables=block_tables, type="decoding") + assert cache[1, 0, 0, 1] == 1 + assert cache[3, 0, 0, 0] == 1 + + +def test_convert_kvcache(): + cache = torch.ones(8, 3, 3, 8) + key = torch.ones(2, 1, 3, 3) + 1 + lengths = torch.tensor([10, 9]) + block_tables = torch.tensor([[0, 1], [2, 3]]) + converted_cache = convert_kvcache(key, cache=cache, lengths=lengths, block_tables=block_tables) + assert converted_cache.shape == (2, 10, 3, 3) + + +def test_context_attention(): + """ + test config: head_num = 4, head_size = 4 + """ + attn = PagedAttention(4, 4) + q = k = v = torch.randn(8, 4, 4) + k_cache = torch.empty(8, 4, 4, 8) + v_cache = torch.empty(8, 4, 4, 8) + context_lengths = torch.tensor( + [ + 8, + ] + ) + block_tables = torch.tensor([[0, 1]]) + attn.nopad_context_forward(q, k, v, k_cache, v_cache, context_lengths, block_tables) + # test padded q/k/v + pad_q = pad_k = pad_v = q.unsqueeze(0) + attn.pad_context_forward(pad_q, pad_k, pad_v, k_cache, v_cache, context_lengths, block_tables) + + config = LlamaConfig(num_attention_heads=4, num_key_value_heads=None, hidden_size=16) + transformer_attn = LlamaAttention(config) + transformer_attn.training = False + + # test accuracy with LlamaAttention + hidden_states = torch.randn(1, 8, 16) + proj_q = transformer_attn.q_proj(hidden_states).view(1, 8, 4, 4) + proj_k = transformer_attn.k_proj(hidden_states).view(1, 8, 4, 4) + proj_v = transformer_attn.v_proj(hidden_states).view(1, 8, 4, 4) + pad_attn_output = attn.pad_context_forward(proj_q, proj_k, proj_v, k_cache, v_cache, context_lengths, block_tables) + pad_attn_output = transformer_attn.o_proj(pad_attn_output) + + attn_mask = AttentionMaskConverter._make_causal_mask( + hidden_states.shape[:2], q.dtype, q.device, past_key_values_length=0 + ) + attn_output, _, _ = transformer_attn.forward(hidden_states, attention_mask=attn_mask) + assert torch.allclose(pad_attn_output, attn_output, atol=1e-3, rtol=1e-2) + + +def test_decoding_attention(): + # test the pipeline of decoding attention + attn = PagedAttention(4, 4) + q = k = v = torch.randn(2, 1, 4, 4) + k_cache = torch.empty(8, 4, 4, 8) + v_cache = torch.empty(8, 4, 4, 8) + past_kv = torch.randn(2, 8, 4, 4) + context_lenghths = torch.tensor([8, 8]) + lengths = context_lenghths + 1 + block_tables = torch.tensor([[0, 1], [2, 3]]) + copy_to_cache(past_kv, k_cache, lengths=context_lenghths, block_tables=block_tables) + copy_to_cache(past_kv, v_cache, lengths=context_lenghths, block_tables=block_tables) + attn.pad_decoding_forward(q, k, v, k_cache, v_cache, lengths=lengths, block_tables=block_tables) + # test decoding accuracy, past_kv is reused + config = LlamaConfig(num_attention_heads=4, num_key_value_heads=None, hidden_size=16) + transformer_attn = LlamaAttention(config) + transformer_attn.layer_idx = 0 + transformer_attn.training = False + hidden_states = torch.randn(2, 1, 16) + proj_q = transformer_attn.q_proj(hidden_states).view(2, 1, 4, 4) + proj_k = transformer_attn.k_proj(hidden_states).view(2, 1, 4, 4) + proj_v = transformer_attn.v_proj(hidden_states).view(2, 1, 4, 4) + + llama_past_kv = DynamicCache() + llama_past_kv.update(key_states=past_kv.transpose(1, 2), value_states=past_kv.transpose(1, 2), layer_idx=0) + + # past_key_value shape in Llama: bsz, num_heads, seq_len, head_dim + pad_attn_output = attn.pad_decoding_forward(proj_q, proj_k, proj_v, k_cache, v_cache, lengths, block_tables) + attn_mask = AttentionMaskConverter._make_causal_mask(proj_q.shape[:2], q.dtype, q.device, past_key_values_length=8) + pad_attn_output = transformer_attn.o_proj(pad_attn_output) + position_ids = context_lenghths.unsqueeze(1) + attn_output, _, _ = transformer_attn.forward( + hidden_states, past_key_value=llama_past_kv, position_ids=position_ids, attention_mask=attn_mask + ) + assert torch.allclose(pad_attn_output, attn_output, atol=1e-3, rtol=1e-2) + + +def check_attention_layer(): + # test_copy_to_cache() + # test_convert_kvcache() + # test_context_attention() + test_decoding_attention() + + +def run_dist(rank, world_size, port): + colossalai.launch(config={}, rank=rank, world_size=world_size, port=port, host="localhost") + check_attention_layer() + + +@pytest.mark.dist +def test_attention_layer(): + spawn(run_dist, 1) + + +if __name__ == "__main__": + test_attention_layer()