diff --git a/colossalai/inference/config.py b/colossalai/inference/config.py index ea06335b7..1c159f203 100644 --- a/colossalai/inference/config.py +++ b/colossalai/inference/config.py @@ -3,7 +3,7 @@ from dataclasses import dataclass from typing import Optional, Union import torch -import torch.nn as nn +import torch.distributed as dist GibiByte = 1024**3 @@ -15,44 +15,44 @@ class InferenceConfig: """The inference configuration. Args: - model: Path or nn.Module of this model. - tokenizer: Path of the tokenizer to use. - tokenizer_mode: "auto" will use the fast tokenizer if available, and "slow" will always use the slow tokenizer. - trust_remote_code: Whether to trust remote code from huggingface. - max_batch_size: Maximum batch size. - max_output_len: Maximum output length. - max_input_len: Maximum input length. - block_size: The number of blocks in a logical block. - dtype: The data type for weights and activations. - tp_size: Tensor parallel size. - pp_size: Pipeline parallel size. - max_seq_len: Maximum length of input sentence. - quant_mode: Quantization mode. - revision: The specific version(a branch, name, a commit id, or a tag name) of model to use. - beam_width: The maximum beam width used to initialize KV Cache. + micro_batch_size (int): the micro batch size. Only useful when `pp_size` > 1. + micro_batch_buffer_size (int): the buffer size for micro batch. Normally, it should be the same as the number of pipeline stages. + max_batch_size (int): Maximum batch size. + max_output_len (int): Maximum output length. + max_input_len (int): Maximum input length. + block_size (int): The number of blocks in a logical block. + dtype (Union[str, torch.dtype]): The data type for weights and activations. + tp_size (int): Tensor parallel size. + pp_size (int): Pipeline parallel size. + max_seq_len (int): Maximum length of input sentence. + beam_width (int): The maximum beam width used to initialize KV Cache. During generation, the beam width provided as sampling parameter should be less than or equivalent to this value. - prefill_ratio: A controling ratio for prefill and decoding in running list, we will do a step of prefill + prefill_ratio (Optional[float]): A controling ratio for prefill and decoding in running list, we will do a step of prefill when the actual value exceeds this ratio. + quant_mode (Optional[str]): Quantization mode. + revision (Optional[str]): The specific version(a branch, name, a commit id, or a tag name) of model to use. """ - model: Union[str, nn.Module] - tokenizer: str = None - tokenizer_mode: str = "auto" - trust_remote_code: bool = False - max_batch_size: int = None + micro_batch_size: int = 1 + micro_batch_buffer_size: int = None + max_batch_size: int = 8 max_output_len: int = 256 max_input_len: int = 256 block_size: int = 16 dtype: Union[str, torch.dtype] = torch.float32 tp_size: int = 1 pp_size: int = 1 - max_seq_len: Optional[int] = None - quant_mode: Optional[str] = None - revision: Optional[str] = None - beam_width: int = 1 + max_seq_len: int = 512 # TODO: beam search is not support for now - prefill_ratio: Optional[float] = 1.2 + beam_width: int = 1 # the ratio of prefill sequences to decoding sequences, we do prefill step once the actual value exceeds ratio + prefill_ratio: Optional[float] = 1.2 + quant_mode: Optional[str] = None + revision: Optional[str] = None + + def __post_init__(self): + self._init_batch_size() + self._verify_config() def _init_batch_size(self): """ @@ -75,10 +75,20 @@ class InferenceConfig: f"The maximum batch size is automatically set to {self.max_batch_size} as no value is provided by the user." ) - def __post_init__(self): - self._init_batch_size() - self._verify_args() - - def _verify_args(self): - if self.tokenizer_mode not in ["auto", "slow"]: - raise ValueError("Tokenizer mode must be " "either 'auto' or 'slow'," f"but got {self.tokenizer_mode}") + def _verify_config(self) -> None: + """ + Verify the input config + """ + assert ( + self.tp_size * self.pp_size == dist.get_world_size() + ), f"TP size({self.tp_size}) * PP size({self.pp_size}) should be equal to the global world size ({dist.get_world_size()})" + assert self.dtype in [ + "fp16", + "fp32", + "bf16", + torch.float32, + torch.float16, + torch.bfloat16, + ], "dtype should be one of 'fp16', 'fp32', 'bf16', torch.float32, torch.float16, torch.bfloat16" + assert self.max_batch_size <= 64, "Max batch size exceeds the constraint" + assert self.quant_mode in ["smoothquant", "gptq", None], "quant should be one of 'smoothquant', 'gptq'" diff --git a/colossalai/inference/core/engine.py b/colossalai/inference/core/engine.py index 232bfb188..3aad5ad97 100644 --- a/colossalai/inference/core/engine.py +++ b/colossalai/inference/core/engine.py @@ -1,65 +1,232 @@ -from logging import Logger -from typing import Optional +from itertools import count +from typing import List, Optional, Union -from transformers import AutoConfig +import torch +import torch.nn as nn +from transformers import GenerationConfig, PreTrainedTokenizer, PreTrainedTokenizerFast +from colossalai.cluster import ProcessGroupMesh from colossalai.inference.config import InferenceConfig +from colossalai.inference.modeling.policy import model_policy_map +from colossalai.inference.struct import Sequence +from colossalai.logging import get_dist_logger +from colossalai.pipeline.stage_manager import PipelineStageManager +from colossalai.shardformer import ShardConfig, ShardFormer +from colossalai.shardformer.policies.base_policy import Policy + +from .request_handler import RequestHandler + +PP_AXIS, TP_AXIS = 0, 1 + +_supported_models = [ + "LlamaForCausalLM", +] class InferenceEngine: - """ - InferenceEngine is the core component for Inference. - It is responsible for launch the inference process, including: - - Initialize model and distributed training environment(if needed) - - Launch request_handler and corresponding kv cache manager - - Receive requests and generate texts. - - Log the generation process + """ + InferenceEngine which manages the inference process.. Args: - tokenizer: Path of the tokenizer to use. - inference_config: We provide a unified config api for that wrapped all the configs. You can use it to replace the below configs. + model (nn.Module): Path or nn.Module of this model. + tokenizer (Union[PreTrainedTokenizer, PreTrainedTokenizerFast]): Path of the tokenizer to use. + inference_config (Optional[InferenceConfig], optional): Store the configuration information related to inference. verbose (bool): Determine whether or not to log the generation process. + model_policy ("Policy"): the policy to shardformer model. It will be determined by the model type if not provided. """ def __init__( self, - tokenizer: str = None, + model: nn.Module, + tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast], inference_config: Optional["InferenceConfig"] = None, verbose: bool = False, + model_policy: Policy = None, ) -> None: assert inference_config, "Please provide inference_config." - - self._init_model() - # cache_config may need to be modified later. - # self.request_handler = RequestHandler(cache_config) self.tokenizer = tokenizer - self.hf_model_config = AutoConfig.from_pretrained( - self.model, trust_remote_code=self.trust_remote_code, revision=self.revision + self.inference_config = inference_config + self.model_config = model.config + + if inference_config.dtype == "fp32" or inference_config.dtype == torch.float32: + self.dtype = torch.float32 + elif inference_config.dtype == "fp16" or inference_config.dtype == torch.float16: + self.dtype = torch.float16 + model.half() + else: + self.dtype = torch.bfloat16 + model.to(torch.bfloat16) + + if model_policy is None: + model_policy = model_policy_map[self.model_config.model_type]() + + pg_mesh = ProcessGroupMesh(inference_config.pp_size, inference_config.tp_size) + + self.model = self._shardformer( + model, + model_policy, + None, + pg_mesh.get_group_along_axis(TP_AXIS) if inference_config.pp_size * inference_config.tp_size > 1 else None, ) + + self.verbose = verbose if verbose: - self.logger = Logger() + self.logger = get_dist_logger(__name__) + + self.request_handler = RequestHandler(self.inference_config, self.model_config) + self.counter = count() + + def _verify_config(self) -> None: + """ + Verify the input config + """ + if not isinstance(self.model, nn.Module): + raise TypeError(f"the model type must be nn.Module, but get {type(self.model)}") + if not isinstance(self.tokenizer, PreTrainedTokenizerFast) and not isinstance( + self.tokenizer, PreTrainedTokenizer + ): + raise TypeError( + f"the tokenizer type must be PreTrainedTokenizer or PreTrainedTokenizerFast, but get {type(self.tokenizer)}" + ) + assert ( + self.model.__class__.__name__ in _supported_models + ), f"Model {self.model.__class__.__name__} is not supported." + + def _shardformer( + self, + model: nn.Module, + model_policy: Policy, + stage_manager: PipelineStageManager = None, + tp_group: ProcessGroupMesh = None, + ) -> nn.Module: + """ + Initialize ShardConfig and replace the model with shardformer. + + Args: + model (nn.Module): Path or nn.Module of this model. + model_policy (Policy): The policy to shardformer model which is determined by the model type. + stage_manager (PipelineStageManager, optional): Used to manage pipeline stages. Defaults to None. + tp_group (ProcessGroupMesh, optional): Used to manage the process TP group mesh. Defaults to None. + + Returns: + nn.Module: _description_ + """ + shardconfig = ShardConfig( + tensor_parallel_process_group=tp_group, + pipeline_stage_manager=stage_manager, + enable_tensor_parallelism=(self.inference_config.tp_size > 1), + enable_fused_normalization=False, + enable_all_optimization=False, + enable_flash_attention=False, + enable_jit_fused=False, + enable_sequence_parallelism=False, + extra_kwargs={"quant": self.inference_config.quant_mode}, + ) + shardformer = ShardFormer(shard_config=shardconfig) + shard_model, _ = shardformer.optimize(model, model_policy) + return shard_model.cuda() - def _init_model(self): + def generate( + self, + generation_config: GenerationConfig = None, + ) -> List[str]: """ - Initialize model and distributed training environment(if needed). - May need to provide two different initialization methods: - 1. 用户自定义(from local path) - 2. 从checkpoint加载(hugging face) + Executing the inference step. + + Args: + generation_config (GenerationConfig, optional): Huggingface GenerationConfig used for inference. Defaults to None. + + Returns: + List[str]: Inference result returned by one generation. """ - def _verify_config(self): + self.generation_config = generation_config + + output_list = [] + + while self.request_handler.check_unfinished_seqs(): + output_list += self.step() + + return output_list + + def add_request( + self, + requests_id: List[int] = None, + prompts: List[str] = None, + prompts_token_ids: List[int] = None, + ) -> None: """ - Verify the configuration to avoid potential bugs. + Add requests. + + Args: + requests_id (List[int], optional): The request ID. Defaults to None. + prompts (Union[List[str], optional): Input prompts. Defaults to None. + prompts_token_ids (List[List[int]], optional): token ids of input prompts. Defaults to None. """ - def generate(self): - pass + block_size = self.inference_config.block_size - def step(self): + if prompts_token_ids is None: + assert prompts, "When the prompts_token_ids is none, the input prompt list must be provided." + prompts_token_ids = [] + for prompt in prompts: + prompts_token_ids.append(self.tokenizer.encode(prompt)) + + prompts_num = len(prompts_token_ids) + + for i in range(prompts_num): + if requests_id: + request_id = requests_id[i] + else: + request_id = next(self.counter) + if prompts == None: + prompt = None + else: + prompt = prompts[i] + sequence = Sequence( + request_id, + prompt, + prompts_token_ids[i], + block_size, + None, + None, + self.tokenizer.eos_token_id, + self.inference_config.max_output_len, + ) + self.request_handler.add_sequence(sequence) + + def step(self) -> List[str]: """ In each step, do the follows: - 1. Run request_handler to update the kv cache and running input_ids + 1. Run RequestHandler.schedule() and get the batch used for inference. 2. Run model to generate the next token - 3. Check whether there is finied request and decode + 3. Update waiting list and running list in RequestHandler and get finished sequences. + 4. Decode and return finished sequences. + + Returns: + List[str]: Decoded finished sequences generated by one step. """ + + if self.verbose: + self.logger.info("Running generation step") + + output_list = [] + self.request_handler.schedule() + + # Uncomment if the development of RequestHandler is completed. + # logits = self.model(batch) + # self.request_handler.search_tokens(logits, self.generation_config) + + finished_sequences = self.request_handler.update() + + # Decode completed sentences. + for seq in finished_sequences: + if seq.prompt: + output_str = self.tokenizer.decode(seq.output_token_id, skip_special_tokens=True) + output_list.append(seq.prompt + output_str) + else: + output_str = self.tokenizer.decode(seq.input_token_id + seq.output_token_id, skip_special_tokens=True) + output_list.append(output_str) + + return output_list diff --git a/colossalai/inference/core/request_handler.py b/colossalai/inference/core/request_handler.py index e7898879a..bfa26de7c 100644 --- a/colossalai/inference/core/request_handler.py +++ b/colossalai/inference/core/request_handler.py @@ -1,5 +1,7 @@ from typing import List +from colossalai.inference.struct import BatchInfo, Sequence + class RequestHandler: """ @@ -7,14 +9,17 @@ class RequestHandler: During generation process, we call schedule function each iteration to update current batch. Args: - cache_config: Configuration for initialize and manage kv cache. + inference_config: Store the configuration information related to inference. + model_config: The huggingface model config. """ - def __init__(self, cache_config) -> None: - self.cache_config = cache_config + def __init__(self, inference_config, model_config) -> None: + self.inference_config = inference_config + self.model_config = model_config self._init_cache() - self.waiting_list: List["Reqseq"] = [] - self.running_list: List["Reqseq"] = [] + self.waiting_list: List["Sequence"] = [] + self.running_list: List["Sequence"] = [] + self.batch = BatchInfo.init_batch() def _init_cache(self): """ @@ -25,12 +30,17 @@ class RequestHandler: """ The main logic of request handler. """ + # The code below is only used for testing engine and will be modified. + if self.waiting_list: + self.running_list = self.waiting_list + self.batch.add_seqs(self.running_list) + return self.batch - def add_sequence(self, reqseq: "Reqseq"): + def add_sequence(self, req_seq: "Sequence"): """ Add the request to waiting list. """ - self.waiting_list.append(reqseq) + self.waiting_list.append(req_seq) def abort_sequence(self, seq_id: str): """ @@ -39,10 +49,23 @@ class RequestHandler: self._find_sequence(seq_id) return - def _find_sequence(self, seq_id: str) -> "Reqseq": + def _find_sequence(self, seq_id: str) -> "Sequence": """ Find the request by seq_id. """ def check_unfinished_seqs(self) -> bool: - return self.waiting_list or self.running_list + return len(self.waiting_list) != 0 or len(self.running_list) != 0 + + def update(self): + """ + Update the waiting list and running list. + """ + + # The code below is only used for testing engine and will be modified. + self.waiting_list = [] + self.running_list = [] + finished_sequences = list(self.batch.sequences_set) + + self.batch.clear_batch() + return finished_sequences diff --git a/colossalai/inference/kv_cache/kvcache_manager.py b/colossalai/inference/kv_cache/kvcache_manager.py index 493613d68..8c3b207e1 100644 --- a/colossalai/inference/kv_cache/kvcache_manager.py +++ b/colossalai/inference/kv_cache/kvcache_manager.py @@ -135,7 +135,7 @@ class KVCacheManager: and updates the provided block table with the allocated block ids. Args: - block_table: A 1D tensor of shape [max_blocks_per_sequence], storing mapping of token_position_id -> block_id. + block_table: A 1D tensor of shape [max_blocks_per_sequence] holded by a sequence, storing mapping of token_position_id -> block_id. context_len: The length of the processing sequnece. """ assert block_table.dim() == 1 @@ -185,7 +185,7 @@ class KVCacheManager: and updates the provided block table if a new cache block is needed. Args: - block_table: A 1D tensor of shape [max_blocks_per_sequence], storing mapping of token_position_id -> block_id. + block_table: A 1D tensor of shape [max_blocks_per_sequence] holded by a sequence, storing mapping of token_position_id -> block_id. context_len: The length of the processing sequnece (already-allocated length). """ assert block_table.dim() == 1 @@ -199,7 +199,7 @@ class KVCacheManager: and updates the provided block table with the allocated block. Args: - block_table: A 1D tensor of shape [max_blocks_per_sequence], storing mapping of token_position_id -> block_id. + block_table: A 1D tensor of shape [max_blocks_per_sequence] holded by a sequence, storing mapping of token_position_id -> block_id. block_local_idx: The index of the block in the block table. space_asked: i.e. The number of tokens to be assigned space for. Returns: diff --git a/colossalai/inference/modeling/policy/__init__.py b/colossalai/inference/modeling/policy/__init__.py new file mode 100644 index 000000000..100993941 --- /dev/null +++ b/colossalai/inference/modeling/policy/__init__.py @@ -0,0 +1,7 @@ +from .llama import LlamaModelInferPolicy + +model_policy_map = { + "llama": LlamaModelInferPolicy, +} + +__all__ = ["LlamaModelInferPolicy", "model_polic_map"] diff --git a/colossalai/inference/modeling/policy/llama.py b/colossalai/inference/modeling/policy/llama.py new file mode 100644 index 000000000..f747eedef --- /dev/null +++ b/colossalai/inference/modeling/policy/llama.py @@ -0,0 +1,7 @@ +from colossalai.shardformer.policies.llama import LlamaForCausalLMPolicy + + +class LlamaModelInferPolicy(LlamaForCausalLMPolicy): + # The code here just for test and will be modified later. + def __init__(self) -> None: + super().__init__() diff --git a/colossalai/inference/struct.py b/colossalai/inference/struct.py index a5201d787..3a9064dcf 100644 --- a/colossalai/inference/struct.py +++ b/colossalai/inference/struct.py @@ -1,68 +1,82 @@ import enum from dataclasses import dataclass -from typing import Dict, List, Set +from typing import List, Union + +import torch +from ordered_set import OrderedSet + +from colossalai.logging import get_dist_logger + +logger = get_dist_logger(__name__) """ The abstraction of request and sequence are defined here. """ -class RequsetStatus(enum.Enum): - """The status of Sentences""" +class RequestStatus(enum.Enum): + """ + The status of Sentences + """ + # running status WAITING = enum.auto() - RUNNING = enum.auto() + PREFILL = enum.auto() + TOKEN = enum.auto() ABORTED = enum.auto() + + # completion status OVERLENGTH = enum.auto() COMPLETED = enum.auto() LENGTH_CAPPED = enum.auto() @staticmethod - def is_finished(status: "RequsetStatus") -> bool: + def is_finished(status: "RequestStatus") -> bool: return status in [ - RequsetStatus.OVERLENGTH, - RequsetStatus.COMPLETED, - RequsetStatus.LENGTH_CAPPED, + RequestStatus.OVERLENGTH, + RequestStatus.COMPLETED, + RequestStatus.LENGTH_CAPPED, ] @staticmethod - def is_running(status: "RequsetStatus") -> bool: - return status == RequsetStatus.RUNNING + def is_running(status: "RequestStatus") -> bool: + return status in [ + RequestStatus.PREFILL, + RequestStatus.TOKEN, + ] @staticmethod - def is_waiting(status: "RequsetStatus") -> bool: - return status == RequsetStatus.WAITING + def is_waiting(status: "RequestStatus") -> bool: + return status == RequestStatus.WAITING +@dataclass class Sequence: """Store information of input sequence. Args: - request_id: The ID of input sequence. - prompt: The prompt of input sequence. - token_id: The tokens ID of input sequence. - block_size: The block size of input sequence. - sample_params: The sample_params of input sequence. - block_table_index: The index of input sequence in block_table. + request_id (int): The ID of input sequence. + prompt (str): The prompt of input sequence. + input_token_id (List[int]): The tokens ID of input sequence. + block_size (int): The block size of input sequence. + sample_params (SampleParams): The sample_params of input sequence. + block_table (torch.Tensor): The index of input sequence in block_table. + eos_token_id (int): The eos token id for this inference process. + max_output_len (int): Maximum output length. """ - def __init__( - self, - request_id: int, - prompt: str, - token_id: List[int], - block_size: int, - sample_params, # SampleParams needs to be imported later. - block_table_index: int, - ): - self.request_id = request_id - self.prompt = prompt - self.input_token_id = token_id - self.blokc_size = block_size - self.sample_params = sample_params + request_id: int + prompt: str + input_token_id: List[int] + block_size: int + sample_params: any # SampleParams needs to be imported later. + block_table: torch.Tensor + eos_token_id: int + max_output_len: int = 256 + + def __post_init__(self): self.output_token_id = [] - self.status = RequsetStatus.WAITING - self.block_table_index = block_table_index + self.status = RequestStatus.WAITING def get_sentence_len(self) -> None: """ @@ -84,17 +98,30 @@ class Sequence: def check_finish(self) -> bool: """ - Check whether inference is over. + Check whether the inference is finished. + + Returns: + bool: Whether the inference is finished. """ - return RequsetStatus.is_finished(self.status) + if RequestStatus.is_finished(self.status): + return True + + if self.output_token_id: + if self.output_token_id[-1] == self.eos_token_id or len(self.output_token_id) == self.max_output_len: + self.status = RequestStatus.COMPLETED + return True + + return False + + def __hash__(self): + return hash(self.request_id) def __repr__(self) -> str: return ( f"Request ID(request_id={self.request_id}, " f"prompt={self.prompt}, " f"status={self.status.name}, " - f"sample_params={self.sample_params}, " - f"logical block number={len(self._logical_blocks)}" + f"sample_params={self.sample_params}" ) @@ -104,34 +131,38 @@ class BatchInfo: Information to be passed and used for a batch of sequences. """ - sequences_set: Set[Sequence] - block_table: Dict[int, int] = None + sequences_set: OrderedSet["Sequence"] @classmethod - def init_batch(cls, seqs: List[Sequence]) -> "BatchInfo": + def init_batch(cls, seqs: List["Sequence"] = None) -> "BatchInfo": """ Initializes inference batches by input sentence list. Args: - seqs (List[Sequence]): List of input sequence. + seqs (List["Sequence"]): List of input sequence. """ - sequences_set = set() - block_table = {} - for seq in seqs: - if seq in sequences_set: - assert ( - seq.request_id in block_table.keys() - ), "The sequence has been added to sequences_set, but it has not been added to block_table." - continue - assert ( - seq.request_id not in block_table.keys() - ), "The sequence has not been added to sequences_set, but it is already in block_table." + sequences_set = OrderedSet() + + if seqs is not None: + if not isinstance(seqs, list): + seqs = [seqs] + for seq in seqs: + if seq in sequences_set: + logger.warning(f"The sequence(request_id {seq.request_id}) is already in sequences_set.") + continue - sequences_set.add(seq) - block_table[seq.request_id] = seq.block_table_index + sequences_set.add(seq) - return cls(sequences_set=sequences_set, block_table=block_table) + return cls(sequences_set=sequences_set) + + def get_block_table_tensor(self): + tesnor_list = [] + for seq in self.sequences_set: + block_table = seq.block_table + assert block_table, f"The sequence(request_id {seq.request_id}) has not initialized the block_table." + tesnor_list.append(seq.block_table) + return torch.concat(tesnor_list) def clear_batch(self) -> None: """ @@ -139,35 +170,76 @@ class BatchInfo: """ for seq in self.sequences_set: if not seq.check_finish(): - seq.status = RequsetStatus.ABORTED + seq.status = RequestStatus.ABORTED self.sequences_set.clear() - self.block_table.clear() - def fliter_batch(self) -> None: + def fliter_batch(self) -> List["Sequence"]: """ Remove completed sentences from a batch. + + Returns: + List["Sequence"]: List of finished sequences. """ - for seq in self.sequences_set.copy(): + finish_seqs = [] + for seq in self.sequences_set: if seq.check_finish(): - self.sequences_set.remove(seq) - del self.block_table[seq.request_id] + finish_seqs.append(seq) + for finish_seq in finish_seqs: + self.sequences_set.discard(finish_seq) + return finish_seqs - def add_seqs(self, seqs: List[Sequence]) -> None: + def abort_seq(self, seq: "Sequence") -> "Sequence": + """ + Remove sequence from the batch. + """ + if not seq.check_finish(): + seq.status = RequestStatus.ABORTED + self.sequences_set.discard(seq) + return seq + + def add_seqs(self, seqs: List["Sequence"]) -> None: """ Add new sequence to batch Args: - seqs (List[Sequence]): The list of new sequences. + seqs (List["Sequence"]): The list of new sequences. """ + + if not isinstance(seqs, list): + seqs = [seqs] + for seq in seqs: if seq in self.sequences_set: - print("The sequence is already in sequences_set.") - assert ( - seq.request_id in self.block_table - ), "The sequence has been added to sequences_set, but it has not been added to block_table." + logger.warning(f"The sequence(request_id {seq.request_id}) is already in sequences_set.") continue - assert ( - seq.request_id not in self.block_table - ), "The sequence has not been added to sequences_set, but it is already in block_table." self.sequences_set.add(seq) - self.block_table[seq.request_id] = seq.block_table_index + + def is_empty(self) -> None: + """ + Check whether sequences_set is empty. + """ + return not self.sequences_set + + def update_batch_tokens(self, tokens: Union[List[int], List[List[int]]]) -> None: + """ + Add an output token for each sentence in the batch. + + Args: + tokens (List[int]): A batch of tokens + """ + + assert self.get_batch_size() == len(tokens), "The number of tokens does not match batch_size." + + for seq, token in zip(self.sequences_set, tokens): + if not isinstance(token, list): + if not isinstance(token, int): + raise TypeError(f"The token type must be List[int] or int, but get {type(token)}.") + token = [token] + seq.output_token_id += token + seq.check_finish() + + def get_batch_size(self) -> int: + """ + Get batch_size of this batch + """ + return len(self.sequences_set) diff --git a/requirements/requirements-infer.txt b/requirements/requirements-infer.txt index f85f9d88e..2d85300c3 100644 --- a/requirements/requirements-infer.txt +++ b/requirements/requirements-infer.txt @@ -1,4 +1,5 @@ +ordered_set transformers==4.34.0 auto-gptq==0.5.0 git+https://github.com/ModelTC/lightllm.git@ece7b43f8a6dfa74027adc77c2c176cff28c76c8 -git+https://github.com/Dao-AILab/flash-attention.git@017716451d446e464dde9aca3a3c1ed2209caaa9 +git+https://github.com/Dao-AILab/flash-attention.git@017716451d446e464dde9aca3a3c1ed2209caaa9 \ No newline at end of file diff --git a/requirements/requirements-test.txt b/requirements/requirements-test.txt index 4136cefc3..a9d8b2363 100644 --- a/requirements/requirements-test.txt +++ b/requirements/requirements-test.txt @@ -1,4 +1,6 @@ diffusers +fbgemm-gpu==0.2.0 +ordered_set pytest coverage==7.2.3 git+https://github.com/hpcaitech/pytest-testmon diff --git a/tests/test_infer/_utils.py b/tests/test_infer/_utils.py old mode 100644 new mode 100755 diff --git a/tests/test_infer/test_config_and_struct.py b/tests/test_infer/test_config_and_struct.py old mode 100644 new mode 100755 index 329165025..c5302c206 --- a/tests/test_infer/test_config_and_struct.py +++ b/tests/test_infer/test_config_and_struct.py @@ -1,26 +1,45 @@ +import pytest + +import colossalai from colossalai.inference.config import InferenceConfig -from colossalai.inference.struct import BatchInfo, RequsetStatus, Sequence +from colossalai.inference.struct import BatchInfo, Sequence +from colossalai.testing import spawn -def test_config_and_inferenceData(): - config = InferenceConfig("/llama") - assert config.max_batch_size +def check_config_and_inference(): + config = InferenceConfig() + assert config.max_batch_size == 8 sequence = Sequence( request_id=1, prompt="abc", - token_id=[1, 2, 3], + input_token_id=[1, 2, 3], block_size=16, sample_params=None, - block_table_index=1, + block_table=None, + eos_token_id=2, + max_output_len=256, ) sequence2 = Sequence( request_id=2, prompt="bcd", - token_id=[4, 5, 6], + input_token_id=[4, 5, 6], + block_size=16, + sample_params=None, + block_table=None, + eos_token_id=2, + max_output_len=256, + ) + + sequence3 = Sequence( + request_id=3, + prompt="efg", + input_token_id=[7, 8, 9], block_size=16, sample_params=None, - block_table_index=2, + block_table=None, + eos_token_id=2, + max_output_len=256, ) assert sequence.get_sentence_len() == 3 @@ -29,15 +48,34 @@ def test_config_and_inferenceData(): assert sequence.check_finish() == False batch = BatchInfo.init_batch([sequence]) - assert batch.block_table[sequence.request_id] == sequence.block_table_index - sequence.status = RequsetStatus.COMPLETED - batch.fliter_batch() - assert batch.block_table == {} - batch.add_seqs([sequence2]) - assert batch.block_table[sequence2.request_id] == sequence2.block_table_index + batch.add_seqs([sequence2, sequence3]) + batch.add_seqs([sequence]) + + assert batch.is_empty() == False + assert batch.get_batch_size() == 3 + batch.update_batch_tokens([1, 2, 3]) + seq = batch.abort_seq(sequence) + seq2 = batch.fliter_batch()[0] + + assert batch.get_batch_size() == 1 + assert seq.get_output_len() == 1 + assert seq.output_token_id == [1] + assert seq2.get_output_len() == 1 + assert seq2.output_token_id == [2] + batch.clear_batch() - assert batch.block_table == {} + assert batch.is_empty() == True + + +def run_dist(rank, world_size, port): + colossalai.launch(config={}, rank=rank, world_size=world_size, port=port, host="localhost") + check_config_and_inference() + + +@pytest.mark.dist +def test_config_and_inference(): + spawn(run_dist, 1) if __name__ == "__main__": - test_config_and_inferenceData() + test_config_and_inference() diff --git a/tests/test_infer/test_inference_engine.py b/tests/test_infer/test_inference_engine.py new file mode 100755 index 000000000..ec1f85b4c --- /dev/null +++ b/tests/test_infer/test_inference_engine.py @@ -0,0 +1,44 @@ +import pytest +import transformers +from transformers import AutoTokenizer + +import colossalai +from colossalai.inference.config import InferenceConfig +from colossalai.inference.core.engine import InferenceEngine +from colossalai.testing import spawn + + +def check_inference_engine(): + model = transformers.LlamaForCausalLM( + transformers.LlamaConfig( + vocab_size=20000, hidden_size=512, intermediate_size=1536, num_attention_heads=4, num_hidden_layers=4 + ) + ) + tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/llama-tokenizer") + inference_config = InferenceConfig() + inference_engine = InferenceEngine(model, tokenizer, inference_config, verbose=True) + + inputs = [ + "介绍一下北京", + "介绍一下武汉", + ] + + inference_engine.add_request(prompts=inputs) + outputs = inference_engine.generate(None) + + for s1, s2 in zip(inputs, outputs): + assert s1 == s2 + + +def run_dist(rank, world_size, port): + colossalai.launch(config={}, rank=rank, world_size=world_size, port=port, host="localhost") + check_inference_engine() + + +@pytest.mark.dist +def test_inference_engine(): + spawn(run_dist, 1) + + +if __name__ == "__main__": + test_inference_engine() diff --git a/tests/test_infer/test_kvcache_manager.py b/tests/test_infer/test_kvcache_manager.py old mode 100644 new mode 100755 index 5187727f1..c5868a30e --- a/tests/test_infer/test_kvcache_manager.py +++ b/tests/test_infer/test_kvcache_manager.py @@ -1,12 +1,14 @@ import random +import pytest import torch from transformers.models.llama import LlamaConfig +import colossalai from colossalai.inference.config import InferenceConfig from colossalai.inference.kv_cache import CacheBlock, KVCacheManager from colossalai.logging import disable_existing_loggers -from colossalai.testing import parameterize +from colossalai.testing import parameterize, spawn @parameterize( @@ -64,7 +66,7 @@ def test_logical_blocks(test_config): }, ], ) -def test_cache_manager(test_config): +def check_cache_manager(test_config): disable_existing_loggers() assert test_config["max_batch_size"] > 1 @@ -78,7 +80,7 @@ def test_cache_manager(test_config): max_input_length = test_config["max_input_len"] max_output_length = test_config["max_output_len"] - inference_config = InferenceConfig(model="", **test_config) + inference_config = InferenceConfig(**test_config) model_config = LlamaConfig( hidden_size=hidden_size, num_hidden_layers=num_layers, @@ -147,6 +149,16 @@ def test_cache_manager(test_config): assert cache_manager.get_num_available_blocks() == num_blocks +def run_dist(rank, world_size, port): + colossalai.launch(config={}, rank=rank, world_size=world_size, port=port, host="localhost") + check_cache_manager() + + +@pytest.mark.dist +def test_cache_manager(): + spawn(run_dist, 1) + + if __name__ == "__main__": test_logical_blocks() test_cache_manager()