import asyncio from itertools import count from time import sleep from typing import List, Tuple, Union import rpyc import torch import torch.nn as nn from rpyc.utils.server import ThreadedServer from torch import multiprocessing as mp from transformers import AutoConfig, PreTrainedTokenizer, PreTrainedTokenizerFast from transformers.configuration_utils import PretrainedConfig from colossalai.inference.batch_bucket import BatchBucket from colossalai.inference.config import InferenceConfig, InputMetaData from colossalai.inference.executor.rpc_worker import rpcWorkerService from colossalai.inference.utils import find_available_ports from colossalai.logging import get_dist_logger from colossalai.shardformer.policies.base_policy import Policy from .engine import InferenceEngine from .request_handler import RPCRequestHandler __all__ = ["RPCInferenceEngine"] def run_server(host, port, event: mp.Event = None): server = ThreadedServer( rpcWorkerService, port=port, protocol_config={"allow_public_attrs": True, "allow_all_attrs": True} ) if event: event.set() server.start() class RPCInferenceEngine(InferenceEngine): """ InferenceEngine which manages the inference process.. NOTE This `RPCInferenceEngine` is designed for multiple-card/online serving. Original `InferenceEngine` is designed for single card and offline service, though it supports multi-card offline inference. Args: model_or_path (nn.Module or str): Path or nn.Module of this model, Currently we don't support `nn.Module` Format tokenizer Optional[(Union[PreTrainedTokenizer, PreTrainedTokenizerFast])]: Path of the tokenizer to use. inference_config (Optional[InferenceConfig], optional): Store the configuration information related to inference. verbose (bool): Determine whether or not to log the generation process. model_policy ("Policy"): the policy to shardformer model. It will be determined by the model type if not provided. """ def __init__( self, model_or_path: Union[nn.Module, str], tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast], inference_config: InferenceConfig, verbose: bool = False, model_policy: Policy = None, ) -> None: """ If you input a real model loaded by transformers, the init will take quite a long time Currently we don't support model(nn.Module) format as the param. """ torch.multiprocessing.set_start_method("spawn", force=True) self.inference_config = inference_config self.tokenizer = tokenizer self.tokenizer.pad_token = self.tokenizer.eos_token self.verbose = verbose self.logger = get_dist_logger(__name__) try: if isinstance(model_or_path, str): self.model_config = AutoConfig.from_pretrained( model_or_path, trust_remote_code=True, torch_dtype=self.dtype ) elif isinstance(model_or_path, nn.Module): self.logger.error( f"An exception occurred during loading model Config: For {__class__.__name__}, we don't support param like nn.Module currently\n" ) # self.model_config = model_or_path.config else: self.logger.error( f"An exception occurred during loading model Config: Please pass right param for {__class__.__name__}\n" ) except Exception as e: self.logger.error( f"An exception occurred during loading model Config: {e}, The path should be transformers-like\n" ) self.generation_config = inference_config.to_generation_config(self.model_config) self.tp_size = inference_config.tp_size self.events = [mp.Event() for _ in range(self.tp_size)] # This operation will init the dist env and models self.workers: List[rpcWorkerService] = [] self.init_workers() asyncio.run(self.init_model(model_or_path, model_policy)) # init the scheduler and logic block manager self.request_handler = self.init_scheduler(self.inference_config, self.model_config) # init the physical cache alloc_shape = self.request_handler.cache_manager.get_physical_cache_shape() self.init_device_cache(alloc_shape) self.use_cuda_graph = self.inference_config.use_cuda_graph self.high_precision = inference_config.high_precision self.dtype = inference_config.dtype # Model and relatable attrs of speculative decoding will be set by `enable_spec_dec` self.use_spec_dec = False self.drafter_model = None self.drafter = None self.use_glide = False self.n_spec_tokens = self.inference_config.max_n_spec_tokens self.counter = count() self._verify_args() self.logger.info("engine init over ") def _verify_args(self) -> None: """Verify the input args""" if not isinstance(self.inference_config, InferenceConfig): raise TypeError("Invalid type of inference config provided.") if not isinstance(self.tokenizer, (PreTrainedTokenizerFast, PreTrainedTokenizer)): raise TypeError( f"the tokenizer type must be PreTrainedTokenizer or PreTrainedTokenizerFast, but got {type(self.tokenizer)}" ) def init_workers(self): rpc_ports = find_available_ports(self.tp_size) self.worker_processes = [] # mp.set_start_method('spawn') for event, rpc_port in zip(self.events, rpc_ports): p = mp.Process(target=run_server, args=("localhost", rpc_port, event)) p.start() self.worker_processes.append(p) self.logger.info(f"Starting RPC Worker on localhost:{rpc_port}...") # Wait for all servers to start for event in self.events: event.wait() event.clear() sleep(0.05) self.logger.info(f"init rpc server done.") for rpc_port in rpc_ports: try: conn = rpyc.connect( "localhost", rpc_port, config={"allow_pickle": True, "allow_public_attrs": True, "allow_all_attrs": True}, ) self.workers.append(conn.root) except: raise Exception("conn error!") self.logger.info(f"Build RPC Connection Success! Begin to load model...") asyncio.run(self.init_worker_env()) self.logger.info(f"init dist env over") async def async_parallel_wrapper(self, f, *args, **kwargs): async_res = rpyc.async_(f)(*args, **kwargs) await asyncio.to_thread(async_res.wait) assert async_res.ready return async_res.value async def init_worker_env(self): assert len(self.workers) == self.tp_size, "init workers first" dist_group_port = find_available_ports(1)[0] init_tasks = [ self.async_parallel_wrapper( worker.init_dist_env, rank, self.inference_config.tp_size, "127.0.0.1", dist_group_port ) for rank, worker in enumerate(self.workers) ] await asyncio.gather(*init_tasks) async def init_model(self, model_or_path: Union[nn.Module, str], model_policy: Policy = None): assert len(self.workers) == self.tp_size, "init workers first" inference_config_param = self.inference_config.to_rpc_param() model_path = model_or_path model_policy_param = model_policy.to_rpc_param() if model_policy else None init_tasks = [ self.async_parallel_wrapper(worker.init_model, inference_config_param, model_path, model_policy_param) for rank, worker in enumerate(self.workers) ] await asyncio.gather(*init_tasks) def init_scheduler(self, inference_config: InferenceConfig, model_config: PretrainedConfig) -> RPCRequestHandler: return RPCRequestHandler(inference_config, model_config) async def _init_device_cache(self, alloc_shape: Tuple[int, int, int, int]): assert len(self.workers) == self.tp_size, "init workers first" init_tasks = [self.async_parallel_wrapper(worker.init_cache, alloc_shape) for worker in self.workers] await asyncio.gather(*init_tasks) def init_device_cache(self, alloc_shape: Tuple[Tuple[int, ...], Tuple[int, ...]]): asyncio.run(self._init_device_cache(alloc_shape)) def prepare_input(self, batch: BatchBucket) -> Tuple[List[int], InputMetaData]: input_ids = batch.get_1D_inputs() sequence_lengths = batch.get_sequence_lengths() if batch.is_prompts: n_tokens = sequence_lengths.sum().item() else: n_tokens = batch.current_batch_size if batch.use_spec_dec: n_tokens = batch.num_tokens_to_verify + 1 assert n_tokens == input_ids.size(0) n_tokens = n_tokens * batch.current_batch_size batch_token_ids = None config_dict = self.generation_config.to_dict() # process repetition_penalty, no_repeat_ngram_size for type in ["repetition_penalty", "no_repeat_ngram_size"]: if type in config_dict and config_dict[type] is not None: batch_token_ids = batch.batch_token_ids # only when we have the graph for specific decoding batch size can we use the cuda graph for inference use_cuda_graph = False if self.use_cuda_graph and not batch.is_prompts and batch.current_batch_size in self.graph_runners.keys(): use_cuda_graph = True input_meta_data = InputMetaData( block_tables=batch.get_block_table_tensor(), sequence_lengths=sequence_lengths, fd_inter_tensor=None, batch_size=batch.current_batch_size, is_prompts=batch.is_prompts, use_cuda_kernel=self.inference_config.use_cuda_kernel, use_cuda_graph=use_cuda_graph, high_precision=self.high_precision, kv_seq_len=sequence_lengths.max().item(), head_dim=batch.head_dim, dtype=batch.dtype, use_spec_dec=batch.use_spec_dec, num_tokens_to_verify=batch.num_tokens_to_verify, batch_token_ids=batch_token_ids, ) return input_ids.tolist(), input_meta_data async def step_(self, input_token_ids, input_meta_data: InputMetaData): assert len(self.workers) == self.tp_size, "init workers first" init_tasks = [ self.async_parallel_wrapper( worker.execute_model_forward, input_token_ids, input_meta_data.to_rpc_param(), self.generation_config_dict, ) for worker in self.workers ] ret = await asyncio.gather(*init_tasks) return ret[0] def step(self) -> List[str]: batch = self.request_handler.schedule() input_token_ids, input_meta_data = self.prepare_input(batch) # TODO: padding_id is used for generating attn_mask and will be removed if nopad version is supported. next_tokens = asyncio.run(self.step_(input_token_ids, input_meta_data)) # update the request_handler next_tokens = torch.tensor(next_tokens, dtype=torch.int) self.request_handler.append_next_tokens(next_tokens) finished_sequences = self.request_handler.update() return finished_sequences def kill_workers(self): """ I don't find a good way to implicit invoke self.kill_workers """ assert len(self.workers) != 0 for proc in self.worker_processes: proc.kill() proc.join() self.logger.info(f"worker killed, serving end") def __del__(self): self.kill_workers()