Fixed a bug in the inference frame

pull/5258/head
yuehuayingxueluo 2023-12-26 21:34:27 +08:00 committed by FrankLeeeee
parent 86853a37d5
commit 62fd08ee44
8 changed files with 261 additions and 90 deletions

View File

@ -97,3 +97,6 @@ class InferenceConfig:
], "dtype should be one of 'fp16', 'fp32', 'bf16', torch.float32, torch.float16, torch.bfloat16"
assert self.max_batch_size <= 64, "Max batch size exceeds the constraint"
assert self.quant_mode in ["smoothquant", "gptq", None], "quant should be one of 'smoothquant', 'gptq'"
assert (
self.max_input_len + self.max_output_len <= self.max_seq_len
), "The sum of max_input_len and max_output_len must be smaller than max_seq_len."

View File

@ -49,6 +49,7 @@ class InferenceEngine:
self.tokenizer.pad_token = self.tokenizer.eos_token
self.inference_config = inference_config
self.model_config = model.config
self.device = torch.device("cuda")
if inference_config.dtype == "fp32" or inference_config.dtype == torch.float32:
self.dtype = torch.float32
@ -76,6 +77,7 @@ class InferenceEngine:
self.logger = get_dist_logger(__name__)
self.request_handler = RequestHandler(self.inference_config, self.model_config)
self.k_cahce, self.v_cache = self.request_handler.get_kvcache()
self.counter = count()
def _verify_config(self) -> None:
@ -170,7 +172,11 @@ class InferenceEngine:
if prompts_token_ids is None:
assert prompts, "When the prompts_token_ids is none, the input prompt list must be provided."
prompts_token_ids = self.tokenizer.batch_encode_plus(prompts)["input_ids"]
prompts_token_ids = self.tokenizer.batch_encode_plus(prompts, padding=True)["input_ids"]
assert (
len(prompts_token_ids[0]) < self.inference_config.max_input_len
), "The length of input prompts must be less than max_input_len."
prompts_num = len(prompts_token_ids)
@ -183,13 +189,14 @@ class InferenceEngine:
prompt = None
else:
prompt = prompts[i]
block_table = torch.full([self.inference_config.max_seq_len], -1, device=self.device)
sequence = Sequence(
request_id,
prompt,
prompts_token_ids[i],
block_size,
None,
None,
block_table,
self.tokenizer.eos_token_id,
self.inference_config.max_output_len,
)
@ -211,14 +218,15 @@ class InferenceEngine:
self.logger.info("Running generation step")
output_list = []
batch, k_cache, v_cache = self.request_handler.schedule()
batch = self.request_handler.schedule()
logits = self.model(
batch,
k_cache,
v_cache,
self.k_cahce,
self.v_cache,
)
self.request_handler.search_tokens(logits, self.generation_config)
self.request_handler.search_tokens(self.generation_config, logits)
finished_sequences = self.request_handler.update()

View File

@ -5,7 +5,6 @@ from transformers.configuration_utils import PretrainedConfig
from colossalai.inference.config import InferenceConfig
from colossalai.inference.kv_cache import KVCacheManager
from colossalai.inference.logit_processors import logit_processor
from colossalai.inference.sampler import *
from colossalai.inference.struct import BatchInfo, Sequence
@ -49,7 +48,7 @@ class RunningList:
def ready_for_prefill(self):
if not self.decoding:
return len(self.prefill) > 0
return len(self.prefill) / len(self.decoding) >= self.ratio
return len(self.prefill) / len(self.decoding) >= self.prefill_ratio
def is_empty(self):
return not self.decoding and not self.prefill
@ -72,8 +71,9 @@ class RequestHandler:
self.running_list: RunningList = RunningList(inference_config.prefill_ratio)
self.waiting_list: List[List] = [[], [], []]
self.done_list: List[Sequence] = []
self.running_batch = BatchInfo(is_prompts=False)
self.prefill_batch = BatchInfo(is_prompts=True)
device = torch.cuda.current_device()
self.running_batch = BatchInfo(is_prompts=False, device=device)
self.prefill_batch = BatchInfo(is_prompts=True, device=device)
def _init_cache(self, model_config):
self.cache_manager = KVCacheManager(self.inference_config, model_config)
@ -81,6 +81,9 @@ class RequestHandler:
def _has_waiting(self) -> bool:
return any(lst for lst in self.waiting_list)
def get_kvcache(self):
return self.cache_manager.get_kv_cache()
def schedule(self):
"""
The main logic of request handler.
@ -90,7 +93,7 @@ class RequestHandler:
for lst in reversed(self.waiting_list):
if lst:
for seq in lst:
if seq.prompt_len > self.inference_config.max_input_len:
if seq.input_len > self.inference_config.max_input_len:
# If the prompt length is longer than max_input_len, abort the sequence.
self.abort_sequence(seq.request_id)
break
@ -98,9 +101,8 @@ class RequestHandler:
if self.cache_manager.check_allocation(seq):
# If succeed, add the sequence to running list.
self.running_list.append(seq)
self.cache_manager.allocate_context_from_block_table(seq.block_table, seq.prompt_len)
lst.remove(seq)
self.cache_manager.allocate_context_from_block_table(seq.block_table, seq.input_len)
lst.clear()
if self.running_list.ready_for_prefill():
for seq in self.running_list.prefill:
seq.mark_running()
@ -115,10 +117,9 @@ class RequestHandler:
"""
assert not self._find_sequence(req.request_id), f"Sequence {req.request_id} already exists."
assert (
req.prompt_len < self.inference_config.max_input_len
req.input_len < self.inference_config.max_input_len
), f"Sequence {req.request_id} exceeds input length limit"
self.waiting_list[req.prompt_len * 3 // self.inference_config.max_input_len].append(req)
self.waiting_list[req.input_len * 3 // self.inference_config.max_input_len].append(req)
def abort_sequence(self, request_id: str):
"""
@ -178,9 +179,12 @@ class RequestHandler:
"""
# do logit processor
# NOTE: need to decide the granularity to process logits (sequence or batch)
for type in ["top_p", "top_k", "min_p"]:
if type in generation_config:
logits = logit_processor(type, logits)
# for type in ["top_p", "top_k", "min_p"]:
# config_dict = generation_config.to_dict()
# if type in config_dict:
# logits = logit_processor(type, logits, config_dict[type])
torch.cuda.synchronize()
# calculate probs
probs = torch.softmax(logits, dim=-1, dtype=torch.float)
@ -188,7 +192,10 @@ class RequestHandler:
# sample the next tokens
sample_tokens = self._sample(probs, logprobs, generation_config)
self.running_batch.update_batch_tokens(sample_tokens)
if not self.prefill_batch.is_empty:
self.prefill_batch.update_batch_tokens(sample_tokens)
else:
self.running_batch.update_batch_tokens(sample_tokens)
def update(self):
"""

View File

@ -112,7 +112,7 @@ class KVCacheManager:
def get_kv_cache(self):
"""Get k_cache and v_cache"""
return self._kv_cache[0], self._kv_cache[1]
return self._kv_caches[0], self._kv_caches[1]
def get_max_blocks_per_sequence(self) -> int:
"""Get the maximum number of blocks that can be allocated for a single sequence."""
@ -122,7 +122,7 @@ class KVCacheManager:
return self.max_blocks_per_sequence
def check_allocation(self, seq: Sequence) -> bool:
num_blocks_needed = (seq.prompt_len + self.max_output_length + self.block_size - 1) // self.block_size
num_blocks_needed = (seq.input_len + self.max_output_length + self.block_size - 1) // self.block_size
return num_blocks_needed <= self.num_available_blocks
def get_block_kv_ptrs(self, block_id: int, layer_id: int) -> Tuple[List[int], List[int]]:

View File

@ -70,7 +70,10 @@ def llama_model_forward(
seq_length = input_ids.shape[1]
device = input_ids.device
past_key_values_length = len(block_tables.shape[1])
if batch.is_prompts:
past_key_values_length = 0
else:
past_key_values_length = sequence_lengths[0].item() - 1
position_ids = torch.arange(
past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
@ -163,26 +166,17 @@ def llama_attn_forward(
key_states = key_states.view(-1, self.num_heads, self.head_dim)
value_states = value_states.view(-1, self.num_heads, self.head_dim)
block_size = k_cache.shape[-1]
k_cache.shape[-1]
memcpy_to_block(key_states, value_states, k_cache, v_cache, block_tables, block_size)
# memcpy_to_block(key_states, value_states, k_cache, v_cache, block_tables, block_size, sequence_lengths)
if is_prompts:
attn_output = context_attention_unpadded(
query_states, key_states, value_states, k_cache, v_cache, sequence_lengths, block_tables, block_size
)
else:
attn_output = torch.empty(bsz, self.num_heads, self.head_dim)
decoding_attention(
query_states,
k_cache,
v_cache,
block_tables,
sequence_lengths,
attn_output,
block_tables.shape[1],
block_size,
)
# if is_prompts:
# attn_output = context_attention_unpadded(query_states, key_states, value_states, k_cache, v_cache, sequence_lengths, block_tables, block_size)
# else:
# attn_output = torch.empty(bsz, self.num_heads, self.head_dim)
# decoding_attention(query_states, k_cache, v_cache, block_tables, sequence_lengths, attn_output, block_tables.shape[1], block_size)
attn_output = query_states
attn_output = attn_output.view(bsz, q_len, self.num_heads, self.head_dim)
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
@ -190,19 +184,3 @@ def llama_attn_forward(
attn_output = self.o_proj(attn_output)
return attn_output
def memcpy_to_block(key, value, k_cache, v_cache, block_tables, block_size):
block_table_list = block_tables.tolist()
batch_size, seq_len, num_heads, head_dim = key
reshape_key = key.reshape(batch_size, seq_len, block_size, num_heads, head_dim).tensor.permute(0, 2, 3, 1)
reshape_value = value.reshape(batch_size, seq_len, block_size, num_heads, head_dim).tensor.permute(0, 2, 3, 1)
if seq_len == 1:
for i in range(batch_size):
k_cache[block_table_list[i][-1], :] = reshape_key[i]
v_cache[block_table_list[i][-1], :] = reshape_value[i]
else:
for i in range(batch_size):
k_cache[block_table_list[i], :] = reshape_key[i]
v_cache[block_table_list[i], :] = reshape_value[i]

View File

@ -1,7 +1,165 @@
from functools import partial
from transformers.models.llama.modeling_llama import (
LlamaAttention,
LlamaDecoderLayer,
LlamaFlashAttention2,
LlamaForCausalLM,
LlamaModel,
LlamaSdpaAttention,
)
from colossalai.inference.modeling.models.llama import (
llama_attn_forward,
llama_causal_lm_forward,
llama_decoder_layer_forward,
llama_model_forward,
)
from colossalai.shardformer.policies.base_policy import ModulePolicyDescription, SubModuleReplacementDescription
# import colossalai
from colossalai.shardformer.policies.llama import LlamaForCausalLMPolicy
class LlamaModelInferPolicy(LlamaForCausalLMPolicy):
# The code here just for test and will be modified later.
def __init__(self) -> None:
super().__init__()
def module_policy(self):
policy = super().module_policy()
decoder_attribute_replacement = {
"self_attn.hidden_size": self.model.config.hidden_size // self.shard_config.tensor_parallel_size,
"self_attn.num_heads": self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size,
"self_attn.num_key_value_heads": self.model.config.num_key_value_heads
// self.shard_config.tensor_parallel_size,
}
if self.shard_config.extra_kwargs.get("quant", None) == "gptq":
from colossalai.inference.quant.gptq.cai_gptq import ColCaiQuantLinear, RowCaiQuantLinear
policy[LlamaDecoderLayer] = ModulePolicyDescription(
attribute_replacement=decoder_attribute_replacement,
sub_module_replacement=[
SubModuleReplacementDescription(
suffix="self_attn.q_proj",
target_module=ColCaiQuantLinear,
kwargs={"split_num": 1},
),
SubModuleReplacementDescription(
suffix="self_attn.k_proj",
target_module=ColCaiQuantLinear,
kwargs={"split_num": 1},
),
SubModuleReplacementDescription(
suffix="self_attn.v_proj",
target_module=ColCaiQuantLinear,
kwargs={"split_num": 1},
),
SubModuleReplacementDescription(
suffix="self_attn.o_proj",
target_module=RowCaiQuantLinear,
kwargs={"split_num": 1},
),
SubModuleReplacementDescription(
suffix="mlp.gate_proj",
target_module=ColCaiQuantLinear,
kwargs={"split_num": 1},
),
SubModuleReplacementDescription(
suffix="mlp.up_proj",
target_module=ColCaiQuantLinear,
kwargs={"split_num": 1},
),
SubModuleReplacementDescription(
suffix="mlp.down_proj",
target_module=RowCaiQuantLinear,
kwargs={"split_num": 1},
),
],
)
elif self.shard_config.extra_kwargs.get("quant", None) == "smoothquant":
from colossalai.inference.quant.smoothquant.models.llama import LlamaSmoothquantDecoderLayer
from colossalai.inference.quant.smoothquant.models.parallel_linear import (
ColW8A8BFP32OFP32Linear,
RowW8A8B8O8Linear,
RowW8A8BFP32O32LinearSiLU,
RowW8A8BFP32OFP32Linear,
)
policy[LlamaSmoothquantDecoderLayer] = ModulePolicyDescription(
attribute_replacement=decoder_attribute_replacement,
sub_module_replacement=[
SubModuleReplacementDescription(
suffix="self_attn.q_proj",
target_module=RowW8A8B8O8Linear,
kwargs={"split_num": 1},
),
SubModuleReplacementDescription(
suffix="self_attn.k_proj",
target_module=RowW8A8B8O8Linear,
kwargs={"split_num": 1},
),
SubModuleReplacementDescription(
suffix="self_attn.v_proj",
target_module=RowW8A8B8O8Linear,
kwargs={"split_num": 1},
),
SubModuleReplacementDescription(
suffix="self_attn.o_proj",
target_module=ColW8A8BFP32OFP32Linear,
kwargs={"split_num": 1},
),
SubModuleReplacementDescription(
suffix="mlp.gate_proj",
target_module=RowW8A8BFP32O32LinearSiLU,
kwargs={"split_num": 1},
),
SubModuleReplacementDescription(
suffix="mlp.up_proj",
target_module=RowW8A8BFP32OFP32Linear,
kwargs={"split_num": 1},
),
SubModuleReplacementDescription(
suffix="mlp.down_proj",
target_module=ColW8A8BFP32OFP32Linear,
kwargs={"split_num": 1},
),
],
)
self.shard_config._infer()
infer_forward = llama_causal_lm_forward
method_replacement = {"forward": partial(infer_forward)}
self.append_or_create_method_replacement(
description=method_replacement, policy=policy, target_key=LlamaForCausalLM
)
infer_forward = llama_model_forward
method_replacement = {"forward": partial(infer_forward)}
self.append_or_create_method_replacement(description=method_replacement, policy=policy, target_key=LlamaModel)
infer_forward = llama_decoder_layer_forward
method_replacement = {"forward": partial(infer_forward)}
self.append_or_create_method_replacement(
description=method_replacement, policy=policy, target_key=LlamaDecoderLayer
)
infer_forward = llama_attn_forward
method_replacement = {"forward": partial(infer_forward)}
self.append_or_create_method_replacement(
description=method_replacement, policy=policy, target_key=LlamaAttention
)
infer_forward = llama_attn_forward
method_replacement = {"forward": partial(infer_forward)}
self.append_or_create_method_replacement(
description=method_replacement, policy=policy, target_key=LlamaFlashAttention2
)
infer_forward = llama_attn_forward
method_replacement = {"forward": partial(infer_forward)}
self.append_or_create_method_replacement(
description=method_replacement, policy=policy, target_key=LlamaSdpaAttention
)
return policy

View File

@ -1,6 +1,6 @@
import enum
from dataclasses import dataclass
from typing import Any, List, Union
from typing import Any, List, Tuple, Union
import torch
from ordered_set import OrderedSet
@ -74,13 +74,6 @@ class Sequence:
self.output_token_id = []
self.status = RequestStatus.WAITING
@property
def prompt_len(self) -> int:
"""
Get length of prompts
"""
return len(self.input_token_id)
@property
def sentence_len(self) -> int:
"""
@ -113,7 +106,7 @@ class Sequence:
return True
if self.output_token_id:
if self.output_token_id[-1] == self.eos_token_id or len(self.output_token_id) == self.max_output_len:
if self.output_token_id[-1] >= self.eos_token_id or len(self.output_token_id) == self.max_output_len:
self.status = RequestStatus.COMPLETED
return True
@ -143,11 +136,13 @@ class Sequence:
def __repr__(self) -> str:
return (
f"Request ID(request_id={self.request_id}, "
f"(request_id={self.request_id}, "
f"prompt={self.prompt}, "
f"status={self.status.name}, "
f"sample_params={self.sample_params}, "
f"logical block number={len(self.block_table_index)}"
f"logical_block_number={self.block_table.shape[0]},"
f"input_len={self.input_len}),"
f"output_len={self.output_len})"
)
@ -159,9 +154,15 @@ class BatchInfo:
sequences_set: OrderedSet["Sequence"] = None
is_prompts: bool = True
device: torch.device = None
@classmethod
def init_batch(cls, seqs: List["Sequence"] = None) -> "BatchInfo":
def __post_init__(self):
if self.device is None:
self.device = torch.cuda.current_device()
if self.sequences_set is None:
self.sequences_set = OrderedSet()
def init_batch(self, seqs: List["Sequence"] = None):
"""
Initializes inference batches by input sentence list.
@ -169,29 +170,29 @@ class BatchInfo:
seqs (List["Sequence"]): List of input sequence.
"""
sequences_set = OrderedSet()
assert len(self.sequences_set) == 0, "Sequences set has been initialized."
if seqs is not None:
if not isinstance(seqs, list):
seqs = [seqs]
for seq in seqs:
if seq in sequences_set:
if seq in self.sequences_set:
logger.warning(f"The sequence(request_id {seq.request_id}) is already in sequences_set.")
continue
sequences_set.add(seq)
return cls(sequences_set=sequences_set)
self.sequences_set.add(seq)
def get_block_table_tensor(self) -> None:
tesnor_list = []
block_table = None
for seq in self.sequences_set:
block_table = seq.block_table
assert block_table, f"The sequence(request_id {seq.request_id}) has not initialized the block_table."
assert (
block_table is not None
), f"The sequence(request_id {seq.request_id}) has not initialized the block_table."
tesnor_list.append(seq.block_table)
assert tesnor_list, "Batch has not been initialized yet. Please initialize batch first."
block_table = torch.concat(tesnor_list)
block_table = torch.stack(tesnor_list)
return block_table
def clear_batch(self) -> None:
@ -239,7 +240,7 @@ class BatchInfo:
seqs = [seqs]
for seq in seqs:
if seq in self.sequences_set:
if self.sequences_set and seq in self.sequences_set:
logger.warning(f"The sequence(request_id {seq.request_id}) is already in sequences_set.")
continue
self.sequences_set.add(seq)
@ -251,7 +252,7 @@ class BatchInfo:
"""
return not self.sequences_set
def update_batch_tokens(self, tokens: Union[List[int], List[List[int]]]) -> None:
def update_batch_tokens(self, tokens: Union[List[int], List[List[int]], torch.Tensor]) -> None:
"""
Add an output token for each sentence in the batch.
@ -259,6 +260,9 @@ class BatchInfo:
tokens (List[int]): A batch of tokens
"""
if isinstance(tokens, torch.Tensor):
tokens = tokens.tolist()
assert self.get_batch_size() == len(tokens), "The number of tokens does not match batch_size."
for seq, token in zip(self.sequences_set, tokens):
@ -287,19 +291,25 @@ class BatchInfo:
else:
input_list.append([seq.output_token_id[-1]])
return torch.tensor(input_list, dtype=torch.long)
return torch.tensor(input_list, dtype=torch.long, device=self.device)
def get_1D_inputs(self) -> Tuple[torch.LongTensor, torch.Tensor]:
"""
Flattening the input tokens.
"""
input_list = []
input_len_list = []
for seq in self.sequences_set:
if self.is_prompts:
input_list.extend(seq.input_token_id)
input_len_list.append(seq.sentence_len)
else:
input_list.append(seq.output_token_id[-1])
return torch.tensor(input_list, dtype=torch.long)
input_len_list.append(1)
return torch.tensor(input_list, dtype=torch.long, device=self.device), torch.tensor(
input_len_list, dtype=torch.int, device=device
)
def get_sequence_lengths(self):
"""
@ -307,5 +317,9 @@ class BatchInfo:
"""
len_list = []
for seq in self.sequences_set:
len_list.append(seq.get_sentence_len())
return torch.tensor(len_list, dtype=torch.int)
len_list.append(seq.sentence_len)
return torch.tensor(len_list, dtype=torch.int, device=self.device)
def __repr__(self) -> str:
return f"(sequences_set={self.sequences_set}, " f"is_prompts={self.is_prompts})"

View File

@ -1,6 +1,6 @@
import pytest
import transformers
from transformers import AutoTokenizer
from transformers import AutoTokenizer, GenerationConfig
import colossalai
from colossalai.inference.config import InferenceConfig
@ -11,21 +11,24 @@ from colossalai.testing import spawn
def check_inference_engine():
model = transformers.LlamaForCausalLM(
transformers.LlamaConfig(
vocab_size=20000, hidden_size=512, intermediate_size=1536, num_attention_heads=4, num_hidden_layers=4
vocab_size=50000, hidden_size=512, intermediate_size=1536, num_attention_heads=4, num_hidden_layers=4
)
)
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/llama-tokenizer")
inference_config = InferenceConfig()
inference_config = InferenceConfig(max_output_len=5)
inference_engine = InferenceEngine(model, tokenizer, inference_config, verbose=True)
inputs = [
"介绍一下北京",
"介绍一下今天的北京",
"介绍一下武汉",
]
inference_engine.add_request(prompts=inputs)
assert inference_engine.request_handler._has_waiting()
# outputs = inference_engine.generate(None)
generation_config = GenerationConfig(top_k=2, top_p=0.8, do_sample=True)
outputs = inference_engine.generate(generation_config)
print("outputs: ", outputs)
# Engine still gets some bug