mirror of https://github.com/hpcaitech/ColossalAI
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
224 lines
9.9 KiB
224 lines
9.9 KiB
from typing import Tuple
|
|
|
|
import torch
|
|
from torch.nn import functional as F
|
|
|
|
|
|
# This function is adapted from src/transformers/models/llama/modeling_llama.py
|
|
# in huggingface transformers repository
|
|
# https://github.com/huggingface/transformers/blob/3b7675b2b844b02d4821b827871a21ad16dd446c/src/transformers/models/llama/modeling_llama.py#L273
|
|
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
|
|
"""
|
|
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep).
|
|
The hidden states go from (bsz, num_key_value_heads, seq_len, head_dim) to (bsz, num_attention_heads, seq_len, head_dim)
|
|
"""
|
|
if n_rep == 1:
|
|
return hidden_states
|
|
bsz, num_key_value_heads, seq_len, head_dim = hidden_states.shape
|
|
hidden_states = hidden_states[:, :, None, :, :].expand(bsz, num_key_value_heads, n_rep, seq_len, head_dim)
|
|
return hidden_states.reshape(bsz, num_key_value_heads * n_rep, seq_len, head_dim)
|
|
|
|
|
|
def create_attention_mask(kv_lengths: torch.Tensor, bsz: int, q_len: int, kv_len: int, device="cuda"):
|
|
assert q_len <= kv_len
|
|
|
|
causal_mask = torch.full((q_len, q_len), fill_value=float("-inf"), device=device).triu(diagonal=1)
|
|
|
|
padding_mask = torch.zeros((bsz, 1, q_len, kv_len), dtype=torch.float32, device=device)
|
|
for i in range(bsz):
|
|
cur_seq_len = kv_lengths[i].item()
|
|
assert cur_seq_len <= kv_len
|
|
padding_mask[i, :, :, : kv_len - cur_seq_len] = float("-inf")
|
|
|
|
padding_mask[:, :, -q_len:, -q_len:] += causal_mask
|
|
|
|
return padding_mask
|
|
|
|
|
|
# Attention calculation adapted from HuggingFace transformers repository
|
|
# src/transformers/models/llama/modeling_llama.py
|
|
# https://github.com/huggingface/transformers/blob/633215ba58fe5114d8c8d32e415a04600e010701/src/transformers/models/llama/modeling_llama.py#L350
|
|
def torch_attn_ref(
|
|
q: torch.Tensor, # [bsz, num_heads, q_len, head_dim]
|
|
k: torch.Tensor, # [bsz, num_heads, kv_len, head_dim]
|
|
v: torch.Tensor, # [bsz, num_heads, kv_len, head_dim]
|
|
attention_mask: torch.Tensor, # [bsz, 1, q_len, kv_len]
|
|
bsz: int,
|
|
q_len: int,
|
|
kv_len: int,
|
|
num_heads: int,
|
|
num_kv_heads: int,
|
|
head_dim: int,
|
|
) -> torch.Tensor:
|
|
assert q.shape[-1] == k.shape[-1] == v.shape[-1] == head_dim
|
|
|
|
# repeat kv for GQA and MQA
|
|
# k/v won't change if kv_group_num is 1
|
|
assert num_heads % num_kv_heads == 0, "Number of heads is not multiple of kv heads"
|
|
kv_group_num = num_heads // num_kv_heads
|
|
k = repeat_kv(k, kv_group_num)
|
|
v = repeat_kv(v, kv_group_num)
|
|
|
|
qk = torch.matmul(q, k.transpose(2, 3))
|
|
attn_scores = qk / (head_dim**0.5)
|
|
|
|
assert attn_scores.shape == (bsz, num_heads, q_len, kv_len), "Invalid shape of attention scores"
|
|
if attention_mask is not None:
|
|
if attention_mask.size() != (bsz, 1, q_len, kv_len):
|
|
raise ValueError(
|
|
f"Attention mask should be of size {(bsz, 1, q_len, kv_len)}, but is {attention_mask.size()}"
|
|
)
|
|
attn_scores = attn_scores + attention_mask
|
|
|
|
attn_weights = F.softmax(attn_scores.to(dtype=torch.float32), dim=-1).to(dtype=q.dtype)
|
|
out = torch.matmul(attn_weights, v)
|
|
if out.size() != (bsz, num_heads, q_len, head_dim):
|
|
raise ValueError(
|
|
f"`attn_output` should be of size {(bsz, num_heads, q_len, head_dim)}, but is" f" {out.size()}"
|
|
)
|
|
out = out.transpose(1, 2).contiguous()
|
|
out = out.view(-1, out.size(-2), out.size(-1))
|
|
# out [bsz * q_len, num_heads, head_dim]
|
|
return out
|
|
|
|
|
|
def mock_alloc_block_table_and_kvcache(
|
|
k: torch.Tensor,
|
|
v: torch.Tensor,
|
|
k_cache: torch.Tensor,
|
|
v_cache: torch.Tensor,
|
|
context_lengths: torch.Tensor,
|
|
num_seqs: int,
|
|
max_num_blocks_per_seq: int,
|
|
block_size: int,
|
|
) -> torch.Tensor:
|
|
"""Allocate block tables based on provided context lengths; and copy KV to blocked KV Cache."""
|
|
block_id = 0
|
|
block_tables = torch.full(size=(num_seqs, max_num_blocks_per_seq), fill_value=-1, dtype=torch.int32)
|
|
num_tokens_processed = 0
|
|
for i, seq_len in enumerate(context_lengths.tolist()):
|
|
right_bound = (seq_len + block_size - 1) // block_size # open bound
|
|
block_tables[i, :right_bound] = torch.arange(block_id, block_id + right_bound, dtype=torch.int32)
|
|
# Manually fill kv caches by copying from k and v
|
|
for i in range(right_bound):
|
|
if i == right_bound - 1:
|
|
allocated_locs = seq_len % block_size or block_size
|
|
else:
|
|
allocated_locs = block_size
|
|
k_block = k[num_tokens_processed : num_tokens_processed + allocated_locs, :, :].permute(1, 2, 0)
|
|
v_block = v[num_tokens_processed : num_tokens_processed + allocated_locs, :, :].permute(1, 2, 0)
|
|
k_cache[block_id, :, :, :allocated_locs] = k_block
|
|
v_cache[block_id, :, :, :allocated_locs] = v_block
|
|
|
|
num_tokens_processed += allocated_locs
|
|
block_id += 1
|
|
|
|
return block_tables
|
|
|
|
|
|
def mock_alloc_block_table_and_kvcache_v2(
|
|
k: torch.Tensor,
|
|
v: torch.Tensor,
|
|
k_cache: torch.Tensor,
|
|
v_cache: torch.Tensor,
|
|
context_lengths: torch.Tensor,
|
|
num_seqs: int,
|
|
max_num_blocks_per_seq: int,
|
|
block_size: int,
|
|
) -> torch.Tensor:
|
|
"""Allocate block tables based on provided context lengths; and copy KV to blocked KV Cache."""
|
|
block_id = 0
|
|
block_tables = torch.full(size=(num_seqs, max_num_blocks_per_seq), fill_value=-1, dtype=torch.int32)
|
|
num_tokens_processed = 0
|
|
for i, seq_len in enumerate(context_lengths.tolist()):
|
|
right_bound = (seq_len + block_size - 1) // block_size # open bound
|
|
block_tables[i, :right_bound] = torch.arange(block_id, block_id + right_bound, dtype=torch.int32)
|
|
# Manually fill kv caches by copying from k and v
|
|
for i in range(right_bound):
|
|
if i == right_bound - 1:
|
|
allocated_locs = seq_len % block_size or block_size
|
|
else:
|
|
allocated_locs = block_size
|
|
k_block = k[num_tokens_processed : num_tokens_processed + allocated_locs, :, :].permute(1, 0, 2)
|
|
v_block = v[num_tokens_processed : num_tokens_processed + allocated_locs, :, :].permute(1, 0, 2)
|
|
k_cache[block_id, :, :allocated_locs, :] = k_block
|
|
v_cache[block_id, :, :allocated_locs, :] = v_block
|
|
|
|
num_tokens_processed += allocated_locs
|
|
block_id += 1
|
|
|
|
return block_tables
|
|
|
|
|
|
def mock_alloc_single_token(block_tables: torch.Tensor, context_lengths: torch.Tensor, block_size: int) -> None:
|
|
# Allocate 1 token on the block table for each seqs in block tables.
|
|
# It won't change provided context_lengths.
|
|
# Consider max_block_id as the last physical block allocated
|
|
# NOTE It assumes all the blocks preceding this block have been allocated
|
|
max_block_id = torch.max(block_tables).item()
|
|
# the indices on each block table representing the cache block to be allocated one more token
|
|
alloc_local_block_indices = context_lengths // block_size
|
|
# offsets of the token to be allocated on the target block (for each seq)
|
|
alloc_block_offsets = context_lengths % block_size
|
|
|
|
require_new_block = alloc_block_offsets == 0
|
|
new_block_ids = torch.arange(
|
|
max_block_id + 1,
|
|
max_block_id + 1 + require_new_block.sum(),
|
|
dtype=block_tables.dtype,
|
|
device=block_tables.device,
|
|
)
|
|
|
|
if new_block_ids.numel():
|
|
new_block_alloc_local_indices = alloc_local_block_indices[require_new_block]
|
|
block_tables[require_new_block, new_block_alloc_local_indices] = new_block_ids
|
|
|
|
|
|
def generate_caches_and_block_tables(
|
|
k_unpad, v_unpad, kv_lengths, bsz, max_num_blocks_per_seq, block_size, dtype=torch.float16, device="cuda"
|
|
) -> Tuple[torch.Tensor, ...]:
|
|
# Mock generation of k/v blocked caches and block tables from providied kv unpad and seq lengths
|
|
# k_unpad/v_unpad [num_total_tokens, num_kv_heads, head_dim]
|
|
_, num_kv_heads, head_dim = k_unpad.shape
|
|
cache_shape = (bsz * max_num_blocks_per_seq, num_kv_heads, head_dim, block_size)
|
|
k_cache = torch.zeros(size=cache_shape, dtype=dtype, device=device)
|
|
v_cache = torch.zeros(size=cache_shape, dtype=dtype, device=device)
|
|
# Mock allocation on block tables as well as blocked kv caches
|
|
block_tables = mock_alloc_block_table_and_kvcache(
|
|
k_unpad, v_unpad, k_cache, v_cache, kv_lengths, bsz, max_num_blocks_per_seq, block_size
|
|
)
|
|
return k_cache, v_cache, block_tables
|
|
|
|
|
|
def generate_caches_and_block_tables_v2(
|
|
k_unpad, v_unpad, kv_lengths, bsz, max_num_blocks_per_seq, block_size, dtype=torch.float16, device="cuda"
|
|
) -> Tuple[torch.Tensor, ...]:
|
|
# Mock generation of k/v blocked caches and block tables from providied kv unpad and seq lengths
|
|
# k_unpad/v_unpad [num_total_tokens, num_kv_heads, head_dim]
|
|
_, num_kv_heads, head_dim = k_unpad.shape
|
|
cache_shape = (bsz * max_num_blocks_per_seq, num_kv_heads, block_size, head_dim)
|
|
k_cache = torch.zeros(size=cache_shape, dtype=dtype, device=device)
|
|
v_cache = torch.zeros(size=cache_shape, dtype=dtype, device=device)
|
|
# Mock allocation on block tables as well as blocked kv caches
|
|
block_tables = mock_alloc_block_table_and_kvcache_v2(
|
|
k_unpad, v_unpad, k_cache, v_cache, kv_lengths, bsz, max_num_blocks_per_seq, block_size
|
|
)
|
|
return k_cache, v_cache, block_tables
|
|
|
|
|
|
def convert_kv_unpad_to_padded(
|
|
k_unpad: torch.Tensor, kv_seq_lengths: torch.Tensor, bsz: int, max_seq_len: int
|
|
) -> torch.Tensor:
|
|
# Rebuild (batched) k/v with padding to be used by torch attention
|
|
# input k_unpad/v_unpad [num_total_tokens, num_kv_heads, head_dim]
|
|
# returns k/v padded [bsz, num_kv_heads, max_seq_len, head_dim]
|
|
_, num_kv_heads, head_dim = k_unpad.shape
|
|
k_torch = torch.zeros((bsz, max_seq_len, num_kv_heads, head_dim), dtype=k_unpad.dtype, device=k_unpad.device)
|
|
prev_len_sum = 0
|
|
for i, seq_len in enumerate(kv_seq_lengths.tolist()):
|
|
# left-side padding
|
|
k_torch[i, -seq_len:, :, :] = k_unpad[prev_len_sum : prev_len_sum + seq_len]
|
|
prev_len_sum += seq_len
|
|
k_torch = k_torch.transpose(1, 2)
|
|
return k_torch
|