mirror of https://github.com/hpcaitech/ColossalAI
[inference] refactored config (#5376)
parent
1f8c7e7046
commit
9afa52061f
|
@ -35,49 +35,60 @@ class InferenceConfig:
|
||||||
"""The inference configuration.
|
"""The inference configuration.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
micro_batch_size (int): the micro batch size, defaults to 1. Only useful when `pp_size` > 1.
|
|
||||||
micro_batch_buffer_size (int): the buffer size for micro batch. Normally, it should be the same as the number of pipeline stages.
|
|
||||||
max_batch_size (int): Maximum batch size, defaults to 8.
|
max_batch_size (int): Maximum batch size, defaults to 8.
|
||||||
max_output_len (int): Maximum output length, defaults to 256.
|
max_output_len (int): Maximum output length, defaults to 256.
|
||||||
max_input_len (int): Maximum input length, defaults to 256.
|
max_input_len (int): Maximum input length, defaults to 256.
|
||||||
block_size (int): The number of blocks in a logical block, defaults to 16.
|
|
||||||
dtype (Union[str, torch.dtype]): The data type for weights and activations.
|
dtype (Union[str, torch.dtype]): The data type for weights and activations.
|
||||||
tp_size (int): Tensor parallel size, defaults to 1.
|
prompt_template (Optional[str]): The prompt template for generation, defaults to None.
|
||||||
pp_size (int): Pipeline parallel size, defaults to 1.
|
do_sample (bool): Whether to use sampling for generation, defaults to False.
|
||||||
beam_width (int): The maximum beam width used to initialize KV Cache, defaults to 1.
|
beam_width (int): The maximum beam width used to initialize KV Cache, defaults to 1.
|
||||||
During generation, the beam width provided as sampling parameter should be less than or equivalent to this value.
|
During generation, the beam width provided as sampling parameter should be less than or equivalent to this value.
|
||||||
prefill_ratio (Optional[float]): A controling ratio for prefill and decoding in running list, defaults to 1.2. We will do a step of prefill
|
prefill_ratio (Optional[float]): A controling ratio for prefill and decoding in running list, defaults to 1.2. We will do a step of prefill
|
||||||
when the actual value exceeds this ratio.
|
when the actual value exceeds this ratio.
|
||||||
pad_input: Whether to pad all inputs to the max length.
|
pad_input: Whether to pad all inputs to the max length.
|
||||||
quant_mode (Optional[str]): Quantization mode.
|
early_stopping (Optional[bool]): Whether to stop the generation when all beam hypotheses have finished or not, defaults to False.
|
||||||
revision (Optional[str]): The specific version(a branch, name, a commit id, or a tag name) of model to use.
|
top_k (Optional[int]): The number of highest probability vocabulary tokens to keep for top-k-filtering, defaults to None.
|
||||||
prompt_template (Optional[str]): The prompt template for formatting the input text. Some built-in templates include 'llama' and 'vicuna'. Otherwise, the template should contain '{input_text}' for formatting the input text.
|
top_p (Optional[float]): The cumulative probability threshold for retaining tokens with a total probability above it, defaults to None.
|
||||||
|
min_p (Optional[float]): The minimum probability to keep for top-p filtering, defaults to None.
|
||||||
|
block_size (int): The number of blocks in a logical block, defaults to 16.
|
||||||
|
tp_size (int): Tensor parallel size, defaults to 1.
|
||||||
|
pp_size (int): Pipeline parallel size, defaults to 1.
|
||||||
|
micro_batch_size (int): the micro batch size, defaults to 1. Only useful when `pp_size` > 1.
|
||||||
|
micro_batch_buffer_size (int): the buffer size for micro batch. Normally, it should be the same as the number of pipeline stages.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
micro_batch_size: int = 1
|
# NOTE: arrange configs according to their importance and frequency of usage
|
||||||
micro_batch_buffer_size: int = None
|
|
||||||
|
# runtime limit
|
||||||
max_batch_size: int = 8
|
max_batch_size: int = 8
|
||||||
max_output_len: int = 256
|
max_output_len: int = 256
|
||||||
max_input_len: int = 256
|
max_input_len: int = 256
|
||||||
block_size: int = 16
|
|
||||||
|
# general configs
|
||||||
dtype: Union[str, torch.dtype] = torch.float16 # use fp16 by default
|
dtype: Union[str, torch.dtype] = torch.float16 # use fp16 by default
|
||||||
|
|
||||||
tp_size: int = 1
|
# generation configs
|
||||||
pp_size: int = 1
|
prompt_template: Optional[str] = None
|
||||||
# TODO: beam search is not support for now
|
|
||||||
do_sample: bool = False
|
do_sample: bool = False
|
||||||
beam_width: int = 1
|
beam_width: int = 1 # TODO: beam search is not support for now
|
||||||
# the ratio of prefill sequences to decoding sequences, we do prefill step once the actual value exceeds ratio
|
prefill_ratio: Optional[
|
||||||
prefill_ratio: Optional[float] = 1.2
|
float
|
||||||
|
] = 1.2 # the ratio of prefill sequences to decoding sequences, we do prefill step once the actual value exceeds ratio
|
||||||
pad_input: bool = False
|
pad_input: bool = False
|
||||||
quant_mode: Optional[str] = None
|
|
||||||
revision: Optional[str] = None
|
|
||||||
early_stopping: Optional[bool] = False
|
early_stopping: Optional[bool] = False
|
||||||
|
|
||||||
top_k: Optional[int] = None
|
top_k: Optional[int] = None
|
||||||
top_p: Optional[float] = None
|
top_p: Optional[float] = None
|
||||||
min_p: Optional[float] = None
|
min_p: Optional[float] = None
|
||||||
prompt_template: Optional[str] = None
|
|
||||||
|
# paged attention configs
|
||||||
|
block_size: int = 16
|
||||||
|
|
||||||
|
# model parallelism configs
|
||||||
|
tp_size: int = 1
|
||||||
|
pp_size: int = 1
|
||||||
|
micro_batch_size: int = 1
|
||||||
|
micro_batch_buffer_size: int = None
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
self._verify_config()
|
self._verify_config()
|
||||||
|
|
|
@ -130,7 +130,6 @@ class InferenceEngine:
|
||||||
enable_flash_attention=False,
|
enable_flash_attention=False,
|
||||||
enable_jit_fused=False,
|
enable_jit_fused=False,
|
||||||
enable_sequence_parallelism=False,
|
enable_sequence_parallelism=False,
|
||||||
extra_kwargs={"quant": self.inference_config.quant_mode},
|
|
||||||
)
|
)
|
||||||
shardformer = ShardFormer(shard_config=shardconfig)
|
shardformer = ShardFormer(shard_config=shardconfig)
|
||||||
shard_model, _ = shardformer.optimize(model, model_policy)
|
shard_model, _ = shardformer.optimize(model, model_policy)
|
||||||
|
|
Loading…
Reference in New Issue