[inference]Optimize the usage of the mid tensors space in flash attn (#5304)

* opt flash attn

* opt tmp tensor

* fix benchmark_llama

* fix code style

* fix None logic for output tensor

* fix adapted to get_xine_cache

* add comment

* fix ci bugs

* fix some codes

* rm duplicated codes

* rm duplicated codes

* fix code style

* add _get_dtype in config.py
pull/5326/head
yuehuayingxueluo 2024-01-26 14:00:10 +08:00 committed by GitHub
parent af8359c430
commit 4f28cb43c0
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
16 changed files with 199 additions and 57 deletions

View File

@ -55,6 +55,7 @@ class InferenceConfig:
def __post_init__(self): def __post_init__(self):
self._init_batch_size() self._init_batch_size()
self._verify_config() self._verify_config()
self._get_dtype()
def _init_batch_size(self): def _init_batch_size(self):
""" """
@ -84,6 +85,7 @@ class InferenceConfig:
assert ( assert (
self.tp_size * self.pp_size == dist.get_world_size() self.tp_size * self.pp_size == dist.get_world_size()
), f"TP size({self.tp_size}) * PP size({self.pp_size}) should be equal to the global world size ({dist.get_world_size()})" ), f"TP size({self.tp_size}) * PP size({self.pp_size}) should be equal to the global world size ({dist.get_world_size()})"
assert self.dtype in [ assert self.dtype in [
"fp16", "fp16",
"fp32", "fp32",
@ -97,3 +99,11 @@ class InferenceConfig:
"gptq", "gptq",
None, None,
], f"quant should be one of 'smoothquant', 'gptq', but got {self.quant_mode}." ], f"quant should be one of 'smoothquant', 'gptq', but got {self.quant_mode}."
def _get_dtype(self) -> None:
if self.dtype == "fp32" or self.dtype == torch.float32:
self.dtype = torch.float32
elif self.dtype == "fp16" or self.dtype == torch.float16:
self.dtype = torch.float16
else:
self.dtype = torch.bfloat16

View File

@ -51,17 +51,10 @@ class InferenceEngine:
self.inference_config = inference_config self.inference_config = inference_config
self.model_config = model.config self.model_config = model.config
self.device = torch.device("cuda") self.device = torch.device("cuda")
self.dtype = inference_config.dtype
model = model.eval() model = model.eval()
model.to(self.dtype)
if inference_config.dtype == "fp32" or inference_config.dtype == torch.float32:
self.dtype = torch.float32
elif inference_config.dtype == "fp16" or inference_config.dtype == torch.float16:
self.dtype = torch.float16
model.half()
else:
self.dtype = torch.bfloat16
model.to(torch.bfloat16)
if model_policy is None: if model_policy is None:
model_policy = model_policy_map[self.model_config.model_type]() model_policy = model_policy_map[self.model_config.model_type]()
@ -217,6 +210,7 @@ class InferenceEngine:
None, None,
block_table, block_table,
self.tokenizer.eos_token_id, self.tokenizer.eos_token_id,
self.tokenizer.pad_token_id,
self.inference_config.max_output_len, self.inference_config.max_output_len,
) )
self.request_handler.add_sequence(sequence) self.request_handler.add_sequence(sequence)
@ -241,7 +235,6 @@ class InferenceEngine:
batch, batch,
self.k_cahce, self.k_cahce,
self.v_cache, self.v_cache,
padding_id=self.tokenizer.pad_token_id,
) )
logits = logits[:, -1, :] logits = logits[:, -1, :]

View File

@ -4,6 +4,7 @@ import torch
from transformers.configuration_utils import PretrainedConfig from transformers.configuration_utils import PretrainedConfig
from colossalai.inference.config import InferenceConfig from colossalai.inference.config import InferenceConfig
from colossalai.inference.flash_decoding_utils import FDIntermTensors
from colossalai.inference.kv_cache import KVCacheManager from colossalai.inference.kv_cache import KVCacheManager
from colossalai.inference.logit_processors import logit_processor from colossalai.inference.logit_processors import logit_processor
from colossalai.inference.sampler import * from colossalai.inference.sampler import *
@ -69,20 +70,60 @@ class RequestHandler:
Args: Args:
inference_config: Configuration for initialize and manage kv cache. inference_config: Configuration for initialize and manage kv cache.
model_config: Configuration for model model_config: Configuration for model
dtype (torch.dtype): The data type for weights and activations.
""" """
def __init__(self, inference_config: InferenceConfig, model_config: PretrainedConfig) -> None: def __init__(self, inference_config: InferenceConfig, model_config: PretrainedConfig) -> None:
self.inference_config = inference_config self.inference_config = inference_config
self._init_cache(model_config)
self.running_list: RunningList = RunningList(inference_config.prefill_ratio) self.running_list: RunningList = RunningList(inference_config.prefill_ratio)
self.waiting_list: List[List] = [[], [], []] self.waiting_list: List[List] = [[], [], []]
self.done_list: List[Sequence] = [] self.done_list: List[Sequence] = []
device = torch.cuda.current_device() self.dtype = inference_config.dtype
self.running_batch = BatchInfo(is_prompts=False, device=device)
self.prefill_batch = BatchInfo(is_prompts=True, device=device)
self.max_batch_size = inference_config.max_batch_size self.max_batch_size = inference_config.max_batch_size
# initialize cache
self._init_cache(model_config)
# initialize batch
device = torch.cuda.current_device()
kv_max_split_num = (
inference_config.max_input_len + inference_config.max_output_len + inference_config.block_size - 1
) // inference_config.block_size
head_dim = model_config.hidden_size // model_config.num_attention_heads
fd_inter_tensor = FDIntermTensors()
fd_inter_tensor.initialize(
max_batch_size=self.max_batch_size,
num_attn_heads=model_config.num_attention_heads,
kv_max_split_num=kv_max_split_num,
head_dim=head_dim,
dtype=self.dtype,
device=device,
)
# TODO In the continuous batching scenario, the batch size may be greater than max_batch_size,
# which may cause bugs and this issue should be fixed later.
self.running_batch = BatchInfo(
max_batch_size=self.max_batch_size,
kv_max_split_num=kv_max_split_num,
num_heads=model_config.num_attention_heads,
head_dim=head_dim,
is_prompts=False,
device=device,
dtype=self.dtype,
fd_inter_tensor=fd_inter_tensor,
)
self.prefill_batch = BatchInfo(
max_batch_size=self.max_batch_size,
kv_max_split_num=kv_max_split_num,
num_heads=model_config.num_attention_heads,
head_dim=head_dim,
is_prompts=True,
device=device,
dtype=self.dtype,
fd_inter_tensor=fd_inter_tensor,
)
def _init_cache(self, model_config): def _init_cache(self, model_config):
self.cache_manager = KVCacheManager(self.inference_config, model_config) self.cache_manager = KVCacheManager(self.inference_config, model_config)

View File

@ -58,12 +58,7 @@ class KVCacheManager:
# Parallel settings # Parallel settings
self.tp_size = config.tp_size self.tp_size = config.tp_size
# Model settings # Model settings
if config.dtype == "fp32" or config.dtype == torch.float32: self.dtype = config.dtype
self.dtype = torch.float32
elif config.dtype == "fp16" or config.dtype == torch.float16:
self.dtype = torch.float16
else:
self.dtype = torch.bfloat16
self.elem_size_in_bytes = torch.tensor([], dtype=self.dtype).element_size() self.elem_size_in_bytes = torch.tensor([], dtype=self.dtype).element_size()
self.num_layers = get_model_config_attr(model_config, "num_hidden_layers") self.num_layers = get_model_config_attr(model_config, "num_hidden_layers")
# For now we focus on MHA only, TODO add handling for MQA and GQA # For now we focus on MHA only, TODO add handling for MQA and GQA

View File

@ -4,6 +4,7 @@ from typing import List, Optional, Tuple
import torch import torch
from transformers.models.llama.modeling_llama import LlamaAttention, LlamaDecoderLayer, LlamaForCausalLM, LlamaModel from transformers.models.llama.modeling_llama import LlamaAttention, LlamaDecoderLayer, LlamaForCausalLM, LlamaModel
from colossalai.inference.flash_decoding_utils import FDIntermTensors
from colossalai.inference.modeling.layers.attention import PagedAttention from colossalai.inference.modeling.layers.attention import PagedAttention
from colossalai.inference.struct import BatchInfo from colossalai.inference.struct import BatchInfo
from colossalai.kernel.triton import ( from colossalai.kernel.triton import (
@ -50,7 +51,6 @@ def llama_causal_lm_forward(
batch: BatchInfo = None, batch: BatchInfo = None,
k_caches: List[torch.Tensor] = None, k_caches: List[torch.Tensor] = None,
v_caches: List[torch.Tensor] = None, v_caches: List[torch.Tensor] = None,
padding_id: int = None,
): ):
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
hidden_states = llama_model_forward( hidden_states = llama_model_forward(
@ -58,7 +58,6 @@ def llama_causal_lm_forward(
batch=batch, batch=batch,
k_caches=k_caches, k_caches=k_caches,
v_caches=v_caches, v_caches=v_caches,
padding_id=padding_id,
) )
logits = self.lm_head(hidden_states) logits = self.lm_head(hidden_states)
return logits return logits
@ -70,11 +69,10 @@ def llama_model_forward(
batch: BatchInfo = None, batch: BatchInfo = None,
k_caches: List[torch.Tensor] = None, k_caches: List[torch.Tensor] = None,
v_caches: List[torch.Tensor] = None, v_caches: List[torch.Tensor] = None,
padding_id: int = None,
): ):
input_ids = batch.get_batch_inputs() input_ids = batch.get_batch_inputs()
block_tables = batch.get_block_table_tensor() block_tables = batch.get_block_table_tensor()
attention_mask = batch.get_attn_mask(padding_id) attention_mask = batch.get_attn_mask()
if attention_mask is not None: if attention_mask is not None:
if HAS_TRITON: if HAS_TRITON:
@ -84,6 +82,7 @@ def llama_model_forward(
else: else:
sequence_lengths = batch.get_sequence_lengths() sequence_lengths = batch.get_sequence_lengths()
batch_size, _ = input_ids.shape
kv_seq_len = sequence_lengths.max().item() kv_seq_len = sequence_lengths.max().item()
if attention_mask is not None: if attention_mask is not None:
@ -102,7 +101,22 @@ def llama_model_forward(
hidden_states = self.embed_tokens(input_ids) hidden_states = self.embed_tokens(input_ids)
cos_sin = get_cos_sin(sequence_lengths, self._cos_cached, self._sin_cached, batch.is_prompts, hidden_states.dtype) # When testing, the performance of get_xine_cache is lower than that of get_cos_sin.
# cos = get_xine_cache(sequence_lengths, self._cos_cached, batch.is_prompts)
# sin = get_xine_cache(sequence_lengths, self._sin_cached, batch.is_prompts)
# cos_sin = (cos, sin)
cos_sin = get_cos_sin(sequence_lengths, self._cos_cached, self._sin_cached, batch.is_prompts, batch.dtype)
if batch.is_prompts:
output_tensor = torch.zeros(
(sequence_lengths.sum().item(), batch.num_heads, batch.head_dim), dtype=batch.dtype, device=batch.device
)
else:
output_tensor = torch.zeros(
(batch_size, 1, batch.num_heads, batch.head_dim), dtype=batch.dtype, device=batch.device
)
sm_scale = 1.0 / (batch.head_dim**0.5)
for layer_id, decoder_layer in enumerate(self.layers): for layer_id, decoder_layer in enumerate(self.layers):
hidden_states = decoder_layer( hidden_states = decoder_layer(
@ -116,6 +130,9 @@ def llama_model_forward(
attention_mask=attention_mask, attention_mask=attention_mask,
kv_seq_len=kv_seq_len, kv_seq_len=kv_seq_len,
cos_sin=cos_sin, cos_sin=cos_sin,
fd_inter_tensor=batch.fd_inter_tensor,
output_tensor=output_tensor,
sm_scale=sm_scale,
) )
hidden_states = self.norm(hidden_states) hidden_states = self.norm(hidden_states)
@ -131,10 +148,13 @@ def llama_decoder_layer_forward(
k_cache: torch.Tensor = None, k_cache: torch.Tensor = None,
v_cache: torch.Tensor = None, v_cache: torch.Tensor = None,
is_prompts: bool = True, is_prompts: bool = True,
sequence_lengths: int = None, sequence_lengths: torch.Tensor = None,
attention_mask: torch.Tensor = None, attention_mask: torch.Tensor = None,
kv_seq_len: int = 0, kv_seq_len: int = 0,
cos_sin: Tuple[torch.Tensor] = None, cos_sin: Tuple[torch.Tensor] = None,
fd_inter_tensor: FDIntermTensors = None,
output_tensor: torch.Tensor = None,
sm_scale: int = None,
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
residual = hidden_states residual = hidden_states
@ -151,6 +171,9 @@ def llama_decoder_layer_forward(
attention_mask=attention_mask, attention_mask=attention_mask,
kv_seq_len=kv_seq_len, kv_seq_len=kv_seq_len,
cos_sin=cos_sin, cos_sin=cos_sin,
fd_inter_tensor=fd_inter_tensor,
output_tensor=output_tensor,
sm_scale=sm_scale,
) )
hidden_states = residual + hidden_states hidden_states = residual + hidden_states
@ -178,6 +201,9 @@ def llama_attn_forward(
attention_mask: torch.Tensor = None, attention_mask: torch.Tensor = None,
kv_seq_len: int = 0, kv_seq_len: int = 0,
cos_sin: Tuple[torch.Tensor] = None, cos_sin: Tuple[torch.Tensor] = None,
fd_inter_tensor: FDIntermTensors = None,
output_tensor: torch.Tensor = None,
sm_scale: int = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
bsz, q_len, _ = hidden_states.size() bsz, q_len, _ = hidden_states.size()
@ -206,7 +232,17 @@ def llama_attn_forward(
if is_prompts: if is_prompts:
attn_output = context_attention_unpadded( attn_output = context_attention_unpadded(
query_states, key_states, value_states, k_cache, v_cache, sequence_lengths, block_tables, block_size q=query_states,
k=key_states,
v=value_states,
k_cache=k_cache,
v_cache=v_cache,
context_lengths=sequence_lengths,
block_tables=block_tables,
block_size=block_size,
output=output_tensor,
max_seq_len=kv_seq_len,
sm_scale=sm_scale,
) )
if attention_mask is not None: if attention_mask is not None:
attn_output = pad_input(attn_output, indices, bsz, q_len) attn_output = pad_input(attn_output, indices, bsz, q_len)
@ -214,7 +250,17 @@ def llama_attn_forward(
copy_kv_to_blocked_cache(key_states, k_cache, kv_lengths=sequence_lengths, block_tables=block_tables) copy_kv_to_blocked_cache(key_states, k_cache, kv_lengths=sequence_lengths, block_tables=block_tables)
copy_kv_to_blocked_cache(value_states, v_cache, kv_lengths=sequence_lengths, block_tables=block_tables) copy_kv_to_blocked_cache(value_states, v_cache, kv_lengths=sequence_lengths, block_tables=block_tables)
attn_output = flash_decoding_attention( attn_output = flash_decoding_attention(
query_states, k_cache, v_cache, sequence_lengths, block_tables, block_size q=query_states,
k_cache=k_cache,
v_cache=v_cache,
kv_seq_len=sequence_lengths,
block_tables=block_tables,
block_size=block_size,
max_seq_len_in_batch=kv_seq_len,
output=output_tensor,
mid_output=fd_inter_tensor.mid_output,
mid_output_lse=fd_inter_tensor.mid_output_lse,
sm_scale=sm_scale,
) )
attn_output = attn_output.squeeze(1) attn_output = attn_output.squeeze(1)
else: else:
@ -285,6 +331,16 @@ def unpading_input(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, attention_
@torch.no_grad() @torch.no_grad()
def get_cos_sin(lengths, cos_cache, sin_cache, is_prompts, dtype): def get_cos_sin(lengths, cos_cache, sin_cache, is_prompts, dtype):
"""
Get cos and sin for the cache, and return nopad format.
Args:
lengths: shape(num_seqs,), stores lenghth of each sequence.
cos_cache: shape(max_rotary_position(e.g.2048), head_dim), cos cache constrcuted in model.
sin_cache: shape(max_rotary_position(e.g.2048), head_dim), sin cache constrcuted in model.
is_prompts: bool, mark if in prefill mode.
dtype: The data type of this inference process.
"""
if is_prompts: if is_prompts:
index_arrays = [torch.arange(length) for length in lengths] index_arrays = [torch.arange(length) for length in lengths]
else: else:

View File

@ -5,6 +5,7 @@ from typing import Any, List, Tuple, Union
import torch import torch
from ordered_set import OrderedSet from ordered_set import OrderedSet
from colossalai.inference.flash_decoding_utils import FDIntermTensors
from colossalai.logging import get_dist_logger from colossalai.logging import get_dist_logger
logger = get_dist_logger(__name__) logger = get_dist_logger(__name__)
@ -61,6 +62,7 @@ class Sequence:
sample_params (SampleParams): The sample_params of input sequence. sample_params (SampleParams): The sample_params of input sequence.
block_table (torch.Tensor): The index of input sequence in block_table. block_table (torch.Tensor): The index of input sequence in block_table.
eos_token_id (int): The eos token id for this inference process. eos_token_id (int): The eos token id for this inference process.
pad_token_id (int): The pad token id for this inference process.
max_output_len (int): Maximum output length. max_output_len (int): Maximum output length.
""" """
@ -71,6 +73,7 @@ class Sequence:
sample_params: Any # SampleParams needs to be imported later. sample_params: Any # SampleParams needs to be imported later.
block_table: torch.Tensor block_table: torch.Tensor
eos_token_id: int eos_token_id: int
pad_token_id: int
max_output_len: int = 256 max_output_len: int = 256
def __post_init__(self): def __post_init__(self):
@ -167,15 +170,23 @@ class BatchInfo:
Information to be passed and used for a batch of sequences. Information to be passed and used for a batch of sequences.
""" """
max_batch_size: int
kv_max_split_num: int
num_heads: int
head_dim: int
sequences_set: OrderedSet[Sequence] = None sequences_set: OrderedSet[Sequence] = None
is_prompts: bool = True is_prompts: bool = True
device: torch.device = None device: torch.device = None
dtype: torch.dtype = None
fd_inter_tensor: FDIntermTensors = None
def __post_init__(self): def __post_init__(self):
if self.device is None: if self.device is None:
self.device = torch.cuda.current_device() self.device = torch.cuda.current_device()
if self.sequences_set is None: if self.sequences_set is None:
self.sequences_set = OrderedSet() self.sequences_set = OrderedSet()
if self.fd_inter_tensor is None:
self.fd_inter_tensor = FDIntermTensors()
def init_batch(self, seqs: List["Sequence"] = None): def init_batch(self, seqs: List["Sequence"] = None):
""" """
@ -185,8 +196,6 @@ class BatchInfo:
seqs (List["Sequence"]): List of input sequence. seqs (List["Sequence"]): List of input sequence.
""" """
assert len(self.sequences_set) == 0, "Sequences set has been initialized."
if seqs is not None: if seqs is not None:
if not isinstance(seqs, list): if not isinstance(seqs, list):
seqs = [seqs] seqs = [seqs]
@ -197,16 +206,30 @@ class BatchInfo:
self.sequences_set.add(seq) self.sequences_set.add(seq)
def init_fd_tensors(self):
if not self.fd_inter_tensor.is_initialized:
self.fd_inter_tensor.initialize(
max_batch_size=self.max_batch_size,
num_attn_heads=self.num_heads,
kv_max_split_num=self.kv_max_split_num,
head_dim=self.head_dim,
dtype=self.dtype,
device=self.device,
)
def get_block_table_tensor(self) -> None: def get_block_table_tensor(self) -> None:
tesnor_list = [] tesnor_list = []
block_table = None block_table = None
assert len(self.sequences_set) > 0, "Batch has not been initialized yet. Please initialize batch first."
for seq in self.sequences_set: for seq in self.sequences_set:
block_table = seq.block_table block_table = seq.block_table
assert ( assert (
block_table is not None block_table is not None
), f"The sequence(request_id {seq.request_id}) has not initialized the block_table." ), f"The sequence(request_id {seq.request_id}) has not initialized the block_table."
tesnor_list.append(seq.block_table) tesnor_list.append(seq.block_table)
assert tesnor_list, "Batch has not been initialized yet. Please initialize batch first."
block_table = torch.stack(tesnor_list) block_table = torch.stack(tesnor_list)
return block_table return block_table
@ -218,7 +241,6 @@ class BatchInfo:
""" """
if self.is_prompts: if self.is_prompts:
self.sequences_set.clear() self.sequences_set.clear()
else: else:
for seq in self.sequences_set: for seq in self.sequences_set:
seq.mark_aborted() seq.mark_aborted()
@ -312,14 +334,14 @@ class BatchInfo:
""" """
Get bacth inputs for forward inference computation. Get bacth inputs for forward inference computation.
""" """
input_list = [] input_list = []
assert len(self.sequences_set) > 0, "Batch has not been initialized yet. Please initialize batch first."
for seq in self.sequences_set: for seq in self.sequences_set:
if self.is_prompts: if self.is_prompts:
if seq.output_len > 0: if seq.output_len > 0:
print(seq.output_token_id)
seq_data = seq.input_token_id + seq.output_token_id
print(seq_data)
input_list.append(seq.input_token_id + seq.output_token_id) input_list.append(seq.input_token_id + seq.output_token_id)
else: else:
input_list.append(seq.input_token_id) input_list.append(seq.input_token_id)
@ -328,7 +350,8 @@ class BatchInfo:
max_seq_len = max(len(sub_list) for sub_list in input_list) max_seq_len = max(len(sub_list) for sub_list in input_list)
return _make_tensor_with_pad(input_list, max_seq_len, 0, dtype=torch.int) # We assume that all the padding_id in seq are the same at present.
return _make_tensor_with_pad(input_list, max_seq_len, self.sequences_set[0].pad_token_id, dtype=torch.int)
def get_1D_inputs(self) -> Tuple[torch.LongTensor, torch.Tensor]: def get_1D_inputs(self) -> Tuple[torch.LongTensor, torch.Tensor]:
""" """
@ -336,6 +359,9 @@ class BatchInfo:
""" """
input_list = [] input_list = []
input_len_list = [] input_len_list = []
assert len(self.sequences_set) > 0, "Batch has not been initialized yet. Please initialize batch first."
for seq in self.sequences_set: for seq in self.sequences_set:
if self.is_prompts: if self.is_prompts:
input_list.extend(seq.input_token_id) input_list.extend(seq.input_token_id)
@ -353,16 +379,23 @@ class BatchInfo:
Get the input_len of each sentence in this batch. Get the input_len of each sentence in this batch.
""" """
len_list = [] len_list = []
assert len(self.sequences_set) > 0, "Batch has not been initialized yet. Please initialize batch first."
for seq in self.sequences_set: for seq in self.sequences_set:
len_list.append(seq.sentence_len) len_list.append(seq.sentence_len)
return torch.tensor(len_list, dtype=torch.int, device=self.device) return torch.tensor(len_list, dtype=torch.int, device=self.device)
def get_attn_mask(self, padding_id: int) -> torch.Tensor: def get_attn_mask(self) -> torch.Tensor:
""" """
Generate and return attention mask. Generate and return attention mask.
""" """
assert len(self.sequences_set) > 0, "Batch has not been initialized yet. Please initialize batch first."
past_values = [] past_values = []
# We assume that all the padding_id in seq are the same at present.
padding_id = self.sequences_set[0].pad_token_id
for seq in self.sequences_set: for seq in self.sequences_set:
past_values.append(seq.input_token_id + seq.output_token_id) past_values.append(seq.input_token_id + seq.output_token_id)
@ -378,7 +411,7 @@ class BatchInfo:
def _pad_to_max(x: List[int], max_len: int, pad: int) -> List[int]: def _pad_to_max(x: List[int], max_len: int, pad: int) -> List[int]:
assert len(x) <= max_len assert len(x) <= max_len
return x + [pad] * (max_len - len(x)) return [pad] * (max_len - len(x)) + x
def _make_tensor_with_pad( def _make_tensor_with_pad(

View File

@ -10,7 +10,6 @@ except ImportError:
if HAS_TRITON: if HAS_TRITON:
from .context_attn_unpad import context_attention_unpadded from .context_attn_unpad import context_attention_unpadded
from .flash_decoding import flash_decoding_attention from .flash_decoding import flash_decoding_attention
from .flash_decoding_utils import FDIntermTensors
from .fused_rotary_embedding import fused_rotary_embedding from .fused_rotary_embedding import fused_rotary_embedding
from .gptq_triton import gptq_fused_linear_triton from .gptq_triton import gptq_fused_linear_triton
from .kvcache_copy import copy_kv_to_blocked_cache from .kvcache_copy import copy_kv_to_blocked_cache
@ -27,7 +26,6 @@ if HAS_TRITON:
"rms_layernorm", "rms_layernorm",
"gptq_fused_linear_triton", "gptq_fused_linear_triton",
"rotary_embedding", "rotary_embedding",
"FDIntermTensors",
"fused_rotary_embedding", "fused_rotary_embedding",
"get_xine_cache", "get_xine_cache",
] ]

View File

@ -5,7 +5,6 @@
# #
# Inspired and modified from Triton Tutorial - Fused Attention # Inspired and modified from Triton Tutorial - Fused Attention
# https://triton-lang.org/main/getting-started/tutorials/06-fused-attention.html # https://triton-lang.org/main/getting-started/tutorials/06-fused-attention.html
from typing import Optional
import torch import torch
import triton import triton
@ -195,7 +194,9 @@ def context_attention_unpadded(
context_lengths: torch.Tensor, # [num_seqs] context_lengths: torch.Tensor, # [num_seqs]
block_tables: torch.Tensor, # [num_seqs, max_blocks_per_sequence], block_tables: torch.Tensor, # [num_seqs, max_blocks_per_sequence],
block_size: int, block_size: int,
max_seq_len_in_b: Optional[int] = None, output: torch.Tensor = None, # [num_tokens, num_heads, head_dim]
max_seq_len: int = None,
sm_scale: int = None,
): ):
Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1] Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1]
assert Lq == Lk == Lv assert Lq == Lk == Lv
@ -210,10 +211,9 @@ def context_attention_unpadded(
num_kv_group = num_heads // num_kv_heads num_kv_group = num_heads // num_kv_heads
num_seqs, max_blocks_per_seq = block_tables.shape num_seqs, max_blocks_per_seq = block_tables.shape
max_seq_len = context_lengths.max().item() if max_seq_len_in_b is None else max_seq_len_in_b max_seq_len = context_lengths.max().item() if max_seq_len is None else max_seq_len
sm_scale = 1.0 / (Lq**0.5) sm_scale = 1.0 / (Lq**0.5) if sm_scale is None else sm_scale
output = torch.zeros_like(q) if output is None else output
output = torch.zeros_like(q)
# NOTE For now, BLOCK_M and BLOCK_N are supposed to be equivalent with # NOTE For now, BLOCK_M and BLOCK_N are supposed to be equivalent with
# the size of physical cache block (i.e. `block_size`) # the size of physical cache block (i.e. `block_size`)

View File

@ -195,6 +195,7 @@ def flash_decoding_attention(
block_tables: torch.Tensor, block_tables: torch.Tensor,
block_size: int, block_size: int,
max_seq_len_in_batch: int = None, max_seq_len_in_batch: int = None,
output: torch.Tensor = None,
mid_output: torch.Tensor = None, mid_output: torch.Tensor = None,
mid_output_lse: torch.Tensor = None, mid_output_lse: torch.Tensor = None,
sm_scale: int = None, sm_scale: int = None,
@ -211,6 +212,7 @@ def flash_decoding_attention(
records the (kv) sequence lengths incorporating past kv sequence lengths. records the (kv) sequence lengths incorporating past kv sequence lengths.
block_tables (torch.Tensor): [batch_size, max_blocks_per_sequence] block_tables (torch.Tensor): [batch_size, max_blocks_per_sequence]
max_seq_len_in_batch (int): Maximum sequence length in the batch. max_seq_len_in_batch (int): Maximum sequence length in the batch.
output (torch.Tensor): [bsz, 1, num_heads, head_dim]
mid_output (torch.Tensor): [ max_bsz , num_heads, kv_max_split_num, head_dim] mid_output (torch.Tensor): [ max_bsz , num_heads, kv_max_split_num, head_dim]
Intermediate output tensor. `max_bsz` should be greater than or equal to `bsz`. Intermediate output tensor. `max_bsz` should be greater than or equal to `bsz`.
mid_output_lse (torch.Tensor): [ max_bsz , num_heads, kv_max_split_num] mid_output_lse (torch.Tensor): [ max_bsz , num_heads, kv_max_split_num]
@ -292,7 +294,7 @@ def flash_decoding_attention(
HEAD_DIM=head_dim, HEAD_DIM=head_dim,
) )
output = torch.empty((bsz, 1, num_heads, head_dim), dtype=q.dtype, device=q.device) # already overlapped output = torch.empty((bsz, 1, num_heads, head_dim), dtype=q.dtype, device=q.device) if output is None else output
grid = (triton.next_power_of_2(bsz), num_heads) grid = (triton.next_power_of_2(bsz), num_heads)

View File

@ -91,7 +91,7 @@ def benchmark_inference(args):
config.pad_token_id = config.eos_token_id config.pad_token_id = config.eos_token_id
model = transformers.LlamaForCausalLM(config).cuda() model = transformers.LlamaForCausalLM(config).cuda()
model = model.eval() model = model.eval()
tokenizer = AutoTokenizer.from_pretrained("/home/caidi/llama_model/") tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/llama-tokenizer")
if args.dtype == "fp16": if args.dtype == "fp16":
model = model.half() model = model.half()

View File

@ -23,11 +23,12 @@ CUDA_VISIBLE_DEVICES_set_n_least_memory_usage() {
CUDA_VISIBLE_DEVICES_set_n_least_memory_usage 1 CUDA_VISIBLE_DEVICES_set_n_least_memory_usage 1
# benchmark llama2-7b one single GPU # benchmark llama2-7b one single GPU
for bsz in 16 32 64; do for bsz in 16 32 64; do
python3 ${PY_SCRIPT} -m llama2-7b --tp_size 1 --pp_size 1 -b $bsz -s 256 --output_len 128 --mode ${mode} | tee logs/${mode}_${GPU}_${bsz}_256.txt python3 ${PY_SCRIPT} -m llama2-7b --tp_size 1 --pp_size 1 -b $bsz -s 512 --output_len 256 --mode ${mode} | tee logs/${mode}_${GPU}_${bsz}_256.txt
done done
for bsz in 16 32 64; do for bsz in 16 32 64; do
python3 ${PY_SCRIPT} -m llama2-7b --tp_size 1 --pp_size 1 -b $bsz -s 1024 --output_len 128 --mode ${mode} | tee logs/${mode}_${GPU}_${bsz}_1024.txt python3 ${PY_SCRIPT} -m llama2-7b --tp_size 1 --pp_size 1 -b $bsz -s 1024 --output_len 256 --mode ${mode} | tee logs/${mode}_${GPU}_${bsz}_1024.txt
done done

View File

@ -17,6 +17,7 @@ def check_config_and_inference():
sample_params=None, sample_params=None,
block_table=None, block_table=None,
eos_token_id=2, eos_token_id=2,
pad_token_id=2,
max_output_len=256, max_output_len=256,
) )
@ -28,6 +29,7 @@ def check_config_and_inference():
sample_params=None, sample_params=None,
block_table=None, block_table=None,
eos_token_id=2, eos_token_id=2,
pad_token_id=2,
max_output_len=256, max_output_len=256,
) )
@ -39,6 +41,7 @@ def check_config_and_inference():
sample_params=None, sample_params=None,
block_table=None, block_table=None,
eos_token_id=2, eos_token_id=2,
pad_token_id=2,
max_output_len=256, max_output_len=256,
) )
sequence.mark_running() sequence.mark_running()
@ -51,7 +54,12 @@ def check_config_and_inference():
assert sequence.output_len == 0 assert sequence.output_len == 0
assert sequence.check_finish() == False assert sequence.check_finish() == False
batch = BatchInfo(is_prompts=False) batch = BatchInfo(
max_batch_size=8,
kv_max_split_num=16,
num_heads=2,
head_dim=128,
)
batch.init_batch([sequence]) batch.init_batch([sequence])
batch.add_seqs([sequence2, sequence3]) batch.add_seqs([sequence2, sequence3])
batch.add_seqs([sequence]) batch.add_seqs([sequence])

View File

@ -3,8 +3,7 @@ import random
import numpy as np import numpy as np
import pytest import pytest
import torch import torch
import transformers from transformers import AutoTokenizer, GenerationConfig, LlamaConfig, LlamaForCausalLM
from transformers import AutoTokenizer, GenerationConfig
import colossalai import colossalai
from colossalai.inference.config import InferenceConfig from colossalai.inference.config import InferenceConfig
@ -22,8 +21,8 @@ def setup_seed(seed):
def check_inference_engine(test_cai=False): def check_inference_engine(test_cai=False):
setup_seed(20) setup_seed(20)
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/llama-tokenizer") tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/llama-tokenizer")
model = transformers.LlamaForCausalLM( model = LlamaForCausalLM(
transformers.LlamaConfig( LlamaConfig(
vocab_size=50000, hidden_size=512, intermediate_size=1536, num_attention_heads=4, num_hidden_layers=16 vocab_size=50000, hidden_size=512, intermediate_size=1536, num_attention_heads=4, num_hidden_layers=16
) )
).cuda() ).cuda()

View File

@ -20,6 +20,7 @@ def check_running_list():
input_token_id=[1, 2, 3], input_token_id=[1, 2, 3],
block_size=16, block_size=16,
eos_token_id=0, eos_token_id=0,
pad_token_id=0,
sample_params=None, sample_params=None,
block_table=1, block_table=1,
) )
@ -56,6 +57,7 @@ def check_request_handler():
input_token_id=[1, 2, 3, 4, 5], input_token_id=[1, 2, 3, 4, 5],
block_size=16, block_size=16,
eos_token_id=0, eos_token_id=0,
pad_token_id=0,
sample_params=None, sample_params=None,
block_table=torch.tensor([-1, -1]), block_table=torch.tensor([-1, -1]),
) )

View File

@ -91,6 +91,7 @@ def test_flash_decoding(
max_seq_len_in_b = kv_seq_lengths.max().item() max_seq_len_in_b = kv_seq_lengths.max().item()
# The maximum block length splitted on kv should be the kv cache block size # The maximum block length splitted on kv should be the kv cache block size
kv_max_split_num = (max_seq_len_in_b + block_size - 1) // block_size kv_max_split_num = (max_seq_len_in_b + block_size - 1) // block_size
output = torch.empty((bsz, 1, num_attn_heads, HEAD_DIM), dtype=q.dtype, device=q.device)
mid_output = torch.empty( mid_output = torch.empty(
size=(bsz, num_attn_heads, kv_max_split_num, HEAD_DIM), dtype=torch.float32, device=q.device size=(bsz, num_attn_heads, kv_max_split_num, HEAD_DIM), dtype=torch.float32, device=q.device
) )
@ -106,6 +107,7 @@ def test_flash_decoding(
block_tables, block_tables,
block_size, block_size,
max_seq_len_in_b, max_seq_len_in_b,
output,
mid_output, mid_output,
mid_output_lse, mid_output_lse,
sm_scale=sm_scale, sm_scale=sm_scale,
@ -184,6 +186,7 @@ def bench_kernel(
block_tables = block_tables.to(device=device) block_tables = block_tables.to(device=device)
# the maximum block length splitted on kv should be the kv cache block size # the maximum block length splitted on kv should be the kv cache block size
kv_max_split_num = (max_seq_len_in_b + block_size - 1) // block_size kv_max_split_num = (max_seq_len_in_b + block_size - 1) // block_size
output = torch.empty((bsz, 1, num_attn_heads, HEAD_DIM), dtype=dtype, device=device)
mid_output = torch.empty( mid_output = torch.empty(
size=(bsz, num_attn_heads, kv_max_split_num, HEAD_DIM), dtype=torch.float32, device=q.device size=(bsz, num_attn_heads, kv_max_split_num, HEAD_DIM), dtype=torch.float32, device=q.device
) )
@ -199,6 +202,7 @@ def bench_kernel(
block_tables, block_tables,
block_size, block_size,
max_seq_len_in_b, max_seq_len_in_b,
output,
mid_output, mid_output,
mid_output_lse, mid_output_lse,
sm_scale=sm_scale, sm_scale=sm_scale,