mirror of https://github.com/hpcaitech/ColossalAI
Runyu Lu
8 months ago
committed by
GitHub
7 changed files with 413 additions and 52 deletions
@ -0,0 +1,100 @@
|
||||
from typing import Dict, List |
||||
|
||||
import torch |
||||
from torch import nn |
||||
|
||||
from colossalai.inference.config import InputMetaData |
||||
from colossalai.logging import get_dist_logger |
||||
|
||||
|
||||
class CUDAGraphRunner: |
||||
def __init__(self, model: nn.Module): |
||||
self.model = model |
||||
self.graph = None |
||||
self.input_buffers: Dict[str, torch.Tensor] = {} |
||||
self.output_buffers: Dict[str, torch.Tensor] = {} |
||||
self.logger = get_dist_logger(__name__) |
||||
|
||||
def capture( |
||||
self, |
||||
input_tokens_ids: torch.Tensor, |
||||
output_tensor: torch.Tensor, |
||||
inputmetadata: InputMetaData, |
||||
k_caches: List[torch.Tensor] = None, |
||||
v_caches: List[torch.Tensor] = None, |
||||
memory_pool=None, |
||||
) -> None: |
||||
assert self.graph is None |
||||
|
||||
# run kernel once to cache the kernel, avoid stream capture error |
||||
hidden_states_origin_model = self.model( |
||||
input_tokens_ids, |
||||
output_tensor, |
||||
inputmetadata, |
||||
k_caches, |
||||
v_caches, |
||||
) |
||||
torch.cuda.synchronize() |
||||
|
||||
# Capture the graph. |
||||
# self.logger.info(f"begin capture model...") |
||||
self.graph = torch.cuda.CUDAGraph() |
||||
with torch.cuda.graph(self.graph, pool=memory_pool): |
||||
hidden_states_cuda_graph = self.model( |
||||
input_tokens_ids, |
||||
output_tensor, |
||||
inputmetadata, |
||||
k_caches, |
||||
v_caches, |
||||
) |
||||
torch.cuda.synchronize() |
||||
|
||||
# Save the input and output buffers, because replay always uses the same virtual memory space |
||||
self.input_buffers = { |
||||
"input_tokens_ids": input_tokens_ids, |
||||
"output_tensor": output_tensor, |
||||
"block_tables": inputmetadata.block_tables, |
||||
"sequence_lengths": inputmetadata.sequence_lengths, |
||||
# "fd_inter_tensor_mid_output": inputmetadata.fd_inter_tensor._mid_output, |
||||
# "fd_inter_tensor_mid_output_lse": inputmetadata.fd_inter_tensor._mid_output_lse, |
||||
"k_caches": k_caches, |
||||
"v_caches": v_caches, |
||||
} |
||||
self.output_buffers = {"logits": hidden_states_cuda_graph} |
||||
return |
||||
|
||||
def forward( |
||||
self, |
||||
input_tokens_ids: torch.Tensor, |
||||
output_tensor: torch.Tensor, |
||||
inputmetadata: InputMetaData, |
||||
k_caches: List[torch.Tensor] = None, |
||||
v_caches: List[torch.Tensor] = None, |
||||
) -> torch.Tensor: |
||||
# Copy the input tensors to the input buffers. |
||||
self.input_buffers["input_tokens_ids"].copy_(input_tokens_ids, non_blocking=True) |
||||
self.input_buffers["output_tensor"].copy_(output_tensor, non_blocking=True) |
||||
|
||||
# for flexible block_table |
||||
self.input_buffers["block_tables"].fill_(-1) |
||||
M, N = inputmetadata.block_tables.shape |
||||
self.input_buffers["block_tables"][:M, :N].copy_(inputmetadata.block_tables, non_blocking=True) |
||||
|
||||
self.input_buffers["sequence_lengths"].copy_(inputmetadata.sequence_lengths, non_blocking=True) |
||||
|
||||
# we only have a global fd_inter_tensor so we don't need to copy them |
||||
# self.input_buffers["fd_inter_tensor_mid_output"].copy_(inputmetadata.fd_inter_tensor.mid_output, non_blocking=True) |
||||
# self.input_buffers["fd_inter_tensor_mid_output_lse"].copy_(inputmetadata.fd_inter_tensor.mid_output_lse, non_blocking=True) |
||||
|
||||
# KV caches are fixed tensors, so we don't need to copy them. |
||||
# self.input_buffers["k_caches"].copy_(k_caches, non_blocking=True) |
||||
# self.input_buffers["v_caches"].copy_(v_caches, non_blocking=True) |
||||
|
||||
# Run the graph. |
||||
self.graph.replay() |
||||
|
||||
# Return the output tensor. |
||||
return self.output_buffers["logits"] |
||||
|
||||
def __call__(self, *args, **kwargs): |
||||
return self.forward(*args, **kwargs) |
@ -0,0 +1,96 @@
|
||||
import random |
||||
|
||||
import numpy as np |
||||
import pytest |
||||
import torch |
||||
from transformers import AutoTokenizer, GenerationConfig, LlamaConfig, LlamaForCausalLM |
||||
|
||||
import colossalai |
||||
from colossalai.inference.config import InferenceConfig |
||||
from colossalai.inference.core.engine import InferenceEngine |
||||
from colossalai.testing import rerun_if_address_is_in_use, spawn |
||||
|
||||
|
||||
def setup_seed(seed): |
||||
torch.manual_seed(seed) |
||||
torch.cuda.manual_seed_all(seed) |
||||
np.random.seed(seed) |
||||
random.seed(seed) |
||||
|
||||
|
||||
def check_inference_engine(use_cuda_graph=False, batch_size=32): |
||||
setup_seed(20) |
||||
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/llama-tokenizer") |
||||
model = ( |
||||
LlamaForCausalLM( |
||||
LlamaConfig( |
||||
vocab_size=50000, hidden_size=512, intermediate_size=1536, num_attention_heads=4, num_hidden_layers=16 |
||||
) |
||||
) |
||||
.cuda() |
||||
.half() |
||||
) |
||||
model = model.eval() |
||||
|
||||
prompts_token_ids = [] |
||||
for i in range(batch_size): |
||||
prompts_token_ids.append( |
||||
np.random.randint(low=0, high=100, size=random.randint(1, max(1024 // batch_size, 32))).tolist() |
||||
) |
||||
|
||||
input_len = 1024 |
||||
output_len = 128 |
||||
do_sample = True |
||||
top_p = 0.5 |
||||
top_k = 50 |
||||
|
||||
if use_cuda_graph: |
||||
inference_config = InferenceConfig( |
||||
max_batch_size=batch_size, |
||||
max_input_len=input_len, |
||||
max_output_len=output_len, |
||||
use_cuda_kernel=False, |
||||
use_cuda_graph=True, |
||||
block_size=16, |
||||
) |
||||
else: |
||||
inference_config = InferenceConfig( |
||||
max_batch_size=batch_size, |
||||
max_input_len=input_len, |
||||
max_output_len=output_len, |
||||
use_cuda_kernel=False, |
||||
use_cuda_graph=False, |
||||
block_size=16, |
||||
) |
||||
|
||||
inference_engine = InferenceEngine(model, tokenizer, inference_config, verbose=True) |
||||
assert inference_engine.generation_config.max_new_tokens == output_len |
||||
generation_config = GenerationConfig(do_sample=do_sample, top_p=top_p, top_k=top_k) |
||||
outputs = inference_engine.generate(prompts_token_ids=prompts_token_ids, generation_config=generation_config) |
||||
|
||||
return outputs |
||||
|
||||
|
||||
def check_output_consistency(batch_size): |
||||
cuda_graph_output = check_inference_engine(use_cuda_graph=True, batch_size=batch_size) |
||||
naive_model_output = check_inference_engine(use_cuda_graph=False, batch_size=batch_size) |
||||
|
||||
for s1, s2 in zip(cuda_graph_output, naive_model_output): |
||||
assert s1 == s2, f"\nCUDA Graph Output: {s1}\nOrigin Output: {s2}" |
||||
|
||||
|
||||
def run_dist(rank, world_size, port): |
||||
colossalai.launch(config={}, rank=rank, world_size=world_size, port=port, host="localhost") |
||||
check_output_consistency(32) |
||||
check_output_consistency(64) |
||||
check_output_consistency(128) |
||||
|
||||
|
||||
@pytest.mark.dist |
||||
@rerun_if_address_is_in_use() |
||||
def test_cuda_graph_infer(): |
||||
spawn(run_dist, 1) |
||||
|
||||
|
||||
if __name__ == "__main__": |
||||
test_cuda_graph_infer() |
Loading…
Reference in new issue