From 3ad1f3b78b830c90079ed9f1e0b5cd26601194fa Mon Sep 17 00:00:00 2001 From: yuehuayingxueluo <867460659@qq.com> Date: Thu, 4 Jan 2024 16:48:53 +0800 Subject: [PATCH] fix beam_width --- colossalai/inference/modeling/models/llama.py | 4 ++++ colossalai/inference/sampler.py | 5 ++--- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/colossalai/inference/modeling/models/llama.py b/colossalai/inference/modeling/models/llama.py index 1331cc021..b4246d947 100644 --- a/colossalai/inference/modeling/models/llama.py +++ b/colossalai/inference/modeling/models/llama.py @@ -176,8 +176,12 @@ def llama_attn_forward( def generate_padding_position_id(input_ids: torch.Tensor) -> torch.Tensor: + # Replace this code and use a more flexible method to obtain padding_id, avoiding directly setting padding_id like this. padding_id = 2 attention_mask = input_ids.ne(padding_id).long() position_ids = attention_mask.long().cumsum(-1) - 1 position_ids.masked_fill_(attention_mask == 0, 1) return position_ids + +# def unpad_inputs(input_ids: torch.Tensor): + diff --git a/colossalai/inference/sampler.py b/colossalai/inference/sampler.py index 1c0c518f9..d3a10ede7 100644 --- a/colossalai/inference/sampler.py +++ b/colossalai/inference/sampler.py @@ -42,9 +42,8 @@ def beam_search_sample( # NOTE: this beam search sample function is wrong now. """ - - # beam_width = generation_config.best_of - beam_width = 1 + + beam_width = generation_config.num_beams results = [] if is_prompt: # Prompt phase.