mirror of https://github.com/hpcaitech/ColossalAI
[Inference] User Experience: update the logic of default tokenizer and generation config. (#5337)
* add * fix * fix * pause * fix * fix pytest * align * fix * license * fix * fix * fix readme * fix some bugs * remove tokenizer configpull/5376/head
parent
6fb4bcbb24
commit
1f8c7e7046
|
@ -86,7 +86,7 @@ colossalai.launch_from_torch(config={})
|
||||||
# Step 1: create a model in "transformers" way
|
# Step 1: create a model in "transformers" way
|
||||||
model_path = "lmsys/vicuna-7b-v1.3"
|
model_path = "lmsys/vicuna-7b-v1.3"
|
||||||
model = transformers.LlamaForCausalLM.from_pretrained(model_path).cuda()
|
model = transformers.LlamaForCausalLM.from_pretrained(model_path).cuda()
|
||||||
tokenizer = transformers.LlamaTokenizer.from_pretrained(model_path)
|
tokenizer = transformers.AutoTokenizer.from_pretrained(model_path)
|
||||||
|
|
||||||
# Step 2: create an inference_config
|
# Step 2: create an inference_config
|
||||||
inference_config = InferenceConfig(
|
inference_config = InferenceConfig(
|
||||||
|
@ -100,13 +100,8 @@ inference_config = InferenceConfig(
|
||||||
engine = InferenceEngine(model, tokenizer, inference_config, verbose=True)
|
engine = InferenceEngine(model, tokenizer, inference_config, verbose=True)
|
||||||
|
|
||||||
# Step 4: try inference
|
# Step 4: try inference
|
||||||
generation_config = transformers.GenerationConfig(
|
|
||||||
pad_token_id=tokenizer.pad_token_id,
|
|
||||||
max_new_tokens=512,
|
|
||||||
)
|
|
||||||
prompts = ['Who is the best player in the history of NBA?']
|
prompts = ['Who is the best player in the history of NBA?']
|
||||||
engine.add_request(prompts=prompts)
|
response = engine.generate(prompts)
|
||||||
response = engine.generate(generation_config)
|
|
||||||
pprint(response)
|
pprint(response)
|
||||||
```
|
```
|
||||||
|
|
||||||
|
@ -150,13 +145,16 @@ Notations:
|
||||||
- [x] Paged Attention
|
- [x] Paged Attention
|
||||||
- [x] High-Performance Kernels
|
- [x] High-Performance Kernels
|
||||||
- [x] Llama Modelling
|
- [x] Llama Modelling
|
||||||
|
- [x] User Documentation
|
||||||
|
- [ ] Speculative Decoding
|
||||||
- [ ] Tensor Parallelism
|
- [ ] Tensor Parallelism
|
||||||
- [ ] Beam Search
|
- [ ] Beam Search
|
||||||
- [ ] Speculative Decoding
|
- [ ] Early stopping
|
||||||
|
- [ ] Logger system
|
||||||
|
- [ ] SplitFuse
|
||||||
- [ ] Continuous Batching
|
- [ ] Continuous Batching
|
||||||
- [ ] Online Inference
|
- [ ] Online Inference
|
||||||
- [ ] Benchmarking
|
- [ ] Benchmarking
|
||||||
- [ ] User Documentation
|
|
||||||
|
|
||||||
## 🌟 Acknowledgement
|
## 🌟 Acknowledgement
|
||||||
|
|
||||||
|
|
|
@ -8,6 +8,7 @@ from typing import Optional, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
|
from transformers.generation import GenerationConfig
|
||||||
|
|
||||||
GibiByte = 1024**3
|
GibiByte = 1024**3
|
||||||
|
|
||||||
|
@ -60,15 +61,22 @@ class InferenceConfig:
|
||||||
max_input_len: int = 256
|
max_input_len: int = 256
|
||||||
block_size: int = 16
|
block_size: int = 16
|
||||||
dtype: Union[str, torch.dtype] = torch.float16 # use fp16 by default
|
dtype: Union[str, torch.dtype] = torch.float16 # use fp16 by default
|
||||||
|
|
||||||
tp_size: int = 1
|
tp_size: int = 1
|
||||||
pp_size: int = 1
|
pp_size: int = 1
|
||||||
# TODO: beam search is not support for now
|
# TODO: beam search is not support for now
|
||||||
|
do_sample: bool = False
|
||||||
beam_width: int = 1
|
beam_width: int = 1
|
||||||
# the ratio of prefill sequences to decoding sequences, we do prefill step once the actual value exceeds ratio
|
# the ratio of prefill sequences to decoding sequences, we do prefill step once the actual value exceeds ratio
|
||||||
prefill_ratio: Optional[float] = 1.2
|
prefill_ratio: Optional[float] = 1.2
|
||||||
pad_input: bool = False
|
pad_input: bool = False
|
||||||
quant_mode: Optional[str] = None
|
quant_mode: Optional[str] = None
|
||||||
revision: Optional[str] = None
|
revision: Optional[str] = None
|
||||||
|
early_stopping: Optional[bool] = False
|
||||||
|
|
||||||
|
top_k: Optional[int] = None
|
||||||
|
top_p: Optional[float] = None
|
||||||
|
min_p: Optional[float] = None
|
||||||
prompt_template: Optional[str] = None
|
prompt_template: Optional[str] = None
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
|
@ -93,7 +101,6 @@ class InferenceConfig:
|
||||||
assert (
|
assert (
|
||||||
self.tp_size * self.pp_size == dist.get_world_size()
|
self.tp_size * self.pp_size == dist.get_world_size()
|
||||||
), f"TP size({self.tp_size}) * PP size({self.pp_size}) should be equal to the global world size ({dist.get_world_size()})"
|
), f"TP size({self.tp_size}) * PP size({self.pp_size}) should be equal to the global world size ({dist.get_world_size()})"
|
||||||
|
|
||||||
# check prompt template
|
# check prompt template
|
||||||
if self.prompt_template is None:
|
if self.prompt_template is None:
|
||||||
return
|
return
|
||||||
|
@ -105,3 +112,20 @@ class InferenceConfig:
|
||||||
assert (
|
assert (
|
||||||
"{input_text}" in self.prompt_template
|
"{input_text}" in self.prompt_template
|
||||||
), "The prompt template should contain '{input_text}' for formatting the input text. For example: 'USER: {input_text}\n\nASSISTANT: '"
|
), "The prompt template should contain '{input_text}' for formatting the input text. For example: 'USER: {input_text}\n\nASSISTANT: '"
|
||||||
|
|
||||||
|
def to_generation_config(self, model_config) -> GenerationConfig:
|
||||||
|
meta_config = {
|
||||||
|
"max_length": self.max_input_len + self.max_output_len,
|
||||||
|
"max_new_tokens": self.max_output_len,
|
||||||
|
"early_stopping": self.early_stopping,
|
||||||
|
"do_sample": self.do_sample,
|
||||||
|
"num_beams": self.beam_width,
|
||||||
|
}
|
||||||
|
for type in ["top_k", "top_p", "min_p"]:
|
||||||
|
if hasattr(self, type):
|
||||||
|
meta_config[type] = getattr(self, type)
|
||||||
|
for type in ["pad_token_id", "bos_token_id", "eos_token_id"]:
|
||||||
|
if hasattr(model_config, type):
|
||||||
|
meta_config[type] = getattr(model_config, type)
|
||||||
|
|
||||||
|
return GenerationConfig.from_dict(meta_config)
|
||||||
|
|
|
@ -33,7 +33,7 @@ class InferenceEngine:
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
model (nn.Module): Path or nn.Module of this model.
|
model (nn.Module): Path or nn.Module of this model.
|
||||||
tokenizer (Union[PreTrainedTokenizer, PreTrainedTokenizerFast]): Path of the tokenizer to use.
|
tokenizer Optional[(Union[PreTrainedTokenizer, PreTrainedTokenizerFast])]: Path of the tokenizer to use.
|
||||||
inference_config (Optional[InferenceConfig], optional): Store the configuration information related to inference.
|
inference_config (Optional[InferenceConfig], optional): Store the configuration information related to inference.
|
||||||
verbose (bool): Determine whether or not to log the generation process.
|
verbose (bool): Determine whether or not to log the generation process.
|
||||||
model_policy ("Policy"): the policy to shardformer model. It will be determined by the model type if not provided.
|
model_policy ("Policy"): the policy to shardformer model. It will be determined by the model type if not provided.
|
||||||
|
@ -42,19 +42,20 @@ class InferenceEngine:
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
model: nn.Module,
|
model: nn.Module,
|
||||||
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
|
tokenizer: [Union[PreTrainedTokenizer, PreTrainedTokenizerFast]],
|
||||||
inference_config: Optional["InferenceConfig"] = None,
|
inference_config: InferenceConfig,
|
||||||
verbose: bool = False,
|
verbose: bool = False,
|
||||||
model_policy: Policy = None,
|
model_policy: Policy = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
assert inference_config, "Please provide inference_config."
|
assert inference_config, "Please provide inference_config."
|
||||||
self.tokenizer = tokenizer
|
assert tokenizer, "Please provide a tokenizer, either a defined one or str"
|
||||||
self.tokenizer.pad_token = self.tokenizer.eos_token
|
|
||||||
self.inference_config = inference_config
|
self.inference_config = inference_config
|
||||||
self.model_config = model.config
|
self.model_config = model.config
|
||||||
self.device = torch.device("cuda")
|
self.device = torch.device("cuda")
|
||||||
self.dtype = inference_config.dtype
|
self.dtype = inference_config.dtype
|
||||||
|
self.tokenizer = tokenizer
|
||||||
|
self.tokenizer.pad_token = self.tokenizer.eos_token
|
||||||
|
self.generation_config = inference_config.to_generation_config(self.model_config)
|
||||||
model = model.eval()
|
model = model.eval()
|
||||||
model.to(self.dtype)
|
model.to(self.dtype)
|
||||||
|
|
||||||
|
@ -80,6 +81,8 @@ class InferenceEngine:
|
||||||
|
|
||||||
self.request_handler = RequestHandler(self.inference_config, self.model_config)
|
self.request_handler = RequestHandler(self.inference_config, self.model_config)
|
||||||
self.k_cahce, self.v_cache = self.request_handler.get_kvcache()
|
self.k_cahce, self.v_cache = self.request_handler.get_kvcache()
|
||||||
|
# DISCUSS maybe move this into batch info?
|
||||||
|
|
||||||
self.counter = count()
|
self.counter = count()
|
||||||
|
|
||||||
def _verify_config(self) -> None:
|
def _verify_config(self) -> None:
|
||||||
|
@ -137,7 +140,7 @@ class InferenceEngine:
|
||||||
self,
|
self,
|
||||||
prompts: List[str] = None,
|
prompts: List[str] = None,
|
||||||
prompts_token_ids: Union[List[int], torch.Tensor, np.ndarray] = None,
|
prompts_token_ids: Union[List[int], torch.Tensor, np.ndarray] = None,
|
||||||
generation_config: GenerationConfig = None,
|
generation_config: Optional[GenerationConfig] = None,
|
||||||
) -> List[str]:
|
) -> List[str]:
|
||||||
"""
|
"""
|
||||||
Executing the inference step.
|
Executing the inference step.
|
||||||
|
@ -158,6 +161,10 @@ class InferenceEngine:
|
||||||
output_seqs_list = []
|
output_seqs_list = []
|
||||||
output_tokens_list = []
|
output_tokens_list = []
|
||||||
|
|
||||||
|
# intuition: If user provide a generation config, we should replace the existing one.
|
||||||
|
if generation_config is not None:
|
||||||
|
self.generation_config = generation_config
|
||||||
|
|
||||||
while self.request_handler.check_unfinished_seqs():
|
while self.request_handler.check_unfinished_seqs():
|
||||||
output_seqs_list += self.step()
|
output_seqs_list += self.step()
|
||||||
|
|
||||||
|
@ -285,8 +292,8 @@ class InferenceEngine:
|
||||||
|
|
||||||
if self.inference_config.pad_input:
|
if self.inference_config.pad_input:
|
||||||
logits = logits[:, -1, :]
|
logits = logits[:, -1, :]
|
||||||
|
|
||||||
self.request_handler.search_tokens(self.generation_config, logits)
|
self.request_handler.search_tokens(self.generation_config, logits)
|
||||||
|
|
||||||
finished_sequences = self.request_handler.update()
|
finished_sequences = self.request_handler.update()
|
||||||
|
|
||||||
return finished_sequences
|
return finished_sequences
|
||||||
|
|
|
@ -2,6 +2,7 @@ from typing import List
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from transformers.configuration_utils import PretrainedConfig
|
from transformers.configuration_utils import PretrainedConfig
|
||||||
|
from transformers.generation import GenerationConfig
|
||||||
|
|
||||||
from colossalai.inference.config import InferenceConfig
|
from colossalai.inference.config import InferenceConfig
|
||||||
from colossalai.inference.flash_decoding_utils import FDIntermTensors
|
from colossalai.inference.flash_decoding_utils import FDIntermTensors
|
||||||
|
@ -94,6 +95,10 @@ class RequestHandler:
|
||||||
head_dim = model_config.hidden_size // model_config.num_attention_heads
|
head_dim = model_config.hidden_size // model_config.num_attention_heads
|
||||||
|
|
||||||
fd_inter_tensor = FDIntermTensors()
|
fd_inter_tensor = FDIntermTensors()
|
||||||
|
|
||||||
|
if fd_inter_tensor._tensors_initialized:
|
||||||
|
fd_inter_tensor._reset()
|
||||||
|
|
||||||
fd_inter_tensor.initialize(
|
fd_inter_tensor.initialize(
|
||||||
max_batch_size=self.max_batch_size,
|
max_batch_size=self.max_batch_size,
|
||||||
num_attn_heads=model_config.num_attention_heads,
|
num_attn_heads=model_config.num_attention_heads,
|
||||||
|
@ -170,6 +175,7 @@ class RequestHandler:
|
||||||
self.cache_manager.allocate_context_from_block_table(seq.block_table, seq.sentence_len)
|
self.cache_manager.allocate_context_from_block_table(seq.block_table, seq.sentence_len)
|
||||||
for seq in remove_list:
|
for seq in remove_list:
|
||||||
lst.remove(seq)
|
lst.remove(seq)
|
||||||
|
|
||||||
if self.running_list.ready_for_prefill():
|
if self.running_list.ready_for_prefill():
|
||||||
for seq in self.running_list.prefill:
|
for seq in self.running_list.prefill:
|
||||||
seq.mark_running()
|
seq.mark_running()
|
||||||
|
@ -229,7 +235,7 @@ class RequestHandler:
|
||||||
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def _sample(self, probs: torch.Tensor, logprobs: torch.Tensor, generation_config):
|
def _sample(self, probs: torch.Tensor, logprobs: torch.Tensor, generation_config: GenerationConfig):
|
||||||
if generation_config.num_beams == 1:
|
if generation_config.num_beams == 1:
|
||||||
if generation_config.do_sample:
|
if generation_config.do_sample:
|
||||||
sample_tokens = multinomial_sample(generation_config, probs)
|
sample_tokens = multinomial_sample(generation_config, probs)
|
||||||
|
@ -240,7 +246,7 @@ class RequestHandler:
|
||||||
|
|
||||||
return sample_tokens
|
return sample_tokens
|
||||||
|
|
||||||
def mark_finished(self, sequence: Sequence, generation_config):
|
def mark_finished(self, sequence: Sequence, generation_config: GenerationConfig):
|
||||||
if (
|
if (
|
||||||
sequence.output_token_id[-1] == generation_config.eos_id
|
sequence.output_token_id[-1] == generation_config.eos_id
|
||||||
or sequence.output_len >= generation_config.max_output_len
|
or sequence.output_len >= generation_config.max_output_len
|
||||||
|
@ -250,7 +256,7 @@ class RequestHandler:
|
||||||
def check_unfinished_seqs(self) -> bool:
|
def check_unfinished_seqs(self) -> bool:
|
||||||
return self._has_waiting() or not self.running_list.is_empty()
|
return self._has_waiting() or not self.running_list.is_empty()
|
||||||
|
|
||||||
def search_tokens(self, generation_config, logits):
|
def search_tokens(self, generation_config: GenerationConfig, logits):
|
||||||
"""
|
"""
|
||||||
Sample tokens for finished requests.
|
Sample tokens for finished requests.
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -12,6 +12,11 @@ class FDIntermTensors(metaclass=SingletonMeta):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self._tensors_initialized = False
|
self._tensors_initialized = False
|
||||||
|
|
||||||
|
def _reset(self):
|
||||||
|
self._tensors_initialized = False
|
||||||
|
del self._mid_output
|
||||||
|
del self._mid_output_lse
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def is_initialized(self):
|
def is_initialized(self):
|
||||||
return self._tensors_initialized
|
return self._tensors_initialized
|
||||||
|
|
|
@ -72,7 +72,6 @@ def llama_model_forward(
|
||||||
"""
|
"""
|
||||||
input_ids = batch.get_1D_inputs()
|
input_ids = batch.get_1D_inputs()
|
||||||
block_tables = batch.get_block_table_tensor()
|
block_tables = batch.get_block_table_tensor()
|
||||||
|
|
||||||
sequence_lengths = batch.get_sequence_lengths()
|
sequence_lengths = batch.get_sequence_lengths()
|
||||||
batch_size = len(sequence_lengths)
|
batch_size = len(sequence_lengths)
|
||||||
kv_seq_len = sequence_lengths.max().item()
|
kv_seq_len = sequence_lengths.max().item()
|
||||||
|
|
|
@ -31,7 +31,6 @@ def check_inference_engine(use_engine=False, prompt_template=None):
|
||||||
.cuda()
|
.cuda()
|
||||||
.half()
|
.half()
|
||||||
)
|
)
|
||||||
|
|
||||||
model = model.eval()
|
model = model.eval()
|
||||||
|
|
||||||
inputs = [
|
inputs = [
|
||||||
|
@ -47,6 +46,7 @@ def check_inference_engine(use_engine=False, prompt_template=None):
|
||||||
if use_engine:
|
if use_engine:
|
||||||
inference_config = InferenceConfig(max_output_len=output_len, prompt_template=prompt_template)
|
inference_config = InferenceConfig(max_output_len=output_len, prompt_template=prompt_template)
|
||||||
inference_engine = InferenceEngine(model, tokenizer, inference_config, verbose=True)
|
inference_engine = InferenceEngine(model, tokenizer, inference_config, verbose=True)
|
||||||
|
assert inference_engine.generation_config.max_new_tokens == output_len
|
||||||
inference_engine.add_request(prompts=inputs)
|
inference_engine.add_request(prompts=inputs)
|
||||||
assert inference_engine.request_handler._has_waiting()
|
assert inference_engine.request_handler._has_waiting()
|
||||||
generation_config = GenerationConfig(do_sample=do_sample, top_p=top_p, top_k=top_k)
|
generation_config = GenerationConfig(do_sample=do_sample, top_p=top_p, top_k=top_k)
|
||||||
|
|
Loading…
Reference in New Issue