From f8e456d20295af52665ca06a21f9fd8b468204d7 Mon Sep 17 00:00:00 2001 From: Frank Lee Date: Thu, 1 Feb 2024 15:31:01 +0800 Subject: [PATCH] [inference] simplified config verification (#5346) * [inference] simplified config verification * polish * polish --- colossalai/inference/config.py | 86 ++++++++--------------- tests/test_infer/test_inference_engine.py | 14 ++-- 2 files changed, 40 insertions(+), 60 deletions(-) diff --git a/colossalai/inference/config.py b/colossalai/inference/config.py index f54555857..6923d63e3 100644 --- a/colossalai/inference/config.py +++ b/colossalai/inference/config.py @@ -14,23 +14,32 @@ GibiByte = 1024**3 logger = logging.Logger(__name__) +_DTYPE_MAPPING = { + "fp16": torch.float16, + "bf16": torch.bfloat16, + "fp32": torch.float32, +} + +_ALLOWED_DTYPES = [torch.float16, torch.bfloat16, torch.float32] + + @dataclass class InferenceConfig: """The inference configuration. Args: - micro_batch_size (int): the micro batch size. Only useful when `pp_size` > 1. + micro_batch_size (int): the micro batch size, defaults to 1. Only useful when `pp_size` > 1. micro_batch_buffer_size (int): the buffer size for micro batch. Normally, it should be the same as the number of pipeline stages. - max_batch_size (int): Maximum batch size. - max_output_len (int): Maximum output length. - max_input_len (int): Maximum input length. - block_size (int): The number of blocks in a logical block. + max_batch_size (int): Maximum batch size, defaults to 8. + max_output_len (int): Maximum output length, defaults to 256. + max_input_len (int): Maximum input length, defaults to 256. + block_size (int): The number of blocks in a logical block, defaults to 16. dtype (Union[str, torch.dtype]): The data type for weights and activations. - tp_size (int): Tensor parallel size. - pp_size (int): Pipeline parallel size. - beam_width (int): The maximum beam width used to initialize KV Cache. + tp_size (int): Tensor parallel size, defaults to 1. + pp_size (int): Pipeline parallel size, defaults to 1. + beam_width (int): The maximum beam width used to initialize KV Cache, defaults to 1. During generation, the beam width provided as sampling parameter should be less than or equivalent to this value. - prefill_ratio (Optional[float]): A controling ratio for prefill and decoding in running list, we will do a step of prefill + prefill_ratio (Optional[float]): A controling ratio for prefill and decoding in running list, defaults to 1.2. We will do a step of prefill when the actual value exceeds this ratio. pad_input: Whether to pad all inputs to the max length. quant_mode (Optional[str]): Quantization mode. @@ -43,7 +52,7 @@ class InferenceConfig: max_output_len: int = 256 max_input_len: int = 256 block_size: int = 16 - dtype: Union[str, torch.dtype] = torch.float32 + dtype: Union[str, torch.dtype] = torch.float16 # use fp16 by default tp_size: int = 1 pp_size: int = 1 # TODO: beam search is not support for now @@ -55,57 +64,24 @@ class InferenceConfig: revision: Optional[str] = None def __post_init__(self): - self._init_batch_size() self._verify_config() - self._get_dtype() - - def _init_batch_size(self): - """ - MAX_BATCH_SIZE is set to acurately utilize the memory of gpu. - We take a simple method to determine it by GPU memory size, user can still set it manually. - """ - if self.max_batch_size is not None: - # already set by user - return - - device = torch.device("cuda") - total_mem = torch.cuda.get_device_properties(device).total_memory // GibiByte - self.max_batch_size = 8 - - if 40 < total_mem <= 60: - self.max_batch_size = 16 - elif 60 < total_mem <= 80: - self.max_batch_size = 32 - logger.info( - f"The maximum batch size is automatically set to {self.max_batch_size} as no value is provided by the user." - ) def _verify_config(self) -> None: """ Verify the input config """ + # check dtype + if isinstance(self.dtype, str): + # convert string dtype to torch dtype + assert ( + self.dtype in _DTYPE_MAPPING + ), f"Expected the dtype string argument to be in {list(_DTYPE_MAPPING.keys())} but found an unknown dtype: {self.dtype}" + self.dtype = _DTYPE_MAPPING[self.dtype] + assert ( + self.dtype in _ALLOWED_DTYPES + ), f"Expected dtype to be in {_ALLOWED_DTYPES} but found an unknown dtype: {self.dtype}" + + # check distributed assert ( self.tp_size * self.pp_size == dist.get_world_size() ), f"TP size({self.tp_size}) * PP size({self.pp_size}) should be equal to the global world size ({dist.get_world_size()})" - - assert self.dtype in [ - "fp16", - "fp32", - "bf16", - torch.float32, - torch.float16, - torch.bfloat16, - ], f"dtype should be one of 'fp16', 'fp32', 'bf16', torch.float32, torch.float16, torch.bfloat16, but got {self.dtype}." - assert self.quant_mode in [ - "smoothquant", - "gptq", - None, - ], f"quant should be one of 'smoothquant', 'gptq', but got {self.quant_mode}." - - def _get_dtype(self) -> None: - if self.dtype == "fp32" or self.dtype == torch.float32: - self.dtype = torch.float32 - elif self.dtype == "fp16" or self.dtype == torch.float16: - self.dtype = torch.float16 - else: - self.dtype = torch.bfloat16 diff --git a/tests/test_infer/test_inference_engine.py b/tests/test_infer/test_inference_engine.py index 19e1a5636..49bbe6df3 100644 --- a/tests/test_infer/test_inference_engine.py +++ b/tests/test_infer/test_inference_engine.py @@ -21,11 +21,15 @@ def setup_seed(seed): def check_inference_engine(test_cai=False): setup_seed(20) tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/llama-tokenizer") - model = LlamaForCausalLM( - LlamaConfig( - vocab_size=50000, hidden_size=512, intermediate_size=1536, num_attention_heads=4, num_hidden_layers=16 + model = ( + LlamaForCausalLM( + LlamaConfig( + vocab_size=50000, hidden_size=512, intermediate_size=1536, num_attention_heads=4, num_hidden_layers=16 + ) ) - ).cuda() + .cuda() + .half() + ) model = model.eval() @@ -70,7 +74,7 @@ def run_dist(rank, world_size, port): transformer_outputs = check_inference_engine(False) for s1, s2 in zip(cai_outputs, transformer_outputs): - assert s1 == s2 + assert s1 == s2, f"\nColossalAI Output: {s1}\nTransformers Output: {s2}" @pytest.mark.dist