From fa4fbdbffb6996e8aa1f65bddce5844f2bbbfdf1 Mon Sep 17 00:00:00 2001 From: yuehuayingxueluo <867460659@qq.com> Date: Tue, 9 Jan 2024 13:52:53 +0800 Subject: [PATCH] adapted to pad_context_forward --- colossalai/inference/config.py | 14 ++++++----- colossalai/inference/core/engine.py | 6 +++-- colossalai/inference/core/request_handler.py | 16 +++++++++---- .../inference/kv_cache/kvcache_manager.py | 2 +- colossalai/inference/modeling/models/llama.py | 23 ++----------------- colossalai/inference/sampler.py | 2 +- colossalai/inference/struct.py | 2 +- .../legacy/inference/hybridengine/engine.py | 2 +- tests/test_infer/test_inference_engine.py | 16 +++++++++---- 9 files changed, 42 insertions(+), 41 deletions(-) diff --git a/colossalai/inference/config.py b/colossalai/inference/config.py index f88120965..8ce4ce967 100644 --- a/colossalai/inference/config.py +++ b/colossalai/inference/config.py @@ -1,6 +1,5 @@ """ -Our config consists of one part: - 1. inference_config: configs for inference, it is a unified api that wraps all the configs for inference. +Our config contains various options for inference optimization, it is a unified API that wraps all the configurations for inference. """ import logging @@ -94,9 +93,12 @@ class InferenceConfig: torch.float32, torch.float16, torch.bfloat16, - ], "dtype should be one of 'fp16', 'fp32', 'bf16', torch.float32, torch.float16, torch.bfloat16" - assert self.max_batch_size <= 64, "Max batch size exceeds the constraint" - assert self.quant_mode in ["smoothquant", "gptq", None], "quant should be one of 'smoothquant', 'gptq'" + ], f"dtype should be one of 'fp16', 'fp32', 'bf16', torch.float32, torch.float16, torch.bfloat16, but got {self.dtype}." + assert self.quant_mode in [ + "smoothquant", + "gptq", + None, + ], f"quant should be one of 'smoothquant', 'gptq', but got {self.quant_mode}." assert ( self.max_input_len + self.max_output_len <= self.max_seq_len - ), "The sum of max_input_len and max_output_len must be smaller than max_seq_len." + ), f"The sum of max_input_len {self.max_input_len} and max_output_len {self.max_output_len} must be smaller than max_seq_len {self.max_seq_len}." diff --git a/colossalai/inference/core/engine.py b/colossalai/inference/core/engine.py index a94120a20..6f582c619 100644 --- a/colossalai/inference/core/engine.py +++ b/colossalai/inference/core/engine.py @@ -51,6 +51,8 @@ class InferenceEngine: self.model_config = model.config self.device = torch.device("cuda") + model = model.eval() + if inference_config.dtype == "fp32" or inference_config.dtype == torch.float32: self.dtype = torch.float32 elif inference_config.dtype == "fp16" or inference_config.dtype == torch.float16: @@ -85,12 +87,12 @@ class InferenceEngine: Verify the input config """ if not isinstance(self.model, nn.Module): - raise TypeError(f"the model type must be nn.Module, but get {type(self.model)}") + raise TypeError(f"the model type must be nn.Module, but got {type(self.model)}") if not isinstance(self.tokenizer, PreTrainedTokenizerFast) and not isinstance( self.tokenizer, PreTrainedTokenizer ): raise TypeError( - f"the tokenizer type must be PreTrainedTokenizer or PreTrainedTokenizerFast, but get {type(self.tokenizer)}" + f"the tokenizer type must be PreTrainedTokenizer or PreTrainedTokenizerFast, but got {type(self.tokenizer)}" ) assert ( self.model.__class__.__name__ in _supported_models diff --git a/colossalai/inference/core/request_handler.py b/colossalai/inference/core/request_handler.py index 7c2752a0d..7fad20211 100644 --- a/colossalai/inference/core/request_handler.py +++ b/colossalai/inference/core/request_handler.py @@ -8,6 +8,9 @@ from colossalai.inference.kv_cache import KVCacheManager from colossalai.inference.logit_processors import logit_processor from colossalai.inference.sampler import * from colossalai.inference.struct import BatchInfo, Sequence +from colossalai.logging import get_dist_logger + +logger = get_dist_logger(__name__) class RunningList: @@ -93,17 +96,23 @@ class RequestHandler: # Try to allocate cache blocks for the sequence using a priority of prompt length. for lst in reversed(self.waiting_list): if lst: + remove_list = [] for seq in lst: if seq.input_len > self.inference_config.max_input_len: # If the prompt length is longer than max_input_len, abort the sequence. + logger.warning( + f"the prompt(Request id = {seq.request_id}) length is longer than max_input_len, abort this sequence." + ) self.abort_sequence(seq.request_id) - break + remove_list.append(seq) # Try to allocate cache blocks for the sequence. if self.cache_manager.check_allocation(seq): # If succeed, add the sequence to running list. + remove_list.append(seq) self.running_list.append(seq) self.cache_manager.allocate_context_from_block_table(seq.block_table, seq.input_len) - lst.clear() + for seq in remove_list: + lst.remove(seq) if self.running_list.ready_for_prefill(): for seq in self.running_list.prefill: seq.mark_running() @@ -130,10 +139,9 @@ class RequestHandler: """ Abort the request. """ - seq, priority = self._find_sequence(request_id) + seq, _ = self._find_sequence(request_id) if seq.status.is_waiting: seq.mark_aborted() - self.waiting_list[priority].remove(seq) elif seq.status.is_running(): self.cache_manager.free_block_table(seq.block_table) self.running_list.remove(seq) diff --git a/colossalai/inference/kv_cache/kvcache_manager.py b/colossalai/inference/kv_cache/kvcache_manager.py index 1fee4958d..419fef3fb 100644 --- a/colossalai/inference/kv_cache/kvcache_manager.py +++ b/colossalai/inference/kv_cache/kvcache_manager.py @@ -112,7 +112,7 @@ class KVCacheManager: def get_kv_cache(self): """Get k_cache and v_cache""" - return self._kv_caches[0], self._kv_caches[1] + return self._kv_caches def get_max_blocks_per_sequence(self) -> int: """Get the maximum number of blocks that can be allocated for a single sequence.""" diff --git a/colossalai/inference/modeling/models/llama.py b/colossalai/inference/modeling/models/llama.py index b17ced6e6..44c07b7c6 100644 --- a/colossalai/inference/modeling/models/llama.py +++ b/colossalai/inference/modeling/models/llama.py @@ -16,7 +16,7 @@ from transformers.models.llama.modeling_llama import ( from colossalai.inference.modeling.layers.attention import convert_kvcache, copy_to_cache from colossalai.inference.struct import BatchInfo -from flash_attn.bert_padding import index_first_axis, pad_input # noqa +from flash_attn.bert_padding import index_first_axis # noqa def rotate_half(x): @@ -167,20 +167,8 @@ def llama_attn_forward( query_states, key_states, value_states, k_cache, v_cache, sequence_lengths, block_tables, attention_mask ) else: - key_states = repeat_kv(key_states, self.num_key_value_groups) - value_states = repeat_kv(value_states, self.num_key_value_groups) attn_output = pad_decoding_forward( - query_states, - key_states, - value_states, - k_cache, - v_cache, - sequence_lengths, - block_tables, - attention_mask, - self.layer_idx, - self.attention_dropout, - self.training, + query_states, key_states, value_states, k_cache, v_cache, sequence_lengths, block_tables, attention_mask ) attn_output = attn_output.view(bsz, q_len, self.num_heads, self.head_dim) @@ -215,9 +203,6 @@ def pad_decoding_forward( lengths: torch.Tensor, # [num_seqs]: input_lengths + output_lengths block_tables: torch.Tensor, # [num_seqs,max_blocks_per_sequence] attn_mask: torch.Tensor = None, - layer_id: int = 0, - attention_dropout: float = None, - training: bool = False, ): bsz, query_length, num_heads, head_size = query.shape seq_len = max(lengths) @@ -247,9 +232,7 @@ def pad_decoding_forward( attn_mask = attn_mask.masked_fill(padding_mask.bool(), torch.finfo(query.dtype).min) attn_weights += attn_mask - attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) - attn_weights = nn.functional.dropout(attn_weights, p=attention_dropout, training=training) attn_output = torch.matmul(attn_weights, value) if attn_output.size() != (bsz, num_heads, 1, head_size): @@ -277,8 +260,6 @@ def pad_context_forward( block_size = k_cache.shape[-1] assert q.shape[0] == k.shape[0] == v.shape[0] == block_tables.shape[0] block_tables.shape[-1] * block_size - shape = (bsz, seq_len, num_heads, head_size) - input_shape = shape[:2] # Copy kv to memory(rotary embedded) copy_to_cache(k, k_cache, lengths=context_lengths, block_tables=block_tables) diff --git a/colossalai/inference/sampler.py b/colossalai/inference/sampler.py index d3a10ede7..93e55fcf3 100644 --- a/colossalai/inference/sampler.py +++ b/colossalai/inference/sampler.py @@ -42,7 +42,7 @@ def beam_search_sample( # NOTE: this beam search sample function is wrong now. """ - + beam_width = generation_config.num_beams results = [] if is_prompt: diff --git a/colossalai/inference/struct.py b/colossalai/inference/struct.py index ef07b7ff9..a62089fc9 100644 --- a/colossalai/inference/struct.py +++ b/colossalai/inference/struct.py @@ -268,7 +268,7 @@ class BatchInfo: for seq, token in zip(self.sequences_set, tokens): if not isinstance(token, list): if not isinstance(token, int): - raise TypeError(f"The token type must be List[int] or int, but get {type(token)}.") + raise TypeError(f"The token type must be List[int] or int, but got {type(token)}.") token = [token] seq.output_token_id += token seq.check_finish() diff --git a/colossalai/legacy/inference/hybridengine/engine.py b/colossalai/legacy/inference/hybridengine/engine.py index bb0b4c77a..48a368fc0 100644 --- a/colossalai/legacy/inference/hybridengine/engine.py +++ b/colossalai/legacy/inference/hybridengine/engine.py @@ -133,7 +133,7 @@ class CaiInferEngine: """ assert isinstance( input_list, (BatchEncoding, dict) - ), f"Only accept BatchEncoding or dict as input, but get {input_list.__class__.__name__}." + ), f"Only accept BatchEncoding or dict as input, but got {input_list.__class__.__name__}." if isinstance(input_list, BatchEncoding): input_list = input_list.data out, timestamp = self.schedule.generate_step(self.model, iter([input_list])) diff --git a/tests/test_infer/test_inference_engine.py b/tests/test_infer/test_inference_engine.py index 5fab016e5..4992fdfc7 100644 --- a/tests/test_infer/test_inference_engine.py +++ b/tests/test_infer/test_inference_engine.py @@ -28,20 +28,24 @@ def check_inference_engine(test_cai=False): ) ).cuda() + model = model.eval() + inputs = [ - "介绍一下今天的北京,", + "介绍一下今天的北京,比如故宫,天安门,长城或者其他的一些景点,", "介绍一下武汉,", ] - output_len = 16 + output_len = 128 do_sample = True + top_p = 0.5 + top_k = 50 if test_cai: inference_config = InferenceConfig(max_output_len=output_len) inference_engine = InferenceEngine(model, tokenizer, inference_config, verbose=True) inference_engine.add_request(prompts=inputs) assert inference_engine.request_handler._has_waiting() - generation_config = GenerationConfig(do_sample=do_sample, top_p=0.5, top_k=50) + generation_config = GenerationConfig(do_sample=do_sample, top_p=top_p, top_k=top_k) outputs = inference_engine.generate(generation_config) else: tokenizer.pad_token = tokenizer.eos_token @@ -49,7 +53,11 @@ def check_inference_engine(test_cai=False): inputs = tokenizer.batch_encode_plus(inputs, padding=True, return_tensors="pt")["input_ids"] inputs = inputs.cuda() generation_config = GenerationConfig( - do_sample=do_sample, top_p=0.5, top_k=50, pad_token_id=tokenizer.pad_token_id, max_new_tokens=output_len + do_sample=do_sample, + top_p=top_p, + top_k=top_k, + pad_token_id=tokenizer.pad_token_id, + max_new_tokens=output_len, ) outputs = model.generate(inputs, generation_config=generation_config) outputs = tokenizer.batch_decode(outputs, skip_special_tokens=True)