ColossalAI/tests/test_infer/test_ops/triton/kernel_utils.py

215 lines
9.7 KiB
Python
Raw Normal View History

from typing import Tuple
[Feature] The first PR to Add TP inference engine, kv-cache manager and related kernels for our inference system (#4577) * [infer] Infer/llama demo (#4503) * add * add infer example * finish * finish * stash * fix * [Kernels] add inference token attention kernel (#4505) * add token forward * fix tests * fix comments * add try import triton * add adapted license * add tests check * [Kernels] add necessary kernels (llama & bloom) for attention forward and kv-cache manager (#4485) * added _vllm_rms_norm * change place * added tests * added tests * modify * adding kernels * added tests: * adding kernels * modify * added * updating kernels * adding tests * added tests * kernel change * submit * modify * added * edit comments * change name * change commnets and fix import * add * added * combine codes (#4509) * [feature] add KV cache manager for llama & bloom inference (#4495) * add kv cache memory manager * add stateinfo during inference * format * format * rename file * add kv cache test * revise on BatchInferState * file dir change * [Bug FIx] import llama context ops fix (#4524) * added _vllm_rms_norm * change place * added tests * added tests * modify * adding kernels * added tests: * adding kernels * modify * added * updating kernels * adding tests * added tests * kernel change * submit * modify * added * edit comments * change name * change commnets and fix import * add * added * fix * add ops into init.py * add * [Infer] Add TPInferEngine and fix file path (#4532) * add engine for TP inference * move file path * update path * fix TPInferEngine * remove unused file * add engine test demo * revise TPInferEngine * fix TPInferEngine, add test * fix * Add Inference test for llama (#4508) * add kv cache memory manager * add stateinfo during inference * add * add infer example * finish * finish * format * format * rename file * add kv cache test * revise on BatchInferState * add inference test for llama * fix conflict * feature: add some new features for llama engine * adapt colossalai triton interface * Change the parent class of llama policy * add nvtx * move llama inference code to tensor_parallel * fix __init__.py * rm tensor_parallel * fix: fix bugs in auto_policy.py * fix:rm some unused codes * mv colossalai/tpinference to colossalai/inference/tensor_parallel * change __init__.py * save change * fix engine * Bug fix: Fix hang * remove llama_infer_engine.py --------- Co-authored-by: yuanheng-zhao <jonathan.zhaoyh@gmail.com> Co-authored-by: CjhHa1 <cjh18671720497@outlook.com> * [infer] Add Bloom inference policy and replaced methods (#4512) * add bloom inference methods and policy * enable pass BatchInferState from model forward * revise bloom infer layers/policies * add engine for inference (draft) * add test for bloom infer * fix bloom infer policy and flow * revise bloom test * fix bloom file path * remove unused codes * fix bloom modeling * fix dir typo * fix trivial * fix policy * clean pr * trivial fix * Revert "[infer] Add Bloom inference policy and replaced methods (#4512)" (#4552) This reverts commit 17cfa5714083a81a505c097f1c411cd28162d922. * [Doc] Add colossal inference doc (#4549) * create readme * add readme.md * fix typos * [infer] Add Bloom inference policy and replaced methods (#4553) * add bloom inference methods and policy * enable pass BatchInferState from model forward * revise bloom infer layers/policies * add engine for inference (draft) * add test for bloom infer * fix bloom infer policy and flow * revise bloom test * fix bloom file path * remove unused codes * fix bloom modeling * fix dir typo * fix trivial * fix policy * clean pr * trivial fix * trivial * Fix Bugs In Llama Model Forward (#4550) * add kv cache memory manager * add stateinfo during inference * add * add infer example * finish * finish * format * format * rename file * add kv cache test * revise on BatchInferState * add inference test for llama * fix conflict * feature: add some new features for llama engine * adapt colossalai triton interface * Change the parent class of llama policy * add nvtx * move llama inference code to tensor_parallel * fix __init__.py * rm tensor_parallel * fix: fix bugs in auto_policy.py * fix:rm some unused codes * mv colossalai/tpinference to colossalai/inference/tensor_parallel * change __init__.py * save change * fix engine * Bug fix: Fix hang * remove llama_infer_engine.py * bug fix: fix bugs about infer_state.is_context_stage * remove pollcies * fix: delete unused code * fix: delete unused code * remove unused coda * fix conflict --------- Co-authored-by: yuanheng-zhao <jonathan.zhaoyh@gmail.com> Co-authored-by: CjhHa1 <cjh18671720497@outlook.com> * [doc] add colossal inference fig (#4554) * create readme * add readme.md * fix typos * upload fig * [NFC] fix docstring for colossal inference (#4555) Fix docstring and comments in kv cache manager and bloom modeling * fix docstring in llama modeling (#4557) * [Infer] check import vllm (#4559) * change import vllm * import apply_rotary_pos_emb * change import location * [DOC] add installation req (#4561) * add installation req * fix * slight change * remove empty * [Feature] rms-norm transfer into inference llama.py (#4563) * add installation req * fix * slight change * remove empty * add rmsnorm polciy * add * clean codes * [infer] Fix tp inference engine (#4564) * fix engine prepare data * add engine test * use bloom for testing * revise on test * revise on test * reset shardformer llama (#4569) * [infer] Fix engine - tensors on different devices (#4570) * fix diff device in engine * [codefactor] Feature/colossal inference (#4579) * code factors * remove * change coding (#4581) * [doc] complete README of colossal inference (#4585) * complete fig * Update README.md * [doc]update readme (#4586) * update readme * Update README.md * bug fix: fix bus in llama and bloom (#4588) * [BUG FIX]Fix test engine in CI and non-vllm kernels llama forward (#4592) * fix tests * clean * clean * fix bugs * add * fix llama non-vllm kernels bug * modify * clean codes * [Kernel]Rmsnorm fix (#4598) * fix tests * clean * clean * fix bugs * add * fix llama non-vllm kernels bug * modify * clean codes * add triton rmsnorm * delete vllm kernel flag * [Bug Fix]Fix bugs in llama (#4601) * fix tests * clean * clean * fix bugs * add * fix llama non-vllm kernels bug * modify * clean codes * bug fix: remove rotary_positions_ids --------- Co-authored-by: cuiqing.li <lixx3527@gmail.com> * [kernel] Add triton layer norm & replace norm for bloom (#4609) * add layernorm for inference * add test for layernorm kernel * add bloom layernorm replacement policy * trivial: path * [Infer] Bug fix rotary embedding in llama (#4608) * fix rotary embedding * delete print * fix init seq len bug * rename pytest * add benchmark for llama * refactor codes * delete useless code * [bench] Add bloom inference benchmark (#4621) * add bloom benchmark * readme - update benchmark res * trivial - uncomment for testing (#4622) * [Infer] add check triton and cuda version for tests (#4627) * fix rotary embedding * delete print * fix init seq len bug * rename pytest * add benchmark for llama * refactor codes * delete useless code * add check triton and cuda * Update sharder.py (#4629) * [Inference] Hot fix some bugs and typos (#4632) * fix * fix test * fix conflicts * [typo]Comments fix (#4633) * fallback * fix commnets * bug fix: fix some bugs in test_llama and test_bloom (#4635) * [Infer] delete benchmark in tests and fix bug for llama and bloom (#4636) * fix rotary embedding * delete print * fix init seq len bug * rename pytest * add benchmark for llama * refactor codes * delete useless code * add check triton and cuda * delete benchmark and fix infer bugs * delete benchmark for tests * delete useless code * delete bechmark function in utils * [Fix] Revise TPInferEngine, inference tests and benchmarks (#4642) * [Fix] revise TPInferEngine methods and inference tests * fix llama/bloom infer benchmarks * fix infer tests * trivial fix: benchmakrs * trivial * trivial: rm print * modify utils filename for infer ops test (#4657) * [Infer] Fix TPInferEngine init & inference tests, benchmarks (#4670) * fix engine funcs * TPInferEngine: receive shard config in init * benchmarks: revise TPInferEngine init * benchmarks: remove pytest decorator * trivial fix * use small model for tests * [NFC] use args for infer benchmarks (#4674) * revise infer default (#4683) * [Fix] optimize/shard model in TPInferEngine init (#4684) * remove using orig model in engine * revise inference tests * trivial: rename --------- Co-authored-by: Jianghai <72591262+CjhHa1@users.noreply.github.com> Co-authored-by: Xu Kai <xukai16@foxmail.com> Co-authored-by: Yuanheng Zhao <54058983+yuanheng-zhao@users.noreply.github.com> Co-authored-by: yuehuayingxueluo <867460659@qq.com> Co-authored-by: yuanheng-zhao <jonathan.zhaoyh@gmail.com> Co-authored-by: CjhHa1 <cjh18671720497@outlook.com>
2023-09-11 17:22:56 +00:00
import torch
from torch.nn import functional as F
# This function is adapted from src/transformers/models/llama/modeling_llama.py
# in huggingface transformers repository
# https://github.com/huggingface/transformers/blob/3b7675b2b844b02d4821b827871a21ad16dd446c/src/transformers/models/llama/modeling_llama.py#L273
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
"""
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep).
The hidden states go from (bsz, num_key_value_heads, seq_len, head_dim) to (bsz, num_attention_heads, seq_len, head_dim)
"""
if n_rep == 1:
return hidden_states
bsz, num_key_value_heads, seq_len, head_dim = hidden_states.shape
hidden_states = hidden_states[:, :, None, :, :].expand(bsz, num_key_value_heads, n_rep, seq_len, head_dim)
return hidden_states.reshape(bsz, num_key_value_heads * n_rep, seq_len, head_dim)
def prepare_padding_mask(kv_lengths: torch.Tensor, bsz: int, kv_seq_len: int, device="cuda"):
padding_mask = torch.zeros((bsz, 1, 1, kv_seq_len), dtype=torch.float32, device=device)
for i in range(bsz):
cur_seq_len = kv_lengths[i].item()
assert cur_seq_len <= kv_seq_len
padding_mask[i, :, :, : kv_seq_len - cur_seq_len] = float("-inf")
return padding_mask
# Attention calculation adapted from HuggingFace transformers repository
# src/transformers/models/llama/modeling_llama.py
# https://github.com/huggingface/transformers/blob/633215ba58fe5114d8c8d32e415a04600e010701/src/transformers/models/llama/modeling_llama.py#L350
def torch_attn_ref(
q: torch.Tensor, # [bsz, num_heads, q_len, head_dim]
k: torch.Tensor, # [bsz, num_heads, kv_seq_len, head_dim]
v: torch.Tensor, # [bsz, num_heads, kv_seq_len, head_dim]
attention_mask: torch.Tensor, # [bsz, 1, seq_len, kv_seq_len]
bsz: int,
seq_len: int,
kv_seq_len: int,
num_heads: int,
num_kv_heads: int,
head_dim: int,
) -> torch.Tensor:
assert q.shape[-1] == k.shape[-1] == v.shape[-1] == head_dim
# repeat kv for GQA and MQA
# k/v won't change if kv_group_num is 1
assert num_heads % num_kv_heads == 0, "Number of heads is not multiple of kv heads"
kv_group_num = num_heads // num_kv_heads
k = repeat_kv(k, kv_group_num)
v = repeat_kv(v, kv_group_num)
qk = torch.matmul(q, k.transpose(2, 3))
attn_scores = qk / (head_dim**0.5)
assert attn_scores.shape == (bsz, num_heads, seq_len, kv_seq_len), "Invalid shape of attention scores"
# for left-side padding
if attention_mask.size() != (bsz, 1, seq_len, kv_seq_len):
raise ValueError(
f"Attention mask should be of size {(bsz, 1, seq_len, kv_seq_len)}, but is {attention_mask.size()}"
)
attn_scores = attn_scores + attention_mask
attn_weights = F.softmax(attn_scores.to(dtype=torch.float32), dim=-1).to(dtype=q.dtype)
out = torch.matmul(attn_weights, v)
if out.size() != (bsz, num_heads, seq_len, head_dim):
raise ValueError(
f"`attn_output` should be of size {(bsz, num_heads, seq_len, head_dim)}, but is" f" {out.size()}"
)
out = out.transpose(1, 2).contiguous()
out = out.squeeze(1)
return out
def mock_alloc_block_table_and_kvcache(
k: torch.Tensor,
v: torch.Tensor,
k_cache: torch.Tensor,
v_cache: torch.Tensor,
context_lengths: torch.Tensor,
num_seqs: int,
max_num_blocks_per_seq: int,
block_size: int,
) -> torch.Tensor:
"""Allocate block tables based on provided context lengths; and copy KV to blocked KV Cache."""
block_id = 0
block_tables = torch.full(size=(num_seqs, max_num_blocks_per_seq), fill_value=-1, dtype=torch.int32)
num_tokens_processed = 0
for i, seq_len in enumerate(context_lengths.tolist()):
right_bound = (seq_len + block_size - 1) // block_size # open bound
block_tables[i, :right_bound] = torch.arange(block_id, block_id + right_bound, dtype=torch.int32)
# Manually fill kv caches by copying from k and v
for i in range(right_bound):
if i == right_bound - 1:
allocated_locs = seq_len % block_size or block_size
else:
allocated_locs = block_size
k_block = k[num_tokens_processed : num_tokens_processed + allocated_locs, :, :].permute(1, 2, 0)
v_block = v[num_tokens_processed : num_tokens_processed + allocated_locs, :, :].permute(1, 2, 0)
k_cache[block_id, :, :, :allocated_locs] = k_block
v_cache[block_id, :, :, :allocated_locs] = v_block
num_tokens_processed += allocated_locs
block_id += 1
return block_tables
def mock_alloc_block_table_and_kvcache_v2(
k: torch.Tensor,
v: torch.Tensor,
k_cache: torch.Tensor,
v_cache: torch.Tensor,
context_lengths: torch.Tensor,
num_seqs: int,
max_num_blocks_per_seq: int,
block_size: int,
) -> torch.Tensor:
"""Allocate block tables based on provided context lengths; and copy KV to blocked KV Cache."""
block_id = 0
block_tables = torch.full(size=(num_seqs, max_num_blocks_per_seq), fill_value=-1, dtype=torch.int32)
num_tokens_processed = 0
for i, seq_len in enumerate(context_lengths.tolist()):
right_bound = (seq_len + block_size - 1) // block_size # open bound
block_tables[i, :right_bound] = torch.arange(block_id, block_id + right_bound, dtype=torch.int32)
# Manually fill kv caches by copying from k and v
for i in range(right_bound):
if i == right_bound - 1:
allocated_locs = seq_len % block_size or block_size
else:
allocated_locs = block_size
k_block = k[num_tokens_processed : num_tokens_processed + allocated_locs, :, :].permute(1, 0, 2)
v_block = v[num_tokens_processed : num_tokens_processed + allocated_locs, :, :].permute(1, 0, 2)
k_cache[block_id, :, :allocated_locs, :] = k_block
v_cache[block_id, :, :allocated_locs, :] = v_block
num_tokens_processed += allocated_locs
block_id += 1
return block_tables
def mock_alloc_single_token(block_tables: torch.Tensor, context_lengths: torch.Tensor, block_size: int) -> None:
# Allocate 1 token on the block table for each seqs in block tables.
# It won't change provided context_lengths.
# Consider max_block_id as the last physical block allocated
# NOTE It assumes all the blocks preceding this block have been allocated
max_block_id = torch.max(block_tables).item()
# the indices on each block table representing the cache block to be allocated one more token
alloc_local_block_indices = context_lengths // block_size
# offsets of the token to be allocated on the target block (for each seq)
alloc_block_offsets = context_lengths % block_size
require_new_block = alloc_block_offsets == 0
new_block_ids = torch.arange(
max_block_id + 1,
max_block_id + 1 + require_new_block.sum(),
dtype=block_tables.dtype,
device=block_tables.device,
)
if new_block_ids.numel():
new_block_alloc_local_indices = alloc_local_block_indices[require_new_block]
block_tables[require_new_block, new_block_alloc_local_indices] = new_block_ids
def generate_caches_and_block_tables(
k_unpad, v_unpad, kv_lengths, bsz, max_num_blocks_per_seq, block_size, dtype=torch.float16, device="cuda"
) -> Tuple[torch.Tensor, ...]:
# Mock generation of k/v blocked caches and block tables from providied kv unpad and seq lengths
# k_unpad/v_unpad [num_total_tokens, num_kv_heads, head_dim]
_, num_kv_heads, head_dim = k_unpad.shape
cache_shape = (bsz * max_num_blocks_per_seq, num_kv_heads, head_dim, block_size)
k_cache = torch.zeros(size=cache_shape, dtype=dtype, device=device)
v_cache = torch.zeros(size=cache_shape, dtype=dtype, device=device)
# Mock allocation on block tables as well as blocked kv caches
block_tables = mock_alloc_block_table_and_kvcache(
k_unpad, v_unpad, k_cache, v_cache, kv_lengths, bsz, max_num_blocks_per_seq, block_size
)
return k_cache, v_cache, block_tables
def generate_caches_and_block_tables_v2(
k_unpad, v_unpad, kv_lengths, bsz, max_num_blocks_per_seq, block_size, dtype=torch.float16, device="cuda"
) -> Tuple[torch.Tensor, ...]:
# Mock generation of k/v blocked caches and block tables from providied kv unpad and seq lengths
# k_unpad/v_unpad [num_total_tokens, num_kv_heads, head_dim]
_, num_kv_heads, head_dim = k_unpad.shape
cache_shape = (bsz * max_num_blocks_per_seq, num_kv_heads, block_size, head_dim)
k_cache = torch.zeros(size=cache_shape, dtype=dtype, device=device)
v_cache = torch.zeros(size=cache_shape, dtype=dtype, device=device)
# Mock allocation on block tables as well as blocked kv caches
block_tables = mock_alloc_block_table_and_kvcache_v2(
k_unpad, v_unpad, k_cache, v_cache, kv_lengths, bsz, max_num_blocks_per_seq, block_size
)
return k_cache, v_cache, block_tables
def convert_kv_unpad_to_padded(
k_unpad: torch.Tensor, kv_seq_lengths: torch.Tensor, bsz: int, max_seq_len: int
) -> torch.Tensor:
# Rebuild (batched) k/v with padding to be used by torch attention
# input k_unpad/v_unpad [num_total_tokens, num_kv_heads, head_dim]
# returns k/v padded [bsz, num_kv_heads, max_seq_len, head_dim]
_, num_kv_heads, head_dim = k_unpad.shape
k_torch = torch.zeros((bsz, max_seq_len, num_kv_heads, head_dim), dtype=k_unpad.dtype, device=k_unpad.device)
prev_len_sum = 0
for i, seq_len in enumerate(kv_seq_lengths.tolist()):
# left-side padding
k_torch[i, -seq_len:, :, :] = k_unpad[prev_len_sum : prev_len_sum + seq_len]
prev_len_sum += seq_len
k_torch = k_torch.transpose(1, 2)
return k_torch