mirror of https://github.com/hpcaitech/ColossalAI
[fix] pytest and fix dyn grid bug
parent
633e95b301
commit
1821a6dab0
|
@ -10,6 +10,8 @@ import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
from transformers.generation import GenerationConfig
|
from transformers.generation import GenerationConfig
|
||||||
|
|
||||||
|
from colossalai.inference.flash_decoding_utils import FDIntermTensors
|
||||||
|
|
||||||
GibiByte = 1024**3
|
GibiByte = 1024**3
|
||||||
|
|
||||||
logger = logging.Logger(__name__)
|
logger = logging.Logger(__name__)
|
||||||
|
@ -45,13 +47,16 @@ class InputMetaData:
|
||||||
|
|
||||||
block_tables: torch.Tensor = None
|
block_tables: torch.Tensor = None
|
||||||
sequence_lengths: torch.Tensor = None
|
sequence_lengths: torch.Tensor = None
|
||||||
fd_inter_tensor: torch.Tensor = None
|
fd_inter_tensor: FDIntermTensors = None
|
||||||
batch_size: int = 64 # current_batch_size
|
batch_size: int = 64 # current_batch_size
|
||||||
is_prompts: bool = False
|
is_prompts: bool = False
|
||||||
use_cuda_graph: bool = False
|
use_cuda_graph: bool = False
|
||||||
kv_seq_len: int = 512
|
kv_seq_len: int = 512
|
||||||
head_dim: int = 32
|
head_dim: int = 32
|
||||||
|
|
||||||
|
def __repr__(self) -> str:
|
||||||
|
return f"InputMetaData(block_tables={self.block_tables}, sequence_lengths={self.sequence_lengths}, fd_inter_tensor={self.fd_inter_tensor}, batch_size={self.batch_size}, is_prompts={self.is_prompts}, use_cuda_graph={self.use_cuda_graph}, kv_seq_len={self.kv_seq_len}, head_dim={self.head_dim})"
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class InferenceConfig:
|
class InferenceConfig:
|
||||||
|
@ -117,9 +122,10 @@ class InferenceConfig:
|
||||||
|
|
||||||
# cuda_graph
|
# cuda_graph
|
||||||
use_cuda_graph: bool = False
|
use_cuda_graph: bool = False
|
||||||
max_context_len_to_capture: int = max_input_len * max_output_len
|
max_context_len_to_capture: int = 512
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
|
self.max_context_len_to_capture = self.max_input_len + self.max_output_len
|
||||||
self._verify_config()
|
self._verify_config()
|
||||||
|
|
||||||
def _verify_config(self) -> None:
|
def _verify_config(self) -> None:
|
||||||
|
|
|
@ -118,6 +118,10 @@ class InferenceEngine:
|
||||||
max_num_blocks = (max_context_len_to_capture + block_size - 1) // block_size
|
max_num_blocks = (max_context_len_to_capture + block_size - 1) // block_size
|
||||||
input_tokens_ids = torch.zeros(max_batch_size, dtype=torch.long).cuda()
|
input_tokens_ids = torch.zeros(max_batch_size, dtype=torch.long).cuda()
|
||||||
self.graph_block_tables = np.zeros((max(_BATCH_SIZES_TO_CAPTURE), max_num_blocks), dtype=np.int32)
|
self.graph_block_tables = np.zeros((max(_BATCH_SIZES_TO_CAPTURE), max_num_blocks), dtype=np.int32)
|
||||||
|
self.graph_block_tables[:, 0] = np.arange(max_num_blocks, max_num_blocks + max(_BATCH_SIZES_TO_CAPTURE))
|
||||||
|
self.graph_block_tables[0, :] = np.arange(
|
||||||
|
0, max_num_blocks
|
||||||
|
) # NOTE this is a hack to insure cuda grpah could capture the fixed cuda kernel grid in flash decoding, to make the first seqlen as the max_seq_len
|
||||||
block_tables = torch.from_numpy(self.graph_block_tables).cuda()
|
block_tables = torch.from_numpy(self.graph_block_tables).cuda()
|
||||||
output_tensor = torch.zeros(
|
output_tensor = torch.zeros(
|
||||||
(max_batch_size, self.model_config.num_attention_heads * head_dim), dtype=self.dtype, device=self.device
|
(max_batch_size, self.model_config.num_attention_heads * head_dim), dtype=self.dtype, device=self.device
|
||||||
|
@ -127,6 +131,10 @@ class InferenceEngine:
|
||||||
max_num_seqs = self.inference_config.max_batch_size
|
max_num_seqs = self.inference_config.max_batch_size
|
||||||
batch_size_capture_list = [bs for bs in _BATCH_SIZES_TO_CAPTURE if bs <= max_num_seqs]
|
batch_size_capture_list = [bs for bs in _BATCH_SIZES_TO_CAPTURE if bs <= max_num_seqs]
|
||||||
sequence_lengths = torch.ones(max_batch_size, dtype=torch.int).cuda()
|
sequence_lengths = torch.ones(max_batch_size, dtype=torch.int).cuda()
|
||||||
|
# NOTE this is a hack to insure cuda grpah could capture the fixed cuda kernel grid in flash decoding, to make the first seqlen as the max_seq_len
|
||||||
|
sequence_lengths[0] = torch.tensor(
|
||||||
|
self.inference_config.max_context_len_to_capture - 1, dtype=torch.int32
|
||||||
|
).cuda()
|
||||||
|
|
||||||
# NOTE: Capturing the largest batch size first may help reduce the
|
# NOTE: Capturing the largest batch size first may help reduce the
|
||||||
# memory usage of CUDA graph.
|
# memory usage of CUDA graph.
|
||||||
|
@ -385,6 +393,13 @@ class InferenceEngine:
|
||||||
head_dim=batch.head_dim,
|
head_dim=batch.head_dim,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# if not batch.is_prompts:
|
||||||
|
# self.logger.info(f"decoding")
|
||||||
|
# self.logger.info(f"input metadata is: {input_meta_data}")
|
||||||
|
# else:
|
||||||
|
# self.logger.info(f"prefill")
|
||||||
|
# self.logger.info(f"input metadata is: {input_meta_data}")
|
||||||
|
|
||||||
return input_ids, output_tensor, input_meta_data
|
return input_ids, output_tensor, input_meta_data
|
||||||
|
|
||||||
def step(self) -> List[str]:
|
def step(self) -> List[str]:
|
||||||
|
@ -414,6 +429,9 @@ class InferenceEngine:
|
||||||
# TODO: padding_id is used for generating attn_mask and will be removed if nopad version is supported.
|
# TODO: padding_id is used for generating attn_mask and will be removed if nopad version is supported.
|
||||||
logits = model_executable(input_token_ids, output_tensor, input_meta_data, self.k_cache, self.v_cache)
|
logits = model_executable(input_token_ids, output_tensor, input_meta_data, self.k_cache, self.v_cache)
|
||||||
|
|
||||||
|
# logits_ = self.model(input_token_ids, output_tensor, input_meta_data, self.k_cache, self.v_cache)
|
||||||
|
# assert torch.all(logits == logits_), f"error! not equal between origin model({logits_[-1]}) and CUDA Graph({logits[-1]})"
|
||||||
|
|
||||||
if self.inference_config.pad_input:
|
if self.inference_config.pad_input:
|
||||||
logits = logits[:, -1, :]
|
logits = logits[:, -1, :]
|
||||||
self.request_handler.search_tokens(self.generation_config, logits)
|
self.request_handler.search_tokens(self.generation_config, logits)
|
||||||
|
|
|
@ -27,8 +27,7 @@ class CUDAGraphRunner:
|
||||||
assert self.graph is None
|
assert self.graph is None
|
||||||
|
|
||||||
# run kernel once to cache the kernel, avoid stream capture error
|
# run kernel once to cache the kernel, avoid stream capture error
|
||||||
hidden_states = self.model(
|
hidden_states_origin_model = self.model(
|
||||||
# batch,
|
|
||||||
input_tokens_ids,
|
input_tokens_ids,
|
||||||
output_tensor,
|
output_tensor,
|
||||||
inputmetadata,
|
inputmetadata,
|
||||||
|
@ -41,7 +40,7 @@ class CUDAGraphRunner:
|
||||||
# self.logger.info(f"begin capture model...")
|
# self.logger.info(f"begin capture model...")
|
||||||
self.graph = torch.cuda.CUDAGraph()
|
self.graph = torch.cuda.CUDAGraph()
|
||||||
with torch.cuda.graph(self.graph, pool=memory_pool):
|
with torch.cuda.graph(self.graph, pool=memory_pool):
|
||||||
hidden_states = self.model(
|
hidden_states_cuda_graph = self.model(
|
||||||
input_tokens_ids,
|
input_tokens_ids,
|
||||||
output_tensor,
|
output_tensor,
|
||||||
inputmetadata,
|
inputmetadata,
|
||||||
|
@ -52,15 +51,16 @@ class CUDAGraphRunner:
|
||||||
|
|
||||||
# Save the input and output buffers, because replay always uses the same virtual memory space
|
# Save the input and output buffers, because replay always uses the same virtual memory space
|
||||||
self.input_buffers = {
|
self.input_buffers = {
|
||||||
# "batch": batch,
|
|
||||||
"input_tokens_ids": input_tokens_ids,
|
"input_tokens_ids": input_tokens_ids,
|
||||||
"output_tensor": output_tensor,
|
"output_tensor": output_tensor,
|
||||||
"block_tables": inputmetadata.block_tables,
|
"block_tables": inputmetadata.block_tables,
|
||||||
"sequence_lengths": inputmetadata.sequence_lengths,
|
"sequence_lengths": inputmetadata.sequence_lengths,
|
||||||
|
# "fd_inter_tensor_mid_output": inputmetadata.fd_inter_tensor._mid_output,
|
||||||
|
# "fd_inter_tensor_mid_output_lse": inputmetadata.fd_inter_tensor._mid_output_lse,
|
||||||
"k_caches": k_caches,
|
"k_caches": k_caches,
|
||||||
"v_caches": v_caches,
|
"v_caches": v_caches,
|
||||||
}
|
}
|
||||||
self.output_buffers = {"logits": hidden_states}
|
self.output_buffers = {"logits": hidden_states_cuda_graph}
|
||||||
return
|
return
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
|
@ -74,9 +74,18 @@ class CUDAGraphRunner:
|
||||||
# Copy the input tensors to the input buffers.
|
# Copy the input tensors to the input buffers.
|
||||||
self.input_buffers["input_tokens_ids"].copy_(input_tokens_ids, non_blocking=True)
|
self.input_buffers["input_tokens_ids"].copy_(input_tokens_ids, non_blocking=True)
|
||||||
self.input_buffers["output_tensor"].copy_(output_tensor, non_blocking=True)
|
self.input_buffers["output_tensor"].copy_(output_tensor, non_blocking=True)
|
||||||
self.input_buffers["block_tables"].copy_(inputmetadata.block_tables, non_blocking=True)
|
|
||||||
|
# for flexible block_table
|
||||||
|
self.input_buffers["block_tables"].fill_(-1)
|
||||||
|
M, N = inputmetadata.block_tables.shape
|
||||||
|
self.input_buffers["block_tables"][:M, :N].copy_(inputmetadata.block_tables, non_blocking=True)
|
||||||
|
|
||||||
self.input_buffers["sequence_lengths"].copy_(inputmetadata.sequence_lengths, non_blocking=True)
|
self.input_buffers["sequence_lengths"].copy_(inputmetadata.sequence_lengths, non_blocking=True)
|
||||||
|
|
||||||
|
# we only have a global fd_inter_tensor so we don't need to copy them
|
||||||
|
# self.input_buffers["fd_inter_tensor_mid_output"].copy_(inputmetadata.fd_inter_tensor.mid_output, non_blocking=True)
|
||||||
|
# self.input_buffers["fd_inter_tensor_mid_output_lse"].copy_(inputmetadata.fd_inter_tensor.mid_output_lse, non_blocking=True)
|
||||||
|
|
||||||
# KV caches are fixed tensors, so we don't need to copy them.
|
# KV caches are fixed tensors, so we don't need to copy them.
|
||||||
# self.input_buffers["k_caches"].copy_(k_caches, non_blocking=True)
|
# self.input_buffers["k_caches"].copy_(k_caches, non_blocking=True)
|
||||||
# self.input_buffers["v_caches"].copy_(v_caches, non_blocking=True)
|
# self.input_buffers["v_caches"].copy_(v_caches, non_blocking=True)
|
||||||
|
|
|
@ -0,0 +1,94 @@
|
||||||
|
import random
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
from transformers import AutoTokenizer, GenerationConfig, LlamaConfig, LlamaForCausalLM
|
||||||
|
|
||||||
|
import colossalai
|
||||||
|
from colossalai.inference.config import InferenceConfig
|
||||||
|
from colossalai.inference.core.engine import InferenceEngine
|
||||||
|
from colossalai.testing import rerun_if_address_is_in_use, spawn
|
||||||
|
|
||||||
|
|
||||||
|
def setup_seed(seed):
|
||||||
|
torch.manual_seed(seed)
|
||||||
|
torch.cuda.manual_seed_all(seed)
|
||||||
|
np.random.seed(seed)
|
||||||
|
random.seed(seed)
|
||||||
|
|
||||||
|
|
||||||
|
def check_inference_engine(use_cuda_graph=False, batch_size=32):
|
||||||
|
setup_seed(20)
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/llama-tokenizer")
|
||||||
|
model = (
|
||||||
|
LlamaForCausalLM(
|
||||||
|
LlamaConfig(
|
||||||
|
vocab_size=50000, hidden_size=512, intermediate_size=1536, num_attention_heads=4, num_hidden_layers=16
|
||||||
|
)
|
||||||
|
)
|
||||||
|
.cuda()
|
||||||
|
.half()
|
||||||
|
)
|
||||||
|
model = model.eval()
|
||||||
|
|
||||||
|
prompts_token_ids = []
|
||||||
|
for i in range(batch_size):
|
||||||
|
prompts_token_ids.append(np.random.randint(low=0, high=100, size=random.randint(1, 1024)).tolist())
|
||||||
|
|
||||||
|
input_len = 1024
|
||||||
|
output_len = 128
|
||||||
|
do_sample = True
|
||||||
|
top_p = 0.5
|
||||||
|
top_k = 50
|
||||||
|
|
||||||
|
if use_cuda_graph:
|
||||||
|
inference_config = InferenceConfig(
|
||||||
|
max_batch_size=batch_size,
|
||||||
|
max_input_len=input_len,
|
||||||
|
max_output_len=output_len,
|
||||||
|
use_cuda_graph=True,
|
||||||
|
block_size=16,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
inference_config = InferenceConfig(
|
||||||
|
max_batch_size=batch_size,
|
||||||
|
max_input_len=input_len,
|
||||||
|
max_output_len=output_len,
|
||||||
|
use_cuda_graph=False,
|
||||||
|
block_size=16,
|
||||||
|
)
|
||||||
|
|
||||||
|
inference_engine = InferenceEngine(model, tokenizer, inference_config, verbose=True)
|
||||||
|
assert inference_engine.generation_config.max_new_tokens == output_len
|
||||||
|
generation_config = GenerationConfig(do_sample=do_sample, top_p=top_p, top_k=top_k)
|
||||||
|
outputs = inference_engine.generate(prompts_token_ids=prompts_token_ids, generation_config=generation_config)
|
||||||
|
|
||||||
|
# print(f"outputs, use_cuda_grpah is {use_cuda_graph}, output: {outputs}")
|
||||||
|
|
||||||
|
return outputs
|
||||||
|
|
||||||
|
|
||||||
|
def check_output_consistency(batch_size):
|
||||||
|
cuda_graph_output = check_inference_engine(use_cuda_graph=True, batch_size=batch_size)
|
||||||
|
naive_model_output = check_inference_engine(use_cuda_graph=False, batch_size=batch_size)
|
||||||
|
|
||||||
|
for s1, s2 in zip(cuda_graph_output, naive_model_output):
|
||||||
|
assert s1 == s2, f"\nCUDA Graph Output: {s1}\nOrigin Output: {s2}"
|
||||||
|
|
||||||
|
|
||||||
|
def run_dist(rank, world_size, port):
|
||||||
|
colossalai.launch(config={}, rank=rank, world_size=world_size, port=port, host="localhost")
|
||||||
|
check_output_consistency(32)
|
||||||
|
check_output_consistency(64)
|
||||||
|
check_output_consistency(128)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.dist
|
||||||
|
@rerun_if_address_is_in_use()
|
||||||
|
def test_cuda_graph_infer():
|
||||||
|
spawn(run_dist, 1)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
test_cuda_graph_infer()
|
Loading…
Reference in New Issue