mirror of https://github.com/hpcaitech/ColossalAI
[Inference]Adapted to the triton attn kernels (#5264)
* adapted to the triton attn kernels * fix pad input * adapted to copy_kv_to_blocked_cache * fix ci test * update kv memcpy * remove printpull/5270/head
parent
0f2b46a41c
commit
86b63f720c
|
@ -236,6 +236,7 @@ class InferenceEngine:
|
||||||
output_list = []
|
output_list = []
|
||||||
batch = self.request_handler.schedule()
|
batch = self.request_handler.schedule()
|
||||||
|
|
||||||
|
# TODO: padding_id is used for generating attn_mask and will be removed if nopad version is supported.
|
||||||
logits = self.model(
|
logits = self.model(
|
||||||
batch,
|
batch,
|
||||||
self.k_cahce,
|
self.k_cahce,
|
||||||
|
|
|
@ -57,9 +57,6 @@ class RunningList:
|
||||||
def is_empty(self):
|
def is_empty(self):
|
||||||
return not self.decoding and not self.prefill
|
return not self.decoding and not self.prefill
|
||||||
|
|
||||||
def total_seq_num(self):
|
|
||||||
return len(self.decoding) + len(self.prefill)
|
|
||||||
|
|
||||||
|
|
||||||
class RequestHandler:
|
class RequestHandler:
|
||||||
"""
|
"""
|
||||||
|
@ -81,6 +78,7 @@ class RequestHandler:
|
||||||
device = torch.cuda.current_device()
|
device = torch.cuda.current_device()
|
||||||
self.running_batch = BatchInfo(is_prompts=False, device=device)
|
self.running_batch = BatchInfo(is_prompts=False, device=device)
|
||||||
self.prefill_batch = BatchInfo(is_prompts=True, device=device)
|
self.prefill_batch = BatchInfo(is_prompts=True, device=device)
|
||||||
|
self.max_batch_size = inference_config.max_batch_size
|
||||||
|
|
||||||
def _init_cache(self, model_config):
|
def _init_cache(self, model_config):
|
||||||
self.cache_manager = KVCacheManager(self.inference_config, model_config)
|
self.cache_manager = KVCacheManager(self.inference_config, model_config)
|
||||||
|
@ -108,20 +106,18 @@ class RequestHandler:
|
||||||
)
|
)
|
||||||
self.abort_sequence(seq.request_id)
|
self.abort_sequence(seq.request_id)
|
||||||
break
|
break
|
||||||
|
|
||||||
# stop feeding new sequence into running list to assure
|
|
||||||
if self.cache_manager.num_available_blocks <= self.running_list.total_seq_num:
|
|
||||||
break
|
|
||||||
|
|
||||||
# Try to allocate cache blocks for the sequence.
|
# Try to allocate cache blocks for the sequence.
|
||||||
if self.cache_manager.check_allocation(seq):
|
if (
|
||||||
|
self.cache_manager.check_allocation(seq)
|
||||||
|
and (len(self.running_list.prefill) + len(self.running_list.decoding))
|
||||||
|
< self.max_batch_size # There some bugs in continous batching, so we disable it here.
|
||||||
|
):
|
||||||
# If succeed, add the sequence to running list.
|
# If succeed, add the sequence to running list.
|
||||||
remove_list.append(seq)
|
remove_list.append(seq)
|
||||||
self.running_list.append(seq)
|
self.running_list.append(seq)
|
||||||
self.cache_manager.allocate_context_from_block_table(seq.block_table, seq.input_len)
|
self.cache_manager.allocate_context_from_block_table(seq.block_table, seq.input_len)
|
||||||
for seq in remove_list:
|
for seq in remove_list:
|
||||||
lst.remove(seq)
|
lst.remove(seq)
|
||||||
|
|
||||||
if self.running_list.ready_for_prefill():
|
if self.running_list.ready_for_prefill():
|
||||||
for seq in self.running_list.prefill:
|
for seq in self.running_list.prefill:
|
||||||
seq.mark_running()
|
seq.mark_running()
|
||||||
|
@ -130,12 +126,7 @@ class RequestHandler:
|
||||||
|
|
||||||
if not self.running_batch.is_empty:
|
if not self.running_batch.is_empty:
|
||||||
for seq in self.running_batch.sequences_set:
|
for seq in self.running_batch.sequences_set:
|
||||||
recycle = self.cache_manager.allocate_token_from_block_table(seq.block_table, seq.sentence_len)
|
self.cache_manager.allocate_token_from_block_table(seq.block_table, seq.sentence_len)
|
||||||
if recycle:
|
|
||||||
seq.recycle()
|
|
||||||
self.running_batch.remove(seq)
|
|
||||||
self.waiting_list[-1].append(seq)
|
|
||||||
# the recycled sequences are handled with highest priority.
|
|
||||||
|
|
||||||
return self.running_batch
|
return self.running_batch
|
||||||
|
|
||||||
|
|
|
@ -6,6 +6,7 @@ import torch.nn.functional as F
|
||||||
from transformers.modeling_attn_mask_utils import AttentionMaskConverter
|
from transformers.modeling_attn_mask_utils import AttentionMaskConverter
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad
|
||||||
def copy_to_cache(source, cache, lengths, block_tables, type: str = "prefill"):
|
def copy_to_cache(source, cache, lengths, block_tables, type: str = "prefill"):
|
||||||
"""
|
"""
|
||||||
Func: copy key/value into key/value cache.
|
Func: copy key/value into key/value cache.
|
||||||
|
@ -40,6 +41,7 @@ def copy_to_cache(source, cache, lengths, block_tables, type: str = "prefill"):
|
||||||
return cache
|
return cache
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad
|
||||||
def convert_kvcache(cache, lengths, block_tables, pad_id=0):
|
def convert_kvcache(cache, lengths, block_tables, pad_id=0):
|
||||||
"""
|
"""
|
||||||
Func: convert key/value cache for calculation
|
Func: convert key/value cache for calculation
|
||||||
|
@ -79,6 +81,7 @@ class PagedAttention:
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
@torch.no_grad
|
||||||
def pad_and_reshape(tensor, seq_lengths, max_seq_len, num_heads, head_size):
|
def pad_and_reshape(tensor, seq_lengths, max_seq_len, num_heads, head_size):
|
||||||
"""
|
"""
|
||||||
Transform 1D no_pad tensor into 2D padded tensor with shape [bsz,seq_len,num_heads,head_size]
|
Transform 1D no_pad tensor into 2D padded tensor with shape [bsz,seq_len,num_heads,head_size]
|
||||||
|
@ -94,12 +97,14 @@ class PagedAttention:
|
||||||
return padded_tensor
|
return padded_tensor
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
@torch.no_grad
|
||||||
def generate_padding_mask(lengths, max_seq_len):
|
def generate_padding_mask(lengths, max_seq_len):
|
||||||
range_tensor = torch.arange(max_seq_len).expand(len(lengths), max_seq_len)
|
range_tensor = torch.arange(max_seq_len).expand(len(lengths), max_seq_len)
|
||||||
padding_mask = range_tensor < lengths.unsqueeze(1)
|
padding_mask = range_tensor < lengths.unsqueeze(1)
|
||||||
return padding_mask
|
return padding_mask
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
@torch.no_grad
|
||||||
def repeat_kv(hidden_states: torch.Tensor, n_rep: int = 1) -> torch.Tensor:
|
def repeat_kv(hidden_states: torch.Tensor, n_rep: int = 1) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
Essential component for MQA. Equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep).
|
Essential component for MQA. Equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep).
|
||||||
|
@ -117,6 +122,7 @@ class PagedAttention:
|
||||||
return hidden_states.reshape(batch, num_attention_heads, seq_len, head_dim)
|
return hidden_states.reshape(batch, num_attention_heads, seq_len, head_dim)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
@torch.no_grad
|
||||||
def nopad_context_forward(
|
def nopad_context_forward(
|
||||||
q: torch.Tensor, # [num_tokens, num_heads, head_size]
|
q: torch.Tensor, # [num_tokens, num_heads, head_size]
|
||||||
k: torch.Tensor, # [num_tokens, num_kv_heads, head_size]
|
k: torch.Tensor, # [num_tokens, num_kv_heads, head_size]
|
||||||
|
@ -185,6 +191,7 @@ class PagedAttention:
|
||||||
return attn_output
|
return attn_output
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
@torch.no_grad
|
||||||
def pad_context_forward(
|
def pad_context_forward(
|
||||||
q: torch.Tensor, # [batch_size, seq_len, num_heads, head_size]
|
q: torch.Tensor, # [batch_size, seq_len, num_heads, head_size]
|
||||||
k: torch.Tensor, # [batch_size, seq_len, num_kv_heads, head_size]
|
k: torch.Tensor, # [batch_size, seq_len, num_kv_heads, head_size]
|
||||||
|
@ -239,11 +246,10 @@ class PagedAttention:
|
||||||
|
|
||||||
attn_output = attn_output.transpose(1, 2).contiguous().reshape(bsz, seq_len, -1)
|
attn_output = attn_output.transpose(1, 2).contiguous().reshape(bsz, seq_len, -1)
|
||||||
|
|
||||||
del attn_weights
|
|
||||||
|
|
||||||
return attn_output
|
return attn_output
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
@torch.no_grad
|
||||||
def pad_decoding_forward(
|
def pad_decoding_forward(
|
||||||
q: torch.Tensor, # [bsz, 1, num_heads, head_size]
|
q: torch.Tensor, # [bsz, 1, num_heads, head_size]
|
||||||
k: torch.Tensor, # [bsz, 1, num_kv_heads, head_size]
|
k: torch.Tensor, # [bsz, 1, num_kv_heads, head_size]
|
||||||
|
@ -297,11 +303,10 @@ class PagedAttention:
|
||||||
raise ValueError(f"Got wrong attn_output, should be in shape {(bsz,num_heads,1,head_size)}.")
|
raise ValueError(f"Got wrong attn_output, should be in shape {(bsz,num_heads,1,head_size)}.")
|
||||||
attn_output = attn_output.transpose(1, 2).contiguous().reshape(bsz, 1, -1)
|
attn_output = attn_output.transpose(1, 2).contiguous().reshape(bsz, 1, -1)
|
||||||
|
|
||||||
del attn_weights
|
|
||||||
|
|
||||||
return attn_output
|
return attn_output
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
@torch.no_grad
|
||||||
def no_pad_decoding_forward(
|
def no_pad_decoding_forward(
|
||||||
self,
|
self,
|
||||||
q: torch.Tensor, # [num_tokens, num_heads, head_size]
|
q: torch.Tensor, # [num_tokens, num_heads, head_size]
|
||||||
|
|
|
@ -2,19 +2,23 @@
|
||||||
from typing import List, Optional, Tuple
|
from typing import List, Optional, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from transformers.models.llama.modeling_llama import (
|
from transformers.models.llama.modeling_llama import LlamaAttention, LlamaDecoderLayer, LlamaForCausalLM, LlamaModel
|
||||||
LlamaAttention,
|
|
||||||
LlamaDecoderLayer,
|
|
||||||
LlamaForCausalLM,
|
|
||||||
LlamaModel,
|
|
||||||
repeat_kv,
|
|
||||||
)
|
|
||||||
|
|
||||||
from colossalai.inference.modeling.layers.attention import PagedAttention
|
from colossalai.inference.modeling.layers.attention import PagedAttention
|
||||||
from colossalai.inference.struct import BatchInfo
|
from colossalai.inference.struct import BatchInfo
|
||||||
|
from colossalai.kernel.triton import context_attention_unpadded, copy_kv_to_blocked_cache, flash_decoding_fwd
|
||||||
|
from colossalai.logging import get_dist_logger
|
||||||
|
|
||||||
from flash_attn.bert_padding import index_first_axis, pad_input # noqa
|
from flash_attn.bert_padding import index_first_axis, pad_input # noqa
|
||||||
|
|
||||||
|
logger = get_dist_logger(__name__)
|
||||||
|
|
||||||
|
try:
|
||||||
|
HAS_TRITON = True
|
||||||
|
except ImportError:
|
||||||
|
HAS_TRITON = False
|
||||||
|
logger.warning(f"triton has not been installed yet, we will use torch to complete the attention calculation.")
|
||||||
|
|
||||||
|
|
||||||
def rotate_half(x):
|
def rotate_half(x):
|
||||||
"""Rotates half the hidden dims of the input."""
|
"""Rotates half the hidden dims of the input."""
|
||||||
|
@ -35,6 +39,7 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids):
|
||||||
return q_embed, k_embed
|
return q_embed, k_embed
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
def llama_causal_lm_forward(
|
def llama_causal_lm_forward(
|
||||||
self: LlamaForCausalLM,
|
self: LlamaForCausalLM,
|
||||||
batch: BatchInfo = None,
|
batch: BatchInfo = None,
|
||||||
|
@ -54,6 +59,7 @@ def llama_causal_lm_forward(
|
||||||
return logits
|
return logits
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
def llama_model_forward(
|
def llama_model_forward(
|
||||||
self: LlamaModel,
|
self: LlamaModel,
|
||||||
batch: BatchInfo = None,
|
batch: BatchInfo = None,
|
||||||
|
@ -63,15 +69,30 @@ def llama_model_forward(
|
||||||
):
|
):
|
||||||
input_ids = batch.get_batch_inputs()
|
input_ids = batch.get_batch_inputs()
|
||||||
block_tables = batch.get_block_table_tensor()
|
block_tables = batch.get_block_table_tensor()
|
||||||
sequence_lengths = batch.get_sequence_lengths()
|
|
||||||
|
|
||||||
attention_mask = batch.get_attn_mask(padding_id)
|
attention_mask = batch.get_attn_mask(padding_id)
|
||||||
|
|
||||||
if batch.is_prompts:
|
if attention_mask is not None:
|
||||||
# Here, we generate position_ids through the input tensor, which can align with the output precision of the transformer.
|
# TODO After the nopad version is implemented, we will use the following code to get sequence_lengths.
|
||||||
position_ids = generate_padding_position_id(attention_mask)
|
# sequence_lengths = batch.get_sequence_lengths()
|
||||||
|
sequence_lengths = attention_mask.sum(dim=-1, dtype=torch.int32)
|
||||||
else:
|
else:
|
||||||
position_ids = (attention_mask.sum(dim=-1) - 1).reshape(-1, 1)
|
sequence_lengths = batch.get_sequence_lengths()
|
||||||
|
|
||||||
|
kv_seq_len = sequence_lengths.max().item()
|
||||||
|
|
||||||
|
if attention_mask is not None:
|
||||||
|
if batch.is_prompts:
|
||||||
|
# Here, we generate position_ids through the input tensor, which can align with the output precision of the transformer.
|
||||||
|
position_ids = generate_padding_position_id(attention_mask)
|
||||||
|
else:
|
||||||
|
position_ids = (attention_mask.sum(dim=-1) - 1).reshape(-1, 1)
|
||||||
|
else:
|
||||||
|
if batch.is_prompts:
|
||||||
|
position_ids = torch.arange(kv_seq_len, dtype=torch.long, device=batch.device)
|
||||||
|
position_ids = position_ids.unsqueeze(0)
|
||||||
|
else:
|
||||||
|
position_ids = torch.arange(kv_seq_len - 1, kv_seq_len, dtype=torch.long, device=batch.device)
|
||||||
|
position_ids = position_ids.unsqueeze(0)
|
||||||
|
|
||||||
hidden_states = self.embed_tokens(input_ids)
|
hidden_states = self.embed_tokens(input_ids)
|
||||||
|
|
||||||
|
@ -85,13 +106,14 @@ def llama_model_forward(
|
||||||
is_prompts=batch.is_prompts,
|
is_prompts=batch.is_prompts,
|
||||||
sequence_lengths=sequence_lengths,
|
sequence_lengths=sequence_lengths,
|
||||||
attention_mask=attention_mask,
|
attention_mask=attention_mask,
|
||||||
|
kv_seq_len=kv_seq_len,
|
||||||
)
|
)
|
||||||
|
|
||||||
hidden_states = self.norm(hidden_states)
|
hidden_states = self.norm(hidden_states)
|
||||||
|
|
||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
def llama_decoder_layer_forward(
|
def llama_decoder_layer_forward(
|
||||||
self: LlamaDecoderLayer,
|
self: LlamaDecoderLayer,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
|
@ -102,6 +124,7 @@ def llama_decoder_layer_forward(
|
||||||
is_prompts: bool = True,
|
is_prompts: bool = True,
|
||||||
sequence_lengths: int = None,
|
sequence_lengths: int = None,
|
||||||
attention_mask: torch.Tensor = None,
|
attention_mask: torch.Tensor = None,
|
||||||
|
kv_seq_len: int = 0,
|
||||||
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
||||||
residual = hidden_states
|
residual = hidden_states
|
||||||
|
|
||||||
|
@ -116,6 +139,7 @@ def llama_decoder_layer_forward(
|
||||||
is_prompts=is_prompts,
|
is_prompts=is_prompts,
|
||||||
sequence_lengths=sequence_lengths,
|
sequence_lengths=sequence_lengths,
|
||||||
attention_mask=attention_mask,
|
attention_mask=attention_mask,
|
||||||
|
kv_seq_len=kv_seq_len,
|
||||||
)
|
)
|
||||||
|
|
||||||
hidden_states = residual + hidden_states
|
hidden_states = residual + hidden_states
|
||||||
|
@ -130,6 +154,7 @@ def llama_decoder_layer_forward(
|
||||||
|
|
||||||
|
|
||||||
# Replace transformers.models.llama.modeling_llama.LlamaAttention.forward
|
# Replace transformers.models.llama.modeling_llama.LlamaAttention.forward
|
||||||
|
@torch.no_grad()
|
||||||
def llama_attn_forward(
|
def llama_attn_forward(
|
||||||
self: LlamaAttention,
|
self: LlamaAttention,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
|
@ -140,6 +165,7 @@ def llama_attn_forward(
|
||||||
is_prompts: bool = True,
|
is_prompts: bool = True,
|
||||||
sequence_lengths: torch.Tensor = None,
|
sequence_lengths: torch.Tensor = None,
|
||||||
attention_mask: torch.Tensor = None,
|
attention_mask: torch.Tensor = None,
|
||||||
|
kv_seq_len: int = 0,
|
||||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||||
bsz, q_len, _ = hidden_states.size()
|
bsz, q_len, _ = hidden_states.size()
|
||||||
|
|
||||||
|
@ -147,26 +173,44 @@ def llama_attn_forward(
|
||||||
key_states = self.k_proj(hidden_states).view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
key_states = self.k_proj(hidden_states).view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||||
value_states = self.v_proj(hidden_states).view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
value_states = self.v_proj(hidden_states).view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||||
|
|
||||||
kv_seq_len = sequence_lengths[0].item()
|
|
||||||
|
|
||||||
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
||||||
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
|
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
|
||||||
|
|
||||||
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
|
||||||
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
|
||||||
|
|
||||||
query_states = query_states.transpose(1, 2)
|
query_states = query_states.transpose(1, 2)
|
||||||
key_states = key_states.transpose(1, 2)
|
key_states = key_states.transpose(1, 2)
|
||||||
value_states = value_states.transpose(1, 2)
|
value_states = value_states.transpose(1, 2)
|
||||||
|
|
||||||
|
_, _, _, block_size = k_cache.shape
|
||||||
|
|
||||||
if is_prompts:
|
if is_prompts:
|
||||||
attn_output = PagedAttention.pad_context_forward(
|
if HAS_TRITON:
|
||||||
query_states, key_states, value_states, k_cache, v_cache, sequence_lengths, block_tables, attention_mask
|
if attention_mask is not None:
|
||||||
)
|
query_states, key_states, value_states, indices = unpading_input(
|
||||||
|
query_states, key_states, value_states, attention_mask
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
query_states = query_states.view(-1, self.num_heads, self.head_dim)
|
||||||
|
key_states = key_states.view(-1, self.num_heads, self.head_dim)
|
||||||
|
value_states = value_states.view(-1, self.num_heads, self.head_dim)
|
||||||
|
|
||||||
|
attn_output = context_attention_unpadded(
|
||||||
|
query_states, key_states, value_states, k_cache, v_cache, sequence_lengths, block_tables, block_size
|
||||||
|
)
|
||||||
|
if attention_mask is not None:
|
||||||
|
attn_output = pad_input(attn_output, indices, bsz, q_len)
|
||||||
|
else:
|
||||||
|
attn_output = PagedAttention.pad_context_forward(
|
||||||
|
query_states, key_states, value_states, k_cache, v_cache, sequence_lengths, block_tables, attention_mask
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
attn_output = PagedAttention.pad_decoding_forward(
|
if HAS_TRITON:
|
||||||
query_states, key_states, value_states, k_cache, v_cache, sequence_lengths, block_tables, attention_mask
|
copy_kv_to_blocked_cache(key_states, k_cache, kv_lengths=sequence_lengths, block_tables=block_tables)
|
||||||
)
|
copy_kv_to_blocked_cache(value_states, v_cache, kv_lengths=sequence_lengths, block_tables=block_tables)
|
||||||
|
attn_output = flash_decoding_fwd(query_states, k_cache, v_cache, sequence_lengths, block_tables, block_size)
|
||||||
|
else:
|
||||||
|
attn_output = PagedAttention.pad_decoding_forward(
|
||||||
|
query_states, key_states, value_states, k_cache, v_cache, sequence_lengths, block_tables, attention_mask
|
||||||
|
)
|
||||||
|
|
||||||
attn_output = attn_output.view(bsz, q_len, self.num_heads, self.head_dim)
|
attn_output = attn_output.view(bsz, q_len, self.num_heads, self.head_dim)
|
||||||
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
|
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
|
||||||
|
@ -175,7 +219,18 @@ def llama_attn_forward(
|
||||||
return attn_output
|
return attn_output
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
def generate_padding_position_id(attention_mask: torch.Tensor) -> torch.Tensor:
|
def generate_padding_position_id(attention_mask: torch.Tensor) -> torch.Tensor:
|
||||||
position_ids = attention_mask.long().cumsum(-1) - 1
|
position_ids = attention_mask.long().cumsum(-1) - 1
|
||||||
position_ids.masked_fill_(attention_mask == 0, 1)
|
position_ids.masked_fill_(attention_mask == 0, 1)
|
||||||
return position_ids
|
return position_ids
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def unpading_input(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, attention_mask: torch.Tensor):
|
||||||
|
indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
|
||||||
|
batch_size, kv_seq_len, num_key_value_heads, head_dim = q.shape
|
||||||
|
q = index_first_axis(q.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices)
|
||||||
|
k = index_first_axis(k.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices)
|
||||||
|
v = index_first_axis(v.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices)
|
||||||
|
return (q, k, v, indices)
|
||||||
|
|
|
@ -332,12 +332,20 @@ class BatchInfo:
|
||||||
return torch.tensor(len_list, dtype=torch.int, device=self.device)
|
return torch.tensor(len_list, dtype=torch.int, device=self.device)
|
||||||
|
|
||||||
def get_attn_mask(self, padding_id: int) -> torch.Tensor:
|
def get_attn_mask(self, padding_id: int) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Generate and return attention mask.
|
||||||
|
"""
|
||||||
past_values = []
|
past_values = []
|
||||||
|
|
||||||
for seq in self.sequences_set:
|
for seq in self.sequences_set:
|
||||||
past_values.append(seq.input_token_id + seq.output_token_id)
|
past_values.append(seq.input_token_id + seq.output_token_id)
|
||||||
|
|
||||||
return torch.tensor(past_values, dtype=torch.int, device=self.device).ne(padding_id).long()
|
attn_mask = torch.tensor(past_values, dtype=torch.int, device=self.device).ne(padding_id).long()
|
||||||
|
|
||||||
|
if torch.any(attn_mask == 0):
|
||||||
|
return attn_mask
|
||||||
|
else:
|
||||||
|
return None
|
||||||
|
|
||||||
def __repr__(self) -> str:
|
def __repr__(self) -> str:
|
||||||
return f"(sequences_set={self.sequences_set}, " f"is_prompts={self.is_prompts})"
|
return f"(sequences_set={self.sequences_set}, " f"is_prompts={self.is_prompts})"
|
||||||
|
|
|
@ -1,13 +1,16 @@
|
||||||
import argparse
|
import argparse
|
||||||
import time
|
import time
|
||||||
|
from contextlib import nullcontext
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
import transformers
|
import transformers
|
||||||
|
from transformers import AutoTokenizer, GenerationConfig
|
||||||
|
|
||||||
import colossalai
|
import colossalai
|
||||||
import colossalai.utils.device as device_utils
|
import colossalai.utils.device as device_utils
|
||||||
from colossalai.inference import InferenceEngine
|
from colossalai.inference.config import InferenceConfig
|
||||||
|
from colossalai.inference.core.engine import InferenceEngine
|
||||||
from colossalai.testing import clear_cache_before_run, rerun_if_address_is_in_use, spawn
|
from colossalai.testing import clear_cache_before_run, rerun_if_address_is_in_use, spawn
|
||||||
from colossalai.utils.device import get_current_device
|
from colossalai.utils.device import get_current_device
|
||||||
|
|
||||||
|
@ -53,36 +56,14 @@ CONFIG_MAP = {
|
||||||
|
|
||||||
def data_gen(batch_size: int = 4, seq_len: int = 512):
|
def data_gen(batch_size: int = 4, seq_len: int = 512):
|
||||||
input_ids = torch.randint(10, 30000, (batch_size, seq_len), device=get_current_device())
|
input_ids = torch.randint(10, 30000, (batch_size, seq_len), device=get_current_device())
|
||||||
attention_mask = torch.ones_like(input_ids)
|
return input_ids
|
||||||
data = dict(input_ids=input_ids, attention_mask=attention_mask)
|
|
||||||
return data
|
|
||||||
|
|
||||||
|
|
||||||
def print_details_info(outputs, model_config, args, whole_end2end):
|
def print_details_info(model_config, args, whole_end2end):
|
||||||
msg: str = ""
|
msg: str = ""
|
||||||
|
|
||||||
if dist.get_rank() == 0:
|
if dist.get_rank() == 0:
|
||||||
msg += "-------Perf Summary-------\n"
|
msg += "-------Perf Summary-------\n"
|
||||||
if args.verbose:
|
|
||||||
timestamps = outputs[1]
|
|
||||||
prefill = []
|
|
||||||
encoder = []
|
|
||||||
end2end = []
|
|
||||||
for timestamp in timestamps:
|
|
||||||
prefill.append(timestamp[1] - timestamp[0])
|
|
||||||
encoder.append(
|
|
||||||
sum(timestamp[i + 1] - timestamp[i] for i in range(1, len(timestamp) - 1)) / (len(timestamp) - 2)
|
|
||||||
)
|
|
||||||
end2end.append(timestamp[-1] - timestamp[0])
|
|
||||||
|
|
||||||
mb_avg_end2end = sum(end2end) / len(end2end)
|
|
||||||
mb_avg_latency = mb_avg_end2end / (args.output_len * args.mb_size)
|
|
||||||
|
|
||||||
msg += f"Average prefill time: {sum(prefill) / len(prefill) * 1000:.2f} ms\n"
|
|
||||||
msg += f"Average encode time: {sum(encoder) / len(encoder) * 1000:.2f} ms\n"
|
|
||||||
msg += f"Average micro batch end2end time: {mb_avg_end2end * 1000:.2f} ms\n"
|
|
||||||
msg += f"Average micro batch per token latency: {mb_avg_latency * 1000:.2f} ms\n"
|
|
||||||
|
|
||||||
whole_avg_latency = whole_end2end / (args.output_len * args.batch_size)
|
whole_avg_latency = whole_end2end / (args.output_len * args.batch_size)
|
||||||
num_layers = getattr(model_config, "num_layers", model_config.num_hidden_layers)
|
num_layers = getattr(model_config, "num_layers", model_config.num_hidden_layers)
|
||||||
num_parameters = num_layers * model_config.hidden_size * model_config.hidden_size * 12 / args.pp_size
|
num_parameters = num_layers * model_config.hidden_size * model_config.hidden_size * 12 / args.pp_size
|
||||||
|
@ -105,35 +86,87 @@ def print_details_info(outputs, model_config, args, whole_end2end):
|
||||||
|
|
||||||
|
|
||||||
def benchmark_inference(args):
|
def benchmark_inference(args):
|
||||||
config = CONFIG_MAP[args.model]
|
with torch.no_grad():
|
||||||
model = transformers.LlamaForCausalLM(config)
|
config = CONFIG_MAP[args.model]
|
||||||
if dist.get_rank() == 0:
|
config.pad_token_id = config.eos_token_id
|
||||||
print("Model loaded")
|
model = transformers.LlamaForCausalLM(config).cuda()
|
||||||
engine = InferenceEngine(
|
model = model.eval()
|
||||||
pp_size=args.pp_size,
|
tokenizer = AutoTokenizer.from_pretrained("/home/caidi/llama_model/")
|
||||||
tp_size=args.tp_size,
|
|
||||||
dtype=args.dtype,
|
|
||||||
micro_batch_size=args.mb_size,
|
|
||||||
model=model,
|
|
||||||
verbose=args.verbose,
|
|
||||||
max_batch_size=args.batch_size,
|
|
||||||
max_input_len=args.seq_len,
|
|
||||||
max_output_len=args.output_len,
|
|
||||||
)
|
|
||||||
data = data_gen(args.batch_size, args.seq_len)
|
|
||||||
|
|
||||||
N_WARMUP_STEPS = 2
|
if args.dtype == "fp16":
|
||||||
|
model = model.half()
|
||||||
|
elif args.dtype == "bf16":
|
||||||
|
model = model.to(torch.bfloat16)
|
||||||
|
|
||||||
for _ in range(N_WARMUP_STEPS):
|
# mbsz = args.mbsz
|
||||||
engine.generate(data)
|
mbsz = args.batch_size
|
||||||
|
if args.mode == "caiinference":
|
||||||
|
inference_config = InferenceConfig(
|
||||||
|
dtype=args.dtype,
|
||||||
|
micro_batch_size=args.mb_size,
|
||||||
|
max_batch_size=mbsz,
|
||||||
|
max_input_len=args.seq_len,
|
||||||
|
max_output_len=args.output_len,
|
||||||
|
prefill_ratio=1.2,
|
||||||
|
)
|
||||||
|
engine = InferenceEngine(model, tokenizer, inference_config, verbose=True)
|
||||||
|
else:
|
||||||
|
engine = model
|
||||||
|
|
||||||
torch.cuda.synchronize()
|
data = data_gen(mbsz, args.seq_len)
|
||||||
whole_end2end = time.time()
|
generation_config = GenerationConfig(
|
||||||
outputs = engine.generate(data)
|
pad_token_id=tokenizer.pad_token_id,
|
||||||
torch.cuda.synchronize()
|
max_new_tokens=args.output_len,
|
||||||
whole_end2end = time.time() - whole_end2end
|
)
|
||||||
|
|
||||||
print_details_info(outputs, model.config, args, whole_end2end)
|
N_WARMUP_STEPS = 2
|
||||||
|
|
||||||
|
ctx = (
|
||||||
|
torch.profiler.profile(
|
||||||
|
record_shapes=True,
|
||||||
|
with_stack=True,
|
||||||
|
with_modules=True,
|
||||||
|
activities=[
|
||||||
|
torch.profiler.ProfilerActivity.CPU,
|
||||||
|
torch.profiler.ProfilerActivity.CUDA,
|
||||||
|
],
|
||||||
|
schedule=torch.profiler.schedule(wait=0, warmup=N_WARMUP_STEPS, active=1),
|
||||||
|
on_trace_ready=torch.profiler.tensorboard_trace_handler("./tb_log_" + args.mode),
|
||||||
|
)
|
||||||
|
if args.profile
|
||||||
|
else nullcontext()
|
||||||
|
)
|
||||||
|
|
||||||
|
with ctx:
|
||||||
|
for _ in range(N_WARMUP_STEPS):
|
||||||
|
if args.mode == "caiinference":
|
||||||
|
engine.add_request(prompts_token_ids=data)
|
||||||
|
engine.generate(generation_config)
|
||||||
|
else:
|
||||||
|
engine.generate(data, generation_config=generation_config)
|
||||||
|
if args.profile:
|
||||||
|
ctx.step()
|
||||||
|
|
||||||
|
if args.nsys:
|
||||||
|
torch.cuda.cudart().cudaProfilerStart()
|
||||||
|
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
|
||||||
|
whole_end2end = time.perf_counter()
|
||||||
|
if args.mode == "caiinference":
|
||||||
|
for _ in range(args.batch_size // mbsz):
|
||||||
|
engine.add_request(prompts_token_ids=data)
|
||||||
|
engine.generate(generation_config)
|
||||||
|
else:
|
||||||
|
for _ in range(args.batch_size // mbsz):
|
||||||
|
engine.generate(data, generation_config=generation_config)
|
||||||
|
whole_end2end = time.perf_counter() - whole_end2end
|
||||||
|
if args.nsys:
|
||||||
|
torch.cuda.cudart().cudaProfilerStop()
|
||||||
|
if args.profile:
|
||||||
|
ctx.step()
|
||||||
|
|
||||||
|
print_details_info(model.config, args, whole_end2end)
|
||||||
|
|
||||||
|
|
||||||
def hybrid_inference(rank, world_size, port, args):
|
def hybrid_inference(rank, world_size, port, args):
|
||||||
|
@ -157,12 +190,21 @@ if __name__ == "__main__":
|
||||||
choices=["toy", "llama-7b", "llama-13b", "llama2-7b", "llama2-13b"],
|
choices=["toy", "llama-7b", "llama-13b", "llama2-7b", "llama2-13b"],
|
||||||
)
|
)
|
||||||
parser.add_argument("-b", "--batch_size", type=int, default=8, help="batch size")
|
parser.add_argument("-b", "--batch_size", type=int, default=8, help="batch size")
|
||||||
parser.add_argument("-s", "--seq_len", type=int, default=8, help="sequence length")
|
parser.add_argument("--mbsz", type=int, default=8, help="batch size for one step")
|
||||||
|
parser.add_argument("-s", "--seq_len", type=int, default=8, help="input sequence length")
|
||||||
parser.add_argument("--mb_size", type=int, default=1, help="micro_batch_size")
|
parser.add_argument("--mb_size", type=int, default=1, help="micro_batch_size")
|
||||||
parser.add_argument("--pp_size", type=int, default=1, help="pipeline size")
|
parser.add_argument("--pp_size", type=int, default=1, help="pipeline size")
|
||||||
parser.add_argument("--tp_size", type=int, default=1, help="pipeline size")
|
parser.add_argument("--tp_size", type=int, default=1, help="pipeline size")
|
||||||
parser.add_argument("--output_len", type=int, default=128, help="Output length")
|
parser.add_argument("--output_len", type=int, default=128, help="Output length")
|
||||||
parser.add_argument("--dtype", type=str, default="fp16", help="data type")
|
parser.add_argument("--dtype", type=str, default="fp16", help="data type", choices=["fp16", "fp32", "bf16"])
|
||||||
parser.add_argument("-v", "--verbose", default=False, action="store_true")
|
parser.add_argument("-v", "--verbose", default=False, action="store_true")
|
||||||
|
parser.add_argument("--profile", default=False, action="store_true", help="enable torch profiler")
|
||||||
|
parser.add_argument("--nsys", default=False, action="store_true", help="enable nsys profiler")
|
||||||
|
parser.add_argument(
|
||||||
|
"--mode",
|
||||||
|
default="caiinference",
|
||||||
|
choices=["caiinference", "transformers"],
|
||||||
|
help="decide which inference framework to run",
|
||||||
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
benchmark(args)
|
benchmark(args)
|
||||||
|
|
|
@ -1,15 +1,33 @@
|
||||||
ROOT=$(realpath $(dirname $0))
|
ROOT=$(realpath $(dirname $0))
|
||||||
PY_SCRIPT=${ROOT}/benchmark_llama.py
|
PY_SCRIPT=${ROOT}/benchmark_llama.py
|
||||||
GPU=$(nvidia-smi -L | head -1 | cut -d' ' -f4 | cut -d'-' -f1)
|
GPU=$(nvidia-smi -L | head -1 | cut -d' ' -f4 | cut -d'-' -f1)
|
||||||
|
mode=$1
|
||||||
|
|
||||||
mkdir -p logs
|
mkdir -p logs
|
||||||
|
|
||||||
|
CUDA_VISIBLE_DEVICES_set_n_least_memory_usage() {
|
||||||
|
local n=${1:-"9999"}
|
||||||
|
echo "GPU Memory Usage:"
|
||||||
|
local FIRST_N_GPU_IDS=$(nvidia-smi --query-gpu=memory.used --format=csv \
|
||||||
|
| tail -n +2 \
|
||||||
|
| nl -v 0 \
|
||||||
|
| tee /dev/tty \
|
||||||
|
| sort -g -k 2 \
|
||||||
|
| awk '{print $1}' \
|
||||||
|
| head -n $n)
|
||||||
|
export CUDA_VISIBLE_DEVICES=$(echo $FIRST_N_GPU_IDS | sed 's/ /,/g')
|
||||||
|
echo "Now CUDA_VISIBLE_DEVICES is set to:"
|
||||||
|
echo "CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES"
|
||||||
|
}
|
||||||
|
|
||||||
|
CUDA_VISIBLE_DEVICES_set_n_least_memory_usage 1
|
||||||
|
|
||||||
# benchmark llama2-7b one single GPU
|
# benchmark llama2-7b one single GPU
|
||||||
for bsz in 16 32 64; do
|
for bsz in 16 32 64; do
|
||||||
python3 ${PY_SCRIPT} -m llama2-7b --tp_size 1 --pp_size 1 -b $bsz -s 256 --output_len 128 | tee logs/${GPU}_${bsz}_256.txt
|
python3 ${PY_SCRIPT} -m llama2-7b --tp_size 1 --pp_size 1 -b $bsz -s 256 --output_len 128 --mode ${mode} | tee logs/${mode}_${GPU}_${bsz}_256.txt
|
||||||
done
|
done
|
||||||
|
|
||||||
|
|
||||||
for bsz in 4 8 16 32 64; do
|
for bsz in 16 32 64; do
|
||||||
python3 ${PY_SCRIPT} -m llama2-7b --tp_size 1 --pp_size 1 -b $bsz -s 1024 --output_len 128 | tee logs/${GPU}_${bsz}_1024.txt
|
python3 ${PY_SCRIPT} -m llama2-7b --tp_size 1 --pp_size 1 -b $bsz -s 1024 --output_len 128 --mode ${mode} | tee logs/${mode}_${GPU}_${bsz}_1024.txt
|
||||||
done
|
done
|
||||||
|
|
Loading…
Reference in New Issue