[Inference] Support the logic related to ignoring EOS token (#5693)

* Adapt temperature processing logic

* add ValueError for top_p and top_k

* add GQA Test

* fix except_msg

* support ignore EOS token

* change variable's name

* fix annotation
pull/5706/head
yuehuayingxueluo 2024-05-08 19:59:10 +08:00 committed by GitHub
parent 9c2fe7935f
commit d482922035
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 9 additions and 1 deletions

View File

@ -111,6 +111,7 @@ class InferenceConfig:
use_cuda_graph (bool): Whether to enforce CUDA graph execution. If False, we will disable CUDA graph and always execute the model in eager mode. If True, we will use eager execution in hybrid.
max_context_len_to_capture (int): max context len that could be captured by CUDA Graph, per sequence
high_precision(Optional[bool]): Whether to use float32 for underlying calculations of float16 data to achieve higher precision, defaults to False.
ignore_eos(bool): Whether to ignore the EOS token and continue generating tokens when encountering the EOS token.
"""
# NOTE: arrange configs according to their importance and frequency of usage
@ -156,6 +157,7 @@ class InferenceConfig:
# cuda_graph
use_cuda_graph: bool = False # NOTE only when we have the graph for specific decoding batch size can we use the cuda graph for inference
max_context_len_to_capture: int = 512
ignore_eos: bool = False
def __post_init__(self):
self.max_context_len_to_capture = self.max_input_len + self.max_output_len

View File

@ -662,6 +662,7 @@ class InferenceEngine:
self.tokenizer.eos_token_id,
self.tokenizer.pad_token_id,
max_output_len=max_new_tokens,
ignore_eos=self.inference_config.ignore_eos,
)
self.request_handler.add_sequence(sequence)

View File

@ -60,6 +60,7 @@ class Sequence:
eos_token_id (int): The eos token id for this inference process.
pad_token_id (int): The pad token id for this inference process.
max_output_len (int): Maximum output length.
ignore_eos(bool): Whether to ignore the EOS token and continue generating tokens when encountering the EOS token.
"""
request_id: int
@ -70,6 +71,8 @@ class Sequence:
eos_token_id: int
pad_token_id: int
max_output_len: int = 256
# NOTE(caidi) This is a temporary solution. It's better to move the logic to turn on or off the flag in sampling module in future.
ignore_eos: bool = False
def __post_init__(self):
self.output_token_id = []
@ -107,7 +110,9 @@ class Sequence:
return True
if self.output_token_id:
if self.output_token_id[-1] == self.eos_token_id or self.output_len >= self.max_output_len:
if (
self.output_token_id[-1] == self.eos_token_id and not self.ignore_eos
) or self.output_len >= self.max_output_len:
self.status = RequestStatus.COMPLETED
return True