fix bugs in sampler

pull/5258/head
yuehuayingxueluo 2024-01-04 15:03:18 +08:00 committed by FrankLeeeee
parent 02c1bf8b2a
commit bbfebfb9fc
4 changed files with 12 additions and 8 deletions

View File

@ -180,9 +180,9 @@ class RequestHandler:
"""
# do logit processor
# NOTE: need to decide the granularity to process logits (sequence or batch)
for type in ["top_p", "top_k", "min_p"]:
for type in ["top_k", "top_p", "min_p"]:
config_dict = generation_config.to_dict()
if type in config_dict:
if type in config_dict and config_dict[type] is not None:
logits = logit_processor(type, logits, config_dict[type])
torch.cuda.synchronize()

View File

@ -21,7 +21,7 @@ def multinomial_sample(
"""
Sample tokens in a random phase.
"""
random_results = torch.multinomial(probs, num_samples=1, replacement=True).cpu()
random_results = torch.multinomial(probs, num_samples=1).squeeze(1)
return random_results

View File

@ -43,11 +43,12 @@ def check_config_and_inference():
)
assert sequence.sentence_len == 3
assert sequence.prompt_len == 3
assert sequence.input_len == 3
assert sequence.output_len == 0
assert sequence.check_finish() == False
batch = BatchInfo.init_batch([sequence])
batch = BatchInfo(is_prompts=False)
batch.init_batch([sequence])
batch.add_seqs([sequence2, sequence3])
batch.add_seqs([sequence])

View File

@ -26,7 +26,7 @@ def check_inference_engine(test_cai=False):
transformers.LlamaConfig(
vocab_size=50000, hidden_size=512, intermediate_size=1536, num_attention_heads=4, num_hidden_layers=4
)
)
).cuda()
inputs = [
"介绍一下北京,",
@ -38,13 +38,16 @@ def check_inference_engine(test_cai=False):
inference_engine = InferenceEngine(model, tokenizer, inference_config, verbose=True)
inference_engine.add_request(prompts=inputs)
assert inference_engine.request_handler._has_waiting()
generation_config = GenerationConfig(do_sample=False)
generation_config = GenerationConfig(do_sample=True, top_p=0.5, top_k=50)
outputs = inference_engine.generate(generation_config)
else:
tokenizer.pad_token = tokenizer.eos_token
tokenizer.pad_token_id = tokenizer.eos_token_id
inputs = tokenizer.batch_encode_plus(inputs, padding=True, return_tensors="pt")["input_ids"]
generation_config = GenerationConfig(do_sample=False, pad_token_id=tokenizer.pad_token_id, max_new_tokens=1)
inputs = inputs.cuda()
generation_config = GenerationConfig(
do_sample=True, top_p=0.5, top_k=50, pad_token_id=tokenizer.pad_token_id, max_new_tokens=1
)
outputs = model.generate(inputs, generation_config=generation_config)
outputs = tokenizer.batch_decode(outputs, skip_special_tokens=True)