mirror of https://github.com/hpcaitech/ColossalAI
[inference] chatglm2 infer demo (#4724)
* add chatglm2 * add * gather needed kernels * fix some bugs * finish context forward * finish context stage * fix * add * pause * add * fix bugs * finish chatglm * fix bug * change some logic * fix bugs * change some logics * add * add * add * fix * fix tests * fixpull/4778/head
parent
946ab56c48
commit
ce7ade3882
|
@ -16,7 +16,13 @@ from .kvcache_manager import MemoryManager
|
|||
|
||||
DP_AXIS, PP_AXIS, TP_AXIS = 0, 1, 2
|
||||
|
||||
_supported_models = ["LlamaForCausalLM", "LlamaModel", "BloomForCausalLM"]
|
||||
_supported_models = [
|
||||
"LlamaForCausalLM",
|
||||
"LlamaModel",
|
||||
"BloomForCausalLM",
|
||||
"ChatGLMModel",
|
||||
"ChatGLMForConditionalGeneration",
|
||||
]
|
||||
|
||||
|
||||
class TPInferEngine:
|
||||
|
@ -64,7 +70,13 @@ class TPInferEngine:
|
|||
|
||||
self.head_dim = model.config.hidden_size // model.config.num_attention_heads
|
||||
self.head_num = model.config.num_attention_heads
|
||||
self.layer_num = model.config.num_hidden_layers
|
||||
num_hidden_layers = (
|
||||
model.config.num_hidden_layers if hasattr(model.config, "num_hidden_layers") else model.config.num_layers
|
||||
)
|
||||
self.layer_num = num_hidden_layers
|
||||
self.multi_query_group_num = (
|
||||
model.config.multi_query_group_num if hasattr(model.config, "multi_query_group_num") else 0
|
||||
)
|
||||
|
||||
self.tp_size = -1 # to be set with given shard config in self.prepare_shard_config
|
||||
self.cache_manager = None
|
||||
|
@ -85,9 +97,22 @@ class TPInferEngine:
|
|||
assert self.tp_size >= 1, "TP size not initialized without providing a valid ShardConfig"
|
||||
assert self.head_num % self.tp_size == 0, f"Cannot shard {self.head_num} heads with tp size {self.tp_size}"
|
||||
self.head_num //= self.tp_size # update sharded number of heads
|
||||
self.cache_manager = MemoryManager(
|
||||
self.max_total_token_num, self.dtype, self.head_num, self.head_dim, self.layer_num
|
||||
)
|
||||
if self.multi_query_group_num:
|
||||
# NOTE the logic of MQA tensor parallelism should be specified.
|
||||
assert (
|
||||
self.multi_query_group_num % self.tp_size == 0
|
||||
), f"Cannot shard {self.multi_query_group_num} query groups with tp size {self.tp_size}"
|
||||
self.cache_manager = MemoryManager(
|
||||
self.max_total_token_num,
|
||||
self.dtype,
|
||||
self.multi_query_group_num // self.tp_size,
|
||||
self.head_dim,
|
||||
self.layer_num,
|
||||
)
|
||||
else:
|
||||
self.cache_manager = MemoryManager(
|
||||
self.max_total_token_num, self.dtype, self.head_num, self.head_dim, self.layer_num
|
||||
)
|
||||
|
||||
def _post_init_gptq_buffer(self, model: nn.Module) -> None:
|
||||
from colossalai.inference.quant.gptq.cai_gptq import CaiQuantLinear
|
||||
|
|
|
@ -1,4 +1,7 @@
|
|||
import _utils
|
||||
|
||||
from .bloom import BloomInferenceForwards
|
||||
from .chatglm2 import ChatGLM2InferenceForwards
|
||||
from .llama import LlamaInferenceForwards
|
||||
|
||||
__all__ = ["BloomInferenceForwards", "LlamaInferenceForwards"]
|
||||
__all__ = ["BloomInferenceForwards", "LlamaInferenceForwards", "ChatGLM2InferenceForwards"]
|
||||
|
|
|
@ -0,0 +1,10 @@
|
|||
"""
|
||||
Utils for model inference
|
||||
"""
|
||||
from colossalai.kernel.triton.copy_kv_cache_dest import copy_kv_cache_to_dest
|
||||
|
||||
|
||||
def copy_kv_to_mem_cache(layer_id, key_buffer, value_buffer, context_mem_index, mem_manager):
|
||||
copy_kv_cache_to_dest(key_buffer, context_mem_index, mem_manager.key_buffer[layer_id])
|
||||
copy_kv_cache_to_dest(value_buffer, context_mem_index, mem_manager.value_buffer[layer_id])
|
||||
return
|
|
@ -0,0 +1,540 @@
|
|||
import os
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import torch
|
||||
from torch.nn import CrossEntropyLoss
|
||||
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
|
||||
|
||||
from colossalai.inference.tensor_parallel.batch_infer_state import BatchInferState
|
||||
from colossalai.kernel.triton.context_attention import llama2_context_attn_fwd
|
||||
from colossalai.kernel.triton.rotary_embedding_kernel import Llama2Forwards
|
||||
from colossalai.kernel.triton.token_attention_kernel import Llama2TokenAttentionForwards
|
||||
from colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm import (
|
||||
ChatGLMForConditionalGeneration,
|
||||
ChatGLMModel,
|
||||
GLMBlock,
|
||||
GLMTransformer,
|
||||
SelfAttention,
|
||||
split_tensor_along_last_dim,
|
||||
)
|
||||
|
||||
from ._utils import copy_kv_to_mem_cache
|
||||
|
||||
|
||||
# This func is same as Llama model init_to_get_rotary, we should move them into _utils.py
|
||||
def _init_to_get_rotary(self, base=10000):
|
||||
self.config.head_dim_ = self.config.hidden_size // self.config.num_attention_heads
|
||||
if not hasattr(self.config, "rope_scaling"):
|
||||
rope_scaling_factor = 1.0
|
||||
else:
|
||||
rope_scaling_factor = self.config.rope_scaling.factor if self.config.rope_scaling is not None else 1.0
|
||||
if hasattr(self.config, "max_sequence_length"):
|
||||
max_seq_len = self.config.max_sequence_length
|
||||
elif hasattr(self.config, "max_position_embeddings"):
|
||||
max_seq_len = self.config.max_position_embeddings * rope_scaling_factor
|
||||
else:
|
||||
max_seq_len = 2048 * rope_scaling_factor
|
||||
base = float(base)
|
||||
|
||||
# NTK ref: https://www.reddit.com/r/LocalLLaMA/comments/14lz7j5/ntkaware_scaled_rope_allows_llama_models_to_have/
|
||||
try:
|
||||
ntk_alpha = float(os.environ.get("INFER_NTK_ALPHA", 1))
|
||||
assert ntk_alpha >= 1
|
||||
if ntk_alpha > 1:
|
||||
print(f"Note: NTK enabled, alpha set to {ntk_alpha}")
|
||||
max_seq_len *= ntk_alpha
|
||||
base = base * (ntk_alpha ** (self.head_dim_ / (self.head_dim_ - 2))) # Base change formula
|
||||
except:
|
||||
pass
|
||||
n_elem = self.config.head_dim_ // 2
|
||||
inv_freq = 1.0 / (base ** (torch.arange(0, n_elem, 2, device="cpu", dtype=torch.float32) / n_elem))
|
||||
t = torch.arange(max_seq_len + 1024 * 64, device="cpu", dtype=torch.float32) / rope_scaling_factor
|
||||
freqs = torch.outer(t, inv_freq)
|
||||
|
||||
self._cos_cached = torch.cos(freqs).to(torch.float16).cuda()
|
||||
self._sin_cached = torch.sin(freqs).to(torch.float16).cuda()
|
||||
return
|
||||
|
||||
|
||||
def get_masks(self, input_ids, past_length, padding_mask=None):
|
||||
batch_size, seq_length = input_ids.shape
|
||||
full_attention_mask = torch.ones(batch_size, seq_length, seq_length, device=input_ids.device)
|
||||
full_attention_mask.tril_()
|
||||
if past_length:
|
||||
full_attention_mask = torch.cat(
|
||||
(
|
||||
torch.ones(batch_size, seq_length, past_length, device=input_ids.device),
|
||||
full_attention_mask,
|
||||
),
|
||||
dim=-1,
|
||||
)
|
||||
|
||||
if padding_mask is not None:
|
||||
full_attention_mask = full_attention_mask * padding_mask.unsqueeze(1)
|
||||
if not past_length and padding_mask is not None:
|
||||
full_attention_mask -= padding_mask.unsqueeze(-1) - 1
|
||||
full_attention_mask = (full_attention_mask < 0.5).bool()
|
||||
full_attention_mask.unsqueeze_(1)
|
||||
return full_attention_mask
|
||||
|
||||
|
||||
class ChatGLM2InferenceForwards:
|
||||
"""
|
||||
This class holds forwards for Chatglm2 inference.
|
||||
We intend to replace the forward methods for ChatGLMModel, ChatGLMEecoderLayer, and ChatGLMAttention.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def chatglm_for_conditional_generation_forward(
|
||||
self: ChatGLMForConditionalGeneration,
|
||||
input_ids: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.Tensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
past_key_values: Optional[Tuple[torch.FloatTensor]] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
labels: Optional[torch.Tensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
return_last_logit: Optional[bool] = False,
|
||||
):
|
||||
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
infer_state = self.infer_state
|
||||
|
||||
if input_ids is not None and inputs_embeds is not None:
|
||||
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
||||
elif input_ids is not None:
|
||||
batch_size, seq_length = input_ids.shape
|
||||
elif inputs_embeds is not None:
|
||||
batch_size, seq_length, _ = inputs_embeds.shape
|
||||
else:
|
||||
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
||||
|
||||
past_key_values_length = 0
|
||||
|
||||
# NOT READY FOR PRIME TIME
|
||||
# dummy but work, revise it
|
||||
past_key_values_length = infer_state.cache_manager.past_key_values_length
|
||||
seq_length_with_past = seq_length + past_key_values_length
|
||||
infer_state.seq_length_with_past = seq_length_with_past
|
||||
|
||||
# prefill stage at first
|
||||
if use_cache and seq_length != 1:
|
||||
infer_state.is_context_stage = True
|
||||
infer_state.context_mem_index = infer_state.cache_manager.alloc(infer_state.total_token_num)
|
||||
infer_state.init_block_loc(
|
||||
infer_state.block_loc, infer_state.seq_len, seq_length, infer_state.context_mem_index
|
||||
)
|
||||
else:
|
||||
infer_state.is_context_stage = False
|
||||
alloc_mem = infer_state.cache_manager.alloc_contiguous(batch_size)
|
||||
if alloc_mem is not None:
|
||||
infer_state.decode_is_contiguous = True
|
||||
infer_state.decode_mem_index = alloc_mem[0]
|
||||
infer_state.decode_mem_start = alloc_mem[1]
|
||||
infer_state.decode_mem_end = alloc_mem[2]
|
||||
infer_state.block_loc[:, seq_length_with_past - 1] = infer_state.decode_mem_index
|
||||
else:
|
||||
print(f" *** Encountered allocation non-contiguous")
|
||||
print(
|
||||
f" infer_state.cache_manager.past_key_values_length: {infer_state.cache_manager.past_key_values_length}"
|
||||
)
|
||||
infer_state.decode_is_contiguous = False
|
||||
alloc_mem = infer_state.cache_manager.alloc(batch_size)
|
||||
infer_state.decode_mem_index = alloc_mem
|
||||
# infer_state.decode_key_buffer = torch.empty((batch_size, self.tp_head_num_, self.head_dim_), dtype=torch.float16, device="cuda")
|
||||
# infer_state.decode_value_buffer = torch.empty((batch_size, self.tp_head_num_, self.head_dim_), dtype=torch.float16, device="cuda")
|
||||
infer_state.block_loc[:, seq_length_with_past - 1] = infer_state.decode_mem_index
|
||||
|
||||
# related to rotary embedding
|
||||
if infer_state.is_context_stage:
|
||||
infer_state.position_cos = torch.index_select(self._cos_cached, 0, position_ids.view(-1)).view(
|
||||
position_ids.view(-1).shape[0], -1
|
||||
)
|
||||
infer_state.position_sin = torch.index_select(self._sin_cached, 0, position_ids.view(-1)).view(
|
||||
position_ids.view(-1).shape[0], -1
|
||||
)
|
||||
else:
|
||||
seq_len = infer_state.seq_len
|
||||
infer_state.position_cos = torch.index_select(self._cos_cached, 0, seq_len - 1).view(seq_len.shape[0], -1)
|
||||
infer_state.position_sin = torch.index_select(self._sin_cached, 0, seq_len - 1).view(seq_len.shape[0], -1)
|
||||
infer_state.other_kv_index = infer_state.block_loc[0, infer_state.max_len_in_batch - 1].item()
|
||||
|
||||
transformer_outputs = self.transformer(
|
||||
input_ids=input_ids,
|
||||
position_ids=position_ids,
|
||||
attention_mask=attention_mask,
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
infer_state=infer_state,
|
||||
)
|
||||
|
||||
hidden_states = transformer_outputs[0]
|
||||
if return_last_logit:
|
||||
hidden_states = hidden_states[-1:]
|
||||
lm_logits = self.transformer.output_layer(hidden_states)
|
||||
lm_logits = lm_logits.transpose(0, 1).contiguous()
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
lm_logits = lm_logits.to(torch.float32)
|
||||
|
||||
# Shift so that tokens < n predict n
|
||||
shift_logits = lm_logits[..., :-1, :].contiguous()
|
||||
shift_labels = labels[..., 1:].contiguous()
|
||||
# Flatten the tokens
|
||||
loss_fct = CrossEntropyLoss(ignore_index=-100)
|
||||
loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
|
||||
|
||||
lm_logits = lm_logits.to(hidden_states.dtype)
|
||||
loss = loss.to(hidden_states.dtype)
|
||||
|
||||
if not return_dict:
|
||||
output = (lm_logits,) + transformer_outputs[1:]
|
||||
return ((loss,) + output) if loss is not None else output
|
||||
|
||||
return CausalLMOutputWithPast(
|
||||
loss=loss,
|
||||
logits=lm_logits,
|
||||
past_key_values=transformer_outputs.past_key_values,
|
||||
hidden_states=transformer_outputs.hidden_states,
|
||||
attentions=transformer_outputs.attentions,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def chatglm_model_forward(
|
||||
self: ChatGLMModel,
|
||||
input_ids,
|
||||
position_ids: Optional[torch.Tensor] = None,
|
||||
attention_mask: Optional[torch.BoolTensor] = None,
|
||||
full_attention_mask: Optional[torch.BoolTensor] = None,
|
||||
past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
infer_state: BatchInferState = None,
|
||||
):
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
)
|
||||
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
batch_size, seq_length = input_ids.shape
|
||||
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.embedding(input_ids)
|
||||
|
||||
if self.pre_seq_len is not None:
|
||||
if past_key_values is None:
|
||||
past_key_values = self.get_prompt(
|
||||
batch_size=batch_size,
|
||||
device=input_ids.device,
|
||||
dtype=inputs_embeds.dtype,
|
||||
)
|
||||
if attention_mask is not None:
|
||||
attention_mask = torch.cat(
|
||||
[
|
||||
attention_mask.new_ones((batch_size, self.pre_seq_len)),
|
||||
attention_mask,
|
||||
],
|
||||
dim=-1,
|
||||
)
|
||||
if full_attention_mask is None:
|
||||
if (attention_mask is not None and not attention_mask.all()) or (past_key_values and seq_length != 1):
|
||||
full_attention_mask = get_masks(
|
||||
self, input_ids, infer_state.cache_manager.past_key_values_length, padding_mask=attention_mask
|
||||
)
|
||||
|
||||
# Run encoder.
|
||||
hidden_states, presents, all_hidden_states, all_self_attentions = self.encoder(
|
||||
inputs_embeds,
|
||||
full_attention_mask,
|
||||
kv_caches=past_key_values,
|
||||
use_cache=use_cache,
|
||||
output_hidden_states=output_hidden_states,
|
||||
infer_state=infer_state,
|
||||
)
|
||||
|
||||
# update indices
|
||||
# infer_state.block_loc[:, infer_state.max_len_in_batch-1] = infer_state.total_token_num + torch.arange(0, batch_size, dtype=torch.int32, device="cuda")
|
||||
infer_state.start_loc = infer_state.start_loc + torch.arange(0, batch_size, dtype=torch.int32, device="cuda")
|
||||
infer_state.seq_len += 1
|
||||
infer_state.max_len_in_batch += 1
|
||||
infer_state.cache_manager.past_key_values_length += seq_length
|
||||
|
||||
if not return_dict:
|
||||
return tuple(
|
||||
v
|
||||
for v in [
|
||||
hidden_states,
|
||||
presents,
|
||||
all_hidden_states,
|
||||
all_self_attentions,
|
||||
]
|
||||
if v is not None
|
||||
)
|
||||
|
||||
return BaseModelOutputWithPast(
|
||||
last_hidden_state=hidden_states,
|
||||
past_key_values=presents,
|
||||
hidden_states=all_hidden_states,
|
||||
attentions=all_self_attentions,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def chatglm_encoder_forward(
|
||||
self: GLMTransformer,
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
kv_caches=None,
|
||||
use_cache: Optional[bool] = True,
|
||||
output_hidden_states: Optional[bool] = False,
|
||||
infer_state: Optional[BatchInferState] = None,
|
||||
):
|
||||
hidden_states = hidden_states.transpose(0, 1).contiguous()
|
||||
if not kv_caches:
|
||||
kv_caches = [None for _ in range(self.num_layers)]
|
||||
presents = () if use_cache else None
|
||||
all_self_attentions = None
|
||||
all_hidden_states = () if output_hidden_states else None
|
||||
|
||||
infer_state.decode_layer_id = 0
|
||||
for index in range(self.num_layers):
|
||||
layer = self.layers[index]
|
||||
|
||||
layer_ret = layer(
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
kv_cache=kv_caches[index],
|
||||
use_cache=use_cache,
|
||||
infer_state=infer_state,
|
||||
)
|
||||
|
||||
infer_state.decode_layer_id += 1
|
||||
|
||||
hidden_states, kv_cache = layer_ret
|
||||
if use_cache:
|
||||
presents = presents + (kv_cache,)
|
||||
|
||||
if output_hidden_states:
|
||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||
|
||||
# Final layer norm.
|
||||
hidden_states = hidden_states.transpose(0, 1).contiguous()
|
||||
|
||||
if self.post_layer_norm:
|
||||
hidden_states = self.final_layernorm(hidden_states)
|
||||
|
||||
return hidden_states, presents, all_hidden_states, all_self_attentions
|
||||
|
||||
@staticmethod
|
||||
def chatglm_glmblock_forward(
|
||||
self: GLMBlock,
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
kv_cache=None,
|
||||
use_cache=True,
|
||||
infer_state: Optional[BatchInferState] = None,
|
||||
):
|
||||
# hidden_states: [s, b, h]
|
||||
|
||||
# Layer norm at the beginning of the transformer layer.
|
||||
layernorm_output = self.input_layernorm(hidden_states)
|
||||
# Self attention.
|
||||
attention_output, kv_cache = self.self_attention(
|
||||
layernorm_output,
|
||||
attention_mask,
|
||||
kv_cache=kv_cache,
|
||||
use_cache=use_cache,
|
||||
infer_state=infer_state,
|
||||
)
|
||||
# Residual connection.
|
||||
if self.apply_residual_connection_post_layernorm:
|
||||
residual = layernorm_output
|
||||
else:
|
||||
residual = hidden_states
|
||||
layernorm_input = torch.nn.functional.dropout(attention_output, p=self.hidden_dropout, training=self.training)
|
||||
layernorm_input = residual + layernorm_input
|
||||
# Layer norm post the self attention.
|
||||
layernorm_output = self.post_attention_layernorm(layernorm_input)
|
||||
# MLP.
|
||||
mlp_output = self.mlp(layernorm_output)
|
||||
|
||||
# Second residual connection.
|
||||
if self.apply_residual_connection_post_layernorm:
|
||||
residual = layernorm_output
|
||||
else:
|
||||
residual = layernorm_input
|
||||
|
||||
output = torch.nn.functional.dropout(mlp_output, p=self.hidden_dropout, training=self.training)
|
||||
output = residual + output
|
||||
return output, kv_cache
|
||||
|
||||
@staticmethod
|
||||
def chatglm_flash_attn_kvcache_forward(
|
||||
self: SelfAttention,
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
kv_cache=None,
|
||||
use_cache=True,
|
||||
infer_state: Optional[BatchInferState] = None,
|
||||
):
|
||||
assert use_cache is True, "use_cache should be set to True using this chatglm attention"
|
||||
# hidden_states: original :[sq, b, h] --> this [b, sq, h]
|
||||
batch_size = hidden_states.shape[0]
|
||||
# Attention heads [sq, b, h] --> [sq, b, (np * 3 * hn)]
|
||||
mixed_x_layer = self.query_key_value(hidden_states)
|
||||
|
||||
if self.multi_query_attention:
|
||||
(query_layer, key_layer, value_layer) = mixed_x_layer.split(
|
||||
[
|
||||
self.num_attention_heads_per_partition * self.hidden_size_per_attention_head,
|
||||
self.num_multi_query_groups_per_partition * self.hidden_size_per_attention_head,
|
||||
self.num_multi_query_groups_per_partition * self.hidden_size_per_attention_head,
|
||||
],
|
||||
dim=-1,
|
||||
)
|
||||
query_layer = query_layer.view(
|
||||
query_layer.size()[:-1]
|
||||
+ (
|
||||
self.num_attention_heads_per_partition,
|
||||
self.hidden_size_per_attention_head,
|
||||
)
|
||||
)
|
||||
key_layer = key_layer.view(
|
||||
key_layer.size()[:-1]
|
||||
+ (
|
||||
self.num_multi_query_groups_per_partition,
|
||||
self.hidden_size_per_attention_head,
|
||||
)
|
||||
)
|
||||
value_layer = value_layer.view(
|
||||
value_layer.size()[:-1]
|
||||
+ (
|
||||
self.num_multi_query_groups_per_partition,
|
||||
self.hidden_size_per_attention_head,
|
||||
)
|
||||
)
|
||||
|
||||
else:
|
||||
new_tensor_shape = mixed_x_layer.size()[:-1] + (
|
||||
self.num_attention_heads_per_partition,
|
||||
3 * self.hidden_size_per_attention_head,
|
||||
)
|
||||
mixed_x_layer = mixed_x_layer.view(*new_tensor_shape)
|
||||
# [sq, b, np, 3 * hn] --> 3 [sq, b, np, hn]
|
||||
(query_layer, key_layer, value_layer) = split_tensor_along_last_dim(mixed_x_layer, 3)
|
||||
|
||||
cos, sin = infer_state.position_cos, infer_state.position_sin
|
||||
|
||||
Llama2Forwards.rotary_emb_fwd(
|
||||
query_layer.view(-1, self.num_attention_heads_per_partition, self.hidden_size_per_attention_head), cos, sin
|
||||
)
|
||||
if self.multi_query_attention:
|
||||
Llama2Forwards.rotary_emb_fwd(
|
||||
key_layer.view(-1, self.num_multi_query_groups_per_partition, self.hidden_size_per_attention_head),
|
||||
cos,
|
||||
sin,
|
||||
)
|
||||
else:
|
||||
Llama2Forwards.rotary_emb_fwd(
|
||||
key_layer.view(-1, self.num_attention_heads_per_partition, self.hidden_size_per_attention_head),
|
||||
cos,
|
||||
sin,
|
||||
)
|
||||
|
||||
# reshape q k v to [bsz*sql, num_heads, head_dim] 2*1 ,32/2 ,128
|
||||
query_layer = query_layer.reshape(
|
||||
-1, self.num_attention_heads_per_partition, self.hidden_size_per_attention_head
|
||||
)
|
||||
key_layer = key_layer.reshape(
|
||||
-1, self.num_multi_query_groups_per_partition, self.hidden_size_per_attention_head
|
||||
)
|
||||
value_layer = value_layer.reshape(
|
||||
-1, self.num_multi_query_groups_per_partition, self.hidden_size_per_attention_head
|
||||
)
|
||||
if infer_state.is_context_stage:
|
||||
# first token generation:
|
||||
# copy key and value calculated in current step to memory manager
|
||||
|
||||
copy_kv_to_mem_cache(
|
||||
infer_state.decode_layer_id,
|
||||
key_layer,
|
||||
value_layer,
|
||||
infer_state.context_mem_index,
|
||||
infer_state.cache_manager,
|
||||
)
|
||||
|
||||
attn_output = torch.empty_like(query_layer.view(-1, self.projection_size))
|
||||
|
||||
# NOTE: no bug in context attn fwd (del it )
|
||||
llama2_context_attn_fwd(
|
||||
query_layer,
|
||||
key_layer,
|
||||
value_layer,
|
||||
attn_output.view(-1, self.num_attention_heads_per_partition, self.hidden_size_per_attention_head),
|
||||
infer_state.start_loc,
|
||||
infer_state.seq_len,
|
||||
infer_state.seq_length_with_past,
|
||||
)
|
||||
|
||||
else:
|
||||
if infer_state.decode_is_contiguous:
|
||||
# if decode is contiguous, then we copy to key cache and value cache in cache manager directly
|
||||
cache_k = infer_state.cache_manager.key_buffer[infer_state.decode_layer_id][
|
||||
infer_state.decode_mem_start : infer_state.decode_mem_end, :, :
|
||||
]
|
||||
cache_v = infer_state.cache_manager.value_buffer[infer_state.decode_layer_id][
|
||||
infer_state.decode_mem_start : infer_state.decode_mem_end, :, :
|
||||
]
|
||||
cache_k.copy_(key_layer)
|
||||
cache_v.copy_(value_layer)
|
||||
else:
|
||||
# if decode is not contiguous, use triton kernel to copy key and value cache
|
||||
# k, v shape: [batch_size, num_heads, head_dim/embed_size_per_head
|
||||
copy_kv_to_mem_cache(
|
||||
infer_state.decode_layer_id,
|
||||
key_layer,
|
||||
value_layer,
|
||||
infer_state.decode_mem_index,
|
||||
infer_state.cache_manager,
|
||||
)
|
||||
|
||||
# second token and follows
|
||||
attn_output = torch.empty_like(query_layer.view(-1, self.projection_size))
|
||||
cache_k = infer_state.cache_manager.key_buffer[infer_state.decode_layer_id][
|
||||
: infer_state.decode_mem_end, :, :
|
||||
]
|
||||
cache_v = infer_state.cache_manager.value_buffer[infer_state.decode_layer_id][
|
||||
: infer_state.decode_mem_end, :, :
|
||||
]
|
||||
|
||||
# ==================================
|
||||
# core attention computation is replaced by triton kernel
|
||||
# ==================================
|
||||
Llama2TokenAttentionForwards.token_attn(
|
||||
query_layer,
|
||||
cache_k,
|
||||
cache_v,
|
||||
attn_output,
|
||||
infer_state.block_loc,
|
||||
infer_state.start_loc,
|
||||
infer_state.seq_len,
|
||||
infer_state.max_len_in_batch,
|
||||
infer_state.other_kv_index,
|
||||
)
|
||||
|
||||
# print('after attention',torch.isnan(attn_output).any())
|
||||
|
||||
# =================
|
||||
# Output:[b,sq, h]
|
||||
# =================
|
||||
|
||||
output = self.dense(attn_output).reshape(batch_size, -1, self.projection_size)
|
||||
return output, kv_cache
|
|
@ -100,7 +100,7 @@ class LlamaInferenceForwards:
|
|||
# NOTE: differentiate with prefill stage
|
||||
# block_loc require different value-assigning method for two different stage
|
||||
if use_cache and seq_length != 1:
|
||||
# NOTE assuem prefill stage
|
||||
# NOTE assume prefill stage
|
||||
# allocate memory block
|
||||
infer_state.is_context_stage = True # set prefill stage, notify attention layer
|
||||
infer_state.context_mem_index = infer_state.cache_manager.alloc(infer_state.total_token_num)
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
from .bloom import BloomModelInferPolicy
|
||||
from .chatglm2 import ChatGLM2InferPolicy
|
||||
from .llama import LlamaModelInferPolicy
|
||||
|
||||
__all__ = ["BloomModelInferPolicy", "LlamaModelInferPolicy"]
|
||||
__all__ = ["BloomModelInferPolicy", "LlamaModelInferPolicy", "ChatGLM2InferPolicy"]
|
||||
|
|
|
@ -0,0 +1,77 @@
|
|||
from functools import partial
|
||||
|
||||
import torch
|
||||
|
||||
from colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm import (
|
||||
ChatGLMForConditionalGeneration,
|
||||
ChatGLMModel,
|
||||
GLMBlock,
|
||||
GLMTransformer,
|
||||
SelfAttention,
|
||||
)
|
||||
# import colossalai
|
||||
from colossalai.shardformer.policies.chatglm2 import ChatGLMModelPolicy
|
||||
|
||||
from ..modeling.chatglm2 import ChatGLM2InferenceForwards, _init_to_get_rotary
|
||||
|
||||
try:
|
||||
from colossalai.kernel.triton.rms_norm import rmsnorm_forward
|
||||
HAS_TRITON_RMSNORM = True
|
||||
except:
|
||||
print("you should install triton from https://github.com/openai/triton")
|
||||
HAS_TRITON_RMSNORM = False
|
||||
|
||||
|
||||
class ChatGLM2InferPolicy(ChatGLMModelPolicy):
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
def module_policy(self):
|
||||
policy = super().module_policy()
|
||||
self.shard_config._infer()
|
||||
|
||||
model_infer_forward = ChatGLM2InferenceForwards.chatglm_model_forward
|
||||
method_replacement = {'forward': model_infer_forward}
|
||||
self.append_or_create_method_replacement(description=method_replacement, policy=policy, target_key=ChatGLMModel)
|
||||
|
||||
encoder_infer_forward = ChatGLM2InferenceForwards.chatglm_encoder_forward
|
||||
method_replacement = {'forward': encoder_infer_forward}
|
||||
self.append_or_create_method_replacement(description=method_replacement,
|
||||
policy=policy,
|
||||
target_key=GLMTransformer)
|
||||
|
||||
encoder_layer_infer_forward = ChatGLM2InferenceForwards.chatglm_glmblock_forward
|
||||
method_replacement = {'forward': encoder_layer_infer_forward}
|
||||
self.append_or_create_method_replacement(description=method_replacement, policy=policy, target_key=GLMBlock)
|
||||
|
||||
attn_infer_forward = ChatGLM2InferenceForwards.chatglm_flash_attn_kvcache_forward
|
||||
method_replacement = {'forward': attn_infer_forward}
|
||||
self.append_or_create_method_replacement(description=method_replacement,
|
||||
policy=policy,
|
||||
target_key=SelfAttention)
|
||||
|
||||
# for rmsnorm and others, we need to check the shape
|
||||
return policy
|
||||
|
||||
def postprocess(self):
|
||||
_init_to_get_rotary(self.model)
|
||||
return self.model
|
||||
|
||||
|
||||
class ChatGLM2ForConditionalGenerationInferPolicy(ChatGLM2InferPolicy):
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
def module_policy(self):
|
||||
policy = super().module_policy()
|
||||
model_infer_forward = ChatGLM2InferenceForwards.chatglm_for_conditional_generation_forward
|
||||
method_replacement = {'forward': partial(model_infer_forward)}
|
||||
self.append_or_create_method_replacement(description=method_replacement,
|
||||
policy=policy,
|
||||
target_key=ChatGLMForConditionalGeneration)
|
||||
return policy
|
||||
|
||||
def postprocess(self):
|
||||
return super().postprocess()
|
|
@ -11,7 +11,6 @@ except ImportError:
|
|||
HAS_TRITON = False
|
||||
print("please install triton from https://github.com/openai/triton")
|
||||
|
||||
|
||||
if HAS_TRITON:
|
||||
"""
|
||||
this function is modified from
|
||||
|
@ -240,3 +239,328 @@ if HAS_TRITON:
|
|||
num_stages=1,
|
||||
)
|
||||
return
|
||||
|
||||
@triton.jit
|
||||
def _fwd_kernel_latest(
|
||||
Q,
|
||||
K,
|
||||
V,
|
||||
sm_scale,
|
||||
B_Start_Loc,
|
||||
B_Seqlen,
|
||||
Out,
|
||||
stride_qbs,
|
||||
stride_qh,
|
||||
stride_qd,
|
||||
stride_kbs,
|
||||
stride_kh,
|
||||
stride_kd,
|
||||
stride_vbs,
|
||||
stride_vh,
|
||||
stride_vd,
|
||||
stride_obs,
|
||||
stride_oh,
|
||||
stride_od,
|
||||
kv_group_num,
|
||||
BLOCK_M: tl.constexpr,
|
||||
BLOCK_DMODEL: tl.constexpr,
|
||||
BLOCK_N: tl.constexpr,
|
||||
):
|
||||
cur_batch = tl.program_id(0)
|
||||
cur_head = tl.program_id(1)
|
||||
start_m = tl.program_id(2)
|
||||
|
||||
cur_kv_head = cur_head // kv_group_num
|
||||
|
||||
cur_batch_seq_len = tl.load(B_Seqlen + cur_batch)
|
||||
cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch)
|
||||
|
||||
block_start_loc = BLOCK_M * start_m
|
||||
|
||||
# initialize offsets
|
||||
offs_n = tl.arange(0, BLOCK_N)
|
||||
offs_d = tl.arange(0, BLOCK_DMODEL)
|
||||
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
|
||||
off_q = (
|
||||
(cur_batch_in_all_start_index + offs_m[:, None]) * stride_qbs
|
||||
+ cur_head * stride_qh
|
||||
+ offs_d[None, :] * stride_qd
|
||||
)
|
||||
off_k = offs_n[None, :] * stride_kbs + cur_kv_head * stride_kh + offs_d[:, None] * stride_kd
|
||||
off_v = offs_n[:, None] * stride_vbs + cur_kv_head * stride_vh + offs_d[None, :] * stride_vd
|
||||
|
||||
q = tl.load(Q + off_q, mask=offs_m[:, None] < cur_batch_seq_len, other=0.0)
|
||||
|
||||
k_ptrs = K + off_k
|
||||
v_ptrs = V + off_v
|
||||
|
||||
# initialize pointer to m and l
|
||||
m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
|
||||
l_i = tl.zeros([BLOCK_M], dtype=tl.float32)
|
||||
acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
|
||||
|
||||
block_mask = tl.where(block_start_loc < cur_batch_seq_len, 1, 0)
|
||||
|
||||
for start_n in range(0, block_mask * (start_m + 1) * BLOCK_M, BLOCK_N):
|
||||
start_n = tl.multiple_of(start_n, BLOCK_N)
|
||||
# -- compute qk ----
|
||||
k = tl.load(
|
||||
k_ptrs + (cur_batch_in_all_start_index + start_n) * stride_kbs,
|
||||
mask=(start_n + offs_n[None, :]) < cur_batch_seq_len,
|
||||
other=0.0,
|
||||
)
|
||||
# mask = tl.load(mask_ptrs + start_n, mask=start_n + offs_n < cur_batch_end_loc, other=0.0)
|
||||
|
||||
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
|
||||
qk += tl.dot(q, k)
|
||||
qk *= sm_scale
|
||||
qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk, float("-inf"))
|
||||
|
||||
# -- compute m_ij, p, l_ij
|
||||
m_ij = tl.max(qk, 1)
|
||||
p = tl.exp(qk - m_ij[:, None])
|
||||
l_ij = tl.sum(p, 1)
|
||||
# -- update m_i and l_i
|
||||
m_i_new = tl.maximum(m_i, m_ij)
|
||||
alpha = tl.exp(m_i - m_i_new)
|
||||
beta = tl.exp(m_ij - m_i_new)
|
||||
l_i_new = alpha * l_i + beta * l_ij
|
||||
# -- update output accumulator --
|
||||
# scale p
|
||||
p_scale = beta / l_i_new
|
||||
p = p * p_scale[:, None]
|
||||
# scale acc
|
||||
acc_scale = l_i / l_i_new * alpha
|
||||
acc = acc * acc_scale[:, None]
|
||||
# update acc
|
||||
v = tl.load(
|
||||
v_ptrs + (cur_batch_in_all_start_index + start_n) * stride_vbs,
|
||||
mask=(start_n + offs_n[:, None]) < cur_batch_seq_len,
|
||||
other=0.0,
|
||||
)
|
||||
|
||||
p = p.to(v.dtype)
|
||||
acc += tl.dot(p, v)
|
||||
# update m_i and l_i
|
||||
l_i = l_i_new
|
||||
m_i = m_i_new
|
||||
# initialize pointers to output
|
||||
off_o = (
|
||||
(cur_batch_in_all_start_index + offs_m[:, None]) * stride_obs
|
||||
+ cur_head * stride_oh
|
||||
+ offs_d[None, :] * stride_od
|
||||
)
|
||||
out_ptrs = Out + off_o
|
||||
tl.store(out_ptrs, acc, mask=offs_m[:, None] < cur_batch_seq_len)
|
||||
return
|
||||
|
||||
@triton.jit
|
||||
def _fwd_kernel_old(
|
||||
Q,
|
||||
K,
|
||||
V,
|
||||
sm_scale,
|
||||
B_Start_Loc,
|
||||
B_Seqlen,
|
||||
TMP, # NOTE: TMP is a scratchpad buffer to workaround a compiler bug
|
||||
Out,
|
||||
stride_qbs,
|
||||
stride_qh,
|
||||
stride_qd,
|
||||
stride_kbs,
|
||||
stride_kh,
|
||||
stride_kd,
|
||||
stride_vbs,
|
||||
stride_vh,
|
||||
stride_vd,
|
||||
stride_obs,
|
||||
stride_oh,
|
||||
stride_od,
|
||||
stride_tmp_b,
|
||||
stride_tmp_h,
|
||||
stride_tmp_s,
|
||||
kv_group_num,
|
||||
BLOCK_M: tl.constexpr,
|
||||
BLOCK_DMODEL: tl.constexpr,
|
||||
BLOCK_N: tl.constexpr,
|
||||
):
|
||||
cur_batch = tl.program_id(0)
|
||||
cur_head = tl.program_id(1)
|
||||
start_m = tl.program_id(2)
|
||||
|
||||
cur_kv_head = cur_head // kv_group_num
|
||||
|
||||
cur_batch_seq_len = tl.load(B_Seqlen + cur_batch)
|
||||
cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch)
|
||||
|
||||
block_start_loc = BLOCK_M * start_m
|
||||
|
||||
# initialize offsets
|
||||
offs_n = tl.arange(0, BLOCK_N)
|
||||
offs_d = tl.arange(0, BLOCK_DMODEL)
|
||||
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
|
||||
off_q = (
|
||||
(cur_batch_in_all_start_index + offs_m[:, None]) * stride_qbs
|
||||
+ cur_head * stride_qh
|
||||
+ offs_d[None, :] * stride_qd
|
||||
)
|
||||
off_k = offs_n[None, :] * stride_kbs + cur_kv_head * stride_kh + offs_d[:, None] * stride_kd
|
||||
off_v = offs_n[:, None] * stride_vbs + cur_kv_head * stride_vh + offs_d[None, :] * stride_vd
|
||||
q = tl.load(Q + off_q, mask=offs_m[:, None] < cur_batch_seq_len, other=0.0)
|
||||
|
||||
k_ptrs = K + off_k
|
||||
v_ptrs = V + off_v
|
||||
|
||||
t_ptrs = TMP + cur_batch * stride_tmp_b + cur_head * stride_tmp_h + offs_m * stride_tmp_s
|
||||
# t_ptrs = TMP + offs_m
|
||||
m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
|
||||
l_i = tl.zeros([BLOCK_M], dtype=tl.float32)
|
||||
acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
|
||||
|
||||
block_mask = tl.where(block_start_loc < cur_batch_seq_len, 1, 0)
|
||||
|
||||
for start_n in range(0, block_mask * (start_m + 1) * BLOCK_M, BLOCK_N):
|
||||
start_n = tl.multiple_of(start_n, BLOCK_N)
|
||||
# -- compute qk ----
|
||||
k = tl.load(
|
||||
k_ptrs + (cur_batch_in_all_start_index + start_n) * stride_kbs,
|
||||
mask=(start_n + offs_n[None, :]) < cur_batch_seq_len,
|
||||
other=0.0,
|
||||
)
|
||||
|
||||
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
|
||||
qk += tl.dot(q, k)
|
||||
qk *= sm_scale
|
||||
qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk, float("-inf"))
|
||||
|
||||
m_ij = tl.max(qk, 1)
|
||||
p = tl.exp(qk - m_ij[:, None])
|
||||
l_ij = tl.sum(p, 1)
|
||||
# -- update m_i and l_i
|
||||
m_i_new = tl.maximum(m_i, m_ij)
|
||||
alpha = tl.exp(m_i - m_i_new)
|
||||
beta = tl.exp(m_ij - m_i_new)
|
||||
l_i_new = alpha * l_i + beta * l_ij
|
||||
# -- update output accumulator --
|
||||
# scale p
|
||||
p_scale = beta / l_i_new
|
||||
p = p * p_scale[:, None]
|
||||
# scale acc
|
||||
acc_scale = l_i / l_i_new * alpha
|
||||
tl.store(t_ptrs, acc_scale)
|
||||
acc_scale = tl.load(t_ptrs) # BUG: have to store and immediately load
|
||||
acc = acc * acc_scale[:, None]
|
||||
# update acc
|
||||
v = tl.load(
|
||||
v_ptrs + (cur_batch_in_all_start_index + start_n) * stride_vbs,
|
||||
mask=(start_n + offs_n[:, None]) < cur_batch_seq_len,
|
||||
other=0.0,
|
||||
)
|
||||
|
||||
p = p.to(v.dtype)
|
||||
acc += tl.dot(p, v)
|
||||
# update m_i and l_i
|
||||
l_i = l_i_new
|
||||
m_i = m_i_new
|
||||
# initialize pointers to output
|
||||
off_o = (
|
||||
(cur_batch_in_all_start_index + offs_m[:, None]) * stride_obs
|
||||
+ cur_head * stride_oh
|
||||
+ offs_d[None, :] * stride_od
|
||||
)
|
||||
out_ptrs = Out + off_o
|
||||
tl.store(out_ptrs, acc, mask=offs_m[:, None] < cur_batch_seq_len)
|
||||
|
||||
return
|
||||
|
||||
@torch.no_grad()
|
||||
def llama2_context_attn_fwd(q, k, v, o, b_start_loc, b_seq_len, max_input_len):
|
||||
if triton.__version__ >= "2.1.0":
|
||||
BLOCK = 128
|
||||
# shape constraints
|
||||
Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1]
|
||||
assert Lq == Lk and Lk == Lv
|
||||
assert Lk in {16, 32, 64, 128}
|
||||
sm_scale = 1.0 / (Lq**0.5) # 计算scale系数
|
||||
batch, head = b_seq_len.shape[0], q.shape[1]
|
||||
kv_group_num = q.shape[1] // k.shape[1]
|
||||
|
||||
grid = (batch, head, triton.cdiv(max_input_len, BLOCK)) # batch, head,
|
||||
|
||||
num_warps = 4 if Lk <= 64 else 8
|
||||
_fwd_kernel_latest[grid](
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
sm_scale,
|
||||
b_start_loc,
|
||||
b_seq_len,
|
||||
o,
|
||||
q.stride(0),
|
||||
q.stride(1),
|
||||
q.stride(2),
|
||||
k.stride(0),
|
||||
k.stride(1),
|
||||
k.stride(2),
|
||||
v.stride(0),
|
||||
v.stride(1),
|
||||
v.stride(2),
|
||||
o.stride(0),
|
||||
o.stride(1),
|
||||
o.stride(2),
|
||||
kv_group_num=kv_group_num,
|
||||
BLOCK_M=BLOCK,
|
||||
BLOCK_DMODEL=Lk,
|
||||
BLOCK_N=BLOCK,
|
||||
num_warps=num_warps,
|
||||
num_stages=1,
|
||||
)
|
||||
return
|
||||
|
||||
elif triton.__version__ == "2.0.0":
|
||||
BLOCK = 128
|
||||
# shape constraints
|
||||
Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1]
|
||||
assert Lq == Lk and Lk == Lv
|
||||
assert Lk in {16, 32, 64, 128}
|
||||
|
||||
sm_scale = 1.0 / (Lq**0.5)
|
||||
batch, head = b_seq_len.shape[0], q.shape[1]
|
||||
kv_group_num = q.shape[1] // k.shape[1]
|
||||
|
||||
grid = (batch, head, triton.cdiv(max_input_len, BLOCK))
|
||||
tmp = torch.empty((batch, head, max_input_len + 256), device=q.device, dtype=torch.float32)
|
||||
num_warps = 4 if Lk <= 64 else 8
|
||||
# num_warps = 4
|
||||
_fwd_kernel_old[grid](
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
sm_scale,
|
||||
b_start_loc,
|
||||
b_seq_len,
|
||||
tmp,
|
||||
o,
|
||||
q.stride(0),
|
||||
q.stride(1),
|
||||
q.stride(2),
|
||||
k.stride(0),
|
||||
k.stride(1),
|
||||
k.stride(2),
|
||||
v.stride(0),
|
||||
v.stride(1),
|
||||
v.stride(2),
|
||||
o.stride(0),
|
||||
o.stride(1),
|
||||
o.stride(2),
|
||||
tmp.stride(0),
|
||||
tmp.stride(1),
|
||||
tmp.stride(2),
|
||||
kv_group_num=kv_group_num,
|
||||
BLOCK_M=BLOCK,
|
||||
BLOCK_DMODEL=Lk,
|
||||
BLOCK_N=BLOCK,
|
||||
num_warps=num_warps,
|
||||
num_stages=1,
|
||||
)
|
||||
return
|
||||
|
|
|
@ -105,3 +105,108 @@ def rotary_embedding_fwd(q, cos, sin):
|
|||
num_stages=1,
|
||||
)
|
||||
return
|
||||
|
||||
|
||||
class Llama2Forwards:
|
||||
@staticmethod
|
||||
@triton.jit
|
||||
def _rotary_kernel(
|
||||
Q,
|
||||
Cos,
|
||||
Sin,
|
||||
stride_qbs,
|
||||
stride_qh,
|
||||
stride_qd,
|
||||
stride_cosbs,
|
||||
stride_cosd,
|
||||
stride_sinbs,
|
||||
stride_sind,
|
||||
max_total_len,
|
||||
H, # N_CTX
|
||||
BLOCK_HEAD: tl.constexpr,
|
||||
BLOCK_SEQ: tl.constexpr,
|
||||
BLOCK_DMODEL: tl.constexpr,
|
||||
):
|
||||
cur_head_index = tl.program_id(0)
|
||||
cur_seq_index = tl.program_id(1)
|
||||
|
||||
cur_head_range = cur_head_index * BLOCK_HEAD + tl.arange(0, BLOCK_HEAD)
|
||||
cur_seq_range = cur_seq_index * BLOCK_SEQ + tl.arange(0, BLOCK_SEQ)
|
||||
|
||||
dim_range0 = tl.arange(0, BLOCK_DMODEL // 2) * 2
|
||||
dim_range1 = dim_range0 + 1
|
||||
off_q0 = (
|
||||
cur_seq_range[:, None, None] * stride_qbs
|
||||
+ cur_head_range[None, :, None] * stride_qh
|
||||
+ dim_range0[None, None, :] * stride_qd
|
||||
)
|
||||
off_q1 = (
|
||||
cur_seq_range[:, None, None] * stride_qbs
|
||||
+ cur_head_range[None, :, None] * stride_qh
|
||||
+ dim_range1[None, None, :] * stride_qd
|
||||
)
|
||||
|
||||
cos_range = tl.arange(0, BLOCK_DMODEL // 2)
|
||||
off_dimcos_sin = cur_seq_range[:, None, None] * stride_cosbs + cos_range[None, None, :] * stride_cosd
|
||||
|
||||
q0 = tl.load(
|
||||
Q + off_q0,
|
||||
mask=(cur_seq_range[:, None, None] < max_total_len) & (cur_head_range[None, :, None] < H),
|
||||
other=0.0,
|
||||
)
|
||||
q1 = tl.load(
|
||||
Q + off_q1,
|
||||
mask=(cur_seq_range[:, None, None] < max_total_len) & (cur_head_range[None, :, None] < H),
|
||||
other=0.0,
|
||||
)
|
||||
|
||||
cos = tl.load(Cos + off_dimcos_sin, mask=cur_seq_range[:, None, None] < max_total_len, other=0.0)
|
||||
sin = tl.load(Sin + off_dimcos_sin, mask=cur_seq_range[:, None, None] < max_total_len, other=0.0)
|
||||
|
||||
out0 = q0 * cos - q1 * sin
|
||||
out1 = q0 * sin + q1 * cos
|
||||
|
||||
tl.store(
|
||||
Q + off_q0, out0, mask=(cur_seq_range[:, None, None] < max_total_len) & (cur_head_range[None, :, None] < H)
|
||||
)
|
||||
tl.store(
|
||||
Q + off_q1, out1, mask=(cur_seq_range[:, None, None] < max_total_len) & (cur_head_range[None, :, None] < H)
|
||||
)
|
||||
|
||||
return
|
||||
|
||||
@staticmethod
|
||||
@torch.no_grad()
|
||||
def rotary_emb_fwd(q, cos, sin):
|
||||
total_len = q.shape[0]
|
||||
head_num = q.shape[1]
|
||||
head_dim = q.shape[2] // 2
|
||||
assert q.shape[0] == cos.shape[0] and q.shape[0] == sin.shape[0], f"q shape {q.shape} cos shape {cos.shape}"
|
||||
BLOCK_HEAD = 4
|
||||
BLOCK_SEQ = 32
|
||||
grid = (triton.cdiv(head_num, BLOCK_HEAD), triton.cdiv(total_len, BLOCK_SEQ))
|
||||
if head_dim >= 128:
|
||||
num_warps = 8
|
||||
else:
|
||||
num_warps = 4
|
||||
|
||||
Llama2Forwards._rotary_kernel[grid](
|
||||
q,
|
||||
cos,
|
||||
sin,
|
||||
q.stride(0),
|
||||
q.stride(1),
|
||||
q.stride(2),
|
||||
cos.stride(0),
|
||||
cos.stride(1),
|
||||
sin.stride(0),
|
||||
sin.stride(1),
|
||||
total_len,
|
||||
head_num,
|
||||
BLOCK_HEAD=BLOCK_HEAD,
|
||||
BLOCK_SEQ=BLOCK_SEQ,
|
||||
BLOCK_DMODEL=head_dim,
|
||||
num_warps=num_warps,
|
||||
num_stages=1,
|
||||
)
|
||||
return
|
||||
|
|
|
@ -402,3 +402,440 @@ if HAS_TRITON:
|
|||
prob = None
|
||||
|
||||
return
|
||||
|
||||
|
||||
class Llama2TokenAttentionForwards:
|
||||
@staticmethod
|
||||
@triton.jit
|
||||
def _fwd_kernel(
|
||||
Logics,
|
||||
V,
|
||||
Out,
|
||||
B_Loc,
|
||||
B_Start_Loc,
|
||||
B_Seqlen,
|
||||
max_input_len,
|
||||
stride_logic_h,
|
||||
stride_logic_bs,
|
||||
stride_vbs,
|
||||
stride_vh,
|
||||
stride_vd,
|
||||
stride_obs,
|
||||
stride_oh,
|
||||
stride_od,
|
||||
stride_b_loc_b,
|
||||
stride_b_loc_s,
|
||||
other_kv_index, # avoid nan information
|
||||
kv_group_num,
|
||||
BLOCK_DMODEL: tl.constexpr,
|
||||
BLOCK_N: tl.constexpr,
|
||||
):
|
||||
cur_batch = tl.program_id(0)
|
||||
cur_head = tl.program_id(1)
|
||||
|
||||
cur_kv_head = cur_head // kv_group_num
|
||||
|
||||
cur_batch_seq_len = tl.load(B_Seqlen + cur_batch)
|
||||
cur_batch_start_loc = tl.load(B_Start_Loc + cur_batch)
|
||||
|
||||
offs_n = tl.arange(0, BLOCK_N)
|
||||
offs_d = tl.arange(0, BLOCK_DMODEL)
|
||||
|
||||
off_v = cur_kv_head * stride_vh + offs_d[None, :] * stride_vd
|
||||
off_b_loc = cur_batch * stride_b_loc_b + (max_input_len - cur_batch_seq_len) * stride_b_loc_s
|
||||
|
||||
v_ptrs = V + off_v
|
||||
|
||||
e_max = float("-inf")
|
||||
e_sum = 0.0
|
||||
acc = tl.zeros([BLOCK_DMODEL], dtype=tl.float32)
|
||||
|
||||
for start_n in range(0, cur_batch_seq_len, BLOCK_N):
|
||||
start_n = tl.multiple_of(start_n, BLOCK_N)
|
||||
v_index = tl.load(
|
||||
B_Loc + off_b_loc + (start_n + offs_n) * stride_b_loc_s,
|
||||
mask=(start_n + offs_n) < cur_batch_seq_len,
|
||||
other=other_kv_index,
|
||||
)
|
||||
|
||||
qk = tl.load(
|
||||
Logics + cur_head * stride_logic_h + (cur_batch_start_loc + start_n + offs_n) * stride_logic_bs,
|
||||
mask=start_n + offs_n < cur_batch_seq_len,
|
||||
other=float("-inf"),
|
||||
)
|
||||
|
||||
n_e_max = tl.maximum(tl.max(qk, 0), e_max)
|
||||
old_scale = tl.exp(e_max - n_e_max)
|
||||
p = tl.exp(qk - n_e_max)
|
||||
e_sum = e_sum * old_scale + tl.sum(p, 0)
|
||||
v = tl.load(v_ptrs + v_index[:, None] * stride_vbs)
|
||||
acc = acc * old_scale + tl.sum(p[:, None] * v, 0)
|
||||
e_max = n_e_max
|
||||
|
||||
acc = acc / e_sum
|
||||
off_o = cur_batch * stride_obs + cur_head * stride_oh + offs_d * stride_od
|
||||
out_ptrs = Out + off_o
|
||||
tl.store(out_ptrs, acc)
|
||||
return
|
||||
|
||||
@staticmethod
|
||||
@torch.no_grad()
|
||||
def token_softmax_reducev_fwd(logics, v, o, b_loc, b_start_loc, b_seq_len, max_input_len, other_kv_index):
|
||||
BLOCK = 64
|
||||
batch, head = b_seq_len.shape[0], logics.shape[0]
|
||||
grid = (batch, head)
|
||||
kv_group_num = logics.shape[0] // v.shape[1]
|
||||
|
||||
num_warps = 1
|
||||
Llama2TokenAttentionForwards._fwd_kernel[grid](
|
||||
logics,
|
||||
v,
|
||||
o,
|
||||
b_loc,
|
||||
b_start_loc,
|
||||
b_seq_len,
|
||||
max_input_len,
|
||||
logics.stride(0),
|
||||
logics.stride(1),
|
||||
v.stride(0),
|
||||
v.stride(1),
|
||||
v.stride(2),
|
||||
o.stride(0),
|
||||
o.stride(1),
|
||||
o.stride(2),
|
||||
b_loc.stride(0),
|
||||
b_loc.stride(1),
|
||||
other_kv_index,
|
||||
kv_group_num,
|
||||
BLOCK_DMODEL=v.shape[-1],
|
||||
BLOCK_N=BLOCK,
|
||||
num_warps=num_warps,
|
||||
num_stages=3,
|
||||
)
|
||||
return
|
||||
|
||||
@staticmethod
|
||||
@triton.jit
|
||||
def _fwd_kernel_token_softmax(
|
||||
Logics,
|
||||
B_Start_Loc,
|
||||
B_Seqlen,
|
||||
Prob_Out,
|
||||
stride_logic_h,
|
||||
stride_logic_bs,
|
||||
stride_prob_h,
|
||||
stride_prob_bs,
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
):
|
||||
cur_batch = tl.program_id(0)
|
||||
cur_head = tl.program_id(1)
|
||||
|
||||
col_offsets = tl.arange(0, BLOCK_SIZE)
|
||||
cur_batch_seq_len = tl.load(B_Seqlen + cur_batch)
|
||||
cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch)
|
||||
|
||||
row = tl.load(
|
||||
Logics + cur_head * stride_logic_h + (cur_batch_in_all_start_index + col_offsets) * stride_logic_bs,
|
||||
mask=col_offsets < cur_batch_seq_len,
|
||||
other=-float("inf"),
|
||||
).to(tl.float32)
|
||||
|
||||
row_minus_max = row - tl.max(row, axis=0)
|
||||
numerator = tl.exp(row_minus_max)
|
||||
denominator = tl.sum(numerator, axis=0)
|
||||
softmax_output = numerator / denominator
|
||||
|
||||
tl.store(
|
||||
Prob_Out + cur_head * stride_prob_h + (cur_batch_in_all_start_index + col_offsets) * stride_prob_bs,
|
||||
softmax_output,
|
||||
mask=col_offsets < cur_batch_seq_len,
|
||||
)
|
||||
return
|
||||
|
||||
@staticmethod
|
||||
@torch.no_grad()
|
||||
def token_softmax_fwd(Logics, B_Start_Loc, B_Seqlen, Prob_Out, max_input_len):
|
||||
BLOCK_SIZE = triton.next_power_of_2(max_input_len)
|
||||
batch, head_num = B_Start_Loc.shape[0], Logics.shape[0]
|
||||
|
||||
num_warps = 4
|
||||
if BLOCK_SIZE >= 2048:
|
||||
num_warps = 8
|
||||
if BLOCK_SIZE >= 4096:
|
||||
num_warps = 16
|
||||
|
||||
Llama2TokenAttentionForwards._fwd_kernel_token_softmax[(batch, head_num)](
|
||||
Logics,
|
||||
B_Start_Loc,
|
||||
B_Seqlen,
|
||||
Prob_Out,
|
||||
Logics.stride(0),
|
||||
Logics.stride(1),
|
||||
Prob_Out.stride(0),
|
||||
Prob_Out.stride(1),
|
||||
num_warps=num_warps,
|
||||
BLOCK_SIZE=BLOCK_SIZE,
|
||||
)
|
||||
return
|
||||
|
||||
@staticmethod
|
||||
@triton.jit
|
||||
def _fwd_kernel_token_att1(
|
||||
Q,
|
||||
K,
|
||||
sm_scale,
|
||||
B_Loc,
|
||||
B_Start_Loc,
|
||||
B_Seqlen,
|
||||
max_input_len,
|
||||
Att_Out,
|
||||
stride_b_loc_b,
|
||||
stride_b_loc_s,
|
||||
stride_qbs,
|
||||
stride_qh,
|
||||
stride_qd,
|
||||
stride_kbs,
|
||||
stride_kh,
|
||||
stride_kd,
|
||||
att_stride_h,
|
||||
att_stride_bs,
|
||||
kv_group_num,
|
||||
BLOCK_DMODEL: tl.constexpr,
|
||||
BLOCK_N: tl.constexpr,
|
||||
):
|
||||
cur_batch = tl.program_id(0)
|
||||
cur_head = tl.program_id(1)
|
||||
start_n = tl.program_id(2)
|
||||
|
||||
cur_kv_head = cur_head // kv_group_num
|
||||
|
||||
offs_d = tl.arange(0, BLOCK_DMODEL)
|
||||
cur_batch_seq_len = tl.load(B_Seqlen + cur_batch)
|
||||
cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch)
|
||||
|
||||
cur_batch_start_index = max_input_len - cur_batch_seq_len
|
||||
cur_batch_end_index = max_input_len
|
||||
|
||||
off_q = cur_batch * stride_qbs + cur_head * stride_qh + offs_d * stride_qd
|
||||
|
||||
offs_n = start_n * BLOCK_N + tl.arange(0, BLOCK_N)
|
||||
|
||||
block_stard_index = start_n * BLOCK_N
|
||||
block_mask = tl.where(block_stard_index < cur_batch_seq_len, 1, 0)
|
||||
|
||||
for start_mark in range(0, block_mask, 1):
|
||||
q = tl.load(Q + off_q + start_mark)
|
||||
offs_n_new = cur_batch_start_index + offs_n
|
||||
k_loc = tl.load(
|
||||
B_Loc + stride_b_loc_b * cur_batch + stride_b_loc_s * offs_n_new,
|
||||
mask=offs_n_new < cur_batch_end_index,
|
||||
other=0,
|
||||
)
|
||||
off_k = k_loc[:, None] * stride_kbs + cur_kv_head * stride_kh + offs_d[None, :] * stride_kd
|
||||
k = tl.load(K + off_k, mask=offs_n_new[:, None] < cur_batch_end_index, other=0.0)
|
||||
att_value = tl.sum(q[None, :] * k, 1)
|
||||
att_value *= sm_scale
|
||||
off_o = cur_head * att_stride_h + (cur_batch_in_all_start_index + offs_n) * att_stride_bs
|
||||
tl.store(Att_Out + off_o, att_value, mask=offs_n_new < cur_batch_end_index)
|
||||
return
|
||||
|
||||
@staticmethod
|
||||
@torch.no_grad()
|
||||
def token_att_fwd(q, k, att_out, B_Loc, B_Start_Loc, B_Seqlen, max_input_len):
|
||||
BLOCK = 32
|
||||
# shape constraints
|
||||
Lq, Lk = q.shape[-1], k.shape[-1]
|
||||
assert Lq == Lk
|
||||
assert Lk in {16, 32, 64, 128}
|
||||
sm_scale = 1.0 / (Lk**0.5)
|
||||
|
||||
batch, head_num = B_Loc.shape[0], q.shape[1]
|
||||
|
||||
grid = (batch, head_num, triton.cdiv(max_input_len, BLOCK))
|
||||
kv_group_num = q.shape[1] // k.shape[1]
|
||||
|
||||
num_warps = 4 if Lk <= 64 else 8
|
||||
num_warps = 2
|
||||
|
||||
Llama2TokenAttentionForwards._fwd_kernel_token_att1[grid](
|
||||
q,
|
||||
k,
|
||||
sm_scale,
|
||||
B_Loc,
|
||||
B_Start_Loc,
|
||||
B_Seqlen,
|
||||
max_input_len,
|
||||
att_out,
|
||||
B_Loc.stride(0),
|
||||
B_Loc.stride(1),
|
||||
q.stride(0),
|
||||
q.stride(1),
|
||||
q.stride(2),
|
||||
k.stride(0),
|
||||
k.stride(1),
|
||||
k.stride(2),
|
||||
att_out.stride(0),
|
||||
att_out.stride(1),
|
||||
kv_group_num=kv_group_num,
|
||||
BLOCK_DMODEL=Lk,
|
||||
BLOCK_N=BLOCK,
|
||||
num_warps=num_warps,
|
||||
num_stages=1,
|
||||
)
|
||||
return
|
||||
|
||||
@staticmethod
|
||||
@triton.jit
|
||||
def _fwd_kernel_token_att2(
|
||||
Prob,
|
||||
V,
|
||||
Out,
|
||||
B_Loc,
|
||||
B_Start_Loc,
|
||||
B_Seqlen,
|
||||
max_input_len, # B_Start_Loc cumsum of input lens if continuous
|
||||
stride_b_loc_b,
|
||||
stride_b_loc_s,
|
||||
stride_ph,
|
||||
stride_pbs,
|
||||
stride_vbs,
|
||||
stride_vh,
|
||||
stride_vd,
|
||||
stride_obs,
|
||||
stride_oh,
|
||||
stride_od,
|
||||
kv_group_num,
|
||||
BLOCK_DMODEL: tl.constexpr,
|
||||
BLOCK_N: tl.constexpr,
|
||||
):
|
||||
cur_batch = tl.program_id(0)
|
||||
cur_head = tl.program_id(1)
|
||||
|
||||
cur_kv_head = cur_head // kv_group_num
|
||||
|
||||
offs_n = tl.arange(0, BLOCK_N)
|
||||
offs_d = tl.arange(0, BLOCK_DMODEL)
|
||||
cur_batch_seq_len = tl.load(B_Seqlen + cur_batch)
|
||||
cur_batch_start_index = max_input_len - cur_batch_seq_len
|
||||
cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch)
|
||||
|
||||
v_loc_off = cur_batch * stride_b_loc_b + (cur_batch_start_index + offs_n) * stride_b_loc_s
|
||||
p_offs = cur_head * stride_ph + (cur_batch_in_all_start_index + offs_n) * stride_pbs
|
||||
v_offs = cur_kv_head * stride_vh + offs_d[None, :] * stride_vd
|
||||
|
||||
acc = tl.zeros([BLOCK_DMODEL], dtype=tl.float32)
|
||||
for start_n in range(0, cur_batch_seq_len, BLOCK_N):
|
||||
start_n = tl.multiple_of(start_n, BLOCK_N)
|
||||
p_value = tl.load(
|
||||
Prob + p_offs + start_n * stride_b_loc_s, mask=(start_n + offs_n) < cur_batch_seq_len, other=0.0
|
||||
)
|
||||
v_loc = tl.load(
|
||||
B_Loc + v_loc_off + start_n * stride_b_loc_s, mask=(start_n + offs_n) < cur_batch_seq_len, other=0.0
|
||||
)
|
||||
v_value = tl.load(
|
||||
V + v_offs + v_loc[:, None] * stride_vbs,
|
||||
mask=(start_n + offs_n[:, None]) < cur_batch_seq_len,
|
||||
other=0.0,
|
||||
)
|
||||
acc += tl.sum(p_value[:, None] * v_value, 0)
|
||||
|
||||
acc = acc.to(tl.float16)
|
||||
off_o = cur_batch * stride_obs + cur_head * stride_oh + offs_d * stride_od
|
||||
out_ptrs = Out + off_o
|
||||
tl.store(out_ptrs, acc)
|
||||
return
|
||||
|
||||
@staticmethod
|
||||
@torch.no_grad()
|
||||
def token_att_fwd2(prob, v, out, B_Loc, B_Start_Loc, B_Seqlen, max_input_len):
|
||||
if triton.__version__ >= "2.1.0":
|
||||
BLOCK = 128
|
||||
else:
|
||||
BLOCK = 64
|
||||
batch, head = B_Loc.shape[0], prob.shape[0]
|
||||
grid = (batch, head)
|
||||
num_warps = 4
|
||||
dim = v.shape[-1]
|
||||
|
||||
kv_group_num = prob.shape[0] // v.shape[1]
|
||||
|
||||
Llama2TokenAttentionForwards._fwd_kernel_token_att2[grid](
|
||||
prob,
|
||||
v,
|
||||
out,
|
||||
B_Loc,
|
||||
B_Start_Loc,
|
||||
B_Seqlen,
|
||||
max_input_len,
|
||||
B_Loc.stride(0),
|
||||
B_Loc.stride(1),
|
||||
prob.stride(0),
|
||||
prob.stride(1),
|
||||
v.stride(0),
|
||||
v.stride(1),
|
||||
v.stride(2),
|
||||
out.stride(0),
|
||||
out.stride(1),
|
||||
out.stride(2),
|
||||
kv_group_num=kv_group_num,
|
||||
BLOCK_DMODEL=dim,
|
||||
BLOCK_N=BLOCK,
|
||||
num_warps=num_warps,
|
||||
num_stages=1,
|
||||
)
|
||||
return
|
||||
|
||||
# this is the interface of llama2 attn forward
|
||||
@staticmethod
|
||||
@torch.no_grad()
|
||||
def token_attn(
|
||||
q, k, v, attn_out, kv_cache_loc, kv_cache_start_loc, kv_cache_seq_len, max_len_in_batch, other_kv_index
|
||||
):
|
||||
total_token_num = k.shape[0]
|
||||
batch_size, head_num, head_dim = q.shape
|
||||
calcu_shape1 = (batch_size, head_num, head_dim)
|
||||
att_m_tensor = torch.empty((head_num, total_token_num), dtype=q.dtype, device="cuda")
|
||||
|
||||
Llama2TokenAttentionForwards.token_att_fwd(
|
||||
q,
|
||||
k,
|
||||
att_m_tensor,
|
||||
kv_cache_loc,
|
||||
kv_cache_start_loc,
|
||||
kv_cache_seq_len,
|
||||
max_len_in_batch,
|
||||
)
|
||||
|
||||
if triton.__version__ == "2.0.0":
|
||||
prob = torch.empty_like(att_m_tensor)
|
||||
Llama2TokenAttentionForwards.token_softmax_fwd(
|
||||
att_m_tensor, kv_cache_start_loc, kv_cache_seq_len, prob, max_len_in_batch
|
||||
)
|
||||
att_m_tensor = None
|
||||
|
||||
Llama2TokenAttentionForwards.token_att_fwd2(
|
||||
prob,
|
||||
v,
|
||||
attn_out.view(calcu_shape1),
|
||||
kv_cache_loc,
|
||||
kv_cache_start_loc,
|
||||
kv_cache_seq_len,
|
||||
max_len_in_batch,
|
||||
)
|
||||
|
||||
prob = None
|
||||
return
|
||||
|
||||
elif triton.__version__ >= "2.1.0":
|
||||
Llama2TokenAttentionForwards.token_softmax_reducev_fwd(
|
||||
att_m_tensor,
|
||||
v,
|
||||
attn_out.view(calcu_shape1),
|
||||
kv_cache_loc,
|
||||
kv_cache_start_loc,
|
||||
kv_cache_seq_len,
|
||||
max_len_in_batch,
|
||||
other_kv_index,
|
||||
)
|
||||
else:
|
||||
raise Exception("not support triton version")
|
||||
|
|
|
@ -380,12 +380,10 @@ class SelfAttention(torch.nn.Module):
|
|||
def __init__(self, config: ChatGLMConfig, layer_number, device=None):
|
||||
super(SelfAttention, self).__init__()
|
||||
self.layer_number = max(1, layer_number)
|
||||
|
||||
self.projection_size = config.kv_channels * config.num_attention_heads
|
||||
# Per attention head and per partition values.
|
||||
self.hidden_size_per_attention_head = self.projection_size // config.num_attention_heads
|
||||
self.num_attention_heads_per_partition = config.num_attention_heads
|
||||
|
||||
self.multi_query_attention = config.multi_query_attention
|
||||
self.qkv_hidden_size = 3 * self.projection_size
|
||||
if self.multi_query_attention:
|
||||
|
@ -445,7 +443,6 @@ class SelfAttention(torch.nn.Module):
|
|||
|
||||
# Attention heads [sq, b, h] --> [sq, b, (np * 3 * hn)]
|
||||
mixed_x_layer = self.query_key_value(hidden_states)
|
||||
|
||||
if self.multi_query_attention:
|
||||
(query_layer, key_layer, value_layer) = mixed_x_layer.split(
|
||||
[
|
||||
|
@ -541,7 +538,6 @@ class SelfAttention(torch.nn.Module):
|
|||
# =================
|
||||
# Output. [sq, b, h]
|
||||
# =================
|
||||
|
||||
output = self.dense(context_layer)
|
||||
|
||||
return output, kv_cache
|
||||
|
|
|
@ -164,6 +164,13 @@ _INFER_POLICY_LIST = {
|
|||
"transformers.models.bloom.modeling_bloom.BloomForCausalLM": PolicyLocation(
|
||||
file_name="bloom", class_name="BloomModelInferPolicy"
|
||||
),
|
||||
# ChatGLM2
|
||||
"colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm.ChatGLMModel": PolicyLocation(
|
||||
file_name="chatglm2", class_name="ChatGLM2InferPolicy"
|
||||
),
|
||||
"colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm.ChatGLMForConditionalGeneration": PolicyLocation(
|
||||
file_name="chatglm2", class_name="ChatGLM2ForConditionalGenerationInferPolicy"
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
|
@ -208,7 +215,7 @@ def get_autopolicy(model: nn.Module, inference_only: Optional[bool] = False) ->
|
|||
|
||||
if policy_location is None:
|
||||
raise NotImplementedError(
|
||||
f"Auto policy for {model.__class__.__qualname__} is not implemented\n. Supported models are {list(_POLICY_LIST.keys())}"
|
||||
f"Auto policy for {model.__class__.__qualname__} is not implemented\n. Supported models are {list(_POLICY_LIST.keys())} and {list(_INFER_POLICY_LIST.keys())}"
|
||||
)
|
||||
else:
|
||||
policy = import_policy(policy_location, inference_only)
|
||||
|
|
|
@ -39,6 +39,21 @@ config = ChatGLMConfig(
|
|||
padded_vocab_size=65024,
|
||||
hidden_size=64,
|
||||
num_attention_heads=8,
|
||||
kv_channels=16,
|
||||
rmsnorm=True,
|
||||
original_rope=True,
|
||||
use_cache=True,
|
||||
torch_dtype=torch.float32,
|
||||
)
|
||||
|
||||
infer_config = ChatGLMConfig(
|
||||
num_layers=2,
|
||||
padded_vocab_size=65024,
|
||||
hidden_size=128,
|
||||
num_attention_heads=8,
|
||||
multi_query_attention=True,
|
||||
multi_query_group_num=2,
|
||||
kv_channels=16,
|
||||
rmsnorm=True,
|
||||
original_rope=True,
|
||||
use_cache=True,
|
||||
|
|
|
@ -0,0 +1,73 @@
|
|||
import os
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from packaging import version
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
import colossalai
|
||||
from colossalai.inference.tensor_parallel.engine import TPInferEngine
|
||||
from colossalai.logging import disable_existing_loggers
|
||||
from colossalai.shardformer import ShardConfig
|
||||
from colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm import ChatGLMForConditionalGeneration
|
||||
from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn
|
||||
from tests.kit.model_zoo.transformers.chatglm2 import infer_config
|
||||
|
||||
os.environ["TRANSFORMERS_NO_ADVISORY_WARNINGS"] = "true"
|
||||
TPSIZE = 1
|
||||
BATCH_SIZE = 8
|
||||
MAX_INPUT_LEN = 12
|
||||
MAX_OUTPUT_LEN = 100
|
||||
CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse("11.5")
|
||||
|
||||
|
||||
@parameterize(
|
||||
"test_config",
|
||||
[
|
||||
{
|
||||
"tp_size": TPSIZE,
|
||||
}
|
||||
],
|
||||
)
|
||||
def run_chatglm2_test(test_config):
|
||||
tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm2-6b", trust_remote_code=True)
|
||||
# pad_token_id = 0
|
||||
model_fn = lambda: ChatGLMForConditionalGeneration(infer_config, empty_init=False)
|
||||
orig_model = model_fn()
|
||||
orig_model = orig_model.half()
|
||||
text = ["how is the weather today?"]
|
||||
input_ids = tokenizer.batch_encode_plus(text, return_tensors="pt", padding=True)
|
||||
shard_config = ShardConfig(
|
||||
enable_tensor_parallelism=True if test_config["tp_size"] > 1 else False, inference_only=True
|
||||
)
|
||||
infer_engine = TPInferEngine(orig_model, shard_config, BATCH_SIZE, MAX_INPUT_LEN, MAX_OUTPUT_LEN)
|
||||
|
||||
generate_kwargs = dict(max_new_tokens=MAX_OUTPUT_LEN, do_sample=False)
|
||||
outputs = infer_engine.generate(input_ids, **generate_kwargs)
|
||||
assert outputs is not None
|
||||
|
||||
# print("outputs.shape: ", outputs[0].shape)
|
||||
# print("outputs: ", outputs[0])
|
||||
if not dist.is_initialized() or dist.get_rank() == 0:
|
||||
for o in outputs:
|
||||
output_text = tokenizer.decode(o)
|
||||
print(output_text)
|
||||
|
||||
|
||||
def check_chatglm2(rank, world_size, port):
|
||||
disable_existing_loggers()
|
||||
colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
|
||||
run_chatglm2_test()
|
||||
|
||||
|
||||
@pytest.mark.skipif(not CUDA_SUPPORT, reason="kv-cache manager engine requires cuda version to be higher than 11.5")
|
||||
@pytest.mark.dist
|
||||
@rerun_if_address_is_in_use()
|
||||
@clear_cache_before_run()
|
||||
def test_chatglm2():
|
||||
spawn(check_chatglm2, TPSIZE)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_chatglm2()
|
|
@ -0,0 +1,65 @@
|
|||
import pytest
|
||||
import torch
|
||||
from packaging import version
|
||||
|
||||
try:
|
||||
pass
|
||||
|
||||
from colossalai.kernel.triton.token_attention_kernel import Llama2TokenAttentionForwards
|
||||
|
||||
HAS_TRITON = True
|
||||
except ImportError:
|
||||
HAS_TRITON = False
|
||||
print("please install triton from https://github.com/openai/triton")
|
||||
|
||||
TRITON_CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse("11.4")
|
||||
|
||||
|
||||
def torch_att(xq, xk, xv, bs, seqlen, num_head, head_dim):
|
||||
xq = xq.view(bs, 1, num_head, head_dim)
|
||||
xk = xk.view(bs, seqlen, num_head, head_dim)
|
||||
xv = xv.view(bs, seqlen, num_head, head_dim)
|
||||
|
||||
logics = torch.sum(xq * xk, dim=3, keepdim=False) * 1 / (head_dim**0.5)
|
||||
prob = torch.softmax(logics, dim=1)
|
||||
prob = prob.view(bs, seqlen, num_head, 1)
|
||||
|
||||
return torch.sum(prob * xv, dim=1, keepdim=False)
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not TRITON_CUDA_SUPPORT or not HAS_TRITON, reason="triton requires cuda version to be higher than 11.4"
|
||||
)
|
||||
def test():
|
||||
Z, head_num, seq_len, head_dim = 2, 32, 2048, 128
|
||||
dtype = torch.float16
|
||||
|
||||
# attn out: 2,4096
|
||||
q = torch.empty((Z, head_num, head_dim), dtype=dtype, device="cuda").normal_(mean=0.1, std=0.2)
|
||||
k = torch.empty((Z * seq_len, head_num, head_dim), dtype=dtype, device="cuda").normal_(mean=0.4, std=0.2)
|
||||
v = torch.empty((Z * seq_len, head_num, head_dim), dtype=dtype, device="cuda").normal_(mean=0.3, std=0.2)
|
||||
o = torch.empty_like()
|
||||
# o = torch.empty((Z, head_num, head_dim), dtype=dtype, device="cuda").normal_(mean=0.3, std=0.2)
|
||||
|
||||
max_kv_cache_len = seq_len
|
||||
kv_cache_start_loc = torch.zeros((Z,), dtype=torch.int32, device="cuda")
|
||||
kv_cache_loc = torch.zeros((Z, seq_len), dtype=torch.int32, device="cuda")
|
||||
kv_cache_seq_len = torch.ones((Z,), dtype=torch.int32, device="cuda")
|
||||
other_kv_index = 2048
|
||||
|
||||
kv_cache_seq_len[:] = seq_len
|
||||
kv_cache_start_loc[0] = 0
|
||||
kv_cache_start_loc[1] = seq_len
|
||||
|
||||
for i in range(Z):
|
||||
kv_cache_loc[i, :] = torch.arange(i * seq_len, (i + 1) * seq_len, dtype=torch.int32, device="cuda")
|
||||
|
||||
Llama2TokenAttentionForwards.token_attn(
|
||||
q, k, v, o, kv_cache_loc, kv_cache_start_loc, kv_cache_seq_len, max_kv_cache_len, other_kv_index
|
||||
)
|
||||
torch_out = torch_att(q, k, v, Z, seq_len, head_num, head_dim)
|
||||
assert torch.allclose(torch_out, o, atol=1e-3, rtol=0)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test()
|
Loading…
Reference in New Issue