[inference] refactored config (#5376)

pull/5379/head
Frank Lee 2024-02-08 14:04:14 +08:00 committed by GitHub
parent 1f8c7e7046
commit 9afa52061f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 32 additions and 22 deletions

View File

@ -35,49 +35,60 @@ class InferenceConfig:
"""The inference configuration. """The inference configuration.
Args: Args:
micro_batch_size (int): the micro batch size, defaults to 1. Only useful when `pp_size` > 1.
micro_batch_buffer_size (int): the buffer size for micro batch. Normally, it should be the same as the number of pipeline stages.
max_batch_size (int): Maximum batch size, defaults to 8. max_batch_size (int): Maximum batch size, defaults to 8.
max_output_len (int): Maximum output length, defaults to 256. max_output_len (int): Maximum output length, defaults to 256.
max_input_len (int): Maximum input length, defaults to 256. max_input_len (int): Maximum input length, defaults to 256.
block_size (int): The number of blocks in a logical block, defaults to 16.
dtype (Union[str, torch.dtype]): The data type for weights and activations. dtype (Union[str, torch.dtype]): The data type for weights and activations.
tp_size (int): Tensor parallel size, defaults to 1. prompt_template (Optional[str]): The prompt template for generation, defaults to None.
pp_size (int): Pipeline parallel size, defaults to 1. do_sample (bool): Whether to use sampling for generation, defaults to False.
beam_width (int): The maximum beam width used to initialize KV Cache, defaults to 1. beam_width (int): The maximum beam width used to initialize KV Cache, defaults to 1.
During generation, the beam width provided as sampling parameter should be less than or equivalent to this value. During generation, the beam width provided as sampling parameter should be less than or equivalent to this value.
prefill_ratio (Optional[float]): A controling ratio for prefill and decoding in running list, defaults to 1.2. We will do a step of prefill prefill_ratio (Optional[float]): A controling ratio for prefill and decoding in running list, defaults to 1.2. We will do a step of prefill
when the actual value exceeds this ratio. when the actual value exceeds this ratio.
pad_input: Whether to pad all inputs to the max length. pad_input: Whether to pad all inputs to the max length.
quant_mode (Optional[str]): Quantization mode. early_stopping (Optional[bool]): Whether to stop the generation when all beam hypotheses have finished or not, defaults to False.
revision (Optional[str]): The specific version(a branch, name, a commit id, or a tag name) of model to use. top_k (Optional[int]): The number of highest probability vocabulary tokens to keep for top-k-filtering, defaults to None.
prompt_template (Optional[str]): The prompt template for formatting the input text. Some built-in templates include 'llama' and 'vicuna'. Otherwise, the template should contain '{input_text}' for formatting the input text. top_p (Optional[float]): The cumulative probability threshold for retaining tokens with a total probability above it, defaults to None.
min_p (Optional[float]): The minimum probability to keep for top-p filtering, defaults to None.
block_size (int): The number of blocks in a logical block, defaults to 16.
tp_size (int): Tensor parallel size, defaults to 1.
pp_size (int): Pipeline parallel size, defaults to 1.
micro_batch_size (int): the micro batch size, defaults to 1. Only useful when `pp_size` > 1.
micro_batch_buffer_size (int): the buffer size for micro batch. Normally, it should be the same as the number of pipeline stages.
""" """
micro_batch_size: int = 1 # NOTE: arrange configs according to their importance and frequency of usage
micro_batch_buffer_size: int = None
# runtime limit
max_batch_size: int = 8 max_batch_size: int = 8
max_output_len: int = 256 max_output_len: int = 256
max_input_len: int = 256 max_input_len: int = 256
block_size: int = 16
# general configs
dtype: Union[str, torch.dtype] = torch.float16 # use fp16 by default dtype: Union[str, torch.dtype] = torch.float16 # use fp16 by default
tp_size: int = 1 # generation configs
pp_size: int = 1 prompt_template: Optional[str] = None
# TODO: beam search is not support for now
do_sample: bool = False do_sample: bool = False
beam_width: int = 1 beam_width: int = 1 # TODO: beam search is not support for now
# the ratio of prefill sequences to decoding sequences, we do prefill step once the actual value exceeds ratio prefill_ratio: Optional[
prefill_ratio: Optional[float] = 1.2 float
] = 1.2 # the ratio of prefill sequences to decoding sequences, we do prefill step once the actual value exceeds ratio
pad_input: bool = False pad_input: bool = False
quant_mode: Optional[str] = None
revision: Optional[str] = None
early_stopping: Optional[bool] = False early_stopping: Optional[bool] = False
top_k: Optional[int] = None top_k: Optional[int] = None
top_p: Optional[float] = None top_p: Optional[float] = None
min_p: Optional[float] = None min_p: Optional[float] = None
prompt_template: Optional[str] = None
# paged attention configs
block_size: int = 16
# model parallelism configs
tp_size: int = 1
pp_size: int = 1
micro_batch_size: int = 1
micro_batch_buffer_size: int = None
def __post_init__(self): def __post_init__(self):
self._verify_config() self._verify_config()

View File

@ -130,7 +130,6 @@ class InferenceEngine:
enable_flash_attention=False, enable_flash_attention=False,
enable_jit_fused=False, enable_jit_fused=False,
enable_sequence_parallelism=False, enable_sequence_parallelism=False,
extra_kwargs={"quant": self.inference_config.quant_mode},
) )
shardformer = ShardFormer(shard_config=shardconfig) shardformer = ShardFormer(shard_config=shardconfig)
shard_model, _ = shardformer.optimize(model, model_policy) shard_model, _ = shardformer.optimize(model, model_policy)