From bfad39357b0fe31ecf6f7639e2c4056165078a3f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=82=85=E5=89=91=E5=AF=92?= Date: Thu, 9 May 2024 18:03:24 +0800 Subject: [PATCH] [Inference/Feat] Add quant kvcache interface (#5700) * add quant kvcache interface * delete unused output * complete args comments --- colossalai/inference/config.py | 8 ++++++++ colossalai/inference/kv_cache/kvcache_manager.py | 10 ++++++++-- 2 files changed, 16 insertions(+), 2 deletions(-) diff --git a/colossalai/inference/config.py b/colossalai/inference/config.py index ee1cd7cfb..aae2024e0 100644 --- a/colossalai/inference/config.py +++ b/colossalai/inference/config.py @@ -88,6 +88,7 @@ class InferenceConfig: max_output_len (int): Maximum output length, defaults to 256. max_input_len (int): Maximum input length, defaults to 256. dtype (Union[str, torch.dtype]): The data type for weights and activations. + kv_cache_dtype (Optional[str]): The data type of kv_cache, defaults to None. prompt_template (Optional[str]): The prompt template for generation, defaults to None. do_sample (bool): Whether to use sampling for generation, defaults to False. beam_width (int): The maximum beam width used to initialize KV Cache, defaults to 1. @@ -122,6 +123,7 @@ class InferenceConfig: # general configs dtype: Union[str, torch.dtype] = torch.float16 # use fp16 by default + kv_cache_dtype: Optional[str] = None # generation configs prompt_template: Optional[str] = None @@ -177,6 +179,12 @@ class InferenceConfig: self.dtype in _ALLOWED_DTYPES ), f"Expected dtype to be in {_ALLOWED_DTYPES} but found an unknown dtype: {self.dtype}" + if self.kv_cache_dtype: + assert ( + self.use_cuda_kernel and self.kv_cache_dtype == "fp8" + ), f"FP8 kv_cache is only supported with use_cuda_kernel open now" + self.kv_cache_dtype = torch.uint8 + # skip using casting when the data type is float32 if self.dtype == torch.float32: self.high_precision = False diff --git a/colossalai/inference/kv_cache/kvcache_manager.py b/colossalai/inference/kv_cache/kvcache_manager.py index 302f379f9..1b9532a3c 100644 --- a/colossalai/inference/kv_cache/kvcache_manager.py +++ b/colossalai/inference/kv_cache/kvcache_manager.py @@ -53,6 +53,12 @@ class KVCacheManager: self.tp_size = config.tp_size # Model settings self.dtype = config.dtype + + if config.kv_cache_dtype is None: + self.kv_cache_dtype = config.dtype + else: + self.kv_cache_dtype = config.kv_cache_dtype + self.elem_size_in_bytes = torch.tensor([], dtype=self.dtype).element_size() self.num_layers = model_config.num_hidden_layers self.head_num = model_config.num_attention_heads @@ -488,6 +494,6 @@ class KVCacheManager: k_cache: List[torch.Tensor] = [] v_cache: List[torch.Tensor] = [] for _ in range(self.num_layers): - k_cache.append(torch.zeros(kalloc_shape, dtype=self.dtype, device=self.device)) - v_cache.append(torch.zeros(valloc_shape, dtype=self.dtype, device=self.device)) + k_cache.append(torch.zeros(kalloc_shape, dtype=self.kv_cache_dtype, device=self.device)) + v_cache.append(torch.zeros(valloc_shape, dtype=self.kv_cache_dtype, device=self.device)) return k_cache, v_cache