mirror of https://github.com/hpcaitech/ColossalAI
[inference] simplified config verification (#5346)
* [inference] simplified config verification * polish * polishpull/5340/head^2
parent
df0aa49585
commit
f8e456d202
|
@ -14,23 +14,32 @@ GibiByte = 1024**3
|
|||
logger = logging.Logger(__name__)
|
||||
|
||||
|
||||
_DTYPE_MAPPING = {
|
||||
"fp16": torch.float16,
|
||||
"bf16": torch.bfloat16,
|
||||
"fp32": torch.float32,
|
||||
}
|
||||
|
||||
_ALLOWED_DTYPES = [torch.float16, torch.bfloat16, torch.float32]
|
||||
|
||||
|
||||
@dataclass
|
||||
class InferenceConfig:
|
||||
"""The inference configuration.
|
||||
|
||||
Args:
|
||||
micro_batch_size (int): the micro batch size. Only useful when `pp_size` > 1.
|
||||
micro_batch_size (int): the micro batch size, defaults to 1. Only useful when `pp_size` > 1.
|
||||
micro_batch_buffer_size (int): the buffer size for micro batch. Normally, it should be the same as the number of pipeline stages.
|
||||
max_batch_size (int): Maximum batch size.
|
||||
max_output_len (int): Maximum output length.
|
||||
max_input_len (int): Maximum input length.
|
||||
block_size (int): The number of blocks in a logical block.
|
||||
max_batch_size (int): Maximum batch size, defaults to 8.
|
||||
max_output_len (int): Maximum output length, defaults to 256.
|
||||
max_input_len (int): Maximum input length, defaults to 256.
|
||||
block_size (int): The number of blocks in a logical block, defaults to 16.
|
||||
dtype (Union[str, torch.dtype]): The data type for weights and activations.
|
||||
tp_size (int): Tensor parallel size.
|
||||
pp_size (int): Pipeline parallel size.
|
||||
beam_width (int): The maximum beam width used to initialize KV Cache.
|
||||
tp_size (int): Tensor parallel size, defaults to 1.
|
||||
pp_size (int): Pipeline parallel size, defaults to 1.
|
||||
beam_width (int): The maximum beam width used to initialize KV Cache, defaults to 1.
|
||||
During generation, the beam width provided as sampling parameter should be less than or equivalent to this value.
|
||||
prefill_ratio (Optional[float]): A controling ratio for prefill and decoding in running list, we will do a step of prefill
|
||||
prefill_ratio (Optional[float]): A controling ratio for prefill and decoding in running list, defaults to 1.2. We will do a step of prefill
|
||||
when the actual value exceeds this ratio.
|
||||
pad_input: Whether to pad all inputs to the max length.
|
||||
quant_mode (Optional[str]): Quantization mode.
|
||||
|
@ -43,7 +52,7 @@ class InferenceConfig:
|
|||
max_output_len: int = 256
|
||||
max_input_len: int = 256
|
||||
block_size: int = 16
|
||||
dtype: Union[str, torch.dtype] = torch.float32
|
||||
dtype: Union[str, torch.dtype] = torch.float16 # use fp16 by default
|
||||
tp_size: int = 1
|
||||
pp_size: int = 1
|
||||
# TODO: beam search is not support for now
|
||||
|
@ -55,57 +64,24 @@ class InferenceConfig:
|
|||
revision: Optional[str] = None
|
||||
|
||||
def __post_init__(self):
|
||||
self._init_batch_size()
|
||||
self._verify_config()
|
||||
self._get_dtype()
|
||||
|
||||
def _init_batch_size(self):
|
||||
"""
|
||||
MAX_BATCH_SIZE is set to acurately utilize the memory of gpu.
|
||||
We take a simple method to determine it by GPU memory size, user can still set it manually.
|
||||
"""
|
||||
if self.max_batch_size is not None:
|
||||
# already set by user
|
||||
return
|
||||
|
||||
device = torch.device("cuda")
|
||||
total_mem = torch.cuda.get_device_properties(device).total_memory // GibiByte
|
||||
self.max_batch_size = 8
|
||||
|
||||
if 40 < total_mem <= 60:
|
||||
self.max_batch_size = 16
|
||||
elif 60 < total_mem <= 80:
|
||||
self.max_batch_size = 32
|
||||
logger.info(
|
||||
f"The maximum batch size is automatically set to {self.max_batch_size} as no value is provided by the user."
|
||||
)
|
||||
|
||||
def _verify_config(self) -> None:
|
||||
"""
|
||||
Verify the input config
|
||||
"""
|
||||
# check dtype
|
||||
if isinstance(self.dtype, str):
|
||||
# convert string dtype to torch dtype
|
||||
assert (
|
||||
self.dtype in _DTYPE_MAPPING
|
||||
), f"Expected the dtype string argument to be in {list(_DTYPE_MAPPING.keys())} but found an unknown dtype: {self.dtype}"
|
||||
self.dtype = _DTYPE_MAPPING[self.dtype]
|
||||
assert (
|
||||
self.dtype in _ALLOWED_DTYPES
|
||||
), f"Expected dtype to be in {_ALLOWED_DTYPES} but found an unknown dtype: {self.dtype}"
|
||||
|
||||
# check distributed
|
||||
assert (
|
||||
self.tp_size * self.pp_size == dist.get_world_size()
|
||||
), f"TP size({self.tp_size}) * PP size({self.pp_size}) should be equal to the global world size ({dist.get_world_size()})"
|
||||
|
||||
assert self.dtype in [
|
||||
"fp16",
|
||||
"fp32",
|
||||
"bf16",
|
||||
torch.float32,
|
||||
torch.float16,
|
||||
torch.bfloat16,
|
||||
], f"dtype should be one of 'fp16', 'fp32', 'bf16', torch.float32, torch.float16, torch.bfloat16, but got {self.dtype}."
|
||||
assert self.quant_mode in [
|
||||
"smoothquant",
|
||||
"gptq",
|
||||
None,
|
||||
], f"quant should be one of 'smoothquant', 'gptq', but got {self.quant_mode}."
|
||||
|
||||
def _get_dtype(self) -> None:
|
||||
if self.dtype == "fp32" or self.dtype == torch.float32:
|
||||
self.dtype = torch.float32
|
||||
elif self.dtype == "fp16" or self.dtype == torch.float16:
|
||||
self.dtype = torch.float16
|
||||
else:
|
||||
self.dtype = torch.bfloat16
|
||||
|
|
|
@ -21,11 +21,15 @@ def setup_seed(seed):
|
|||
def check_inference_engine(test_cai=False):
|
||||
setup_seed(20)
|
||||
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/llama-tokenizer")
|
||||
model = LlamaForCausalLM(
|
||||
model = (
|
||||
LlamaForCausalLM(
|
||||
LlamaConfig(
|
||||
vocab_size=50000, hidden_size=512, intermediate_size=1536, num_attention_heads=4, num_hidden_layers=16
|
||||
)
|
||||
).cuda()
|
||||
)
|
||||
.cuda()
|
||||
.half()
|
||||
)
|
||||
|
||||
model = model.eval()
|
||||
|
||||
|
@ -70,7 +74,7 @@ def run_dist(rank, world_size, port):
|
|||
transformer_outputs = check_inference_engine(False)
|
||||
|
||||
for s1, s2 in zip(cai_outputs, transformer_outputs):
|
||||
assert s1 == s2
|
||||
assert s1 == s2, f"\nColossalAI Output: {s1}\nTransformers Output: {s2}"
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
|
|
Loading…
Reference in New Issue