From 9c2fe7935ff5aaec4f174cfba6f324df623c7447 Mon Sep 17 00:00:00 2001 From: yuehuayingxueluo <867460659@qq.com> Date: Wed, 8 May 2024 17:58:29 +0800 Subject: [PATCH] [Inference]Adapt temperature processing logic (#5689) * Adapt temperature processing logic * add ValueError for top_p and top_k * add GQA Test * fix except_msg --- colossalai/inference/core/request_handler.py | 12 +++++----- colossalai/inference/logit_processors.py | 23 ++++++++++++++++++++ tests/test_infer/test_inference_engine.py | 7 +++++- 3 files changed, 36 insertions(+), 6 deletions(-) diff --git a/colossalai/inference/core/request_handler.py b/colossalai/inference/core/request_handler.py index d80572599..10180ff2f 100644 --- a/colossalai/inference/core/request_handler.py +++ b/colossalai/inference/core/request_handler.py @@ -328,12 +328,14 @@ class RequestHandler: """ Sample tokens for finished requests. """ + # do logit processor - # NOTE: need to decide the granularity to process logits (sequence or batch) - config_dict = generation_config.to_dict() - for type in ["top_k", "top_p", "min_p"]: - if type in config_dict and config_dict[type] is not None: - logits = logit_processor(type, logits, config_dict[type]) + if generation_config.do_sample: + # NOTE: need to decide the granularity to process logits (sequence or batch) + config_dict = generation_config.to_dict() + for type in ["temperature", "top_k", "top_p"]: + if type in config_dict and config_dict[type] is not None: + logits = logit_processor(type, logits, config_dict[type]) # calculate probs probs = torch.softmax(logits, dim=-1, dtype=torch.float) diff --git a/colossalai/inference/logit_processors.py b/colossalai/inference/logit_processors.py index 557b3df65..39044fcec 100644 --- a/colossalai/inference/logit_processors.py +++ b/colossalai/inference/logit_processors.py @@ -17,11 +17,30 @@ def register_logit_processor(process_type): return register +@register_logit_processor("temperature") +def temperature_logit_process(logits, temperature: float): + """ + apply temperature scaling. + """ + + if not isinstance(temperature, float) or not (0.0 < temperature <= 1.0): + except_msg = f"'temperature={temperature}' should be a strictly positive float, less than or equal to 1.0 and greater than 0." + if temperature == 0.0: + except_msg += "if you want to use greedy decoding strategies, set `do_sample=False`." + raise ValueError(except_msg) + + return logits if temperature == 1.0 else logits / temperature + + @register_logit_processor("top_k") def top_k_logit_processor(logits, top_k: int): """ top_k logit processor """ + + if not isinstance(top_k, int) or top_k <= 0: + raise ValueError(f"`top_k` should be a strictly positive integer, but got {top_k}.") + indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None] logits[indices_to_remove] = -float("inf") return logits @@ -32,6 +51,10 @@ def top_p_logit_processor(logits, top_p: float): """ top_p logit processor """ + + if top_p < 0 or top_p > 1.0: + raise ValueError(f"`top_p` should be a float > 0 and < 1, but got {top_p}.") + sorted_logits, sorted_indices = torch.sort(logits, descending=True) cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) diff --git a/tests/test_infer/test_inference_engine.py b/tests/test_infer/test_inference_engine.py index 8061c50d2..be1330898 100644 --- a/tests/test_infer/test_inference_engine.py +++ b/tests/test_infer/test_inference_engine.py @@ -28,7 +28,12 @@ def check_inference_engine(use_engine=False, prompt_template=None, do_sample=Tru tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/llama-tokenizer") model = LlamaForCausalLM( LlamaConfig( - vocab_size=50000, hidden_size=512, intermediate_size=1536, num_attention_heads=4, num_hidden_layers=16 + vocab_size=50000, + hidden_size=512, + intermediate_size=1536, + num_attention_heads=4, + num_key_value_heads=2, + num_hidden_layers=16, ) ).cuda() model = model.eval()