
215 lines
9.7 KiB

from typing import Tuple
import torch
from torch.nn import functional as F
# This function is adapted from src/transformers/models/llama/
# in huggingface transformers repository
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep).
The hidden states go from (bsz, num_key_value_heads, seq_len, head_dim) to (bsz, num_attention_heads, seq_len, head_dim)
if n_rep == 1:
return hidden_states
bsz, num_key_value_heads, seq_len, head_dim = hidden_states.shape
hidden_states = hidden_states[:, :, None, :, :].expand(bsz, num_key_value_heads, n_rep, seq_len, head_dim)
return hidden_states.reshape(bsz, num_key_value_heads * n_rep, seq_len, head_dim)
def prepare_padding_mask(kv_lengths: torch.Tensor, bsz: int, kv_seq_len: int, device="cuda"):
padding_mask = torch.zeros((bsz, 1, 1, kv_seq_len), dtype=torch.float32, device=device)
for i in range(bsz):
cur_seq_len = kv_lengths[i].item()
assert cur_seq_len <= kv_seq_len
padding_mask[i, :, :, : kv_seq_len - cur_seq_len] = float("-inf")
return padding_mask
# Attention calculation adapted from HuggingFace transformers repository
# src/transformers/models/llama/
def torch_attn_ref(
q: torch.Tensor, # [bsz, num_heads, q_len, head_dim]
k: torch.Tensor, # [bsz, num_heads, kv_seq_len, head_dim]
v: torch.Tensor, # [bsz, num_heads, kv_seq_len, head_dim]
attention_mask: torch.Tensor, # [bsz, 1, seq_len, kv_seq_len]
bsz: int,
seq_len: int,
kv_seq_len: int,
num_heads: int,
num_kv_heads: int,
head_dim: int,
) -> torch.Tensor:
assert q.shape[-1] == k.shape[-1] == v.shape[-1] == head_dim
# repeat kv for GQA and MQA
# k/v won't change if kv_group_num is 1
assert num_heads % num_kv_heads == 0, "Number of heads is not multiple of kv heads"
kv_group_num = num_heads // num_kv_heads
k = repeat_kv(k, kv_group_num)
v = repeat_kv(v, kv_group_num)
qk = torch.matmul(q, k.transpose(2, 3))
attn_scores = qk / (head_dim**0.5)
assert attn_scores.shape == (bsz, num_heads, seq_len, kv_seq_len), "Invalid shape of attention scores"
# for left-side padding
if attention_mask.size() != (bsz, 1, seq_len, kv_seq_len):
raise ValueError(
f"Attention mask should be of size {(bsz, 1, seq_len, kv_seq_len)}, but is {attention_mask.size()}"
attn_scores = attn_scores + attention_mask
attn_weights = F.softmax(, dim=-1).to(dtype=q.dtype)
out = torch.matmul(attn_weights, v)
if out.size() != (bsz, num_heads, seq_len, head_dim):
raise ValueError(
f"`attn_output` should be of size {(bsz, num_heads, seq_len, head_dim)}, but is" f" {out.size()}"
out = out.transpose(1, 2).contiguous()
out = out.squeeze(1)
return out
def mock_alloc_block_table_and_kvcache(
k: torch.Tensor,
v: torch.Tensor,
k_cache: torch.Tensor,
v_cache: torch.Tensor,
context_lengths: torch.Tensor,
num_seqs: int,
max_num_blocks_per_seq: int,
block_size: int,
) -> torch.Tensor:
"""Allocate block tables based on provided context lengths; and copy KV to blocked KV Cache."""
block_id = 0
block_tables = torch.full(size=(num_seqs, max_num_blocks_per_seq), fill_value=-1, dtype=torch.int32)
num_tokens_processed = 0
for i, seq_len in enumerate(context_lengths.tolist()):
right_bound = (seq_len + block_size - 1) // block_size # open bound
block_tables[i, :right_bound] = torch.arange(block_id, block_id + right_bound, dtype=torch.int32)
# Manually fill kv caches by copying from k and v
for i in range(right_bound):
if i == right_bound - 1:
allocated_locs = seq_len % block_size or block_size
allocated_locs = block_size
k_block = k[num_tokens_processed : num_tokens_processed + allocated_locs, :, :].permute(1, 2, 0)
v_block = v[num_tokens_processed : num_tokens_processed + allocated_locs, :, :].permute(1, 2, 0)
k_cache[block_id, :, :, :allocated_locs] = k_block
v_cache[block_id, :, :, :allocated_locs] = v_block
num_tokens_processed += allocated_locs
block_id += 1
return block_tables
def mock_alloc_block_table_and_kvcache_v2(
k: torch.Tensor,
v: torch.Tensor,
k_cache: torch.Tensor,
v_cache: torch.Tensor,
context_lengths: torch.Tensor,
num_seqs: int,
max_num_blocks_per_seq: int,
block_size: int,
) -> torch.Tensor:
"""Allocate block tables based on provided context lengths; and copy KV to blocked KV Cache."""
block_id = 0
block_tables = torch.full(size=(num_seqs, max_num_blocks_per_seq), fill_value=-1, dtype=torch.int32)
num_tokens_processed = 0
for i, seq_len in enumerate(context_lengths.tolist()):
right_bound = (seq_len + block_size - 1) // block_size # open bound
block_tables[i, :right_bound] = torch.arange(block_id, block_id + right_bound, dtype=torch.int32)
# Manually fill kv caches by copying from k and v
for i in range(right_bound):
if i == right_bound - 1:
allocated_locs = seq_len % block_size or block_size
allocated_locs = block_size
k_block = k[num_tokens_processed : num_tokens_processed + allocated_locs, :, :].permute(1, 0, 2)
v_block = v[num_tokens_processed : num_tokens_processed + allocated_locs, :, :].permute(1, 0, 2)
k_cache[block_id, :, :allocated_locs, :] = k_block
v_cache[block_id, :, :allocated_locs, :] = v_block
num_tokens_processed += allocated_locs
block_id += 1
return block_tables
def mock_alloc_single_token(block_tables: torch.Tensor, context_lengths: torch.Tensor, block_size: int) -> None:
# Allocate 1 token on the block table for each seqs in block tables.
# It won't change provided context_lengths.
# Consider max_block_id as the last physical block allocated
# NOTE It assumes all the blocks preceding this block have been allocated
max_block_id = torch.max(block_tables).item()
# the indices on each block table representing the cache block to be allocated one more token
alloc_local_block_indices = context_lengths // block_size
# offsets of the token to be allocated on the target block (for each seq)
alloc_block_offsets = context_lengths % block_size
require_new_block = alloc_block_offsets == 0
new_block_ids = torch.arange(
max_block_id + 1,
max_block_id + 1 + require_new_block.sum(),
if new_block_ids.numel():
new_block_alloc_local_indices = alloc_local_block_indices[require_new_block]
block_tables[require_new_block, new_block_alloc_local_indices] = new_block_ids
def generate_caches_and_block_tables(
k_unpad, v_unpad, kv_lengths, bsz, max_num_blocks_per_seq, block_size, dtype=torch.float16, device="cuda"
) -> Tuple[torch.Tensor, ...]:
# Mock generation of k/v blocked caches and block tables from providied kv unpad and seq lengths
# k_unpad/v_unpad [num_total_tokens, num_kv_heads, head_dim]
_, num_kv_heads, head_dim = k_unpad.shape
cache_shape = (bsz * max_num_blocks_per_seq, num_kv_heads, head_dim, block_size)
k_cache = torch.zeros(size=cache_shape, dtype=dtype, device=device)
v_cache = torch.zeros(size=cache_shape, dtype=dtype, device=device)
# Mock allocation on block tables as well as blocked kv caches
block_tables = mock_alloc_block_table_and_kvcache(
k_unpad, v_unpad, k_cache, v_cache, kv_lengths, bsz, max_num_blocks_per_seq, block_size
return k_cache, v_cache, block_tables
def generate_caches_and_block_tables_v2(
k_unpad, v_unpad, kv_lengths, bsz, max_num_blocks_per_seq, block_size, dtype=torch.float16, device="cuda"
) -> Tuple[torch.Tensor, ...]:
# Mock generation of k/v blocked caches and block tables from providied kv unpad and seq lengths
# k_unpad/v_unpad [num_total_tokens, num_kv_heads, head_dim]
_, num_kv_heads, head_dim = k_unpad.shape
cache_shape = (bsz * max_num_blocks_per_seq, num_kv_heads, block_size, head_dim)
k_cache = torch.zeros(size=cache_shape, dtype=dtype, device=device)
v_cache = torch.zeros(size=cache_shape, dtype=dtype, device=device)
# Mock allocation on block tables as well as blocked kv caches
block_tables = mock_alloc_block_table_and_kvcache_v2(
k_unpad, v_unpad, k_cache, v_cache, kv_lengths, bsz, max_num_blocks_per_seq, block_size
return k_cache, v_cache, block_tables
def convert_kv_unpad_to_padded(
k_unpad: torch.Tensor, kv_seq_lengths: torch.Tensor, bsz: int, max_seq_len: int
) -> torch.Tensor:
# Rebuild (batched) k/v with padding to be used by torch attention
# input k_unpad/v_unpad [num_total_tokens, num_kv_heads, head_dim]
# returns k/v padded [bsz, num_kv_heads, max_seq_len, head_dim]
_, num_kv_heads, head_dim = k_unpad.shape
k_torch = torch.zeros((bsz, max_seq_len, num_kv_heads, head_dim), dtype=k_unpad.dtype, device=k_unpad.device)
prev_len_sum = 0
for i, seq_len in enumerate(kv_seq_lengths.tolist()):
# left-side padding
k_torch[i, -seq_len:, :, :] = k_unpad[prev_len_sum : prev_len_sum + seq_len]
prev_len_sum += seq_len
k_torch = k_torch.transpose(1, 2)
return k_torch