mirror of https://github.com/hpcaitech/ColossalAI
fix bugs in attention.py and request_handler.py
parent
bfd9b1b494
commit
47e53eaa1c
|
@ -214,9 +214,6 @@ class InferenceEngine:
|
||||||
List[str]: Decoded finished sequences generated by one step.
|
List[str]: Decoded finished sequences generated by one step.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if self.verbose:
|
|
||||||
self.logger.info("Running generation step")
|
|
||||||
|
|
||||||
output_list = []
|
output_list = []
|
||||||
batch = self.request_handler.schedule()
|
batch = self.request_handler.schedule()
|
||||||
|
|
||||||
|
@ -224,6 +221,7 @@ class InferenceEngine:
|
||||||
batch,
|
batch,
|
||||||
self.k_cahce,
|
self.k_cahce,
|
||||||
self.v_cache,
|
self.v_cache,
|
||||||
|
padding_id=self.tokenizer.pad_token_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
logits = logits[:, -1, :]
|
logits = logits[:, -1, :]
|
||||||
|
|
|
@ -110,6 +110,10 @@ class RequestHandler:
|
||||||
self.prefill_batch.init_batch(self.running_list.prefill)
|
self.prefill_batch.init_batch(self.running_list.prefill)
|
||||||
return self.prefill_batch
|
return self.prefill_batch
|
||||||
|
|
||||||
|
if not self.running_batch.is_empty:
|
||||||
|
for seq in self.running_batch.sequences_set:
|
||||||
|
self.cache_manager.allocate_token_from_block_table(seq.block_table, seq.sentence_len)
|
||||||
|
|
||||||
return self.running_batch
|
return self.running_batch
|
||||||
|
|
||||||
def add_sequence(self, req: Sequence):
|
def add_sequence(self, req: Sequence):
|
||||||
|
|
|
@ -29,47 +29,50 @@ def copy_to_cache(source, cache, lengths, block_tables, type: str = "prefill"):
|
||||||
for block_idx in range(block_num - 1):
|
for block_idx in range(block_num - 1):
|
||||||
cache[block_tables[i][block_idx]] = source[i][token_id : token_id + block_size].permute(1, 2, 0)
|
cache[block_tables[i][block_idx]] = source[i][token_id : token_id + block_size].permute(1, 2, 0)
|
||||||
token_id += block_size
|
token_id += block_size
|
||||||
cache[block_tables[i][block_num - 1]] = source[i][token_id:seq_len].permute(1, 2, 0)
|
cache[block_tables[i][block_num - 1], :, :, : seq_len - token_id] = source[i][token_id:seq_len].permute(
|
||||||
|
1, 2, 0
|
||||||
|
)
|
||||||
elif type == "decoding":
|
elif type == "decoding":
|
||||||
assert len(source[0]) == 1, "seq_len should be equal to 1 when decoding."
|
assert len(source[0]) == 1, "seq_len should be equal to 1 when decoding."
|
||||||
source = source.squeeze(1)
|
source = source.squeeze(1)
|
||||||
slot_idx = (lengths + block_size - 1) % block_size
|
slot_idx = (lengths + block_size - 1) % block_size
|
||||||
for i in range(bsz):
|
for i in range(bsz):
|
||||||
cache[block_tables[i, needed_blocks[i] - 1], :, :, slot_idx[i]] = source[i].permute(0, 1)
|
cache[block_tables[i, needed_blocks[i] - 1], :, :, slot_idx[i]] = source[i]
|
||||||
|
|
||||||
return cache
|
return cache
|
||||||
|
|
||||||
|
|
||||||
def convert_kvcache(source, cache, lengths, block_tables):
|
def convert_kvcache(cache, lengths, block_tables):
|
||||||
"""
|
"""
|
||||||
Func: convert key/value cache for calculation
|
Func: convert key/value cache for calculation
|
||||||
|
|
||||||
Args: key/value(source): shape [bsz, 1, num_heads, head_size]
|
Args: cache: shape [num_blocks, num_heads, head_size, block_size]
|
||||||
cache: shape [num_blocks, num_heads, head_size, block_size]
|
|
||||||
lengths: key/value length
|
lengths: key/value length
|
||||||
block_tables
|
block_tables
|
||||||
"""
|
"""
|
||||||
num_blocks, num_heads, head_size, block_size = cache.shape
|
num_blocks, num_heads, head_size, block_size = cache.shape
|
||||||
|
|
||||||
needed_blocks = (lengths + block_size - 1) // block_size
|
needed_blocks = (lengths + block_size - 1) // block_size
|
||||||
num_remaing_tokens = (lengths - 1) % block_size
|
num_remaing_tokens = lengths % block_size
|
||||||
|
num_remaing_tokens[num_remaing_tokens == 0] += block_size
|
||||||
bsz = block_tables.shape[0]
|
bsz = block_tables.shape[0]
|
||||||
seq_len = max(lengths)
|
seq_len = max(lengths)
|
||||||
padded_cache = []
|
padded_cache = []
|
||||||
for i in range(bsz):
|
for i in range(bsz):
|
||||||
|
cache1 = cache[block_tables[i][: needed_blocks[i] - 1]].permute((0, 3, 1, 2)).reshape(-1, num_heads, head_size)
|
||||||
|
cache2 = cache[block_tables[i][needed_blocks[i] - 1], :, :, : num_remaing_tokens[i]].permute(2, 0, 1)
|
||||||
|
|
||||||
_cache = torch.cat(
|
_cache = torch.cat(
|
||||||
(
|
(
|
||||||
cache[block_tables[i][: needed_blocks[i] - 1]].permute((3, 0, 1, 2)).reshape(-1, num_heads, head_size),
|
cache1,
|
||||||
cache[block_tables[i][needed_blocks[i] - 1], :, :, : num_remaing_tokens[i]].permute(2, 1, 0),
|
cache2,
|
||||||
),
|
),
|
||||||
dim=0,
|
dim=0,
|
||||||
)
|
)
|
||||||
concat_cache = torch.cat((_cache, source[i]), dim=0)
|
padding = seq_len - _cache.size(0)
|
||||||
padding = seq_len - concat_cache.size(0)
|
|
||||||
if padding > 0:
|
if padding > 0:
|
||||||
concat_cache = F.pad(concat_cache, (0, 0, 0, 0, 0, 1))
|
_cache = F.pad(_cache, (0, 0, 0, 0, 0, 1))
|
||||||
padded_cache.append(concat_cache)
|
padded_cache.append(_cache)
|
||||||
|
|
||||||
return torch.stack(padded_cache, dim=0)
|
return torch.stack(padded_cache, dim=0)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -1,11 +1,22 @@
|
||||||
# This code is adapted from huggingface transformers: https://github.com/huggingface/transformers/blob/v4.34.1/src/transformers/models/llama/modeling_llama.py
|
# This code is adapted from huggingface transformers: https://github.com/huggingface/transformers/blob/v4.34.1/src/transformers/models/llama/modeling_llama.py
|
||||||
|
import math
|
||||||
from typing import List, Optional, Tuple
|
from typing import List, Optional, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from transformers.models.llama.modeling_llama import LlamaAttention, LlamaDecoderLayer, LlamaForCausalLM, LlamaModel
|
import torch.nn as nn
|
||||||
|
from transformers.modeling_attn_mask_utils import AttentionMaskConverter
|
||||||
|
from transformers.models.llama.modeling_llama import (
|
||||||
|
LlamaAttention,
|
||||||
|
LlamaDecoderLayer,
|
||||||
|
LlamaForCausalLM,
|
||||||
|
LlamaModel,
|
||||||
|
repeat_kv,
|
||||||
|
)
|
||||||
|
|
||||||
|
from colossalai.inference.modeling.layers.attention import convert_kvcache, copy_to_cache
|
||||||
from colossalai.inference.struct import BatchInfo
|
from colossalai.inference.struct import BatchInfo
|
||||||
from colossalai.kernel.triton import context_attention_unpadded
|
|
||||||
|
from flash_attn.bert_padding import index_first_axis, pad_input # noqa
|
||||||
|
|
||||||
|
|
||||||
def rotate_half(x):
|
def rotate_half(x):
|
||||||
|
@ -27,24 +38,12 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids):
|
||||||
return q_embed, k_embed
|
return q_embed, k_embed
|
||||||
|
|
||||||
|
|
||||||
# Copied from transformers.models.llama.modeling_llama.repeat_kv
|
|
||||||
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
|
|
||||||
"""
|
|
||||||
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
|
|
||||||
num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
|
|
||||||
"""
|
|
||||||
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
|
|
||||||
if n_rep == 1:
|
|
||||||
return hidden_states
|
|
||||||
hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
|
|
||||||
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
|
|
||||||
|
|
||||||
|
|
||||||
def llama_causal_lm_forward(
|
def llama_causal_lm_forward(
|
||||||
self: LlamaForCausalLM,
|
self: LlamaForCausalLM,
|
||||||
batch: BatchInfo = None,
|
batch: BatchInfo = None,
|
||||||
k_caches: List[torch.Tensor] = None,
|
k_caches: List[torch.Tensor] = None,
|
||||||
v_caches: List[torch.Tensor] = None,
|
v_caches: List[torch.Tensor] = None,
|
||||||
|
padding_id: int = None,
|
||||||
):
|
):
|
||||||
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
||||||
hidden_states = llama_model_forward(
|
hidden_states = llama_model_forward(
|
||||||
|
@ -52,6 +51,7 @@ def llama_causal_lm_forward(
|
||||||
batch=batch,
|
batch=batch,
|
||||||
k_caches=k_caches,
|
k_caches=k_caches,
|
||||||
v_caches=v_caches,
|
v_caches=v_caches,
|
||||||
|
padding_id=padding_id,
|
||||||
)
|
)
|
||||||
logits = self.lm_head(hidden_states)
|
logits = self.lm_head(hidden_states)
|
||||||
return logits
|
return logits
|
||||||
|
@ -62,13 +62,20 @@ def llama_model_forward(
|
||||||
batch: BatchInfo = None,
|
batch: BatchInfo = None,
|
||||||
k_caches: List[torch.Tensor] = None,
|
k_caches: List[torch.Tensor] = None,
|
||||||
v_caches: List[torch.Tensor] = None,
|
v_caches: List[torch.Tensor] = None,
|
||||||
|
padding_id: int = None,
|
||||||
):
|
):
|
||||||
input_ids = batch.get_batch_inputs()
|
input_ids = batch.get_batch_inputs()
|
||||||
block_tables = batch.get_block_table_tensor()
|
block_tables = batch.get_block_table_tensor()
|
||||||
sequence_lengths = batch.get_sequence_lengths()
|
sequence_lengths = batch.get_sequence_lengths()
|
||||||
|
|
||||||
# Here, we generate position_ids through the input tensor, which can align with the output precision of the transformer.
|
attention_mask = batch.get_attn_mask(padding_id)
|
||||||
position_ids = generate_padding_position_id(input_ids)
|
|
||||||
|
if batch.is_prompts:
|
||||||
|
# Here, we generate position_ids through the input tensor, which can align with the output precision of the transformer.
|
||||||
|
position_ids = generate_padding_position_id(attention_mask)
|
||||||
|
else:
|
||||||
|
position_ids = (attention_mask.sum(dim=-1) - 1).reshape(-1, 1)
|
||||||
|
|
||||||
hidden_states = self.embed_tokens(input_ids)
|
hidden_states = self.embed_tokens(input_ids)
|
||||||
|
|
||||||
for layer_id, decoder_layer in enumerate(self.layers):
|
for layer_id, decoder_layer in enumerate(self.layers):
|
||||||
|
@ -80,6 +87,7 @@ def llama_model_forward(
|
||||||
v_cache=v_caches[layer_id],
|
v_cache=v_caches[layer_id],
|
||||||
is_prompts=batch.is_prompts,
|
is_prompts=batch.is_prompts,
|
||||||
sequence_lengths=sequence_lengths,
|
sequence_lengths=sequence_lengths,
|
||||||
|
attention_mask=attention_mask,
|
||||||
)
|
)
|
||||||
|
|
||||||
hidden_states = self.norm(hidden_states)
|
hidden_states = self.norm(hidden_states)
|
||||||
|
@ -96,6 +104,7 @@ def llama_decoder_layer_forward(
|
||||||
v_cache: torch.Tensor = None,
|
v_cache: torch.Tensor = None,
|
||||||
is_prompts: bool = True,
|
is_prompts: bool = True,
|
||||||
sequence_lengths: int = None,
|
sequence_lengths: int = None,
|
||||||
|
attention_mask: torch.Tensor = None,
|
||||||
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
||||||
residual = hidden_states
|
residual = hidden_states
|
||||||
|
|
||||||
|
@ -109,6 +118,7 @@ def llama_decoder_layer_forward(
|
||||||
v_cache=v_cache,
|
v_cache=v_cache,
|
||||||
is_prompts=is_prompts,
|
is_prompts=is_prompts,
|
||||||
sequence_lengths=sequence_lengths,
|
sequence_lengths=sequence_lengths,
|
||||||
|
attention_mask=attention_mask,
|
||||||
)
|
)
|
||||||
|
|
||||||
hidden_states = residual + hidden_states
|
hidden_states = residual + hidden_states
|
||||||
|
@ -132,6 +142,7 @@ def llama_attn_forward(
|
||||||
v_cache: torch.Tensor = None,
|
v_cache: torch.Tensor = None,
|
||||||
is_prompts: bool = True,
|
is_prompts: bool = True,
|
||||||
sequence_lengths: torch.Tensor = None,
|
sequence_lengths: torch.Tensor = None,
|
||||||
|
attention_mask: torch.Tensor = None,
|
||||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||||
bsz, q_len, _ = hidden_states.size()
|
bsz, q_len, _ = hidden_states.size()
|
||||||
|
|
||||||
|
@ -139,9 +150,7 @@ def llama_attn_forward(
|
||||||
key_states = self.k_proj(hidden_states).view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
key_states = self.k_proj(hidden_states).view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||||
value_states = self.v_proj(hidden_states).view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
value_states = self.v_proj(hidden_states).view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||||
|
|
||||||
kv_seq_len = key_states.shape[-2]
|
kv_seq_len = sequence_lengths[0].item()
|
||||||
if not is_prompts:
|
|
||||||
kv_seq_len = kv_seq_len + sequence_lengths[0].item()
|
|
||||||
|
|
||||||
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
||||||
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
|
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
|
||||||
|
@ -153,20 +162,26 @@ def llama_attn_forward(
|
||||||
key_states = key_states.transpose(1, 2)
|
key_states = key_states.transpose(1, 2)
|
||||||
value_states = value_states.transpose(1, 2)
|
value_states = value_states.transpose(1, 2)
|
||||||
|
|
||||||
query_states = query_states.view(-1, self.num_heads, self.head_dim)
|
|
||||||
key_states = key_states.view(-1, self.num_heads, self.head_dim)
|
|
||||||
value_states = value_states.view(-1, self.num_heads, self.head_dim)
|
|
||||||
|
|
||||||
_, _, _, block_size = k_cache.shape
|
|
||||||
|
|
||||||
# NOTE: context_attention_unpadded is used for testing accuracy and we can only use aligned inputs.
|
|
||||||
# The code below will be uncommented after the development of attention-related kernel is completed.
|
|
||||||
if is_prompts:
|
if is_prompts:
|
||||||
attn_output = context_attention_unpadded(
|
attn_output = pad_context_forward(
|
||||||
query_states, key_states, value_states, k_cache, v_cache, sequence_lengths, block_tables, block_size
|
query_states, key_states, value_states, k_cache, v_cache, sequence_lengths, block_tables, attention_mask
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
||||||
|
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
||||||
|
attn_output = pad_decoding_forward(
|
||||||
|
query_states,
|
||||||
|
key_states,
|
||||||
|
value_states,
|
||||||
|
k_cache,
|
||||||
|
v_cache,
|
||||||
|
sequence_lengths,
|
||||||
|
block_tables,
|
||||||
|
attention_mask,
|
||||||
|
self.layer_idx,
|
||||||
|
self.attention_dropout,
|
||||||
|
self.training,
|
||||||
)
|
)
|
||||||
# else:
|
|
||||||
# attn_output = context_attention_unpadded(query_states, key_states, value_states, k_cache, v_cache, sequence_lengths, block_tables, block_size)
|
|
||||||
|
|
||||||
attn_output = attn_output.view(bsz, q_len, self.num_heads, self.head_dim)
|
attn_output = attn_output.view(bsz, q_len, self.num_heads, self.head_dim)
|
||||||
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
|
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
|
||||||
|
@ -175,13 +190,129 @@ def llama_attn_forward(
|
||||||
return attn_output
|
return attn_output
|
||||||
|
|
||||||
|
|
||||||
def generate_padding_position_id(input_ids: torch.Tensor) -> torch.Tensor:
|
def generate_padding_position_id(attention_mask: torch.Tensor) -> torch.Tensor:
|
||||||
# Replace this code and use a more flexible method to obtain padding_id, avoiding directly setting padding_id like this.
|
|
||||||
padding_id = 2
|
|
||||||
attention_mask = input_ids.ne(padding_id).long()
|
|
||||||
position_ids = attention_mask.long().cumsum(-1) - 1
|
position_ids = attention_mask.long().cumsum(-1) - 1
|
||||||
position_ids.masked_fill_(attention_mask == 0, 1)
|
position_ids.masked_fill_(attention_mask == 0, 1)
|
||||||
return position_ids
|
return position_ids
|
||||||
|
|
||||||
# def unpad_inputs(input_ids: torch.Tensor):
|
|
||||||
|
def unpading_input(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, attention_mask: torch.Tensor):
|
||||||
|
seqlens = attention_mask.sum(dim=-1, dtype=torch.int32)
|
||||||
|
indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
|
||||||
|
batch_size, kv_seq_len, num_key_value_heads, head_dim = q.shape
|
||||||
|
q = index_first_axis(q.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices)
|
||||||
|
k = index_first_axis(k.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices)
|
||||||
|
v = index_first_axis(v.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices)
|
||||||
|
return (q, k, v, indices, seqlens)
|
||||||
|
|
||||||
|
|
||||||
|
def pad_decoding_forward(
|
||||||
|
query: torch.Tensor, # [bsz, 1, num_heads, head_size]
|
||||||
|
key: torch.Tensor,
|
||||||
|
value: torch.Tensor,
|
||||||
|
k_cache: torch.Tensor, # [num_blocks, num_heads, head_size, block_size]
|
||||||
|
v_cache: torch.Tensor,
|
||||||
|
lengths: torch.Tensor, # [num_seqs]: input_lengths + output_lengths
|
||||||
|
block_tables: torch.Tensor, # [num_seqs,max_blocks_per_sequence]
|
||||||
|
attn_mask: torch.Tensor = None,
|
||||||
|
layer_id: int = 0,
|
||||||
|
attention_dropout: float = None,
|
||||||
|
training: bool = False,
|
||||||
|
):
|
||||||
|
bsz, query_length, num_heads, head_size = query.shape
|
||||||
|
seq_len = max(lengths)
|
||||||
|
|
||||||
|
copy_to_cache(key, k_cache, lengths=lengths, block_tables=block_tables, type="decoding")
|
||||||
|
copy_to_cache(value, v_cache, lengths=lengths, block_tables=block_tables, type="decoding")
|
||||||
|
|
||||||
|
key = convert_kvcache(k_cache, lengths, block_tables) # bsz, seqlen,
|
||||||
|
value = convert_kvcache(v_cache, lengths, block_tables)
|
||||||
|
|
||||||
|
query = query.transpose(1, 2)
|
||||||
|
key = key.transpose(1, 2)
|
||||||
|
value = value.transpose(1, 2)
|
||||||
|
|
||||||
|
attn_weights = torch.matmul(query, key.transpose(2, 3)) / math.sqrt(head_size)
|
||||||
|
if attn_weights.size() != (bsz, num_heads, 1, seq_len):
|
||||||
|
raise ValueError(f"Got wrong attn_weights, should be in shape {(bsz,num_heads,1,seq_len)}.")
|
||||||
|
|
||||||
|
if attn_mask is not None:
|
||||||
|
padding_mask = AttentionMaskConverter._expand_mask(attn_mask, query.dtype, query_length)
|
||||||
|
|
||||||
|
attn_mask = AttentionMaskConverter._make_causal_mask(
|
||||||
|
(bsz, query_length), query.dtype, query.device, past_key_values_length=seq_len - query_length
|
||||||
|
)
|
||||||
|
|
||||||
|
if padding_mask is not None:
|
||||||
|
attn_mask = attn_mask.masked_fill(padding_mask.bool(), torch.finfo(query.dtype).min)
|
||||||
|
|
||||||
|
attn_weights += attn_mask
|
||||||
|
|
||||||
|
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
|
||||||
|
attn_weights = nn.functional.dropout(attn_weights, p=attention_dropout, training=training)
|
||||||
|
attn_output = torch.matmul(attn_weights, value)
|
||||||
|
|
||||||
|
if attn_output.size() != (bsz, num_heads, 1, head_size):
|
||||||
|
raise ValueError(f"Got wrong attn_output, should be in shape {(bsz,num_heads,1,head_size)}.")
|
||||||
|
attn_output = attn_output.transpose(1, 2).contiguous().reshape(bsz, 1, -1)
|
||||||
|
|
||||||
|
return attn_output
|
||||||
|
|
||||||
|
|
||||||
|
def pad_context_forward(
|
||||||
|
q: torch.Tensor, # [batch_size, seq_len, num_heads, head_size]
|
||||||
|
k: torch.Tensor, # [batch_size, seq_len, num_kv_heads, head_size]
|
||||||
|
v: torch.Tensor,
|
||||||
|
k_cache: torch.Tensor, # [num_blocks, num_heads, head_size, block_size]
|
||||||
|
v_cache: torch.Tensor,
|
||||||
|
context_lengths: torch.Tensor, # [num_seqs]
|
||||||
|
block_tables: torch.Tensor, # [num_seqs,max_blocks_per_sequence]
|
||||||
|
attn_mask: torch.Tensor = None,
|
||||||
|
):
|
||||||
|
# Firt, do shape verification
|
||||||
|
bsz, seq_len, num_heads, head_size = q.shape
|
||||||
|
num_kv_heads = k.shape[-2]
|
||||||
|
assert num_heads % num_kv_heads == 0, "num_kv_heads should be divisible by num_heads"
|
||||||
|
num_kv_groups = num_heads // num_kv_heads
|
||||||
|
block_size = k_cache.shape[-1]
|
||||||
|
assert q.shape[0] == k.shape[0] == v.shape[0] == block_tables.shape[0]
|
||||||
|
block_tables.shape[-1] * block_size
|
||||||
|
shape = (bsz, seq_len, num_heads, head_size)
|
||||||
|
input_shape = shape[:2]
|
||||||
|
|
||||||
|
# Copy kv to memory(rotary embedded)
|
||||||
|
copy_to_cache(k, k_cache, lengths=context_lengths, block_tables=block_tables)
|
||||||
|
copy_to_cache(v, v_cache, lengths=context_lengths, block_tables=block_tables)
|
||||||
|
|
||||||
|
q = q.transpose(1, 2)
|
||||||
|
k = repeat_kv(k.transpose(1, 2), num_kv_groups)
|
||||||
|
v = repeat_kv(v.transpose(1, 2), num_kv_groups)
|
||||||
|
|
||||||
|
attn_weights = torch.matmul(q, k.transpose(2, 3)) / math.sqrt(head_size)
|
||||||
|
|
||||||
|
if attn_mask is not None:
|
||||||
|
padding_mask = AttentionMaskConverter._expand_mask(attn_mask, q.dtype, seq_len)
|
||||||
|
|
||||||
|
attn_mask = AttentionMaskConverter._make_causal_mask(
|
||||||
|
(bsz, seq_len), q.dtype, q.device, past_key_values_length=seq_len - seq_len
|
||||||
|
)
|
||||||
|
|
||||||
|
if padding_mask is not None:
|
||||||
|
attn_mask = attn_mask.masked_fill(padding_mask.bool(), torch.finfo(q.dtype).min)
|
||||||
|
|
||||||
|
if attn_weights.size() != (bsz, num_heads, seq_len, seq_len):
|
||||||
|
raise ValueError(f"Got wrong attn_weights, should be in shape {(bsz,num_heads,seq_len,seq_len)}.")
|
||||||
|
if attn_mask is not None:
|
||||||
|
attn_weights += attn_mask
|
||||||
|
|
||||||
|
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(q.dtype)
|
||||||
|
attn_output = torch.matmul(attn_weights, v)
|
||||||
|
|
||||||
|
if attn_output.size() != (bsz, num_heads, seq_len, head_size):
|
||||||
|
raise ValueError(f"Got wrong attn_output, should be in shape {(bsz,num_heads,seq_len,head_size)}.")
|
||||||
|
|
||||||
|
attn_output = attn_output.transpose(1, 2).contiguous().reshape(bsz, seq_len, -1)
|
||||||
|
|
||||||
|
del attn_weights
|
||||||
|
|
||||||
|
return attn_output
|
||||||
|
|
|
@ -321,5 +321,13 @@ class BatchInfo:
|
||||||
|
|
||||||
return torch.tensor(len_list, dtype=torch.int, device=self.device)
|
return torch.tensor(len_list, dtype=torch.int, device=self.device)
|
||||||
|
|
||||||
|
def get_attn_mask(self, padding_id: int) -> torch.Tensor:
|
||||||
|
past_values = []
|
||||||
|
|
||||||
|
for seq in self.sequences_set:
|
||||||
|
past_values.append(seq.input_token_id + seq.output_token_id)
|
||||||
|
|
||||||
|
return torch.tensor(past_values, dtype=torch.int, device=self.device).ne(padding_id).long()
|
||||||
|
|
||||||
def __repr__(self) -> str:
|
def __repr__(self) -> str:
|
||||||
return f"(sequences_set={self.sequences_set}, " f"is_prompts={self.is_prompts})"
|
return f"(sequences_set={self.sequences_set}, " f"is_prompts={self.is_prompts})"
|
||||||
|
|
|
@ -9,7 +9,7 @@ from transformers import AutoTokenizer, GenerationConfig
|
||||||
import colossalai
|
import colossalai
|
||||||
from colossalai.inference.config import InferenceConfig
|
from colossalai.inference.config import InferenceConfig
|
||||||
from colossalai.inference.core.engine import InferenceEngine
|
from colossalai.inference.core.engine import InferenceEngine
|
||||||
from colossalai.testing import spawn
|
from colossalai.testing import rerun_if_address_is_in_use, spawn
|
||||||
|
|
||||||
|
|
||||||
def setup_seed(seed):
|
def setup_seed(seed):
|
||||||
|
@ -24,21 +24,24 @@ def check_inference_engine(test_cai=False):
|
||||||
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/llama-tokenizer")
|
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/llama-tokenizer")
|
||||||
model = transformers.LlamaForCausalLM(
|
model = transformers.LlamaForCausalLM(
|
||||||
transformers.LlamaConfig(
|
transformers.LlamaConfig(
|
||||||
vocab_size=50000, hidden_size=512, intermediate_size=1536, num_attention_heads=4, num_hidden_layers=4
|
vocab_size=50000, hidden_size=512, intermediate_size=1536, num_attention_heads=4, num_hidden_layers=16
|
||||||
)
|
)
|
||||||
).cuda()
|
).cuda()
|
||||||
|
|
||||||
inputs = [
|
inputs = [
|
||||||
"介绍一下北京,",
|
"介绍一下今天的北京,",
|
||||||
"介绍一下武汉,",
|
"介绍一下武汉,",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
output_len = 16
|
||||||
|
do_sample = True
|
||||||
|
|
||||||
if test_cai:
|
if test_cai:
|
||||||
inference_config = InferenceConfig(max_output_len=1)
|
inference_config = InferenceConfig(max_output_len=output_len)
|
||||||
inference_engine = InferenceEngine(model, tokenizer, inference_config, verbose=True)
|
inference_engine = InferenceEngine(model, tokenizer, inference_config, verbose=True)
|
||||||
inference_engine.add_request(prompts=inputs)
|
inference_engine.add_request(prompts=inputs)
|
||||||
assert inference_engine.request_handler._has_waiting()
|
assert inference_engine.request_handler._has_waiting()
|
||||||
generation_config = GenerationConfig(do_sample=True, top_p=0.5, top_k=50)
|
generation_config = GenerationConfig(do_sample=do_sample, top_p=0.5, top_k=50)
|
||||||
outputs = inference_engine.generate(generation_config)
|
outputs = inference_engine.generate(generation_config)
|
||||||
else:
|
else:
|
||||||
tokenizer.pad_token = tokenizer.eos_token
|
tokenizer.pad_token = tokenizer.eos_token
|
||||||
|
@ -46,7 +49,7 @@ def check_inference_engine(test_cai=False):
|
||||||
inputs = tokenizer.batch_encode_plus(inputs, padding=True, return_tensors="pt")["input_ids"]
|
inputs = tokenizer.batch_encode_plus(inputs, padding=True, return_tensors="pt")["input_ids"]
|
||||||
inputs = inputs.cuda()
|
inputs = inputs.cuda()
|
||||||
generation_config = GenerationConfig(
|
generation_config = GenerationConfig(
|
||||||
do_sample=True, top_p=0.5, top_k=50, pad_token_id=tokenizer.pad_token_id, max_new_tokens=1
|
do_sample=do_sample, top_p=0.5, top_k=50, pad_token_id=tokenizer.pad_token_id, max_new_tokens=output_len
|
||||||
)
|
)
|
||||||
outputs = model.generate(inputs, generation_config=generation_config)
|
outputs = model.generate(inputs, generation_config=generation_config)
|
||||||
outputs = tokenizer.batch_decode(outputs, skip_special_tokens=True)
|
outputs = tokenizer.batch_decode(outputs, skip_special_tokens=True)
|
||||||
|
@ -64,6 +67,7 @@ def run_dist(rank, world_size, port):
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.dist
|
@pytest.mark.dist
|
||||||
|
@rerun_if_address_is_in_use()
|
||||||
def test_inference_engine():
|
def test_inference_engine():
|
||||||
spawn(run_dist, 1)
|
spawn(run_dist, 1)
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue