diff --git a/colossalai/inference/README.md b/colossalai/inference/README.md index 33131f5f1..6131dacc3 100644 --- a/colossalai/inference/README.md +++ b/colossalai/inference/README.md @@ -86,7 +86,7 @@ colossalai.launch_from_torch(config={}) # Step 1: create a model in "transformers" way model_path = "lmsys/vicuna-7b-v1.3" model = transformers.LlamaForCausalLM.from_pretrained(model_path).cuda() -tokenizer = transformers.LlamaTokenizer.from_pretrained(model_path) +tokenizer = transformers.AutoTokenizer.from_pretrained(model_path) # Step 2: create an inference_config inference_config = InferenceConfig( @@ -100,13 +100,8 @@ inference_config = InferenceConfig( engine = InferenceEngine(model, tokenizer, inference_config, verbose=True) # Step 4: try inference -generation_config = transformers.GenerationConfig( - pad_token_id=tokenizer.pad_token_id, - max_new_tokens=512, - ) prompts = ['Who is the best player in the history of NBA?'] -engine.add_request(prompts=prompts) -response = engine.generate(generation_config) +response = engine.generate(prompts) pprint(response) ``` @@ -150,13 +145,16 @@ Notations: - [x] Paged Attention - [x] High-Performance Kernels - [x] Llama Modelling +- [x] User Documentation +- [ ] Speculative Decoding - [ ] Tensor Parallelism - [ ] Beam Search -- [ ] Speculative Decoding +- [ ] Early stopping +- [ ] Logger system +- [ ] SplitFuse - [ ] Continuous Batching - [ ] Online Inference - [ ] Benchmarking -- [ ] User Documentation ## 🌟 Acknowledgement diff --git a/colossalai/inference/config.py b/colossalai/inference/config.py index 613afcacd..a87cbaa70 100644 --- a/colossalai/inference/config.py +++ b/colossalai/inference/config.py @@ -8,6 +8,7 @@ from typing import Optional, Union import torch import torch.distributed as dist +from transformers.generation import GenerationConfig GibiByte = 1024**3 @@ -60,15 +61,22 @@ class InferenceConfig: max_input_len: int = 256 block_size: int = 16 dtype: Union[str, torch.dtype] = torch.float16 # use fp16 by default + tp_size: int = 1 pp_size: int = 1 # TODO: beam search is not support for now + do_sample: bool = False beam_width: int = 1 # the ratio of prefill sequences to decoding sequences, we do prefill step once the actual value exceeds ratio prefill_ratio: Optional[float] = 1.2 pad_input: bool = False quant_mode: Optional[str] = None revision: Optional[str] = None + early_stopping: Optional[bool] = False + + top_k: Optional[int] = None + top_p: Optional[float] = None + min_p: Optional[float] = None prompt_template: Optional[str] = None def __post_init__(self): @@ -93,7 +101,6 @@ class InferenceConfig: assert ( self.tp_size * self.pp_size == dist.get_world_size() ), f"TP size({self.tp_size}) * PP size({self.pp_size}) should be equal to the global world size ({dist.get_world_size()})" - # check prompt template if self.prompt_template is None: return @@ -105,3 +112,20 @@ class InferenceConfig: assert ( "{input_text}" in self.prompt_template ), "The prompt template should contain '{input_text}' for formatting the input text. For example: 'USER: {input_text}\n\nASSISTANT: '" + + def to_generation_config(self, model_config) -> GenerationConfig: + meta_config = { + "max_length": self.max_input_len + self.max_output_len, + "max_new_tokens": self.max_output_len, + "early_stopping": self.early_stopping, + "do_sample": self.do_sample, + "num_beams": self.beam_width, + } + for type in ["top_k", "top_p", "min_p"]: + if hasattr(self, type): + meta_config[type] = getattr(self, type) + for type in ["pad_token_id", "bos_token_id", "eos_token_id"]: + if hasattr(model_config, type): + meta_config[type] = getattr(model_config, type) + + return GenerationConfig.from_dict(meta_config) diff --git a/colossalai/inference/core/engine.py b/colossalai/inference/core/engine.py index d97d70ad5..765fd9f04 100644 --- a/colossalai/inference/core/engine.py +++ b/colossalai/inference/core/engine.py @@ -33,7 +33,7 @@ class InferenceEngine: Args: model (nn.Module): Path or nn.Module of this model. - tokenizer (Union[PreTrainedTokenizer, PreTrainedTokenizerFast]): Path of the tokenizer to use. + tokenizer Optional[(Union[PreTrainedTokenizer, PreTrainedTokenizerFast])]: Path of the tokenizer to use. inference_config (Optional[InferenceConfig], optional): Store the configuration information related to inference. verbose (bool): Determine whether or not to log the generation process. model_policy ("Policy"): the policy to shardformer model. It will be determined by the model type if not provided. @@ -42,19 +42,20 @@ class InferenceEngine: def __init__( self, model: nn.Module, - tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast], - inference_config: Optional["InferenceConfig"] = None, + tokenizer: [Union[PreTrainedTokenizer, PreTrainedTokenizerFast]], + inference_config: InferenceConfig, verbose: bool = False, model_policy: Policy = None, ) -> None: assert inference_config, "Please provide inference_config." - self.tokenizer = tokenizer - self.tokenizer.pad_token = self.tokenizer.eos_token + assert tokenizer, "Please provide a tokenizer, either a defined one or str" self.inference_config = inference_config self.model_config = model.config self.device = torch.device("cuda") self.dtype = inference_config.dtype - + self.tokenizer = tokenizer + self.tokenizer.pad_token = self.tokenizer.eos_token + self.generation_config = inference_config.to_generation_config(self.model_config) model = model.eval() model.to(self.dtype) @@ -80,6 +81,8 @@ class InferenceEngine: self.request_handler = RequestHandler(self.inference_config, self.model_config) self.k_cahce, self.v_cache = self.request_handler.get_kvcache() + # DISCUSS maybe move this into batch info? + self.counter = count() def _verify_config(self) -> None: @@ -137,7 +140,7 @@ class InferenceEngine: self, prompts: List[str] = None, prompts_token_ids: Union[List[int], torch.Tensor, np.ndarray] = None, - generation_config: GenerationConfig = None, + generation_config: Optional[GenerationConfig] = None, ) -> List[str]: """ Executing the inference step. @@ -158,6 +161,10 @@ class InferenceEngine: output_seqs_list = [] output_tokens_list = [] + # intuition: If user provide a generation config, we should replace the existing one. + if generation_config is not None: + self.generation_config = generation_config + while self.request_handler.check_unfinished_seqs(): output_seqs_list += self.step() @@ -285,8 +292,8 @@ class InferenceEngine: if self.inference_config.pad_input: logits = logits[:, -1, :] - self.request_handler.search_tokens(self.generation_config, logits) + finished_sequences = self.request_handler.update() return finished_sequences diff --git a/colossalai/inference/core/request_handler.py b/colossalai/inference/core/request_handler.py index 85e41ea73..7e66cfe31 100644 --- a/colossalai/inference/core/request_handler.py +++ b/colossalai/inference/core/request_handler.py @@ -2,6 +2,7 @@ from typing import List import torch from transformers.configuration_utils import PretrainedConfig +from transformers.generation import GenerationConfig from colossalai.inference.config import InferenceConfig from colossalai.inference.flash_decoding_utils import FDIntermTensors @@ -94,6 +95,10 @@ class RequestHandler: head_dim = model_config.hidden_size // model_config.num_attention_heads fd_inter_tensor = FDIntermTensors() + + if fd_inter_tensor._tensors_initialized: + fd_inter_tensor._reset() + fd_inter_tensor.initialize( max_batch_size=self.max_batch_size, num_attn_heads=model_config.num_attention_heads, @@ -170,6 +175,7 @@ class RequestHandler: self.cache_manager.allocate_context_from_block_table(seq.block_table, seq.sentence_len) for seq in remove_list: lst.remove(seq) + if self.running_list.ready_for_prefill(): for seq in self.running_list.prefill: seq.mark_running() @@ -229,7 +235,7 @@ class RequestHandler: return None - def _sample(self, probs: torch.Tensor, logprobs: torch.Tensor, generation_config): + def _sample(self, probs: torch.Tensor, logprobs: torch.Tensor, generation_config: GenerationConfig): if generation_config.num_beams == 1: if generation_config.do_sample: sample_tokens = multinomial_sample(generation_config, probs) @@ -240,7 +246,7 @@ class RequestHandler: return sample_tokens - def mark_finished(self, sequence: Sequence, generation_config): + def mark_finished(self, sequence: Sequence, generation_config: GenerationConfig): if ( sequence.output_token_id[-1] == generation_config.eos_id or sequence.output_len >= generation_config.max_output_len @@ -250,7 +256,7 @@ class RequestHandler: def check_unfinished_seqs(self) -> bool: return self._has_waiting() or not self.running_list.is_empty() - def search_tokens(self, generation_config, logits): + def search_tokens(self, generation_config: GenerationConfig, logits): """ Sample tokens for finished requests. """ diff --git a/colossalai/inference/flash_decoding_utils.py b/colossalai/inference/flash_decoding_utils.py index a91524815..7563d1e4e 100644 --- a/colossalai/inference/flash_decoding_utils.py +++ b/colossalai/inference/flash_decoding_utils.py @@ -12,6 +12,11 @@ class FDIntermTensors(metaclass=SingletonMeta): def __init__(self): self._tensors_initialized = False + def _reset(self): + self._tensors_initialized = False + del self._mid_output + del self._mid_output_lse + @property def is_initialized(self): return self._tensors_initialized diff --git a/colossalai/inference/modeling/models/nopadding_llama.py b/colossalai/inference/modeling/models/nopadding_llama.py index 9de3f040d..a1db4ecfa 100644 --- a/colossalai/inference/modeling/models/nopadding_llama.py +++ b/colossalai/inference/modeling/models/nopadding_llama.py @@ -72,7 +72,6 @@ def llama_model_forward( """ input_ids = batch.get_1D_inputs() block_tables = batch.get_block_table_tensor() - sequence_lengths = batch.get_sequence_lengths() batch_size = len(sequence_lengths) kv_seq_len = sequence_lengths.max().item() diff --git a/tests/test_infer/test_inference_engine.py b/tests/test_infer/test_inference_engine.py index 2bc6d5436..edd92bb96 100644 --- a/tests/test_infer/test_inference_engine.py +++ b/tests/test_infer/test_inference_engine.py @@ -31,7 +31,6 @@ def check_inference_engine(use_engine=False, prompt_template=None): .cuda() .half() ) - model = model.eval() inputs = [ @@ -47,6 +46,7 @@ def check_inference_engine(use_engine=False, prompt_template=None): if use_engine: inference_config = InferenceConfig(max_output_len=output_len, prompt_template=prompt_template) inference_engine = InferenceEngine(model, tokenizer, inference_config, verbose=True) + assert inference_engine.generation_config.max_new_tokens == output_len inference_engine.add_request(prompts=inputs) assert inference_engine.request_handler._has_waiting() generation_config = GenerationConfig(do_sample=do_sample, top_p=top_p, top_k=top_k)