mirror of https://github.com/hpcaitech/ColossalAI
adapted to pad_context_forward
parent
47e53eaa1c
commit
fa4fbdbffb
|
@ -1,6 +1,5 @@
|
|||
"""
|
||||
Our config consists of one part:
|
||||
1. inference_config: configs for inference, it is a unified api that wraps all the configs for inference.
|
||||
Our config contains various options for inference optimization, it is a unified API that wraps all the configurations for inference.
|
||||
"""
|
||||
|
||||
import logging
|
||||
|
@ -94,9 +93,12 @@ class InferenceConfig:
|
|||
torch.float32,
|
||||
torch.float16,
|
||||
torch.bfloat16,
|
||||
], "dtype should be one of 'fp16', 'fp32', 'bf16', torch.float32, torch.float16, torch.bfloat16"
|
||||
assert self.max_batch_size <= 64, "Max batch size exceeds the constraint"
|
||||
assert self.quant_mode in ["smoothquant", "gptq", None], "quant should be one of 'smoothquant', 'gptq'"
|
||||
], f"dtype should be one of 'fp16', 'fp32', 'bf16', torch.float32, torch.float16, torch.bfloat16, but got {self.dtype}."
|
||||
assert self.quant_mode in [
|
||||
"smoothquant",
|
||||
"gptq",
|
||||
None,
|
||||
], f"quant should be one of 'smoothquant', 'gptq', but got {self.quant_mode}."
|
||||
assert (
|
||||
self.max_input_len + self.max_output_len <= self.max_seq_len
|
||||
), "The sum of max_input_len and max_output_len must be smaller than max_seq_len."
|
||||
), f"The sum of max_input_len {self.max_input_len} and max_output_len {self.max_output_len} must be smaller than max_seq_len {self.max_seq_len}."
|
||||
|
|
|
@ -51,6 +51,8 @@ class InferenceEngine:
|
|||
self.model_config = model.config
|
||||
self.device = torch.device("cuda")
|
||||
|
||||
model = model.eval()
|
||||
|
||||
if inference_config.dtype == "fp32" or inference_config.dtype == torch.float32:
|
||||
self.dtype = torch.float32
|
||||
elif inference_config.dtype == "fp16" or inference_config.dtype == torch.float16:
|
||||
|
@ -85,12 +87,12 @@ class InferenceEngine:
|
|||
Verify the input config
|
||||
"""
|
||||
if not isinstance(self.model, nn.Module):
|
||||
raise TypeError(f"the model type must be nn.Module, but get {type(self.model)}")
|
||||
raise TypeError(f"the model type must be nn.Module, but got {type(self.model)}")
|
||||
if not isinstance(self.tokenizer, PreTrainedTokenizerFast) and not isinstance(
|
||||
self.tokenizer, PreTrainedTokenizer
|
||||
):
|
||||
raise TypeError(
|
||||
f"the tokenizer type must be PreTrainedTokenizer or PreTrainedTokenizerFast, but get {type(self.tokenizer)}"
|
||||
f"the tokenizer type must be PreTrainedTokenizer or PreTrainedTokenizerFast, but got {type(self.tokenizer)}"
|
||||
)
|
||||
assert (
|
||||
self.model.__class__.__name__ in _supported_models
|
||||
|
|
|
@ -8,6 +8,9 @@ from colossalai.inference.kv_cache import KVCacheManager
|
|||
from colossalai.inference.logit_processors import logit_processor
|
||||
from colossalai.inference.sampler import *
|
||||
from colossalai.inference.struct import BatchInfo, Sequence
|
||||
from colossalai.logging import get_dist_logger
|
||||
|
||||
logger = get_dist_logger(__name__)
|
||||
|
||||
|
||||
class RunningList:
|
||||
|
@ -93,17 +96,23 @@ class RequestHandler:
|
|||
# Try to allocate cache blocks for the sequence using a priority of prompt length.
|
||||
for lst in reversed(self.waiting_list):
|
||||
if lst:
|
||||
remove_list = []
|
||||
for seq in lst:
|
||||
if seq.input_len > self.inference_config.max_input_len:
|
||||
# If the prompt length is longer than max_input_len, abort the sequence.
|
||||
logger.warning(
|
||||
f"the prompt(Request id = {seq.request_id}) length is longer than max_input_len, abort this sequence."
|
||||
)
|
||||
self.abort_sequence(seq.request_id)
|
||||
break
|
||||
remove_list.append(seq)
|
||||
# Try to allocate cache blocks for the sequence.
|
||||
if self.cache_manager.check_allocation(seq):
|
||||
# If succeed, add the sequence to running list.
|
||||
remove_list.append(seq)
|
||||
self.running_list.append(seq)
|
||||
self.cache_manager.allocate_context_from_block_table(seq.block_table, seq.input_len)
|
||||
lst.clear()
|
||||
for seq in remove_list:
|
||||
lst.remove(seq)
|
||||
if self.running_list.ready_for_prefill():
|
||||
for seq in self.running_list.prefill:
|
||||
seq.mark_running()
|
||||
|
@ -130,10 +139,9 @@ class RequestHandler:
|
|||
"""
|
||||
Abort the request.
|
||||
"""
|
||||
seq, priority = self._find_sequence(request_id)
|
||||
seq, _ = self._find_sequence(request_id)
|
||||
if seq.status.is_waiting:
|
||||
seq.mark_aborted()
|
||||
self.waiting_list[priority].remove(seq)
|
||||
elif seq.status.is_running():
|
||||
self.cache_manager.free_block_table(seq.block_table)
|
||||
self.running_list.remove(seq)
|
||||
|
|
|
@ -112,7 +112,7 @@ class KVCacheManager:
|
|||
|
||||
def get_kv_cache(self):
|
||||
"""Get k_cache and v_cache"""
|
||||
return self._kv_caches[0], self._kv_caches[1]
|
||||
return self._kv_caches
|
||||
|
||||
def get_max_blocks_per_sequence(self) -> int:
|
||||
"""Get the maximum number of blocks that can be allocated for a single sequence."""
|
||||
|
|
|
@ -16,7 +16,7 @@ from transformers.models.llama.modeling_llama import (
|
|||
from colossalai.inference.modeling.layers.attention import convert_kvcache, copy_to_cache
|
||||
from colossalai.inference.struct import BatchInfo
|
||||
|
||||
from flash_attn.bert_padding import index_first_axis, pad_input # noqa
|
||||
from flash_attn.bert_padding import index_first_axis # noqa
|
||||
|
||||
|
||||
def rotate_half(x):
|
||||
|
@ -167,20 +167,8 @@ def llama_attn_forward(
|
|||
query_states, key_states, value_states, k_cache, v_cache, sequence_lengths, block_tables, attention_mask
|
||||
)
|
||||
else:
|
||||
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
||||
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
||||
attn_output = pad_decoding_forward(
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
k_cache,
|
||||
v_cache,
|
||||
sequence_lengths,
|
||||
block_tables,
|
||||
attention_mask,
|
||||
self.layer_idx,
|
||||
self.attention_dropout,
|
||||
self.training,
|
||||
query_states, key_states, value_states, k_cache, v_cache, sequence_lengths, block_tables, attention_mask
|
||||
)
|
||||
|
||||
attn_output = attn_output.view(bsz, q_len, self.num_heads, self.head_dim)
|
||||
|
@ -215,9 +203,6 @@ def pad_decoding_forward(
|
|||
lengths: torch.Tensor, # [num_seqs]: input_lengths + output_lengths
|
||||
block_tables: torch.Tensor, # [num_seqs,max_blocks_per_sequence]
|
||||
attn_mask: torch.Tensor = None,
|
||||
layer_id: int = 0,
|
||||
attention_dropout: float = None,
|
||||
training: bool = False,
|
||||
):
|
||||
bsz, query_length, num_heads, head_size = query.shape
|
||||
seq_len = max(lengths)
|
||||
|
@ -247,9 +232,7 @@ def pad_decoding_forward(
|
|||
attn_mask = attn_mask.masked_fill(padding_mask.bool(), torch.finfo(query.dtype).min)
|
||||
|
||||
attn_weights += attn_mask
|
||||
|
||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
|
||||
attn_weights = nn.functional.dropout(attn_weights, p=attention_dropout, training=training)
|
||||
attn_output = torch.matmul(attn_weights, value)
|
||||
|
||||
if attn_output.size() != (bsz, num_heads, 1, head_size):
|
||||
|
@ -277,8 +260,6 @@ def pad_context_forward(
|
|||
block_size = k_cache.shape[-1]
|
||||
assert q.shape[0] == k.shape[0] == v.shape[0] == block_tables.shape[0]
|
||||
block_tables.shape[-1] * block_size
|
||||
shape = (bsz, seq_len, num_heads, head_size)
|
||||
input_shape = shape[:2]
|
||||
|
||||
# Copy kv to memory(rotary embedded)
|
||||
copy_to_cache(k, k_cache, lengths=context_lengths, block_tables=block_tables)
|
||||
|
|
|
@ -42,7 +42,7 @@ def beam_search_sample(
|
|||
|
||||
# NOTE: this beam search sample function is wrong now.
|
||||
"""
|
||||
|
||||
|
||||
beam_width = generation_config.num_beams
|
||||
results = []
|
||||
if is_prompt:
|
||||
|
|
|
@ -268,7 +268,7 @@ class BatchInfo:
|
|||
for seq, token in zip(self.sequences_set, tokens):
|
||||
if not isinstance(token, list):
|
||||
if not isinstance(token, int):
|
||||
raise TypeError(f"The token type must be List[int] or int, but get {type(token)}.")
|
||||
raise TypeError(f"The token type must be List[int] or int, but got {type(token)}.")
|
||||
token = [token]
|
||||
seq.output_token_id += token
|
||||
seq.check_finish()
|
||||
|
|
|
@ -133,7 +133,7 @@ class CaiInferEngine:
|
|||
"""
|
||||
assert isinstance(
|
||||
input_list, (BatchEncoding, dict)
|
||||
), f"Only accept BatchEncoding or dict as input, but get {input_list.__class__.__name__}."
|
||||
), f"Only accept BatchEncoding or dict as input, but got {input_list.__class__.__name__}."
|
||||
if isinstance(input_list, BatchEncoding):
|
||||
input_list = input_list.data
|
||||
out, timestamp = self.schedule.generate_step(self.model, iter([input_list]))
|
||||
|
|
|
@ -28,20 +28,24 @@ def check_inference_engine(test_cai=False):
|
|||
)
|
||||
).cuda()
|
||||
|
||||
model = model.eval()
|
||||
|
||||
inputs = [
|
||||
"介绍一下今天的北京,",
|
||||
"介绍一下今天的北京,比如故宫,天安门,长城或者其他的一些景点,",
|
||||
"介绍一下武汉,",
|
||||
]
|
||||
|
||||
output_len = 16
|
||||
output_len = 128
|
||||
do_sample = True
|
||||
top_p = 0.5
|
||||
top_k = 50
|
||||
|
||||
if test_cai:
|
||||
inference_config = InferenceConfig(max_output_len=output_len)
|
||||
inference_engine = InferenceEngine(model, tokenizer, inference_config, verbose=True)
|
||||
inference_engine.add_request(prompts=inputs)
|
||||
assert inference_engine.request_handler._has_waiting()
|
||||
generation_config = GenerationConfig(do_sample=do_sample, top_p=0.5, top_k=50)
|
||||
generation_config = GenerationConfig(do_sample=do_sample, top_p=top_p, top_k=top_k)
|
||||
outputs = inference_engine.generate(generation_config)
|
||||
else:
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
|
@ -49,7 +53,11 @@ def check_inference_engine(test_cai=False):
|
|||
inputs = tokenizer.batch_encode_plus(inputs, padding=True, return_tensors="pt")["input_ids"]
|
||||
inputs = inputs.cuda()
|
||||
generation_config = GenerationConfig(
|
||||
do_sample=do_sample, top_p=0.5, top_k=50, pad_token_id=tokenizer.pad_token_id, max_new_tokens=output_len
|
||||
do_sample=do_sample,
|
||||
top_p=top_p,
|
||||
top_k=top_k,
|
||||
pad_token_id=tokenizer.pad_token_id,
|
||||
max_new_tokens=output_len,
|
||||
)
|
||||
outputs = model.generate(inputs, generation_config=generation_config)
|
||||
outputs = tokenizer.batch_decode(outputs, skip_special_tokens=True)
|
||||
|
|
Loading…
Reference in New Issue