[Inference] Optimize request handler of llama (#5512)

* optimize request_handler

* fix ways of writing
pull/5537/head
傅剑寒 2024-03-26 16:37:14 +08:00 committed by GitHub
parent 6251d68dc9
commit e6496dd371
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 9 additions and 7 deletions

View File

@ -298,8 +298,8 @@ class RequestHandler:
"""
# do logit processor
# NOTE: need to decide the granularity to process logits (sequence or batch)
config_dict = generation_config.to_dict()
for type in ["top_k", "top_p", "min_p"]:
config_dict = generation_config.to_dict()
if type in config_dict and config_dict[type] is not None:
logits = logit_processor(type, logits, config_dict[type])

View File

@ -36,21 +36,23 @@ def top_p_logit_processor(logits, top_p: float):
cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
sorted_indices_to_remove = cumulative_probs > top_p
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
sorted_indices_to_remove = torch.roll(sorted_indices_to_remove, 1, -1)
sorted_indices_to_remove[..., 0] = 0
indices_to_remove = sorted_indices_to_remove.scatter(dim=1, index=sorted_indices, src=sorted_indices_to_remove)
logits[indices_to_remove] = -float("inf")
return logits
def logit_processor(processor:str, logits , attrs):
def logit_processor(processor: str, logits, attrs):
"""
do logit process for given logits.
Args:
processor(str): the type of logit processor
processor(str): the type of logit processor
logits(torch.Tensor): input logits
attrs(dict): attrs of the logit processor
attrs(dict): attrs of the logit processor
Returns:
logits after process
@ -61,6 +63,6 @@ def logit_processor(processor:str, logits , attrs):
func = _LOGIT_PROCESSOR_MAP[processor]
try:
logits = func(logits, attrs)
except Exception as e:
except Exception:
return logits
return logits
return logits