[Inference] Support the logic related to ignoring EOS token (#5693)

* Adapt temperature processing logic

* add ValueError for top_p and top_k

* add GQA Test

* fix except_msg

* support ignore EOS token

* change variable's name

* fix annotation
pull/5706/head
yuehuayingxueluo 2024-05-08 19:59:10 +08:00 committed by GitHub
parent 9c2fe7935f
commit d482922035
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 9 additions and 1 deletions

View File

@ -111,6 +111,7 @@ class InferenceConfig:
use_cuda_graph (bool): Whether to enforce CUDA graph execution. If False, we will disable CUDA graph and always execute the model in eager mode. If True, we will use eager execution in hybrid. use_cuda_graph (bool): Whether to enforce CUDA graph execution. If False, we will disable CUDA graph and always execute the model in eager mode. If True, we will use eager execution in hybrid.
max_context_len_to_capture (int): max context len that could be captured by CUDA Graph, per sequence max_context_len_to_capture (int): max context len that could be captured by CUDA Graph, per sequence
high_precision(Optional[bool]): Whether to use float32 for underlying calculations of float16 data to achieve higher precision, defaults to False. high_precision(Optional[bool]): Whether to use float32 for underlying calculations of float16 data to achieve higher precision, defaults to False.
ignore_eos(bool): Whether to ignore the EOS token and continue generating tokens when encountering the EOS token.
""" """
# NOTE: arrange configs according to their importance and frequency of usage # NOTE: arrange configs according to their importance and frequency of usage
@ -156,6 +157,7 @@ class InferenceConfig:
# cuda_graph # cuda_graph
use_cuda_graph: bool = False # NOTE only when we have the graph for specific decoding batch size can we use the cuda graph for inference use_cuda_graph: bool = False # NOTE only when we have the graph for specific decoding batch size can we use the cuda graph for inference
max_context_len_to_capture: int = 512 max_context_len_to_capture: int = 512
ignore_eos: bool = False
def __post_init__(self): def __post_init__(self):
self.max_context_len_to_capture = self.max_input_len + self.max_output_len self.max_context_len_to_capture = self.max_input_len + self.max_output_len

View File

@ -662,6 +662,7 @@ class InferenceEngine:
self.tokenizer.eos_token_id, self.tokenizer.eos_token_id,
self.tokenizer.pad_token_id, self.tokenizer.pad_token_id,
max_output_len=max_new_tokens, max_output_len=max_new_tokens,
ignore_eos=self.inference_config.ignore_eos,
) )
self.request_handler.add_sequence(sequence) self.request_handler.add_sequence(sequence)

View File

@ -60,6 +60,7 @@ class Sequence:
eos_token_id (int): The eos token id for this inference process. eos_token_id (int): The eos token id for this inference process.
pad_token_id (int): The pad token id for this inference process. pad_token_id (int): The pad token id for this inference process.
max_output_len (int): Maximum output length. max_output_len (int): Maximum output length.
ignore_eos(bool): Whether to ignore the EOS token and continue generating tokens when encountering the EOS token.
""" """
request_id: int request_id: int
@ -70,6 +71,8 @@ class Sequence:
eos_token_id: int eos_token_id: int
pad_token_id: int pad_token_id: int
max_output_len: int = 256 max_output_len: int = 256
# NOTE(caidi) This is a temporary solution. It's better to move the logic to turn on or off the flag in sampling module in future.
ignore_eos: bool = False
def __post_init__(self): def __post_init__(self):
self.output_token_id = [] self.output_token_id = []
@ -107,7 +110,9 @@ class Sequence:
return True return True
if self.output_token_id: if self.output_token_id:
if self.output_token_id[-1] == self.eos_token_id or self.output_len >= self.max_output_len: if (
self.output_token_id[-1] == self.eos_token_id and not self.ignore_eos
) or self.output_len >= self.max_output_len:
self.status = RequestStatus.COMPLETED self.status = RequestStatus.COMPLETED
return True return True