mirror of https://github.com/hpcaitech/ColossalAI
[hotfix] Fix KV Heads Number Assignment in KVCacheManager (#5695)
- Fix key value number assignment in KVCacheManager, as well as method of accessingpull/5697/head
parent
1ace1065e6
commit
f9afe0addd
|
@ -15,14 +15,6 @@ __all__ = ["KVCacheManager"]
|
|||
GIGABYTE = 1024**3
|
||||
|
||||
|
||||
def get_model_config_attr(config: PretrainedConfig, attr_name: str):
|
||||
if hasattr(config, attr_name):
|
||||
return getattr(config, attr_name)
|
||||
elif hasattr(config, "attribute_map") and hasattr(config, config.attribute_map[attr_name]):
|
||||
return getattr(config, config.attribute_map[attr_name])
|
||||
raise AttributeError(f"{attr_name} is not found in config")
|
||||
|
||||
|
||||
class KVCacheManager:
|
||||
"""KVCacheManager manages both the logical cache blocks and physical KV cache (tensors).
|
||||
|
||||
|
@ -53,7 +45,7 @@ class KVCacheManager:
|
|||
And it's possible to have a batch of sequences with different lengths of block tables.
|
||||
"""
|
||||
|
||||
def __init__(self, config: InferenceConfig, model_config: PretrainedConfig, verbose: bool = False) -> None:
|
||||
def __init__(self, config: InferenceConfig, model_config: PretrainedConfig) -> None:
|
||||
self.logger = get_dist_logger(__name__)
|
||||
self.device = get_current_device()
|
||||
|
||||
|
@ -62,14 +54,11 @@ class KVCacheManager:
|
|||
# Model settings
|
||||
self.dtype = config.dtype
|
||||
self.elem_size_in_bytes = torch.tensor([], dtype=self.dtype).element_size()
|
||||
self.num_layers = get_model_config_attr(model_config, "num_hidden_layers")
|
||||
self.head_num = get_model_config_attr(model_config, "num_attention_heads")
|
||||
self.head_size = get_model_config_attr(model_config, "hidden_size") // self.head_num
|
||||
|
||||
if hasattr(config, "num_key_value_heads"):
|
||||
self.kv_head_num = getattr(config, "num_key_value_heads")
|
||||
elif hasattr(config, "attribute_map") and hasattr(config, config.attribute_map["num_key_value_heads"]):
|
||||
self.kv_head_num = getattr(config, config.attribute_map["num_key_value_heads"])
|
||||
self.num_layers = model_config.num_hidden_layers
|
||||
self.head_num = model_config.num_attention_heads
|
||||
self.head_size = model_config.hidden_size // self.head_num
|
||||
if hasattr(model_config, "num_key_value_heads"):
|
||||
self.kv_head_num = model_config.num_key_value_heads
|
||||
else:
|
||||
self.kv_head_num = self.head_num
|
||||
|
||||
|
|
|
@ -141,9 +141,11 @@ class LlamaPolicy(Policy):
|
|||
assert (
|
||||
self.model.config.num_attention_heads % self.shard_config.tensor_parallel_size == 0
|
||||
), f"The number of attention heads must be divisible by tensor parallel size."
|
||||
assert (
|
||||
self.model.config.num_key_value_heads % self.shard_config.tensor_parallel_size == 0
|
||||
), f"The number of key_value heads must be divisible by tensor parallel size."
|
||||
if hasattr(self.model.config, "num_key_value_heads"):
|
||||
assert (
|
||||
self.model.config.num_key_value_heads >= self.shard_config.tensor_parallel_size
|
||||
and self.model.config.num_key_value_heads % self.shard_config.tensor_parallel_size == 0
|
||||
), f"The number of key_value heads must be divisible by, and must not be less than tensor parallel size."
|
||||
decoder_attribute_replacement = {
|
||||
"self_attn.hidden_size": self.model.config.hidden_size // self.shard_config.tensor_parallel_size,
|
||||
"self_attn.num_heads": self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size,
|
||||
|
|
Loading…
Reference in New Issue