[Inference] Fix bug in ChatGLM2 Tensor Parallelism (#5014)

* fix bug

* fix

* fix multiquery

* fix multiquery

---------

Co-authored-by: CjhHa1 <cjh18671720497outlook.com>
pull/5024/head
Jianghai 2023-11-07 15:01:50 +08:00 committed by GitHub
parent c36e782d80
commit ef4c14a5e2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 21 additions and 19 deletions

View File

@ -77,14 +77,15 @@ class TPInferEngine:
)
self.layer_num = num_hidden_layers
self.multi_query_group_num = 0
self.multi_query_group_num = model.config.num_attention_heads
# default to attention_heads
self.multi_query_attention = model.config.multi_query_attention
if hasattr(model.config, "multi_query_group_num"):
self.multi_query_group_num = model.config.multi_query_group_num
if hasattr(model.config, "num_key_value_heads"):
self.multi_query_group_num = model.config.num_key_value_heads
self.tp_size = -1 # to be set with given shard config in self.prepare_shard_config
self.cache_manager = None
@ -107,7 +108,7 @@ class TPInferEngine:
assert self.head_num % self.tp_size == 0, f"Cannot shard {self.head_num} heads with tp size {self.tp_size}"
self.head_num //= self.tp_size # update sharded number of heads
if self.multi_query_group_num:
if self.multi_query_attention:
# NOTE the logic of MQA tensor parallelism should be specified.
assert (
self.multi_query_group_num % self.tp_size == 0

View File

@ -395,9 +395,9 @@ class ChatGLM2InferenceForwards:
assert use_cache is True, "use_cache should be set to True using this chatglm attention"
# hidden_states: original :[sq, b, h] --> this [b, sq, h]
batch_size = hidden_states.shape[0]
hidden_size = hidden_states.shape[-1]
# Attention heads [sq, b, h] --> [sq, b, (np * 3 * hn)]
mixed_x_layer = self.query_key_value(hidden_states)
if self.multi_query_attention:
(query_layer, key_layer, value_layer) = mixed_x_layer.split(
[
@ -437,7 +437,6 @@ class ChatGLM2InferenceForwards:
mixed_x_layer = mixed_x_layer.view(*new_tensor_shape)
# [sq, b, np, 3 * hn] --> 3 [sq, b, np, hn]
(query_layer, key_layer, value_layer) = split_tensor_along_last_dim(mixed_x_layer, 3)
cos, sin = infer_state.position_cos, infer_state.position_sin
chatglm2_rotary_emb_fwd(
@ -466,10 +465,10 @@ class ChatGLM2InferenceForwards:
value_layer = value_layer.reshape(
-1, self.num_multi_query_groups_per_partition, self.hidden_size_per_attention_head
)
if infer_state.is_context_stage:
# first token generation:
# copy key and value calculated in current step to memory manager
copy_kv_to_mem_cache(
infer_state.decode_layer_id,
key_layer,
@ -477,8 +476,7 @@ class ChatGLM2InferenceForwards:
infer_state.context_mem_index,
infer_state.cache_manager,
)
attn_output = torch.empty_like(query_layer.view(-1, self.projection_size))
attn_output = torch.empty_like(query_layer.contiguous().view(-1, self.projection_size))
# NOTE: no bug in context attn fwd (del it )
lightllm_llama2_context_attention_fwd(
@ -514,7 +512,7 @@ class ChatGLM2InferenceForwards:
)
# second token and follows
attn_output = torch.empty_like(query_layer.view(-1, self.projection_size))
attn_output = torch.empty_like(query_layer.contiguous().view(-1, self.projection_size))
cache_k = infer_state.cache_manager.key_buffer[infer_state.decode_layer_id][
: infer_state.decode_mem_end, :, :
]
@ -542,6 +540,6 @@ class ChatGLM2InferenceForwards:
# =================
# Output:[b,sq, h]
# =================
output = self.dense(attn_output).reshape(batch_size, -1, hidden_size)
output = self.dense(attn_output).reshape(batch_size, -1, self.projection_size)
return output, kv_cache

View File

@ -48,7 +48,10 @@ class ChatGLM2InferPolicy(ChatGLMModelPolicy):
self.append_or_create_method_replacement(
description=method_replacement, policy=policy, target_key=SelfAttention
)
if self.shard_config.enable_tensor_parallelism:
policy[GLMBlock].attribute_replacement["self_attention.num_multi_query_groups_per_partition"] = (
self.model.config.multi_query_group_num // self.shard_config.tensor_parallel_size
)
# for rmsnorm and others, we need to check the shape
return policy

View File

@ -149,7 +149,6 @@ class Linear1D_Col(ParallelModule):
out_features = module.out_features
bias = module.bias is not None
device = module.weight.device
# ensure only one process group is passed
if isinstance(process_group, (list, tuple)):
assert len(process_group) == 1, f"Expected only one process group, got {len(process_group)}."

View File

@ -400,7 +400,6 @@ class SelfAttention(torch.nn.Module):
)
self.core_attention = CoreAttention(config, self.layer_number)
# Output.
self.dense = nn.Linear(
self.projection_size,

View File

@ -104,7 +104,6 @@ class ChatGLMPolicy(Policy):
),
],
)
# optimization configuration
self.append_or_create_submodule_replacement(
description=[

View File

@ -180,7 +180,6 @@ class ModelSharder(object):
assert target_module is not None, "target_module should not be None"
native_sub_module = getattr_(org_layer, suffix, ignore=True)
# Skip replacement if submodule is not kept by current device when pipeline parallel is enabled.
if (include is not None) and (native_sub_module is not None) and (native_sub_module not in include):
continue

View File

@ -13,13 +13,14 @@ from colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm import ChatGLM
from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn
try:
import lightllm
import lightllm # noqa
HAS_LIGHTLLM_KERNEL = True
except:
HAS_LIGHTLLM_KERNEL = False
os.environ["TRANSFORMERS_NO_ADVISORY_WARNINGS"] = "true"
TPSIZE = 1
TPSIZE = 2
BATCH_SIZE = 8
MAX_INPUT_LEN = 12
MAX_OUTPUT_LEN = 100
@ -67,7 +68,10 @@ def check_chatglm2(rank, world_size, port):
run_chatglm2_test()
@pytest.mark.skipif(not CUDA_SUPPORT or not HAS_LIGHTLLM_KERNEL, reason="kv-cache manager engine requires cuda version to be higher than 11.5")
@pytest.mark.skipif(
not CUDA_SUPPORT or not HAS_LIGHTLLM_KERNEL,
reason="kv-cache manager engine requires cuda version to be higher than 11.5",
)
@pytest.mark.dist
@rerun_if_address_is_in_use()
@clear_cache_before_run()