diff --git a/colossalai/inference/tensor_parallel/engine.py b/colossalai/inference/tensor_parallel/engine.py index d5ef37fee..e4c4a2d70 100644 --- a/colossalai/inference/tensor_parallel/engine.py +++ b/colossalai/inference/tensor_parallel/engine.py @@ -1,7 +1,6 @@ from typing import Any, Callable, List, Optional, Union import torch -import torch.distributed as dist import torch.nn as nn from transformers import BloomForCausalLM, LlamaForCausalLM from transformers.generation import GenerationConfig @@ -74,9 +73,14 @@ class TPInferEngine: model.config.num_hidden_layers if hasattr(model.config, "num_hidden_layers") else model.config.num_layers ) self.layer_num = num_hidden_layers - self.multi_query_group_num = ( - model.config.multi_query_group_num if hasattr(model.config, "multi_query_group_num") else 0 - ) + + self.multi_query_group_num = 0 + + if hasattr(model.config, "multi_query_group_num"): + self.multi_query_group_num = model.config.multi_query_group_num + + if hasattr(model.config, "num_key_value_heads"): + self.multi_query_group_num = model.config.num_key_value_heads self.tp_size = -1 # to be set with given shard config in self.prepare_shard_config self.cache_manager = None @@ -97,6 +101,7 @@ class TPInferEngine: assert self.tp_size >= 1, "TP size not initialized without providing a valid ShardConfig" assert self.head_num % self.tp_size == 0, f"Cannot shard {self.head_num} heads with tp size {self.tp_size}" self.head_num //= self.tp_size # update sharded number of heads + if self.multi_query_group_num: # NOTE the logic of MQA tensor parallelism should be specified. assert ( @@ -116,13 +121,15 @@ class TPInferEngine: def _post_init_gptq_buffer(self, model: nn.Module) -> None: from colossalai.inference.quant.gptq.cai_gptq import CaiQuantLinear + HAS_GPTQ_CUDA = False try: from colossalai.kernel.op_builder.gptq import GPTQBuilder + gptq_cuda = GPTQBuilder().load() HAS_GPTQ_CUDA = True except ImportError: - warnings.warn('CUDA gptq is not installed') + warnings.warn("CUDA gptq is not installed") HAS_GPTQ_CUDA = False for name, submodule in model.named_modules(): @@ -130,8 +137,9 @@ class TPInferEngine: self.max_dq_buffer_size = max(self.max_dq_buffer_size, submodule.qweight.numel() * 8) if self.use_act_order: - self.max_inner_outer_dim = max(self.max_inner_outer_dim, submodule.infeatures, - submodule.outfeatures) + self.max_inner_outer_dim = max( + self.max_inner_outer_dim, submodule.infeatures, submodule.outfeatures + ) self.bits = submodule.bits if not (HAS_GPTQ_CUDA and self.bits == 4): return @@ -141,15 +149,16 @@ class TPInferEngine: max_input_len = self.max_input_len # The temp_state buffer is required to reorder X in the act-order case. # The temp_dq buffer is required to dequantize weights when using cuBLAS, typically for the prefill. - self.gptq_temp_state_buffer = torch.zeros((max_input_len, self.max_inner_outer_dim), - dtype=torch.float16, - device=torch.cuda.current_device()) - self.gptq_temp_dq_buffer = torch.zeros((1, self.max_dq_buffer_size), - dtype=torch.float16, - device=torch.cuda.current_device()) + self.gptq_temp_state_buffer = torch.zeros( + (max_input_len, self.max_inner_outer_dim), dtype=torch.float16, device=torch.cuda.current_device() + ) + self.gptq_temp_dq_buffer = torch.zeros( + (1, self.max_dq_buffer_size), dtype=torch.float16, device=torch.cuda.current_device() + ) - gptq_cuda.prepare_buffers(torch.device(torch.cuda.current_device()), self.gptq_temp_state_buffer, - self.gptq_temp_dq_buffer) + gptq_cuda.prepare_buffers( + torch.device(torch.cuda.current_device()), self.gptq_temp_state_buffer, self.gptq_temp_dq_buffer + ) # Using the default from exllama repo here. matmul_recons_thd = 8 matmul_fused_remap = False diff --git a/colossalai/inference/tensor_parallel/modeling/_utils.py b/colossalai/inference/tensor_parallel/modeling/_utils.py index e476c3132..068b64b4f 100644 --- a/colossalai/inference/tensor_parallel/modeling/_utils.py +++ b/colossalai/inference/tensor_parallel/modeling/_utils.py @@ -45,7 +45,7 @@ def init_to_get_rotary(self, base=10000, use_elem=False): base = float(base) # NTK ref: https://www.reddit.com/r/LocalLLaMA/comments/14lz7j5/ntkaware_scaled_rope_allows_llama_models_to_have/ - ntk_alpha = float(os.environ.get("INFER_NTK_ALPHA", None)) + ntk_alpha = os.environ.get("INFER_NTK_ALPHA", None) if ntk_alpha is not None: ntk_alpha = float(ntk_alpha) diff --git a/colossalai/inference/tensor_parallel/modeling/llama.py b/colossalai/inference/tensor_parallel/modeling/llama.py index a7661cee1..ac4ae72f3 100644 --- a/colossalai/inference/tensor_parallel/modeling/llama.py +++ b/colossalai/inference/tensor_parallel/modeling/llama.py @@ -5,7 +5,13 @@ from transformers.modeling_outputs import BaseModelOutputWithPast from transformers.models.llama.modeling_llama import LlamaAttention, LlamaDecoderLayer, LlamaModel, LlamaRMSNorm from colossalai.inference.tensor_parallel.batch_infer_state import BatchInferState -from colossalai.kernel.triton import llama_context_attn_fwd, rotary_embedding_fwd, token_attention_fwd +from colossalai.kernel.triton import ( + llama2_context_attn_fwd, + llama_context_attn_fwd, + rotary_embedding_fwd, + token_attention_fwd, +) +from colossalai.kernel.triton.token_attention_kernel import Llama2TokenAttentionForwards from ._utils import copy_kv_to_mem_cache @@ -138,6 +144,7 @@ class LlamaInferenceForwards: seq_len = infer_state.seq_len infer_state.position_cos = torch.index_select(self._cos_cached, 0, seq_len - 1).view(seq_len.shape[0], -1) infer_state.position_sin = torch.index_select(self._sin_cached, 0, seq_len - 1).view(seq_len.shape[0], -1) + infer_state.other_kv_index = infer_state.block_loc[0, seq_length_with_past - 1].item() if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) @@ -261,8 +268,8 @@ class LlamaInferenceForwards: # key_states_transposed [bs, num_heads, seq_len, head_dim/embed_size_per_head] query_states = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim) - key_states = self.k_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim) - value_states = self.v_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim) + key_states = self.k_proj(hidden_states).view(bsz, q_len, self.num_key_value_heads, self.head_dim) + value_states = self.v_proj(hidden_states).view(bsz, q_len, self.num_key_value_heads, self.head_dim) # NOTE might want to revise # need some way to record the length of past key values cache @@ -274,11 +281,11 @@ class LlamaInferenceForwards: # print("shape ", cos.shape, query_states.view(-1, self.num_heads, self.head_dim).shape, ) rotary_embedding_fwd(query_states.view(-1, self.num_heads, self.head_dim), cos, sin) - rotary_embedding_fwd(key_states.view(-1, self.num_heads, self.head_dim), cos, sin) + rotary_embedding_fwd(key_states.view(-1, self.num_key_value_heads, self.head_dim), cos, sin) query_states = query_states.reshape(-1, self.num_heads, self.head_dim) - key_states = key_states.reshape(-1, self.num_heads, self.head_dim) - value_states = value_states.reshape(-1, self.num_heads, self.head_dim) + key_states = key_states.reshape(-1, self.num_key_value_heads, self.head_dim) + value_states = value_states.reshape(-1, self.num_key_value_heads, self.head_dim) if infer_state.is_context_stage: # first token generation @@ -294,15 +301,26 @@ class LlamaInferenceForwards: attn_output = torch.empty_like(query_states) - llama_context_attn_fwd( - query_states, - key_states, - value_states, - attn_output, - infer_state.start_loc, - infer_state.seq_len, - infer_state.cache_manager.past_key_values_length, - ) + if self.num_key_value_groups == 1: + llama_context_attn_fwd( + query_states, + key_states, + value_states, + attn_output, + infer_state.start_loc, + infer_state.seq_len, + infer_state.cache_manager.past_key_values_length, + ) + else: + llama2_context_attn_fwd( + query_states, + key_states, + value_states, + attn_output, + infer_state.start_loc, + infer_state.seq_len, + infer_state.cache_manager.past_key_values_length, + ) else: if infer_state.decode_is_contiguous: # if decode is contiguous, then we copy to key cache and value cache in cache manager directly @@ -330,17 +348,29 @@ class LlamaInferenceForwards: # (batch_size, seqlen, nheads, headdim) attn_output = torch.empty_like(query_states) - token_attention_fwd( - query_states, - infer_state.cache_manager.key_buffer[infer_state.decode_layer_id], - infer_state.cache_manager.value_buffer[infer_state.decode_layer_id], - attn_output, - infer_state.block_loc, - infer_state.start_loc, - infer_state.seq_len, - infer_state.cache_manager.past_key_values_length, - ) - + if self.num_key_value_groups == 1: + token_attention_fwd( + query_states, + infer_state.cache_manager.key_buffer[infer_state.decode_layer_id], + infer_state.cache_manager.value_buffer[infer_state.decode_layer_id], + attn_output, + infer_state.block_loc, + infer_state.start_loc, + infer_state.seq_len, + infer_state.cache_manager.past_key_values_length, + ) + else: + Llama2TokenAttentionForwards.token_attn( + query_states, + infer_state.cache_manager.key_buffer[infer_state.decode_layer_id], + infer_state.cache_manager.value_buffer[infer_state.decode_layer_id], + attn_output, + infer_state.block_loc, + infer_state.start_loc, + infer_state.seq_len, + infer_state.cache_manager.past_key_values_length, + infer_state.other_kv_index, + ) attn_output = attn_output.view(bsz, q_len, self.hidden_size) attn_output = self.o_proj(attn_output) diff --git a/colossalai/kernel/triton/__init__.py b/colossalai/kernel/triton/__init__.py index 983069158..f065b2100 100644 --- a/colossalai/kernel/triton/__init__.py +++ b/colossalai/kernel/triton/__init__.py @@ -9,7 +9,7 @@ except ImportError: # There may exist import error even if we have triton installed. if HAS_TRITON: - from .context_attention import bloom_context_attn_fwd, llama_context_attn_fwd + from .context_attention import bloom_context_attn_fwd, llama2_context_attn_fwd, llama_context_attn_fwd from .copy_kv_cache_dest import copy_kv_cache_to_dest from .fused_layernorm import layer_norm from .gptq_triton import gptq_fused_linear_triton @@ -20,6 +20,7 @@ if HAS_TRITON: __all__ = [ "llama_context_attn_fwd", + "llama2_context_attn_fwd", "bloom_context_attn_fwd", "softmax", "layer_norm", diff --git a/tests/test_infer/test_llama2_infer.py b/tests/test_infer/test_llama2_infer.py new file mode 100644 index 000000000..0eebed889 --- /dev/null +++ b/tests/test_infer/test_llama2_infer.py @@ -0,0 +1,69 @@ +import os + +import pytest +import torch +from packaging import version +from transformers import LlamaForCausalLM +from transformers.models.llama.configuration_llama import LlamaConfig + +import colossalai +from colossalai.inference.tensor_parallel.engine import TPInferEngine +from colossalai.logging import disable_existing_loggers +from colossalai.shardformer import ShardConfig +from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn + +os.environ["TRANSFORMERS_NO_ADVISORY_WARNINGS"] = "true" +TPSIZE = 2 +BATCH_SIZE = 8 +MAX_INPUT_LEN = 12 +MAX_OUTPUT_LEN = 100 + +CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse("11.5") + + +@parameterize( + "test_config", + [ + { + "tp_size": TPSIZE, + } + ], +) +def run_llama_test(test_config): + llama_config = LlamaConfig( + num_hidden_layers=2, num_key_value_heads=8, bos_token_id=0, eos_token_id=1, vocab_size=1200, hidden_size=1024 + ) + model = LlamaForCausalLM(llama_config) + model = model.half() + + shard_config = ShardConfig( + enable_tensor_parallelism=True if test_config["tp_size"] > 1 else False, inference_only=True + ) + infer_engine = TPInferEngine(model, shard_config, BATCH_SIZE, MAX_INPUT_LEN, MAX_OUTPUT_LEN) + generate_kwargs = dict(max_new_tokens=MAX_OUTPUT_LEN, do_sample=False) + + input_tokens = { + "input_ids": torch.randint(1, 1000, (BATCH_SIZE, MAX_INPUT_LEN), device="cuda"), + "attention_mask": torch.ones((BATCH_SIZE, MAX_INPUT_LEN), device="cuda"), + } + outputs = infer_engine.generate(input_tokens, **generate_kwargs) + + assert outputs is not None + + +def check_llama(rank, world_size, port): + disable_existing_loggers() + colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") + run_llama_test() + + +@pytest.mark.skipif(not CUDA_SUPPORT, reason="kv-cache manager engine requires cuda version to be higher than 11.5") +@pytest.mark.dist +@rerun_if_address_is_in_use() +@clear_cache_before_run() +def test_llama(): + spawn(check_llama, TPSIZE) + + +if __name__ == "__main__": + test_llama()