You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
ColossalAI/colossalai/inference/core/rpc_engine.py

298 lines
12 KiB

import asyncio
from itertools import count
from time import sleep
from typing import List, Tuple, Union
import rpyc
import torch
import torch.nn as nn
from rpyc.utils.server import ThreadedServer
from torch import multiprocessing as mp
from transformers import AutoConfig, PreTrainedTokenizer, PreTrainedTokenizerFast
from transformers.configuration_utils import PretrainedConfig
from colossalai.inference.batch_bucket import BatchBucket
from colossalai.inference.config import InferenceConfig, InputMetaData
from colossalai.inference.executor.rpc_worker import rpcWorkerService
from colossalai.inference.utils import find_available_ports
from colossalai.logging import get_dist_logger
from colossalai.shardformer.policies.base_policy import Policy
from .engine import InferenceEngine
from .request_handler import RPCRequestHandler
__all__ = ["RPCInferenceEngine"]
def run_server(host, port, event: mp.Event = None):
server = ThreadedServer(
rpcWorkerService, port=port, protocol_config={"allow_public_attrs": True, "allow_all_attrs": True}
)
if event:
event.set()
server.start()
class RPCInferenceEngine(InferenceEngine):
"""
InferenceEngine which manages the inference process..
NOTE This `RPCInferenceEngine` is designed for multiple-card/online serving.
Original `InferenceEngine` is designed for single card and offline service, though it supports multi-card offline inference.
Args:
model_or_path (nn.Module or str): Path or nn.Module of this model, Currently we don't support `nn.Module` Format
tokenizer Optional[(Union[PreTrainedTokenizer, PreTrainedTokenizerFast])]: Path of the tokenizer to use.
inference_config (Optional[InferenceConfig], optional): Store the configuration information related to inference.
verbose (bool): Determine whether or not to log the generation process.
model_policy ("Policy"): the policy to shardformer model. It will be determined by the model type if not provided.
"""
def __init__(
self,
model_or_path: Union[nn.Module, str],
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
inference_config: InferenceConfig,
verbose: bool = False,
model_policy: Policy = None,
) -> None:
"""
If you input a real model loaded by transformers, the init will take quite a long time
Currently we don't support model(nn.Module) format as the param.
"""
torch.multiprocessing.set_start_method("spawn", force=True)
self.inference_config = inference_config
self.tokenizer = tokenizer
self.tokenizer.pad_token = self.tokenizer.eos_token
self.verbose = verbose
self.logger = get_dist_logger(__name__)
try:
if isinstance(model_or_path, str):
self.model_config = AutoConfig.from_pretrained(
model_or_path, trust_remote_code=True, torch_dtype=self.dtype
)
elif isinstance(model_or_path, nn.Module):
self.logger.error(
f"An exception occurred during loading model Config: For {__class__.__name__}, we don't support param like nn.Module currently\n"
)
# self.model_config = model_or_path.config
else:
self.logger.error(
f"An exception occurred during loading model Config: Please pass right param for {__class__.__name__}\n"
)
except Exception as e:
self.logger.error(
f"An exception occurred during loading model Config: {e}, The path should be transformers-like\n"
)
self.generation_config = inference_config.to_generation_config(self.model_config)
self.tp_size = inference_config.tp_size
self.events = [mp.Event() for _ in range(self.tp_size)]
# This operation will init the dist env and models
self.workers: List[rpcWorkerService] = []
self.init_workers()
asyncio.run(self.init_model(model_or_path, model_policy))
# init the scheduler and logic block manager
self.request_handler = self.init_scheduler(self.inference_config, self.model_config)
# init the physical cache
alloc_shape = self.request_handler.cache_manager.get_physical_cache_shape()
self.init_device_cache(alloc_shape)
self.use_cuda_graph = self.inference_config.use_cuda_graph
self.high_precision = inference_config.high_precision
self.dtype = inference_config.dtype
# Model and relatable attrs of speculative decoding will be set by `enable_spec_dec`
self.use_spec_dec = False
self.drafter_model = None
self.drafter = None
self.use_glide = False
self.n_spec_tokens = self.inference_config.max_n_spec_tokens
self.counter = count()
self._verify_args()
self.logger.info("engine init over ")
def _verify_args(self) -> None:
"""Verify the input args"""
if not isinstance(self.inference_config, InferenceConfig):
raise TypeError("Invalid type of inference config provided.")
if not isinstance(self.tokenizer, (PreTrainedTokenizerFast, PreTrainedTokenizer)):
raise TypeError(
f"the tokenizer type must be PreTrainedTokenizer or PreTrainedTokenizerFast, but got {type(self.tokenizer)}"
)
def init_workers(self):
rpc_ports = find_available_ports(self.tp_size)
self.worker_processes = []
# mp.set_start_method('spawn')
for event, rpc_port in zip(self.events, rpc_ports):
p = mp.Process(target=run_server, args=("localhost", rpc_port, event))
p.start()
self.worker_processes.append(p)
self.logger.info(f"Starting RPC Worker on localhost:{rpc_port}...")
# Wait for all servers to start
for event in self.events:
event.wait()
event.clear()
sleep(0.05)
self.logger.info(f"init rpc server done.")
for rpc_port in rpc_ports:
try:
conn = rpyc.connect(
"localhost",
rpc_port,
config={"allow_pickle": True, "allow_public_attrs": True, "allow_all_attrs": True},
)
self.workers.append(conn.root)
except:
raise Exception("conn error!")
self.logger.info(f"Build RPC Connection Success! Begin to load model...")
asyncio.run(self.init_worker_env())
self.logger.info(f"init dist env over")
async def async_parallel_wrapper(self, f, *args, **kwargs):
async_res = rpyc.async_(f)(*args, **kwargs)
await asyncio.to_thread(async_res.wait)
assert async_res.ready
return async_res.value
async def init_worker_env(self):
assert len(self.workers) == self.tp_size, "init workers first"
dist_group_port = find_available_ports(1)[0]
init_tasks = [
self.async_parallel_wrapper(
worker.init_dist_env, rank, self.inference_config.tp_size, "127.0.0.1", dist_group_port
)
for rank, worker in enumerate(self.workers)
]
await asyncio.gather(*init_tasks)
async def init_model(self, model_or_path: Union[nn.Module, str], model_policy: Policy = None):
assert len(self.workers) == self.tp_size, "init workers first"
inference_config_param = self.inference_config.to_rpc_param()
model_path = model_or_path
model_policy_param = model_policy.to_rpc_param() if model_policy else None
init_tasks = [
self.async_parallel_wrapper(worker.init_model, inference_config_param, model_path, model_policy_param)
for rank, worker in enumerate(self.workers)
]
await asyncio.gather(*init_tasks)
def init_scheduler(self, inference_config: InferenceConfig, model_config: PretrainedConfig) -> RPCRequestHandler:
return RPCRequestHandler(inference_config, model_config)
async def _init_device_cache(self, alloc_shape: Tuple[int, int, int, int]):
assert len(self.workers) == self.tp_size, "init workers first"
init_tasks = [self.async_parallel_wrapper(worker.init_cache, alloc_shape) for worker in self.workers]
await asyncio.gather(*init_tasks)
def init_device_cache(self, alloc_shape: Tuple[Tuple[int, ...], Tuple[int, ...]]):
asyncio.run(self._init_device_cache(alloc_shape))
def prepare_input(self, batch: BatchBucket) -> Tuple[List[int], InputMetaData]:
input_ids = batch.get_1D_inputs()
sequence_lengths = batch.get_sequence_lengths()
if batch.is_prompts:
n_tokens = sequence_lengths.sum().item()
else:
n_tokens = batch.current_batch_size
if batch.use_spec_dec:
n_tokens = batch.num_tokens_to_verify + 1
assert n_tokens == input_ids.size(0)
n_tokens = n_tokens * batch.current_batch_size
batch_token_ids = None
config_dict = self.generation_config.to_dict()
# process repetition_penalty, no_repeat_ngram_size
for type in ["repetition_penalty", "no_repeat_ngram_size"]:
if type in config_dict and config_dict[type] is not None:
batch_token_ids = batch.batch_token_ids
# only when we have the graph for specific decoding batch size can we use the cuda graph for inference
use_cuda_graph = False
if self.use_cuda_graph and not batch.is_prompts and batch.current_batch_size in self.graph_runners.keys():
use_cuda_graph = True
input_meta_data = InputMetaData(
block_tables=batch.get_block_table_tensor(),
sequence_lengths=sequence_lengths,
fd_inter_tensor=None,
batch_size=batch.current_batch_size,
is_prompts=batch.is_prompts,
use_cuda_kernel=self.inference_config.use_cuda_kernel,
use_cuda_graph=use_cuda_graph,
high_precision=self.high_precision,
kv_seq_len=sequence_lengths.max().item(),
head_dim=batch.head_dim,
dtype=batch.dtype,
use_spec_dec=batch.use_spec_dec,
num_tokens_to_verify=batch.num_tokens_to_verify,
batch_token_ids=batch_token_ids,
)
return input_ids.tolist(), input_meta_data
async def step_(self, input_token_ids, input_meta_data: InputMetaData):
assert len(self.workers) == self.tp_size, "init workers first"
init_tasks = [
self.async_parallel_wrapper(
worker.execute_model_forward,
input_token_ids,
input_meta_data.to_rpc_param(),
self.generation_config_dict,
)
for worker in self.workers
]
ret = await asyncio.gather(*init_tasks)
return ret[0]
def step(self) -> List[str]:
batch = self.request_handler.schedule()
input_token_ids, input_meta_data = self.prepare_input(batch)
# TODO: padding_id is used for generating attn_mask and will be removed if nopad version is supported.
next_tokens = asyncio.run(self.step_(input_token_ids, input_meta_data))
# update the request_handler
next_tokens = torch.tensor(next_tokens, dtype=torch.int)
self.request_handler.append_next_tokens(next_tokens)
finished_sequences = self.request_handler.update()
return finished_sequences
def kill_workers(self):
"""
I don't find a good way to implicit invoke self.kill_workers
"""
assert len(self.workers) != 0
for proc in self.worker_processes:
proc.kill()
proc.join()
self.logger.info(f"worker killed, serving end")
def __del__(self):
self.kill_workers()