mirror of https://github.com/hpcaitech/ColossalAI
fix bugs in request_handler
parent
62fd08ee44
commit
62968588d1
|
@ -226,12 +226,15 @@ class InferenceEngine:
|
|||
self.v_cache,
|
||||
)
|
||||
|
||||
logits = logits[:, -1, :]
|
||||
self.request_handler.search_tokens(self.generation_config, logits)
|
||||
|
||||
finished_sequences = self.request_handler.update()
|
||||
|
||||
print("finished_sequences: ", finished_sequences)
|
||||
|
||||
# Decode completed sentences.
|
||||
for seq in finished_sequences:
|
||||
print("seq.output_token_id: ", seq.output_token_id)
|
||||
if seq.prompt:
|
||||
output_str = self.tokenizer.decode(seq.output_token_id, skip_special_tokens=True)
|
||||
output_list.append(seq.prompt + output_str)
|
||||
|
@ -239,4 +242,6 @@ class InferenceEngine:
|
|||
output_str = self.tokenizer.decode(seq.input_token_id + seq.output_token_id, skip_special_tokens=True)
|
||||
output_list.append(output_str)
|
||||
|
||||
print("len(output_list): ", len(output_list))
|
||||
|
||||
return output_list
|
||||
|
|
|
@ -5,6 +5,7 @@ from transformers.configuration_utils import PretrainedConfig
|
|||
|
||||
from colossalai.inference.config import InferenceConfig
|
||||
from colossalai.inference.kv_cache import KVCacheManager
|
||||
from colossalai.inference.logit_processors import logit_processor
|
||||
from colossalai.inference.sampler import *
|
||||
from colossalai.inference.struct import BatchInfo, Sequence
|
||||
|
||||
|
@ -179,10 +180,10 @@ class RequestHandler:
|
|||
"""
|
||||
# do logit processor
|
||||
# NOTE: need to decide the granularity to process logits (sequence or batch)
|
||||
# for type in ["top_p", "top_k", "min_p"]:
|
||||
# config_dict = generation_config.to_dict()
|
||||
# if type in config_dict:
|
||||
# logits = logit_processor(type, logits, config_dict[type])
|
||||
for type in ["top_p", "top_k", "min_p"]:
|
||||
config_dict = generation_config.to_dict()
|
||||
if type in config_dict:
|
||||
logits = logit_processor(type, logits, config_dict[type])
|
||||
|
||||
torch.cuda.synchronize()
|
||||
|
||||
|
@ -207,11 +208,12 @@ class RequestHandler:
|
|||
self.running_list.prefill.clear()
|
||||
self.prefill_batch.clear_batch()
|
||||
|
||||
for seq in self.running_batch.sequences_set:
|
||||
if seq.check_finish():
|
||||
self.done_list.append(seq)
|
||||
self.running_list.remove(seq)
|
||||
self.running_batch.sequences_set.remove(seq)
|
||||
self.cache_manager.free_block_table(seq.block_table)
|
||||
finish_seqs = self.running_batch.fliter_batch()
|
||||
|
||||
return self.done_list
|
||||
for seq in finish_seqs:
|
||||
self.running_list.remove(seq)
|
||||
self.cache_manager.free_block_table(seq.block_table)
|
||||
|
||||
self.done_list.extend(finish_seqs)
|
||||
|
||||
return finish_seqs
|
||||
|
|
|
@ -106,7 +106,7 @@ class Sequence:
|
|||
return True
|
||||
|
||||
if self.output_token_id:
|
||||
if self.output_token_id[-1] >= self.eos_token_id or len(self.output_token_id) == self.max_output_len:
|
||||
if self.output_token_id[-1] == self.eos_token_id or self.output_len >= self.max_output_len:
|
||||
self.status = RequestStatus.COMPLETED
|
||||
return True
|
||||
|
||||
|
|
|
@ -28,6 +28,7 @@ def check_inference_engine():
|
|||
generation_config = GenerationConfig(top_k=2, top_p=0.8, do_sample=True)
|
||||
outputs = inference_engine.generate(generation_config)
|
||||
|
||||
print("len(outputs): ", len(outputs))
|
||||
print("outputs: ", outputs)
|
||||
|
||||
# Engine still gets some bug
|
||||
|
|
Loading…
Reference in New Issue