[Inference] Optimize request handler of llama (#5512)

* optimize request_handler

* fix ways of writing
pull/5537/head
傅剑寒 2024-03-26 16:37:14 +08:00 committed by GitHub
parent 6251d68dc9
commit e6496dd371
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 9 additions and 7 deletions

View File

@ -298,8 +298,8 @@ class RequestHandler:
"""
# do logit processor
# NOTE: need to decide the granularity to process logits (sequence or batch)
for type in ["top_k", "top_p", "min_p"]:
config_dict = generation_config.to_dict()
for type in ["top_k", "top_p", "min_p"]:
if type in config_dict and config_dict[type] is not None:
logits = logit_processor(type, logits, config_dict[type])

View File

@ -36,13 +36,15 @@ def top_p_logit_processor(logits, top_p: float):
cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
sorted_indices_to_remove = cumulative_probs > top_p
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
sorted_indices_to_remove = torch.roll(sorted_indices_to_remove, 1, -1)
sorted_indices_to_remove[..., 0] = 0
indices_to_remove = sorted_indices_to_remove.scatter(dim=1, index=sorted_indices, src=sorted_indices_to_remove)
logits[indices_to_remove] = -float("inf")
return logits
def logit_processor(processor: str, logits, attrs):
"""
do logit process for given logits.
@ -61,6 +63,6 @@ def logit_processor(processor:str, logits , attrs):
func = _LOGIT_PROCESSOR_MAP[processor]
try:
logits = func(logits, attrs)
except Exception as e:
except Exception:
return logits
return logits