[Inference]Adapted to the triton attn kernels (#5264)

* adapted to the triton attn kernels

* fix pad input

* adapted to copy_kv_to_blocked_cache

* fix ci test

* update kv memcpy

* remove print
pull/5270/head
yuehuayingxueluo 2024-01-17 16:03:10 +08:00 committed by GitHub
parent 0f2b46a41c
commit 86b63f720c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 221 additions and 101 deletions

View File

@ -236,6 +236,7 @@ class InferenceEngine:
output_list = [] output_list = []
batch = self.request_handler.schedule() batch = self.request_handler.schedule()
# TODO: padding_id is used for generating attn_mask and will be removed if nopad version is supported.
logits = self.model( logits = self.model(
batch, batch,
self.k_cahce, self.k_cahce,

View File

@ -57,9 +57,6 @@ class RunningList:
def is_empty(self): def is_empty(self):
return not self.decoding and not self.prefill return not self.decoding and not self.prefill
def total_seq_num(self):
return len(self.decoding) + len(self.prefill)
class RequestHandler: class RequestHandler:
""" """
@ -81,6 +78,7 @@ class RequestHandler:
device = torch.cuda.current_device() device = torch.cuda.current_device()
self.running_batch = BatchInfo(is_prompts=False, device=device) self.running_batch = BatchInfo(is_prompts=False, device=device)
self.prefill_batch = BatchInfo(is_prompts=True, device=device) self.prefill_batch = BatchInfo(is_prompts=True, device=device)
self.max_batch_size = inference_config.max_batch_size
def _init_cache(self, model_config): def _init_cache(self, model_config):
self.cache_manager = KVCacheManager(self.inference_config, model_config) self.cache_manager = KVCacheManager(self.inference_config, model_config)
@ -108,20 +106,18 @@ class RequestHandler:
) )
self.abort_sequence(seq.request_id) self.abort_sequence(seq.request_id)
break break
# stop feeding new sequence into running list to assure
if self.cache_manager.num_available_blocks <= self.running_list.total_seq_num:
break
# Try to allocate cache blocks for the sequence. # Try to allocate cache blocks for the sequence.
if self.cache_manager.check_allocation(seq): if (
self.cache_manager.check_allocation(seq)
and (len(self.running_list.prefill) + len(self.running_list.decoding))
< self.max_batch_size # There some bugs in continous batching, so we disable it here.
):
# If succeed, add the sequence to running list. # If succeed, add the sequence to running list.
remove_list.append(seq) remove_list.append(seq)
self.running_list.append(seq) self.running_list.append(seq)
self.cache_manager.allocate_context_from_block_table(seq.block_table, seq.input_len) self.cache_manager.allocate_context_from_block_table(seq.block_table, seq.input_len)
for seq in remove_list: for seq in remove_list:
lst.remove(seq) lst.remove(seq)
if self.running_list.ready_for_prefill(): if self.running_list.ready_for_prefill():
for seq in self.running_list.prefill: for seq in self.running_list.prefill:
seq.mark_running() seq.mark_running()
@ -130,12 +126,7 @@ class RequestHandler:
if not self.running_batch.is_empty: if not self.running_batch.is_empty:
for seq in self.running_batch.sequences_set: for seq in self.running_batch.sequences_set:
recycle = self.cache_manager.allocate_token_from_block_table(seq.block_table, seq.sentence_len) self.cache_manager.allocate_token_from_block_table(seq.block_table, seq.sentence_len)
if recycle:
seq.recycle()
self.running_batch.remove(seq)
self.waiting_list[-1].append(seq)
# the recycled sequences are handled with highest priority.
return self.running_batch return self.running_batch

View File

@ -6,6 +6,7 @@ import torch.nn.functional as F
from transformers.modeling_attn_mask_utils import AttentionMaskConverter from transformers.modeling_attn_mask_utils import AttentionMaskConverter
@torch.no_grad
def copy_to_cache(source, cache, lengths, block_tables, type: str = "prefill"): def copy_to_cache(source, cache, lengths, block_tables, type: str = "prefill"):
""" """
Func: copy key/value into key/value cache. Func: copy key/value into key/value cache.
@ -40,6 +41,7 @@ def copy_to_cache(source, cache, lengths, block_tables, type: str = "prefill"):
return cache return cache
@torch.no_grad
def convert_kvcache(cache, lengths, block_tables, pad_id=0): def convert_kvcache(cache, lengths, block_tables, pad_id=0):
""" """
Func: convert key/value cache for calculation Func: convert key/value cache for calculation
@ -79,6 +81,7 @@ class PagedAttention:
""" """
@staticmethod @staticmethod
@torch.no_grad
def pad_and_reshape(tensor, seq_lengths, max_seq_len, num_heads, head_size): def pad_and_reshape(tensor, seq_lengths, max_seq_len, num_heads, head_size):
""" """
Transform 1D no_pad tensor into 2D padded tensor with shape [bsz,seq_len,num_heads,head_size] Transform 1D no_pad tensor into 2D padded tensor with shape [bsz,seq_len,num_heads,head_size]
@ -94,12 +97,14 @@ class PagedAttention:
return padded_tensor return padded_tensor
@staticmethod @staticmethod
@torch.no_grad
def generate_padding_mask(lengths, max_seq_len): def generate_padding_mask(lengths, max_seq_len):
range_tensor = torch.arange(max_seq_len).expand(len(lengths), max_seq_len) range_tensor = torch.arange(max_seq_len).expand(len(lengths), max_seq_len)
padding_mask = range_tensor < lengths.unsqueeze(1) padding_mask = range_tensor < lengths.unsqueeze(1)
return padding_mask return padding_mask
@staticmethod @staticmethod
@torch.no_grad
def repeat_kv(hidden_states: torch.Tensor, n_rep: int = 1) -> torch.Tensor: def repeat_kv(hidden_states: torch.Tensor, n_rep: int = 1) -> torch.Tensor:
""" """
Essential component for MQA. Equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). Essential component for MQA. Equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep).
@ -117,6 +122,7 @@ class PagedAttention:
return hidden_states.reshape(batch, num_attention_heads, seq_len, head_dim) return hidden_states.reshape(batch, num_attention_heads, seq_len, head_dim)
@staticmethod @staticmethod
@torch.no_grad
def nopad_context_forward( def nopad_context_forward(
q: torch.Tensor, # [num_tokens, num_heads, head_size] q: torch.Tensor, # [num_tokens, num_heads, head_size]
k: torch.Tensor, # [num_tokens, num_kv_heads, head_size] k: torch.Tensor, # [num_tokens, num_kv_heads, head_size]
@ -185,6 +191,7 @@ class PagedAttention:
return attn_output return attn_output
@staticmethod @staticmethod
@torch.no_grad
def pad_context_forward( def pad_context_forward(
q: torch.Tensor, # [batch_size, seq_len, num_heads, head_size] q: torch.Tensor, # [batch_size, seq_len, num_heads, head_size]
k: torch.Tensor, # [batch_size, seq_len, num_kv_heads, head_size] k: torch.Tensor, # [batch_size, seq_len, num_kv_heads, head_size]
@ -239,11 +246,10 @@ class PagedAttention:
attn_output = attn_output.transpose(1, 2).contiguous().reshape(bsz, seq_len, -1) attn_output = attn_output.transpose(1, 2).contiguous().reshape(bsz, seq_len, -1)
del attn_weights
return attn_output return attn_output
@staticmethod @staticmethod
@torch.no_grad
def pad_decoding_forward( def pad_decoding_forward(
q: torch.Tensor, # [bsz, 1, num_heads, head_size] q: torch.Tensor, # [bsz, 1, num_heads, head_size]
k: torch.Tensor, # [bsz, 1, num_kv_heads, head_size] k: torch.Tensor, # [bsz, 1, num_kv_heads, head_size]
@ -297,11 +303,10 @@ class PagedAttention:
raise ValueError(f"Got wrong attn_output, should be in shape {(bsz,num_heads,1,head_size)}.") raise ValueError(f"Got wrong attn_output, should be in shape {(bsz,num_heads,1,head_size)}.")
attn_output = attn_output.transpose(1, 2).contiguous().reshape(bsz, 1, -1) attn_output = attn_output.transpose(1, 2).contiguous().reshape(bsz, 1, -1)
del attn_weights
return attn_output return attn_output
@staticmethod @staticmethod
@torch.no_grad
def no_pad_decoding_forward( def no_pad_decoding_forward(
self, self,
q: torch.Tensor, # [num_tokens, num_heads, head_size] q: torch.Tensor, # [num_tokens, num_heads, head_size]

View File

@ -2,19 +2,23 @@
from typing import List, Optional, Tuple from typing import List, Optional, Tuple
import torch import torch
from transformers.models.llama.modeling_llama import ( from transformers.models.llama.modeling_llama import LlamaAttention, LlamaDecoderLayer, LlamaForCausalLM, LlamaModel
LlamaAttention,
LlamaDecoderLayer,
LlamaForCausalLM,
LlamaModel,
repeat_kv,
)
from colossalai.inference.modeling.layers.attention import PagedAttention from colossalai.inference.modeling.layers.attention import PagedAttention
from colossalai.inference.struct import BatchInfo from colossalai.inference.struct import BatchInfo
from colossalai.kernel.triton import context_attention_unpadded, copy_kv_to_blocked_cache, flash_decoding_fwd
from colossalai.logging import get_dist_logger
from flash_attn.bert_padding import index_first_axis, pad_input # noqa from flash_attn.bert_padding import index_first_axis, pad_input # noqa
logger = get_dist_logger(__name__)
try:
HAS_TRITON = True
except ImportError:
HAS_TRITON = False
logger.warning(f"triton has not been installed yet, we will use torch to complete the attention calculation.")
def rotate_half(x): def rotate_half(x):
"""Rotates half the hidden dims of the input.""" """Rotates half the hidden dims of the input."""
@ -35,6 +39,7 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids):
return q_embed, k_embed return q_embed, k_embed
@torch.no_grad()
def llama_causal_lm_forward( def llama_causal_lm_forward(
self: LlamaForCausalLM, self: LlamaForCausalLM,
batch: BatchInfo = None, batch: BatchInfo = None,
@ -54,6 +59,7 @@ def llama_causal_lm_forward(
return logits return logits
@torch.no_grad()
def llama_model_forward( def llama_model_forward(
self: LlamaModel, self: LlamaModel,
batch: BatchInfo = None, batch: BatchInfo = None,
@ -63,15 +69,30 @@ def llama_model_forward(
): ):
input_ids = batch.get_batch_inputs() input_ids = batch.get_batch_inputs()
block_tables = batch.get_block_table_tensor() block_tables = batch.get_block_table_tensor()
sequence_lengths = batch.get_sequence_lengths()
attention_mask = batch.get_attn_mask(padding_id) attention_mask = batch.get_attn_mask(padding_id)
if batch.is_prompts: if attention_mask is not None:
# Here, we generate position_ids through the input tensor, which can align with the output precision of the transformer. # TODO After the nopad version is implemented, we will use the following code to get sequence_lengths.
position_ids = generate_padding_position_id(attention_mask) # sequence_lengths = batch.get_sequence_lengths()
sequence_lengths = attention_mask.sum(dim=-1, dtype=torch.int32)
else: else:
position_ids = (attention_mask.sum(dim=-1) - 1).reshape(-1, 1) sequence_lengths = batch.get_sequence_lengths()
kv_seq_len = sequence_lengths.max().item()
if attention_mask is not None:
if batch.is_prompts:
# Here, we generate position_ids through the input tensor, which can align with the output precision of the transformer.
position_ids = generate_padding_position_id(attention_mask)
else:
position_ids = (attention_mask.sum(dim=-1) - 1).reshape(-1, 1)
else:
if batch.is_prompts:
position_ids = torch.arange(kv_seq_len, dtype=torch.long, device=batch.device)
position_ids = position_ids.unsqueeze(0)
else:
position_ids = torch.arange(kv_seq_len - 1, kv_seq_len, dtype=torch.long, device=batch.device)
position_ids = position_ids.unsqueeze(0)
hidden_states = self.embed_tokens(input_ids) hidden_states = self.embed_tokens(input_ids)
@ -85,13 +106,14 @@ def llama_model_forward(
is_prompts=batch.is_prompts, is_prompts=batch.is_prompts,
sequence_lengths=sequence_lengths, sequence_lengths=sequence_lengths,
attention_mask=attention_mask, attention_mask=attention_mask,
kv_seq_len=kv_seq_len,
) )
hidden_states = self.norm(hidden_states) hidden_states = self.norm(hidden_states)
return hidden_states return hidden_states
@torch.no_grad()
def llama_decoder_layer_forward( def llama_decoder_layer_forward(
self: LlamaDecoderLayer, self: LlamaDecoderLayer,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
@ -102,6 +124,7 @@ def llama_decoder_layer_forward(
is_prompts: bool = True, is_prompts: bool = True,
sequence_lengths: int = None, sequence_lengths: int = None,
attention_mask: torch.Tensor = None, attention_mask: torch.Tensor = None,
kv_seq_len: int = 0,
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
residual = hidden_states residual = hidden_states
@ -116,6 +139,7 @@ def llama_decoder_layer_forward(
is_prompts=is_prompts, is_prompts=is_prompts,
sequence_lengths=sequence_lengths, sequence_lengths=sequence_lengths,
attention_mask=attention_mask, attention_mask=attention_mask,
kv_seq_len=kv_seq_len,
) )
hidden_states = residual + hidden_states hidden_states = residual + hidden_states
@ -130,6 +154,7 @@ def llama_decoder_layer_forward(
# Replace transformers.models.llama.modeling_llama.LlamaAttention.forward # Replace transformers.models.llama.modeling_llama.LlamaAttention.forward
@torch.no_grad()
def llama_attn_forward( def llama_attn_forward(
self: LlamaAttention, self: LlamaAttention,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
@ -140,6 +165,7 @@ def llama_attn_forward(
is_prompts: bool = True, is_prompts: bool = True,
sequence_lengths: torch.Tensor = None, sequence_lengths: torch.Tensor = None,
attention_mask: torch.Tensor = None, attention_mask: torch.Tensor = None,
kv_seq_len: int = 0,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
bsz, q_len, _ = hidden_states.size() bsz, q_len, _ = hidden_states.size()
@ -147,26 +173,44 @@ def llama_attn_forward(
key_states = self.k_proj(hidden_states).view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) key_states = self.k_proj(hidden_states).view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states = self.v_proj(hidden_states).view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) value_states = self.v_proj(hidden_states).view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
kv_seq_len = sequence_lengths[0].item()
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)
query_states = query_states.transpose(1, 2) query_states = query_states.transpose(1, 2)
key_states = key_states.transpose(1, 2) key_states = key_states.transpose(1, 2)
value_states = value_states.transpose(1, 2) value_states = value_states.transpose(1, 2)
_, _, _, block_size = k_cache.shape
if is_prompts: if is_prompts:
attn_output = PagedAttention.pad_context_forward( if HAS_TRITON:
query_states, key_states, value_states, k_cache, v_cache, sequence_lengths, block_tables, attention_mask if attention_mask is not None:
) query_states, key_states, value_states, indices = unpading_input(
query_states, key_states, value_states, attention_mask
)
else:
query_states = query_states.view(-1, self.num_heads, self.head_dim)
key_states = key_states.view(-1, self.num_heads, self.head_dim)
value_states = value_states.view(-1, self.num_heads, self.head_dim)
attn_output = context_attention_unpadded(
query_states, key_states, value_states, k_cache, v_cache, sequence_lengths, block_tables, block_size
)
if attention_mask is not None:
attn_output = pad_input(attn_output, indices, bsz, q_len)
else:
attn_output = PagedAttention.pad_context_forward(
query_states, key_states, value_states, k_cache, v_cache, sequence_lengths, block_tables, attention_mask
)
else: else:
attn_output = PagedAttention.pad_decoding_forward( if HAS_TRITON:
query_states, key_states, value_states, k_cache, v_cache, sequence_lengths, block_tables, attention_mask copy_kv_to_blocked_cache(key_states, k_cache, kv_lengths=sequence_lengths, block_tables=block_tables)
) copy_kv_to_blocked_cache(value_states, v_cache, kv_lengths=sequence_lengths, block_tables=block_tables)
attn_output = flash_decoding_fwd(query_states, k_cache, v_cache, sequence_lengths, block_tables, block_size)
else:
attn_output = PagedAttention.pad_decoding_forward(
query_states, key_states, value_states, k_cache, v_cache, sequence_lengths, block_tables, attention_mask
)
attn_output = attn_output.view(bsz, q_len, self.num_heads, self.head_dim) attn_output = attn_output.view(bsz, q_len, self.num_heads, self.head_dim)
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
@ -175,7 +219,18 @@ def llama_attn_forward(
return attn_output return attn_output
@torch.no_grad()
def generate_padding_position_id(attention_mask: torch.Tensor) -> torch.Tensor: def generate_padding_position_id(attention_mask: torch.Tensor) -> torch.Tensor:
position_ids = attention_mask.long().cumsum(-1) - 1 position_ids = attention_mask.long().cumsum(-1) - 1
position_ids.masked_fill_(attention_mask == 0, 1) position_ids.masked_fill_(attention_mask == 0, 1)
return position_ids return position_ids
@torch.no_grad()
def unpading_input(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, attention_mask: torch.Tensor):
indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
batch_size, kv_seq_len, num_key_value_heads, head_dim = q.shape
q = index_first_axis(q.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices)
k = index_first_axis(k.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices)
v = index_first_axis(v.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices)
return (q, k, v, indices)

View File

@ -332,12 +332,20 @@ class BatchInfo:
return torch.tensor(len_list, dtype=torch.int, device=self.device) return torch.tensor(len_list, dtype=torch.int, device=self.device)
def get_attn_mask(self, padding_id: int) -> torch.Tensor: def get_attn_mask(self, padding_id: int) -> torch.Tensor:
"""
Generate and return attention mask.
"""
past_values = [] past_values = []
for seq in self.sequences_set: for seq in self.sequences_set:
past_values.append(seq.input_token_id + seq.output_token_id) past_values.append(seq.input_token_id + seq.output_token_id)
return torch.tensor(past_values, dtype=torch.int, device=self.device).ne(padding_id).long() attn_mask = torch.tensor(past_values, dtype=torch.int, device=self.device).ne(padding_id).long()
if torch.any(attn_mask == 0):
return attn_mask
else:
return None
def __repr__(self) -> str: def __repr__(self) -> str:
return f"(sequences_set={self.sequences_set}, " f"is_prompts={self.is_prompts})" return f"(sequences_set={self.sequences_set}, " f"is_prompts={self.is_prompts})"

View File

@ -1,13 +1,16 @@
import argparse import argparse
import time import time
from contextlib import nullcontext
import torch import torch
import torch.distributed as dist import torch.distributed as dist
import transformers import transformers
from transformers import AutoTokenizer, GenerationConfig
import colossalai import colossalai
import colossalai.utils.device as device_utils import colossalai.utils.device as device_utils
from colossalai.inference import InferenceEngine from colossalai.inference.config import InferenceConfig
from colossalai.inference.core.engine import InferenceEngine
from colossalai.testing import clear_cache_before_run, rerun_if_address_is_in_use, spawn from colossalai.testing import clear_cache_before_run, rerun_if_address_is_in_use, spawn
from colossalai.utils.device import get_current_device from colossalai.utils.device import get_current_device
@ -53,36 +56,14 @@ CONFIG_MAP = {
def data_gen(batch_size: int = 4, seq_len: int = 512): def data_gen(batch_size: int = 4, seq_len: int = 512):
input_ids = torch.randint(10, 30000, (batch_size, seq_len), device=get_current_device()) input_ids = torch.randint(10, 30000, (batch_size, seq_len), device=get_current_device())
attention_mask = torch.ones_like(input_ids) return input_ids
data = dict(input_ids=input_ids, attention_mask=attention_mask)
return data
def print_details_info(outputs, model_config, args, whole_end2end): def print_details_info(model_config, args, whole_end2end):
msg: str = "" msg: str = ""
if dist.get_rank() == 0: if dist.get_rank() == 0:
msg += "-------Perf Summary-------\n" msg += "-------Perf Summary-------\n"
if args.verbose:
timestamps = outputs[1]
prefill = []
encoder = []
end2end = []
for timestamp in timestamps:
prefill.append(timestamp[1] - timestamp[0])
encoder.append(
sum(timestamp[i + 1] - timestamp[i] for i in range(1, len(timestamp) - 1)) / (len(timestamp) - 2)
)
end2end.append(timestamp[-1] - timestamp[0])
mb_avg_end2end = sum(end2end) / len(end2end)
mb_avg_latency = mb_avg_end2end / (args.output_len * args.mb_size)
msg += f"Average prefill time: {sum(prefill) / len(prefill) * 1000:.2f} ms\n"
msg += f"Average encode time: {sum(encoder) / len(encoder) * 1000:.2f} ms\n"
msg += f"Average micro batch end2end time: {mb_avg_end2end * 1000:.2f} ms\n"
msg += f"Average micro batch per token latency: {mb_avg_latency * 1000:.2f} ms\n"
whole_avg_latency = whole_end2end / (args.output_len * args.batch_size) whole_avg_latency = whole_end2end / (args.output_len * args.batch_size)
num_layers = getattr(model_config, "num_layers", model_config.num_hidden_layers) num_layers = getattr(model_config, "num_layers", model_config.num_hidden_layers)
num_parameters = num_layers * model_config.hidden_size * model_config.hidden_size * 12 / args.pp_size num_parameters = num_layers * model_config.hidden_size * model_config.hidden_size * 12 / args.pp_size
@ -105,35 +86,87 @@ def print_details_info(outputs, model_config, args, whole_end2end):
def benchmark_inference(args): def benchmark_inference(args):
config = CONFIG_MAP[args.model] with torch.no_grad():
model = transformers.LlamaForCausalLM(config) config = CONFIG_MAP[args.model]
if dist.get_rank() == 0: config.pad_token_id = config.eos_token_id
print("Model loaded") model = transformers.LlamaForCausalLM(config).cuda()
engine = InferenceEngine( model = model.eval()
pp_size=args.pp_size, tokenizer = AutoTokenizer.from_pretrained("/home/caidi/llama_model/")
tp_size=args.tp_size,
dtype=args.dtype,
micro_batch_size=args.mb_size,
model=model,
verbose=args.verbose,
max_batch_size=args.batch_size,
max_input_len=args.seq_len,
max_output_len=args.output_len,
)
data = data_gen(args.batch_size, args.seq_len)
N_WARMUP_STEPS = 2 if args.dtype == "fp16":
model = model.half()
elif args.dtype == "bf16":
model = model.to(torch.bfloat16)
for _ in range(N_WARMUP_STEPS): # mbsz = args.mbsz
engine.generate(data) mbsz = args.batch_size
if args.mode == "caiinference":
inference_config = InferenceConfig(
dtype=args.dtype,
micro_batch_size=args.mb_size,
max_batch_size=mbsz,
max_input_len=args.seq_len,
max_output_len=args.output_len,
prefill_ratio=1.2,
)
engine = InferenceEngine(model, tokenizer, inference_config, verbose=True)
else:
engine = model
torch.cuda.synchronize() data = data_gen(mbsz, args.seq_len)
whole_end2end = time.time() generation_config = GenerationConfig(
outputs = engine.generate(data) pad_token_id=tokenizer.pad_token_id,
torch.cuda.synchronize() max_new_tokens=args.output_len,
whole_end2end = time.time() - whole_end2end )
print_details_info(outputs, model.config, args, whole_end2end) N_WARMUP_STEPS = 2
ctx = (
torch.profiler.profile(
record_shapes=True,
with_stack=True,
with_modules=True,
activities=[
torch.profiler.ProfilerActivity.CPU,
torch.profiler.ProfilerActivity.CUDA,
],
schedule=torch.profiler.schedule(wait=0, warmup=N_WARMUP_STEPS, active=1),
on_trace_ready=torch.profiler.tensorboard_trace_handler("./tb_log_" + args.mode),
)
if args.profile
else nullcontext()
)
with ctx:
for _ in range(N_WARMUP_STEPS):
if args.mode == "caiinference":
engine.add_request(prompts_token_ids=data)
engine.generate(generation_config)
else:
engine.generate(data, generation_config=generation_config)
if args.profile:
ctx.step()
if args.nsys:
torch.cuda.cudart().cudaProfilerStart()
torch.cuda.synchronize()
whole_end2end = time.perf_counter()
if args.mode == "caiinference":
for _ in range(args.batch_size // mbsz):
engine.add_request(prompts_token_ids=data)
engine.generate(generation_config)
else:
for _ in range(args.batch_size // mbsz):
engine.generate(data, generation_config=generation_config)
whole_end2end = time.perf_counter() - whole_end2end
if args.nsys:
torch.cuda.cudart().cudaProfilerStop()
if args.profile:
ctx.step()
print_details_info(model.config, args, whole_end2end)
def hybrid_inference(rank, world_size, port, args): def hybrid_inference(rank, world_size, port, args):
@ -157,12 +190,21 @@ if __name__ == "__main__":
choices=["toy", "llama-7b", "llama-13b", "llama2-7b", "llama2-13b"], choices=["toy", "llama-7b", "llama-13b", "llama2-7b", "llama2-13b"],
) )
parser.add_argument("-b", "--batch_size", type=int, default=8, help="batch size") parser.add_argument("-b", "--batch_size", type=int, default=8, help="batch size")
parser.add_argument("-s", "--seq_len", type=int, default=8, help="sequence length") parser.add_argument("--mbsz", type=int, default=8, help="batch size for one step")
parser.add_argument("-s", "--seq_len", type=int, default=8, help="input sequence length")
parser.add_argument("--mb_size", type=int, default=1, help="micro_batch_size") parser.add_argument("--mb_size", type=int, default=1, help="micro_batch_size")
parser.add_argument("--pp_size", type=int, default=1, help="pipeline size") parser.add_argument("--pp_size", type=int, default=1, help="pipeline size")
parser.add_argument("--tp_size", type=int, default=1, help="pipeline size") parser.add_argument("--tp_size", type=int, default=1, help="pipeline size")
parser.add_argument("--output_len", type=int, default=128, help="Output length") parser.add_argument("--output_len", type=int, default=128, help="Output length")
parser.add_argument("--dtype", type=str, default="fp16", help="data type") parser.add_argument("--dtype", type=str, default="fp16", help="data type", choices=["fp16", "fp32", "bf16"])
parser.add_argument("-v", "--verbose", default=False, action="store_true") parser.add_argument("-v", "--verbose", default=False, action="store_true")
parser.add_argument("--profile", default=False, action="store_true", help="enable torch profiler")
parser.add_argument("--nsys", default=False, action="store_true", help="enable nsys profiler")
parser.add_argument(
"--mode",
default="caiinference",
choices=["caiinference", "transformers"],
help="decide which inference framework to run",
)
args = parser.parse_args() args = parser.parse_args()
benchmark(args) benchmark(args)

View File

@ -1,15 +1,33 @@
ROOT=$(realpath $(dirname $0)) ROOT=$(realpath $(dirname $0))
PY_SCRIPT=${ROOT}/benchmark_llama.py PY_SCRIPT=${ROOT}/benchmark_llama.py
GPU=$(nvidia-smi -L | head -1 | cut -d' ' -f4 | cut -d'-' -f1) GPU=$(nvidia-smi -L | head -1 | cut -d' ' -f4 | cut -d'-' -f1)
mode=$1
mkdir -p logs mkdir -p logs
CUDA_VISIBLE_DEVICES_set_n_least_memory_usage() {
local n=${1:-"9999"}
echo "GPU Memory Usage:"
local FIRST_N_GPU_IDS=$(nvidia-smi --query-gpu=memory.used --format=csv \
| tail -n +2 \
| nl -v 0 \
| tee /dev/tty \
| sort -g -k 2 \
| awk '{print $1}' \
| head -n $n)
export CUDA_VISIBLE_DEVICES=$(echo $FIRST_N_GPU_IDS | sed 's/ /,/g')
echo "Now CUDA_VISIBLE_DEVICES is set to:"
echo "CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES"
}
CUDA_VISIBLE_DEVICES_set_n_least_memory_usage 1
# benchmark llama2-7b one single GPU # benchmark llama2-7b one single GPU
for bsz in 16 32 64; do for bsz in 16 32 64; do
python3 ${PY_SCRIPT} -m llama2-7b --tp_size 1 --pp_size 1 -b $bsz -s 256 --output_len 128 | tee logs/${GPU}_${bsz}_256.txt python3 ${PY_SCRIPT} -m llama2-7b --tp_size 1 --pp_size 1 -b $bsz -s 256 --output_len 128 --mode ${mode} | tee logs/${mode}_${GPU}_${bsz}_256.txt
done done
for bsz in 4 8 16 32 64; do for bsz in 16 32 64; do
python3 ${PY_SCRIPT} -m llama2-7b --tp_size 1 --pp_size 1 -b $bsz -s 1024 --output_len 128 | tee logs/${GPU}_${bsz}_1024.txt python3 ${PY_SCRIPT} -m llama2-7b --tp_size 1 --pp_size 1 -b $bsz -s 1024 --output_len 128 --mode ${mode} | tee logs/${mode}_${GPU}_${bsz}_1024.txt
done done