mirror of https://github.com/hpcaitech/ColossalAI
parent
593a72e4d5
commit
cefaeb5fdd
@ -0,0 +1,92 @@
|
|||||||
|
from typing import Dict, List
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
|
||||||
|
from colossalai.inference.config import InputMetaData
|
||||||
|
from colossalai.logging import get_dist_logger
|
||||||
|
|
||||||
|
|
||||||
|
class CUDAGraphRunner:
|
||||||
|
def __init__(self, model: nn.Module):
|
||||||
|
self.model = model
|
||||||
|
self.graph = None
|
||||||
|
self.input_buffers: Dict[str, torch.Tensor] = {}
|
||||||
|
self.output_buffers: Dict[str, torch.Tensor] = {}
|
||||||
|
self.logger = get_dist_logger(__name__)
|
||||||
|
|
||||||
|
def capture(
|
||||||
|
self,
|
||||||
|
input_tokens_ids: torch.Tensor,
|
||||||
|
output_tensor: torch.Tensor,
|
||||||
|
inputmetadata: InputMetaData,
|
||||||
|
k_caches: List[torch.Tensor] = None,
|
||||||
|
v_caches: List[torch.Tensor] = None,
|
||||||
|
memory_pool=None,
|
||||||
|
) -> None:
|
||||||
|
assert self.graph is None
|
||||||
|
|
||||||
|
# run kernel once to cache the kernel, avoid stream capture error
|
||||||
|
hidden_states = self.model(
|
||||||
|
# batch,
|
||||||
|
input_tokens_ids,
|
||||||
|
output_tensor,
|
||||||
|
inputmetadata,
|
||||||
|
k_caches,
|
||||||
|
v_caches,
|
||||||
|
)
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
|
||||||
|
# Capture the graph.
|
||||||
|
# self.logger.info(f"begin capture model...")
|
||||||
|
self.graph = torch.cuda.CUDAGraph()
|
||||||
|
with torch.cuda.graph(self.graph, pool=memory_pool):
|
||||||
|
hidden_states = self.model(
|
||||||
|
# batch,
|
||||||
|
input_tokens_ids,
|
||||||
|
output_tensor,
|
||||||
|
inputmetadata,
|
||||||
|
k_caches,
|
||||||
|
v_caches,
|
||||||
|
)
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
|
||||||
|
# Save the input and output buffers, because replay always uses the same virtual memory space
|
||||||
|
self.input_buffers = {
|
||||||
|
# "batch": batch,
|
||||||
|
"input_tokens_ids": input_tokens_ids,
|
||||||
|
"output_tensor": output_tensor,
|
||||||
|
"block_tables": inputmetadata.block_tables,
|
||||||
|
"sequence_lengths": inputmetadata.sequence_lengths,
|
||||||
|
"k_caches": k_caches,
|
||||||
|
"v_caches": v_caches,
|
||||||
|
}
|
||||||
|
self.output_buffers = {"logits": hidden_states}
|
||||||
|
return
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input_tokens_ids: torch.Tensor,
|
||||||
|
output_tensor: torch.Tensor,
|
||||||
|
inputmetadata: InputMetaData,
|
||||||
|
k_caches: List[torch.Tensor] = None,
|
||||||
|
v_caches: List[torch.Tensor] = None,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
# Copy the input tensors to the input buffers.
|
||||||
|
self.input_buffers["input_tokens_ids"].copy_(input_tokens_ids, non_blocking=True)
|
||||||
|
self.input_buffers["output_tensor"].copy_(output_tensor, non_blocking=True)
|
||||||
|
self.input_buffers["block_tables"].copy_(inputmetadata.block_tables, non_blocking=True)
|
||||||
|
self.input_buffers["sequence_lengths"].copy_(inputmetadata.sequence_lengths, non_blocking=True)
|
||||||
|
|
||||||
|
# KV caches are fixed tensors, so we don't need to copy them.
|
||||||
|
# self.input_buffers["k_caches"].copy_(k_caches, non_blocking=True)
|
||||||
|
# self.input_buffers["v_caches"].copy_(v_caches, non_blocking=True)
|
||||||
|
|
||||||
|
# Run the graph.
|
||||||
|
self.graph.replay()
|
||||||
|
|
||||||
|
# Return the output tensor.
|
||||||
|
return self.output_buffers["logits"]
|
||||||
|
|
||||||
|
def __call__(self, *args, **kwargs):
|
||||||
|
return self.forward(*args, **kwargs)
|
Loading…
Reference in new issue