mirror of https://github.com/hpcaitech/ColossalAI
[Inference] Fix bug in ChatGLM2 Tensor Parallelism (#5014)
* fix bug * fix * fix multiquery * fix multiquery --------- Co-authored-by: CjhHa1 <cjh18671720497outlook.com>pull/5024/head
parent
c36e782d80
commit
ef4c14a5e2
|
@ -77,14 +77,15 @@ class TPInferEngine:
|
||||||
)
|
)
|
||||||
self.layer_num = num_hidden_layers
|
self.layer_num = num_hidden_layers
|
||||||
|
|
||||||
self.multi_query_group_num = 0
|
self.multi_query_group_num = model.config.num_attention_heads
|
||||||
|
# default to attention_heads
|
||||||
|
self.multi_query_attention = model.config.multi_query_attention
|
||||||
|
|
||||||
if hasattr(model.config, "multi_query_group_num"):
|
if hasattr(model.config, "multi_query_group_num"):
|
||||||
self.multi_query_group_num = model.config.multi_query_group_num
|
self.multi_query_group_num = model.config.multi_query_group_num
|
||||||
|
|
||||||
if hasattr(model.config, "num_key_value_heads"):
|
if hasattr(model.config, "num_key_value_heads"):
|
||||||
self.multi_query_group_num = model.config.num_key_value_heads
|
self.multi_query_group_num = model.config.num_key_value_heads
|
||||||
|
|
||||||
self.tp_size = -1 # to be set with given shard config in self.prepare_shard_config
|
self.tp_size = -1 # to be set with given shard config in self.prepare_shard_config
|
||||||
self.cache_manager = None
|
self.cache_manager = None
|
||||||
|
|
||||||
|
@ -107,7 +108,7 @@ class TPInferEngine:
|
||||||
assert self.head_num % self.tp_size == 0, f"Cannot shard {self.head_num} heads with tp size {self.tp_size}"
|
assert self.head_num % self.tp_size == 0, f"Cannot shard {self.head_num} heads with tp size {self.tp_size}"
|
||||||
self.head_num //= self.tp_size # update sharded number of heads
|
self.head_num //= self.tp_size # update sharded number of heads
|
||||||
|
|
||||||
if self.multi_query_group_num:
|
if self.multi_query_attention:
|
||||||
# NOTE the logic of MQA tensor parallelism should be specified.
|
# NOTE the logic of MQA tensor parallelism should be specified.
|
||||||
assert (
|
assert (
|
||||||
self.multi_query_group_num % self.tp_size == 0
|
self.multi_query_group_num % self.tp_size == 0
|
||||||
|
|
|
@ -395,9 +395,9 @@ class ChatGLM2InferenceForwards:
|
||||||
assert use_cache is True, "use_cache should be set to True using this chatglm attention"
|
assert use_cache is True, "use_cache should be set to True using this chatglm attention"
|
||||||
# hidden_states: original :[sq, b, h] --> this [b, sq, h]
|
# hidden_states: original :[sq, b, h] --> this [b, sq, h]
|
||||||
batch_size = hidden_states.shape[0]
|
batch_size = hidden_states.shape[0]
|
||||||
|
hidden_size = hidden_states.shape[-1]
|
||||||
# Attention heads [sq, b, h] --> [sq, b, (np * 3 * hn)]
|
# Attention heads [sq, b, h] --> [sq, b, (np * 3 * hn)]
|
||||||
mixed_x_layer = self.query_key_value(hidden_states)
|
mixed_x_layer = self.query_key_value(hidden_states)
|
||||||
|
|
||||||
if self.multi_query_attention:
|
if self.multi_query_attention:
|
||||||
(query_layer, key_layer, value_layer) = mixed_x_layer.split(
|
(query_layer, key_layer, value_layer) = mixed_x_layer.split(
|
||||||
[
|
[
|
||||||
|
@ -437,7 +437,6 @@ class ChatGLM2InferenceForwards:
|
||||||
mixed_x_layer = mixed_x_layer.view(*new_tensor_shape)
|
mixed_x_layer = mixed_x_layer.view(*new_tensor_shape)
|
||||||
# [sq, b, np, 3 * hn] --> 3 [sq, b, np, hn]
|
# [sq, b, np, 3 * hn] --> 3 [sq, b, np, hn]
|
||||||
(query_layer, key_layer, value_layer) = split_tensor_along_last_dim(mixed_x_layer, 3)
|
(query_layer, key_layer, value_layer) = split_tensor_along_last_dim(mixed_x_layer, 3)
|
||||||
|
|
||||||
cos, sin = infer_state.position_cos, infer_state.position_sin
|
cos, sin = infer_state.position_cos, infer_state.position_sin
|
||||||
|
|
||||||
chatglm2_rotary_emb_fwd(
|
chatglm2_rotary_emb_fwd(
|
||||||
|
@ -466,10 +465,10 @@ class ChatGLM2InferenceForwards:
|
||||||
value_layer = value_layer.reshape(
|
value_layer = value_layer.reshape(
|
||||||
-1, self.num_multi_query_groups_per_partition, self.hidden_size_per_attention_head
|
-1, self.num_multi_query_groups_per_partition, self.hidden_size_per_attention_head
|
||||||
)
|
)
|
||||||
|
|
||||||
if infer_state.is_context_stage:
|
if infer_state.is_context_stage:
|
||||||
# first token generation:
|
# first token generation:
|
||||||
# copy key and value calculated in current step to memory manager
|
# copy key and value calculated in current step to memory manager
|
||||||
|
|
||||||
copy_kv_to_mem_cache(
|
copy_kv_to_mem_cache(
|
||||||
infer_state.decode_layer_id,
|
infer_state.decode_layer_id,
|
||||||
key_layer,
|
key_layer,
|
||||||
|
@ -477,8 +476,7 @@ class ChatGLM2InferenceForwards:
|
||||||
infer_state.context_mem_index,
|
infer_state.context_mem_index,
|
||||||
infer_state.cache_manager,
|
infer_state.cache_manager,
|
||||||
)
|
)
|
||||||
|
attn_output = torch.empty_like(query_layer.contiguous().view(-1, self.projection_size))
|
||||||
attn_output = torch.empty_like(query_layer.view(-1, self.projection_size))
|
|
||||||
|
|
||||||
# NOTE: no bug in context attn fwd (del it )
|
# NOTE: no bug in context attn fwd (del it )
|
||||||
lightllm_llama2_context_attention_fwd(
|
lightllm_llama2_context_attention_fwd(
|
||||||
|
@ -514,7 +512,7 @@ class ChatGLM2InferenceForwards:
|
||||||
)
|
)
|
||||||
|
|
||||||
# second token and follows
|
# second token and follows
|
||||||
attn_output = torch.empty_like(query_layer.view(-1, self.projection_size))
|
attn_output = torch.empty_like(query_layer.contiguous().view(-1, self.projection_size))
|
||||||
cache_k = infer_state.cache_manager.key_buffer[infer_state.decode_layer_id][
|
cache_k = infer_state.cache_manager.key_buffer[infer_state.decode_layer_id][
|
||||||
: infer_state.decode_mem_end, :, :
|
: infer_state.decode_mem_end, :, :
|
||||||
]
|
]
|
||||||
|
@ -542,6 +540,6 @@ class ChatGLM2InferenceForwards:
|
||||||
# =================
|
# =================
|
||||||
# Output:[b,sq, h]
|
# Output:[b,sq, h]
|
||||||
# =================
|
# =================
|
||||||
|
output = self.dense(attn_output).reshape(batch_size, -1, hidden_size)
|
||||||
|
|
||||||
output = self.dense(attn_output).reshape(batch_size, -1, self.projection_size)
|
|
||||||
return output, kv_cache
|
return output, kv_cache
|
||||||
|
|
|
@ -48,7 +48,10 @@ class ChatGLM2InferPolicy(ChatGLMModelPolicy):
|
||||||
self.append_or_create_method_replacement(
|
self.append_or_create_method_replacement(
|
||||||
description=method_replacement, policy=policy, target_key=SelfAttention
|
description=method_replacement, policy=policy, target_key=SelfAttention
|
||||||
)
|
)
|
||||||
|
if self.shard_config.enable_tensor_parallelism:
|
||||||
|
policy[GLMBlock].attribute_replacement["self_attention.num_multi_query_groups_per_partition"] = (
|
||||||
|
self.model.config.multi_query_group_num // self.shard_config.tensor_parallel_size
|
||||||
|
)
|
||||||
# for rmsnorm and others, we need to check the shape
|
# for rmsnorm and others, we need to check the shape
|
||||||
return policy
|
return policy
|
||||||
|
|
||||||
|
|
|
@ -149,7 +149,6 @@ class Linear1D_Col(ParallelModule):
|
||||||
out_features = module.out_features
|
out_features = module.out_features
|
||||||
bias = module.bias is not None
|
bias = module.bias is not None
|
||||||
device = module.weight.device
|
device = module.weight.device
|
||||||
|
|
||||||
# ensure only one process group is passed
|
# ensure only one process group is passed
|
||||||
if isinstance(process_group, (list, tuple)):
|
if isinstance(process_group, (list, tuple)):
|
||||||
assert len(process_group) == 1, f"Expected only one process group, got {len(process_group)}."
|
assert len(process_group) == 1, f"Expected only one process group, got {len(process_group)}."
|
||||||
|
|
|
@ -400,7 +400,6 @@ class SelfAttention(torch.nn.Module):
|
||||||
)
|
)
|
||||||
|
|
||||||
self.core_attention = CoreAttention(config, self.layer_number)
|
self.core_attention = CoreAttention(config, self.layer_number)
|
||||||
|
|
||||||
# Output.
|
# Output.
|
||||||
self.dense = nn.Linear(
|
self.dense = nn.Linear(
|
||||||
self.projection_size,
|
self.projection_size,
|
||||||
|
|
|
@ -104,7 +104,6 @@ class ChatGLMPolicy(Policy):
|
||||||
),
|
),
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
# optimization configuration
|
# optimization configuration
|
||||||
self.append_or_create_submodule_replacement(
|
self.append_or_create_submodule_replacement(
|
||||||
description=[
|
description=[
|
||||||
|
|
|
@ -180,7 +180,6 @@ class ModelSharder(object):
|
||||||
assert target_module is not None, "target_module should not be None"
|
assert target_module is not None, "target_module should not be None"
|
||||||
|
|
||||||
native_sub_module = getattr_(org_layer, suffix, ignore=True)
|
native_sub_module = getattr_(org_layer, suffix, ignore=True)
|
||||||
|
|
||||||
# Skip replacement if submodule is not kept by current device when pipeline parallel is enabled.
|
# Skip replacement if submodule is not kept by current device when pipeline parallel is enabled.
|
||||||
if (include is not None) and (native_sub_module is not None) and (native_sub_module not in include):
|
if (include is not None) and (native_sub_module is not None) and (native_sub_module not in include):
|
||||||
continue
|
continue
|
||||||
|
|
|
@ -13,13 +13,14 @@ from colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm import ChatGLM
|
||||||
from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn
|
from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import lightllm
|
import lightllm # noqa
|
||||||
|
|
||||||
HAS_LIGHTLLM_KERNEL = True
|
HAS_LIGHTLLM_KERNEL = True
|
||||||
except:
|
except:
|
||||||
HAS_LIGHTLLM_KERNEL = False
|
HAS_LIGHTLLM_KERNEL = False
|
||||||
|
|
||||||
os.environ["TRANSFORMERS_NO_ADVISORY_WARNINGS"] = "true"
|
os.environ["TRANSFORMERS_NO_ADVISORY_WARNINGS"] = "true"
|
||||||
TPSIZE = 1
|
TPSIZE = 2
|
||||||
BATCH_SIZE = 8
|
BATCH_SIZE = 8
|
||||||
MAX_INPUT_LEN = 12
|
MAX_INPUT_LEN = 12
|
||||||
MAX_OUTPUT_LEN = 100
|
MAX_OUTPUT_LEN = 100
|
||||||
|
@ -67,7 +68,10 @@ def check_chatglm2(rank, world_size, port):
|
||||||
run_chatglm2_test()
|
run_chatglm2_test()
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(not CUDA_SUPPORT or not HAS_LIGHTLLM_KERNEL, reason="kv-cache manager engine requires cuda version to be higher than 11.5")
|
@pytest.mark.skipif(
|
||||||
|
not CUDA_SUPPORT or not HAS_LIGHTLLM_KERNEL,
|
||||||
|
reason="kv-cache manager engine requires cuda version to be higher than 11.5",
|
||||||
|
)
|
||||||
@pytest.mark.dist
|
@pytest.mark.dist
|
||||||
@rerun_if_address_is_in_use()
|
@rerun_if_address_is_in_use()
|
||||||
@clear_cache_before_run()
|
@clear_cache_before_run()
|
||||||
|
|
Loading…
Reference in New Issue