[hotfix] Fix KV Heads Number Assignment in KVCacheManager (#5695)

- Fix key value number assignment in KVCacheManager, as well as method of accessing
pull/5697/head
Yuanheng Zhao 2024-05-07 23:13:14 +08:00 committed by GitHub
parent 1ace1065e6
commit f9afe0addd
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 11 additions and 20 deletions

View File

@ -15,14 +15,6 @@ __all__ = ["KVCacheManager"]
GIGABYTE = 1024**3
def get_model_config_attr(config: PretrainedConfig, attr_name: str):
if hasattr(config, attr_name):
return getattr(config, attr_name)
elif hasattr(config, "attribute_map") and hasattr(config, config.attribute_map[attr_name]):
return getattr(config, config.attribute_map[attr_name])
raise AttributeError(f"{attr_name} is not found in config")
class KVCacheManager:
"""KVCacheManager manages both the logical cache blocks and physical KV cache (tensors).
@ -53,7 +45,7 @@ class KVCacheManager:
And it's possible to have a batch of sequences with different lengths of block tables.
"""
def __init__(self, config: InferenceConfig, model_config: PretrainedConfig, verbose: bool = False) -> None:
def __init__(self, config: InferenceConfig, model_config: PretrainedConfig) -> None:
self.logger = get_dist_logger(__name__)
self.device = get_current_device()
@ -62,14 +54,11 @@ class KVCacheManager:
# Model settings
self.dtype = config.dtype
self.elem_size_in_bytes = torch.tensor([], dtype=self.dtype).element_size()
self.num_layers = get_model_config_attr(model_config, "num_hidden_layers")
self.head_num = get_model_config_attr(model_config, "num_attention_heads")
self.head_size = get_model_config_attr(model_config, "hidden_size") // self.head_num
if hasattr(config, "num_key_value_heads"):
self.kv_head_num = getattr(config, "num_key_value_heads")
elif hasattr(config, "attribute_map") and hasattr(config, config.attribute_map["num_key_value_heads"]):
self.kv_head_num = getattr(config, config.attribute_map["num_key_value_heads"])
self.num_layers = model_config.num_hidden_layers
self.head_num = model_config.num_attention_heads
self.head_size = model_config.hidden_size // self.head_num
if hasattr(model_config, "num_key_value_heads"):
self.kv_head_num = model_config.num_key_value_heads
else:
self.kv_head_num = self.head_num

View File

@ -141,9 +141,11 @@ class LlamaPolicy(Policy):
assert (
self.model.config.num_attention_heads % self.shard_config.tensor_parallel_size == 0
), f"The number of attention heads must be divisible by tensor parallel size."
assert (
self.model.config.num_key_value_heads % self.shard_config.tensor_parallel_size == 0
), f"The number of key_value heads must be divisible by tensor parallel size."
if hasattr(self.model.config, "num_key_value_heads"):
assert (
self.model.config.num_key_value_heads >= self.shard_config.tensor_parallel_size
and self.model.config.num_key_value_heads % self.shard_config.tensor_parallel_size == 0
), f"The number of key_value heads must be divisible by, and must not be less than tensor parallel size."
decoder_attribute_replacement = {
"self_attn.hidden_size": self.model.config.hidden_size // self.shard_config.tensor_parallel_size,
"self_attn.num_heads": self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size,