mirror of https://github.com/hpcaitech/ColossalAI
parent
39f2582e98
commit
77a9328304
|
@ -1,7 +1,6 @@
|
|||
from typing import Any, Callable, List, Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
from transformers import BloomForCausalLM, LlamaForCausalLM
|
||||
from transformers.generation import GenerationConfig
|
||||
|
@ -74,9 +73,14 @@ class TPInferEngine:
|
|||
model.config.num_hidden_layers if hasattr(model.config, "num_hidden_layers") else model.config.num_layers
|
||||
)
|
||||
self.layer_num = num_hidden_layers
|
||||
self.multi_query_group_num = (
|
||||
model.config.multi_query_group_num if hasattr(model.config, "multi_query_group_num") else 0
|
||||
)
|
||||
|
||||
self.multi_query_group_num = 0
|
||||
|
||||
if hasattr(model.config, "multi_query_group_num"):
|
||||
self.multi_query_group_num = model.config.multi_query_group_num
|
||||
|
||||
if hasattr(model.config, "num_key_value_heads"):
|
||||
self.multi_query_group_num = model.config.num_key_value_heads
|
||||
|
||||
self.tp_size = -1 # to be set with given shard config in self.prepare_shard_config
|
||||
self.cache_manager = None
|
||||
|
@ -97,6 +101,7 @@ class TPInferEngine:
|
|||
assert self.tp_size >= 1, "TP size not initialized without providing a valid ShardConfig"
|
||||
assert self.head_num % self.tp_size == 0, f"Cannot shard {self.head_num} heads with tp size {self.tp_size}"
|
||||
self.head_num //= self.tp_size # update sharded number of heads
|
||||
|
||||
if self.multi_query_group_num:
|
||||
# NOTE the logic of MQA tensor parallelism should be specified.
|
||||
assert (
|
||||
|
@ -116,13 +121,15 @@ class TPInferEngine:
|
|||
|
||||
def _post_init_gptq_buffer(self, model: nn.Module) -> None:
|
||||
from colossalai.inference.quant.gptq.cai_gptq import CaiQuantLinear
|
||||
|
||||
HAS_GPTQ_CUDA = False
|
||||
try:
|
||||
from colossalai.kernel.op_builder.gptq import GPTQBuilder
|
||||
|
||||
gptq_cuda = GPTQBuilder().load()
|
||||
HAS_GPTQ_CUDA = True
|
||||
except ImportError:
|
||||
warnings.warn('CUDA gptq is not installed')
|
||||
warnings.warn("CUDA gptq is not installed")
|
||||
HAS_GPTQ_CUDA = False
|
||||
|
||||
for name, submodule in model.named_modules():
|
||||
|
@ -130,8 +137,9 @@ class TPInferEngine:
|
|||
self.max_dq_buffer_size = max(self.max_dq_buffer_size, submodule.qweight.numel() * 8)
|
||||
|
||||
if self.use_act_order:
|
||||
self.max_inner_outer_dim = max(self.max_inner_outer_dim, submodule.infeatures,
|
||||
submodule.outfeatures)
|
||||
self.max_inner_outer_dim = max(
|
||||
self.max_inner_outer_dim, submodule.infeatures, submodule.outfeatures
|
||||
)
|
||||
self.bits = submodule.bits
|
||||
if not (HAS_GPTQ_CUDA and self.bits == 4):
|
||||
return
|
||||
|
@ -141,15 +149,16 @@ class TPInferEngine:
|
|||
max_input_len = self.max_input_len
|
||||
# The temp_state buffer is required to reorder X in the act-order case.
|
||||
# The temp_dq buffer is required to dequantize weights when using cuBLAS, typically for the prefill.
|
||||
self.gptq_temp_state_buffer = torch.zeros((max_input_len, self.max_inner_outer_dim),
|
||||
dtype=torch.float16,
|
||||
device=torch.cuda.current_device())
|
||||
self.gptq_temp_dq_buffer = torch.zeros((1, self.max_dq_buffer_size),
|
||||
dtype=torch.float16,
|
||||
device=torch.cuda.current_device())
|
||||
self.gptq_temp_state_buffer = torch.zeros(
|
||||
(max_input_len, self.max_inner_outer_dim), dtype=torch.float16, device=torch.cuda.current_device()
|
||||
)
|
||||
self.gptq_temp_dq_buffer = torch.zeros(
|
||||
(1, self.max_dq_buffer_size), dtype=torch.float16, device=torch.cuda.current_device()
|
||||
)
|
||||
|
||||
gptq_cuda.prepare_buffers(torch.device(torch.cuda.current_device()), self.gptq_temp_state_buffer,
|
||||
self.gptq_temp_dq_buffer)
|
||||
gptq_cuda.prepare_buffers(
|
||||
torch.device(torch.cuda.current_device()), self.gptq_temp_state_buffer, self.gptq_temp_dq_buffer
|
||||
)
|
||||
# Using the default from exllama repo here.
|
||||
matmul_recons_thd = 8
|
||||
matmul_fused_remap = False
|
||||
|
|
|
@ -45,7 +45,7 @@ def init_to_get_rotary(self, base=10000, use_elem=False):
|
|||
base = float(base)
|
||||
|
||||
# NTK ref: https://www.reddit.com/r/LocalLLaMA/comments/14lz7j5/ntkaware_scaled_rope_allows_llama_models_to_have/
|
||||
ntk_alpha = float(os.environ.get("INFER_NTK_ALPHA", None))
|
||||
ntk_alpha = os.environ.get("INFER_NTK_ALPHA", None)
|
||||
|
||||
if ntk_alpha is not None:
|
||||
ntk_alpha = float(ntk_alpha)
|
||||
|
|
|
@ -5,7 +5,13 @@ from transformers.modeling_outputs import BaseModelOutputWithPast
|
|||
from transformers.models.llama.modeling_llama import LlamaAttention, LlamaDecoderLayer, LlamaModel, LlamaRMSNorm
|
||||
|
||||
from colossalai.inference.tensor_parallel.batch_infer_state import BatchInferState
|
||||
from colossalai.kernel.triton import llama_context_attn_fwd, rotary_embedding_fwd, token_attention_fwd
|
||||
from colossalai.kernel.triton import (
|
||||
llama2_context_attn_fwd,
|
||||
llama_context_attn_fwd,
|
||||
rotary_embedding_fwd,
|
||||
token_attention_fwd,
|
||||
)
|
||||
from colossalai.kernel.triton.token_attention_kernel import Llama2TokenAttentionForwards
|
||||
|
||||
from ._utils import copy_kv_to_mem_cache
|
||||
|
||||
|
@ -138,6 +144,7 @@ class LlamaInferenceForwards:
|
|||
seq_len = infer_state.seq_len
|
||||
infer_state.position_cos = torch.index_select(self._cos_cached, 0, seq_len - 1).view(seq_len.shape[0], -1)
|
||||
infer_state.position_sin = torch.index_select(self._sin_cached, 0, seq_len - 1).view(seq_len.shape[0], -1)
|
||||
infer_state.other_kv_index = infer_state.block_loc[0, seq_length_with_past - 1].item()
|
||||
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.embed_tokens(input_ids)
|
||||
|
@ -261,8 +268,8 @@ class LlamaInferenceForwards:
|
|||
# key_states_transposed [bs, num_heads, seq_len, head_dim/embed_size_per_head]
|
||||
|
||||
query_states = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim)
|
||||
key_states = self.k_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim)
|
||||
value_states = self.v_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim)
|
||||
key_states = self.k_proj(hidden_states).view(bsz, q_len, self.num_key_value_heads, self.head_dim)
|
||||
value_states = self.v_proj(hidden_states).view(bsz, q_len, self.num_key_value_heads, self.head_dim)
|
||||
|
||||
# NOTE might want to revise
|
||||
# need some way to record the length of past key values cache
|
||||
|
@ -274,11 +281,11 @@ class LlamaInferenceForwards:
|
|||
# print("shape ", cos.shape, query_states.view(-1, self.num_heads, self.head_dim).shape, )
|
||||
|
||||
rotary_embedding_fwd(query_states.view(-1, self.num_heads, self.head_dim), cos, sin)
|
||||
rotary_embedding_fwd(key_states.view(-1, self.num_heads, self.head_dim), cos, sin)
|
||||
rotary_embedding_fwd(key_states.view(-1, self.num_key_value_heads, self.head_dim), cos, sin)
|
||||
|
||||
query_states = query_states.reshape(-1, self.num_heads, self.head_dim)
|
||||
key_states = key_states.reshape(-1, self.num_heads, self.head_dim)
|
||||
value_states = value_states.reshape(-1, self.num_heads, self.head_dim)
|
||||
key_states = key_states.reshape(-1, self.num_key_value_heads, self.head_dim)
|
||||
value_states = value_states.reshape(-1, self.num_key_value_heads, self.head_dim)
|
||||
|
||||
if infer_state.is_context_stage:
|
||||
# first token generation
|
||||
|
@ -294,15 +301,26 @@ class LlamaInferenceForwards:
|
|||
|
||||
attn_output = torch.empty_like(query_states)
|
||||
|
||||
llama_context_attn_fwd(
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
attn_output,
|
||||
infer_state.start_loc,
|
||||
infer_state.seq_len,
|
||||
infer_state.cache_manager.past_key_values_length,
|
||||
)
|
||||
if self.num_key_value_groups == 1:
|
||||
llama_context_attn_fwd(
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
attn_output,
|
||||
infer_state.start_loc,
|
||||
infer_state.seq_len,
|
||||
infer_state.cache_manager.past_key_values_length,
|
||||
)
|
||||
else:
|
||||
llama2_context_attn_fwd(
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
attn_output,
|
||||
infer_state.start_loc,
|
||||
infer_state.seq_len,
|
||||
infer_state.cache_manager.past_key_values_length,
|
||||
)
|
||||
else:
|
||||
if infer_state.decode_is_contiguous:
|
||||
# if decode is contiguous, then we copy to key cache and value cache in cache manager directly
|
||||
|
@ -330,17 +348,29 @@ class LlamaInferenceForwards:
|
|||
# (batch_size, seqlen, nheads, headdim)
|
||||
attn_output = torch.empty_like(query_states)
|
||||
|
||||
token_attention_fwd(
|
||||
query_states,
|
||||
infer_state.cache_manager.key_buffer[infer_state.decode_layer_id],
|
||||
infer_state.cache_manager.value_buffer[infer_state.decode_layer_id],
|
||||
attn_output,
|
||||
infer_state.block_loc,
|
||||
infer_state.start_loc,
|
||||
infer_state.seq_len,
|
||||
infer_state.cache_manager.past_key_values_length,
|
||||
)
|
||||
|
||||
if self.num_key_value_groups == 1:
|
||||
token_attention_fwd(
|
||||
query_states,
|
||||
infer_state.cache_manager.key_buffer[infer_state.decode_layer_id],
|
||||
infer_state.cache_manager.value_buffer[infer_state.decode_layer_id],
|
||||
attn_output,
|
||||
infer_state.block_loc,
|
||||
infer_state.start_loc,
|
||||
infer_state.seq_len,
|
||||
infer_state.cache_manager.past_key_values_length,
|
||||
)
|
||||
else:
|
||||
Llama2TokenAttentionForwards.token_attn(
|
||||
query_states,
|
||||
infer_state.cache_manager.key_buffer[infer_state.decode_layer_id],
|
||||
infer_state.cache_manager.value_buffer[infer_state.decode_layer_id],
|
||||
attn_output,
|
||||
infer_state.block_loc,
|
||||
infer_state.start_loc,
|
||||
infer_state.seq_len,
|
||||
infer_state.cache_manager.past_key_values_length,
|
||||
infer_state.other_kv_index,
|
||||
)
|
||||
attn_output = attn_output.view(bsz, q_len, self.hidden_size)
|
||||
|
||||
attn_output = self.o_proj(attn_output)
|
||||
|
|
|
@ -9,7 +9,7 @@ except ImportError:
|
|||
|
||||
# There may exist import error even if we have triton installed.
|
||||
if HAS_TRITON:
|
||||
from .context_attention import bloom_context_attn_fwd, llama_context_attn_fwd
|
||||
from .context_attention import bloom_context_attn_fwd, llama2_context_attn_fwd, llama_context_attn_fwd
|
||||
from .copy_kv_cache_dest import copy_kv_cache_to_dest
|
||||
from .fused_layernorm import layer_norm
|
||||
from .gptq_triton import gptq_fused_linear_triton
|
||||
|
@ -20,6 +20,7 @@ if HAS_TRITON:
|
|||
|
||||
__all__ = [
|
||||
"llama_context_attn_fwd",
|
||||
"llama2_context_attn_fwd",
|
||||
"bloom_context_attn_fwd",
|
||||
"softmax",
|
||||
"layer_norm",
|
||||
|
|
|
@ -0,0 +1,69 @@
|
|||
import os
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from packaging import version
|
||||
from transformers import LlamaForCausalLM
|
||||
from transformers.models.llama.configuration_llama import LlamaConfig
|
||||
|
||||
import colossalai
|
||||
from colossalai.inference.tensor_parallel.engine import TPInferEngine
|
||||
from colossalai.logging import disable_existing_loggers
|
||||
from colossalai.shardformer import ShardConfig
|
||||
from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn
|
||||
|
||||
os.environ["TRANSFORMERS_NO_ADVISORY_WARNINGS"] = "true"
|
||||
TPSIZE = 2
|
||||
BATCH_SIZE = 8
|
||||
MAX_INPUT_LEN = 12
|
||||
MAX_OUTPUT_LEN = 100
|
||||
|
||||
CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse("11.5")
|
||||
|
||||
|
||||
@parameterize(
|
||||
"test_config",
|
||||
[
|
||||
{
|
||||
"tp_size": TPSIZE,
|
||||
}
|
||||
],
|
||||
)
|
||||
def run_llama_test(test_config):
|
||||
llama_config = LlamaConfig(
|
||||
num_hidden_layers=2, num_key_value_heads=8, bos_token_id=0, eos_token_id=1, vocab_size=1200, hidden_size=1024
|
||||
)
|
||||
model = LlamaForCausalLM(llama_config)
|
||||
model = model.half()
|
||||
|
||||
shard_config = ShardConfig(
|
||||
enable_tensor_parallelism=True if test_config["tp_size"] > 1 else False, inference_only=True
|
||||
)
|
||||
infer_engine = TPInferEngine(model, shard_config, BATCH_SIZE, MAX_INPUT_LEN, MAX_OUTPUT_LEN)
|
||||
generate_kwargs = dict(max_new_tokens=MAX_OUTPUT_LEN, do_sample=False)
|
||||
|
||||
input_tokens = {
|
||||
"input_ids": torch.randint(1, 1000, (BATCH_SIZE, MAX_INPUT_LEN), device="cuda"),
|
||||
"attention_mask": torch.ones((BATCH_SIZE, MAX_INPUT_LEN), device="cuda"),
|
||||
}
|
||||
outputs = infer_engine.generate(input_tokens, **generate_kwargs)
|
||||
|
||||
assert outputs is not None
|
||||
|
||||
|
||||
def check_llama(rank, world_size, port):
|
||||
disable_existing_loggers()
|
||||
colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
|
||||
run_llama_test()
|
||||
|
||||
|
||||
@pytest.mark.skipif(not CUDA_SUPPORT, reason="kv-cache manager engine requires cuda version to be higher than 11.5")
|
||||
@pytest.mark.dist
|
||||
@rerun_if_address_is_in_use()
|
||||
@clear_cache_before_run()
|
||||
def test_llama():
|
||||
spawn(check_llama, TPSIZE)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_llama()
|
Loading…
Reference in New Issue