fix bugs related to processing padding mask

pull/5258/head
yuehuayingxueluo 2024-01-09 14:29:45 +08:00 committed by FrankLeeeee
parent e545a871b8
commit 2a73e828eb
2 changed files with 26 additions and 139 deletions

View File

@ -196,6 +196,7 @@ class PagedAttention:
v_cache: torch.Tensor,
context_lengths: torch.Tensor, # [num_seqs]
block_tables: torch.Tensor, # [num_seqs,max_blocks_per_sequence]
attn_mask: torch.Tensor = None, # [bsz, input_lengths + output_lengths]
):
# Firt, do shape verification
bsz, seq_len, num_heads, head_size = q.shape
@ -205,8 +206,6 @@ class PagedAttention:
block_size = k_cache.shape[-1]
assert q.shape[0] == k.shape[0] == v.shape[0] == block_tables.shape[0]
block_tables.shape[-1] * block_size
shape = (bsz, seq_len, num_heads, head_size)
input_shape = shape[:2]
# Copy kv to memory(rotary embedded)
copy_to_cache(k, k_cache, lengths=context_lengths, block_tables=block_tables)
@ -217,8 +216,16 @@ class PagedAttention:
v = PagedAttention.repeat_kv(v.transpose(1, 2), num_kv_groups)
attn_weights = torch.matmul(q, k.transpose(2, 3)) / math.sqrt(head_size)
attn_mask = AttentionMaskConverter._make_causal_mask(input_shape, q.dtype, q.device, past_key_values_length=0)
attn_mask = attn_mask + PagedAttention.generate_padding_mask(context_lengths, seq_len)
if attn_mask is not None:
padding_mask = AttentionMaskConverter._expand_mask(attn_mask, q.dtype, seq_len)
attn_mask = AttentionMaskConverter._make_causal_mask(
(bsz, seq_len), q.dtype, q.device, past_key_values_length=seq_len - seq_len
)
if padding_mask is not None:
attn_mask = attn_mask.masked_fill(padding_mask.bool(), torch.finfo(q.dtype).min)
if attn_weights.size() != (bsz, num_heads, seq_len, seq_len):
raise ValueError(f"Got wrong attn_weights, should be in shape {(bsz,num_heads,seq_len,seq_len)}.")
@ -246,27 +253,17 @@ class PagedAttention:
v_cache: torch.Tensor,
lengths: torch.Tensor, # [num_seqs]: input_lengths + output_lengths
block_tables: torch.Tensor, # [num_seqs,max_blocks_per_sequence]
attn_mask: torch.Tensor = None, # [bsz, input_lengths + output_lengths]
):
# Firt, do shape verification.
bsz, _, num_heads, head_size = q.shape
bsz, q_length, num_heads, head_size = q.shape
num_kv_heads = k.shape[-2]
assert num_heads % num_kv_heads == 0, "num_kv_heads should be divisible by num_heads"
num_kv_groups = num_heads // num_kv_heads
block_size = k_cache.shape[-1]
seq_len = max(lengths)
assert q.shape[0] == k.shape[0] == v.shape[0] == block_tables.shape[0]
block_tables.shape[-1] * block_size
attn_mask = AttentionMaskConverter._make_causal_mask(
q.shape[:2], q.dtype, q.device, past_key_values_length=seq_len - 1
)
attn_mask = attn_mask + PagedAttention.generate_padding_mask(lengths, seq_len).unsqueeze(1).unsqueeze(2)
# cos, sin = self.rotary_emb(v, max_seq_len)
# position_ids = lengths - 1
# position_ids = position_ids.unsqueeze(1)
# query, key = apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=2)
copy_to_cache(k, k_cache, lengths=lengths, block_tables=block_tables, type="decoding")
copy_to_cache(v, v_cache, lengths=lengths, block_tables=block_tables, type="decoding")
@ -283,8 +280,16 @@ class PagedAttention:
raise ValueError(f"Got wrong attn_weights, should be in shape {(bsz,num_heads,1,seq_len)}.")
if attn_mask is not None:
attn_weights += attn_mask
padding_mask = AttentionMaskConverter._expand_mask(attn_mask, q.dtype, query_length)
attn_mask = AttentionMaskConverter._make_causal_mask(
(bsz, q_length), q.dtype, q.device, past_key_values_length=seq_len - query_length
)
if padding_mask is not None:
attn_mask = attn_mask.masked_fill(padding_mask.bool(), torch.finfo(q.dtype).min)
attn_weights += attn_mask
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(q.dtype)
attn_output = torch.matmul(attn_weights, v)

View File

@ -1,10 +1,7 @@
# This code is adapted from huggingface transformers: https://github.com/huggingface/transformers/blob/v4.34.1/src/transformers/models/llama/modeling_llama.py
import math
from typing import List, Optional, Tuple
import torch
import torch.nn as nn
from transformers.modeling_attn_mask_utils import AttentionMaskConverter
from transformers.models.llama.modeling_llama import (
LlamaAttention,
LlamaDecoderLayer,
@ -13,10 +10,10 @@ from transformers.models.llama.modeling_llama import (
repeat_kv,
)
from colossalai.inference.modeling.layers.attention import convert_kvcache, copy_to_cache
from colossalai.inference.modeling.layers.attention import PagedAttention
from colossalai.inference.struct import BatchInfo
from flash_attn.bert_padding import index_first_axis # noqa
from flash_attn.bert_padding import index_first_axis, pad_input # noqa
def rotate_half(x):
@ -163,11 +160,11 @@ def llama_attn_forward(
value_states = value_states.transpose(1, 2)
if is_prompts:
attn_output = pad_context_forward(
attn_output = PagedAttention.pad_context_forward(
query_states, key_states, value_states, k_cache, v_cache, sequence_lengths, block_tables, attention_mask
)
else:
attn_output = pad_decoding_forward(
attn_output = PagedAttention.pad_decoding_forward(
query_states, key_states, value_states, k_cache, v_cache, sequence_lengths, block_tables, attention_mask
)
@ -182,118 +179,3 @@ def generate_padding_position_id(attention_mask: torch.Tensor) -> torch.Tensor:
position_ids = attention_mask.long().cumsum(-1) - 1
position_ids.masked_fill_(attention_mask == 0, 1)
return position_ids
def unpading_input(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, attention_mask: torch.Tensor):
seqlens = attention_mask.sum(dim=-1, dtype=torch.int32)
indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
batch_size, kv_seq_len, num_key_value_heads, head_dim = q.shape
q = index_first_axis(q.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices)
k = index_first_axis(k.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices)
v = index_first_axis(v.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices)
return (q, k, v, indices, seqlens)
def pad_decoding_forward(
query: torch.Tensor, # [bsz, 1, num_heads, head_size]
key: torch.Tensor,
value: torch.Tensor,
k_cache: torch.Tensor, # [num_blocks, num_heads, head_size, block_size]
v_cache: torch.Tensor,
lengths: torch.Tensor, # [num_seqs]: input_lengths + output_lengths
block_tables: torch.Tensor, # [num_seqs,max_blocks_per_sequence]
attn_mask: torch.Tensor = None,
):
bsz, query_length, num_heads, head_size = query.shape
seq_len = max(lengths)
copy_to_cache(key, k_cache, lengths=lengths, block_tables=block_tables, type="decoding")
copy_to_cache(value, v_cache, lengths=lengths, block_tables=block_tables, type="decoding")
key = convert_kvcache(k_cache, lengths, block_tables) # bsz, seqlen,
value = convert_kvcache(v_cache, lengths, block_tables)
query = query.transpose(1, 2)
key = key.transpose(1, 2)
value = value.transpose(1, 2)
attn_weights = torch.matmul(query, key.transpose(2, 3)) / math.sqrt(head_size)
if attn_weights.size() != (bsz, num_heads, 1, seq_len):
raise ValueError(f"Got wrong attn_weights, should be in shape {(bsz,num_heads,1,seq_len)}.")
if attn_mask is not None:
padding_mask = AttentionMaskConverter._expand_mask(attn_mask, query.dtype, query_length)
attn_mask = AttentionMaskConverter._make_causal_mask(
(bsz, query_length), query.dtype, query.device, past_key_values_length=seq_len - query_length
)
if padding_mask is not None:
attn_mask = attn_mask.masked_fill(padding_mask.bool(), torch.finfo(query.dtype).min)
attn_weights += attn_mask
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
attn_output = torch.matmul(attn_weights, value)
if attn_output.size() != (bsz, num_heads, 1, head_size):
raise ValueError(f"Got wrong attn_output, should be in shape {(bsz,num_heads,1,head_size)}.")
attn_output = attn_output.transpose(1, 2).contiguous().reshape(bsz, 1, -1)
return attn_output
def pad_context_forward(
q: torch.Tensor, # [batch_size, seq_len, num_heads, head_size]
k: torch.Tensor, # [batch_size, seq_len, num_kv_heads, head_size]
v: torch.Tensor,
k_cache: torch.Tensor, # [num_blocks, num_heads, head_size, block_size]
v_cache: torch.Tensor,
context_lengths: torch.Tensor, # [num_seqs]
block_tables: torch.Tensor, # [num_seqs,max_blocks_per_sequence]
attn_mask: torch.Tensor = None,
):
# Firt, do shape verification
bsz, seq_len, num_heads, head_size = q.shape
num_kv_heads = k.shape[-2]
assert num_heads % num_kv_heads == 0, "num_kv_heads should be divisible by num_heads"
num_kv_groups = num_heads // num_kv_heads
block_size = k_cache.shape[-1]
assert q.shape[0] == k.shape[0] == v.shape[0] == block_tables.shape[0]
block_tables.shape[-1] * block_size
# Copy kv to memory(rotary embedded)
copy_to_cache(k, k_cache, lengths=context_lengths, block_tables=block_tables)
copy_to_cache(v, v_cache, lengths=context_lengths, block_tables=block_tables)
q = q.transpose(1, 2)
k = repeat_kv(k.transpose(1, 2), num_kv_groups)
v = repeat_kv(v.transpose(1, 2), num_kv_groups)
attn_weights = torch.matmul(q, k.transpose(2, 3)) / math.sqrt(head_size)
if attn_mask is not None:
padding_mask = AttentionMaskConverter._expand_mask(attn_mask, q.dtype, seq_len)
attn_mask = AttentionMaskConverter._make_causal_mask(
(bsz, seq_len), q.dtype, q.device, past_key_values_length=seq_len - seq_len
)
if padding_mask is not None:
attn_mask = attn_mask.masked_fill(padding_mask.bool(), torch.finfo(q.dtype).min)
if attn_weights.size() != (bsz, num_heads, seq_len, seq_len):
raise ValueError(f"Got wrong attn_weights, should be in shape {(bsz,num_heads,seq_len,seq_len)}.")
if attn_mask is not None:
attn_weights += attn_mask
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(q.dtype)
attn_output = torch.matmul(attn_weights, v)
if attn_output.size() != (bsz, num_heads, seq_len, head_size):
raise ValueError(f"Got wrong attn_output, should be in shape {(bsz,num_heads,seq_len,head_size)}.")
attn_output = attn_output.transpose(1, 2).contiguous().reshape(bsz, seq_len, -1)
del attn_weights
return attn_output