mirror of https://github.com/hpcaitech/ColossalAI
[inference]Optimize the usage of the mid tensors space in flash attn (#5304)
* opt flash attn * opt tmp tensor * fix benchmark_llama * fix code style * fix None logic for output tensor * fix adapted to get_xine_cache * add comment * fix ci bugs * fix some codes * rm duplicated codes * rm duplicated codes * fix code style * add _get_dtype in config.pypull/5326/head
parent
af8359c430
commit
4f28cb43c0
|
@ -55,6 +55,7 @@ class InferenceConfig:
|
|||
def __post_init__(self):
|
||||
self._init_batch_size()
|
||||
self._verify_config()
|
||||
self._get_dtype()
|
||||
|
||||
def _init_batch_size(self):
|
||||
"""
|
||||
|
@ -84,6 +85,7 @@ class InferenceConfig:
|
|||
assert (
|
||||
self.tp_size * self.pp_size == dist.get_world_size()
|
||||
), f"TP size({self.tp_size}) * PP size({self.pp_size}) should be equal to the global world size ({dist.get_world_size()})"
|
||||
|
||||
assert self.dtype in [
|
||||
"fp16",
|
||||
"fp32",
|
||||
|
@ -97,3 +99,11 @@ class InferenceConfig:
|
|||
"gptq",
|
||||
None,
|
||||
], f"quant should be one of 'smoothquant', 'gptq', but got {self.quant_mode}."
|
||||
|
||||
def _get_dtype(self) -> None:
|
||||
if self.dtype == "fp32" or self.dtype == torch.float32:
|
||||
self.dtype = torch.float32
|
||||
elif self.dtype == "fp16" or self.dtype == torch.float16:
|
||||
self.dtype = torch.float16
|
||||
else:
|
||||
self.dtype = torch.bfloat16
|
||||
|
|
|
@ -51,17 +51,10 @@ class InferenceEngine:
|
|||
self.inference_config = inference_config
|
||||
self.model_config = model.config
|
||||
self.device = torch.device("cuda")
|
||||
self.dtype = inference_config.dtype
|
||||
|
||||
model = model.eval()
|
||||
|
||||
if inference_config.dtype == "fp32" or inference_config.dtype == torch.float32:
|
||||
self.dtype = torch.float32
|
||||
elif inference_config.dtype == "fp16" or inference_config.dtype == torch.float16:
|
||||
self.dtype = torch.float16
|
||||
model.half()
|
||||
else:
|
||||
self.dtype = torch.bfloat16
|
||||
model.to(torch.bfloat16)
|
||||
model.to(self.dtype)
|
||||
|
||||
if model_policy is None:
|
||||
model_policy = model_policy_map[self.model_config.model_type]()
|
||||
|
@ -217,6 +210,7 @@ class InferenceEngine:
|
|||
None,
|
||||
block_table,
|
||||
self.tokenizer.eos_token_id,
|
||||
self.tokenizer.pad_token_id,
|
||||
self.inference_config.max_output_len,
|
||||
)
|
||||
self.request_handler.add_sequence(sequence)
|
||||
|
@ -241,7 +235,6 @@ class InferenceEngine:
|
|||
batch,
|
||||
self.k_cahce,
|
||||
self.v_cache,
|
||||
padding_id=self.tokenizer.pad_token_id,
|
||||
)
|
||||
|
||||
logits = logits[:, -1, :]
|
||||
|
|
|
@ -4,6 +4,7 @@ import torch
|
|||
from transformers.configuration_utils import PretrainedConfig
|
||||
|
||||
from colossalai.inference.config import InferenceConfig
|
||||
from colossalai.inference.flash_decoding_utils import FDIntermTensors
|
||||
from colossalai.inference.kv_cache import KVCacheManager
|
||||
from colossalai.inference.logit_processors import logit_processor
|
||||
from colossalai.inference.sampler import *
|
||||
|
@ -69,20 +70,60 @@ class RequestHandler:
|
|||
Args:
|
||||
inference_config: Configuration for initialize and manage kv cache.
|
||||
model_config: Configuration for model
|
||||
dtype (torch.dtype): The data type for weights and activations.
|
||||
"""
|
||||
|
||||
def __init__(self, inference_config: InferenceConfig, model_config: PretrainedConfig) -> None:
|
||||
self.inference_config = inference_config
|
||||
self._init_cache(model_config)
|
||||
|
||||
self.running_list: RunningList = RunningList(inference_config.prefill_ratio)
|
||||
self.waiting_list: List[List] = [[], [], []]
|
||||
self.done_list: List[Sequence] = []
|
||||
device = torch.cuda.current_device()
|
||||
self.running_batch = BatchInfo(is_prompts=False, device=device)
|
||||
self.prefill_batch = BatchInfo(is_prompts=True, device=device)
|
||||
self.dtype = inference_config.dtype
|
||||
self.max_batch_size = inference_config.max_batch_size
|
||||
|
||||
# initialize cache
|
||||
self._init_cache(model_config)
|
||||
|
||||
# initialize batch
|
||||
device = torch.cuda.current_device()
|
||||
kv_max_split_num = (
|
||||
inference_config.max_input_len + inference_config.max_output_len + inference_config.block_size - 1
|
||||
) // inference_config.block_size
|
||||
head_dim = model_config.hidden_size // model_config.num_attention_heads
|
||||
|
||||
fd_inter_tensor = FDIntermTensors()
|
||||
fd_inter_tensor.initialize(
|
||||
max_batch_size=self.max_batch_size,
|
||||
num_attn_heads=model_config.num_attention_heads,
|
||||
kv_max_split_num=kv_max_split_num,
|
||||
head_dim=head_dim,
|
||||
dtype=self.dtype,
|
||||
device=device,
|
||||
)
|
||||
|
||||
# TODO In the continuous batching scenario, the batch size may be greater than max_batch_size,
|
||||
# which may cause bugs and this issue should be fixed later.
|
||||
self.running_batch = BatchInfo(
|
||||
max_batch_size=self.max_batch_size,
|
||||
kv_max_split_num=kv_max_split_num,
|
||||
num_heads=model_config.num_attention_heads,
|
||||
head_dim=head_dim,
|
||||
is_prompts=False,
|
||||
device=device,
|
||||
dtype=self.dtype,
|
||||
fd_inter_tensor=fd_inter_tensor,
|
||||
)
|
||||
self.prefill_batch = BatchInfo(
|
||||
max_batch_size=self.max_batch_size,
|
||||
kv_max_split_num=kv_max_split_num,
|
||||
num_heads=model_config.num_attention_heads,
|
||||
head_dim=head_dim,
|
||||
is_prompts=True,
|
||||
device=device,
|
||||
dtype=self.dtype,
|
||||
fd_inter_tensor=fd_inter_tensor,
|
||||
)
|
||||
|
||||
def _init_cache(self, model_config):
|
||||
self.cache_manager = KVCacheManager(self.inference_config, model_config)
|
||||
|
||||
|
|
|
@ -58,12 +58,7 @@ class KVCacheManager:
|
|||
# Parallel settings
|
||||
self.tp_size = config.tp_size
|
||||
# Model settings
|
||||
if config.dtype == "fp32" or config.dtype == torch.float32:
|
||||
self.dtype = torch.float32
|
||||
elif config.dtype == "fp16" or config.dtype == torch.float16:
|
||||
self.dtype = torch.float16
|
||||
else:
|
||||
self.dtype = torch.bfloat16
|
||||
self.dtype = config.dtype
|
||||
self.elem_size_in_bytes = torch.tensor([], dtype=self.dtype).element_size()
|
||||
self.num_layers = get_model_config_attr(model_config, "num_hidden_layers")
|
||||
# For now we focus on MHA only, TODO add handling for MQA and GQA
|
||||
|
|
|
@ -4,6 +4,7 @@ from typing import List, Optional, Tuple
|
|||
import torch
|
||||
from transformers.models.llama.modeling_llama import LlamaAttention, LlamaDecoderLayer, LlamaForCausalLM, LlamaModel
|
||||
|
||||
from colossalai.inference.flash_decoding_utils import FDIntermTensors
|
||||
from colossalai.inference.modeling.layers.attention import PagedAttention
|
||||
from colossalai.inference.struct import BatchInfo
|
||||
from colossalai.kernel.triton import (
|
||||
|
@ -50,7 +51,6 @@ def llama_causal_lm_forward(
|
|||
batch: BatchInfo = None,
|
||||
k_caches: List[torch.Tensor] = None,
|
||||
v_caches: List[torch.Tensor] = None,
|
||||
padding_id: int = None,
|
||||
):
|
||||
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
||||
hidden_states = llama_model_forward(
|
||||
|
@ -58,7 +58,6 @@ def llama_causal_lm_forward(
|
|||
batch=batch,
|
||||
k_caches=k_caches,
|
||||
v_caches=v_caches,
|
||||
padding_id=padding_id,
|
||||
)
|
||||
logits = self.lm_head(hidden_states)
|
||||
return logits
|
||||
|
@ -70,11 +69,10 @@ def llama_model_forward(
|
|||
batch: BatchInfo = None,
|
||||
k_caches: List[torch.Tensor] = None,
|
||||
v_caches: List[torch.Tensor] = None,
|
||||
padding_id: int = None,
|
||||
):
|
||||
input_ids = batch.get_batch_inputs()
|
||||
block_tables = batch.get_block_table_tensor()
|
||||
attention_mask = batch.get_attn_mask(padding_id)
|
||||
attention_mask = batch.get_attn_mask()
|
||||
|
||||
if attention_mask is not None:
|
||||
if HAS_TRITON:
|
||||
|
@ -84,6 +82,7 @@ def llama_model_forward(
|
|||
else:
|
||||
sequence_lengths = batch.get_sequence_lengths()
|
||||
|
||||
batch_size, _ = input_ids.shape
|
||||
kv_seq_len = sequence_lengths.max().item()
|
||||
|
||||
if attention_mask is not None:
|
||||
|
@ -102,7 +101,22 @@ def llama_model_forward(
|
|||
|
||||
hidden_states = self.embed_tokens(input_ids)
|
||||
|
||||
cos_sin = get_cos_sin(sequence_lengths, self._cos_cached, self._sin_cached, batch.is_prompts, hidden_states.dtype)
|
||||
# When testing, the performance of get_xine_cache is lower than that of get_cos_sin.
|
||||
# cos = get_xine_cache(sequence_lengths, self._cos_cached, batch.is_prompts)
|
||||
# sin = get_xine_cache(sequence_lengths, self._sin_cached, batch.is_prompts)
|
||||
# cos_sin = (cos, sin)
|
||||
|
||||
cos_sin = get_cos_sin(sequence_lengths, self._cos_cached, self._sin_cached, batch.is_prompts, batch.dtype)
|
||||
|
||||
if batch.is_prompts:
|
||||
output_tensor = torch.zeros(
|
||||
(sequence_lengths.sum().item(), batch.num_heads, batch.head_dim), dtype=batch.dtype, device=batch.device
|
||||
)
|
||||
else:
|
||||
output_tensor = torch.zeros(
|
||||
(batch_size, 1, batch.num_heads, batch.head_dim), dtype=batch.dtype, device=batch.device
|
||||
)
|
||||
sm_scale = 1.0 / (batch.head_dim**0.5)
|
||||
|
||||
for layer_id, decoder_layer in enumerate(self.layers):
|
||||
hidden_states = decoder_layer(
|
||||
|
@ -116,6 +130,9 @@ def llama_model_forward(
|
|||
attention_mask=attention_mask,
|
||||
kv_seq_len=kv_seq_len,
|
||||
cos_sin=cos_sin,
|
||||
fd_inter_tensor=batch.fd_inter_tensor,
|
||||
output_tensor=output_tensor,
|
||||
sm_scale=sm_scale,
|
||||
)
|
||||
|
||||
hidden_states = self.norm(hidden_states)
|
||||
|
@ -131,10 +148,13 @@ def llama_decoder_layer_forward(
|
|||
k_cache: torch.Tensor = None,
|
||||
v_cache: torch.Tensor = None,
|
||||
is_prompts: bool = True,
|
||||
sequence_lengths: int = None,
|
||||
sequence_lengths: torch.Tensor = None,
|
||||
attention_mask: torch.Tensor = None,
|
||||
kv_seq_len: int = 0,
|
||||
cos_sin: Tuple[torch.Tensor] = None,
|
||||
fd_inter_tensor: FDIntermTensors = None,
|
||||
output_tensor: torch.Tensor = None,
|
||||
sm_scale: int = None,
|
||||
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
||||
residual = hidden_states
|
||||
|
||||
|
@ -151,6 +171,9 @@ def llama_decoder_layer_forward(
|
|||
attention_mask=attention_mask,
|
||||
kv_seq_len=kv_seq_len,
|
||||
cos_sin=cos_sin,
|
||||
fd_inter_tensor=fd_inter_tensor,
|
||||
output_tensor=output_tensor,
|
||||
sm_scale=sm_scale,
|
||||
)
|
||||
|
||||
hidden_states = residual + hidden_states
|
||||
|
@ -178,6 +201,9 @@ def llama_attn_forward(
|
|||
attention_mask: torch.Tensor = None,
|
||||
kv_seq_len: int = 0,
|
||||
cos_sin: Tuple[torch.Tensor] = None,
|
||||
fd_inter_tensor: FDIntermTensors = None,
|
||||
output_tensor: torch.Tensor = None,
|
||||
sm_scale: int = None,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||
bsz, q_len, _ = hidden_states.size()
|
||||
|
||||
|
@ -206,7 +232,17 @@ def llama_attn_forward(
|
|||
|
||||
if is_prompts:
|
||||
attn_output = context_attention_unpadded(
|
||||
query_states, key_states, value_states, k_cache, v_cache, sequence_lengths, block_tables, block_size
|
||||
q=query_states,
|
||||
k=key_states,
|
||||
v=value_states,
|
||||
k_cache=k_cache,
|
||||
v_cache=v_cache,
|
||||
context_lengths=sequence_lengths,
|
||||
block_tables=block_tables,
|
||||
block_size=block_size,
|
||||
output=output_tensor,
|
||||
max_seq_len=kv_seq_len,
|
||||
sm_scale=sm_scale,
|
||||
)
|
||||
if attention_mask is not None:
|
||||
attn_output = pad_input(attn_output, indices, bsz, q_len)
|
||||
|
@ -214,7 +250,17 @@ def llama_attn_forward(
|
|||
copy_kv_to_blocked_cache(key_states, k_cache, kv_lengths=sequence_lengths, block_tables=block_tables)
|
||||
copy_kv_to_blocked_cache(value_states, v_cache, kv_lengths=sequence_lengths, block_tables=block_tables)
|
||||
attn_output = flash_decoding_attention(
|
||||
query_states, k_cache, v_cache, sequence_lengths, block_tables, block_size
|
||||
q=query_states,
|
||||
k_cache=k_cache,
|
||||
v_cache=v_cache,
|
||||
kv_seq_len=sequence_lengths,
|
||||
block_tables=block_tables,
|
||||
block_size=block_size,
|
||||
max_seq_len_in_batch=kv_seq_len,
|
||||
output=output_tensor,
|
||||
mid_output=fd_inter_tensor.mid_output,
|
||||
mid_output_lse=fd_inter_tensor.mid_output_lse,
|
||||
sm_scale=sm_scale,
|
||||
)
|
||||
attn_output = attn_output.squeeze(1)
|
||||
else:
|
||||
|
@ -285,6 +331,16 @@ def unpading_input(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, attention_
|
|||
|
||||
@torch.no_grad()
|
||||
def get_cos_sin(lengths, cos_cache, sin_cache, is_prompts, dtype):
|
||||
"""
|
||||
Get cos and sin for the cache, and return nopad format.
|
||||
Args:
|
||||
lengths: shape(num_seqs,), stores lenghth of each sequence.
|
||||
cos_cache: shape(max_rotary_position(e.g.2048), head_dim), cos cache constrcuted in model.
|
||||
sin_cache: shape(max_rotary_position(e.g.2048), head_dim), sin cache constrcuted in model.
|
||||
is_prompts: bool, mark if in prefill mode.
|
||||
dtype: The data type of this inference process.
|
||||
"""
|
||||
|
||||
if is_prompts:
|
||||
index_arrays = [torch.arange(length) for length in lengths]
|
||||
else:
|
||||
|
|
|
@ -5,6 +5,7 @@ from typing import Any, List, Tuple, Union
|
|||
import torch
|
||||
from ordered_set import OrderedSet
|
||||
|
||||
from colossalai.inference.flash_decoding_utils import FDIntermTensors
|
||||
from colossalai.logging import get_dist_logger
|
||||
|
||||
logger = get_dist_logger(__name__)
|
||||
|
@ -61,6 +62,7 @@ class Sequence:
|
|||
sample_params (SampleParams): The sample_params of input sequence.
|
||||
block_table (torch.Tensor): The index of input sequence in block_table.
|
||||
eos_token_id (int): The eos token id for this inference process.
|
||||
pad_token_id (int): The pad token id for this inference process.
|
||||
max_output_len (int): Maximum output length.
|
||||
"""
|
||||
|
||||
|
@ -71,6 +73,7 @@ class Sequence:
|
|||
sample_params: Any # SampleParams needs to be imported later.
|
||||
block_table: torch.Tensor
|
||||
eos_token_id: int
|
||||
pad_token_id: int
|
||||
max_output_len: int = 256
|
||||
|
||||
def __post_init__(self):
|
||||
|
@ -167,15 +170,23 @@ class BatchInfo:
|
|||
Information to be passed and used for a batch of sequences.
|
||||
"""
|
||||
|
||||
max_batch_size: int
|
||||
kv_max_split_num: int
|
||||
num_heads: int
|
||||
head_dim: int
|
||||
sequences_set: OrderedSet[Sequence] = None
|
||||
is_prompts: bool = True
|
||||
device: torch.device = None
|
||||
dtype: torch.dtype = None
|
||||
fd_inter_tensor: FDIntermTensors = None
|
||||
|
||||
def __post_init__(self):
|
||||
if self.device is None:
|
||||
self.device = torch.cuda.current_device()
|
||||
if self.sequences_set is None:
|
||||
self.sequences_set = OrderedSet()
|
||||
if self.fd_inter_tensor is None:
|
||||
self.fd_inter_tensor = FDIntermTensors()
|
||||
|
||||
def init_batch(self, seqs: List["Sequence"] = None):
|
||||
"""
|
||||
|
@ -185,8 +196,6 @@ class BatchInfo:
|
|||
seqs (List["Sequence"]): List of input sequence.
|
||||
"""
|
||||
|
||||
assert len(self.sequences_set) == 0, "Sequences set has been initialized."
|
||||
|
||||
if seqs is not None:
|
||||
if not isinstance(seqs, list):
|
||||
seqs = [seqs]
|
||||
|
@ -197,16 +206,30 @@ class BatchInfo:
|
|||
|
||||
self.sequences_set.add(seq)
|
||||
|
||||
def init_fd_tensors(self):
|
||||
if not self.fd_inter_tensor.is_initialized:
|
||||
self.fd_inter_tensor.initialize(
|
||||
max_batch_size=self.max_batch_size,
|
||||
num_attn_heads=self.num_heads,
|
||||
kv_max_split_num=self.kv_max_split_num,
|
||||
head_dim=self.head_dim,
|
||||
dtype=self.dtype,
|
||||
device=self.device,
|
||||
)
|
||||
|
||||
def get_block_table_tensor(self) -> None:
|
||||
tesnor_list = []
|
||||
block_table = None
|
||||
|
||||
assert len(self.sequences_set) > 0, "Batch has not been initialized yet. Please initialize batch first."
|
||||
|
||||
for seq in self.sequences_set:
|
||||
block_table = seq.block_table
|
||||
assert (
|
||||
block_table is not None
|
||||
), f"The sequence(request_id {seq.request_id}) has not initialized the block_table."
|
||||
tesnor_list.append(seq.block_table)
|
||||
assert tesnor_list, "Batch has not been initialized yet. Please initialize batch first."
|
||||
|
||||
block_table = torch.stack(tesnor_list)
|
||||
return block_table
|
||||
|
||||
|
@ -218,7 +241,6 @@ class BatchInfo:
|
|||
"""
|
||||
if self.is_prompts:
|
||||
self.sequences_set.clear()
|
||||
|
||||
else:
|
||||
for seq in self.sequences_set:
|
||||
seq.mark_aborted()
|
||||
|
@ -312,14 +334,14 @@ class BatchInfo:
|
|||
"""
|
||||
Get bacth inputs for forward inference computation.
|
||||
"""
|
||||
|
||||
input_list = []
|
||||
|
||||
assert len(self.sequences_set) > 0, "Batch has not been initialized yet. Please initialize batch first."
|
||||
|
||||
for seq in self.sequences_set:
|
||||
if self.is_prompts:
|
||||
if seq.output_len > 0:
|
||||
print(seq.output_token_id)
|
||||
seq_data = seq.input_token_id + seq.output_token_id
|
||||
print(seq_data)
|
||||
input_list.append(seq.input_token_id + seq.output_token_id)
|
||||
else:
|
||||
input_list.append(seq.input_token_id)
|
||||
|
@ -328,7 +350,8 @@ class BatchInfo:
|
|||
|
||||
max_seq_len = max(len(sub_list) for sub_list in input_list)
|
||||
|
||||
return _make_tensor_with_pad(input_list, max_seq_len, 0, dtype=torch.int)
|
||||
# We assume that all the padding_id in seq are the same at present.
|
||||
return _make_tensor_with_pad(input_list, max_seq_len, self.sequences_set[0].pad_token_id, dtype=torch.int)
|
||||
|
||||
def get_1D_inputs(self) -> Tuple[torch.LongTensor, torch.Tensor]:
|
||||
"""
|
||||
|
@ -336,6 +359,9 @@ class BatchInfo:
|
|||
"""
|
||||
input_list = []
|
||||
input_len_list = []
|
||||
|
||||
assert len(self.sequences_set) > 0, "Batch has not been initialized yet. Please initialize batch first."
|
||||
|
||||
for seq in self.sequences_set:
|
||||
if self.is_prompts:
|
||||
input_list.extend(seq.input_token_id)
|
||||
|
@ -353,16 +379,23 @@ class BatchInfo:
|
|||
Get the input_len of each sentence in this batch.
|
||||
"""
|
||||
len_list = []
|
||||
|
||||
assert len(self.sequences_set) > 0, "Batch has not been initialized yet. Please initialize batch first."
|
||||
|
||||
for seq in self.sequences_set:
|
||||
len_list.append(seq.sentence_len)
|
||||
|
||||
return torch.tensor(len_list, dtype=torch.int, device=self.device)
|
||||
|
||||
def get_attn_mask(self, padding_id: int) -> torch.Tensor:
|
||||
def get_attn_mask(self) -> torch.Tensor:
|
||||
"""
|
||||
Generate and return attention mask.
|
||||
"""
|
||||
assert len(self.sequences_set) > 0, "Batch has not been initialized yet. Please initialize batch first."
|
||||
|
||||
past_values = []
|
||||
# We assume that all the padding_id in seq are the same at present.
|
||||
padding_id = self.sequences_set[0].pad_token_id
|
||||
|
||||
for seq in self.sequences_set:
|
||||
past_values.append(seq.input_token_id + seq.output_token_id)
|
||||
|
@ -378,7 +411,7 @@ class BatchInfo:
|
|||
|
||||
def _pad_to_max(x: List[int], max_len: int, pad: int) -> List[int]:
|
||||
assert len(x) <= max_len
|
||||
return x + [pad] * (max_len - len(x))
|
||||
return [pad] * (max_len - len(x)) + x
|
||||
|
||||
|
||||
def _make_tensor_with_pad(
|
||||
|
|
|
@ -10,7 +10,6 @@ except ImportError:
|
|||
if HAS_TRITON:
|
||||
from .context_attn_unpad import context_attention_unpadded
|
||||
from .flash_decoding import flash_decoding_attention
|
||||
from .flash_decoding_utils import FDIntermTensors
|
||||
from .fused_rotary_embedding import fused_rotary_embedding
|
||||
from .gptq_triton import gptq_fused_linear_triton
|
||||
from .kvcache_copy import copy_kv_to_blocked_cache
|
||||
|
@ -27,7 +26,6 @@ if HAS_TRITON:
|
|||
"rms_layernorm",
|
||||
"gptq_fused_linear_triton",
|
||||
"rotary_embedding",
|
||||
"FDIntermTensors",
|
||||
"fused_rotary_embedding",
|
||||
"get_xine_cache",
|
||||
]
|
||||
|
|
|
@ -5,7 +5,6 @@
|
|||
#
|
||||
# Inspired and modified from Triton Tutorial - Fused Attention
|
||||
# https://triton-lang.org/main/getting-started/tutorials/06-fused-attention.html
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import triton
|
||||
|
@ -195,7 +194,9 @@ def context_attention_unpadded(
|
|||
context_lengths: torch.Tensor, # [num_seqs]
|
||||
block_tables: torch.Tensor, # [num_seqs, max_blocks_per_sequence],
|
||||
block_size: int,
|
||||
max_seq_len_in_b: Optional[int] = None,
|
||||
output: torch.Tensor = None, # [num_tokens, num_heads, head_dim]
|
||||
max_seq_len: int = None,
|
||||
sm_scale: int = None,
|
||||
):
|
||||
Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1]
|
||||
assert Lq == Lk == Lv
|
||||
|
@ -210,10 +211,9 @@ def context_attention_unpadded(
|
|||
num_kv_group = num_heads // num_kv_heads
|
||||
|
||||
num_seqs, max_blocks_per_seq = block_tables.shape
|
||||
max_seq_len = context_lengths.max().item() if max_seq_len_in_b is None else max_seq_len_in_b
|
||||
sm_scale = 1.0 / (Lq**0.5)
|
||||
|
||||
output = torch.zeros_like(q)
|
||||
max_seq_len = context_lengths.max().item() if max_seq_len is None else max_seq_len
|
||||
sm_scale = 1.0 / (Lq**0.5) if sm_scale is None else sm_scale
|
||||
output = torch.zeros_like(q) if output is None else output
|
||||
|
||||
# NOTE For now, BLOCK_M and BLOCK_N are supposed to be equivalent with
|
||||
# the size of physical cache block (i.e. `block_size`)
|
||||
|
|
|
@ -195,6 +195,7 @@ def flash_decoding_attention(
|
|||
block_tables: torch.Tensor,
|
||||
block_size: int,
|
||||
max_seq_len_in_batch: int = None,
|
||||
output: torch.Tensor = None,
|
||||
mid_output: torch.Tensor = None,
|
||||
mid_output_lse: torch.Tensor = None,
|
||||
sm_scale: int = None,
|
||||
|
@ -211,6 +212,7 @@ def flash_decoding_attention(
|
|||
records the (kv) sequence lengths incorporating past kv sequence lengths.
|
||||
block_tables (torch.Tensor): [batch_size, max_blocks_per_sequence]
|
||||
max_seq_len_in_batch (int): Maximum sequence length in the batch.
|
||||
output (torch.Tensor): [bsz, 1, num_heads, head_dim]
|
||||
mid_output (torch.Tensor): [ max_bsz , num_heads, kv_max_split_num, head_dim]
|
||||
Intermediate output tensor. `max_bsz` should be greater than or equal to `bsz`.
|
||||
mid_output_lse (torch.Tensor): [ max_bsz , num_heads, kv_max_split_num]
|
||||
|
@ -292,7 +294,7 @@ def flash_decoding_attention(
|
|||
HEAD_DIM=head_dim,
|
||||
)
|
||||
|
||||
output = torch.empty((bsz, 1, num_heads, head_dim), dtype=q.dtype, device=q.device) # already overlapped
|
||||
output = torch.empty((bsz, 1, num_heads, head_dim), dtype=q.dtype, device=q.device) if output is None else output
|
||||
|
||||
grid = (triton.next_power_of_2(bsz), num_heads)
|
||||
|
||||
|
|
|
@ -91,7 +91,7 @@ def benchmark_inference(args):
|
|||
config.pad_token_id = config.eos_token_id
|
||||
model = transformers.LlamaForCausalLM(config).cuda()
|
||||
model = model.eval()
|
||||
tokenizer = AutoTokenizer.from_pretrained("/home/caidi/llama_model/")
|
||||
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/llama-tokenizer")
|
||||
|
||||
if args.dtype == "fp16":
|
||||
model = model.half()
|
||||
|
|
|
@ -23,11 +23,12 @@ CUDA_VISIBLE_DEVICES_set_n_least_memory_usage() {
|
|||
CUDA_VISIBLE_DEVICES_set_n_least_memory_usage 1
|
||||
|
||||
# benchmark llama2-7b one single GPU
|
||||
|
||||
for bsz in 16 32 64; do
|
||||
python3 ${PY_SCRIPT} -m llama2-7b --tp_size 1 --pp_size 1 -b $bsz -s 256 --output_len 128 --mode ${mode} | tee logs/${mode}_${GPU}_${bsz}_256.txt
|
||||
python3 ${PY_SCRIPT} -m llama2-7b --tp_size 1 --pp_size 1 -b $bsz -s 512 --output_len 256 --mode ${mode} | tee logs/${mode}_${GPU}_${bsz}_256.txt
|
||||
done
|
||||
|
||||
|
||||
for bsz in 16 32 64; do
|
||||
python3 ${PY_SCRIPT} -m llama2-7b --tp_size 1 --pp_size 1 -b $bsz -s 1024 --output_len 128 --mode ${mode} | tee logs/${mode}_${GPU}_${bsz}_1024.txt
|
||||
python3 ${PY_SCRIPT} -m llama2-7b --tp_size 1 --pp_size 1 -b $bsz -s 1024 --output_len 256 --mode ${mode} | tee logs/${mode}_${GPU}_${bsz}_1024.txt
|
||||
done
|
||||
|
|
|
@ -17,6 +17,7 @@ def check_config_and_inference():
|
|||
sample_params=None,
|
||||
block_table=None,
|
||||
eos_token_id=2,
|
||||
pad_token_id=2,
|
||||
max_output_len=256,
|
||||
)
|
||||
|
||||
|
@ -28,6 +29,7 @@ def check_config_and_inference():
|
|||
sample_params=None,
|
||||
block_table=None,
|
||||
eos_token_id=2,
|
||||
pad_token_id=2,
|
||||
max_output_len=256,
|
||||
)
|
||||
|
||||
|
@ -39,6 +41,7 @@ def check_config_and_inference():
|
|||
sample_params=None,
|
||||
block_table=None,
|
||||
eos_token_id=2,
|
||||
pad_token_id=2,
|
||||
max_output_len=256,
|
||||
)
|
||||
sequence.mark_running()
|
||||
|
@ -51,7 +54,12 @@ def check_config_and_inference():
|
|||
assert sequence.output_len == 0
|
||||
assert sequence.check_finish() == False
|
||||
|
||||
batch = BatchInfo(is_prompts=False)
|
||||
batch = BatchInfo(
|
||||
max_batch_size=8,
|
||||
kv_max_split_num=16,
|
||||
num_heads=2,
|
||||
head_dim=128,
|
||||
)
|
||||
batch.init_batch([sequence])
|
||||
batch.add_seqs([sequence2, sequence3])
|
||||
batch.add_seqs([sequence])
|
||||
|
|
|
@ -3,8 +3,7 @@ import random
|
|||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
import transformers
|
||||
from transformers import AutoTokenizer, GenerationConfig
|
||||
from transformers import AutoTokenizer, GenerationConfig, LlamaConfig, LlamaForCausalLM
|
||||
|
||||
import colossalai
|
||||
from colossalai.inference.config import InferenceConfig
|
||||
|
@ -22,8 +21,8 @@ def setup_seed(seed):
|
|||
def check_inference_engine(test_cai=False):
|
||||
setup_seed(20)
|
||||
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/llama-tokenizer")
|
||||
model = transformers.LlamaForCausalLM(
|
||||
transformers.LlamaConfig(
|
||||
model = LlamaForCausalLM(
|
||||
LlamaConfig(
|
||||
vocab_size=50000, hidden_size=512, intermediate_size=1536, num_attention_heads=4, num_hidden_layers=16
|
||||
)
|
||||
).cuda()
|
||||
|
@ -81,4 +80,4 @@ def test_inference_engine():
|
|||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_inference_engine()
|
||||
test_inference_engine()
|
||||
|
|
|
@ -20,6 +20,7 @@ def check_running_list():
|
|||
input_token_id=[1, 2, 3],
|
||||
block_size=16,
|
||||
eos_token_id=0,
|
||||
pad_token_id=0,
|
||||
sample_params=None,
|
||||
block_table=1,
|
||||
)
|
||||
|
@ -56,6 +57,7 @@ def check_request_handler():
|
|||
input_token_id=[1, 2, 3, 4, 5],
|
||||
block_size=16,
|
||||
eos_token_id=0,
|
||||
pad_token_id=0,
|
||||
sample_params=None,
|
||||
block_table=torch.tensor([-1, -1]),
|
||||
)
|
||||
|
|
|
@ -91,6 +91,7 @@ def test_flash_decoding(
|
|||
max_seq_len_in_b = kv_seq_lengths.max().item()
|
||||
# The maximum block length splitted on kv should be the kv cache block size
|
||||
kv_max_split_num = (max_seq_len_in_b + block_size - 1) // block_size
|
||||
output = torch.empty((bsz, 1, num_attn_heads, HEAD_DIM), dtype=q.dtype, device=q.device)
|
||||
mid_output = torch.empty(
|
||||
size=(bsz, num_attn_heads, kv_max_split_num, HEAD_DIM), dtype=torch.float32, device=q.device
|
||||
)
|
||||
|
@ -106,6 +107,7 @@ def test_flash_decoding(
|
|||
block_tables,
|
||||
block_size,
|
||||
max_seq_len_in_b,
|
||||
output,
|
||||
mid_output,
|
||||
mid_output_lse,
|
||||
sm_scale=sm_scale,
|
||||
|
@ -184,6 +186,7 @@ def bench_kernel(
|
|||
block_tables = block_tables.to(device=device)
|
||||
# the maximum block length splitted on kv should be the kv cache block size
|
||||
kv_max_split_num = (max_seq_len_in_b + block_size - 1) // block_size
|
||||
output = torch.empty((bsz, 1, num_attn_heads, HEAD_DIM), dtype=dtype, device=device)
|
||||
mid_output = torch.empty(
|
||||
size=(bsz, num_attn_heads, kv_max_split_num, HEAD_DIM), dtype=torch.float32, device=q.device
|
||||
)
|
||||
|
@ -199,6 +202,7 @@ def bench_kernel(
|
|||
block_tables,
|
||||
block_size,
|
||||
max_seq_len_in_b,
|
||||
output,
|
||||
mid_output,
|
||||
mid_output_lse,
|
||||
sm_scale=sm_scale,
|
||||
|
|
Loading…
Reference in New Issue