mirror of https://github.com/hpcaitech/ColossalAI
[feat] cuda graph support and refactor non-functional api
parent
593a72e4d5
commit
cefaeb5fdd
|
@ -14,7 +14,6 @@ GibiByte = 1024**3
|
||||||
|
|
||||||
logger = logging.Logger(__name__)
|
logger = logging.Logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
_DTYPE_MAPPING = {
|
_DTYPE_MAPPING = {
|
||||||
"fp16": torch.float16,
|
"fp16": torch.float16,
|
||||||
"bf16": torch.bfloat16,
|
"bf16": torch.bfloat16,
|
||||||
|
@ -23,13 +22,37 @@ _DTYPE_MAPPING = {
|
||||||
|
|
||||||
_ALLOWED_DTYPES = [torch.float16, torch.bfloat16, torch.float32]
|
_ALLOWED_DTYPES = [torch.float16, torch.bfloat16, torch.float32]
|
||||||
|
|
||||||
|
|
||||||
_DEFAULT_PROMPT_TEMPLATES = {
|
_DEFAULT_PROMPT_TEMPLATES = {
|
||||||
"llama": "[INST] <<SYS>>\nYou are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature. If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.\n<</SYS>>\n{input_text}[/INST]",
|
"llama": "[INST] <<SYS>>\nYou are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature. If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.\n<</SYS>>\n{input_text}[/INST]",
|
||||||
"vicuna": "USER: {input_text}\n\nASSISTANT: ",
|
"vicuna": "USER: {input_text}\n\nASSISTANT: ",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class InputMetaData:
|
||||||
|
"""The input info for a single step
|
||||||
|
|
||||||
|
Args:
|
||||||
|
block_tables (torch.Tensor, optional): Sequences' BlockTables Defaults to None.
|
||||||
|
sequence_lengths (torch.Tensor): A tensor containing sequence lengths.
|
||||||
|
fd_inter_tensor (torch.Tensor, optional): A tensor representing intermediate data for flash decoding. Defaults to None.
|
||||||
|
batch_size (int, optional): The current batch size. Defaults to 64.
|
||||||
|
is_prompts (bool, optional): Indicates whether prefill or decoding. Defaults to False(decoding).
|
||||||
|
use_cuda_graph (bool, optional): Indicates whether to use the CUDA graph. Defaults to False.
|
||||||
|
kv_seq_len (int, optional): Key-value sequence length. Defaults to 512.
|
||||||
|
head_dim (int, optional): Head dimension. Defaults to 32.
|
||||||
|
"""
|
||||||
|
|
||||||
|
block_tables: torch.Tensor = None
|
||||||
|
sequence_lengths: torch.Tensor = None
|
||||||
|
fd_inter_tensor: torch.Tensor = None
|
||||||
|
batch_size: int = 64 # current_batch_size
|
||||||
|
is_prompts: bool = False
|
||||||
|
use_cuda_graph: bool = False
|
||||||
|
kv_seq_len: int = 512
|
||||||
|
head_dim: int = 32
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class InferenceConfig:
|
class InferenceConfig:
|
||||||
"""The inference configuration.
|
"""The inference configuration.
|
||||||
|
@ -55,6 +78,8 @@ class InferenceConfig:
|
||||||
pp_size (int): Pipeline parallel size, defaults to 1.
|
pp_size (int): Pipeline parallel size, defaults to 1.
|
||||||
micro_batch_size (int): the micro batch size, defaults to 1. Only useful when `pp_size` > 1.
|
micro_batch_size (int): the micro batch size, defaults to 1. Only useful when `pp_size` > 1.
|
||||||
micro_batch_buffer_size (int): the buffer size for micro batch. Normally, it should be the same as the number of pipeline stages.
|
micro_batch_buffer_size (int): the buffer size for micro batch. Normally, it should be the same as the number of pipeline stages.
|
||||||
|
use_cuda_graph (bool): Whether to enforce CUDA graph execution. If False, we will disable CUDA graph and always execute the model in eager mode. If True, we will use eager execution in hybrid.
|
||||||
|
max_context_len_to_capture (int)
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
@ -90,6 +115,10 @@ class InferenceConfig:
|
||||||
micro_batch_size: int = 1
|
micro_batch_size: int = 1
|
||||||
micro_batch_buffer_size: int = None
|
micro_batch_buffer_size: int = None
|
||||||
|
|
||||||
|
# cuda_graph
|
||||||
|
use_cuda_graph: bool = False
|
||||||
|
max_context_len_to_capture: int = max_input_len * max_output_len
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
self._verify_config()
|
self._verify_config()
|
||||||
|
|
||||||
|
|
|
@ -1,5 +1,7 @@
|
||||||
|
import copy
|
||||||
|
import time
|
||||||
from itertools import count
|
from itertools import count
|
||||||
from typing import List, Optional, Union
|
from typing import Dict, List, Optional, Tuple, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
@ -7,7 +9,9 @@ import torch.nn as nn
|
||||||
from transformers import GenerationConfig, PreTrainedTokenizer, PreTrainedTokenizerFast
|
from transformers import GenerationConfig, PreTrainedTokenizer, PreTrainedTokenizerFast
|
||||||
|
|
||||||
from colossalai.cluster import ProcessGroupMesh
|
from colossalai.cluster import ProcessGroupMesh
|
||||||
from colossalai.inference.config import InferenceConfig
|
from colossalai.inference.batch_bucket import BatchBucket
|
||||||
|
from colossalai.inference.config import InferenceConfig, InputMetaData
|
||||||
|
from colossalai.inference.graph_runner import CUDAGraphRunner
|
||||||
from colossalai.inference.modeling.policy import model_policy_map
|
from colossalai.inference.modeling.policy import model_policy_map
|
||||||
from colossalai.inference.struct import Sequence
|
from colossalai.inference.struct import Sequence
|
||||||
from colossalai.logging import get_dist_logger
|
from colossalai.logging import get_dist_logger
|
||||||
|
@ -81,11 +85,89 @@ class InferenceEngine:
|
||||||
self.logger = get_dist_logger(__name__)
|
self.logger = get_dist_logger(__name__)
|
||||||
|
|
||||||
self.request_handler = RequestHandler(self.inference_config, self.model_config)
|
self.request_handler = RequestHandler(self.inference_config, self.model_config)
|
||||||
self.k_cahce, self.v_cache = self.request_handler.get_kvcache()
|
self.k_cache, self.v_cache = self.request_handler.get_kvcache()
|
||||||
# DISCUSS maybe move this into batch info?
|
# DISCUSS maybe move this into batch info?
|
||||||
|
|
||||||
self.counter = count()
|
self.counter = count()
|
||||||
|
|
||||||
|
self.use_cuda_graph = self.inference_config.use_cuda_graph
|
||||||
|
if self.use_cuda_graph:
|
||||||
|
self.graph_runners: Dict[int, CUDAGraphRunner] = {}
|
||||||
|
self.graph_memory_pool = None # Set during graph capture.
|
||||||
|
if verbose:
|
||||||
|
self.logger.info("Colossal AI CUDA Graph Capture on")
|
||||||
|
|
||||||
|
self.capture_model(self.k_cache, self.v_cache)
|
||||||
|
|
||||||
|
@torch.inference_mode()
|
||||||
|
def capture_model(self, k_cache: torch.Tensor, v_cache: torch.Tensor):
|
||||||
|
assert self.use_cuda_graph, "please turn on the cuda graph"
|
||||||
|
|
||||||
|
if self.verbose:
|
||||||
|
self.logger.info("Colossal AI CUDA Graph Capture begin")
|
||||||
|
|
||||||
|
t_capture_begin = time.perf_counter()
|
||||||
|
|
||||||
|
_BATCH_SIZES_TO_CAPTURE = [1, 2, 4] + [8 * i for i in range(1, 33)]
|
||||||
|
|
||||||
|
block_size = self.inference_config.block_size
|
||||||
|
|
||||||
|
# Prepare dummy inputs. These will be reused for all batch sizes.
|
||||||
|
max_batch_size = max(_BATCH_SIZES_TO_CAPTURE)
|
||||||
|
|
||||||
|
max_context_len_to_capture = self.inference_config.max_context_len_to_capture
|
||||||
|
max_num_blocks = (max_context_len_to_capture + block_size - 1) // block_size
|
||||||
|
input_tokens = torch.zeros(max_batch_size, 1, dtype=torch.long).cuda()
|
||||||
|
self.graph_block_tables = np.zeros((max(_BATCH_SIZES_TO_CAPTURE), max_num_blocks), dtype=np.int32)
|
||||||
|
block_tables = torch.from_numpy(self.graph_block_tables).cuda()
|
||||||
|
max_num_seqs = self.inference_config.max_batch_size
|
||||||
|
batch_size_capture_list = [bs for bs in _BATCH_SIZES_TO_CAPTURE if bs <= max_num_seqs]
|
||||||
|
|
||||||
|
# NOTE: Capturing the largest batch size first may help reduce the
|
||||||
|
# memory usage of CUDA graph.
|
||||||
|
for batch_size in reversed(batch_size_capture_list[-1:]):
|
||||||
|
batch_bucket_for_capture = copy.deepcopy(self.request_handler.running_bb)
|
||||||
|
batch_bucket_for_capture.fd_interm_tensor = self.request_handler.running_bb.fd_interm_tensor
|
||||||
|
|
||||||
|
if self.verbose:
|
||||||
|
self.logger.info(f"batch size {batch_size} graph capturing")
|
||||||
|
|
||||||
|
# generate dummy input
|
||||||
|
for i in range(batch_size):
|
||||||
|
sequence = Sequence(
|
||||||
|
i,
|
||||||
|
None,
|
||||||
|
input_tokens[i],
|
||||||
|
block_size,
|
||||||
|
None,
|
||||||
|
self.tokenizer.eos_token_id,
|
||||||
|
self.tokenizer.pad_token_id,
|
||||||
|
self.inference_config.max_output_len,
|
||||||
|
)
|
||||||
|
sequence.output_token_id = [0] # only capture the graph of decoding
|
||||||
|
batch_bucket_for_capture.add_seq(sequence, alloc_block_table=block_tables[i])
|
||||||
|
|
||||||
|
input_data = self.prepare_input(batch_bucket_for_capture)
|
||||||
|
|
||||||
|
input_tokens_ids, output_tensor, inputmetadata = input_data
|
||||||
|
|
||||||
|
graph_runner = CUDAGraphRunner(self.model)
|
||||||
|
graph_runner.capture(
|
||||||
|
input_tokens_ids,
|
||||||
|
output_tensor,
|
||||||
|
inputmetadata,
|
||||||
|
k_caches=k_cache,
|
||||||
|
v_caches=v_cache,
|
||||||
|
memory_pool=self.graph_memory_pool,
|
||||||
|
)
|
||||||
|
self.graph_memory_pool = graph_runner.graph.pool()
|
||||||
|
self.graph_runners[batch_size] = graph_runner
|
||||||
|
|
||||||
|
t_capture_end = time.perf_counter()
|
||||||
|
|
||||||
|
if self.verbose:
|
||||||
|
self.logger.info(f"CUDA Graph capture time: {t_capture_end - t_capture_begin} s")
|
||||||
|
|
||||||
def _verify_config(self) -> None:
|
def _verify_config(self) -> None:
|
||||||
"""
|
"""
|
||||||
Verify the input config
|
Verify the input config
|
||||||
|
@ -278,13 +360,47 @@ class InferenceEngine:
|
||||||
)
|
)
|
||||||
self.request_handler.add_sequence(sequence)
|
self.request_handler.add_sequence(sequence)
|
||||||
|
|
||||||
|
def prepare_input(self, batch: BatchBucket) -> Tuple[torch.Tensor, torch.Tensor, InputMetaData]:
|
||||||
|
input_ids = batch.get_1D_inputs()
|
||||||
|
|
||||||
|
sequence_lengths = batch.get_sequence_lengths()
|
||||||
|
if batch.is_prompts:
|
||||||
|
output_tensor = torch.zeros(
|
||||||
|
(sequence_lengths.sum().item(), batch.num_heads * batch.head_dim),
|
||||||
|
dtype=batch.dtype,
|
||||||
|
device=batch.device,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
output_tensor = torch.zeros(
|
||||||
|
(batch.current_batch_size, batch.num_heads * batch.head_dim), dtype=batch.dtype, device=batch.device
|
||||||
|
)
|
||||||
|
|
||||||
|
# only when we have the graph for specific decoding batch size can we use the cuda graph for inference
|
||||||
|
use_cuda_graph = False
|
||||||
|
if self.use_cuda_graph and not batch.is_prompts and batch.current_batch_size in self.graph_runners.keys():
|
||||||
|
use_cuda_graph = True
|
||||||
|
|
||||||
|
input_meta_data = InputMetaData(
|
||||||
|
block_tables=batch.get_block_table_tensor(),
|
||||||
|
sequence_lengths=sequence_lengths,
|
||||||
|
fd_inter_tensor=batch.fd_inter_tensor,
|
||||||
|
batch_size=batch.current_batch_size,
|
||||||
|
is_prompts=batch.is_prompts,
|
||||||
|
use_cuda_graph=use_cuda_graph,
|
||||||
|
kv_seq_len=sequence_lengths.max().item(),
|
||||||
|
head_dim=batch.head_dim,
|
||||||
|
)
|
||||||
|
|
||||||
|
return input_ids, output_tensor, input_meta_data
|
||||||
|
|
||||||
def step(self) -> List[str]:
|
def step(self) -> List[str]:
|
||||||
"""
|
"""
|
||||||
In each step, do the follows:
|
In each step, do the follows:
|
||||||
1. Run RequestHandler.schedule() and get the batch used for inference.
|
1. Run RequestHandler.schedule() and get the batch used for inference.
|
||||||
2. Run model to generate the next token
|
2. Get the input, inputinfo and output placeholder from the batchbucket
|
||||||
3. Update waiting list and running list in RequestHandler and get finished sequences.
|
3. Run model to generate the next token
|
||||||
4. Decode and return finished sequences.
|
4. Update waiting list and running list in RequestHandler and get finished sequences.
|
||||||
|
5. Decode and return finished sequences.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
List[str]: Decoded finished sequences generated by one step.
|
List[str]: Decoded finished sequences generated by one step.
|
||||||
|
@ -292,12 +408,15 @@ class InferenceEngine:
|
||||||
|
|
||||||
batch = self.request_handler.schedule()
|
batch = self.request_handler.schedule()
|
||||||
|
|
||||||
|
input_token_ids, output_tensor, input_meta_data = self.prepare_input(batch)
|
||||||
|
|
||||||
|
if input_meta_data.use_cuda_graph:
|
||||||
|
model_executable = self.graph_runners[input_meta_data.batch_size]
|
||||||
|
else:
|
||||||
|
model_executable = self.model
|
||||||
|
|
||||||
# TODO: padding_id is used for generating attn_mask and will be removed if nopad version is supported.
|
# TODO: padding_id is used for generating attn_mask and will be removed if nopad version is supported.
|
||||||
logits = self.model(
|
logits = model_executable(input_token_ids, output_tensor, input_meta_data, self.k_cache, self.v_cache)
|
||||||
batch,
|
|
||||||
self.k_cahce,
|
|
||||||
self.v_cache,
|
|
||||||
)
|
|
||||||
|
|
||||||
if self.inference_config.pad_input:
|
if self.inference_config.pad_input:
|
||||||
logits = logits[:, -1, :]
|
logits = logits[:, -1, :]
|
||||||
|
|
|
@ -0,0 +1,92 @@
|
||||||
|
from typing import Dict, List
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
|
||||||
|
from colossalai.inference.config import InputMetaData
|
||||||
|
from colossalai.logging import get_dist_logger
|
||||||
|
|
||||||
|
|
||||||
|
class CUDAGraphRunner:
|
||||||
|
def __init__(self, model: nn.Module):
|
||||||
|
self.model = model
|
||||||
|
self.graph = None
|
||||||
|
self.input_buffers: Dict[str, torch.Tensor] = {}
|
||||||
|
self.output_buffers: Dict[str, torch.Tensor] = {}
|
||||||
|
self.logger = get_dist_logger(__name__)
|
||||||
|
|
||||||
|
def capture(
|
||||||
|
self,
|
||||||
|
input_tokens_ids: torch.Tensor,
|
||||||
|
output_tensor: torch.Tensor,
|
||||||
|
inputmetadata: InputMetaData,
|
||||||
|
k_caches: List[torch.Tensor] = None,
|
||||||
|
v_caches: List[torch.Tensor] = None,
|
||||||
|
memory_pool=None,
|
||||||
|
) -> None:
|
||||||
|
assert self.graph is None
|
||||||
|
|
||||||
|
# run kernel once to cache the kernel, avoid stream capture error
|
||||||
|
hidden_states = self.model(
|
||||||
|
# batch,
|
||||||
|
input_tokens_ids,
|
||||||
|
output_tensor,
|
||||||
|
inputmetadata,
|
||||||
|
k_caches,
|
||||||
|
v_caches,
|
||||||
|
)
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
|
||||||
|
# Capture the graph.
|
||||||
|
# self.logger.info(f"begin capture model...")
|
||||||
|
self.graph = torch.cuda.CUDAGraph()
|
||||||
|
with torch.cuda.graph(self.graph, pool=memory_pool):
|
||||||
|
hidden_states = self.model(
|
||||||
|
# batch,
|
||||||
|
input_tokens_ids,
|
||||||
|
output_tensor,
|
||||||
|
inputmetadata,
|
||||||
|
k_caches,
|
||||||
|
v_caches,
|
||||||
|
)
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
|
||||||
|
# Save the input and output buffers, because replay always uses the same virtual memory space
|
||||||
|
self.input_buffers = {
|
||||||
|
# "batch": batch,
|
||||||
|
"input_tokens_ids": input_tokens_ids,
|
||||||
|
"output_tensor": output_tensor,
|
||||||
|
"block_tables": inputmetadata.block_tables,
|
||||||
|
"sequence_lengths": inputmetadata.sequence_lengths,
|
||||||
|
"k_caches": k_caches,
|
||||||
|
"v_caches": v_caches,
|
||||||
|
}
|
||||||
|
self.output_buffers = {"logits": hidden_states}
|
||||||
|
return
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input_tokens_ids: torch.Tensor,
|
||||||
|
output_tensor: torch.Tensor,
|
||||||
|
inputmetadata: InputMetaData,
|
||||||
|
k_caches: List[torch.Tensor] = None,
|
||||||
|
v_caches: List[torch.Tensor] = None,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
# Copy the input tensors to the input buffers.
|
||||||
|
self.input_buffers["input_tokens_ids"].copy_(input_tokens_ids, non_blocking=True)
|
||||||
|
self.input_buffers["output_tensor"].copy_(output_tensor, non_blocking=True)
|
||||||
|
self.input_buffers["block_tables"].copy_(inputmetadata.block_tables, non_blocking=True)
|
||||||
|
self.input_buffers["sequence_lengths"].copy_(inputmetadata.sequence_lengths, non_blocking=True)
|
||||||
|
|
||||||
|
# KV caches are fixed tensors, so we don't need to copy them.
|
||||||
|
# self.input_buffers["k_caches"].copy_(k_caches, non_blocking=True)
|
||||||
|
# self.input_buffers["v_caches"].copy_(v_caches, non_blocking=True)
|
||||||
|
|
||||||
|
# Run the graph.
|
||||||
|
self.graph.replay()
|
||||||
|
|
||||||
|
# Return the output tensor.
|
||||||
|
return self.output_buffers["logits"]
|
||||||
|
|
||||||
|
def __call__(self, *args, **kwargs):
|
||||||
|
return self.forward(*args, **kwargs)
|
|
@ -11,7 +11,7 @@ from transformers.models.llama.modeling_llama import (
|
||||||
LlamaModel,
|
LlamaModel,
|
||||||
)
|
)
|
||||||
|
|
||||||
from colossalai.inference.batch_bucket import BatchBucket
|
from colossalai.inference.config import InputMetaData
|
||||||
from colossalai.inference.flash_decoding_utils import FDIntermTensors
|
from colossalai.inference.flash_decoding_utils import FDIntermTensors
|
||||||
from colossalai.kernel.kernel_loader import InferenceOpsLoader
|
from colossalai.kernel.kernel_loader import InferenceOpsLoader
|
||||||
from colossalai.kernel.triton import (
|
from colossalai.kernel.triton import (
|
||||||
|
@ -36,10 +36,12 @@ except ImportError:
|
||||||
|
|
||||||
def llama_causal_lm_forward(
|
def llama_causal_lm_forward(
|
||||||
self: LlamaForCausalLM,
|
self: LlamaForCausalLM,
|
||||||
batch: BatchBucket = None,
|
input_tokens_ids: torch.Tensor,
|
||||||
|
output_tensor: torch.Tensor,
|
||||||
|
inputmetadata: InputMetaData,
|
||||||
k_caches: List[torch.Tensor] = None,
|
k_caches: List[torch.Tensor] = None,
|
||||||
v_caches: List[torch.Tensor] = None,
|
v_caches: List[torch.Tensor] = None,
|
||||||
):
|
) -> torch.Tensor:
|
||||||
"""This function will replace the forward function of LlamaForCausalLM.
|
"""This function will replace the forward function of LlamaForCausalLM.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
@ -51,7 +53,9 @@ def llama_causal_lm_forward(
|
||||||
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
||||||
hidden_states = llama_model_forward(
|
hidden_states = llama_model_forward(
|
||||||
self.model,
|
self.model,
|
||||||
batch=batch,
|
input_tokens_ids=input_tokens_ids,
|
||||||
|
output_tensor=output_tensor,
|
||||||
|
inputmetadata=inputmetadata,
|
||||||
k_caches=k_caches,
|
k_caches=k_caches,
|
||||||
v_caches=v_caches,
|
v_caches=v_caches,
|
||||||
)
|
)
|
||||||
|
@ -61,10 +65,12 @@ def llama_causal_lm_forward(
|
||||||
|
|
||||||
def llama_model_forward(
|
def llama_model_forward(
|
||||||
self: LlamaModel,
|
self: LlamaModel,
|
||||||
batch: BatchBucket = None,
|
input_tokens_ids: torch.Tensor,
|
||||||
|
output_tensor: torch.Tensor,
|
||||||
|
inputmetadata: InputMetaData,
|
||||||
k_caches: List[torch.Tensor] = None,
|
k_caches: List[torch.Tensor] = None,
|
||||||
v_caches: List[torch.Tensor] = None,
|
v_caches: List[torch.Tensor] = None,
|
||||||
):
|
) -> torch.Tensor:
|
||||||
"""This function will replace the forward function of LlamaModel.
|
"""This function will replace the forward function of LlamaModel.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
@ -72,11 +78,10 @@ def llama_model_forward(
|
||||||
k_caches (List[torch.Tensor], optional): It holds the GPU memory for the key cache. Defaults to None.
|
k_caches (List[torch.Tensor], optional): It holds the GPU memory for the key cache. Defaults to None.
|
||||||
v_caches (List[torch.Tensor], optional): It holds the GPU memory for the value cache. Defaults to None.
|
v_caches (List[torch.Tensor], optional): It holds the GPU memory for the value cache. Defaults to None.
|
||||||
"""
|
"""
|
||||||
input_ids = batch.get_1D_inputs()
|
block_tables = inputmetadata.block_tables
|
||||||
block_tables = batch.get_block_table_tensor()
|
sequence_lengths = inputmetadata.sequence_lengths
|
||||||
sequence_lengths = batch.get_sequence_lengths()
|
batch_size = inputmetadata.batch_size
|
||||||
batch_size = batch.current_batch_size
|
kv_seq_len = inputmetadata.kv_seq_len
|
||||||
kv_seq_len = sequence_lengths.max().item()
|
|
||||||
use_cuda_kernel = True
|
use_cuda_kernel = True
|
||||||
# NOTE: After testing, the performance of this configuration is relatively good. With updates
|
# NOTE: After testing, the performance of this configuration is relatively good. With updates
|
||||||
# and optimizations to the CUDA kernel implementation, a more detailed analysis of this configuration's
|
# and optimizations to the CUDA kernel implementation, a more detailed analysis of this configuration's
|
||||||
|
@ -84,21 +89,13 @@ def llama_model_forward(
|
||||||
if batch_size >= 32 and kv_seq_len > 512:
|
if batch_size >= 32 and kv_seq_len > 512:
|
||||||
use_cuda_kernel = False
|
use_cuda_kernel = False
|
||||||
|
|
||||||
hidden_states = self.embed_tokens(input_ids)
|
hidden_states = self.embed_tokens(input_tokens_ids)
|
||||||
|
|
||||||
cos_sin = get_xine_cache(sequence_lengths, self._cos_cached, self._sin_cached, batch.is_prompts)
|
cos_sin = get_xine_cache(sequence_lengths, self._cos_cached, self._sin_cached, inputmetadata.is_prompts)
|
||||||
|
|
||||||
if batch.is_prompts:
|
sm_scale = 1.0 / (inputmetadata.head_dim**0.5)
|
||||||
output_tensor = torch.zeros(
|
|
||||||
(sequence_lengths.sum().item(), batch.num_heads * batch.head_dim), dtype=batch.dtype, device=batch.device
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
output_tensor = torch.zeros(
|
|
||||||
(batch_size, batch.num_heads * batch.head_dim), dtype=batch.dtype, device=batch.device
|
|
||||||
)
|
|
||||||
sm_scale = 1.0 / (batch.head_dim**0.5)
|
|
||||||
|
|
||||||
norm_output = torch.empty_like(hidden_states)
|
norm_output = None
|
||||||
residual = None
|
residual = None
|
||||||
|
|
||||||
for layer_id, decoder_layer in enumerate(self.layers):
|
for layer_id, decoder_layer in enumerate(self.layers):
|
||||||
|
@ -108,22 +105,22 @@ def llama_model_forward(
|
||||||
block_tables=block_tables,
|
block_tables=block_tables,
|
||||||
k_cache=k_caches[layer_id],
|
k_cache=k_caches[layer_id],
|
||||||
v_cache=v_caches[layer_id],
|
v_cache=v_caches[layer_id],
|
||||||
is_prompts=batch.is_prompts,
|
is_prompts=inputmetadata.is_prompts,
|
||||||
sequence_lengths=sequence_lengths,
|
sequence_lengths=sequence_lengths,
|
||||||
kv_seq_len=kv_seq_len,
|
kv_seq_len=kv_seq_len,
|
||||||
cos_sin=cos_sin,
|
cos_sin=cos_sin,
|
||||||
fd_inter_tensor=batch.fd_inter_tensor,
|
fd_inter_tensor=inputmetadata.fd_inter_tensor,
|
||||||
output_tensor=output_tensor,
|
output_tensor=output_tensor,
|
||||||
norm_output=norm_output,
|
norm_output=norm_output,
|
||||||
sm_scale=sm_scale,
|
sm_scale=sm_scale,
|
||||||
use_cuda_kernel=use_cuda_kernel,
|
use_cuda_kernel=use_cuda_kernel,
|
||||||
)
|
)
|
||||||
|
|
||||||
if batch.is_prompts:
|
if inputmetadata.is_prompts:
|
||||||
last_token_indexs = sequence_lengths.cumsum(dim=-1)
|
last_token_indexs = sequence_lengths.cumsum(dim=-1)
|
||||||
hidden_states = hidden_states[last_token_indexs - 1].contiguous()
|
hidden_states = hidden_states[last_token_indexs - 1].contiguous()
|
||||||
residual = residual[last_token_indexs - 1].contiguous()
|
residual = residual[last_token_indexs - 1].contiguous()
|
||||||
norm_output = torch.empty_like(hidden_states)
|
norm_output = torch.empty_like(hidden_states) # NOTE non-functional, but cuda graph only capture decoding only
|
||||||
hidden_states, _ = self.norm(hidden_states, norm_output, residual)
|
hidden_states, _ = self.norm(hidden_states, norm_output, residual)
|
||||||
|
|
||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
|
@ -1,5 +1,3 @@
|
||||||
import torch
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import triton
|
import triton
|
||||||
import triton.language as tl
|
import triton.language as tl
|
||||||
|
@ -94,7 +92,10 @@ if HAS_TRITON:
|
||||||
|
|
||||||
def rms_layernorm(x, weight, eps, norm_output=None, residual=None):
|
def rms_layernorm(x, weight, eps, norm_output=None, residual=None):
|
||||||
# allocate output
|
# allocate output
|
||||||
y = torch.empty_like(x) if norm_output is None else norm_output
|
# y = torch.empty_like(x) if norm_output is None else norm_output
|
||||||
|
y = (
|
||||||
|
x * 0 if norm_output is None else norm_output
|
||||||
|
) # to make the operation non-functional, store y as the intermediate activation
|
||||||
M, N = x.shape
|
M, N = x.shape
|
||||||
# Less than 64KB per feature: enqueue fused kernel
|
# Less than 64KB per feature: enqueue fused kernel
|
||||||
MAX_FUSED_SIZE = 65536 // x.element_size()
|
MAX_FUSED_SIZE = 65536 // x.element_size()
|
||||||
|
|
Loading…
Reference in New Issue