From bbfebfb9fc5250c1e4d3a6f008af652f7a0a9ca0 Mon Sep 17 00:00:00 2001 From: yuehuayingxueluo <867460659@qq.com> Date: Thu, 4 Jan 2024 15:03:18 +0800 Subject: [PATCH] fix bugs in sampler --- colossalai/inference/core/request_handler.py | 4 ++-- colossalai/inference/sampler.py | 2 +- tests/test_infer/test_config_and_struct.py | 5 +++-- tests/test_infer/test_inference_engine.py | 9 ++++++--- 4 files changed, 12 insertions(+), 8 deletions(-) diff --git a/colossalai/inference/core/request_handler.py b/colossalai/inference/core/request_handler.py index f9202b675..1754a8862 100644 --- a/colossalai/inference/core/request_handler.py +++ b/colossalai/inference/core/request_handler.py @@ -180,9 +180,9 @@ class RequestHandler: """ # do logit processor # NOTE: need to decide the granularity to process logits (sequence or batch) - for type in ["top_p", "top_k", "min_p"]: + for type in ["top_k", "top_p", "min_p"]: config_dict = generation_config.to_dict() - if type in config_dict: + if type in config_dict and config_dict[type] is not None: logits = logit_processor(type, logits, config_dict[type]) torch.cuda.synchronize() diff --git a/colossalai/inference/sampler.py b/colossalai/inference/sampler.py index e139a6071..1c0c518f9 100644 --- a/colossalai/inference/sampler.py +++ b/colossalai/inference/sampler.py @@ -21,7 +21,7 @@ def multinomial_sample( """ Sample tokens in a random phase. """ - random_results = torch.multinomial(probs, num_samples=1, replacement=True).cpu() + random_results = torch.multinomial(probs, num_samples=1).squeeze(1) return random_results diff --git a/tests/test_infer/test_config_and_struct.py b/tests/test_infer/test_config_and_struct.py index b42308bfc..7feb1cd41 100755 --- a/tests/test_infer/test_config_and_struct.py +++ b/tests/test_infer/test_config_and_struct.py @@ -43,11 +43,12 @@ def check_config_and_inference(): ) assert sequence.sentence_len == 3 - assert sequence.prompt_len == 3 + assert sequence.input_len == 3 assert sequence.output_len == 0 assert sequence.check_finish() == False - batch = BatchInfo.init_batch([sequence]) + batch = BatchInfo(is_prompts=False) + batch.init_batch([sequence]) batch.add_seqs([sequence2, sequence3]) batch.add_seqs([sequence]) diff --git a/tests/test_infer/test_inference_engine.py b/tests/test_infer/test_inference_engine.py index 72df88136..5315c7811 100644 --- a/tests/test_infer/test_inference_engine.py +++ b/tests/test_infer/test_inference_engine.py @@ -26,7 +26,7 @@ def check_inference_engine(test_cai=False): transformers.LlamaConfig( vocab_size=50000, hidden_size=512, intermediate_size=1536, num_attention_heads=4, num_hidden_layers=4 ) - ) + ).cuda() inputs = [ "介绍一下北京,", @@ -38,13 +38,16 @@ def check_inference_engine(test_cai=False): inference_engine = InferenceEngine(model, tokenizer, inference_config, verbose=True) inference_engine.add_request(prompts=inputs) assert inference_engine.request_handler._has_waiting() - generation_config = GenerationConfig(do_sample=False) + generation_config = GenerationConfig(do_sample=True, top_p=0.5, top_k=50) outputs = inference_engine.generate(generation_config) else: tokenizer.pad_token = tokenizer.eos_token tokenizer.pad_token_id = tokenizer.eos_token_id inputs = tokenizer.batch_encode_plus(inputs, padding=True, return_tensors="pt")["input_ids"] - generation_config = GenerationConfig(do_sample=False, pad_token_id=tokenizer.pad_token_id, max_new_tokens=1) + inputs = inputs.cuda() + generation_config = GenerationConfig( + do_sample=True, top_p=0.5, top_k=50, pad_token_id=tokenizer.pad_token_id, max_new_tokens=1 + ) outputs = model.generate(inputs, generation_config=generation_config) outputs = tokenizer.batch_decode(outputs, skip_special_tokens=True)