fix beam_width

pull/5258/head
yuehuayingxueluo 2024-01-04 16:48:53 +08:00 committed by FrankLeeeee
parent b2eb9cd186
commit 3ad1f3b78b
2 changed files with 6 additions and 3 deletions

View File

@ -176,8 +176,12 @@ def llama_attn_forward(
def generate_padding_position_id(input_ids: torch.Tensor) -> torch.Tensor: def generate_padding_position_id(input_ids: torch.Tensor) -> torch.Tensor:
# Replace this code and use a more flexible method to obtain padding_id, avoiding directly setting padding_id like this.
padding_id = 2 padding_id = 2
attention_mask = input_ids.ne(padding_id).long() attention_mask = input_ids.ne(padding_id).long()
position_ids = attention_mask.long().cumsum(-1) - 1 position_ids = attention_mask.long().cumsum(-1) - 1
position_ids.masked_fill_(attention_mask == 0, 1) position_ids.masked_fill_(attention_mask == 0, 1)
return position_ids return position_ids
# def unpad_inputs(input_ids: torch.Tensor):

View File

@ -42,9 +42,8 @@ def beam_search_sample(
# NOTE: this beam search sample function is wrong now. # NOTE: this beam search sample function is wrong now.
""" """
# beam_width = generation_config.best_of beam_width = generation_config.num_beams
beam_width = 1
results = [] results = []
if is_prompt: if is_prompt:
# Prompt phase. # Prompt phase.