diff --git a/colossalai/inference/tensor_parallel/engine.py b/colossalai/inference/tensor_parallel/engine.py index 3be213274..2eadbcab1 100644 --- a/colossalai/inference/tensor_parallel/engine.py +++ b/colossalai/inference/tensor_parallel/engine.py @@ -77,14 +77,15 @@ class TPInferEngine: ) self.layer_num = num_hidden_layers - self.multi_query_group_num = 0 + self.multi_query_group_num = model.config.num_attention_heads + # default to attention_heads + self.multi_query_attention = model.config.multi_query_attention if hasattr(model.config, "multi_query_group_num"): self.multi_query_group_num = model.config.multi_query_group_num if hasattr(model.config, "num_key_value_heads"): self.multi_query_group_num = model.config.num_key_value_heads - self.tp_size = -1 # to be set with given shard config in self.prepare_shard_config self.cache_manager = None @@ -107,7 +108,7 @@ class TPInferEngine: assert self.head_num % self.tp_size == 0, f"Cannot shard {self.head_num} heads with tp size {self.tp_size}" self.head_num //= self.tp_size # update sharded number of heads - if self.multi_query_group_num: + if self.multi_query_attention: # NOTE the logic of MQA tensor parallelism should be specified. assert ( self.multi_query_group_num % self.tp_size == 0 diff --git a/colossalai/inference/tensor_parallel/modeling/chatglm2.py b/colossalai/inference/tensor_parallel/modeling/chatglm2.py index 69a92c4fe..b8fe8eb54 100644 --- a/colossalai/inference/tensor_parallel/modeling/chatglm2.py +++ b/colossalai/inference/tensor_parallel/modeling/chatglm2.py @@ -395,9 +395,9 @@ class ChatGLM2InferenceForwards: assert use_cache is True, "use_cache should be set to True using this chatglm attention" # hidden_states: original :[sq, b, h] --> this [b, sq, h] batch_size = hidden_states.shape[0] + hidden_size = hidden_states.shape[-1] # Attention heads [sq, b, h] --> [sq, b, (np * 3 * hn)] mixed_x_layer = self.query_key_value(hidden_states) - if self.multi_query_attention: (query_layer, key_layer, value_layer) = mixed_x_layer.split( [ @@ -437,7 +437,6 @@ class ChatGLM2InferenceForwards: mixed_x_layer = mixed_x_layer.view(*new_tensor_shape) # [sq, b, np, 3 * hn] --> 3 [sq, b, np, hn] (query_layer, key_layer, value_layer) = split_tensor_along_last_dim(mixed_x_layer, 3) - cos, sin = infer_state.position_cos, infer_state.position_sin chatglm2_rotary_emb_fwd( @@ -466,10 +465,10 @@ class ChatGLM2InferenceForwards: value_layer = value_layer.reshape( -1, self.num_multi_query_groups_per_partition, self.hidden_size_per_attention_head ) + if infer_state.is_context_stage: # first token generation: # copy key and value calculated in current step to memory manager - copy_kv_to_mem_cache( infer_state.decode_layer_id, key_layer, @@ -477,8 +476,7 @@ class ChatGLM2InferenceForwards: infer_state.context_mem_index, infer_state.cache_manager, ) - - attn_output = torch.empty_like(query_layer.view(-1, self.projection_size)) + attn_output = torch.empty_like(query_layer.contiguous().view(-1, self.projection_size)) # NOTE: no bug in context attn fwd (del it ) lightllm_llama2_context_attention_fwd( @@ -514,7 +512,7 @@ class ChatGLM2InferenceForwards: ) # second token and follows - attn_output = torch.empty_like(query_layer.view(-1, self.projection_size)) + attn_output = torch.empty_like(query_layer.contiguous().view(-1, self.projection_size)) cache_k = infer_state.cache_manager.key_buffer[infer_state.decode_layer_id][ : infer_state.decode_mem_end, :, : ] @@ -542,6 +540,6 @@ class ChatGLM2InferenceForwards: # ================= # Output:[b,sq, h] # ================= + output = self.dense(attn_output).reshape(batch_size, -1, hidden_size) - output = self.dense(attn_output).reshape(batch_size, -1, self.projection_size) return output, kv_cache diff --git a/colossalai/inference/tensor_parallel/policies/chatglm2.py b/colossalai/inference/tensor_parallel/policies/chatglm2.py index 90f8b4fd2..60dc511f5 100644 --- a/colossalai/inference/tensor_parallel/policies/chatglm2.py +++ b/colossalai/inference/tensor_parallel/policies/chatglm2.py @@ -48,7 +48,10 @@ class ChatGLM2InferPolicy(ChatGLMModelPolicy): self.append_or_create_method_replacement( description=method_replacement, policy=policy, target_key=SelfAttention ) - + if self.shard_config.enable_tensor_parallelism: + policy[GLMBlock].attribute_replacement["self_attention.num_multi_query_groups_per_partition"] = ( + self.model.config.multi_query_group_num // self.shard_config.tensor_parallel_size + ) # for rmsnorm and others, we need to check the shape return policy diff --git a/colossalai/shardformer/layer/linear.py b/colossalai/shardformer/layer/linear.py index cf2003877..9e6386223 100644 --- a/colossalai/shardformer/layer/linear.py +++ b/colossalai/shardformer/layer/linear.py @@ -149,7 +149,6 @@ class Linear1D_Col(ParallelModule): out_features = module.out_features bias = module.bias is not None device = module.weight.device - # ensure only one process group is passed if isinstance(process_group, (list, tuple)): assert len(process_group) == 1, f"Expected only one process group, got {len(process_group)}." diff --git a/colossalai/shardformer/modeling/chatglm2_6b/modeling_chatglm.py b/colossalai/shardformer/modeling/chatglm2_6b/modeling_chatglm.py index fdd49ecfe..71aa2296e 100644 --- a/colossalai/shardformer/modeling/chatglm2_6b/modeling_chatglm.py +++ b/colossalai/shardformer/modeling/chatglm2_6b/modeling_chatglm.py @@ -400,7 +400,6 @@ class SelfAttention(torch.nn.Module): ) self.core_attention = CoreAttention(config, self.layer_number) - # Output. self.dense = nn.Linear( self.projection_size, diff --git a/colossalai/shardformer/policies/chatglm2.py b/colossalai/shardformer/policies/chatglm2.py index ab18d80b7..d1ad9f914 100644 --- a/colossalai/shardformer/policies/chatglm2.py +++ b/colossalai/shardformer/policies/chatglm2.py @@ -104,7 +104,6 @@ class ChatGLMPolicy(Policy): ), ], ) - # optimization configuration self.append_or_create_submodule_replacement( description=[ diff --git a/colossalai/shardformer/shard/sharder.py b/colossalai/shardformer/shard/sharder.py index e3c0aa93d..0586ada9e 100644 --- a/colossalai/shardformer/shard/sharder.py +++ b/colossalai/shardformer/shard/sharder.py @@ -180,7 +180,6 @@ class ModelSharder(object): assert target_module is not None, "target_module should not be None" native_sub_module = getattr_(org_layer, suffix, ignore=True) - # Skip replacement if submodule is not kept by current device when pipeline parallel is enabled. if (include is not None) and (native_sub_module is not None) and (native_sub_module not in include): continue diff --git a/tests/test_infer/test_chatglm2_infer.py b/tests/test_infer/test_chatglm2_infer.py index 09bb8a949..a2ec35dcd 100644 --- a/tests/test_infer/test_chatglm2_infer.py +++ b/tests/test_infer/test_chatglm2_infer.py @@ -13,13 +13,14 @@ from colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm import ChatGLM from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn try: - import lightllm + import lightllm # noqa + HAS_LIGHTLLM_KERNEL = True except: HAS_LIGHTLLM_KERNEL = False - + os.environ["TRANSFORMERS_NO_ADVISORY_WARNINGS"] = "true" -TPSIZE = 1 +TPSIZE = 2 BATCH_SIZE = 8 MAX_INPUT_LEN = 12 MAX_OUTPUT_LEN = 100 @@ -67,7 +68,10 @@ def check_chatglm2(rank, world_size, port): run_chatglm2_test() -@pytest.mark.skipif(not CUDA_SUPPORT or not HAS_LIGHTLLM_KERNEL, reason="kv-cache manager engine requires cuda version to be higher than 11.5") +@pytest.mark.skipif( + not CUDA_SUPPORT or not HAS_LIGHTLLM_KERNEL, + reason="kv-cache manager engine requires cuda version to be higher than 11.5", +) @pytest.mark.dist @rerun_if_address_is_in_use() @clear_cache_before_run()