mirror of https://github.com/hpcaitech/ColossalAI
[Inference] Support the logic related to ignoring EOS token (#5693)
* Adapt temperature processing logic * add ValueError for top_p and top_k * add GQA Test * fix except_msg * support ignore EOS token * change variable's name * fix annotationpull/5706/head
parent
9c2fe7935f
commit
d482922035
|
@ -111,6 +111,7 @@ class InferenceConfig:
|
|||
use_cuda_graph (bool): Whether to enforce CUDA graph execution. If False, we will disable CUDA graph and always execute the model in eager mode. If True, we will use eager execution in hybrid.
|
||||
max_context_len_to_capture (int): max context len that could be captured by CUDA Graph, per sequence
|
||||
high_precision(Optional[bool]): Whether to use float32 for underlying calculations of float16 data to achieve higher precision, defaults to False.
|
||||
ignore_eos(bool): Whether to ignore the EOS token and continue generating tokens when encountering the EOS token.
|
||||
"""
|
||||
|
||||
# NOTE: arrange configs according to their importance and frequency of usage
|
||||
|
@ -156,6 +157,7 @@ class InferenceConfig:
|
|||
# cuda_graph
|
||||
use_cuda_graph: bool = False # NOTE only when we have the graph for specific decoding batch size can we use the cuda graph for inference
|
||||
max_context_len_to_capture: int = 512
|
||||
ignore_eos: bool = False
|
||||
|
||||
def __post_init__(self):
|
||||
self.max_context_len_to_capture = self.max_input_len + self.max_output_len
|
||||
|
|
|
@ -662,6 +662,7 @@ class InferenceEngine:
|
|||
self.tokenizer.eos_token_id,
|
||||
self.tokenizer.pad_token_id,
|
||||
max_output_len=max_new_tokens,
|
||||
ignore_eos=self.inference_config.ignore_eos,
|
||||
)
|
||||
self.request_handler.add_sequence(sequence)
|
||||
|
||||
|
|
|
@ -60,6 +60,7 @@ class Sequence:
|
|||
eos_token_id (int): The eos token id for this inference process.
|
||||
pad_token_id (int): The pad token id for this inference process.
|
||||
max_output_len (int): Maximum output length.
|
||||
ignore_eos(bool): Whether to ignore the EOS token and continue generating tokens when encountering the EOS token.
|
||||
"""
|
||||
|
||||
request_id: int
|
||||
|
@ -70,6 +71,8 @@ class Sequence:
|
|||
eos_token_id: int
|
||||
pad_token_id: int
|
||||
max_output_len: int = 256
|
||||
# NOTE(caidi) This is a temporary solution. It's better to move the logic to turn on or off the flag in sampling module in future.
|
||||
ignore_eos: bool = False
|
||||
|
||||
def __post_init__(self):
|
||||
self.output_token_id = []
|
||||
|
@ -107,7 +110,9 @@ class Sequence:
|
|||
return True
|
||||
|
||||
if self.output_token_id:
|
||||
if self.output_token_id[-1] == self.eos_token_id or self.output_len >= self.max_output_len:
|
||||
if (
|
||||
self.output_token_id[-1] == self.eos_token_id and not self.ignore_eos
|
||||
) or self.output_len >= self.max_output_len:
|
||||
self.status = RequestStatus.COMPLETED
|
||||
return True
|
||||
|
||||
|
|
Loading…
Reference in New Issue