diff --git a/colossalai/inference/config.py b/colossalai/inference/config.py index 210c3c618..1c4d4e3aa 100644 --- a/colossalai/inference/config.py +++ b/colossalai/inference/config.py @@ -10,6 +10,8 @@ import torch import torch.distributed as dist from transformers.generation import GenerationConfig +from colossalai.inference.flash_decoding_utils import FDIntermTensors + GibiByte = 1024**3 logger = logging.Logger(__name__) @@ -45,13 +47,16 @@ class InputMetaData: block_tables: torch.Tensor = None sequence_lengths: torch.Tensor = None - fd_inter_tensor: torch.Tensor = None + fd_inter_tensor: FDIntermTensors = None batch_size: int = 64 # current_batch_size is_prompts: bool = False use_cuda_graph: bool = False kv_seq_len: int = 512 head_dim: int = 32 + def __repr__(self) -> str: + return f"InputMetaData(block_tables={self.block_tables}, sequence_lengths={self.sequence_lengths}, fd_inter_tensor={self.fd_inter_tensor}, batch_size={self.batch_size}, is_prompts={self.is_prompts}, use_cuda_graph={self.use_cuda_graph}, kv_seq_len={self.kv_seq_len}, head_dim={self.head_dim})" + @dataclass class InferenceConfig: @@ -117,9 +122,10 @@ class InferenceConfig: # cuda_graph use_cuda_graph: bool = False - max_context_len_to_capture: int = max_input_len * max_output_len + max_context_len_to_capture: int = 512 def __post_init__(self): + self.max_context_len_to_capture = self.max_input_len + self.max_output_len self._verify_config() def _verify_config(self) -> None: diff --git a/colossalai/inference/core/engine.py b/colossalai/inference/core/engine.py index 742f53f76..e096956d3 100644 --- a/colossalai/inference/core/engine.py +++ b/colossalai/inference/core/engine.py @@ -118,6 +118,10 @@ class InferenceEngine: max_num_blocks = (max_context_len_to_capture + block_size - 1) // block_size input_tokens_ids = torch.zeros(max_batch_size, dtype=torch.long).cuda() self.graph_block_tables = np.zeros((max(_BATCH_SIZES_TO_CAPTURE), max_num_blocks), dtype=np.int32) + self.graph_block_tables[:, 0] = np.arange(max_num_blocks, max_num_blocks + max(_BATCH_SIZES_TO_CAPTURE)) + self.graph_block_tables[0, :] = np.arange( + 0, max_num_blocks + ) # NOTE this is a hack to insure cuda grpah could capture the fixed cuda kernel grid in flash decoding, to make the first seqlen as the max_seq_len block_tables = torch.from_numpy(self.graph_block_tables).cuda() output_tensor = torch.zeros( (max_batch_size, self.model_config.num_attention_heads * head_dim), dtype=self.dtype, device=self.device @@ -127,6 +131,10 @@ class InferenceEngine: max_num_seqs = self.inference_config.max_batch_size batch_size_capture_list = [bs for bs in _BATCH_SIZES_TO_CAPTURE if bs <= max_num_seqs] sequence_lengths = torch.ones(max_batch_size, dtype=torch.int).cuda() + # NOTE this is a hack to insure cuda grpah could capture the fixed cuda kernel grid in flash decoding, to make the first seqlen as the max_seq_len + sequence_lengths[0] = torch.tensor( + self.inference_config.max_context_len_to_capture - 1, dtype=torch.int32 + ).cuda() # NOTE: Capturing the largest batch size first may help reduce the # memory usage of CUDA graph. @@ -385,6 +393,13 @@ class InferenceEngine: head_dim=batch.head_dim, ) + # if not batch.is_prompts: + # self.logger.info(f"decoding") + # self.logger.info(f"input metadata is: {input_meta_data}") + # else: + # self.logger.info(f"prefill") + # self.logger.info(f"input metadata is: {input_meta_data}") + return input_ids, output_tensor, input_meta_data def step(self) -> List[str]: @@ -414,6 +429,9 @@ class InferenceEngine: # TODO: padding_id is used for generating attn_mask and will be removed if nopad version is supported. logits = model_executable(input_token_ids, output_tensor, input_meta_data, self.k_cache, self.v_cache) + # logits_ = self.model(input_token_ids, output_tensor, input_meta_data, self.k_cache, self.v_cache) + # assert torch.all(logits == logits_), f"error! not equal between origin model({logits_[-1]}) and CUDA Graph({logits[-1]})" + if self.inference_config.pad_input: logits = logits[:, -1, :] self.request_handler.search_tokens(self.generation_config, logits) diff --git a/colossalai/inference/graph_runner.py b/colossalai/inference/graph_runner.py index 7e63cfce2..e8b805574 100644 --- a/colossalai/inference/graph_runner.py +++ b/colossalai/inference/graph_runner.py @@ -27,8 +27,7 @@ class CUDAGraphRunner: assert self.graph is None # run kernel once to cache the kernel, avoid stream capture error - hidden_states = self.model( - # batch, + hidden_states_origin_model = self.model( input_tokens_ids, output_tensor, inputmetadata, @@ -41,7 +40,7 @@ class CUDAGraphRunner: # self.logger.info(f"begin capture model...") self.graph = torch.cuda.CUDAGraph() with torch.cuda.graph(self.graph, pool=memory_pool): - hidden_states = self.model( + hidden_states_cuda_graph = self.model( input_tokens_ids, output_tensor, inputmetadata, @@ -52,15 +51,16 @@ class CUDAGraphRunner: # Save the input and output buffers, because replay always uses the same virtual memory space self.input_buffers = { - # "batch": batch, "input_tokens_ids": input_tokens_ids, "output_tensor": output_tensor, "block_tables": inputmetadata.block_tables, "sequence_lengths": inputmetadata.sequence_lengths, + # "fd_inter_tensor_mid_output": inputmetadata.fd_inter_tensor._mid_output, + # "fd_inter_tensor_mid_output_lse": inputmetadata.fd_inter_tensor._mid_output_lse, "k_caches": k_caches, "v_caches": v_caches, } - self.output_buffers = {"logits": hidden_states} + self.output_buffers = {"logits": hidden_states_cuda_graph} return def forward( @@ -74,9 +74,18 @@ class CUDAGraphRunner: # Copy the input tensors to the input buffers. self.input_buffers["input_tokens_ids"].copy_(input_tokens_ids, non_blocking=True) self.input_buffers["output_tensor"].copy_(output_tensor, non_blocking=True) - self.input_buffers["block_tables"].copy_(inputmetadata.block_tables, non_blocking=True) + + # for flexible block_table + self.input_buffers["block_tables"].fill_(-1) + M, N = inputmetadata.block_tables.shape + self.input_buffers["block_tables"][:M, :N].copy_(inputmetadata.block_tables, non_blocking=True) + self.input_buffers["sequence_lengths"].copy_(inputmetadata.sequence_lengths, non_blocking=True) + # we only have a global fd_inter_tensor so we don't need to copy them + # self.input_buffers["fd_inter_tensor_mid_output"].copy_(inputmetadata.fd_inter_tensor.mid_output, non_blocking=True) + # self.input_buffers["fd_inter_tensor_mid_output_lse"].copy_(inputmetadata.fd_inter_tensor.mid_output_lse, non_blocking=True) + # KV caches are fixed tensors, so we don't need to copy them. # self.input_buffers["k_caches"].copy_(k_caches, non_blocking=True) # self.input_buffers["v_caches"].copy_(v_caches, non_blocking=True) diff --git a/tests/test_infer/test_cuda_graph.py b/tests/test_infer/test_cuda_graph.py new file mode 100644 index 000000000..0810c356a --- /dev/null +++ b/tests/test_infer/test_cuda_graph.py @@ -0,0 +1,94 @@ +import random + +import numpy as np +import pytest +import torch +from transformers import AutoTokenizer, GenerationConfig, LlamaConfig, LlamaForCausalLM + +import colossalai +from colossalai.inference.config import InferenceConfig +from colossalai.inference.core.engine import InferenceEngine +from colossalai.testing import rerun_if_address_is_in_use, spawn + + +def setup_seed(seed): + torch.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + np.random.seed(seed) + random.seed(seed) + + +def check_inference_engine(use_cuda_graph=False, batch_size=32): + setup_seed(20) + tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/llama-tokenizer") + model = ( + LlamaForCausalLM( + LlamaConfig( + vocab_size=50000, hidden_size=512, intermediate_size=1536, num_attention_heads=4, num_hidden_layers=16 + ) + ) + .cuda() + .half() + ) + model = model.eval() + + prompts_token_ids = [] + for i in range(batch_size): + prompts_token_ids.append(np.random.randint(low=0, high=100, size=random.randint(1, 1024)).tolist()) + + input_len = 1024 + output_len = 128 + do_sample = True + top_p = 0.5 + top_k = 50 + + if use_cuda_graph: + inference_config = InferenceConfig( + max_batch_size=batch_size, + max_input_len=input_len, + max_output_len=output_len, + use_cuda_graph=True, + block_size=16, + ) + else: + inference_config = InferenceConfig( + max_batch_size=batch_size, + max_input_len=input_len, + max_output_len=output_len, + use_cuda_graph=False, + block_size=16, + ) + + inference_engine = InferenceEngine(model, tokenizer, inference_config, verbose=True) + assert inference_engine.generation_config.max_new_tokens == output_len + generation_config = GenerationConfig(do_sample=do_sample, top_p=top_p, top_k=top_k) + outputs = inference_engine.generate(prompts_token_ids=prompts_token_ids, generation_config=generation_config) + + # print(f"outputs, use_cuda_grpah is {use_cuda_graph}, output: {outputs}") + + return outputs + + +def check_output_consistency(batch_size): + cuda_graph_output = check_inference_engine(use_cuda_graph=True, batch_size=batch_size) + naive_model_output = check_inference_engine(use_cuda_graph=False, batch_size=batch_size) + + for s1, s2 in zip(cuda_graph_output, naive_model_output): + assert s1 == s2, f"\nCUDA Graph Output: {s1}\nOrigin Output: {s2}" + + +def run_dist(rank, world_size, port): + colossalai.launch(config={}, rank=rank, world_size=world_size, port=port, host="localhost") + check_output_consistency(32) + check_output_consistency(64) + check_output_consistency(128) + + +@pytest.mark.dist +@rerun_if_address_is_in_use() +def test_cuda_graph_infer(): + spawn(run_dist, 1) + + +if __name__ == "__main__": + test_cuda_graph_infer()