mirror of https://github.com/hpcaitech/ColossalAI
[fix] multi graphs capture error
parent
cefaeb5fdd
commit
b2c0d9ff2b
|
@ -79,7 +79,7 @@ class InferenceConfig:
|
|||
micro_batch_size (int): the micro batch size, defaults to 1. Only useful when `pp_size` > 1.
|
||||
micro_batch_buffer_size (int): the buffer size for micro batch. Normally, it should be the same as the number of pipeline stages.
|
||||
use_cuda_graph (bool): Whether to enforce CUDA graph execution. If False, we will disable CUDA graph and always execute the model in eager mode. If True, we will use eager execution in hybrid.
|
||||
max_context_len_to_capture (int)
|
||||
max_context_len_to_capture (int): max context len that could be captured by CUDA Graph, per sequence
|
||||
|
||||
"""
|
||||
|
||||
|
|
|
@ -29,6 +29,8 @@ _supported_models = [
|
|||
"LlamaForCausalLM",
|
||||
]
|
||||
|
||||
_BATCH_SIZES_TO_CAPTURE = [1, 2, 4] + [8 * i for i in range(1, 33)]
|
||||
|
||||
|
||||
class InferenceEngine:
|
||||
|
||||
|
@ -108,54 +110,49 @@ class InferenceEngine:
|
|||
|
||||
t_capture_begin = time.perf_counter()
|
||||
|
||||
_BATCH_SIZES_TO_CAPTURE = [1, 2, 4] + [8 * i for i in range(1, 33)]
|
||||
|
||||
block_size = self.inference_config.block_size
|
||||
head_dim = self.model_config.hidden_size // self.model_config.num_attention_heads
|
||||
|
||||
# Prepare dummy inputs. These will be reused for all batch sizes.
|
||||
max_batch_size = max(_BATCH_SIZES_TO_CAPTURE)
|
||||
|
||||
max_context_len_to_capture = self.inference_config.max_context_len_to_capture
|
||||
max_num_blocks = (max_context_len_to_capture + block_size - 1) // block_size
|
||||
input_tokens = torch.zeros(max_batch_size, 1, dtype=torch.long).cuda()
|
||||
input_tokens_ids = torch.zeros(max_batch_size, dtype=torch.long).cuda()
|
||||
self.graph_block_tables = np.zeros((max(_BATCH_SIZES_TO_CAPTURE), max_num_blocks), dtype=np.int32)
|
||||
block_tables = torch.from_numpy(self.graph_block_tables).cuda()
|
||||
output_tensor = torch.zeros(
|
||||
(max_batch_size, self.model_config.num_attention_heads * head_dim), dtype=self.dtype, device=self.device
|
||||
)
|
||||
fd_inter_tensor = self.request_handler.running_bb.fd_inter_tensor
|
||||
|
||||
max_num_seqs = self.inference_config.max_batch_size
|
||||
batch_size_capture_list = [bs for bs in _BATCH_SIZES_TO_CAPTURE if bs <= max_num_seqs]
|
||||
sequence_lengths = torch.ones(max_batch_size, dtype=torch.int).cuda()
|
||||
|
||||
# NOTE: Capturing the largest batch size first may help reduce the
|
||||
# memory usage of CUDA graph.
|
||||
for batch_size in reversed(batch_size_capture_list[-1:]):
|
||||
batch_bucket_for_capture = copy.deepcopy(self.request_handler.running_bb)
|
||||
batch_bucket_for_capture.fd_interm_tensor = self.request_handler.running_bb.fd_interm_tensor
|
||||
for batch_size in reversed(batch_size_capture_list):
|
||||
|
||||
if self.verbose:
|
||||
self.logger.info(f"batch size {batch_size} graph capturing")
|
||||
|
||||
# generate dummy input
|
||||
for i in range(batch_size):
|
||||
sequence = Sequence(
|
||||
i,
|
||||
None,
|
||||
input_tokens[i],
|
||||
block_size,
|
||||
None,
|
||||
self.tokenizer.eos_token_id,
|
||||
self.tokenizer.pad_token_id,
|
||||
self.inference_config.max_output_len,
|
||||
)
|
||||
sequence.output_token_id = [0] # only capture the graph of decoding
|
||||
batch_bucket_for_capture.add_seq(sequence, alloc_block_table=block_tables[i])
|
||||
|
||||
input_data = self.prepare_input(batch_bucket_for_capture)
|
||||
|
||||
input_tokens_ids, output_tensor, inputmetadata = input_data
|
||||
input_meta_data = InputMetaData(
|
||||
block_tables=block_tables[:batch_size],
|
||||
sequence_lengths=sequence_lengths[:batch_size],
|
||||
fd_inter_tensor=fd_inter_tensor,
|
||||
batch_size=batch_size,
|
||||
is_prompts=False,
|
||||
use_cuda_graph=True,
|
||||
kv_seq_len=sequence_lengths[:batch_size].max().item(),
|
||||
head_dim=head_dim,
|
||||
)
|
||||
|
||||
graph_runner = CUDAGraphRunner(self.model)
|
||||
graph_runner.capture(
|
||||
input_tokens_ids,
|
||||
output_tensor,
|
||||
inputmetadata,
|
||||
input_tokens_ids[:batch_size],
|
||||
output_tensor[:batch_size],
|
||||
input_meta_data,
|
||||
k_caches=k_cache,
|
||||
v_caches=v_cache,
|
||||
memory_pool=self.graph_memory_pool,
|
||||
|
@ -412,8 +409,10 @@ class InferenceEngine:
|
|||
|
||||
if input_meta_data.use_cuda_graph:
|
||||
model_executable = self.graph_runners[input_meta_data.batch_size]
|
||||
# self.logger.info("run cuda graph")
|
||||
else:
|
||||
model_executable = self.model
|
||||
# self.logger.info("run original model")
|
||||
|
||||
# TODO: padding_id is used for generating attn_mask and will be removed if nopad version is supported.
|
||||
logits = model_executable(input_token_ids, output_tensor, input_meta_data, self.k_cache, self.v_cache)
|
||||
|
|
|
@ -42,7 +42,6 @@ class CUDAGraphRunner:
|
|||
self.graph = torch.cuda.CUDAGraph()
|
||||
with torch.cuda.graph(self.graph, pool=memory_pool):
|
||||
hidden_states = self.model(
|
||||
# batch,
|
||||
input_tokens_ids,
|
||||
output_tensor,
|
||||
inputmetadata,
|
||||
|
|
|
@ -92,7 +92,6 @@ if HAS_TRITON:
|
|||
|
||||
def rms_layernorm(x, weight, eps, norm_output=None, residual=None):
|
||||
# allocate output
|
||||
# y = torch.empty_like(x) if norm_output is None else norm_output
|
||||
y = (
|
||||
x * 0 if norm_output is None else norm_output
|
||||
) # to make the operation non-functional, store y as the intermediate activation
|
||||
|
|
Loading…
Reference in New Issue