diff --git a/colossalai/inference/modeling/layers/attention.py b/colossalai/inference/modeling/layers/attention.py index 8f6d6b569..d95504903 100644 --- a/colossalai/inference/modeling/layers/attention.py +++ b/colossalai/inference/modeling/layers/attention.py @@ -196,6 +196,7 @@ class PagedAttention: v_cache: torch.Tensor, context_lengths: torch.Tensor, # [num_seqs] block_tables: torch.Tensor, # [num_seqs,max_blocks_per_sequence] + attn_mask: torch.Tensor = None, # [bsz, input_lengths + output_lengths] ): # Firt, do shape verification bsz, seq_len, num_heads, head_size = q.shape @@ -205,8 +206,6 @@ class PagedAttention: block_size = k_cache.shape[-1] assert q.shape[0] == k.shape[0] == v.shape[0] == block_tables.shape[0] block_tables.shape[-1] * block_size - shape = (bsz, seq_len, num_heads, head_size) - input_shape = shape[:2] # Copy kv to memory(rotary embedded) copy_to_cache(k, k_cache, lengths=context_lengths, block_tables=block_tables) @@ -217,8 +216,16 @@ class PagedAttention: v = PagedAttention.repeat_kv(v.transpose(1, 2), num_kv_groups) attn_weights = torch.matmul(q, k.transpose(2, 3)) / math.sqrt(head_size) - attn_mask = AttentionMaskConverter._make_causal_mask(input_shape, q.dtype, q.device, past_key_values_length=0) - attn_mask = attn_mask + PagedAttention.generate_padding_mask(context_lengths, seq_len) + + if attn_mask is not None: + padding_mask = AttentionMaskConverter._expand_mask(attn_mask, q.dtype, seq_len) + + attn_mask = AttentionMaskConverter._make_causal_mask( + (bsz, seq_len), q.dtype, q.device, past_key_values_length=seq_len - seq_len + ) + + if padding_mask is not None: + attn_mask = attn_mask.masked_fill(padding_mask.bool(), torch.finfo(q.dtype).min) if attn_weights.size() != (bsz, num_heads, seq_len, seq_len): raise ValueError(f"Got wrong attn_weights, should be in shape {(bsz,num_heads,seq_len,seq_len)}.") @@ -246,27 +253,17 @@ class PagedAttention: v_cache: torch.Tensor, lengths: torch.Tensor, # [num_seqs]: input_lengths + output_lengths block_tables: torch.Tensor, # [num_seqs,max_blocks_per_sequence] + attn_mask: torch.Tensor = None, # [bsz, input_lengths + output_lengths] ): # Firt, do shape verification. - bsz, _, num_heads, head_size = q.shape + bsz, q_length, num_heads, head_size = q.shape num_kv_heads = k.shape[-2] assert num_heads % num_kv_heads == 0, "num_kv_heads should be divisible by num_heads" num_kv_groups = num_heads // num_kv_heads - block_size = k_cache.shape[-1] seq_len = max(lengths) assert q.shape[0] == k.shape[0] == v.shape[0] == block_tables.shape[0] - block_tables.shape[-1] * block_size - - attn_mask = AttentionMaskConverter._make_causal_mask( - q.shape[:2], q.dtype, q.device, past_key_values_length=seq_len - 1 - ) - attn_mask = attn_mask + PagedAttention.generate_padding_mask(lengths, seq_len).unsqueeze(1).unsqueeze(2) - # cos, sin = self.rotary_emb(v, max_seq_len) - # position_ids = lengths - 1 - # position_ids = position_ids.unsqueeze(1) - # query, key = apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=2) copy_to_cache(k, k_cache, lengths=lengths, block_tables=block_tables, type="decoding") copy_to_cache(v, v_cache, lengths=lengths, block_tables=block_tables, type="decoding") @@ -283,8 +280,16 @@ class PagedAttention: raise ValueError(f"Got wrong attn_weights, should be in shape {(bsz,num_heads,1,seq_len)}.") if attn_mask is not None: - attn_weights += attn_mask + padding_mask = AttentionMaskConverter._expand_mask(attn_mask, q.dtype, query_length) + attn_mask = AttentionMaskConverter._make_causal_mask( + (bsz, q_length), q.dtype, q.device, past_key_values_length=seq_len - query_length + ) + + if padding_mask is not None: + attn_mask = attn_mask.masked_fill(padding_mask.bool(), torch.finfo(q.dtype).min) + + attn_weights += attn_mask attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(q.dtype) attn_output = torch.matmul(attn_weights, v) diff --git a/colossalai/inference/modeling/models/llama.py b/colossalai/inference/modeling/models/llama.py index 44c07b7c6..d41267138 100644 --- a/colossalai/inference/modeling/models/llama.py +++ b/colossalai/inference/modeling/models/llama.py @@ -1,10 +1,7 @@ # This code is adapted from huggingface transformers: https://github.com/huggingface/transformers/blob/v4.34.1/src/transformers/models/llama/modeling_llama.py -import math from typing import List, Optional, Tuple import torch -import torch.nn as nn -from transformers.modeling_attn_mask_utils import AttentionMaskConverter from transformers.models.llama.modeling_llama import ( LlamaAttention, LlamaDecoderLayer, @@ -13,10 +10,10 @@ from transformers.models.llama.modeling_llama import ( repeat_kv, ) -from colossalai.inference.modeling.layers.attention import convert_kvcache, copy_to_cache +from colossalai.inference.modeling.layers.attention import PagedAttention from colossalai.inference.struct import BatchInfo -from flash_attn.bert_padding import index_first_axis # noqa +from flash_attn.bert_padding import index_first_axis, pad_input # noqa def rotate_half(x): @@ -163,11 +160,11 @@ def llama_attn_forward( value_states = value_states.transpose(1, 2) if is_prompts: - attn_output = pad_context_forward( + attn_output = PagedAttention.pad_context_forward( query_states, key_states, value_states, k_cache, v_cache, sequence_lengths, block_tables, attention_mask ) else: - attn_output = pad_decoding_forward( + attn_output = PagedAttention.pad_decoding_forward( query_states, key_states, value_states, k_cache, v_cache, sequence_lengths, block_tables, attention_mask ) @@ -182,118 +179,3 @@ def generate_padding_position_id(attention_mask: torch.Tensor) -> torch.Tensor: position_ids = attention_mask.long().cumsum(-1) - 1 position_ids.masked_fill_(attention_mask == 0, 1) return position_ids - - -def unpading_input(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, attention_mask: torch.Tensor): - seqlens = attention_mask.sum(dim=-1, dtype=torch.int32) - indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten() - batch_size, kv_seq_len, num_key_value_heads, head_dim = q.shape - q = index_first_axis(q.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices) - k = index_first_axis(k.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices) - v = index_first_axis(v.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices) - return (q, k, v, indices, seqlens) - - -def pad_decoding_forward( - query: torch.Tensor, # [bsz, 1, num_heads, head_size] - key: torch.Tensor, - value: torch.Tensor, - k_cache: torch.Tensor, # [num_blocks, num_heads, head_size, block_size] - v_cache: torch.Tensor, - lengths: torch.Tensor, # [num_seqs]: input_lengths + output_lengths - block_tables: torch.Tensor, # [num_seqs,max_blocks_per_sequence] - attn_mask: torch.Tensor = None, -): - bsz, query_length, num_heads, head_size = query.shape - seq_len = max(lengths) - - copy_to_cache(key, k_cache, lengths=lengths, block_tables=block_tables, type="decoding") - copy_to_cache(value, v_cache, lengths=lengths, block_tables=block_tables, type="decoding") - - key = convert_kvcache(k_cache, lengths, block_tables) # bsz, seqlen, - value = convert_kvcache(v_cache, lengths, block_tables) - - query = query.transpose(1, 2) - key = key.transpose(1, 2) - value = value.transpose(1, 2) - - attn_weights = torch.matmul(query, key.transpose(2, 3)) / math.sqrt(head_size) - if attn_weights.size() != (bsz, num_heads, 1, seq_len): - raise ValueError(f"Got wrong attn_weights, should be in shape {(bsz,num_heads,1,seq_len)}.") - - if attn_mask is not None: - padding_mask = AttentionMaskConverter._expand_mask(attn_mask, query.dtype, query_length) - - attn_mask = AttentionMaskConverter._make_causal_mask( - (bsz, query_length), query.dtype, query.device, past_key_values_length=seq_len - query_length - ) - - if padding_mask is not None: - attn_mask = attn_mask.masked_fill(padding_mask.bool(), torch.finfo(query.dtype).min) - - attn_weights += attn_mask - attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) - attn_output = torch.matmul(attn_weights, value) - - if attn_output.size() != (bsz, num_heads, 1, head_size): - raise ValueError(f"Got wrong attn_output, should be in shape {(bsz,num_heads,1,head_size)}.") - attn_output = attn_output.transpose(1, 2).contiguous().reshape(bsz, 1, -1) - - return attn_output - - -def pad_context_forward( - q: torch.Tensor, # [batch_size, seq_len, num_heads, head_size] - k: torch.Tensor, # [batch_size, seq_len, num_kv_heads, head_size] - v: torch.Tensor, - k_cache: torch.Tensor, # [num_blocks, num_heads, head_size, block_size] - v_cache: torch.Tensor, - context_lengths: torch.Tensor, # [num_seqs] - block_tables: torch.Tensor, # [num_seqs,max_blocks_per_sequence] - attn_mask: torch.Tensor = None, -): - # Firt, do shape verification - bsz, seq_len, num_heads, head_size = q.shape - num_kv_heads = k.shape[-2] - assert num_heads % num_kv_heads == 0, "num_kv_heads should be divisible by num_heads" - num_kv_groups = num_heads // num_kv_heads - block_size = k_cache.shape[-1] - assert q.shape[0] == k.shape[0] == v.shape[0] == block_tables.shape[0] - block_tables.shape[-1] * block_size - - # Copy kv to memory(rotary embedded) - copy_to_cache(k, k_cache, lengths=context_lengths, block_tables=block_tables) - copy_to_cache(v, v_cache, lengths=context_lengths, block_tables=block_tables) - - q = q.transpose(1, 2) - k = repeat_kv(k.transpose(1, 2), num_kv_groups) - v = repeat_kv(v.transpose(1, 2), num_kv_groups) - - attn_weights = torch.matmul(q, k.transpose(2, 3)) / math.sqrt(head_size) - - if attn_mask is not None: - padding_mask = AttentionMaskConverter._expand_mask(attn_mask, q.dtype, seq_len) - - attn_mask = AttentionMaskConverter._make_causal_mask( - (bsz, seq_len), q.dtype, q.device, past_key_values_length=seq_len - seq_len - ) - - if padding_mask is not None: - attn_mask = attn_mask.masked_fill(padding_mask.bool(), torch.finfo(q.dtype).min) - - if attn_weights.size() != (bsz, num_heads, seq_len, seq_len): - raise ValueError(f"Got wrong attn_weights, should be in shape {(bsz,num_heads,seq_len,seq_len)}.") - if attn_mask is not None: - attn_weights += attn_mask - - attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(q.dtype) - attn_output = torch.matmul(attn_weights, v) - - if attn_output.size() != (bsz, num_heads, seq_len, head_size): - raise ValueError(f"Got wrong attn_output, should be in shape {(bsz,num_heads,seq_len,head_size)}.") - - attn_output = attn_output.transpose(1, 2).contiguous().reshape(bsz, seq_len, -1) - - del attn_weights - - return attn_output