mirror of https://github.com/hpcaitech/ColossalAI
fix bugs in sampler
parent
02c1bf8b2a
commit
bbfebfb9fc
|
@ -180,9 +180,9 @@ class RequestHandler:
|
||||||
"""
|
"""
|
||||||
# do logit processor
|
# do logit processor
|
||||||
# NOTE: need to decide the granularity to process logits (sequence or batch)
|
# NOTE: need to decide the granularity to process logits (sequence or batch)
|
||||||
for type in ["top_p", "top_k", "min_p"]:
|
for type in ["top_k", "top_p", "min_p"]:
|
||||||
config_dict = generation_config.to_dict()
|
config_dict = generation_config.to_dict()
|
||||||
if type in config_dict:
|
if type in config_dict and config_dict[type] is not None:
|
||||||
logits = logit_processor(type, logits, config_dict[type])
|
logits = logit_processor(type, logits, config_dict[type])
|
||||||
|
|
||||||
torch.cuda.synchronize()
|
torch.cuda.synchronize()
|
||||||
|
|
|
@ -21,7 +21,7 @@ def multinomial_sample(
|
||||||
"""
|
"""
|
||||||
Sample tokens in a random phase.
|
Sample tokens in a random phase.
|
||||||
"""
|
"""
|
||||||
random_results = torch.multinomial(probs, num_samples=1, replacement=True).cpu()
|
random_results = torch.multinomial(probs, num_samples=1).squeeze(1)
|
||||||
return random_results
|
return random_results
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -43,11 +43,12 @@ def check_config_and_inference():
|
||||||
)
|
)
|
||||||
|
|
||||||
assert sequence.sentence_len == 3
|
assert sequence.sentence_len == 3
|
||||||
assert sequence.prompt_len == 3
|
assert sequence.input_len == 3
|
||||||
assert sequence.output_len == 0
|
assert sequence.output_len == 0
|
||||||
assert sequence.check_finish() == False
|
assert sequence.check_finish() == False
|
||||||
|
|
||||||
batch = BatchInfo.init_batch([sequence])
|
batch = BatchInfo(is_prompts=False)
|
||||||
|
batch.init_batch([sequence])
|
||||||
batch.add_seqs([sequence2, sequence3])
|
batch.add_seqs([sequence2, sequence3])
|
||||||
batch.add_seqs([sequence])
|
batch.add_seqs([sequence])
|
||||||
|
|
||||||
|
|
|
@ -26,7 +26,7 @@ def check_inference_engine(test_cai=False):
|
||||||
transformers.LlamaConfig(
|
transformers.LlamaConfig(
|
||||||
vocab_size=50000, hidden_size=512, intermediate_size=1536, num_attention_heads=4, num_hidden_layers=4
|
vocab_size=50000, hidden_size=512, intermediate_size=1536, num_attention_heads=4, num_hidden_layers=4
|
||||||
)
|
)
|
||||||
)
|
).cuda()
|
||||||
|
|
||||||
inputs = [
|
inputs = [
|
||||||
"介绍一下北京,",
|
"介绍一下北京,",
|
||||||
|
@ -38,13 +38,16 @@ def check_inference_engine(test_cai=False):
|
||||||
inference_engine = InferenceEngine(model, tokenizer, inference_config, verbose=True)
|
inference_engine = InferenceEngine(model, tokenizer, inference_config, verbose=True)
|
||||||
inference_engine.add_request(prompts=inputs)
|
inference_engine.add_request(prompts=inputs)
|
||||||
assert inference_engine.request_handler._has_waiting()
|
assert inference_engine.request_handler._has_waiting()
|
||||||
generation_config = GenerationConfig(do_sample=False)
|
generation_config = GenerationConfig(do_sample=True, top_p=0.5, top_k=50)
|
||||||
outputs = inference_engine.generate(generation_config)
|
outputs = inference_engine.generate(generation_config)
|
||||||
else:
|
else:
|
||||||
tokenizer.pad_token = tokenizer.eos_token
|
tokenizer.pad_token = tokenizer.eos_token
|
||||||
tokenizer.pad_token_id = tokenizer.eos_token_id
|
tokenizer.pad_token_id = tokenizer.eos_token_id
|
||||||
inputs = tokenizer.batch_encode_plus(inputs, padding=True, return_tensors="pt")["input_ids"]
|
inputs = tokenizer.batch_encode_plus(inputs, padding=True, return_tensors="pt")["input_ids"]
|
||||||
generation_config = GenerationConfig(do_sample=False, pad_token_id=tokenizer.pad_token_id, max_new_tokens=1)
|
inputs = inputs.cuda()
|
||||||
|
generation_config = GenerationConfig(
|
||||||
|
do_sample=True, top_p=0.5, top_k=50, pad_token_id=tokenizer.pad_token_id, max_new_tokens=1
|
||||||
|
)
|
||||||
outputs = model.generate(inputs, generation_config=generation_config)
|
outputs = model.generate(inputs, generation_config=generation_config)
|
||||||
outputs = tokenizer.batch_decode(outputs, skip_special_tokens=True)
|
outputs = tokenizer.batch_decode(outputs, skip_special_tokens=True)
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue