adapted to pad_context_forward

pull/5258/head
yuehuayingxueluo 2024-01-09 13:52:53 +08:00 committed by FrankLeeeee
parent 47e53eaa1c
commit fa4fbdbffb
9 changed files with 42 additions and 41 deletions

View File

@ -1,6 +1,5 @@
""" """
Our config consists of one part: Our config contains various options for inference optimization, it is a unified API that wraps all the configurations for inference.
1. inference_config: configs for inference, it is a unified api that wraps all the configs for inference.
""" """
import logging import logging
@ -94,9 +93,12 @@ class InferenceConfig:
torch.float32, torch.float32,
torch.float16, torch.float16,
torch.bfloat16, torch.bfloat16,
], "dtype should be one of 'fp16', 'fp32', 'bf16', torch.float32, torch.float16, torch.bfloat16" ], f"dtype should be one of 'fp16', 'fp32', 'bf16', torch.float32, torch.float16, torch.bfloat16, but got {self.dtype}."
assert self.max_batch_size <= 64, "Max batch size exceeds the constraint" assert self.quant_mode in [
assert self.quant_mode in ["smoothquant", "gptq", None], "quant should be one of 'smoothquant', 'gptq'" "smoothquant",
"gptq",
None,
], f"quant should be one of 'smoothquant', 'gptq', but got {self.quant_mode}."
assert ( assert (
self.max_input_len + self.max_output_len <= self.max_seq_len self.max_input_len + self.max_output_len <= self.max_seq_len
), "The sum of max_input_len and max_output_len must be smaller than max_seq_len." ), f"The sum of max_input_len {self.max_input_len} and max_output_len {self.max_output_len} must be smaller than max_seq_len {self.max_seq_len}."

View File

@ -51,6 +51,8 @@ class InferenceEngine:
self.model_config = model.config self.model_config = model.config
self.device = torch.device("cuda") self.device = torch.device("cuda")
model = model.eval()
if inference_config.dtype == "fp32" or inference_config.dtype == torch.float32: if inference_config.dtype == "fp32" or inference_config.dtype == torch.float32:
self.dtype = torch.float32 self.dtype = torch.float32
elif inference_config.dtype == "fp16" or inference_config.dtype == torch.float16: elif inference_config.dtype == "fp16" or inference_config.dtype == torch.float16:
@ -85,12 +87,12 @@ class InferenceEngine:
Verify the input config Verify the input config
""" """
if not isinstance(self.model, nn.Module): if not isinstance(self.model, nn.Module):
raise TypeError(f"the model type must be nn.Module, but get {type(self.model)}") raise TypeError(f"the model type must be nn.Module, but got {type(self.model)}")
if not isinstance(self.tokenizer, PreTrainedTokenizerFast) and not isinstance( if not isinstance(self.tokenizer, PreTrainedTokenizerFast) and not isinstance(
self.tokenizer, PreTrainedTokenizer self.tokenizer, PreTrainedTokenizer
): ):
raise TypeError( raise TypeError(
f"the tokenizer type must be PreTrainedTokenizer or PreTrainedTokenizerFast, but get {type(self.tokenizer)}" f"the tokenizer type must be PreTrainedTokenizer or PreTrainedTokenizerFast, but got {type(self.tokenizer)}"
) )
assert ( assert (
self.model.__class__.__name__ in _supported_models self.model.__class__.__name__ in _supported_models

View File

@ -8,6 +8,9 @@ from colossalai.inference.kv_cache import KVCacheManager
from colossalai.inference.logit_processors import logit_processor from colossalai.inference.logit_processors import logit_processor
from colossalai.inference.sampler import * from colossalai.inference.sampler import *
from colossalai.inference.struct import BatchInfo, Sequence from colossalai.inference.struct import BatchInfo, Sequence
from colossalai.logging import get_dist_logger
logger = get_dist_logger(__name__)
class RunningList: class RunningList:
@ -93,17 +96,23 @@ class RequestHandler:
# Try to allocate cache blocks for the sequence using a priority of prompt length. # Try to allocate cache blocks for the sequence using a priority of prompt length.
for lst in reversed(self.waiting_list): for lst in reversed(self.waiting_list):
if lst: if lst:
remove_list = []
for seq in lst: for seq in lst:
if seq.input_len > self.inference_config.max_input_len: if seq.input_len > self.inference_config.max_input_len:
# If the prompt length is longer than max_input_len, abort the sequence. # If the prompt length is longer than max_input_len, abort the sequence.
logger.warning(
f"the prompt(Request id = {seq.request_id}) length is longer than max_input_len, abort this sequence."
)
self.abort_sequence(seq.request_id) self.abort_sequence(seq.request_id)
break remove_list.append(seq)
# Try to allocate cache blocks for the sequence. # Try to allocate cache blocks for the sequence.
if self.cache_manager.check_allocation(seq): if self.cache_manager.check_allocation(seq):
# If succeed, add the sequence to running list. # If succeed, add the sequence to running list.
remove_list.append(seq)
self.running_list.append(seq) self.running_list.append(seq)
self.cache_manager.allocate_context_from_block_table(seq.block_table, seq.input_len) self.cache_manager.allocate_context_from_block_table(seq.block_table, seq.input_len)
lst.clear() for seq in remove_list:
lst.remove(seq)
if self.running_list.ready_for_prefill(): if self.running_list.ready_for_prefill():
for seq in self.running_list.prefill: for seq in self.running_list.prefill:
seq.mark_running() seq.mark_running()
@ -130,10 +139,9 @@ class RequestHandler:
""" """
Abort the request. Abort the request.
""" """
seq, priority = self._find_sequence(request_id) seq, _ = self._find_sequence(request_id)
if seq.status.is_waiting: if seq.status.is_waiting:
seq.mark_aborted() seq.mark_aborted()
self.waiting_list[priority].remove(seq)
elif seq.status.is_running(): elif seq.status.is_running():
self.cache_manager.free_block_table(seq.block_table) self.cache_manager.free_block_table(seq.block_table)
self.running_list.remove(seq) self.running_list.remove(seq)

View File

@ -112,7 +112,7 @@ class KVCacheManager:
def get_kv_cache(self): def get_kv_cache(self):
"""Get k_cache and v_cache""" """Get k_cache and v_cache"""
return self._kv_caches[0], self._kv_caches[1] return self._kv_caches
def get_max_blocks_per_sequence(self) -> int: def get_max_blocks_per_sequence(self) -> int:
"""Get the maximum number of blocks that can be allocated for a single sequence.""" """Get the maximum number of blocks that can be allocated for a single sequence."""

View File

@ -16,7 +16,7 @@ from transformers.models.llama.modeling_llama import (
from colossalai.inference.modeling.layers.attention import convert_kvcache, copy_to_cache from colossalai.inference.modeling.layers.attention import convert_kvcache, copy_to_cache
from colossalai.inference.struct import BatchInfo from colossalai.inference.struct import BatchInfo
from flash_attn.bert_padding import index_first_axis, pad_input # noqa from flash_attn.bert_padding import index_first_axis # noqa
def rotate_half(x): def rotate_half(x):
@ -167,20 +167,8 @@ def llama_attn_forward(
query_states, key_states, value_states, k_cache, v_cache, sequence_lengths, block_tables, attention_mask query_states, key_states, value_states, k_cache, v_cache, sequence_lengths, block_tables, attention_mask
) )
else: else:
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)
attn_output = pad_decoding_forward( attn_output = pad_decoding_forward(
query_states, query_states, key_states, value_states, k_cache, v_cache, sequence_lengths, block_tables, attention_mask
key_states,
value_states,
k_cache,
v_cache,
sequence_lengths,
block_tables,
attention_mask,
self.layer_idx,
self.attention_dropout,
self.training,
) )
attn_output = attn_output.view(bsz, q_len, self.num_heads, self.head_dim) attn_output = attn_output.view(bsz, q_len, self.num_heads, self.head_dim)
@ -215,9 +203,6 @@ def pad_decoding_forward(
lengths: torch.Tensor, # [num_seqs]: input_lengths + output_lengths lengths: torch.Tensor, # [num_seqs]: input_lengths + output_lengths
block_tables: torch.Tensor, # [num_seqs,max_blocks_per_sequence] block_tables: torch.Tensor, # [num_seqs,max_blocks_per_sequence]
attn_mask: torch.Tensor = None, attn_mask: torch.Tensor = None,
layer_id: int = 0,
attention_dropout: float = None,
training: bool = False,
): ):
bsz, query_length, num_heads, head_size = query.shape bsz, query_length, num_heads, head_size = query.shape
seq_len = max(lengths) seq_len = max(lengths)
@ -247,9 +232,7 @@ def pad_decoding_forward(
attn_mask = attn_mask.masked_fill(padding_mask.bool(), torch.finfo(query.dtype).min) attn_mask = attn_mask.masked_fill(padding_mask.bool(), torch.finfo(query.dtype).min)
attn_weights += attn_mask attn_weights += attn_mask
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
attn_weights = nn.functional.dropout(attn_weights, p=attention_dropout, training=training)
attn_output = torch.matmul(attn_weights, value) attn_output = torch.matmul(attn_weights, value)
if attn_output.size() != (bsz, num_heads, 1, head_size): if attn_output.size() != (bsz, num_heads, 1, head_size):
@ -277,8 +260,6 @@ def pad_context_forward(
block_size = k_cache.shape[-1] block_size = k_cache.shape[-1]
assert q.shape[0] == k.shape[0] == v.shape[0] == block_tables.shape[0] assert q.shape[0] == k.shape[0] == v.shape[0] == block_tables.shape[0]
block_tables.shape[-1] * block_size block_tables.shape[-1] * block_size
shape = (bsz, seq_len, num_heads, head_size)
input_shape = shape[:2]
# Copy kv to memory(rotary embedded) # Copy kv to memory(rotary embedded)
copy_to_cache(k, k_cache, lengths=context_lengths, block_tables=block_tables) copy_to_cache(k, k_cache, lengths=context_lengths, block_tables=block_tables)

View File

@ -268,7 +268,7 @@ class BatchInfo:
for seq, token in zip(self.sequences_set, tokens): for seq, token in zip(self.sequences_set, tokens):
if not isinstance(token, list): if not isinstance(token, list):
if not isinstance(token, int): if not isinstance(token, int):
raise TypeError(f"The token type must be List[int] or int, but get {type(token)}.") raise TypeError(f"The token type must be List[int] or int, but got {type(token)}.")
token = [token] token = [token]
seq.output_token_id += token seq.output_token_id += token
seq.check_finish() seq.check_finish()

View File

@ -133,7 +133,7 @@ class CaiInferEngine:
""" """
assert isinstance( assert isinstance(
input_list, (BatchEncoding, dict) input_list, (BatchEncoding, dict)
), f"Only accept BatchEncoding or dict as input, but get {input_list.__class__.__name__}." ), f"Only accept BatchEncoding or dict as input, but got {input_list.__class__.__name__}."
if isinstance(input_list, BatchEncoding): if isinstance(input_list, BatchEncoding):
input_list = input_list.data input_list = input_list.data
out, timestamp = self.schedule.generate_step(self.model, iter([input_list])) out, timestamp = self.schedule.generate_step(self.model, iter([input_list]))

View File

@ -28,20 +28,24 @@ def check_inference_engine(test_cai=False):
) )
).cuda() ).cuda()
model = model.eval()
inputs = [ inputs = [
"介绍一下今天的北京,", "介绍一下今天的北京,比如故宫,天安门,长城或者其他的一些景点,",
"介绍一下武汉,", "介绍一下武汉,",
] ]
output_len = 16 output_len = 128
do_sample = True do_sample = True
top_p = 0.5
top_k = 50
if test_cai: if test_cai:
inference_config = InferenceConfig(max_output_len=output_len) inference_config = InferenceConfig(max_output_len=output_len)
inference_engine = InferenceEngine(model, tokenizer, inference_config, verbose=True) inference_engine = InferenceEngine(model, tokenizer, inference_config, verbose=True)
inference_engine.add_request(prompts=inputs) inference_engine.add_request(prompts=inputs)
assert inference_engine.request_handler._has_waiting() assert inference_engine.request_handler._has_waiting()
generation_config = GenerationConfig(do_sample=do_sample, top_p=0.5, top_k=50) generation_config = GenerationConfig(do_sample=do_sample, top_p=top_p, top_k=top_k)
outputs = inference_engine.generate(generation_config) outputs = inference_engine.generate(generation_config)
else: else:
tokenizer.pad_token = tokenizer.eos_token tokenizer.pad_token = tokenizer.eos_token
@ -49,7 +53,11 @@ def check_inference_engine(test_cai=False):
inputs = tokenizer.batch_encode_plus(inputs, padding=True, return_tensors="pt")["input_ids"] inputs = tokenizer.batch_encode_plus(inputs, padding=True, return_tensors="pt")["input_ids"]
inputs = inputs.cuda() inputs = inputs.cuda()
generation_config = GenerationConfig( generation_config = GenerationConfig(
do_sample=do_sample, top_p=0.5, top_k=50, pad_token_id=tokenizer.pad_token_id, max_new_tokens=output_len do_sample=do_sample,
top_p=top_p,
top_k=top_k,
pad_token_id=tokenizer.pad_token_id,
max_new_tokens=output_len,
) )
outputs = model.generate(inputs, generation_config=generation_config) outputs = model.generate(inputs, generation_config=generation_config)
outputs = tokenizer.batch_decode(outputs, skip_special_tokens=True) outputs = tokenizer.batch_decode(outputs, skip_special_tokens=True)