[fix] pytest and fix dyn grid bug

pull/5434/head
Runyu Lu 2024-03-13 17:28:32 +08:00
parent 633e95b301
commit 1821a6dab0
4 changed files with 135 additions and 8 deletions

View File

@ -10,6 +10,8 @@ import torch
import torch.distributed as dist import torch.distributed as dist
from transformers.generation import GenerationConfig from transformers.generation import GenerationConfig
from colossalai.inference.flash_decoding_utils import FDIntermTensors
GibiByte = 1024**3 GibiByte = 1024**3
logger = logging.Logger(__name__) logger = logging.Logger(__name__)
@ -45,13 +47,16 @@ class InputMetaData:
block_tables: torch.Tensor = None block_tables: torch.Tensor = None
sequence_lengths: torch.Tensor = None sequence_lengths: torch.Tensor = None
fd_inter_tensor: torch.Tensor = None fd_inter_tensor: FDIntermTensors = None
batch_size: int = 64 # current_batch_size batch_size: int = 64 # current_batch_size
is_prompts: bool = False is_prompts: bool = False
use_cuda_graph: bool = False use_cuda_graph: bool = False
kv_seq_len: int = 512 kv_seq_len: int = 512
head_dim: int = 32 head_dim: int = 32
def __repr__(self) -> str:
return f"InputMetaData(block_tables={self.block_tables}, sequence_lengths={self.sequence_lengths}, fd_inter_tensor={self.fd_inter_tensor}, batch_size={self.batch_size}, is_prompts={self.is_prompts}, use_cuda_graph={self.use_cuda_graph}, kv_seq_len={self.kv_seq_len}, head_dim={self.head_dim})"
@dataclass @dataclass
class InferenceConfig: class InferenceConfig:
@ -117,9 +122,10 @@ class InferenceConfig:
# cuda_graph # cuda_graph
use_cuda_graph: bool = False use_cuda_graph: bool = False
max_context_len_to_capture: int = max_input_len * max_output_len max_context_len_to_capture: int = 512
def __post_init__(self): def __post_init__(self):
self.max_context_len_to_capture = self.max_input_len + self.max_output_len
self._verify_config() self._verify_config()
def _verify_config(self) -> None: def _verify_config(self) -> None:

View File

@ -118,6 +118,10 @@ class InferenceEngine:
max_num_blocks = (max_context_len_to_capture + block_size - 1) // block_size max_num_blocks = (max_context_len_to_capture + block_size - 1) // block_size
input_tokens_ids = torch.zeros(max_batch_size, dtype=torch.long).cuda() input_tokens_ids = torch.zeros(max_batch_size, dtype=torch.long).cuda()
self.graph_block_tables = np.zeros((max(_BATCH_SIZES_TO_CAPTURE), max_num_blocks), dtype=np.int32) self.graph_block_tables = np.zeros((max(_BATCH_SIZES_TO_CAPTURE), max_num_blocks), dtype=np.int32)
self.graph_block_tables[:, 0] = np.arange(max_num_blocks, max_num_blocks + max(_BATCH_SIZES_TO_CAPTURE))
self.graph_block_tables[0, :] = np.arange(
0, max_num_blocks
) # NOTE this is a hack to insure cuda grpah could capture the fixed cuda kernel grid in flash decoding, to make the first seqlen as the max_seq_len
block_tables = torch.from_numpy(self.graph_block_tables).cuda() block_tables = torch.from_numpy(self.graph_block_tables).cuda()
output_tensor = torch.zeros( output_tensor = torch.zeros(
(max_batch_size, self.model_config.num_attention_heads * head_dim), dtype=self.dtype, device=self.device (max_batch_size, self.model_config.num_attention_heads * head_dim), dtype=self.dtype, device=self.device
@ -127,6 +131,10 @@ class InferenceEngine:
max_num_seqs = self.inference_config.max_batch_size max_num_seqs = self.inference_config.max_batch_size
batch_size_capture_list = [bs for bs in _BATCH_SIZES_TO_CAPTURE if bs <= max_num_seqs] batch_size_capture_list = [bs for bs in _BATCH_SIZES_TO_CAPTURE if bs <= max_num_seqs]
sequence_lengths = torch.ones(max_batch_size, dtype=torch.int).cuda() sequence_lengths = torch.ones(max_batch_size, dtype=torch.int).cuda()
# NOTE this is a hack to insure cuda grpah could capture the fixed cuda kernel grid in flash decoding, to make the first seqlen as the max_seq_len
sequence_lengths[0] = torch.tensor(
self.inference_config.max_context_len_to_capture - 1, dtype=torch.int32
).cuda()
# NOTE: Capturing the largest batch size first may help reduce the # NOTE: Capturing the largest batch size first may help reduce the
# memory usage of CUDA graph. # memory usage of CUDA graph.
@ -385,6 +393,13 @@ class InferenceEngine:
head_dim=batch.head_dim, head_dim=batch.head_dim,
) )
# if not batch.is_prompts:
# self.logger.info(f"decoding")
# self.logger.info(f"input metadata is: {input_meta_data}")
# else:
# self.logger.info(f"prefill")
# self.logger.info(f"input metadata is: {input_meta_data}")
return input_ids, output_tensor, input_meta_data return input_ids, output_tensor, input_meta_data
def step(self) -> List[str]: def step(self) -> List[str]:
@ -414,6 +429,9 @@ class InferenceEngine:
# TODO: padding_id is used for generating attn_mask and will be removed if nopad version is supported. # TODO: padding_id is used for generating attn_mask and will be removed if nopad version is supported.
logits = model_executable(input_token_ids, output_tensor, input_meta_data, self.k_cache, self.v_cache) logits = model_executable(input_token_ids, output_tensor, input_meta_data, self.k_cache, self.v_cache)
# logits_ = self.model(input_token_ids, output_tensor, input_meta_data, self.k_cache, self.v_cache)
# assert torch.all(logits == logits_), f"error! not equal between origin model({logits_[-1]}) and CUDA Graph({logits[-1]})"
if self.inference_config.pad_input: if self.inference_config.pad_input:
logits = logits[:, -1, :] logits = logits[:, -1, :]
self.request_handler.search_tokens(self.generation_config, logits) self.request_handler.search_tokens(self.generation_config, logits)

View File

@ -27,8 +27,7 @@ class CUDAGraphRunner:
assert self.graph is None assert self.graph is None
# run kernel once to cache the kernel, avoid stream capture error # run kernel once to cache the kernel, avoid stream capture error
hidden_states = self.model( hidden_states_origin_model = self.model(
# batch,
input_tokens_ids, input_tokens_ids,
output_tensor, output_tensor,
inputmetadata, inputmetadata,
@ -41,7 +40,7 @@ class CUDAGraphRunner:
# self.logger.info(f"begin capture model...") # self.logger.info(f"begin capture model...")
self.graph = torch.cuda.CUDAGraph() self.graph = torch.cuda.CUDAGraph()
with torch.cuda.graph(self.graph, pool=memory_pool): with torch.cuda.graph(self.graph, pool=memory_pool):
hidden_states = self.model( hidden_states_cuda_graph = self.model(
input_tokens_ids, input_tokens_ids,
output_tensor, output_tensor,
inputmetadata, inputmetadata,
@ -52,15 +51,16 @@ class CUDAGraphRunner:
# Save the input and output buffers, because replay always uses the same virtual memory space # Save the input and output buffers, because replay always uses the same virtual memory space
self.input_buffers = { self.input_buffers = {
# "batch": batch,
"input_tokens_ids": input_tokens_ids, "input_tokens_ids": input_tokens_ids,
"output_tensor": output_tensor, "output_tensor": output_tensor,
"block_tables": inputmetadata.block_tables, "block_tables": inputmetadata.block_tables,
"sequence_lengths": inputmetadata.sequence_lengths, "sequence_lengths": inputmetadata.sequence_lengths,
# "fd_inter_tensor_mid_output": inputmetadata.fd_inter_tensor._mid_output,
# "fd_inter_tensor_mid_output_lse": inputmetadata.fd_inter_tensor._mid_output_lse,
"k_caches": k_caches, "k_caches": k_caches,
"v_caches": v_caches, "v_caches": v_caches,
} }
self.output_buffers = {"logits": hidden_states} self.output_buffers = {"logits": hidden_states_cuda_graph}
return return
def forward( def forward(
@ -74,9 +74,18 @@ class CUDAGraphRunner:
# Copy the input tensors to the input buffers. # Copy the input tensors to the input buffers.
self.input_buffers["input_tokens_ids"].copy_(input_tokens_ids, non_blocking=True) self.input_buffers["input_tokens_ids"].copy_(input_tokens_ids, non_blocking=True)
self.input_buffers["output_tensor"].copy_(output_tensor, non_blocking=True) self.input_buffers["output_tensor"].copy_(output_tensor, non_blocking=True)
self.input_buffers["block_tables"].copy_(inputmetadata.block_tables, non_blocking=True)
# for flexible block_table
self.input_buffers["block_tables"].fill_(-1)
M, N = inputmetadata.block_tables.shape
self.input_buffers["block_tables"][:M, :N].copy_(inputmetadata.block_tables, non_blocking=True)
self.input_buffers["sequence_lengths"].copy_(inputmetadata.sequence_lengths, non_blocking=True) self.input_buffers["sequence_lengths"].copy_(inputmetadata.sequence_lengths, non_blocking=True)
# we only have a global fd_inter_tensor so we don't need to copy them
# self.input_buffers["fd_inter_tensor_mid_output"].copy_(inputmetadata.fd_inter_tensor.mid_output, non_blocking=True)
# self.input_buffers["fd_inter_tensor_mid_output_lse"].copy_(inputmetadata.fd_inter_tensor.mid_output_lse, non_blocking=True)
# KV caches are fixed tensors, so we don't need to copy them. # KV caches are fixed tensors, so we don't need to copy them.
# self.input_buffers["k_caches"].copy_(k_caches, non_blocking=True) # self.input_buffers["k_caches"].copy_(k_caches, non_blocking=True)
# self.input_buffers["v_caches"].copy_(v_caches, non_blocking=True) # self.input_buffers["v_caches"].copy_(v_caches, non_blocking=True)

View File

@ -0,0 +1,94 @@
import random
import numpy as np
import pytest
import torch
from transformers import AutoTokenizer, GenerationConfig, LlamaConfig, LlamaForCausalLM
import colossalai
from colossalai.inference.config import InferenceConfig
from colossalai.inference.core.engine import InferenceEngine
from colossalai.testing import rerun_if_address_is_in_use, spawn
def setup_seed(seed):
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
np.random.seed(seed)
random.seed(seed)
def check_inference_engine(use_cuda_graph=False, batch_size=32):
setup_seed(20)
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/llama-tokenizer")
model = (
LlamaForCausalLM(
LlamaConfig(
vocab_size=50000, hidden_size=512, intermediate_size=1536, num_attention_heads=4, num_hidden_layers=16
)
)
.cuda()
.half()
)
model = model.eval()
prompts_token_ids = []
for i in range(batch_size):
prompts_token_ids.append(np.random.randint(low=0, high=100, size=random.randint(1, 1024)).tolist())
input_len = 1024
output_len = 128
do_sample = True
top_p = 0.5
top_k = 50
if use_cuda_graph:
inference_config = InferenceConfig(
max_batch_size=batch_size,
max_input_len=input_len,
max_output_len=output_len,
use_cuda_graph=True,
block_size=16,
)
else:
inference_config = InferenceConfig(
max_batch_size=batch_size,
max_input_len=input_len,
max_output_len=output_len,
use_cuda_graph=False,
block_size=16,
)
inference_engine = InferenceEngine(model, tokenizer, inference_config, verbose=True)
assert inference_engine.generation_config.max_new_tokens == output_len
generation_config = GenerationConfig(do_sample=do_sample, top_p=top_p, top_k=top_k)
outputs = inference_engine.generate(prompts_token_ids=prompts_token_ids, generation_config=generation_config)
# print(f"outputs, use_cuda_grpah is {use_cuda_graph}, output: {outputs}")
return outputs
def check_output_consistency(batch_size):
cuda_graph_output = check_inference_engine(use_cuda_graph=True, batch_size=batch_size)
naive_model_output = check_inference_engine(use_cuda_graph=False, batch_size=batch_size)
for s1, s2 in zip(cuda_graph_output, naive_model_output):
assert s1 == s2, f"\nCUDA Graph Output: {s1}\nOrigin Output: {s2}"
def run_dist(rank, world_size, port):
colossalai.launch(config={}, rank=rank, world_size=world_size, port=port, host="localhost")
check_output_consistency(32)
check_output_consistency(64)
check_output_consistency(128)
@pytest.mark.dist
@rerun_if_address_is_in_use()
def test_cuda_graph_infer():
spawn(run_dist, 1)
if __name__ == "__main__":
test_cuda_graph_infer()