[inference] simplified config verification (#5346)

* [inference] simplified config verification

* polish

* polish
pull/5340/head^2
Frank Lee 2024-02-01 15:31:01 +08:00 committed by GitHub
parent df0aa49585
commit f8e456d202
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 40 additions and 60 deletions

View File

@ -14,23 +14,32 @@ GibiByte = 1024**3
logger = logging.Logger(__name__)
_DTYPE_MAPPING = {
"fp16": torch.float16,
"bf16": torch.bfloat16,
"fp32": torch.float32,
}
_ALLOWED_DTYPES = [torch.float16, torch.bfloat16, torch.float32]
@dataclass
class InferenceConfig:
"""The inference configuration.
Args:
micro_batch_size (int): the micro batch size. Only useful when `pp_size` > 1.
micro_batch_size (int): the micro batch size, defaults to 1. Only useful when `pp_size` > 1.
micro_batch_buffer_size (int): the buffer size for micro batch. Normally, it should be the same as the number of pipeline stages.
max_batch_size (int): Maximum batch size.
max_output_len (int): Maximum output length.
max_input_len (int): Maximum input length.
block_size (int): The number of blocks in a logical block.
max_batch_size (int): Maximum batch size, defaults to 8.
max_output_len (int): Maximum output length, defaults to 256.
max_input_len (int): Maximum input length, defaults to 256.
block_size (int): The number of blocks in a logical block, defaults to 16.
dtype (Union[str, torch.dtype]): The data type for weights and activations.
tp_size (int): Tensor parallel size.
pp_size (int): Pipeline parallel size.
beam_width (int): The maximum beam width used to initialize KV Cache.
tp_size (int): Tensor parallel size, defaults to 1.
pp_size (int): Pipeline parallel size, defaults to 1.
beam_width (int): The maximum beam width used to initialize KV Cache, defaults to 1.
During generation, the beam width provided as sampling parameter should be less than or equivalent to this value.
prefill_ratio (Optional[float]): A controling ratio for prefill and decoding in running list, we will do a step of prefill
prefill_ratio (Optional[float]): A controling ratio for prefill and decoding in running list, defaults to 1.2. We will do a step of prefill
when the actual value exceeds this ratio.
pad_input: Whether to pad all inputs to the max length.
quant_mode (Optional[str]): Quantization mode.
@ -43,7 +52,7 @@ class InferenceConfig:
max_output_len: int = 256
max_input_len: int = 256
block_size: int = 16
dtype: Union[str, torch.dtype] = torch.float32
dtype: Union[str, torch.dtype] = torch.float16 # use fp16 by default
tp_size: int = 1
pp_size: int = 1
# TODO: beam search is not support for now
@ -55,57 +64,24 @@ class InferenceConfig:
revision: Optional[str] = None
def __post_init__(self):
self._init_batch_size()
self._verify_config()
self._get_dtype()
def _init_batch_size(self):
"""
MAX_BATCH_SIZE is set to acurately utilize the memory of gpu.
We take a simple method to determine it by GPU memory size, user can still set it manually.
"""
if self.max_batch_size is not None:
# already set by user
return
device = torch.device("cuda")
total_mem = torch.cuda.get_device_properties(device).total_memory // GibiByte
self.max_batch_size = 8
if 40 < total_mem <= 60:
self.max_batch_size = 16
elif 60 < total_mem <= 80:
self.max_batch_size = 32
logger.info(
f"The maximum batch size is automatically set to {self.max_batch_size} as no value is provided by the user."
)
def _verify_config(self) -> None:
"""
Verify the input config
"""
# check dtype
if isinstance(self.dtype, str):
# convert string dtype to torch dtype
assert (
self.dtype in _DTYPE_MAPPING
), f"Expected the dtype string argument to be in {list(_DTYPE_MAPPING.keys())} but found an unknown dtype: {self.dtype}"
self.dtype = _DTYPE_MAPPING[self.dtype]
assert (
self.dtype in _ALLOWED_DTYPES
), f"Expected dtype to be in {_ALLOWED_DTYPES} but found an unknown dtype: {self.dtype}"
# check distributed
assert (
self.tp_size * self.pp_size == dist.get_world_size()
), f"TP size({self.tp_size}) * PP size({self.pp_size}) should be equal to the global world size ({dist.get_world_size()})"
assert self.dtype in [
"fp16",
"fp32",
"bf16",
torch.float32,
torch.float16,
torch.bfloat16,
], f"dtype should be one of 'fp16', 'fp32', 'bf16', torch.float32, torch.float16, torch.bfloat16, but got {self.dtype}."
assert self.quant_mode in [
"smoothquant",
"gptq",
None,
], f"quant should be one of 'smoothquant', 'gptq', but got {self.quant_mode}."
def _get_dtype(self) -> None:
if self.dtype == "fp32" or self.dtype == torch.float32:
self.dtype = torch.float32
elif self.dtype == "fp16" or self.dtype == torch.float16:
self.dtype = torch.float16
else:
self.dtype = torch.bfloat16

View File

@ -21,11 +21,15 @@ def setup_seed(seed):
def check_inference_engine(test_cai=False):
setup_seed(20)
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/llama-tokenizer")
model = LlamaForCausalLM(
LlamaConfig(
vocab_size=50000, hidden_size=512, intermediate_size=1536, num_attention_heads=4, num_hidden_layers=16
model = (
LlamaForCausalLM(
LlamaConfig(
vocab_size=50000, hidden_size=512, intermediate_size=1536, num_attention_heads=4, num_hidden_layers=16
)
)
).cuda()
.cuda()
.half()
)
model = model.eval()
@ -70,7 +74,7 @@ def run_dist(rank, world_size, port):
transformer_outputs = check_inference_engine(False)
for s1, s2 in zip(cai_outputs, transformer_outputs):
assert s1 == s2
assert s1 == s2, f"\nColossalAI Output: {s1}\nTransformers Output: {s2}"
@pytest.mark.dist