mirror of https://github.com/hpcaitech/ColossalAI
fix bugs in request_handler
parent
62fd08ee44
commit
62968588d1
|
@ -226,12 +226,15 @@ class InferenceEngine:
|
||||||
self.v_cache,
|
self.v_cache,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
logits = logits[:, -1, :]
|
||||||
self.request_handler.search_tokens(self.generation_config, logits)
|
self.request_handler.search_tokens(self.generation_config, logits)
|
||||||
|
|
||||||
finished_sequences = self.request_handler.update()
|
finished_sequences = self.request_handler.update()
|
||||||
|
|
||||||
|
print("finished_sequences: ", finished_sequences)
|
||||||
|
|
||||||
# Decode completed sentences.
|
# Decode completed sentences.
|
||||||
for seq in finished_sequences:
|
for seq in finished_sequences:
|
||||||
|
print("seq.output_token_id: ", seq.output_token_id)
|
||||||
if seq.prompt:
|
if seq.prompt:
|
||||||
output_str = self.tokenizer.decode(seq.output_token_id, skip_special_tokens=True)
|
output_str = self.tokenizer.decode(seq.output_token_id, skip_special_tokens=True)
|
||||||
output_list.append(seq.prompt + output_str)
|
output_list.append(seq.prompt + output_str)
|
||||||
|
@ -239,4 +242,6 @@ class InferenceEngine:
|
||||||
output_str = self.tokenizer.decode(seq.input_token_id + seq.output_token_id, skip_special_tokens=True)
|
output_str = self.tokenizer.decode(seq.input_token_id + seq.output_token_id, skip_special_tokens=True)
|
||||||
output_list.append(output_str)
|
output_list.append(output_str)
|
||||||
|
|
||||||
|
print("len(output_list): ", len(output_list))
|
||||||
|
|
||||||
return output_list
|
return output_list
|
||||||
|
|
|
@ -5,6 +5,7 @@ from transformers.configuration_utils import PretrainedConfig
|
||||||
|
|
||||||
from colossalai.inference.config import InferenceConfig
|
from colossalai.inference.config import InferenceConfig
|
||||||
from colossalai.inference.kv_cache import KVCacheManager
|
from colossalai.inference.kv_cache import KVCacheManager
|
||||||
|
from colossalai.inference.logit_processors import logit_processor
|
||||||
from colossalai.inference.sampler import *
|
from colossalai.inference.sampler import *
|
||||||
from colossalai.inference.struct import BatchInfo, Sequence
|
from colossalai.inference.struct import BatchInfo, Sequence
|
||||||
|
|
||||||
|
@ -179,10 +180,10 @@ class RequestHandler:
|
||||||
"""
|
"""
|
||||||
# do logit processor
|
# do logit processor
|
||||||
# NOTE: need to decide the granularity to process logits (sequence or batch)
|
# NOTE: need to decide the granularity to process logits (sequence or batch)
|
||||||
# for type in ["top_p", "top_k", "min_p"]:
|
for type in ["top_p", "top_k", "min_p"]:
|
||||||
# config_dict = generation_config.to_dict()
|
config_dict = generation_config.to_dict()
|
||||||
# if type in config_dict:
|
if type in config_dict:
|
||||||
# logits = logit_processor(type, logits, config_dict[type])
|
logits = logit_processor(type, logits, config_dict[type])
|
||||||
|
|
||||||
torch.cuda.synchronize()
|
torch.cuda.synchronize()
|
||||||
|
|
||||||
|
@ -207,11 +208,12 @@ class RequestHandler:
|
||||||
self.running_list.prefill.clear()
|
self.running_list.prefill.clear()
|
||||||
self.prefill_batch.clear_batch()
|
self.prefill_batch.clear_batch()
|
||||||
|
|
||||||
for seq in self.running_batch.sequences_set:
|
finish_seqs = self.running_batch.fliter_batch()
|
||||||
if seq.check_finish():
|
|
||||||
self.done_list.append(seq)
|
|
||||||
self.running_list.remove(seq)
|
|
||||||
self.running_batch.sequences_set.remove(seq)
|
|
||||||
self.cache_manager.free_block_table(seq.block_table)
|
|
||||||
|
|
||||||
return self.done_list
|
for seq in finish_seqs:
|
||||||
|
self.running_list.remove(seq)
|
||||||
|
self.cache_manager.free_block_table(seq.block_table)
|
||||||
|
|
||||||
|
self.done_list.extend(finish_seqs)
|
||||||
|
|
||||||
|
return finish_seqs
|
||||||
|
|
|
@ -106,7 +106,7 @@ class Sequence:
|
||||||
return True
|
return True
|
||||||
|
|
||||||
if self.output_token_id:
|
if self.output_token_id:
|
||||||
if self.output_token_id[-1] >= self.eos_token_id or len(self.output_token_id) == self.max_output_len:
|
if self.output_token_id[-1] == self.eos_token_id or self.output_len >= self.max_output_len:
|
||||||
self.status = RequestStatus.COMPLETED
|
self.status = RequestStatus.COMPLETED
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
|
@ -28,6 +28,7 @@ def check_inference_engine():
|
||||||
generation_config = GenerationConfig(top_k=2, top_p=0.8, do_sample=True)
|
generation_config = GenerationConfig(top_k=2, top_p=0.8, do_sample=True)
|
||||||
outputs = inference_engine.generate(generation_config)
|
outputs = inference_engine.generate(generation_config)
|
||||||
|
|
||||||
|
print("len(outputs): ", len(outputs))
|
||||||
print("outputs: ", outputs)
|
print("outputs: ", outputs)
|
||||||
|
|
||||||
# Engine still gets some bug
|
# Engine still gets some bug
|
||||||
|
|
Loading…
Reference in New Issue