Browse Source

Merge pull request #5434 from LRY89757/colossal-infer-cuda-graph

[feat] cuda graph support and refactor non-functional api
pull/5512/head
Runyu Lu 8 months ago committed by GitHub
parent
commit
1d626233ce
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
  1. 2
      colossalai/inference/README.md
  2. 45
      colossalai/inference/config.py
  3. 151
      colossalai/inference/core/engine.py
  4. 100
      colossalai/inference/graph_runner.py
  5. 65
      colossalai/inference/modeling/models/nopadding_llama.py
  6. 6
      colossalai/kernel/triton/rms_layernorm.py
  7. 96
      tests/test_infer/test_cuda_graph.py

2
colossalai/inference/README.md

@ -94,6 +94,8 @@ inference_config = InferenceConfig(
max_batch_size=4,
max_input_len=1024,
max_output_len=512,
use_cuda_kernel=True,
use_cuda_graph=False, # Turn on if you want to use CUDA Graph to accelerate inference
)
# Step 3: create an engine with model and config

45
colossalai/inference/config.py

@ -10,11 +10,12 @@ import torch
import torch.distributed as dist
from transformers.generation import GenerationConfig
from colossalai.inference.flash_decoding_utils import FDIntermTensors
GibiByte = 1024**3
logger = logging.Logger(__name__)
_DTYPE_MAPPING = {
"fp16": torch.float16,
"bf16": torch.bfloat16,
@ -23,13 +24,42 @@ _DTYPE_MAPPING = {
_ALLOWED_DTYPES = [torch.float16, torch.bfloat16, torch.float32]
_DEFAULT_PROMPT_TEMPLATES = {
"llama": "[INST] <<SYS>>\nYou are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature. If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.\n<</SYS>>\n{input_text}[/INST]",
"vicuna": "USER: {input_text}\n\nASSISTANT: ",
}
@dataclass
class InputMetaData:
"""The input info for a single step
Args:
block_tables (torch.Tensor, optional): Sequences' BlockTables Defaults to None.
sequence_lengths (torch.Tensor): A tensor containing sequence lengths.
fd_inter_tensor (torch.Tensor, optional): A tensor representing intermediate data for flash decoding. Defaults to None.
batch_size (int, optional): The current batch size. Defaults to 64.
is_prompts (bool, optional): Indicates whether prefill or decoding. Defaults to False(decoding).
use_cuda_kernel(bool): Whether to use cuda kernel, faster but lose some precision occasionally
use_cuda_graph (bool, optional): Indicates whether to use the CUDA graph. Defaults to False.
kv_seq_len (int, optional): Key-value sequence length. Defaults to 512.
head_dim (int, optional): Head dimension. Defaults to 32.
"""
block_tables: torch.Tensor = None
sequence_lengths: torch.Tensor = None
fd_inter_tensor: FDIntermTensors = None
batch_size: int = 64 # current_batch_size
is_prompts: bool = False
use_cuda_kernel: bool = False
use_cuda_graph: bool = False
kv_seq_len: int = 512
head_dim: int = 32
def __repr__(self) -> str:
return f"InputMetaData(block_tables={self.block_tables}, sequence_lengths={self.sequence_lengths}, fd_inter_tensor={self.fd_inter_tensor}, batch_size={self.batch_size}, is_prompts={self.is_prompts}, use_cuda_graph={self.use_cuda_graph}, kv_seq_len={self.kv_seq_len}, head_dim={self.head_dim})"
@dataclass
class InferenceConfig:
"""The inference configuration.
@ -55,6 +85,9 @@ class InferenceConfig:
pp_size (int): Pipeline parallel size, defaults to 1.
micro_batch_size (int): the micro batch size, defaults to 1. Only useful when `pp_size` > 1.
micro_batch_buffer_size (int): the buffer size for micro batch. Normally, it should be the same as the number of pipeline stages.
use_cuda_kernel(bool): Whether to use cuda kernel, faster but lose some precision occasionally
use_cuda_graph (bool): Whether to enforce CUDA graph execution. If False, we will disable CUDA graph and always execute the model in eager mode. If True, we will use eager execution in hybrid.
max_context_len_to_capture (int): max context len that could be captured by CUDA Graph, per sequence
high_precision(Optional[bool]): Whether to use float32 for underlying calculations of float16 data to achieve higher precision, defaults to False.
"""
@ -91,7 +124,15 @@ class InferenceConfig:
micro_batch_buffer_size: int = None
high_precision: Optional[bool] = False
# cuda kernel option
use_cuda_kernel: bool = False
# cuda_graph
use_cuda_graph: bool = False # NOTE only when we have the graph for specific decoding batch size can we use the cuda graph for inference
max_context_len_to_capture: int = 512
def __post_init__(self):
self.max_context_len_to_capture = self.max_input_len + self.max_output_len
self._verify_config()
def _verify_config(self) -> None:

151
colossalai/inference/core/engine.py

@ -1,5 +1,6 @@
import time
from itertools import count
from typing import List, Optional, Union
from typing import Dict, List, Optional, Tuple, Union
import numpy as np
import torch
@ -7,7 +8,9 @@ import torch.nn as nn
from transformers import GenerationConfig, PreTrainedTokenizer, PreTrainedTokenizerFast
from colossalai.cluster import ProcessGroupMesh
from colossalai.inference.config import InferenceConfig
from colossalai.inference.batch_bucket import BatchBucket
from colossalai.inference.config import InferenceConfig, InputMetaData
from colossalai.inference.graph_runner import CUDAGraphRunner
from colossalai.inference.modeling.policy import model_policy_map
from colossalai.inference.struct import Sequence
from colossalai.logging import get_dist_logger
@ -25,6 +28,8 @@ _supported_models = [
"LlamaForCausalLM",
]
_BATCH_SIZES_TO_CAPTURE = [1, 2, 4] + [8 * i for i in range(1, 33)]
class InferenceEngine:
@ -82,11 +87,93 @@ class InferenceEngine:
self.logger = get_dist_logger(__name__)
self.request_handler = RequestHandler(self.inference_config, self.model_config)
self.k_cahce, self.v_cache = self.request_handler.get_kvcache()
self.k_cache, self.v_cache = self.request_handler.get_kvcache()
# DISCUSS maybe move this into batch info?
self.counter = count()
self.use_cuda_graph = self.inference_config.use_cuda_graph
if self.use_cuda_graph:
self.graph_runners: Dict[int, CUDAGraphRunner] = {}
self.graph_memory_pool = None # Set during graph capture.
if verbose:
self.logger.info("Colossal AI CUDA Graph Capture on")
self.capture_model(self.k_cache, self.v_cache)
@torch.inference_mode()
def capture_model(self, k_cache: List[torch.Tensor], v_cache: List[torch.Tensor]):
assert self.use_cuda_graph, "please turn on the cuda graph"
if self.verbose:
self.logger.info("Colossal AI CUDA Graph Capture begin")
t_capture_begin = time.perf_counter()
block_size = self.inference_config.block_size
head_dim = self.model_config.hidden_size // self.model_config.num_attention_heads
# Prepare dummy inputs. These will be reused for all batch sizes.
max_batch_size = max(_BATCH_SIZES_TO_CAPTURE)
max_context_len_to_capture = self.inference_config.max_context_len_to_capture
max_num_blocks = (max_context_len_to_capture + block_size - 1) // block_size
input_tokens_ids = torch.zeros(max_batch_size, dtype=torch.long).cuda()
# self.graph_block_tables = np.zeros((max(_BATCH_SIZES_TO_CAPTURE), max_num_blocks), dtype=np.int32)
self.graph_block_tables = np.full((max(_BATCH_SIZES_TO_CAPTURE), max_num_blocks), -1, dtype=np.int32)
self.graph_block_tables[:, 0] = np.arange(max_num_blocks, max_num_blocks + max(_BATCH_SIZES_TO_CAPTURE))
self.graph_block_tables[0, :] = np.arange(
0, max_num_blocks
) # NOTE this is a hack to insure cuda grpah could capture the fixed cuda kernel grid in flash decoding, to make the first seqlen as the max_seq_len
block_tables = torch.from_numpy(self.graph_block_tables).cuda()
output_tensor = torch.zeros(
(max_batch_size, self.model_config.num_attention_heads * head_dim), dtype=self.dtype, device=self.device
)
fd_inter_tensor = self.request_handler.running_bb.fd_inter_tensor
max_num_seqs = self.inference_config.max_batch_size
batch_size_capture_list = [bs for bs in _BATCH_SIZES_TO_CAPTURE if bs <= max_num_seqs]
sequence_lengths = torch.ones(max_batch_size, dtype=torch.int).cuda()
# NOTE this is a hack to insure cuda grpah could capture the fixed cuda kernel grid in flash decoding, to make the first seqlen as the max_seq_len
sequence_lengths[0] = torch.tensor(
self.inference_config.max_context_len_to_capture - 1, dtype=torch.int32
).cuda()
# NOTE: Capturing the largest batch size first may help reduce the
# memory usage of CUDA graph.
for batch_size in reversed(batch_size_capture_list):
if self.verbose:
self.logger.info(f"batch size {batch_size} graph capturing")
input_meta_data = InputMetaData(
block_tables=block_tables[:batch_size],
sequence_lengths=sequence_lengths[:batch_size],
fd_inter_tensor=fd_inter_tensor,
batch_size=batch_size,
is_prompts=False,
use_cuda_graph=True,
high_precision=False,
kv_seq_len=sequence_lengths[:batch_size].max().item(),
head_dim=head_dim,
dtype=self.dtype,
)
graph_runner = CUDAGraphRunner(self.model)
graph_runner.capture(
input_tokens_ids[:batch_size],
output_tensor[:batch_size],
input_meta_data,
k_caches=k_cache,
v_caches=v_cache,
memory_pool=self.graph_memory_pool,
)
self.graph_memory_pool = graph_runner.graph.pool()
self.graph_runners[batch_size] = graph_runner
t_capture_end = time.perf_counter()
if self.verbose:
self.logger.info(f"CUDA Graph capture time: {t_capture_end - t_capture_begin} s")
def _verify_config(self) -> None:
"""
Verify the input config
@ -279,13 +366,50 @@ class InferenceEngine:
)
self.request_handler.add_sequence(sequence)
def prepare_input(self, batch: BatchBucket) -> Tuple[torch.Tensor, torch.Tensor, InputMetaData]:
input_ids = batch.get_1D_inputs()
sequence_lengths = batch.get_sequence_lengths()
if batch.is_prompts:
output_tensor = torch.zeros(
(sequence_lengths.sum().item(), batch.num_heads * batch.head_dim),
dtype=batch.dtype,
device=batch.device,
)
else:
output_tensor = torch.zeros(
(batch.current_batch_size, batch.num_heads * batch.head_dim), dtype=batch.dtype, device=batch.device
)
# only when we have the graph for specific decoding batch size can we use the cuda graph for inference
use_cuda_graph = False
if self.use_cuda_graph and not batch.is_prompts and batch.current_batch_size in self.graph_runners.keys():
use_cuda_graph = True
input_meta_data = InputMetaData(
block_tables=batch.get_block_table_tensor(),
sequence_lengths=sequence_lengths,
fd_inter_tensor=batch.fd_inter_tensor,
batch_size=batch.current_batch_size,
is_prompts=batch.is_prompts,
use_cuda_kernel=self.inference_config.use_cuda_kernel,
use_cuda_graph=use_cuda_graph,
high_precision=self.high_precision,
kv_seq_len=sequence_lengths.max().item(),
head_dim=batch.head_dim,
dtype=batch.dtype,
)
return input_ids, output_tensor, input_meta_data
def step(self) -> List[str]:
"""
In each step, do the follows:
1. Run RequestHandler.schedule() and get the batch used for inference.
2. Run model to generate the next token
3. Update waiting list and running list in RequestHandler and get finished sequences.
4. Decode and return finished sequences.
2. Get the input, inputinfo and output placeholder from the batchbucket
3. Run model to generate the next token
4. Update waiting list and running list in RequestHandler and get finished sequences.
5. Decode and return finished sequences.
Returns:
List[str]: Decoded finished sequences generated by one step.
@ -293,14 +417,15 @@ class InferenceEngine:
batch = self.request_handler.schedule()
# TODO: padding_id is used for generating attn_mask and will be removed if nopad version is supported.
logits = self.model(
batch,
self.k_cahce,
self.v_cache,
self.high_precision,
)
input_token_ids, output_tensor, input_meta_data = self.prepare_input(batch)
if input_meta_data.use_cuda_graph:
model_executable = self.graph_runners[input_meta_data.batch_size]
else:
model_executable = self.model
# TODO: padding_id is used for generating attn_mask and will be removed if nopad version is supported.
logits = model_executable(input_token_ids, output_tensor, input_meta_data, self.k_cache, self.v_cache)
if self.inference_config.pad_input:
logits = logits[:, -1, :]
self.request_handler.search_tokens(self.generation_config, logits)

100
colossalai/inference/graph_runner.py

@ -0,0 +1,100 @@
from typing import Dict, List
import torch
from torch import nn
from colossalai.inference.config import InputMetaData
from colossalai.logging import get_dist_logger
class CUDAGraphRunner:
def __init__(self, model: nn.Module):
self.model = model
self.graph = None
self.input_buffers: Dict[str, torch.Tensor] = {}
self.output_buffers: Dict[str, torch.Tensor] = {}
self.logger = get_dist_logger(__name__)
def capture(
self,
input_tokens_ids: torch.Tensor,
output_tensor: torch.Tensor,
inputmetadata: InputMetaData,
k_caches: List[torch.Tensor] = None,
v_caches: List[torch.Tensor] = None,
memory_pool=None,
) -> None:
assert self.graph is None
# run kernel once to cache the kernel, avoid stream capture error
hidden_states_origin_model = self.model(
input_tokens_ids,
output_tensor,
inputmetadata,
k_caches,
v_caches,
)
torch.cuda.synchronize()
# Capture the graph.
# self.logger.info(f"begin capture model...")
self.graph = torch.cuda.CUDAGraph()
with torch.cuda.graph(self.graph, pool=memory_pool):
hidden_states_cuda_graph = self.model(
input_tokens_ids,
output_tensor,
inputmetadata,
k_caches,
v_caches,
)
torch.cuda.synchronize()
# Save the input and output buffers, because replay always uses the same virtual memory space
self.input_buffers = {
"input_tokens_ids": input_tokens_ids,
"output_tensor": output_tensor,
"block_tables": inputmetadata.block_tables,
"sequence_lengths": inputmetadata.sequence_lengths,
# "fd_inter_tensor_mid_output": inputmetadata.fd_inter_tensor._mid_output,
# "fd_inter_tensor_mid_output_lse": inputmetadata.fd_inter_tensor._mid_output_lse,
"k_caches": k_caches,
"v_caches": v_caches,
}
self.output_buffers = {"logits": hidden_states_cuda_graph}
return
def forward(
self,
input_tokens_ids: torch.Tensor,
output_tensor: torch.Tensor,
inputmetadata: InputMetaData,
k_caches: List[torch.Tensor] = None,
v_caches: List[torch.Tensor] = None,
) -> torch.Tensor:
# Copy the input tensors to the input buffers.
self.input_buffers["input_tokens_ids"].copy_(input_tokens_ids, non_blocking=True)
self.input_buffers["output_tensor"].copy_(output_tensor, non_blocking=True)
# for flexible block_table
self.input_buffers["block_tables"].fill_(-1)
M, N = inputmetadata.block_tables.shape
self.input_buffers["block_tables"][:M, :N].copy_(inputmetadata.block_tables, non_blocking=True)
self.input_buffers["sequence_lengths"].copy_(inputmetadata.sequence_lengths, non_blocking=True)
# we only have a global fd_inter_tensor so we don't need to copy them
# self.input_buffers["fd_inter_tensor_mid_output"].copy_(inputmetadata.fd_inter_tensor.mid_output, non_blocking=True)
# self.input_buffers["fd_inter_tensor_mid_output_lse"].copy_(inputmetadata.fd_inter_tensor.mid_output_lse, non_blocking=True)
# KV caches are fixed tensors, so we don't need to copy them.
# self.input_buffers["k_caches"].copy_(k_caches, non_blocking=True)
# self.input_buffers["v_caches"].copy_(v_caches, non_blocking=True)
# Run the graph.
self.graph.replay()
# Return the output tensor.
return self.output_buffers["logits"]
def __call__(self, *args, **kwargs):
return self.forward(*args, **kwargs)

65
colossalai/inference/modeling/models/nopadding_llama.py

@ -13,7 +13,7 @@ from transformers.models.llama.modeling_llama import (
LlamaRMSNorm,
)
from colossalai.inference.batch_bucket import BatchBucket
from colossalai.inference.config import InputMetaData
from colossalai.inference.flash_decoding_utils import FDIntermTensors
from colossalai.kernel.kernel_loader import InferenceOpsLoader
from colossalai.kernel.triton import (
@ -41,11 +41,12 @@ except ImportError:
def llama_causal_lm_forward(
self: LlamaForCausalLM,
batch: BatchBucket,
k_caches: List[torch.Tensor],
v_caches: List[torch.Tensor],
high_precision: bool = False,
):
input_tokens_ids: torch.Tensor,
output_tensor: torch.Tensor,
inputmetadata: InputMetaData,
k_caches: List[torch.Tensor] = None,
v_caches: List[torch.Tensor] = None,
) -> torch.Tensor:
"""This function will replace the forward function of LlamaForCausalLM.
Args:
@ -58,10 +59,13 @@ def llama_causal_lm_forward(
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
hidden_states = llama_model_forward(
self.model,
batch=batch,
input_tokens_ids=input_tokens_ids,
output_tensor=output_tensor,
inputmetadata=inputmetadata,
k_caches=k_caches,
v_caches=v_caches,
high_precision=high_precision,
use_cuda_kernel=inputmetadata.use_cuda_kernel, # Note currently the cuda kernel of layernorm, rotary_embedding_and_cache_copy couldn't pass the unitest but triton kernel could
high_precision=inputmetadata.high_precision,
)
logits = torch.mm(hidden_states, self.lm_head.weight)
return logits
@ -69,11 +73,14 @@ def llama_causal_lm_forward(
def llama_model_forward(
self: LlamaModel,
batch: BatchBucket,
k_caches: List[torch.Tensor],
v_caches: List[torch.Tensor],
input_tokens_ids: torch.Tensor,
output_tensor: torch.Tensor,
inputmetadata: InputMetaData,
k_caches: List[torch.Tensor] = None,
v_caches: List[torch.Tensor] = None,
use_cuda_kernel: Optional[bool] = True,
high_precision: bool = False,
):
) -> torch.Tensor:
"""This function will replace the forward function of LlamaModel.
Args:
@ -82,36 +89,26 @@ def llama_model_forward(
v_caches (List[torch.Tensor]): It holds the GPU memory for the value cache.
high_precision(Optional[bool]): Whether to use float32 for underlying calculations of float16 data to achieve higher precision, defaults to False.
"""
input_ids = batch.get_1D_inputs()
block_tables = batch.get_block_table_tensor()
sequence_lengths = batch.get_sequence_lengths()
batch_size = batch.current_batch_size
kv_seq_len = sequence_lengths.max().item()
use_cuda_kernel = True
block_tables = inputmetadata.block_tables
sequence_lengths = inputmetadata.sequence_lengths
batch_size = inputmetadata.batch_size
kv_seq_len = inputmetadata.kv_seq_len
# NOTE: After testing, the performance of this configuration is relatively good. With updates
# and optimizations to the CUDA kernel implementation, a more detailed analysis of this configuration's
# selection should be conducted.
if batch_size >= 32 and kv_seq_len > 512:
use_cuda_kernel = False
if use_cuda_kernel and batch.dtype != torch.float32 and use_flash_attn2:
hidden_states = self.embed_tokens(input_tokens_ids)
if use_cuda_kernel and inputmetadata != torch.float32 and use_flash_attn2:
cu_seqlens = F.pad(torch.cumsum(sequence_lengths, dim=0, dtype=torch.torch.int32), (1, 0))
else:
cu_seqlens = None
hidden_states = self.embed_tokens(input_ids)
cos_sin = get_xine_cache(sequence_lengths, self._cos_cached, self._sin_cached, batch.is_prompts)
cos_sin = get_xine_cache(sequence_lengths, self._cos_cached, self._sin_cached, inputmetadata.is_prompts)
if batch.is_prompts:
output_tensor = torch.zeros(
(sequence_lengths.sum().item(), batch.num_heads * batch.head_dim), dtype=batch.dtype, device=batch.device
)
else:
output_tensor = torch.zeros(
(batch_size, batch.num_heads * batch.head_dim), dtype=batch.dtype, device=batch.device
)
sm_scale = 1.0 / (batch.head_dim**0.5)
sm_scale = 1.0 / (inputmetadata.head_dim**0.5)
norm_output = torch.empty_like(hidden_states)
residual = None
@ -123,10 +120,10 @@ def llama_model_forward(
block_tables=block_tables,
k_cache=k_caches[layer_id],
v_cache=v_caches[layer_id],
is_prompts=inputmetadata.is_prompts,
sequence_lengths=sequence_lengths,
cos_sin=cos_sin,
fd_inter_tensor=batch.fd_inter_tensor,
is_prompts=batch.is_prompts,
fd_inter_tensor=inputmetadata.fd_inter_tensor,
kv_seq_len=kv_seq_len,
output_tensor=output_tensor,
norm_output=norm_output,
@ -136,7 +133,7 @@ def llama_model_forward(
high_precision=high_precision,
)
if batch.is_prompts:
if inputmetadata.is_prompts:
last_token_indexs = sequence_lengths.cumsum(dim=-1)
hidden_states = hidden_states[last_token_indexs - 1].contiguous()
residual = residual[last_token_indexs - 1].contiguous()

6
colossalai/kernel/triton/rms_layernorm.py

@ -1,5 +1,3 @@
import torch
try:
import triton
import triton.language as tl
@ -94,7 +92,9 @@ if HAS_TRITON:
def rms_layernorm(x, weight, eps, norm_output=None, residual=None):
# allocate output
y = torch.empty_like(x) if norm_output is None else norm_output
y = (
x * 0 if norm_output is None else norm_output
) # to make the operation non-functional, store y as the intermediate activation
M, N = x.shape
# Less than 64KB per feature: enqueue fused kernel
MAX_FUSED_SIZE = 65536 // x.element_size()

96
tests/test_infer/test_cuda_graph.py

@ -0,0 +1,96 @@
import random
import numpy as np
import pytest
import torch
from transformers import AutoTokenizer, GenerationConfig, LlamaConfig, LlamaForCausalLM
import colossalai
from colossalai.inference.config import InferenceConfig
from colossalai.inference.core.engine import InferenceEngine
from colossalai.testing import rerun_if_address_is_in_use, spawn
def setup_seed(seed):
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
np.random.seed(seed)
random.seed(seed)
def check_inference_engine(use_cuda_graph=False, batch_size=32):
setup_seed(20)
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/llama-tokenizer")
model = (
LlamaForCausalLM(
LlamaConfig(
vocab_size=50000, hidden_size=512, intermediate_size=1536, num_attention_heads=4, num_hidden_layers=16
)
)
.cuda()
.half()
)
model = model.eval()
prompts_token_ids = []
for i in range(batch_size):
prompts_token_ids.append(
np.random.randint(low=0, high=100, size=random.randint(1, max(1024 // batch_size, 32))).tolist()
)
input_len = 1024
output_len = 128
do_sample = True
top_p = 0.5
top_k = 50
if use_cuda_graph:
inference_config = InferenceConfig(
max_batch_size=batch_size,
max_input_len=input_len,
max_output_len=output_len,
use_cuda_kernel=False,
use_cuda_graph=True,
block_size=16,
)
else:
inference_config = InferenceConfig(
max_batch_size=batch_size,
max_input_len=input_len,
max_output_len=output_len,
use_cuda_kernel=False,
use_cuda_graph=False,
block_size=16,
)
inference_engine = InferenceEngine(model, tokenizer, inference_config, verbose=True)
assert inference_engine.generation_config.max_new_tokens == output_len
generation_config = GenerationConfig(do_sample=do_sample, top_p=top_p, top_k=top_k)
outputs = inference_engine.generate(prompts_token_ids=prompts_token_ids, generation_config=generation_config)
return outputs
def check_output_consistency(batch_size):
cuda_graph_output = check_inference_engine(use_cuda_graph=True, batch_size=batch_size)
naive_model_output = check_inference_engine(use_cuda_graph=False, batch_size=batch_size)
for s1, s2 in zip(cuda_graph_output, naive_model_output):
assert s1 == s2, f"\nCUDA Graph Output: {s1}\nOrigin Output: {s2}"
def run_dist(rank, world_size, port):
colossalai.launch(config={}, rank=rank, world_size=world_size, port=port, host="localhost")
check_output_consistency(32)
check_output_consistency(64)
check_output_consistency(128)
@pytest.mark.dist
@rerun_if_address_is_in_use()
def test_cuda_graph_infer():
spawn(run_dist, 1)
if __name__ == "__main__":
test_cuda_graph_infer()
Loading…
Cancel
Save