mirror of https://github.com/hpcaitech/ColossalAI
[Inference] Optimize request handler of llama (#5512)
* optimize request_handler * fix ways of writingpull/5537/head
parent
6251d68dc9
commit
e6496dd371
|
@ -298,8 +298,8 @@ class RequestHandler:
|
|||
"""
|
||||
# do logit processor
|
||||
# NOTE: need to decide the granularity to process logits (sequence or batch)
|
||||
config_dict = generation_config.to_dict()
|
||||
for type in ["top_k", "top_p", "min_p"]:
|
||||
config_dict = generation_config.to_dict()
|
||||
if type in config_dict and config_dict[type] is not None:
|
||||
logits = logit_processor(type, logits, config_dict[type])
|
||||
|
||||
|
|
|
@ -36,21 +36,23 @@ def top_p_logit_processor(logits, top_p: float):
|
|||
cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
|
||||
|
||||
sorted_indices_to_remove = cumulative_probs > top_p
|
||||
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
|
||||
|
||||
sorted_indices_to_remove = torch.roll(sorted_indices_to_remove, 1, -1)
|
||||
sorted_indices_to_remove[..., 0] = 0
|
||||
|
||||
indices_to_remove = sorted_indices_to_remove.scatter(dim=1, index=sorted_indices, src=sorted_indices_to_remove)
|
||||
logits[indices_to_remove] = -float("inf")
|
||||
return logits
|
||||
|
||||
def logit_processor(processor:str, logits , attrs):
|
||||
|
||||
def logit_processor(processor: str, logits, attrs):
|
||||
"""
|
||||
do logit process for given logits.
|
||||
|
||||
Args:
|
||||
processor(str): the type of logit processor
|
||||
processor(str): the type of logit processor
|
||||
logits(torch.Tensor): input logits
|
||||
attrs(dict): attrs of the logit processor
|
||||
attrs(dict): attrs of the logit processor
|
||||
|
||||
Returns:
|
||||
logits after process
|
||||
|
@ -61,6 +63,6 @@ def logit_processor(processor:str, logits , attrs):
|
|||
func = _LOGIT_PROCESSOR_MAP[processor]
|
||||
try:
|
||||
logits = func(logits, attrs)
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
return logits
|
||||
return logits
|
||||
return logits
|
||||
|
|
Loading…
Reference in New Issue