mirror of https://github.com/hpcaitech/ColossalAI
[Inference/Feat] Add quant kvcache interface (#5700)
* add quant kvcache interface * delete unused output * complete args commentspull/5707/head
parent
492520dbdb
commit
bfad39357b
|
@ -88,6 +88,7 @@ class InferenceConfig:
|
||||||
max_output_len (int): Maximum output length, defaults to 256.
|
max_output_len (int): Maximum output length, defaults to 256.
|
||||||
max_input_len (int): Maximum input length, defaults to 256.
|
max_input_len (int): Maximum input length, defaults to 256.
|
||||||
dtype (Union[str, torch.dtype]): The data type for weights and activations.
|
dtype (Union[str, torch.dtype]): The data type for weights and activations.
|
||||||
|
kv_cache_dtype (Optional[str]): The data type of kv_cache, defaults to None.
|
||||||
prompt_template (Optional[str]): The prompt template for generation, defaults to None.
|
prompt_template (Optional[str]): The prompt template for generation, defaults to None.
|
||||||
do_sample (bool): Whether to use sampling for generation, defaults to False.
|
do_sample (bool): Whether to use sampling for generation, defaults to False.
|
||||||
beam_width (int): The maximum beam width used to initialize KV Cache, defaults to 1.
|
beam_width (int): The maximum beam width used to initialize KV Cache, defaults to 1.
|
||||||
|
@ -122,6 +123,7 @@ class InferenceConfig:
|
||||||
|
|
||||||
# general configs
|
# general configs
|
||||||
dtype: Union[str, torch.dtype] = torch.float16 # use fp16 by default
|
dtype: Union[str, torch.dtype] = torch.float16 # use fp16 by default
|
||||||
|
kv_cache_dtype: Optional[str] = None
|
||||||
|
|
||||||
# generation configs
|
# generation configs
|
||||||
prompt_template: Optional[str] = None
|
prompt_template: Optional[str] = None
|
||||||
|
@ -177,6 +179,12 @@ class InferenceConfig:
|
||||||
self.dtype in _ALLOWED_DTYPES
|
self.dtype in _ALLOWED_DTYPES
|
||||||
), f"Expected dtype to be in {_ALLOWED_DTYPES} but found an unknown dtype: {self.dtype}"
|
), f"Expected dtype to be in {_ALLOWED_DTYPES} but found an unknown dtype: {self.dtype}"
|
||||||
|
|
||||||
|
if self.kv_cache_dtype:
|
||||||
|
assert (
|
||||||
|
self.use_cuda_kernel and self.kv_cache_dtype == "fp8"
|
||||||
|
), f"FP8 kv_cache is only supported with use_cuda_kernel open now"
|
||||||
|
self.kv_cache_dtype = torch.uint8
|
||||||
|
|
||||||
# skip using casting when the data type is float32
|
# skip using casting when the data type is float32
|
||||||
if self.dtype == torch.float32:
|
if self.dtype == torch.float32:
|
||||||
self.high_precision = False
|
self.high_precision = False
|
||||||
|
|
|
@ -53,6 +53,12 @@ class KVCacheManager:
|
||||||
self.tp_size = config.tp_size
|
self.tp_size = config.tp_size
|
||||||
# Model settings
|
# Model settings
|
||||||
self.dtype = config.dtype
|
self.dtype = config.dtype
|
||||||
|
|
||||||
|
if config.kv_cache_dtype is None:
|
||||||
|
self.kv_cache_dtype = config.dtype
|
||||||
|
else:
|
||||||
|
self.kv_cache_dtype = config.kv_cache_dtype
|
||||||
|
|
||||||
self.elem_size_in_bytes = torch.tensor([], dtype=self.dtype).element_size()
|
self.elem_size_in_bytes = torch.tensor([], dtype=self.dtype).element_size()
|
||||||
self.num_layers = model_config.num_hidden_layers
|
self.num_layers = model_config.num_hidden_layers
|
||||||
self.head_num = model_config.num_attention_heads
|
self.head_num = model_config.num_attention_heads
|
||||||
|
@ -488,6 +494,6 @@ class KVCacheManager:
|
||||||
k_cache: List[torch.Tensor] = []
|
k_cache: List[torch.Tensor] = []
|
||||||
v_cache: List[torch.Tensor] = []
|
v_cache: List[torch.Tensor] = []
|
||||||
for _ in range(self.num_layers):
|
for _ in range(self.num_layers):
|
||||||
k_cache.append(torch.zeros(kalloc_shape, dtype=self.dtype, device=self.device))
|
k_cache.append(torch.zeros(kalloc_shape, dtype=self.kv_cache_dtype, device=self.device))
|
||||||
v_cache.append(torch.zeros(valloc_shape, dtype=self.dtype, device=self.device))
|
v_cache.append(torch.zeros(valloc_shape, dtype=self.kv_cache_dtype, device=self.device))
|
||||||
return k_cache, v_cache
|
return k_cache, v_cache
|
||||||
|
|
Loading…
Reference in New Issue