diff --git a/colossalai/inference/config.py b/colossalai/inference/config.py index 70faf34e3..61bc7c8ab 100644 --- a/colossalai/inference/config.py +++ b/colossalai/inference/config.py @@ -202,11 +202,12 @@ class InferenceConfig(RPC_PARAM): ] = 1.2 # the ratio of prefill sequences to decoding sequences, we do prefill step once the actual value exceeds ratio pad_input: bool = False early_stopping: Optional[bool] = False - top_k: Optional[int] = None - top_p: Optional[float] = None + top_k: Optional[int] = 50 + top_p: Optional[float] = 1.0 temperature: Optional[float] = 1.0 no_repeat_ngram_size: Optional[int] = 0 repetition_penalty: Optional[float] = 1.0 + forced_eos_token_id: int = None # speculative decoding configs max_n_spec_tokens: int = 5 diff --git a/colossalai/inference/core/engine.py b/colossalai/inference/core/engine.py index 73ba08750..646b3cede 100644 --- a/colossalai/inference/core/engine.py +++ b/colossalai/inference/core/engine.py @@ -76,6 +76,7 @@ class InferenceEngine: self.init_model(model_or_path, model_policy) self.generation_config = inference_config.to_generation_config(self.model_config) + self.generation_config_dict = self.generation_config.to_dict() self.tokenizer = tokenizer self.tokenizer.pad_token = self.tokenizer.eos_token @@ -524,12 +525,13 @@ class InferenceEngine: Returns: List[str]: Inference result returned by one generation. """ + + gen_config_dict = generation_config.to_dict() if generation_config is not None else {} + prompts = [prompts] if isinstance(prompts, str) else prompts + request_ids = [request_ids] if isinstance(request_ids, int) else request_ids + with torch.inference_mode(): - if isinstance(prompts, str) and isinstance(request_ids, int): - prompts = [prompts] - request_ids = [request_ids] if prompts is not None or prompts_token_ids is not None: - gen_config_dict = generation_config.to_dict() if generation_config is not None else {} self.add_request( request_ids=request_ids, prompts=prompts, @@ -543,6 +545,7 @@ class InferenceEngine: # intuition: If user provide a generation config, we should replace the existing one. if generation_config is not None: self.generation_config = generation_config + self.generation_config_dict = gen_config_dict if self.use_spec_dec: assert self.drafter is not None, "Drafter Model is not initialized." @@ -688,11 +691,12 @@ class InferenceEngine: ) batch_token_ids = None - config_dict = self.generation_config.to_dict() - # process repetition_penalty, no_repeat_ngram_size - for type in ["repetition_penalty", "no_repeat_ngram_size"]: - if type in config_dict and config_dict[type] is not None: - batch_token_ids = batch.batch_token_ids + if ( + self.generation_config.repetition_penalty != 1.0 + or self.generation_config.no_repeat_ngram_size > 0 + or self.generation_config.forced_eos_token_id is not None + ): + batch_token_ids = batch.batch_token_ids # only when we have the graph for specific decoding batch size can we use the cuda graph for inference use_cuda_graph = False diff --git a/colossalai/inference/core/rpc_engine.py b/colossalai/inference/core/rpc_engine.py index 9602147f5..439c4b0b5 100644 --- a/colossalai/inference/core/rpc_engine.py +++ b/colossalai/inference/core/rpc_engine.py @@ -257,7 +257,12 @@ class RPCInferenceEngine(InferenceEngine): assert len(self.workers) == self.tp_size, "init workers first" init_tasks = [ - self.async_parallel_wrapper(worker.execute_model_forward, input_token_ids, input_meta_data.to_rpc_param()) + self.async_parallel_wrapper( + worker.execute_model_forward, + input_token_ids, + input_meta_data.to_rpc_param(), + self.generation_config_dict, + ) for worker in self.workers ] ret = await asyncio.gather(*init_tasks) diff --git a/colossalai/inference/executor/rpc_worker.py b/colossalai/inference/executor/rpc_worker.py index 7d8350ac0..913b8667d 100644 --- a/colossalai/inference/executor/rpc_worker.py +++ b/colossalai/inference/executor/rpc_worker.py @@ -97,7 +97,9 @@ class rpcWorkerService(rpyc.Service): ) logger.info("physical cache init over") - def exposed_execute_model_forward(self, input_token_ids_param: List[int], input_meta_data_param: dict): + def exposed_execute_model_forward( + self, input_token_ids_param: List[int], input_meta_data_param: dict, generation_config_param: dict + ): # prepare the data for model forward input_meta_data = InputMetaData.from_rpc_param(input_meta_data_param) input_meta_data.fd_inter_tensor = self.fd_inter_tensor @@ -120,7 +122,7 @@ class rpcWorkerService(rpyc.Service): if self.inference_config.pad_input: logits = logits[:, -1, :] next_tokens = search_tokens( - self.inference_config.to_generation_config(self.model_config), + generation_config_param, logits, input_meta_data.is_prompts, input_meta_data.batch_token_ids, diff --git a/colossalai/inference/logit_processors.py b/colossalai/inference/logit_processors.py index 8e4b29ae6..ea73f8332 100644 --- a/colossalai/inference/logit_processors.py +++ b/colossalai/inference/logit_processors.py @@ -1,27 +1,28 @@ # This code is adapted from huggingface transformers: https://github.com/huggingface/transformers/blob/v4.36.2/src/transformers/generation/logits_process.py -from typing import List +import logging +from typing import List, Union import torch import torch.nn.functional as F -_LOGIT_PROCESSOR_MAP = {} +_LOGITS_PROCESSOR_MAP = {} -def register_logit_processor(process_type): +def register_logits_processor(process_type): """ register flops computation function for operation. """ def register(func): - global _LOGIT_PROCESSOR_MAP - _LOGIT_PROCESSOR_MAP[process_type] = func + global _LOGITS_PROCESSOR_MAP + _LOGITS_PROCESSOR_MAP[process_type] = func return func return register -@register_logit_processor("no_repeat_ngram_size") -def no_repeat_ngram_size_logit_process(logits, ngram_size: int, batch_token_ids: List[List[int]]): +@register_logits_processor("no_repeat_ngram_size") +def apply_no_repeat_ngram_size(logits, ngram_size: int, batch_token_ids: List[List[int]]): """ enforces no repetition of n-grams to avoid repetitions of word sequences. """ @@ -52,8 +53,8 @@ def no_repeat_ngram_size_logit_process(logits, ngram_size: int, batch_token_ids: return logits -@register_logit_processor("repetition_penalty") -def repetition_penalty_logit_process(logits, penalty: float, batch_token_ids: List[List[int]]): +@register_logits_processor("repetition_penalty") +def apply_repetition_penalty(logits, penalty: float, batch_token_ids: List[List[int]]): """ apply the penalty to the tokens present in the prompt. """ @@ -61,7 +62,7 @@ def repetition_penalty_logit_process(logits, penalty: float, batch_token_ids: Li if not isinstance(penalty, float) or not (penalty > 0): raise ValueError(f"'penalty={penalty}' has to be a strictly positive float and greater than 0.") - logit_list = [] + logits_list = [] # TODO(yuehuayingxueluo) This is only a temporary implementation. Later, we will implement presence_penalties, frequency_penalties, and repetition_penalties using CUDA kernels. if penalty != 1.0: @@ -71,15 +72,15 @@ def repetition_penalty_logit_process(logits, penalty: float, batch_token_ids: Li curretn_socre = torch.gather(current_logit, 0, current_token) curretn_socre = torch.where(curretn_socre < 0, curretn_socre * penalty, curretn_socre / penalty) - logit_list.append(current_logit.scatter(0, current_token, curretn_socre)) + logits_list.append(current_logit.scatter(0, current_token, curretn_socre)) - logits = torch.stack(logit_list) + logits = torch.stack(logits_list) return logits -@register_logit_processor("temperature") -def temperature_logit_process(logits, temperature: float): +@register_logits_processor("temperature") +def apply_temperature(logits, temperature: float): """ apply temperature scaling. """ @@ -93,8 +94,8 @@ def temperature_logit_process(logits, temperature: float): return logits if temperature == 1.0 else logits / temperature -@register_logit_processor("top_k") -def top_k_logit_processor(logits, top_k: int): +@register_logits_processor("top_k") +def apply_top_k(logits, top_k: int): """ top_k logit processor """ @@ -107,8 +108,8 @@ def top_k_logit_processor(logits, top_k: int): return logits -@register_logit_processor("top_p") -def top_p_logit_processor(logits, top_p: float): +@register_logits_processor("top_p") +def apply_top_p(logits, top_p: float): """ top_p logit processor """ @@ -129,7 +130,46 @@ def top_p_logit_processor(logits, top_p: float): return logits -def logit_processor(processor: str, logits, *args, **kwargs): +@register_logits_processor("forced_eos_token_id") +def apply_forced_eos_token_id( + logits: torch.Tensor, + sequence_lengths: Union[torch.Tensor, List[int]], + max_lengths: Union[torch.Tensor, List[int]], + eos_token_id: Union[int, List[int]], +): + """ + Enforces the specified token as the last generated token when the maximum output length + is reached. Notice that the maximum output lengths for different sequences, even if they're + in the same batch, can be different. + + Args: + logits(torch.Tensor): logits + sequence_lengths(torch.Tensor): sequence lengths including prompt and output tokens + max_lengths(torch.Tensor): the maximum length for each sequence + eos_token_id(Union[int, List[int]]): forced eos token id + """ + if isinstance(eos_token_id, int): + eos_token_id = [eos_token_id] + if isinstance(sequence_lengths, torch.Tensor): + sequence_lengths = sequence_lengths.tolist() + if isinstance(max_lengths, torch.Tensor): + max_lengths = max_lengths.tolist() + + select_indexes = [] + num_sequences = logits.shape[0] + sequence_lengths = sequence_lengths[:num_sequences] + max_lengths = max_lengths[:num_sequences] + for i, (sequence_length, max_out_length) in enumerate(zip(sequence_lengths, max_lengths)): + if sequence_length == max_out_length - 1: + select_indexes.append(i) + if select_indexes: + logits[select_indexes, :] = -float("inf") + logits[select_indexes, eos_token_id] = 0 + + return logits + + +def get_logits_processor(processor: str, logits, *args, **kwargs): """ do logit process for given logits. @@ -140,9 +180,10 @@ def logit_processor(processor: str, logits, *args, **kwargs): Returns: logits after process """ - if processor not in _LOGIT_PROCESSOR_MAP: - return logits + if processor not in _LOGITS_PROCESSOR_MAP: + logging.warning(f"Unsupported processor {processor}. Fall back to the original logits.") else: - func = _LOGIT_PROCESSOR_MAP[processor] + func = _LOGITS_PROCESSOR_MAP[processor] logits = func(logits, *args, **kwargs) - return logits + + return logits diff --git a/colossalai/inference/sampler.py b/colossalai/inference/sampler.py index d3857a3bd..949d979bc 100644 --- a/colossalai/inference/sampler.py +++ b/colossalai/inference/sampler.py @@ -1,13 +1,12 @@ -from typing import List, Optional, Tuple +from typing import List, Optional, Tuple, Union import torch from transformers.generation import GenerationConfig -from colossalai.inference.logit_processors import logit_processor +from colossalai.inference.logit_processors import get_logits_processor def greedy_sample( - generation_config, logprobs: torch.Tensor, ) -> torch.Tensor: """ @@ -18,7 +17,6 @@ def greedy_sample( def multinomial_sample( - generation_config, probs: torch.Tensor, ) -> torch.Tensor: """ @@ -29,7 +27,7 @@ def multinomial_sample( def beam_search_sample( - generation_config, + beam_width: int, logprobs: torch.Tensor, is_prompt: bool = False, ) -> List[Tuple[List[int], List[int]]]: @@ -46,7 +44,6 @@ def beam_search_sample( # NOTE: this beam search sample function is wrong now. """ - beam_width = generation_config.num_beams results = [] if is_prompt: # Prompt phase. @@ -64,20 +61,8 @@ def beam_search_sample( return results -def _sample(probs: torch.Tensor, logprobs: torch.Tensor, generation_config: GenerationConfig, is_prompt: bool = False): - if generation_config.num_beams == 1: - if generation_config.do_sample: - sample_tokens = multinomial_sample(generation_config, probs) - else: - sample_tokens = greedy_sample(generation_config, logprobs) - else: - sample_tokens = beam_search_sample(generation_config, logprobs, is_prompt=is_prompt) - - return sample_tokens - - def search_tokens( - generation_config: GenerationConfig, + generation_config: Union[GenerationConfig, dict], logits, is_prompt: bool = False, batch_token_ids: Optional[List[List[int]]] = None, @@ -86,23 +71,41 @@ def search_tokens( Sample tokens for finished requests. """ # NOTE: need to decide the granularity to process logits (sequence or batch) - config_dict = generation_config.to_dict() - # process repetition_penalty, no_repeat_ngram_size - for type in ["repetition_penalty", "no_repeat_ngram_size"]: - if type in config_dict and config_dict[type] is not None: - logits = logit_processor(type, logits, config_dict[type], batch_token_ids) - # do logit processor - if generation_config.do_sample: - # process temperature, top_k, top_p - for type in ["temperature", "top_k", "top_p"]: - if type in config_dict and config_dict[type] is not None: - logits = logit_processor(type, logits, config_dict[type]) + # convert GenerationConfig to dict + # temporary fix for compatibility with the usage of RPCInferenceEngine + if isinstance(generation_config, GenerationConfig): + generation_config = generation_config.to_dict() + + if (repetition_penalty := generation_config.get("repetition_penalty", 1.0)) != 1.0: + logits = get_logits_processor("repetition_penalty", logits, repetition_penalty, batch_token_ids) + if (no_repeat_ngram_size := generation_config.get("no_repeat_ngram_size", 0)) > 0: + logits = get_logits_processor("no_repeat_ngram_size", logits, no_repeat_ngram_size, batch_token_ids) + if (forced_eos_token_id := generation_config.get("forced_eos_token_id", None)) is not None: + sequence_lengths = [len(batch_token_ids[i]) for i in range(len(batch_token_ids))] + max_out_lengths = [generation_config.max_length for _ in range(len(batch_token_ids))] + logits = get_logits_processor( + "forced_eos_token_id", logits, sequence_lengths, max_out_lengths, forced_eos_token_id + ) + + if generation_config.get("do_sample"): + if (temperature := generation_config.get("temperature", 1.0)) != 1.0: + logits = get_logits_processor("temperature", logits, temperature) + if (top_k := generation_config.get("top_k", 0)) != 0: + logits = get_logits_processor("top_k", logits, top_k) + if (top_p := generation_config.get("top_p", 1.0)) < 1.0: + logits = get_logits_processor("top_p", logits, top_p) # calculate probs probs = torch.softmax(logits, dim=-1, dtype=torch.float) logprobs = torch.log_softmax(logits, dim=-1, dtype=torch.float) # sample the next tokens - sample_tokens = _sample(probs, logprobs, generation_config, is_prompt) + if generation_config.get("num_beams", 1) != 1: + raise NotImplementedError("Beam search is not supported yet.") + if generation_config.get("do_sample", False): + sample_tokens = multinomial_sample(probs) + else: + sample_tokens = greedy_sample(logprobs) + return sample_tokens