mirror of https://github.com/hpcaitech/ColossalAI
aibig-modeldata-parallelismdeep-learningdistributed-computingfoundation-modelsheterogeneous-traininghpcinferencelarge-scalemodel-parallelismpipeline-parallelism
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
100 lines
3.7 KiB
100 lines
3.7 KiB
from typing import Dict, List |
|
|
|
import torch |
|
from torch import nn |
|
|
|
from colossalai.inference.config import InputMetaData |
|
from colossalai.logging import get_dist_logger |
|
|
|
|
|
class CUDAGraphRunner: |
|
def __init__(self, model: nn.Module): |
|
self.model = model |
|
self.graph = None |
|
self.input_buffers: Dict[str, torch.Tensor] = {} |
|
self.output_buffers: Dict[str, torch.Tensor] = {} |
|
self.logger = get_dist_logger(__name__) |
|
|
|
def capture( |
|
self, |
|
input_tokens_ids: torch.Tensor, |
|
output_tensor: torch.Tensor, |
|
inputmetadata: InputMetaData, |
|
k_caches: List[torch.Tensor] = None, |
|
v_caches: List[torch.Tensor] = None, |
|
memory_pool=None, |
|
) -> None: |
|
assert self.graph is None |
|
|
|
# run kernel once to cache the kernel, avoid stream capture error |
|
hidden_states_origin_model = self.model( |
|
input_tokens_ids, |
|
output_tensor, |
|
inputmetadata, |
|
k_caches, |
|
v_caches, |
|
) |
|
torch.cuda.synchronize() |
|
|
|
# Capture the graph. |
|
# self.logger.info(f"begin capture model...") |
|
self.graph = torch.cuda.CUDAGraph() |
|
with torch.cuda.graph(self.graph, pool=memory_pool): |
|
hidden_states_cuda_graph = self.model( |
|
input_tokens_ids, |
|
output_tensor, |
|
inputmetadata, |
|
k_caches, |
|
v_caches, |
|
) |
|
torch.cuda.synchronize() |
|
|
|
# Save the input and output buffers, because replay always uses the same virtual memory space |
|
self.input_buffers = { |
|
"input_tokens_ids": input_tokens_ids, |
|
"output_tensor": output_tensor, |
|
"block_tables": inputmetadata.block_tables, |
|
"sequence_lengths": inputmetadata.sequence_lengths, |
|
# "fd_inter_tensor_mid_output": inputmetadata.fd_inter_tensor._mid_output, |
|
# "fd_inter_tensor_mid_output_lse": inputmetadata.fd_inter_tensor._mid_output_lse, |
|
"k_caches": k_caches, |
|
"v_caches": v_caches, |
|
} |
|
self.output_buffers = {"logits": hidden_states_cuda_graph} |
|
return |
|
|
|
def forward( |
|
self, |
|
input_tokens_ids: torch.Tensor, |
|
output_tensor: torch.Tensor, |
|
inputmetadata: InputMetaData, |
|
k_caches: List[torch.Tensor] = None, |
|
v_caches: List[torch.Tensor] = None, |
|
) -> torch.Tensor: |
|
# Copy the input tensors to the input buffers. |
|
self.input_buffers["input_tokens_ids"].copy_(input_tokens_ids, non_blocking=True) |
|
self.input_buffers["output_tensor"].copy_(output_tensor, non_blocking=True) |
|
|
|
# for flexible block_table |
|
self.input_buffers["block_tables"].fill_(-1) |
|
M, N = inputmetadata.block_tables.shape |
|
self.input_buffers["block_tables"][:M, :N].copy_(inputmetadata.block_tables, non_blocking=True) |
|
|
|
self.input_buffers["sequence_lengths"].copy_(inputmetadata.sequence_lengths, non_blocking=True) |
|
|
|
# we only have a global fd_inter_tensor so we don't need to copy them |
|
# self.input_buffers["fd_inter_tensor_mid_output"].copy_(inputmetadata.fd_inter_tensor.mid_output, non_blocking=True) |
|
# self.input_buffers["fd_inter_tensor_mid_output_lse"].copy_(inputmetadata.fd_inter_tensor.mid_output_lse, non_blocking=True) |
|
|
|
# KV caches are fixed tensors, so we don't need to copy them. |
|
# self.input_buffers["k_caches"].copy_(k_caches, non_blocking=True) |
|
# self.input_buffers["v_caches"].copy_(v_caches, non_blocking=True) |
|
|
|
# Run the graph. |
|
self.graph.replay() |
|
|
|
# Return the output tensor. |
|
return self.output_buffers["logits"] |
|
|
|
def __call__(self, *args, **kwargs): |
|
return self.forward(*args, **kwargs)
|
|
|