mirror of https://github.com/hpcaitech/ColossalAI
adapted to pad_context_forward
parent
47e53eaa1c
commit
fa4fbdbffb
|
@ -1,6 +1,5 @@
|
||||||
"""
|
"""
|
||||||
Our config consists of one part:
|
Our config contains various options for inference optimization, it is a unified API that wraps all the configurations for inference.
|
||||||
1. inference_config: configs for inference, it is a unified api that wraps all the configs for inference.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
@ -94,9 +93,12 @@ class InferenceConfig:
|
||||||
torch.float32,
|
torch.float32,
|
||||||
torch.float16,
|
torch.float16,
|
||||||
torch.bfloat16,
|
torch.bfloat16,
|
||||||
], "dtype should be one of 'fp16', 'fp32', 'bf16', torch.float32, torch.float16, torch.bfloat16"
|
], f"dtype should be one of 'fp16', 'fp32', 'bf16', torch.float32, torch.float16, torch.bfloat16, but got {self.dtype}."
|
||||||
assert self.max_batch_size <= 64, "Max batch size exceeds the constraint"
|
assert self.quant_mode in [
|
||||||
assert self.quant_mode in ["smoothquant", "gptq", None], "quant should be one of 'smoothquant', 'gptq'"
|
"smoothquant",
|
||||||
|
"gptq",
|
||||||
|
None,
|
||||||
|
], f"quant should be one of 'smoothquant', 'gptq', but got {self.quant_mode}."
|
||||||
assert (
|
assert (
|
||||||
self.max_input_len + self.max_output_len <= self.max_seq_len
|
self.max_input_len + self.max_output_len <= self.max_seq_len
|
||||||
), "The sum of max_input_len and max_output_len must be smaller than max_seq_len."
|
), f"The sum of max_input_len {self.max_input_len} and max_output_len {self.max_output_len} must be smaller than max_seq_len {self.max_seq_len}."
|
||||||
|
|
|
@ -51,6 +51,8 @@ class InferenceEngine:
|
||||||
self.model_config = model.config
|
self.model_config = model.config
|
||||||
self.device = torch.device("cuda")
|
self.device = torch.device("cuda")
|
||||||
|
|
||||||
|
model = model.eval()
|
||||||
|
|
||||||
if inference_config.dtype == "fp32" or inference_config.dtype == torch.float32:
|
if inference_config.dtype == "fp32" or inference_config.dtype == torch.float32:
|
||||||
self.dtype = torch.float32
|
self.dtype = torch.float32
|
||||||
elif inference_config.dtype == "fp16" or inference_config.dtype == torch.float16:
|
elif inference_config.dtype == "fp16" or inference_config.dtype == torch.float16:
|
||||||
|
@ -85,12 +87,12 @@ class InferenceEngine:
|
||||||
Verify the input config
|
Verify the input config
|
||||||
"""
|
"""
|
||||||
if not isinstance(self.model, nn.Module):
|
if not isinstance(self.model, nn.Module):
|
||||||
raise TypeError(f"the model type must be nn.Module, but get {type(self.model)}")
|
raise TypeError(f"the model type must be nn.Module, but got {type(self.model)}")
|
||||||
if not isinstance(self.tokenizer, PreTrainedTokenizerFast) and not isinstance(
|
if not isinstance(self.tokenizer, PreTrainedTokenizerFast) and not isinstance(
|
||||||
self.tokenizer, PreTrainedTokenizer
|
self.tokenizer, PreTrainedTokenizer
|
||||||
):
|
):
|
||||||
raise TypeError(
|
raise TypeError(
|
||||||
f"the tokenizer type must be PreTrainedTokenizer or PreTrainedTokenizerFast, but get {type(self.tokenizer)}"
|
f"the tokenizer type must be PreTrainedTokenizer or PreTrainedTokenizerFast, but got {type(self.tokenizer)}"
|
||||||
)
|
)
|
||||||
assert (
|
assert (
|
||||||
self.model.__class__.__name__ in _supported_models
|
self.model.__class__.__name__ in _supported_models
|
||||||
|
|
|
@ -8,6 +8,9 @@ from colossalai.inference.kv_cache import KVCacheManager
|
||||||
from colossalai.inference.logit_processors import logit_processor
|
from colossalai.inference.logit_processors import logit_processor
|
||||||
from colossalai.inference.sampler import *
|
from colossalai.inference.sampler import *
|
||||||
from colossalai.inference.struct import BatchInfo, Sequence
|
from colossalai.inference.struct import BatchInfo, Sequence
|
||||||
|
from colossalai.logging import get_dist_logger
|
||||||
|
|
||||||
|
logger = get_dist_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class RunningList:
|
class RunningList:
|
||||||
|
@ -93,17 +96,23 @@ class RequestHandler:
|
||||||
# Try to allocate cache blocks for the sequence using a priority of prompt length.
|
# Try to allocate cache blocks for the sequence using a priority of prompt length.
|
||||||
for lst in reversed(self.waiting_list):
|
for lst in reversed(self.waiting_list):
|
||||||
if lst:
|
if lst:
|
||||||
|
remove_list = []
|
||||||
for seq in lst:
|
for seq in lst:
|
||||||
if seq.input_len > self.inference_config.max_input_len:
|
if seq.input_len > self.inference_config.max_input_len:
|
||||||
# If the prompt length is longer than max_input_len, abort the sequence.
|
# If the prompt length is longer than max_input_len, abort the sequence.
|
||||||
|
logger.warning(
|
||||||
|
f"the prompt(Request id = {seq.request_id}) length is longer than max_input_len, abort this sequence."
|
||||||
|
)
|
||||||
self.abort_sequence(seq.request_id)
|
self.abort_sequence(seq.request_id)
|
||||||
break
|
remove_list.append(seq)
|
||||||
# Try to allocate cache blocks for the sequence.
|
# Try to allocate cache blocks for the sequence.
|
||||||
if self.cache_manager.check_allocation(seq):
|
if self.cache_manager.check_allocation(seq):
|
||||||
# If succeed, add the sequence to running list.
|
# If succeed, add the sequence to running list.
|
||||||
|
remove_list.append(seq)
|
||||||
self.running_list.append(seq)
|
self.running_list.append(seq)
|
||||||
self.cache_manager.allocate_context_from_block_table(seq.block_table, seq.input_len)
|
self.cache_manager.allocate_context_from_block_table(seq.block_table, seq.input_len)
|
||||||
lst.clear()
|
for seq in remove_list:
|
||||||
|
lst.remove(seq)
|
||||||
if self.running_list.ready_for_prefill():
|
if self.running_list.ready_for_prefill():
|
||||||
for seq in self.running_list.prefill:
|
for seq in self.running_list.prefill:
|
||||||
seq.mark_running()
|
seq.mark_running()
|
||||||
|
@ -130,10 +139,9 @@ class RequestHandler:
|
||||||
"""
|
"""
|
||||||
Abort the request.
|
Abort the request.
|
||||||
"""
|
"""
|
||||||
seq, priority = self._find_sequence(request_id)
|
seq, _ = self._find_sequence(request_id)
|
||||||
if seq.status.is_waiting:
|
if seq.status.is_waiting:
|
||||||
seq.mark_aborted()
|
seq.mark_aborted()
|
||||||
self.waiting_list[priority].remove(seq)
|
|
||||||
elif seq.status.is_running():
|
elif seq.status.is_running():
|
||||||
self.cache_manager.free_block_table(seq.block_table)
|
self.cache_manager.free_block_table(seq.block_table)
|
||||||
self.running_list.remove(seq)
|
self.running_list.remove(seq)
|
||||||
|
|
|
@ -112,7 +112,7 @@ class KVCacheManager:
|
||||||
|
|
||||||
def get_kv_cache(self):
|
def get_kv_cache(self):
|
||||||
"""Get k_cache and v_cache"""
|
"""Get k_cache and v_cache"""
|
||||||
return self._kv_caches[0], self._kv_caches[1]
|
return self._kv_caches
|
||||||
|
|
||||||
def get_max_blocks_per_sequence(self) -> int:
|
def get_max_blocks_per_sequence(self) -> int:
|
||||||
"""Get the maximum number of blocks that can be allocated for a single sequence."""
|
"""Get the maximum number of blocks that can be allocated for a single sequence."""
|
||||||
|
|
|
@ -16,7 +16,7 @@ from transformers.models.llama.modeling_llama import (
|
||||||
from colossalai.inference.modeling.layers.attention import convert_kvcache, copy_to_cache
|
from colossalai.inference.modeling.layers.attention import convert_kvcache, copy_to_cache
|
||||||
from colossalai.inference.struct import BatchInfo
|
from colossalai.inference.struct import BatchInfo
|
||||||
|
|
||||||
from flash_attn.bert_padding import index_first_axis, pad_input # noqa
|
from flash_attn.bert_padding import index_first_axis # noqa
|
||||||
|
|
||||||
|
|
||||||
def rotate_half(x):
|
def rotate_half(x):
|
||||||
|
@ -167,20 +167,8 @@ def llama_attn_forward(
|
||||||
query_states, key_states, value_states, k_cache, v_cache, sequence_lengths, block_tables, attention_mask
|
query_states, key_states, value_states, k_cache, v_cache, sequence_lengths, block_tables, attention_mask
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
|
||||||
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
|
||||||
attn_output = pad_decoding_forward(
|
attn_output = pad_decoding_forward(
|
||||||
query_states,
|
query_states, key_states, value_states, k_cache, v_cache, sequence_lengths, block_tables, attention_mask
|
||||||
key_states,
|
|
||||||
value_states,
|
|
||||||
k_cache,
|
|
||||||
v_cache,
|
|
||||||
sequence_lengths,
|
|
||||||
block_tables,
|
|
||||||
attention_mask,
|
|
||||||
self.layer_idx,
|
|
||||||
self.attention_dropout,
|
|
||||||
self.training,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
attn_output = attn_output.view(bsz, q_len, self.num_heads, self.head_dim)
|
attn_output = attn_output.view(bsz, q_len, self.num_heads, self.head_dim)
|
||||||
|
@ -215,9 +203,6 @@ def pad_decoding_forward(
|
||||||
lengths: torch.Tensor, # [num_seqs]: input_lengths + output_lengths
|
lengths: torch.Tensor, # [num_seqs]: input_lengths + output_lengths
|
||||||
block_tables: torch.Tensor, # [num_seqs,max_blocks_per_sequence]
|
block_tables: torch.Tensor, # [num_seqs,max_blocks_per_sequence]
|
||||||
attn_mask: torch.Tensor = None,
|
attn_mask: torch.Tensor = None,
|
||||||
layer_id: int = 0,
|
|
||||||
attention_dropout: float = None,
|
|
||||||
training: bool = False,
|
|
||||||
):
|
):
|
||||||
bsz, query_length, num_heads, head_size = query.shape
|
bsz, query_length, num_heads, head_size = query.shape
|
||||||
seq_len = max(lengths)
|
seq_len = max(lengths)
|
||||||
|
@ -247,9 +232,7 @@ def pad_decoding_forward(
|
||||||
attn_mask = attn_mask.masked_fill(padding_mask.bool(), torch.finfo(query.dtype).min)
|
attn_mask = attn_mask.masked_fill(padding_mask.bool(), torch.finfo(query.dtype).min)
|
||||||
|
|
||||||
attn_weights += attn_mask
|
attn_weights += attn_mask
|
||||||
|
|
||||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
|
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
|
||||||
attn_weights = nn.functional.dropout(attn_weights, p=attention_dropout, training=training)
|
|
||||||
attn_output = torch.matmul(attn_weights, value)
|
attn_output = torch.matmul(attn_weights, value)
|
||||||
|
|
||||||
if attn_output.size() != (bsz, num_heads, 1, head_size):
|
if attn_output.size() != (bsz, num_heads, 1, head_size):
|
||||||
|
@ -277,8 +260,6 @@ def pad_context_forward(
|
||||||
block_size = k_cache.shape[-1]
|
block_size = k_cache.shape[-1]
|
||||||
assert q.shape[0] == k.shape[0] == v.shape[0] == block_tables.shape[0]
|
assert q.shape[0] == k.shape[0] == v.shape[0] == block_tables.shape[0]
|
||||||
block_tables.shape[-1] * block_size
|
block_tables.shape[-1] * block_size
|
||||||
shape = (bsz, seq_len, num_heads, head_size)
|
|
||||||
input_shape = shape[:2]
|
|
||||||
|
|
||||||
# Copy kv to memory(rotary embedded)
|
# Copy kv to memory(rotary embedded)
|
||||||
copy_to_cache(k, k_cache, lengths=context_lengths, block_tables=block_tables)
|
copy_to_cache(k, k_cache, lengths=context_lengths, block_tables=block_tables)
|
||||||
|
|
|
@ -268,7 +268,7 @@ class BatchInfo:
|
||||||
for seq, token in zip(self.sequences_set, tokens):
|
for seq, token in zip(self.sequences_set, tokens):
|
||||||
if not isinstance(token, list):
|
if not isinstance(token, list):
|
||||||
if not isinstance(token, int):
|
if not isinstance(token, int):
|
||||||
raise TypeError(f"The token type must be List[int] or int, but get {type(token)}.")
|
raise TypeError(f"The token type must be List[int] or int, but got {type(token)}.")
|
||||||
token = [token]
|
token = [token]
|
||||||
seq.output_token_id += token
|
seq.output_token_id += token
|
||||||
seq.check_finish()
|
seq.check_finish()
|
||||||
|
|
|
@ -133,7 +133,7 @@ class CaiInferEngine:
|
||||||
"""
|
"""
|
||||||
assert isinstance(
|
assert isinstance(
|
||||||
input_list, (BatchEncoding, dict)
|
input_list, (BatchEncoding, dict)
|
||||||
), f"Only accept BatchEncoding or dict as input, but get {input_list.__class__.__name__}."
|
), f"Only accept BatchEncoding or dict as input, but got {input_list.__class__.__name__}."
|
||||||
if isinstance(input_list, BatchEncoding):
|
if isinstance(input_list, BatchEncoding):
|
||||||
input_list = input_list.data
|
input_list = input_list.data
|
||||||
out, timestamp = self.schedule.generate_step(self.model, iter([input_list]))
|
out, timestamp = self.schedule.generate_step(self.model, iter([input_list]))
|
||||||
|
|
|
@ -28,20 +28,24 @@ def check_inference_engine(test_cai=False):
|
||||||
)
|
)
|
||||||
).cuda()
|
).cuda()
|
||||||
|
|
||||||
|
model = model.eval()
|
||||||
|
|
||||||
inputs = [
|
inputs = [
|
||||||
"介绍一下今天的北京,",
|
"介绍一下今天的北京,比如故宫,天安门,长城或者其他的一些景点,",
|
||||||
"介绍一下武汉,",
|
"介绍一下武汉,",
|
||||||
]
|
]
|
||||||
|
|
||||||
output_len = 16
|
output_len = 128
|
||||||
do_sample = True
|
do_sample = True
|
||||||
|
top_p = 0.5
|
||||||
|
top_k = 50
|
||||||
|
|
||||||
if test_cai:
|
if test_cai:
|
||||||
inference_config = InferenceConfig(max_output_len=output_len)
|
inference_config = InferenceConfig(max_output_len=output_len)
|
||||||
inference_engine = InferenceEngine(model, tokenizer, inference_config, verbose=True)
|
inference_engine = InferenceEngine(model, tokenizer, inference_config, verbose=True)
|
||||||
inference_engine.add_request(prompts=inputs)
|
inference_engine.add_request(prompts=inputs)
|
||||||
assert inference_engine.request_handler._has_waiting()
|
assert inference_engine.request_handler._has_waiting()
|
||||||
generation_config = GenerationConfig(do_sample=do_sample, top_p=0.5, top_k=50)
|
generation_config = GenerationConfig(do_sample=do_sample, top_p=top_p, top_k=top_k)
|
||||||
outputs = inference_engine.generate(generation_config)
|
outputs = inference_engine.generate(generation_config)
|
||||||
else:
|
else:
|
||||||
tokenizer.pad_token = tokenizer.eos_token
|
tokenizer.pad_token = tokenizer.eos_token
|
||||||
|
@ -49,7 +53,11 @@ def check_inference_engine(test_cai=False):
|
||||||
inputs = tokenizer.batch_encode_plus(inputs, padding=True, return_tensors="pt")["input_ids"]
|
inputs = tokenizer.batch_encode_plus(inputs, padding=True, return_tensors="pt")["input_ids"]
|
||||||
inputs = inputs.cuda()
|
inputs = inputs.cuda()
|
||||||
generation_config = GenerationConfig(
|
generation_config = GenerationConfig(
|
||||||
do_sample=do_sample, top_p=0.5, top_k=50, pad_token_id=tokenizer.pad_token_id, max_new_tokens=output_len
|
do_sample=do_sample,
|
||||||
|
top_p=top_p,
|
||||||
|
top_k=top_k,
|
||||||
|
pad_token_id=tokenizer.pad_token_id,
|
||||||
|
max_new_tokens=output_len,
|
||||||
)
|
)
|
||||||
outputs = model.generate(inputs, generation_config=generation_config)
|
outputs = model.generate(inputs, generation_config=generation_config)
|
||||||
outputs = tokenizer.batch_decode(outputs, skip_special_tokens=True)
|
outputs = tokenizer.batch_decode(outputs, skip_special_tokens=True)
|
||||||
|
|
Loading…
Reference in New Issue