[fix] multi graphs capture error

pull/5434/head
Runyu Lu 2024-03-11 10:49:31 +08:00
parent cefaeb5fdd
commit b2c0d9ff2b
4 changed files with 27 additions and 30 deletions

View File

@ -79,7 +79,7 @@ class InferenceConfig:
micro_batch_size (int): the micro batch size, defaults to 1. Only useful when `pp_size` > 1.
micro_batch_buffer_size (int): the buffer size for micro batch. Normally, it should be the same as the number of pipeline stages.
use_cuda_graph (bool): Whether to enforce CUDA graph execution. If False, we will disable CUDA graph and always execute the model in eager mode. If True, we will use eager execution in hybrid.
max_context_len_to_capture (int)
max_context_len_to_capture (int): max context len that could be captured by CUDA Graph, per sequence
"""

View File

@ -29,6 +29,8 @@ _supported_models = [
"LlamaForCausalLM",
]
_BATCH_SIZES_TO_CAPTURE = [1, 2, 4] + [8 * i for i in range(1, 33)]
class InferenceEngine:
@ -108,54 +110,49 @@ class InferenceEngine:
t_capture_begin = time.perf_counter()
_BATCH_SIZES_TO_CAPTURE = [1, 2, 4] + [8 * i for i in range(1, 33)]
block_size = self.inference_config.block_size
head_dim = self.model_config.hidden_size // self.model_config.num_attention_heads
# Prepare dummy inputs. These will be reused for all batch sizes.
max_batch_size = max(_BATCH_SIZES_TO_CAPTURE)
max_context_len_to_capture = self.inference_config.max_context_len_to_capture
max_num_blocks = (max_context_len_to_capture + block_size - 1) // block_size
input_tokens = torch.zeros(max_batch_size, 1, dtype=torch.long).cuda()
input_tokens_ids = torch.zeros(max_batch_size, dtype=torch.long).cuda()
self.graph_block_tables = np.zeros((max(_BATCH_SIZES_TO_CAPTURE), max_num_blocks), dtype=np.int32)
block_tables = torch.from_numpy(self.graph_block_tables).cuda()
output_tensor = torch.zeros(
(max_batch_size, self.model_config.num_attention_heads * head_dim), dtype=self.dtype, device=self.device
)
fd_inter_tensor = self.request_handler.running_bb.fd_inter_tensor
max_num_seqs = self.inference_config.max_batch_size
batch_size_capture_list = [bs for bs in _BATCH_SIZES_TO_CAPTURE if bs <= max_num_seqs]
sequence_lengths = torch.ones(max_batch_size, dtype=torch.int).cuda()
# NOTE: Capturing the largest batch size first may help reduce the
# memory usage of CUDA graph.
for batch_size in reversed(batch_size_capture_list[-1:]):
batch_bucket_for_capture = copy.deepcopy(self.request_handler.running_bb)
batch_bucket_for_capture.fd_interm_tensor = self.request_handler.running_bb.fd_interm_tensor
for batch_size in reversed(batch_size_capture_list):
if self.verbose:
self.logger.info(f"batch size {batch_size} graph capturing")
# generate dummy input
for i in range(batch_size):
sequence = Sequence(
i,
None,
input_tokens[i],
block_size,
None,
self.tokenizer.eos_token_id,
self.tokenizer.pad_token_id,
self.inference_config.max_output_len,
)
sequence.output_token_id = [0] # only capture the graph of decoding
batch_bucket_for_capture.add_seq(sequence, alloc_block_table=block_tables[i])
input_data = self.prepare_input(batch_bucket_for_capture)
input_tokens_ids, output_tensor, inputmetadata = input_data
input_meta_data = InputMetaData(
block_tables=block_tables[:batch_size],
sequence_lengths=sequence_lengths[:batch_size],
fd_inter_tensor=fd_inter_tensor,
batch_size=batch_size,
is_prompts=False,
use_cuda_graph=True,
kv_seq_len=sequence_lengths[:batch_size].max().item(),
head_dim=head_dim,
)
graph_runner = CUDAGraphRunner(self.model)
graph_runner.capture(
input_tokens_ids,
output_tensor,
inputmetadata,
input_tokens_ids[:batch_size],
output_tensor[:batch_size],
input_meta_data,
k_caches=k_cache,
v_caches=v_cache,
memory_pool=self.graph_memory_pool,
@ -412,8 +409,10 @@ class InferenceEngine:
if input_meta_data.use_cuda_graph:
model_executable = self.graph_runners[input_meta_data.batch_size]
# self.logger.info("run cuda graph")
else:
model_executable = self.model
# self.logger.info("run original model")
# TODO: padding_id is used for generating attn_mask and will be removed if nopad version is supported.
logits = model_executable(input_token_ids, output_tensor, input_meta_data, self.k_cache, self.v_cache)

View File

@ -42,7 +42,6 @@ class CUDAGraphRunner:
self.graph = torch.cuda.CUDAGraph()
with torch.cuda.graph(self.graph, pool=memory_pool):
hidden_states = self.model(
# batch,
input_tokens_ids,
output_tensor,
inputmetadata,

View File

@ -92,7 +92,6 @@ if HAS_TRITON:
def rms_layernorm(x, weight, eps, norm_output=None, residual=None):
# allocate output
# y = torch.empty_like(x) if norm_output is None else norm_output
y = (
x * 0 if norm_output is None else norm_output
) # to make the operation non-functional, store y as the intermediate activation