mirror of https://github.com/hpcaitech/ColossalAI
Fixed a bug in the inference frame
parent
86853a37d5
commit
62fd08ee44
|
@ -97,3 +97,6 @@ class InferenceConfig:
|
||||||
], "dtype should be one of 'fp16', 'fp32', 'bf16', torch.float32, torch.float16, torch.bfloat16"
|
], "dtype should be one of 'fp16', 'fp32', 'bf16', torch.float32, torch.float16, torch.bfloat16"
|
||||||
assert self.max_batch_size <= 64, "Max batch size exceeds the constraint"
|
assert self.max_batch_size <= 64, "Max batch size exceeds the constraint"
|
||||||
assert self.quant_mode in ["smoothquant", "gptq", None], "quant should be one of 'smoothquant', 'gptq'"
|
assert self.quant_mode in ["smoothquant", "gptq", None], "quant should be one of 'smoothquant', 'gptq'"
|
||||||
|
assert (
|
||||||
|
self.max_input_len + self.max_output_len <= self.max_seq_len
|
||||||
|
), "The sum of max_input_len and max_output_len must be smaller than max_seq_len."
|
||||||
|
|
|
@ -49,6 +49,7 @@ class InferenceEngine:
|
||||||
self.tokenizer.pad_token = self.tokenizer.eos_token
|
self.tokenizer.pad_token = self.tokenizer.eos_token
|
||||||
self.inference_config = inference_config
|
self.inference_config = inference_config
|
||||||
self.model_config = model.config
|
self.model_config = model.config
|
||||||
|
self.device = torch.device("cuda")
|
||||||
|
|
||||||
if inference_config.dtype == "fp32" or inference_config.dtype == torch.float32:
|
if inference_config.dtype == "fp32" or inference_config.dtype == torch.float32:
|
||||||
self.dtype = torch.float32
|
self.dtype = torch.float32
|
||||||
|
@ -76,6 +77,7 @@ class InferenceEngine:
|
||||||
self.logger = get_dist_logger(__name__)
|
self.logger = get_dist_logger(__name__)
|
||||||
|
|
||||||
self.request_handler = RequestHandler(self.inference_config, self.model_config)
|
self.request_handler = RequestHandler(self.inference_config, self.model_config)
|
||||||
|
self.k_cahce, self.v_cache = self.request_handler.get_kvcache()
|
||||||
self.counter = count()
|
self.counter = count()
|
||||||
|
|
||||||
def _verify_config(self) -> None:
|
def _verify_config(self) -> None:
|
||||||
|
@ -170,7 +172,11 @@ class InferenceEngine:
|
||||||
|
|
||||||
if prompts_token_ids is None:
|
if prompts_token_ids is None:
|
||||||
assert prompts, "When the prompts_token_ids is none, the input prompt list must be provided."
|
assert prompts, "When the prompts_token_ids is none, the input prompt list must be provided."
|
||||||
prompts_token_ids = self.tokenizer.batch_encode_plus(prompts)["input_ids"]
|
prompts_token_ids = self.tokenizer.batch_encode_plus(prompts, padding=True)["input_ids"]
|
||||||
|
|
||||||
|
assert (
|
||||||
|
len(prompts_token_ids[0]) < self.inference_config.max_input_len
|
||||||
|
), "The length of input prompts must be less than max_input_len."
|
||||||
|
|
||||||
prompts_num = len(prompts_token_ids)
|
prompts_num = len(prompts_token_ids)
|
||||||
|
|
||||||
|
@ -183,13 +189,14 @@ class InferenceEngine:
|
||||||
prompt = None
|
prompt = None
|
||||||
else:
|
else:
|
||||||
prompt = prompts[i]
|
prompt = prompts[i]
|
||||||
|
block_table = torch.full([self.inference_config.max_seq_len], -1, device=self.device)
|
||||||
sequence = Sequence(
|
sequence = Sequence(
|
||||||
request_id,
|
request_id,
|
||||||
prompt,
|
prompt,
|
||||||
prompts_token_ids[i],
|
prompts_token_ids[i],
|
||||||
block_size,
|
block_size,
|
||||||
None,
|
None,
|
||||||
None,
|
block_table,
|
||||||
self.tokenizer.eos_token_id,
|
self.tokenizer.eos_token_id,
|
||||||
self.inference_config.max_output_len,
|
self.inference_config.max_output_len,
|
||||||
)
|
)
|
||||||
|
@ -211,14 +218,15 @@ class InferenceEngine:
|
||||||
self.logger.info("Running generation step")
|
self.logger.info("Running generation step")
|
||||||
|
|
||||||
output_list = []
|
output_list = []
|
||||||
batch, k_cache, v_cache = self.request_handler.schedule()
|
batch = self.request_handler.schedule()
|
||||||
|
|
||||||
logits = self.model(
|
logits = self.model(
|
||||||
batch,
|
batch,
|
||||||
k_cache,
|
self.k_cahce,
|
||||||
v_cache,
|
self.v_cache,
|
||||||
)
|
)
|
||||||
self.request_handler.search_tokens(logits, self.generation_config)
|
|
||||||
|
self.request_handler.search_tokens(self.generation_config, logits)
|
||||||
|
|
||||||
finished_sequences = self.request_handler.update()
|
finished_sequences = self.request_handler.update()
|
||||||
|
|
||||||
|
|
|
@ -5,7 +5,6 @@ from transformers.configuration_utils import PretrainedConfig
|
||||||
|
|
||||||
from colossalai.inference.config import InferenceConfig
|
from colossalai.inference.config import InferenceConfig
|
||||||
from colossalai.inference.kv_cache import KVCacheManager
|
from colossalai.inference.kv_cache import KVCacheManager
|
||||||
from colossalai.inference.logit_processors import logit_processor
|
|
||||||
from colossalai.inference.sampler import *
|
from colossalai.inference.sampler import *
|
||||||
from colossalai.inference.struct import BatchInfo, Sequence
|
from colossalai.inference.struct import BatchInfo, Sequence
|
||||||
|
|
||||||
|
@ -49,7 +48,7 @@ class RunningList:
|
||||||
def ready_for_prefill(self):
|
def ready_for_prefill(self):
|
||||||
if not self.decoding:
|
if not self.decoding:
|
||||||
return len(self.prefill) > 0
|
return len(self.prefill) > 0
|
||||||
return len(self.prefill) / len(self.decoding) >= self.ratio
|
return len(self.prefill) / len(self.decoding) >= self.prefill_ratio
|
||||||
|
|
||||||
def is_empty(self):
|
def is_empty(self):
|
||||||
return not self.decoding and not self.prefill
|
return not self.decoding and not self.prefill
|
||||||
|
@ -72,8 +71,9 @@ class RequestHandler:
|
||||||
self.running_list: RunningList = RunningList(inference_config.prefill_ratio)
|
self.running_list: RunningList = RunningList(inference_config.prefill_ratio)
|
||||||
self.waiting_list: List[List] = [[], [], []]
|
self.waiting_list: List[List] = [[], [], []]
|
||||||
self.done_list: List[Sequence] = []
|
self.done_list: List[Sequence] = []
|
||||||
self.running_batch = BatchInfo(is_prompts=False)
|
device = torch.cuda.current_device()
|
||||||
self.prefill_batch = BatchInfo(is_prompts=True)
|
self.running_batch = BatchInfo(is_prompts=False, device=device)
|
||||||
|
self.prefill_batch = BatchInfo(is_prompts=True, device=device)
|
||||||
|
|
||||||
def _init_cache(self, model_config):
|
def _init_cache(self, model_config):
|
||||||
self.cache_manager = KVCacheManager(self.inference_config, model_config)
|
self.cache_manager = KVCacheManager(self.inference_config, model_config)
|
||||||
|
@ -81,6 +81,9 @@ class RequestHandler:
|
||||||
def _has_waiting(self) -> bool:
|
def _has_waiting(self) -> bool:
|
||||||
return any(lst for lst in self.waiting_list)
|
return any(lst for lst in self.waiting_list)
|
||||||
|
|
||||||
|
def get_kvcache(self):
|
||||||
|
return self.cache_manager.get_kv_cache()
|
||||||
|
|
||||||
def schedule(self):
|
def schedule(self):
|
||||||
"""
|
"""
|
||||||
The main logic of request handler.
|
The main logic of request handler.
|
||||||
|
@ -90,7 +93,7 @@ class RequestHandler:
|
||||||
for lst in reversed(self.waiting_list):
|
for lst in reversed(self.waiting_list):
|
||||||
if lst:
|
if lst:
|
||||||
for seq in lst:
|
for seq in lst:
|
||||||
if seq.prompt_len > self.inference_config.max_input_len:
|
if seq.input_len > self.inference_config.max_input_len:
|
||||||
# If the prompt length is longer than max_input_len, abort the sequence.
|
# If the prompt length is longer than max_input_len, abort the sequence.
|
||||||
self.abort_sequence(seq.request_id)
|
self.abort_sequence(seq.request_id)
|
||||||
break
|
break
|
||||||
|
@ -98,9 +101,8 @@ class RequestHandler:
|
||||||
if self.cache_manager.check_allocation(seq):
|
if self.cache_manager.check_allocation(seq):
|
||||||
# If succeed, add the sequence to running list.
|
# If succeed, add the sequence to running list.
|
||||||
self.running_list.append(seq)
|
self.running_list.append(seq)
|
||||||
self.cache_manager.allocate_context_from_block_table(seq.block_table, seq.prompt_len)
|
self.cache_manager.allocate_context_from_block_table(seq.block_table, seq.input_len)
|
||||||
lst.remove(seq)
|
lst.clear()
|
||||||
|
|
||||||
if self.running_list.ready_for_prefill():
|
if self.running_list.ready_for_prefill():
|
||||||
for seq in self.running_list.prefill:
|
for seq in self.running_list.prefill:
|
||||||
seq.mark_running()
|
seq.mark_running()
|
||||||
|
@ -115,10 +117,9 @@ class RequestHandler:
|
||||||
"""
|
"""
|
||||||
assert not self._find_sequence(req.request_id), f"Sequence {req.request_id} already exists."
|
assert not self._find_sequence(req.request_id), f"Sequence {req.request_id} already exists."
|
||||||
assert (
|
assert (
|
||||||
req.prompt_len < self.inference_config.max_input_len
|
req.input_len < self.inference_config.max_input_len
|
||||||
), f"Sequence {req.request_id} exceeds input length limit"
|
), f"Sequence {req.request_id} exceeds input length limit"
|
||||||
|
self.waiting_list[req.input_len * 3 // self.inference_config.max_input_len].append(req)
|
||||||
self.waiting_list[req.prompt_len * 3 // self.inference_config.max_input_len].append(req)
|
|
||||||
|
|
||||||
def abort_sequence(self, request_id: str):
|
def abort_sequence(self, request_id: str):
|
||||||
"""
|
"""
|
||||||
|
@ -178,9 +179,12 @@ class RequestHandler:
|
||||||
"""
|
"""
|
||||||
# do logit processor
|
# do logit processor
|
||||||
# NOTE: need to decide the granularity to process logits (sequence or batch)
|
# NOTE: need to decide the granularity to process logits (sequence or batch)
|
||||||
for type in ["top_p", "top_k", "min_p"]:
|
# for type in ["top_p", "top_k", "min_p"]:
|
||||||
if type in generation_config:
|
# config_dict = generation_config.to_dict()
|
||||||
logits = logit_processor(type, logits)
|
# if type in config_dict:
|
||||||
|
# logits = logit_processor(type, logits, config_dict[type])
|
||||||
|
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
|
||||||
# calculate probs
|
# calculate probs
|
||||||
probs = torch.softmax(logits, dim=-1, dtype=torch.float)
|
probs = torch.softmax(logits, dim=-1, dtype=torch.float)
|
||||||
|
@ -188,6 +192,9 @@ class RequestHandler:
|
||||||
|
|
||||||
# sample the next tokens
|
# sample the next tokens
|
||||||
sample_tokens = self._sample(probs, logprobs, generation_config)
|
sample_tokens = self._sample(probs, logprobs, generation_config)
|
||||||
|
if not self.prefill_batch.is_empty:
|
||||||
|
self.prefill_batch.update_batch_tokens(sample_tokens)
|
||||||
|
else:
|
||||||
self.running_batch.update_batch_tokens(sample_tokens)
|
self.running_batch.update_batch_tokens(sample_tokens)
|
||||||
|
|
||||||
def update(self):
|
def update(self):
|
||||||
|
|
|
@ -112,7 +112,7 @@ class KVCacheManager:
|
||||||
|
|
||||||
def get_kv_cache(self):
|
def get_kv_cache(self):
|
||||||
"""Get k_cache and v_cache"""
|
"""Get k_cache and v_cache"""
|
||||||
return self._kv_cache[0], self._kv_cache[1]
|
return self._kv_caches[0], self._kv_caches[1]
|
||||||
|
|
||||||
def get_max_blocks_per_sequence(self) -> int:
|
def get_max_blocks_per_sequence(self) -> int:
|
||||||
"""Get the maximum number of blocks that can be allocated for a single sequence."""
|
"""Get the maximum number of blocks that can be allocated for a single sequence."""
|
||||||
|
@ -122,7 +122,7 @@ class KVCacheManager:
|
||||||
return self.max_blocks_per_sequence
|
return self.max_blocks_per_sequence
|
||||||
|
|
||||||
def check_allocation(self, seq: Sequence) -> bool:
|
def check_allocation(self, seq: Sequence) -> bool:
|
||||||
num_blocks_needed = (seq.prompt_len + self.max_output_length + self.block_size - 1) // self.block_size
|
num_blocks_needed = (seq.input_len + self.max_output_length + self.block_size - 1) // self.block_size
|
||||||
return num_blocks_needed <= self.num_available_blocks
|
return num_blocks_needed <= self.num_available_blocks
|
||||||
|
|
||||||
def get_block_kv_ptrs(self, block_id: int, layer_id: int) -> Tuple[List[int], List[int]]:
|
def get_block_kv_ptrs(self, block_id: int, layer_id: int) -> Tuple[List[int], List[int]]:
|
||||||
|
|
|
@ -70,7 +70,10 @@ def llama_model_forward(
|
||||||
seq_length = input_ids.shape[1]
|
seq_length = input_ids.shape[1]
|
||||||
device = input_ids.device
|
device = input_ids.device
|
||||||
|
|
||||||
past_key_values_length = len(block_tables.shape[1])
|
if batch.is_prompts:
|
||||||
|
past_key_values_length = 0
|
||||||
|
else:
|
||||||
|
past_key_values_length = sequence_lengths[0].item() - 1
|
||||||
|
|
||||||
position_ids = torch.arange(
|
position_ids = torch.arange(
|
||||||
past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
|
past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
|
||||||
|
@ -163,26 +166,17 @@ def llama_attn_forward(
|
||||||
key_states = key_states.view(-1, self.num_heads, self.head_dim)
|
key_states = key_states.view(-1, self.num_heads, self.head_dim)
|
||||||
value_states = value_states.view(-1, self.num_heads, self.head_dim)
|
value_states = value_states.view(-1, self.num_heads, self.head_dim)
|
||||||
|
|
||||||
block_size = k_cache.shape[-1]
|
k_cache.shape[-1]
|
||||||
|
|
||||||
memcpy_to_block(key_states, value_states, k_cache, v_cache, block_tables, block_size)
|
# memcpy_to_block(key_states, value_states, k_cache, v_cache, block_tables, block_size, sequence_lengths)
|
||||||
|
|
||||||
if is_prompts:
|
# if is_prompts:
|
||||||
attn_output = context_attention_unpadded(
|
# attn_output = context_attention_unpadded(query_states, key_states, value_states, k_cache, v_cache, sequence_lengths, block_tables, block_size)
|
||||||
query_states, key_states, value_states, k_cache, v_cache, sequence_lengths, block_tables, block_size
|
# else:
|
||||||
)
|
# attn_output = torch.empty(bsz, self.num_heads, self.head_dim)
|
||||||
else:
|
# decoding_attention(query_states, k_cache, v_cache, block_tables, sequence_lengths, attn_output, block_tables.shape[1], block_size)
|
||||||
attn_output = torch.empty(bsz, self.num_heads, self.head_dim)
|
|
||||||
decoding_attention(
|
attn_output = query_states
|
||||||
query_states,
|
|
||||||
k_cache,
|
|
||||||
v_cache,
|
|
||||||
block_tables,
|
|
||||||
sequence_lengths,
|
|
||||||
attn_output,
|
|
||||||
block_tables.shape[1],
|
|
||||||
block_size,
|
|
||||||
)
|
|
||||||
|
|
||||||
attn_output = attn_output.view(bsz, q_len, self.num_heads, self.head_dim)
|
attn_output = attn_output.view(bsz, q_len, self.num_heads, self.head_dim)
|
||||||
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
|
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
|
||||||
|
@ -190,19 +184,3 @@ def llama_attn_forward(
|
||||||
attn_output = self.o_proj(attn_output)
|
attn_output = self.o_proj(attn_output)
|
||||||
|
|
||||||
return attn_output
|
return attn_output
|
||||||
|
|
||||||
|
|
||||||
def memcpy_to_block(key, value, k_cache, v_cache, block_tables, block_size):
|
|
||||||
block_table_list = block_tables.tolist()
|
|
||||||
batch_size, seq_len, num_heads, head_dim = key
|
|
||||||
|
|
||||||
reshape_key = key.reshape(batch_size, seq_len, block_size, num_heads, head_dim).tensor.permute(0, 2, 3, 1)
|
|
||||||
reshape_value = value.reshape(batch_size, seq_len, block_size, num_heads, head_dim).tensor.permute(0, 2, 3, 1)
|
|
||||||
if seq_len == 1:
|
|
||||||
for i in range(batch_size):
|
|
||||||
k_cache[block_table_list[i][-1], :] = reshape_key[i]
|
|
||||||
v_cache[block_table_list[i][-1], :] = reshape_value[i]
|
|
||||||
else:
|
|
||||||
for i in range(batch_size):
|
|
||||||
k_cache[block_table_list[i], :] = reshape_key[i]
|
|
||||||
v_cache[block_table_list[i], :] = reshape_value[i]
|
|
||||||
|
|
|
@ -1,7 +1,165 @@
|
||||||
|
from functools import partial
|
||||||
|
|
||||||
|
from transformers.models.llama.modeling_llama import (
|
||||||
|
LlamaAttention,
|
||||||
|
LlamaDecoderLayer,
|
||||||
|
LlamaFlashAttention2,
|
||||||
|
LlamaForCausalLM,
|
||||||
|
LlamaModel,
|
||||||
|
LlamaSdpaAttention,
|
||||||
|
)
|
||||||
|
|
||||||
|
from colossalai.inference.modeling.models.llama import (
|
||||||
|
llama_attn_forward,
|
||||||
|
llama_causal_lm_forward,
|
||||||
|
llama_decoder_layer_forward,
|
||||||
|
llama_model_forward,
|
||||||
|
)
|
||||||
|
from colossalai.shardformer.policies.base_policy import ModulePolicyDescription, SubModuleReplacementDescription
|
||||||
|
|
||||||
|
# import colossalai
|
||||||
from colossalai.shardformer.policies.llama import LlamaForCausalLMPolicy
|
from colossalai.shardformer.policies.llama import LlamaForCausalLMPolicy
|
||||||
|
|
||||||
|
|
||||||
class LlamaModelInferPolicy(LlamaForCausalLMPolicy):
|
class LlamaModelInferPolicy(LlamaForCausalLMPolicy):
|
||||||
# The code here just for test and will be modified later.
|
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
|
def module_policy(self):
|
||||||
|
policy = super().module_policy()
|
||||||
|
decoder_attribute_replacement = {
|
||||||
|
"self_attn.hidden_size": self.model.config.hidden_size // self.shard_config.tensor_parallel_size,
|
||||||
|
"self_attn.num_heads": self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size,
|
||||||
|
"self_attn.num_key_value_heads": self.model.config.num_key_value_heads
|
||||||
|
// self.shard_config.tensor_parallel_size,
|
||||||
|
}
|
||||||
|
if self.shard_config.extra_kwargs.get("quant", None) == "gptq":
|
||||||
|
from colossalai.inference.quant.gptq.cai_gptq import ColCaiQuantLinear, RowCaiQuantLinear
|
||||||
|
|
||||||
|
policy[LlamaDecoderLayer] = ModulePolicyDescription(
|
||||||
|
attribute_replacement=decoder_attribute_replacement,
|
||||||
|
sub_module_replacement=[
|
||||||
|
SubModuleReplacementDescription(
|
||||||
|
suffix="self_attn.q_proj",
|
||||||
|
target_module=ColCaiQuantLinear,
|
||||||
|
kwargs={"split_num": 1},
|
||||||
|
),
|
||||||
|
SubModuleReplacementDescription(
|
||||||
|
suffix="self_attn.k_proj",
|
||||||
|
target_module=ColCaiQuantLinear,
|
||||||
|
kwargs={"split_num": 1},
|
||||||
|
),
|
||||||
|
SubModuleReplacementDescription(
|
||||||
|
suffix="self_attn.v_proj",
|
||||||
|
target_module=ColCaiQuantLinear,
|
||||||
|
kwargs={"split_num": 1},
|
||||||
|
),
|
||||||
|
SubModuleReplacementDescription(
|
||||||
|
suffix="self_attn.o_proj",
|
||||||
|
target_module=RowCaiQuantLinear,
|
||||||
|
kwargs={"split_num": 1},
|
||||||
|
),
|
||||||
|
SubModuleReplacementDescription(
|
||||||
|
suffix="mlp.gate_proj",
|
||||||
|
target_module=ColCaiQuantLinear,
|
||||||
|
kwargs={"split_num": 1},
|
||||||
|
),
|
||||||
|
SubModuleReplacementDescription(
|
||||||
|
suffix="mlp.up_proj",
|
||||||
|
target_module=ColCaiQuantLinear,
|
||||||
|
kwargs={"split_num": 1},
|
||||||
|
),
|
||||||
|
SubModuleReplacementDescription(
|
||||||
|
suffix="mlp.down_proj",
|
||||||
|
target_module=RowCaiQuantLinear,
|
||||||
|
kwargs={"split_num": 1},
|
||||||
|
),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
elif self.shard_config.extra_kwargs.get("quant", None) == "smoothquant":
|
||||||
|
from colossalai.inference.quant.smoothquant.models.llama import LlamaSmoothquantDecoderLayer
|
||||||
|
from colossalai.inference.quant.smoothquant.models.parallel_linear import (
|
||||||
|
ColW8A8BFP32OFP32Linear,
|
||||||
|
RowW8A8B8O8Linear,
|
||||||
|
RowW8A8BFP32O32LinearSiLU,
|
||||||
|
RowW8A8BFP32OFP32Linear,
|
||||||
|
)
|
||||||
|
|
||||||
|
policy[LlamaSmoothquantDecoderLayer] = ModulePolicyDescription(
|
||||||
|
attribute_replacement=decoder_attribute_replacement,
|
||||||
|
sub_module_replacement=[
|
||||||
|
SubModuleReplacementDescription(
|
||||||
|
suffix="self_attn.q_proj",
|
||||||
|
target_module=RowW8A8B8O8Linear,
|
||||||
|
kwargs={"split_num": 1},
|
||||||
|
),
|
||||||
|
SubModuleReplacementDescription(
|
||||||
|
suffix="self_attn.k_proj",
|
||||||
|
target_module=RowW8A8B8O8Linear,
|
||||||
|
kwargs={"split_num": 1},
|
||||||
|
),
|
||||||
|
SubModuleReplacementDescription(
|
||||||
|
suffix="self_attn.v_proj",
|
||||||
|
target_module=RowW8A8B8O8Linear,
|
||||||
|
kwargs={"split_num": 1},
|
||||||
|
),
|
||||||
|
SubModuleReplacementDescription(
|
||||||
|
suffix="self_attn.o_proj",
|
||||||
|
target_module=ColW8A8BFP32OFP32Linear,
|
||||||
|
kwargs={"split_num": 1},
|
||||||
|
),
|
||||||
|
SubModuleReplacementDescription(
|
||||||
|
suffix="mlp.gate_proj",
|
||||||
|
target_module=RowW8A8BFP32O32LinearSiLU,
|
||||||
|
kwargs={"split_num": 1},
|
||||||
|
),
|
||||||
|
SubModuleReplacementDescription(
|
||||||
|
suffix="mlp.up_proj",
|
||||||
|
target_module=RowW8A8BFP32OFP32Linear,
|
||||||
|
kwargs={"split_num": 1},
|
||||||
|
),
|
||||||
|
SubModuleReplacementDescription(
|
||||||
|
suffix="mlp.down_proj",
|
||||||
|
target_module=ColW8A8BFP32OFP32Linear,
|
||||||
|
kwargs={"split_num": 1},
|
||||||
|
),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
self.shard_config._infer()
|
||||||
|
|
||||||
|
infer_forward = llama_causal_lm_forward
|
||||||
|
method_replacement = {"forward": partial(infer_forward)}
|
||||||
|
self.append_or_create_method_replacement(
|
||||||
|
description=method_replacement, policy=policy, target_key=LlamaForCausalLM
|
||||||
|
)
|
||||||
|
|
||||||
|
infer_forward = llama_model_forward
|
||||||
|
method_replacement = {"forward": partial(infer_forward)}
|
||||||
|
self.append_or_create_method_replacement(description=method_replacement, policy=policy, target_key=LlamaModel)
|
||||||
|
|
||||||
|
infer_forward = llama_decoder_layer_forward
|
||||||
|
method_replacement = {"forward": partial(infer_forward)}
|
||||||
|
self.append_or_create_method_replacement(
|
||||||
|
description=method_replacement, policy=policy, target_key=LlamaDecoderLayer
|
||||||
|
)
|
||||||
|
|
||||||
|
infer_forward = llama_attn_forward
|
||||||
|
method_replacement = {"forward": partial(infer_forward)}
|
||||||
|
self.append_or_create_method_replacement(
|
||||||
|
description=method_replacement, policy=policy, target_key=LlamaAttention
|
||||||
|
)
|
||||||
|
|
||||||
|
infer_forward = llama_attn_forward
|
||||||
|
method_replacement = {"forward": partial(infer_forward)}
|
||||||
|
self.append_or_create_method_replacement(
|
||||||
|
description=method_replacement, policy=policy, target_key=LlamaFlashAttention2
|
||||||
|
)
|
||||||
|
|
||||||
|
infer_forward = llama_attn_forward
|
||||||
|
method_replacement = {"forward": partial(infer_forward)}
|
||||||
|
self.append_or_create_method_replacement(
|
||||||
|
description=method_replacement, policy=policy, target_key=LlamaSdpaAttention
|
||||||
|
)
|
||||||
|
|
||||||
|
return policy
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
import enum
|
import enum
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Any, List, Union
|
from typing import Any, List, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from ordered_set import OrderedSet
|
from ordered_set import OrderedSet
|
||||||
|
@ -74,13 +74,6 @@ class Sequence:
|
||||||
self.output_token_id = []
|
self.output_token_id = []
|
||||||
self.status = RequestStatus.WAITING
|
self.status = RequestStatus.WAITING
|
||||||
|
|
||||||
@property
|
|
||||||
def prompt_len(self) -> int:
|
|
||||||
"""
|
|
||||||
Get length of prompts
|
|
||||||
"""
|
|
||||||
return len(self.input_token_id)
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def sentence_len(self) -> int:
|
def sentence_len(self) -> int:
|
||||||
"""
|
"""
|
||||||
|
@ -113,7 +106,7 @@ class Sequence:
|
||||||
return True
|
return True
|
||||||
|
|
||||||
if self.output_token_id:
|
if self.output_token_id:
|
||||||
if self.output_token_id[-1] == self.eos_token_id or len(self.output_token_id) == self.max_output_len:
|
if self.output_token_id[-1] >= self.eos_token_id or len(self.output_token_id) == self.max_output_len:
|
||||||
self.status = RequestStatus.COMPLETED
|
self.status = RequestStatus.COMPLETED
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
@ -143,11 +136,13 @@ class Sequence:
|
||||||
|
|
||||||
def __repr__(self) -> str:
|
def __repr__(self) -> str:
|
||||||
return (
|
return (
|
||||||
f"Request ID(request_id={self.request_id}, "
|
f"(request_id={self.request_id}, "
|
||||||
f"prompt={self.prompt}, "
|
f"prompt={self.prompt}, "
|
||||||
f"status={self.status.name}, "
|
f"status={self.status.name}, "
|
||||||
f"sample_params={self.sample_params}, "
|
f"sample_params={self.sample_params}, "
|
||||||
f"logical block number={len(self.block_table_index)}"
|
f"logical_block_number={self.block_table.shape[0]},"
|
||||||
|
f"input_len={self.input_len}),"
|
||||||
|
f"output_len={self.output_len})"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -159,9 +154,15 @@ class BatchInfo:
|
||||||
|
|
||||||
sequences_set: OrderedSet["Sequence"] = None
|
sequences_set: OrderedSet["Sequence"] = None
|
||||||
is_prompts: bool = True
|
is_prompts: bool = True
|
||||||
|
device: torch.device = None
|
||||||
|
|
||||||
@classmethod
|
def __post_init__(self):
|
||||||
def init_batch(cls, seqs: List["Sequence"] = None) -> "BatchInfo":
|
if self.device is None:
|
||||||
|
self.device = torch.cuda.current_device()
|
||||||
|
if self.sequences_set is None:
|
||||||
|
self.sequences_set = OrderedSet()
|
||||||
|
|
||||||
|
def init_batch(self, seqs: List["Sequence"] = None):
|
||||||
"""
|
"""
|
||||||
Initializes inference batches by input sentence list.
|
Initializes inference batches by input sentence list.
|
||||||
|
|
||||||
|
@ -169,29 +170,29 @@ class BatchInfo:
|
||||||
seqs (List["Sequence"]): List of input sequence.
|
seqs (List["Sequence"]): List of input sequence.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
sequences_set = OrderedSet()
|
assert len(self.sequences_set) == 0, "Sequences set has been initialized."
|
||||||
|
|
||||||
if seqs is not None:
|
if seqs is not None:
|
||||||
if not isinstance(seqs, list):
|
if not isinstance(seqs, list):
|
||||||
seqs = [seqs]
|
seqs = [seqs]
|
||||||
for seq in seqs:
|
for seq in seqs:
|
||||||
if seq in sequences_set:
|
if seq in self.sequences_set:
|
||||||
logger.warning(f"The sequence(request_id {seq.request_id}) is already in sequences_set.")
|
logger.warning(f"The sequence(request_id {seq.request_id}) is already in sequences_set.")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
sequences_set.add(seq)
|
self.sequences_set.add(seq)
|
||||||
|
|
||||||
return cls(sequences_set=sequences_set)
|
|
||||||
|
|
||||||
def get_block_table_tensor(self) -> None:
|
def get_block_table_tensor(self) -> None:
|
||||||
tesnor_list = []
|
tesnor_list = []
|
||||||
block_table = None
|
block_table = None
|
||||||
for seq in self.sequences_set:
|
for seq in self.sequences_set:
|
||||||
block_table = seq.block_table
|
block_table = seq.block_table
|
||||||
assert block_table, f"The sequence(request_id {seq.request_id}) has not initialized the block_table."
|
assert (
|
||||||
|
block_table is not None
|
||||||
|
), f"The sequence(request_id {seq.request_id}) has not initialized the block_table."
|
||||||
tesnor_list.append(seq.block_table)
|
tesnor_list.append(seq.block_table)
|
||||||
assert tesnor_list, "Batch has not been initialized yet. Please initialize batch first."
|
assert tesnor_list, "Batch has not been initialized yet. Please initialize batch first."
|
||||||
block_table = torch.concat(tesnor_list)
|
block_table = torch.stack(tesnor_list)
|
||||||
return block_table
|
return block_table
|
||||||
|
|
||||||
def clear_batch(self) -> None:
|
def clear_batch(self) -> None:
|
||||||
|
@ -239,7 +240,7 @@ class BatchInfo:
|
||||||
seqs = [seqs]
|
seqs = [seqs]
|
||||||
|
|
||||||
for seq in seqs:
|
for seq in seqs:
|
||||||
if seq in self.sequences_set:
|
if self.sequences_set and seq in self.sequences_set:
|
||||||
logger.warning(f"The sequence(request_id {seq.request_id}) is already in sequences_set.")
|
logger.warning(f"The sequence(request_id {seq.request_id}) is already in sequences_set.")
|
||||||
continue
|
continue
|
||||||
self.sequences_set.add(seq)
|
self.sequences_set.add(seq)
|
||||||
|
@ -251,7 +252,7 @@ class BatchInfo:
|
||||||
"""
|
"""
|
||||||
return not self.sequences_set
|
return not self.sequences_set
|
||||||
|
|
||||||
def update_batch_tokens(self, tokens: Union[List[int], List[List[int]]]) -> None:
|
def update_batch_tokens(self, tokens: Union[List[int], List[List[int]], torch.Tensor]) -> None:
|
||||||
"""
|
"""
|
||||||
Add an output token for each sentence in the batch.
|
Add an output token for each sentence in the batch.
|
||||||
|
|
||||||
|
@ -259,6 +260,9 @@ class BatchInfo:
|
||||||
tokens (List[int]): A batch of tokens
|
tokens (List[int]): A batch of tokens
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
if isinstance(tokens, torch.Tensor):
|
||||||
|
tokens = tokens.tolist()
|
||||||
|
|
||||||
assert self.get_batch_size() == len(tokens), "The number of tokens does not match batch_size."
|
assert self.get_batch_size() == len(tokens), "The number of tokens does not match batch_size."
|
||||||
|
|
||||||
for seq, token in zip(self.sequences_set, tokens):
|
for seq, token in zip(self.sequences_set, tokens):
|
||||||
|
@ -287,19 +291,25 @@ class BatchInfo:
|
||||||
else:
|
else:
|
||||||
input_list.append([seq.output_token_id[-1]])
|
input_list.append([seq.output_token_id[-1]])
|
||||||
|
|
||||||
return torch.tensor(input_list, dtype=torch.long)
|
return torch.tensor(input_list, dtype=torch.long, device=self.device)
|
||||||
|
|
||||||
def get_1D_inputs(self) -> Tuple[torch.LongTensor, torch.Tensor]:
|
def get_1D_inputs(self) -> Tuple[torch.LongTensor, torch.Tensor]:
|
||||||
"""
|
"""
|
||||||
Flattening the input tokens.
|
Flattening the input tokens.
|
||||||
"""
|
"""
|
||||||
input_list = []
|
input_list = []
|
||||||
|
input_len_list = []
|
||||||
for seq in self.sequences_set:
|
for seq in self.sequences_set:
|
||||||
if self.is_prompts:
|
if self.is_prompts:
|
||||||
input_list.extend(seq.input_token_id)
|
input_list.extend(seq.input_token_id)
|
||||||
|
input_len_list.append(seq.sentence_len)
|
||||||
else:
|
else:
|
||||||
input_list.append(seq.output_token_id[-1])
|
input_list.append(seq.output_token_id[-1])
|
||||||
return torch.tensor(input_list, dtype=torch.long)
|
input_len_list.append(1)
|
||||||
|
|
||||||
|
return torch.tensor(input_list, dtype=torch.long, device=self.device), torch.tensor(
|
||||||
|
input_len_list, dtype=torch.int, device=device
|
||||||
|
)
|
||||||
|
|
||||||
def get_sequence_lengths(self):
|
def get_sequence_lengths(self):
|
||||||
"""
|
"""
|
||||||
|
@ -307,5 +317,9 @@ class BatchInfo:
|
||||||
"""
|
"""
|
||||||
len_list = []
|
len_list = []
|
||||||
for seq in self.sequences_set:
|
for seq in self.sequences_set:
|
||||||
len_list.append(seq.get_sentence_len())
|
len_list.append(seq.sentence_len)
|
||||||
return torch.tensor(len_list, dtype=torch.int)
|
|
||||||
|
return torch.tensor(len_list, dtype=torch.int, device=self.device)
|
||||||
|
|
||||||
|
def __repr__(self) -> str:
|
||||||
|
return f"(sequences_set={self.sequences_set}, " f"is_prompts={self.is_prompts})"
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
import pytest
|
import pytest
|
||||||
import transformers
|
import transformers
|
||||||
from transformers import AutoTokenizer
|
from transformers import AutoTokenizer, GenerationConfig
|
||||||
|
|
||||||
import colossalai
|
import colossalai
|
||||||
from colossalai.inference.config import InferenceConfig
|
from colossalai.inference.config import InferenceConfig
|
||||||
|
@ -11,21 +11,24 @@ from colossalai.testing import spawn
|
||||||
def check_inference_engine():
|
def check_inference_engine():
|
||||||
model = transformers.LlamaForCausalLM(
|
model = transformers.LlamaForCausalLM(
|
||||||
transformers.LlamaConfig(
|
transformers.LlamaConfig(
|
||||||
vocab_size=20000, hidden_size=512, intermediate_size=1536, num_attention_heads=4, num_hidden_layers=4
|
vocab_size=50000, hidden_size=512, intermediate_size=1536, num_attention_heads=4, num_hidden_layers=4
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/llama-tokenizer")
|
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/llama-tokenizer")
|
||||||
inference_config = InferenceConfig()
|
inference_config = InferenceConfig(max_output_len=5)
|
||||||
inference_engine = InferenceEngine(model, tokenizer, inference_config, verbose=True)
|
inference_engine = InferenceEngine(model, tokenizer, inference_config, verbose=True)
|
||||||
|
|
||||||
inputs = [
|
inputs = [
|
||||||
"介绍一下北京",
|
"介绍一下今天的北京",
|
||||||
"介绍一下武汉",
|
"介绍一下武汉",
|
||||||
]
|
]
|
||||||
|
|
||||||
inference_engine.add_request(prompts=inputs)
|
inference_engine.add_request(prompts=inputs)
|
||||||
assert inference_engine.request_handler._has_waiting()
|
assert inference_engine.request_handler._has_waiting()
|
||||||
# outputs = inference_engine.generate(None)
|
generation_config = GenerationConfig(top_k=2, top_p=0.8, do_sample=True)
|
||||||
|
outputs = inference_engine.generate(generation_config)
|
||||||
|
|
||||||
|
print("outputs: ", outputs)
|
||||||
|
|
||||||
# Engine still gets some bug
|
# Engine still gets some bug
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue