From 18d67d0e8e79c22bded0745c7d3daf8ca40d445c Mon Sep 17 00:00:00 2001 From: Runyu Lu <77330637+LRY89757@users.noreply.github.com> Date: Tue, 14 May 2024 10:00:55 +0800 Subject: [PATCH] [Feat]Inference RPC Server Support (#5705) * rpc support source * kv cache logical/physical disaggregation * sampler refactor * colossalai launch built in * Unitest * Rpyc support --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- colossalai/inference/config.py | 115 ++++++- colossalai/inference/core/engine.py | 17 +- colossalai/inference/core/request_handler.py | 95 +++--- colossalai/inference/core/rpc_engine.py | 291 +++++++++++++++++ colossalai/inference/executor/rpc_worker.py | 300 ++++++++++++++++++ colossalai/inference/kv_cache/__init__.py | 4 +- .../inference/kv_cache/kvcache_manager.py | 77 +++++ colossalai/inference/logit_processors.py | 9 +- .../modeling/policy/nopadding_baichuan.py | 10 +- .../modeling/policy/nopadding_llama.py | 10 +- colossalai/inference/sampler.py | 49 ++- colossalai/inference/utils.py | 11 + requirements/requirements-test.txt | 1 + requirements/requirements.txt | 1 + tests/test_infer/test_rpc_engine.py | 105 ++++++ 15 files changed, 1032 insertions(+), 63 deletions(-) create mode 100644 colossalai/inference/core/rpc_engine.py create mode 100644 colossalai/inference/executor/rpc_worker.py create mode 100644 tests/test_infer/test_rpc_engine.py diff --git a/colossalai/inference/config.py b/colossalai/inference/config.py index 8bd2394ad..70faf34e3 100644 --- a/colossalai/inference/config.py +++ b/colossalai/inference/config.py @@ -2,11 +2,11 @@ Our config contains various options for inference optimization, it is a unified API that wraps all the configurations for inference. """ import logging +from abc import ABC, abstractmethod from dataclasses import dataclass, fields -from typing import Any, Dict, Optional, Union +from typing import Any, Dict, List, Optional, Union import torch -import torch.distributed as dist from transformers.generation import GenerationConfig from colossalai.inference.flash_decoding_utils import FDIntermTensors @@ -30,8 +30,25 @@ _DEFAULT_PROMPT_TEMPLATES = { } +class RPC_PARAM(ABC): + """ + NOTE(lry89757) We use rpyc to transport param between client and server. + Rpyc only support the type of `POD` in python as the param, so we should take some smart ways to transport the data like tensor or some sophisticated classes. + Drawing on the logic of `__setstate__`, `__getstate__`, we will let some classes(will be rpc param later) inherit this base class, and rewrite the to_rpc_param and from_rpc_param. We will invoke `to_rpc_param` in client to pass the params and recover the param in server side by `from_rpc_param`. + """ + + @abstractmethod + def to_rpc_param(self): + return NotImplementedError + + @staticmethod + @abstractmethod + def from_rpc_param(): + return NotImplementedError + + @dataclass -class InputMetaData: +class InputMetaData(RPC_PARAM): """The input info for a single step Args: @@ -48,6 +65,7 @@ class InputMetaData: dtype (torch.dtype, optional): The computation type of tensor, Defaults to torch.float32. use_spec_dec (bool): Indicate whether to use speculative decoding. num_tokens_to_verify (int): The number of tokens to verify in speculative decoding. Only valid when `use_spec_dec` is set to True. + batch_token_ids (List[List[int]], optional): input_token_ids + output_token_ids of current batch. Only used for `repetition_penalty`, `no_repeat_ngram_size` in sampler process. """ block_tables: torch.Tensor = None @@ -63,6 +81,54 @@ class InputMetaData: dtype: torch.dtype = torch.float32 use_spec_dec: bool = False num_tokens_to_verify: int = 0 + batch_token_ids: Optional[ + List[List[int]] + ] = None # for `repetition_penalty`, `no_repeat_ngram_size` in sampler process + + def to_rpc_param(self) -> Dict[str, any]: + return { + "block_tables": self.block_tables.tolist(), + "sequence_lengths": self.sequence_lengths.tolist(), + "batch_size": self.batch_size, + "is_prompts": self.is_prompts, + "use_cuda_kernel": self.use_cuda_kernel, + "use_cuda_graph": self.use_cuda_graph, + "kv_seq_len": self.kv_seq_len, + "head_dim": self.head_dim, + "high_precision": self.high_precision, + "dtype": str(self.dtype).split(".")[-1], + "use_spec_dec": self.use_spec_dec, + "num_tokens_to_verify": self.num_tokens_to_verify, + "batch_token_ids": self.batch_token_ids, + } + + @staticmethod + def from_rpc_param(rpc_dict: Dict[str, any]) -> "InputMetaData": + """ + We intentionally don't use `dict.get` method to ensure we pass the right rpc param, or program will show error message + """ + from colossalai.accelerator import get_accelerator + + dtype = getattr(torch, rpc_dict["dtype"]) + return InputMetaData( + block_tables=torch.tensor( + rpc_dict["block_tables"], dtype=torch.int, device=get_accelerator().get_current_device() + ), + sequence_lengths=torch.tensor( + rpc_dict["sequence_lengths"], dtype=torch.int, device=get_accelerator().get_current_device() + ), + batch_size=rpc_dict["batch_size"], + is_prompts=rpc_dict["is_prompts"], + use_cuda_kernel=rpc_dict["use_cuda_kernel"], + use_cuda_graph=rpc_dict["use_cuda_graph"], + kv_seq_len=rpc_dict["kv_seq_len"], + head_dim=rpc_dict["head_dim"], + high_precision=rpc_dict["high_precision"], + dtype=dtype, + use_spec_dec=rpc_dict["use_spec_dec"], + num_tokens_to_verify=rpc_dict["num_tokens_to_verify"], + batch_token_ids=rpc_dict["batch_token_ids"], + ) def __repr__(self) -> str: return ( @@ -80,7 +146,7 @@ class InputMetaData: @dataclass -class InferenceConfig: +class InferenceConfig(RPC_PARAM): """The inference configuration. Args: @@ -193,10 +259,6 @@ class InferenceConfig: if self.dtype == torch.float32: self.high_precision = False - # check distributed - assert (not torch.distributed.is_initialized() and self.tp_size * self.pp_size == 1) or ( - self.tp_size * self.pp_size == dist.get_world_size() - ), f"TP size({self.tp_size}) * PP size({self.pp_size}) should be equal to the global world size ({dist.get_world_size()})" # check prompt template if self.prompt_template is None: return @@ -226,6 +288,43 @@ class InferenceConfig: return GenerationConfig.from_dict(meta_config) + def to_rpc_param(self) -> dict: + kwargs = { + "dtype": str(self.dtype).split(".")[-1], + "max_n_spec_tokens": self.max_n_spec_tokens, + "max_batch_size": self.max_batch_size, + "max_input_len": self.max_input_len, + "max_output_len": self.max_output_len, + "tp_size": self.tp_size, + "pp_size": self.pp_size, + "pad_input": self.pad_input, + "early_stopping": self.early_stopping, + "do_sample": self.do_sample, + "beam_width": self.beam_width, + "kv_cache_dtype": str(self.kv_cache_dtype).split(".")[-1], + } + return kwargs + + @staticmethod + def from_rpc_param(rpc_dict: dict) -> "InferenceConfig": + """ + We intentionally don't use `dict.get` method to ensure we pass the right rpc param, or program will show error message + """ + return InferenceConfig( + dtype=getattr(torch, rpc_dict["dtype"]), + max_n_spec_tokens=rpc_dict["max_n_spec_tokens"], + max_batch_size=rpc_dict["max_batch_size"], + max_input_len=rpc_dict["max_input_len"], + max_output_len=rpc_dict["max_output_len"], + tp_size=rpc_dict["tp_size"], + pp_size=rpc_dict["pp_size"], + pad_input=rpc_dict["pad_input"], + early_stopping=rpc_dict["early_stopping"], + do_sample=rpc_dict["do_sample"], + beam_width=rpc_dict["beam_width"], + kv_cache_dtype=getattr(torch, rpc_dict["kv_cache_dtype"], None), + ) + @classmethod def from_dict(cls, config_dict: Dict[str, Any]) -> "InferenceConfig": # Get the list of attributes of this dataclass. diff --git a/colossalai/inference/core/engine.py b/colossalai/inference/core/engine.py index 44f2c8f47..7b456b8be 100644 --- a/colossalai/inference/core/engine.py +++ b/colossalai/inference/core/engine.py @@ -21,6 +21,7 @@ from colossalai.inference.batch_bucket import BatchBucket from colossalai.inference.config import InferenceConfig, InputMetaData from colossalai.inference.graph_runner import CUDAGraphRunner from colossalai.inference.modeling.policy import model_policy_map +from colossalai.inference.sampler import search_tokens from colossalai.inference.spec import Drafter, GlideInput from colossalai.inference.struct import Sequence from colossalai.inference.utils import get_model_size, has_index_file @@ -424,7 +425,7 @@ class InferenceEngine: # 2. Prefill main model (Verifier) - fill past kv cache for main model logits = model_executable(input_token_ids, output_tensor, input_meta_data, self.k_cache, self.v_cache) - next_tokens = self.request_handler.search_tokens(self.generation_config, logits, batch) + next_tokens = search_tokens(self.generation_config, logits, batch_token_ids=batch.batch_token_ids) # append new inputs to the batch, temporarily batch.append_batch_tokens(next_tokens) self.request_handler.allocate_batch_spec_dec(batch, 1) @@ -472,7 +473,7 @@ class InferenceEngine: input_token_ids, output_tensor, input_meta_data = self.prepare_input(batch) logits = model_executable(input_token_ids, output_tensor, input_meta_data, self.k_cache, self.v_cache) - next_tokens = self.request_handler.search_tokens(self.generation_config, logits, batch) + next_tokens = search_tokens(self.generation_config, logits, batch_token_ids=batch.batch_token_ids) # 5. Compare and process the results diff_indexes = torch.nonzero(~(next_tokens[:-1] == next_token_ids_spec)) @@ -689,6 +690,13 @@ class InferenceEngine: (n_tokens, batch.num_heads * batch.head_dim), dtype=batch.dtype, device=batch.device ) + batch_token_ids = None + config_dict = self.generation_config.to_dict() + # process repetition_penalty, no_repeat_ngram_size + for type in ["repetition_penalty", "no_repeat_ngram_size"]: + if type in config_dict and config_dict[type] is not None: + batch_token_ids = batch.batch_token_ids + # only when we have the graph for specific decoding batch size can we use the cuda graph for inference use_cuda_graph = False if self.use_cuda_graph and not batch.is_prompts and batch.current_batch_size in self.graph_runners.keys(): @@ -708,6 +716,7 @@ class InferenceEngine: dtype=batch.dtype, use_spec_dec=batch.use_spec_dec, num_tokens_to_verify=batch.num_tokens_to_verify, + batch_token_ids=batch_token_ids, ) return input_ids, output_tensor, input_meta_data @@ -738,7 +747,9 @@ class InferenceEngine: logits = model_executable(input_token_ids, output_tensor, input_meta_data, self.k_cache, self.v_cache) if self.inference_config.pad_input: logits = logits[:, -1, :] - next_tokens = self.request_handler.search_tokens(self.generation_config, logits, batch) + next_tokens = search_tokens( + self.generation_config, logits, input_meta_data.is_prompts, batch_token_ids=input_meta_data.batch_token_ids + ) self.request_handler.append_next_tokens(next_tokens) finished_sequences = self.request_handler.update() diff --git a/colossalai/inference/core/request_handler.py b/colossalai/inference/core/request_handler.py index c514eeccf..5085c5555 100644 --- a/colossalai/inference/core/request_handler.py +++ b/colossalai/inference/core/request_handler.py @@ -7,10 +7,11 @@ from transformers.generation import GenerationConfig from colossalai.inference.batch_bucket import BatchBucket from colossalai.inference.config import InferenceConfig from colossalai.inference.flash_decoding_utils import FDIntermTensors -from colossalai.inference.kv_cache import KVCacheManager -from colossalai.inference.logit_processors import logit_processor -from colossalai.inference.sampler import * +from colossalai.inference.kv_cache import KVCacheManager, RPCKVCacheManager from colossalai.inference.struct import RequestStatus, Sequence +from colossalai.logging import get_dist_logger + +logger = get_dist_logger(__name__) __all__ = ["RunningList", "RequestHandler"] @@ -295,17 +296,6 @@ class RequestHandler: return None - def _sample(self, probs: torch.Tensor, logprobs: torch.Tensor, generation_config: GenerationConfig): - if generation_config.num_beams == 1: - if generation_config.do_sample: - sample_tokens = multinomial_sample(generation_config, probs) - else: - sample_tokens = greedy_sample(generation_config, logprobs) - else: - sample_tokens = beam_search_sample(generation_config, logprobs, is_prompt=not self.prefill_bb.is_empty) - - return sample_tokens - def update_seq_finished(self, sequence: Sequence, generation_config: GenerationConfig): if ( sequence.output_token_id[-1] == generation_config.eos_token_id @@ -328,33 +318,6 @@ class RequestHandler: def total_requests_in_batch_bucket(self) -> int: return self.prefill_bb.current_batch_size + self.running_bb.current_batch_size - def search_tokens(self, generation_config: GenerationConfig, logits, cur_batch: BatchBucket): - """ - Sample tokens for finished requests. - """ - - # NOTE: need to decide the granularity to process logits (sequence or batch) - config_dict = generation_config.to_dict() - # process repetition_penalty, no_repeat_ngram_size - for type in ["repetition_penalty", "no_repeat_ngram_size"]: - if type in config_dict and config_dict[type] is not None: - logits = logit_processor(type, logits, config_dict[type], cur_batch) - - # do logit processor - if generation_config.do_sample: - # process temperature, top_k, top_p - for type in ["temperature", "top_k", "top_p"]: - if type in config_dict and config_dict[type] is not None: - logits = logit_processor(type, logits, config_dict[type]) - - # calculate probs - probs = torch.softmax(logits, dim=-1, dtype=torch.float) - logprobs = torch.log_softmax(logits, dim=-1, dtype=torch.float) - - # sample the next tokens - sample_tokens = self._sample(probs, logprobs, generation_config) - return sample_tokens - def append_next_tokens(self, sample_tokens: torch.Tensor): assert sample_tokens.dim() == 1 n_elements = sample_tokens.size(0) @@ -386,3 +349,53 @@ class RequestHandler: self.done_list.extend(finished_seqs) return finished_seqs + + +class RPCRequestHandler(RequestHandler): + """ + RPC Version of request handler + """ + + def __init__(self, inference_config: InferenceConfig, model_config: PretrainedConfig) -> None: + self.inference_config = inference_config + self.running_list: RunningList = RunningList(inference_config.prefill_ratio) + self.waiting_list: List[List] = [[], [], []] + self.done_list: List[Sequence] = [] + self.dtype = inference_config.dtype + self.max_batch_size = inference_config.max_batch_size + + # initialize cache + self._init_cache(model_config) + + # initialize batch + torch.cuda.current_device() + kv_max_split_num = ( + inference_config.max_input_len + inference_config.max_output_len + inference_config.block_size - 1 + ) // inference_config.block_size + head_dim = model_config.hidden_size // model_config.num_attention_heads + + # TODO In the continuous batching scenario, the batch size may be greater than max_batch_size, + # which may cause bugs and this issue should be fixed later. + self.running_bb = BatchBucket( + num_heads=model_config.num_attention_heads // inference_config.tp_size, + head_dim=head_dim, + max_batch_size=self.max_batch_size, + max_length=inference_config.max_input_len + inference_config.max_output_len, + block_size=inference_config.block_size, + kv_max_split_num=kv_max_split_num, + fd_interm_tensor=None, + dtype=self.dtype, + ) + self.prefill_bb = BatchBucket( + num_heads=model_config.num_attention_heads // inference_config.tp_size, + head_dim=head_dim, + max_batch_size=self.max_batch_size, + max_length=inference_config.max_input_len + inference_config.max_output_len, + block_size=inference_config.block_size, + kv_max_split_num=kv_max_split_num, + fd_interm_tensor=None, + dtype=self.dtype, + ) + + def _init_cache(self, model_config): + self.cache_manager = RPCKVCacheManager(self.inference_config, model_config) diff --git a/colossalai/inference/core/rpc_engine.py b/colossalai/inference/core/rpc_engine.py new file mode 100644 index 000000000..9602147f5 --- /dev/null +++ b/colossalai/inference/core/rpc_engine.py @@ -0,0 +1,291 @@ +import asyncio +from itertools import count +from time import sleep +from typing import List, Tuple, Union + +import rpyc +import torch +import torch.nn as nn +from rpyc.utils.server import ThreadedServer +from torch import multiprocessing as mp +from transformers import AutoConfig, PreTrainedTokenizer, PreTrainedTokenizerFast +from transformers.configuration_utils import PretrainedConfig + +from colossalai.inference.batch_bucket import BatchBucket +from colossalai.inference.config import InferenceConfig, InputMetaData +from colossalai.inference.executor.rpc_worker import rpcWorkerService +from colossalai.inference.utils import find_available_ports +from colossalai.logging import get_dist_logger +from colossalai.shardformer.policies.base_policy import Policy + +from .engine import InferenceEngine +from .request_handler import RPCRequestHandler + +__all__ = ["RPCInferenceEngine"] + + +def run_server(host, port, event: mp.Event = None): + server = ThreadedServer( + rpcWorkerService, port=port, protocol_config={"allow_public_attrs": True, "allow_all_attrs": True} + ) + if event: + event.set() + server.start() + + +class RPCInferenceEngine(InferenceEngine): + + """ + InferenceEngine which manages the inference process.. + + NOTE This `RPCInferenceEngine` is designed for multiple-card/online serving. + Original `InferenceEngine` is designed for single card and offline service, though it supports multi-card offline inference. + + Args: + model_or_path (nn.Module or str): Path or nn.Module of this model, Currently we don't support `nn.Module` Format + tokenizer Optional[(Union[PreTrainedTokenizer, PreTrainedTokenizerFast])]: Path of the tokenizer to use. + inference_config (Optional[InferenceConfig], optional): Store the configuration information related to inference. + verbose (bool): Determine whether or not to log the generation process. + model_policy ("Policy"): the policy to shardformer model. It will be determined by the model type if not provided. + """ + + def __init__( + self, + model_or_path: Union[nn.Module, str], + tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast], + inference_config: InferenceConfig, + verbose: bool = False, + model_policy: Policy = None, + ) -> None: + """ + If you input a real model loaded by transformers, the init will take quite a long time + Currently we don't support model(nn.Module) format as the param. + """ + + torch.multiprocessing.set_start_method("spawn", force=True) + + self.inference_config = inference_config + self.tokenizer = tokenizer + self.tokenizer.pad_token = self.tokenizer.eos_token + + self.verbose = verbose + self.logger = get_dist_logger(__name__) + + try: + if isinstance(model_or_path, str): + self.model_config = AutoConfig.from_pretrained(model_or_path, trust_remote_code=True) + elif isinstance(model_or_path, nn.Module): + self.logger.error( + f"An exception occurred during loading model Config: For {__class__.__name__}, we don't support param like nn.Module currently\n" + ) + # self.model_config = model_or_path.config + else: + self.logger.error( + f"An exception occurred during loading model Config: Please pass right param for {__class__.__name__}\n" + ) + except Exception as e: + self.logger.error( + f"An exception occurred during loading model Config: {e}, The path should be transformers-like\n" + ) + self.generation_config = inference_config.to_generation_config(self.model_config) + + self.tp_size = inference_config.tp_size + self.events = [mp.Event() for _ in range(self.tp_size)] + + # This operation will init the dist env and models + self.workers: List[rpcWorkerService] = [] + self.init_workers() + + asyncio.run(self.init_model(model_or_path, model_policy)) + + # init the scheduler and logic block manager + self.request_handler = self.init_scheduler(self.inference_config, self.model_config) + + # init the physical cache + alloc_shape = self.request_handler.cache_manager.get_physical_cache_shape() + self.init_device_cache(alloc_shape) + + self.use_cuda_graph = self.inference_config.use_cuda_graph + self.high_precision = inference_config.high_precision + self.dtype = inference_config.dtype + + # Model and relatable attrs of speculative decoding will be set by `enable_spec_dec` + self.use_spec_dec = False + self.drafter_model = None + self.drafter = None + self.use_glide = False + self.n_spec_tokens = self.inference_config.max_n_spec_tokens + + self.counter = count() + self._verify_args() + + self.logger.info("engine init over ") + + def _verify_args(self) -> None: + """Verify the input args""" + if not isinstance(self.inference_config, InferenceConfig): + raise TypeError("Invalid type of inference config provided.") + if not isinstance(self.tokenizer, (PreTrainedTokenizerFast, PreTrainedTokenizer)): + raise TypeError( + f"the tokenizer type must be PreTrainedTokenizer or PreTrainedTokenizerFast, but got {type(self.tokenizer)}" + ) + + def init_workers(self): + rpc_ports = find_available_ports(self.tp_size) + self.worker_processes = [] + # mp.set_start_method('spawn') + for event, rpc_port in zip(self.events, rpc_ports): + p = mp.Process(target=run_server, args=("localhost", rpc_port, event)) + p.start() + self.worker_processes.append(p) + self.logger.info(f"Starting RPC Worker on localhost:{rpc_port}...") + + # Wait for all servers to start + for event in self.events: + event.wait() + event.clear() + + sleep(0.05) + + self.logger.info(f"init rpc server done.") + + for rpc_port in rpc_ports: + try: + conn = rpyc.connect( + "localhost", + rpc_port, + config={"allow_pickle": True, "allow_public_attrs": True, "allow_all_attrs": True}, + ) + self.workers.append(conn.root) + except: + raise Exception("conn error!") + self.logger.info(f"Build RPC Connection Success! Begin to load model...") + asyncio.run(self.init_worker_env()) + self.logger.info(f"init dist env over") + + async def async_parallel_wrapper(self, f, *args, **kwargs): + async_res = rpyc.async_(f)(*args, **kwargs) + await asyncio.to_thread(async_res.wait) + assert async_res.ready + return async_res.value + + async def init_worker_env(self): + assert len(self.workers) == self.tp_size, "init workers first" + + dist_group_port = find_available_ports(1)[0] + init_tasks = [ + self.async_parallel_wrapper( + worker.init_dist_env, rank, self.inference_config.tp_size, "127.0.0.1", dist_group_port + ) + for rank, worker in enumerate(self.workers) + ] + + await asyncio.gather(*init_tasks) + + async def init_model(self, model_or_path: Union[nn.Module, str], model_policy: Policy = None): + assert len(self.workers) == self.tp_size, "init workers first" + + inference_config_param = self.inference_config.to_rpc_param() + model_path = model_or_path + model_policy_param = model_policy.to_rpc_param() if model_policy else None + + init_tasks = [ + self.async_parallel_wrapper(worker.init_model, inference_config_param, model_path, model_policy_param) + for rank, worker in enumerate(self.workers) + ] + + await asyncio.gather(*init_tasks) + + def init_scheduler(self, inference_config: InferenceConfig, model_config: PretrainedConfig) -> RPCRequestHandler: + return RPCRequestHandler(inference_config, model_config) + + async def _init_device_cache(self, alloc_shape: Tuple[int, int, int, int]): + assert len(self.workers) == self.tp_size, "init workers first" + + init_tasks = [self.async_parallel_wrapper(worker.init_cache, alloc_shape) for worker in self.workers] + + await asyncio.gather(*init_tasks) + + def init_device_cache(self, alloc_shape: Tuple[Tuple[int, ...], Tuple[int, ...]]): + asyncio.run(self._init_device_cache(alloc_shape)) + + def prepare_input(self, batch: BatchBucket) -> Tuple[List[int], InputMetaData]: + input_ids = batch.get_1D_inputs() + sequence_lengths = batch.get_sequence_lengths() + + if batch.is_prompts: + n_tokens = sequence_lengths.sum().item() + else: + n_tokens = batch.current_batch_size + if batch.use_spec_dec: + n_tokens = batch.num_tokens_to_verify + 1 + assert n_tokens == input_ids.size(0) + n_tokens = n_tokens * batch.current_batch_size + + batch_token_ids = None + config_dict = self.generation_config.to_dict() + # process repetition_penalty, no_repeat_ngram_size + for type in ["repetition_penalty", "no_repeat_ngram_size"]: + if type in config_dict and config_dict[type] is not None: + batch_token_ids = batch.batch_token_ids + + # only when we have the graph for specific decoding batch size can we use the cuda graph for inference + use_cuda_graph = False + if self.use_cuda_graph and not batch.is_prompts and batch.current_batch_size in self.graph_runners.keys(): + use_cuda_graph = True + + input_meta_data = InputMetaData( + block_tables=batch.get_block_table_tensor(), + sequence_lengths=sequence_lengths, + fd_inter_tensor=None, + batch_size=batch.current_batch_size, + is_prompts=batch.is_prompts, + use_cuda_kernel=self.inference_config.use_cuda_kernel, + use_cuda_graph=use_cuda_graph, + high_precision=self.high_precision, + kv_seq_len=sequence_lengths.max().item(), + head_dim=batch.head_dim, + dtype=batch.dtype, + use_spec_dec=batch.use_spec_dec, + num_tokens_to_verify=batch.num_tokens_to_verify, + batch_token_ids=batch_token_ids, + ) + + return input_ids.tolist(), input_meta_data + + async def step_(self, input_token_ids, input_meta_data: InputMetaData): + assert len(self.workers) == self.tp_size, "init workers first" + + init_tasks = [ + self.async_parallel_wrapper(worker.execute_model_forward, input_token_ids, input_meta_data.to_rpc_param()) + for worker in self.workers + ] + ret = await asyncio.gather(*init_tasks) + + return ret[0] + + def step(self) -> List[str]: + batch = self.request_handler.schedule() + + input_token_ids, input_meta_data = self.prepare_input(batch) + # TODO: padding_id is used for generating attn_mask and will be removed if nopad version is supported. + next_tokens = asyncio.run(self.step_(input_token_ids, input_meta_data)) + + # update the request_handler + next_tokens = torch.tensor(next_tokens, dtype=torch.int) + self.request_handler.append_next_tokens(next_tokens) + finished_sequences = self.request_handler.update() + return finished_sequences + + def kill_workers(self): + """ + I don't find a good way to implicit invoke self.kill_workers + """ + assert len(self.workers) != 0 + for proc in self.worker_processes: + proc.kill() + proc.join() + self.logger.info(f"worker killed, serving end") + + def __del__(self): + self.kill_workers() diff --git a/colossalai/inference/executor/rpc_worker.py b/colossalai/inference/executor/rpc_worker.py new file mode 100644 index 000000000..4b84dcc85 --- /dev/null +++ b/colossalai/inference/executor/rpc_worker.py @@ -0,0 +1,300 @@ +import os +from typing import List, Tuple, Union + +import rpyc +import torch +import torch.distributed as dist +from torch import nn +from transformers import AutoConfig, AutoModelForCausalLM +from transformers.models.llama.modeling_llama import LlamaForCausalLM + +import colossalai +from colossalai.accelerator import get_accelerator +from colossalai.cluster import ProcessGroupMesh +from colossalai.inference.config import InferenceConfig, InputMetaData +from colossalai.inference.flash_decoding_utils import FDIntermTensors +from colossalai.inference.modeling.policy import ( + NoPaddingBaichuanModelInferPolicy, + NoPaddingLlamaModelInferPolicy, + model_policy_map, +) +from colossalai.inference.sampler import search_tokens +from colossalai.inference.utils import get_model_size, has_index_file +from colossalai.interface import ModelWrapper +from colossalai.logging import get_dist_logger +from colossalai.pipeline.stage_manager import PipelineStageManager +from colossalai.shardformer import ShardConfig, ShardFormer +from colossalai.shardformer.policies.base_policy import Policy + +PP_AXIS, TP_AXIS = 0, 1 + +_SUPPORTED_MODELS = { + "LlamaForCausalLM": LlamaForCausalLM, + "BaichuanForCausalLM": AutoModelForCausalLM, +} + +_SUPPORTED_MODEL_POLICIES = { + "NoPaddingLlamaModelInferPolicy": NoPaddingLlamaModelInferPolicy, + "NoPaddingBaichuanModelInferPolicy": NoPaddingBaichuanModelInferPolicy, +} + +logger = get_dist_logger(__name__) + + +class rpcWorkerService(rpyc.Service): + + """ + Execute the computation tasks and manage its own kv cache + + Func with prefix `exposed_` will be invoked by client. + """ + + def exposed_init_dist_env(self, rank, world_size, master_address, master_port): + logger.info(f"init process group for rank {rank}") + colossalai.launch(rank=rank, world_size=world_size, port=master_port, host=master_address) + logger.info(f"init process group done for rank {rank}") + + def exposed_init_model( + self, inference_config_param: dict, model_or_path: Union[nn.Module, str], model_policy_param: str = None + ): + assert dist.is_initialized(), "invoke init_dist_env first please!" + + self.inference_config = InferenceConfig.from_rpc_param(inference_config_param) + model_policy = _SUPPORTED_MODEL_POLICIES[model_policy_param]() if model_policy_param else None + + self.dtype = self.inference_config.dtype + self.verbose = True + + self._init_model(model_or_path, model_policy) + self._init_fd_tensor() + self._init_output_tensor() + logger.info(f"init model done for rank {dist.get_rank()}") + + def exposed_init_cache(self, alloc_shape: Tuple[Tuple[int, ...], Tuple[int, ...]]): + """Initialize the physical cache on the device. + + For each layer of the model, we allocate two tensors for key and value respectively, + with shape of [num_blocks, num_kv_heads, block_size, head_size] + """ + kalloc_shape, valloc_shape = alloc_shape + num_layers = self.model_config.num_hidden_layers + + self.k_cache: List[torch.Tensor] = [] + self.v_cache: List[torch.Tensor] = [] + for _ in range(num_layers): + self.k_cache.append( + torch.zeros( + kalloc_shape, + dtype=self.inference_config.kv_cache_dtype, + device=get_accelerator().get_current_device(), + ) + ) + self.v_cache.append( + torch.zeros( + valloc_shape, + dtype=self.inference_config.kv_cache_dtype, + device=get_accelerator().get_current_device(), + ) + ) + logger.info("physical cache init over") + + def exposed_execute_model_forward(self, input_token_ids_param: List[int], input_meta_data_param: dict): + # prepare the data for model forward + input_meta_data = InputMetaData.from_rpc_param(input_meta_data_param) + input_meta_data.fd_inter_tensor = self.fd_inter_tensor + if input_meta_data.is_prompts: + n_tokens = input_meta_data.sequence_lengths.sum().item() + else: + n_tokens = input_meta_data.batch_size + input_token_ids = torch.tensor(input_token_ids_param, dtype=torch.int, device=self.device) + + # execute the model + logits = self.model( + input_token_ids, + self.output_tensor[:n_tokens], + input_meta_data, + self.k_cache, + self.v_cache, + ) + + # sampler + if self.inference_config.pad_input: + logits = logits[:, -1, :] + next_tokens = search_tokens( + self.inference_config.to_generation_config(self.model_config), + logits, + input_meta_data.is_prompts, + input_meta_data.batch_token_ids, + ) + + # return the tokens generated to scheduler + return next_tokens.tolist() + + def _init_output_tensor(self): + alloc_shape = ( + self.inference_config.max_batch_size + * (self.inference_config.max_input_len + self.inference_config.max_output_len), + self.model_config.hidden_size // self.inference_config.tp_size, + ) + self.output_tensor = torch.zeros(alloc_shape, dtype=self.dtype, device=self.device) + + def _init_fd_tensor(self): + fd_inter_tensor = FDIntermTensors() + + if fd_inter_tensor._tensors_initialized: + fd_inter_tensor._reset() + + # For Spec-Dec, process the speculated tokens plus the token in the last step for each seq + max_n_tokens = self.inference_config.max_batch_size + max_n_tokens *= self.inference_config.max_n_spec_tokens + 1 + + inference_config = self.inference_config + kv_max_split_num = ( + inference_config.max_input_len + inference_config.max_output_len + inference_config.block_size - 1 + ) // inference_config.block_size + head_dim = self.model_config.hidden_size // self.model_config.num_attention_heads + + fd_inter_tensor.initialize( + max_batch_size=max_n_tokens, + num_attn_heads=self.model_config.num_attention_heads // self.inference_config.tp_size, + kv_max_split_num=kv_max_split_num, + head_dim=head_dim, + dtype=self.dtype, + device=get_accelerator().get_current_device(), + ) + + self.fd_inter_tensor = fd_inter_tensor + + def _init_model(self, model_or_path: Union[nn.Module, str], model_policy: Policy = None): + """ + Shard model or/and Load weight + + Shard model: When we set tp_size > 1, we will shard the model by given model_policy. + Load Weight: If we pass a local model path, we will load the model weight by checkpoint_io. If it is a remote-transformer url, we will use `AutoModel.from_pretrained` api of transformers lib + + Args: + model_or_path Union[nn.Module, str]: path to the checkpoint or model of transformer format. + model_policy (Policy): the policy to replace the model + """ + + if isinstance(model_or_path, str): + is_local = os.path.isdir(model_or_path) + try: + hf_config = AutoConfig.from_pretrained(model_or_path, trust_remote_code=True) + arch = getattr(hf_config, "architectures")[0] + if is_local: + model = _SUPPORTED_MODELS[arch](hf_config) + else: + # load the real checkpoint + model = _SUPPORTED_MODELS[arch].from_pretrained(model_or_path, trust_remote_code=True) + except Exception as e: + logger.error( + f"An exception occurred during loading model: {e}, model should be loaded by transformers\n" + ) + else: + model = model_or_path + + self.model_config = model.config + + torch.cuda.empty_cache() + init_gpu_memory = torch.cuda.mem_get_info()[0] + + self.device = get_accelerator().get_current_device() + torch.cuda.set_device(self.device) + if self.verbose: + logger.info(f"the device is {self.device}") + + model = model.to(dtype=self.dtype, non_blocking=False).eval() + + if self.verbose: + logger.info( + f"Before the shard, Rank: [{dist.get_rank()}], model size: {get_model_size(model)} GB, model's device is: {model.device}" + ) + + if model_policy is None: + if self.inference_config.pad_input: + model_type = "padding_" + self.model_config.model_type + else: + model_type = "nopadding_" + self.model_config.model_type + model_policy = model_policy_map[model_type]() + + pg_mesh = ProcessGroupMesh(self.inference_config.pp_size, self.inference_config.tp_size) + tp_group = pg_mesh.get_group_along_axis(TP_AXIS) + + self.model = self._shardformer( + model, + model_policy, + None, + tp_group=tp_group, + ) + + self.model = ModelWrapper(model).to(device=get_accelerator().get_current_device()) + + if self.verbose: + logger.info( + f"After the shard, Rank: [{dist.get_rank()}], model size: {get_model_size(self.model)} GB, model's device is: {model.device}" + ) + + if isinstance(model_or_path, str) and is_local: + from colossalai.inference.core.plugin import InferCheckpoint_io + + cpt_io = InferCheckpoint_io() + if_has_index_file, model_index_file = has_index_file(model_or_path) + assert if_has_index_file, "the model path is invalid" + cpt_io.load_model(self.model, model_index_file) + + free_gpu_memory, total_gpu_memory = torch.cuda.mem_get_info() + peak_memory = init_gpu_memory - free_gpu_memory + if self.verbose: + logger.info( + f"Rank [{dist.get_rank()}], Model Weight Max Occupy {peak_memory / (1024 ** 3)} GB, Model size: {get_model_size(self.model)} GB" + ) + + def _shardformer( + self, + model: nn.Module, + model_policy: Policy, + stage_manager: PipelineStageManager = None, + tp_group: ProcessGroupMesh = None, + ) -> nn.Module: + """ + Initialize ShardConfig and replace the model with shardformer. + + Args: + model (nn.Module): Path or nn.Module of this model. + model_policy (Policy): The policy to shardformer model which is determined by the model type. + stage_manager (PipelineStageManager, optional): Used to manage pipeline stages. Defaults to None. + tp_group (ProcessGroupMesh, optional): Used to manage the process TP group mesh. Defaults to None. + + Returns: + nn.Module: The model optimized by Shardformer. + """ + + shardconfig = ShardConfig( + tensor_parallel_process_group=tp_group, + pipeline_stage_manager=stage_manager, + enable_tensor_parallelism=(self.inference_config.tp_size > 1), + enable_fused_normalization=False, + enable_all_optimization=False, + enable_flash_attention=False, + enable_jit_fused=False, + enable_sequence_parallelism=False, + ) + shardformer = ShardFormer(shard_config=shardconfig) + shard_model, _ = shardformer.optimize(model, model_policy) + return shard_model + + def exposed_compute_only_for_test(self): + dist_rank = dist.get_rank() + + # Dummy data for each worker + data = torch.tensor([dist_rank], dtype=torch.float).cuda(dist_rank) + dist.barrier() + + # Perform distributed all_reduce + dist.all_reduce(data, op=dist.ReduceOp.SUM) + + dist.barrier() + logger.info(f"Worker rank {dist_rank}: Sum after all_reduce: {data.item()}") + + return data.item() diff --git a/colossalai/inference/kv_cache/__init__.py b/colossalai/inference/kv_cache/__init__.py index c3beb5545..b232db936 100644 --- a/colossalai/inference/kv_cache/__init__.py +++ b/colossalai/inference/kv_cache/__init__.py @@ -1,4 +1,4 @@ from .block_cache import CacheBlock -from .kvcache_manager import KVCacheManager +from .kvcache_manager import KVCacheManager, RPCKVCacheManager -__all__ = ["CacheBlock", "KVCacheManager"] +__all__ = ["CacheBlock", "KVCacheManager", "RPCKVCacheManager"] diff --git a/colossalai/inference/kv_cache/kvcache_manager.py b/colossalai/inference/kv_cache/kvcache_manager.py index 1b9532a3c..a20bd8ee7 100644 --- a/colossalai/inference/kv_cache/kvcache_manager.py +++ b/colossalai/inference/kv_cache/kvcache_manager.py @@ -497,3 +497,80 @@ class KVCacheManager: k_cache.append(torch.zeros(kalloc_shape, dtype=self.kv_cache_dtype, device=self.device)) v_cache.append(torch.zeros(valloc_shape, dtype=self.kv_cache_dtype, device=self.device)) return k_cache, v_cache + + +class RPCKVCacheManager(KVCacheManager): + def __init__(self, config: InferenceConfig, model_config: PretrainedConfig, verbose: bool = False) -> None: + self.logger = get_dist_logger(__name__) + self.device = get_current_device() + self.config = config + + # Parallel settings + self.tp_size = config.tp_size + # Model settings + self.dtype = config.dtype + self.elem_size_in_bytes = torch.tensor([], dtype=self.dtype).element_size() + self.num_layers = model_config.num_hidden_layers + self.head_num = model_config.num_attention_heads + self.head_size = model_config.hidden_size // self.head_num + if hasattr(model_config, "num_key_value_heads"): + self.kv_head_num = model_config.num_key_value_heads + else: + self.kv_head_num = self.head_num + + if config.kv_cache_dtype is None: + self.kv_cache_dtype = config.dtype + else: + self.kv_cache_dtype = config.kv_cache_dtype + + assert ( + self.kv_head_num % self.tp_size == 0 + ), f"Cannot shard {self.kv_head_num} heads with tp size {self.tp_size}" + self.kv_head_num //= self.tp_size + self.beam_width = config.beam_width + self.max_batch_size = config.max_batch_size + self.max_input_length = config.max_input_len + self.max_output_length = config.max_output_len + # Cache block settings + self.block_size = config.block_size + # NOTE: `num_blocks` is not prompted, but evaluated from the maximum input/output length, and the maximum batch size + self.max_blocks_per_sequence = ( + self.max_input_length + self.max_output_length + self.block_size - 1 + ) // self.block_size + self.num_blocks = self.max_blocks_per_sequence * self.max_batch_size * self.beam_width + + # Logical cache blocks allocation + self._available_blocks = self.num_blocks + self._cache_blocks = tuple(self._init_logical_caches()) + # block availablity state 0->allocated, 1->free + self._block_states = torch.ones((self.num_blocks,), dtype=torch.bool) + self._block_states_cum = torch.zeros(size=(self.num_blocks + 1,), dtype=torch.int64) + self._block_finder = torch.zeros((self.num_blocks,), dtype=torch.int64) + + def get_physical_cache_shape(self) -> Tuple[Tuple[int, ...], Tuple[int, ...]]: + # Physical cache allocation + if self.config.use_cuda_kernel: + x = 16 // torch.tensor([], dtype=self.config.dtype).element_size() + kalloc_shape = (self.num_blocks, self.kv_head_num, self.head_size // x, self.block_size, x) + valloc_shape = (self.num_blocks, self.kv_head_num, self.block_size, self.head_size) + self.logger.info( + f"Allocating K cache with shape: {kalloc_shape}, V cache with shape: {valloc_shape} consisting of {self.num_blocks} blocks." + ) + else: + alloc_shape = (self.num_blocks, self.kv_head_num, self.block_size, self.head_size) + kalloc_shape = alloc_shape + valloc_shape = alloc_shape + self.logger.info(f"Allocating KV cache with shape: {alloc_shape} consisting of {self.num_blocks} blocks.") + return kalloc_shape, valloc_shape + + def get_kv_cache(self): + """Get k_cache and v_cache""" + return NotImplementedError + + def _init_logical_caches(self): + """Initialize the logical cache blocks.""" + blocks = [] + for i in range(self.num_blocks): + cache_block = CacheBlock(i, self.block_size, self.elem_size_in_bytes, k_ptrs=None, v_ptrs=None) + blocks.append(cache_block) + return blocks diff --git a/colossalai/inference/logit_processors.py b/colossalai/inference/logit_processors.py index b7119a221..8e4b29ae6 100644 --- a/colossalai/inference/logit_processors.py +++ b/colossalai/inference/logit_processors.py @@ -1,10 +1,9 @@ # This code is adapted from huggingface transformers: https://github.com/huggingface/transformers/blob/v4.36.2/src/transformers/generation/logits_process.py +from typing import List import torch import torch.nn.functional as F -from colossalai.inference.batch_bucket import BatchBucket - _LOGIT_PROCESSOR_MAP = {} @@ -22,7 +21,7 @@ def register_logit_processor(process_type): @register_logit_processor("no_repeat_ngram_size") -def no_repeat_ngram_size_logit_process(logits, ngram_size: int, batch: BatchBucket): +def no_repeat_ngram_size_logit_process(logits, ngram_size: int, batch_token_ids: List[List[int]]): """ enforces no repetition of n-grams to avoid repetitions of word sequences. """ @@ -31,7 +30,6 @@ def no_repeat_ngram_size_logit_process(logits, ngram_size: int, batch: BatchBuck raise ValueError(f"'temperature={ngram_size}' should be a strictly positive integer.") if ngram_size != 0: - batch_token_ids = batch.batch_token_ids batch_size = len(batch_token_ids) for batch_id in range(batch_size): @@ -55,7 +53,7 @@ def no_repeat_ngram_size_logit_process(logits, ngram_size: int, batch: BatchBuck @register_logit_processor("repetition_penalty") -def repetition_penalty_logit_process(logits, penalty: float, batch: BatchBucket): +def repetition_penalty_logit_process(logits, penalty: float, batch_token_ids: List[List[int]]): """ apply the penalty to the tokens present in the prompt. """ @@ -67,7 +65,6 @@ def repetition_penalty_logit_process(logits, penalty: float, batch: BatchBucket) # TODO(yuehuayingxueluo) This is only a temporary implementation. Later, we will implement presence_penalties, frequency_penalties, and repetition_penalties using CUDA kernels. if penalty != 1.0: - batch_token_ids = batch.batch_token_ids for batch_id in range(len(batch_token_ids)): current_logit = logits[batch_id] current_token = torch.tensor(batch_token_ids[batch_id], dtype=torch.long, device=logits.device) diff --git a/colossalai/inference/modeling/policy/nopadding_baichuan.py b/colossalai/inference/modeling/policy/nopadding_baichuan.py index 2134eff59..78268d6e7 100644 --- a/colossalai/inference/modeling/policy/nopadding_baichuan.py +++ b/colossalai/inference/modeling/policy/nopadding_baichuan.py @@ -1,3 +1,4 @@ +from colossalai.inference.config import RPC_PARAM from colossalai.inference.modeling.layers.baichuan_tp_linear import ( BaichuanLMHeadLinear1D_Col, BaichuanWpackLinear1D_Col, @@ -18,7 +19,7 @@ from colossalai.shardformer.policies.base_policy import ModulePolicyDescription, from colossalai.shardformer.policies.llama import LlamaForCausalLMPolicy -class NoPaddingBaichuanModelInferPolicy(LlamaForCausalLMPolicy): +class NoPaddingBaichuanModelInferPolicy(LlamaForCausalLMPolicy, RPC_PARAM): def __init__(self) -> None: super().__init__() @@ -100,3 +101,10 @@ class NoPaddingBaichuanModelInferPolicy(LlamaForCausalLMPolicy): def postprocess(self): init_to_get_rotary(self.model.model) return self.model + + def to_rpc_param(self) -> str: + return __class__.__name__ + + @staticmethod + def from_rpc_param() -> "NoPaddingBaichuanModelInferPolicy": + return NoPaddingBaichuanModelInferPolicy() diff --git a/colossalai/inference/modeling/policy/nopadding_llama.py b/colossalai/inference/modeling/policy/nopadding_llama.py index 59a3a4e51..24cf7c740 100644 --- a/colossalai/inference/modeling/policy/nopadding_llama.py +++ b/colossalai/inference/modeling/policy/nopadding_llama.py @@ -1,5 +1,6 @@ from transformers.models.llama.modeling_llama import LlamaDecoderLayer, LlamaForCausalLM, LlamaModel, LlamaRMSNorm +from colossalai.inference.config import RPC_PARAM from colossalai.inference.modeling.models.nopadding_llama import ( NopadLlamaAttention, NopadLlamaMLP, @@ -14,7 +15,7 @@ from colossalai.shardformer.policies.base_policy import ModulePolicyDescription, from colossalai.shardformer.policies.llama import LlamaForCausalLMPolicy -class NoPaddingLlamaModelInferPolicy(LlamaForCausalLMPolicy): +class NoPaddingLlamaModelInferPolicy(LlamaForCausalLMPolicy, RPC_PARAM): def __init__(self) -> None: super().__init__() @@ -102,3 +103,10 @@ class NoPaddingLlamaModelInferPolicy(LlamaForCausalLMPolicy): def postprocess(self): init_to_get_rotary(self.model.model, self.model.config.rope_theta) return self.model + + def to_rpc_param(self) -> str: + return __class__.__name__ + + @staticmethod + def from_rpc_param() -> "NoPaddingLlamaModelInferPolicy": + return NoPaddingLlamaModelInferPolicy() diff --git a/colossalai/inference/sampler.py b/colossalai/inference/sampler.py index 7547c32b0..d3857a3bd 100644 --- a/colossalai/inference/sampler.py +++ b/colossalai/inference/sampler.py @@ -1,6 +1,9 @@ -from typing import List, Tuple +from typing import List, Optional, Tuple import torch +from transformers.generation import GenerationConfig + +from colossalai.inference.logit_processors import logit_processor def greedy_sample( @@ -59,3 +62,47 @@ def beam_search_sample( results.append((next_token_ids, parent_ids)) return results + + +def _sample(probs: torch.Tensor, logprobs: torch.Tensor, generation_config: GenerationConfig, is_prompt: bool = False): + if generation_config.num_beams == 1: + if generation_config.do_sample: + sample_tokens = multinomial_sample(generation_config, probs) + else: + sample_tokens = greedy_sample(generation_config, logprobs) + else: + sample_tokens = beam_search_sample(generation_config, logprobs, is_prompt=is_prompt) + + return sample_tokens + + +def search_tokens( + generation_config: GenerationConfig, + logits, + is_prompt: bool = False, + batch_token_ids: Optional[List[List[int]]] = None, +): + """ + Sample tokens for finished requests. + """ + # NOTE: need to decide the granularity to process logits (sequence or batch) + config_dict = generation_config.to_dict() + # process repetition_penalty, no_repeat_ngram_size + for type in ["repetition_penalty", "no_repeat_ngram_size"]: + if type in config_dict and config_dict[type] is not None: + logits = logit_processor(type, logits, config_dict[type], batch_token_ids) + + # do logit processor + if generation_config.do_sample: + # process temperature, top_k, top_p + for type in ["temperature", "top_k", "top_p"]: + if type in config_dict and config_dict[type] is not None: + logits = logit_processor(type, logits, config_dict[type]) + + # calculate probs + probs = torch.softmax(logits, dim=-1, dtype=torch.float) + logprobs = torch.log_softmax(logits, dim=-1, dtype=torch.float) + + # sample the next tokens + sample_tokens = _sample(probs, logprobs, generation_config, is_prompt) + return sample_tokens diff --git a/colossalai/inference/utils.py b/colossalai/inference/utils.py index 9e0d72586..072bedec3 100644 --- a/colossalai/inference/utils.py +++ b/colossalai/inference/utils.py @@ -9,6 +9,8 @@ from typing import Optional, Tuple import torch from torch import nn +from colossalai.testing import free_port + def init_to_get_rotary(self, base=10000, use_elem=False): """ @@ -102,3 +104,12 @@ def get_model_size(model: nn.Module): for key, param in model.named_parameters(): total_size += param.element_size() * param.numel() return total_size / (1024**3) + + +def find_available_ports(num: int): + try: + free_ports = [free_port() for i in range(num)] + except OSError as e: + print(f"An OS error occurred: {e}") + raise RuntimeError("Error finding available ports") + return free_ports diff --git a/requirements/requirements-test.txt b/requirements/requirements-test.txt index 58c7f780f..652ddff04 100644 --- a/requirements/requirements-test.txt +++ b/requirements/requirements-test.txt @@ -19,4 +19,5 @@ datasets pydantic ray peft>=0.7.1 +rpyc==6.0.0 #auto-gptq now not support torch1.12 diff --git a/requirements/requirements.txt b/requirements/requirements.txt index 8ab13c0ad..297b057c1 100644 --- a/requirements/requirements.txt +++ b/requirements/requirements.txt @@ -19,3 +19,4 @@ protobuf transformers==4.36.2 peft>=0.7.1 bitsandbytes>=0.39.0 +rpyc==6.0.0 diff --git a/tests/test_infer/test_rpc_engine.py b/tests/test_infer/test_rpc_engine.py new file mode 100644 index 000000000..12479b49c --- /dev/null +++ b/tests/test_infer/test_rpc_engine.py @@ -0,0 +1,105 @@ +import random + +import numpy as np +import pytest +import torch +from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig + +from colossalai.inference.config import _DEFAULT_PROMPT_TEMPLATES, InferenceConfig +from colossalai.inference.core.rpc_engine import RPCInferenceEngine +from colossalai.inference.modeling.policy import NoPaddingLlamaModelInferPolicy +from colossalai.testing import parameterize, rerun_if_address_is_in_use + + +def setup_seed(seed): + torch.manual_seed(seed) + torch.random.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + np.random.seed(seed) + random.seed(seed) + + +def check_inference_engine(tp_size, use_engine=False, prompt_template=None, do_sample=True, policy=None): + setup_seed(20) + tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/llama-tokenizer") + model = "meta-llama/Llama-2-7b-hf" # remote mode path + inputs = [ + "介绍一下今天的北京,比如故宫,天安门,长城或者其他的一些景点,", + "介绍一下武汉,", + ] + + output_len = 38 + top_p = 0.5 + top_k = 50 + + if use_engine: + inference_config = InferenceConfig( + max_output_len=output_len, + prompt_template=prompt_template, + dtype="fp32", + use_cuda_kernel=True, + tp_size=tp_size, + ) + inference_engine = RPCInferenceEngine(model, tokenizer, inference_config, verbose=True, model_policy=policy) + assert inference_engine.generation_config.max_new_tokens == output_len + inference_engine.add_request(prompts=inputs) + assert inference_engine.request_handler._has_waiting() + generation_config = GenerationConfig( + max_new_tokens=output_len, do_sample=do_sample, dtype="fp32", top_p=top_p, top_k=top_k + ) + outputs = inference_engine.generate(generation_config=generation_config) + else: + if prompt_template: + # apply prompt template + inputs = [_DEFAULT_PROMPT_TEMPLATES[prompt_template].format(input_text=input_text) for input_text in inputs] + model = AutoModelForCausalLM.from_pretrained(model).cuda() + tokenizer.pad_token = tokenizer.eos_token + tokenizer.pad_token_id = tokenizer.eos_token_id + inputs = tokenizer.batch_encode_plus(inputs, padding=True, return_tensors="pt")["input_ids"] + inputs = inputs.cuda() + generation_config = GenerationConfig( + do_sample=do_sample, + dtype="fp32", + top_p=top_p, + top_k=top_k, + pad_token_id=tokenizer.pad_token_id, + max_new_tokens=output_len, + ) + outputs = model.generate(inputs, generation_config=generation_config) + outputs = tokenizer.batch_decode(outputs, skip_special_tokens=True) + + return outputs + + +def run_engine(tp_size, **kwargs): + return check_inference_engine(tp_size=tp_size, **kwargs) + + +@pytest.mark.largedist +@parameterize("prompt_template", [None, "llama"]) +@parameterize("do_sample", [False]) +@rerun_if_address_is_in_use() +def test_tp_engine(prompt_template, do_sample): + if torch.multiprocessing.get_start_method(allow_none=True) is None: + torch.multiprocessing.set_start_method("spawn") + kwargs1 = { + "use_engine": True, + "prompt_template": prompt_template, + "do_sample": do_sample, + "policy": NoPaddingLlamaModelInferPolicy(), + } + + kwargs2 = {"use_engine": False, "prompt_template": prompt_template, "do_sample": do_sample, "policy": None} + + colossal_tp_1_output = run_engine(1, **kwargs1) + colossal_tp_2_output = run_engine(2, **kwargs1) + transformer_tp_1_output = run_engine(1, **kwargs2) + + for s1, s2, s3 in zip(colossal_tp_1_output, colossal_tp_2_output, transformer_tp_1_output): + assert s1 == s3, f"\nColossalAI TP=1 Output: {s1}\nTransformers Output: {s3}" + assert s1 == s2, f"\nColossalAI TP=1 Output: {s1}\nColossalAI TP=2 Output: {s2}" + + +if __name__ == "__main__": + torch.multiprocessing.set_start_method("spawn") # this code will not be ok for settings to fork to subprocess + test_tp_engine()