Fixed a bug in the inference frame

pull/5258/head
yuehuayingxueluo 2023-12-26 21:34:27 +08:00 committed by FrankLeeeee
parent 86853a37d5
commit 62fd08ee44
8 changed files with 261 additions and 90 deletions

View File

@ -97,3 +97,6 @@ class InferenceConfig:
], "dtype should be one of 'fp16', 'fp32', 'bf16', torch.float32, torch.float16, torch.bfloat16" ], "dtype should be one of 'fp16', 'fp32', 'bf16', torch.float32, torch.float16, torch.bfloat16"
assert self.max_batch_size <= 64, "Max batch size exceeds the constraint" assert self.max_batch_size <= 64, "Max batch size exceeds the constraint"
assert self.quant_mode in ["smoothquant", "gptq", None], "quant should be one of 'smoothquant', 'gptq'" assert self.quant_mode in ["smoothquant", "gptq", None], "quant should be one of 'smoothquant', 'gptq'"
assert (
self.max_input_len + self.max_output_len <= self.max_seq_len
), "The sum of max_input_len and max_output_len must be smaller than max_seq_len."

View File

@ -49,6 +49,7 @@ class InferenceEngine:
self.tokenizer.pad_token = self.tokenizer.eos_token self.tokenizer.pad_token = self.tokenizer.eos_token
self.inference_config = inference_config self.inference_config = inference_config
self.model_config = model.config self.model_config = model.config
self.device = torch.device("cuda")
if inference_config.dtype == "fp32" or inference_config.dtype == torch.float32: if inference_config.dtype == "fp32" or inference_config.dtype == torch.float32:
self.dtype = torch.float32 self.dtype = torch.float32
@ -76,6 +77,7 @@ class InferenceEngine:
self.logger = get_dist_logger(__name__) self.logger = get_dist_logger(__name__)
self.request_handler = RequestHandler(self.inference_config, self.model_config) self.request_handler = RequestHandler(self.inference_config, self.model_config)
self.k_cahce, self.v_cache = self.request_handler.get_kvcache()
self.counter = count() self.counter = count()
def _verify_config(self) -> None: def _verify_config(self) -> None:
@ -170,7 +172,11 @@ class InferenceEngine:
if prompts_token_ids is None: if prompts_token_ids is None:
assert prompts, "When the prompts_token_ids is none, the input prompt list must be provided." assert prompts, "When the prompts_token_ids is none, the input prompt list must be provided."
prompts_token_ids = self.tokenizer.batch_encode_plus(prompts)["input_ids"] prompts_token_ids = self.tokenizer.batch_encode_plus(prompts, padding=True)["input_ids"]
assert (
len(prompts_token_ids[0]) < self.inference_config.max_input_len
), "The length of input prompts must be less than max_input_len."
prompts_num = len(prompts_token_ids) prompts_num = len(prompts_token_ids)
@ -183,13 +189,14 @@ class InferenceEngine:
prompt = None prompt = None
else: else:
prompt = prompts[i] prompt = prompts[i]
block_table = torch.full([self.inference_config.max_seq_len], -1, device=self.device)
sequence = Sequence( sequence = Sequence(
request_id, request_id,
prompt, prompt,
prompts_token_ids[i], prompts_token_ids[i],
block_size, block_size,
None, None,
None, block_table,
self.tokenizer.eos_token_id, self.tokenizer.eos_token_id,
self.inference_config.max_output_len, self.inference_config.max_output_len,
) )
@ -211,14 +218,15 @@ class InferenceEngine:
self.logger.info("Running generation step") self.logger.info("Running generation step")
output_list = [] output_list = []
batch, k_cache, v_cache = self.request_handler.schedule() batch = self.request_handler.schedule()
logits = self.model( logits = self.model(
batch, batch,
k_cache, self.k_cahce,
v_cache, self.v_cache,
) )
self.request_handler.search_tokens(logits, self.generation_config)
self.request_handler.search_tokens(self.generation_config, logits)
finished_sequences = self.request_handler.update() finished_sequences = self.request_handler.update()

View File

@ -5,7 +5,6 @@ from transformers.configuration_utils import PretrainedConfig
from colossalai.inference.config import InferenceConfig from colossalai.inference.config import InferenceConfig
from colossalai.inference.kv_cache import KVCacheManager from colossalai.inference.kv_cache import KVCacheManager
from colossalai.inference.logit_processors import logit_processor
from colossalai.inference.sampler import * from colossalai.inference.sampler import *
from colossalai.inference.struct import BatchInfo, Sequence from colossalai.inference.struct import BatchInfo, Sequence
@ -49,7 +48,7 @@ class RunningList:
def ready_for_prefill(self): def ready_for_prefill(self):
if not self.decoding: if not self.decoding:
return len(self.prefill) > 0 return len(self.prefill) > 0
return len(self.prefill) / len(self.decoding) >= self.ratio return len(self.prefill) / len(self.decoding) >= self.prefill_ratio
def is_empty(self): def is_empty(self):
return not self.decoding and not self.prefill return not self.decoding and not self.prefill
@ -72,8 +71,9 @@ class RequestHandler:
self.running_list: RunningList = RunningList(inference_config.prefill_ratio) self.running_list: RunningList = RunningList(inference_config.prefill_ratio)
self.waiting_list: List[List] = [[], [], []] self.waiting_list: List[List] = [[], [], []]
self.done_list: List[Sequence] = [] self.done_list: List[Sequence] = []
self.running_batch = BatchInfo(is_prompts=False) device = torch.cuda.current_device()
self.prefill_batch = BatchInfo(is_prompts=True) self.running_batch = BatchInfo(is_prompts=False, device=device)
self.prefill_batch = BatchInfo(is_prompts=True, device=device)
def _init_cache(self, model_config): def _init_cache(self, model_config):
self.cache_manager = KVCacheManager(self.inference_config, model_config) self.cache_manager = KVCacheManager(self.inference_config, model_config)
@ -81,6 +81,9 @@ class RequestHandler:
def _has_waiting(self) -> bool: def _has_waiting(self) -> bool:
return any(lst for lst in self.waiting_list) return any(lst for lst in self.waiting_list)
def get_kvcache(self):
return self.cache_manager.get_kv_cache()
def schedule(self): def schedule(self):
""" """
The main logic of request handler. The main logic of request handler.
@ -90,7 +93,7 @@ class RequestHandler:
for lst in reversed(self.waiting_list): for lst in reversed(self.waiting_list):
if lst: if lst:
for seq in lst: for seq in lst:
if seq.prompt_len > self.inference_config.max_input_len: if seq.input_len > self.inference_config.max_input_len:
# If the prompt length is longer than max_input_len, abort the sequence. # If the prompt length is longer than max_input_len, abort the sequence.
self.abort_sequence(seq.request_id) self.abort_sequence(seq.request_id)
break break
@ -98,9 +101,8 @@ class RequestHandler:
if self.cache_manager.check_allocation(seq): if self.cache_manager.check_allocation(seq):
# If succeed, add the sequence to running list. # If succeed, add the sequence to running list.
self.running_list.append(seq) self.running_list.append(seq)
self.cache_manager.allocate_context_from_block_table(seq.block_table, seq.prompt_len) self.cache_manager.allocate_context_from_block_table(seq.block_table, seq.input_len)
lst.remove(seq) lst.clear()
if self.running_list.ready_for_prefill(): if self.running_list.ready_for_prefill():
for seq in self.running_list.prefill: for seq in self.running_list.prefill:
seq.mark_running() seq.mark_running()
@ -115,10 +117,9 @@ class RequestHandler:
""" """
assert not self._find_sequence(req.request_id), f"Sequence {req.request_id} already exists." assert not self._find_sequence(req.request_id), f"Sequence {req.request_id} already exists."
assert ( assert (
req.prompt_len < self.inference_config.max_input_len req.input_len < self.inference_config.max_input_len
), f"Sequence {req.request_id} exceeds input length limit" ), f"Sequence {req.request_id} exceeds input length limit"
self.waiting_list[req.input_len * 3 // self.inference_config.max_input_len].append(req)
self.waiting_list[req.prompt_len * 3 // self.inference_config.max_input_len].append(req)
def abort_sequence(self, request_id: str): def abort_sequence(self, request_id: str):
""" """
@ -178,9 +179,12 @@ class RequestHandler:
""" """
# do logit processor # do logit processor
# NOTE: need to decide the granularity to process logits (sequence or batch) # NOTE: need to decide the granularity to process logits (sequence or batch)
for type in ["top_p", "top_k", "min_p"]: # for type in ["top_p", "top_k", "min_p"]:
if type in generation_config: # config_dict = generation_config.to_dict()
logits = logit_processor(type, logits) # if type in config_dict:
# logits = logit_processor(type, logits, config_dict[type])
torch.cuda.synchronize()
# calculate probs # calculate probs
probs = torch.softmax(logits, dim=-1, dtype=torch.float) probs = torch.softmax(logits, dim=-1, dtype=torch.float)
@ -188,6 +192,9 @@ class RequestHandler:
# sample the next tokens # sample the next tokens
sample_tokens = self._sample(probs, logprobs, generation_config) sample_tokens = self._sample(probs, logprobs, generation_config)
if not self.prefill_batch.is_empty:
self.prefill_batch.update_batch_tokens(sample_tokens)
else:
self.running_batch.update_batch_tokens(sample_tokens) self.running_batch.update_batch_tokens(sample_tokens)
def update(self): def update(self):

View File

@ -112,7 +112,7 @@ class KVCacheManager:
def get_kv_cache(self): def get_kv_cache(self):
"""Get k_cache and v_cache""" """Get k_cache and v_cache"""
return self._kv_cache[0], self._kv_cache[1] return self._kv_caches[0], self._kv_caches[1]
def get_max_blocks_per_sequence(self) -> int: def get_max_blocks_per_sequence(self) -> int:
"""Get the maximum number of blocks that can be allocated for a single sequence.""" """Get the maximum number of blocks that can be allocated for a single sequence."""
@ -122,7 +122,7 @@ class KVCacheManager:
return self.max_blocks_per_sequence return self.max_blocks_per_sequence
def check_allocation(self, seq: Sequence) -> bool: def check_allocation(self, seq: Sequence) -> bool:
num_blocks_needed = (seq.prompt_len + self.max_output_length + self.block_size - 1) // self.block_size num_blocks_needed = (seq.input_len + self.max_output_length + self.block_size - 1) // self.block_size
return num_blocks_needed <= self.num_available_blocks return num_blocks_needed <= self.num_available_blocks
def get_block_kv_ptrs(self, block_id: int, layer_id: int) -> Tuple[List[int], List[int]]: def get_block_kv_ptrs(self, block_id: int, layer_id: int) -> Tuple[List[int], List[int]]:

View File

@ -70,7 +70,10 @@ def llama_model_forward(
seq_length = input_ids.shape[1] seq_length = input_ids.shape[1]
device = input_ids.device device = input_ids.device
past_key_values_length = len(block_tables.shape[1]) if batch.is_prompts:
past_key_values_length = 0
else:
past_key_values_length = sequence_lengths[0].item() - 1
position_ids = torch.arange( position_ids = torch.arange(
past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
@ -163,26 +166,17 @@ def llama_attn_forward(
key_states = key_states.view(-1, self.num_heads, self.head_dim) key_states = key_states.view(-1, self.num_heads, self.head_dim)
value_states = value_states.view(-1, self.num_heads, self.head_dim) value_states = value_states.view(-1, self.num_heads, self.head_dim)
block_size = k_cache.shape[-1] k_cache.shape[-1]
memcpy_to_block(key_states, value_states, k_cache, v_cache, block_tables, block_size) # memcpy_to_block(key_states, value_states, k_cache, v_cache, block_tables, block_size, sequence_lengths)
if is_prompts: # if is_prompts:
attn_output = context_attention_unpadded( # attn_output = context_attention_unpadded(query_states, key_states, value_states, k_cache, v_cache, sequence_lengths, block_tables, block_size)
query_states, key_states, value_states, k_cache, v_cache, sequence_lengths, block_tables, block_size # else:
) # attn_output = torch.empty(bsz, self.num_heads, self.head_dim)
else: # decoding_attention(query_states, k_cache, v_cache, block_tables, sequence_lengths, attn_output, block_tables.shape[1], block_size)
attn_output = torch.empty(bsz, self.num_heads, self.head_dim)
decoding_attention( attn_output = query_states
query_states,
k_cache,
v_cache,
block_tables,
sequence_lengths,
attn_output,
block_tables.shape[1],
block_size,
)
attn_output = attn_output.view(bsz, q_len, self.num_heads, self.head_dim) attn_output = attn_output.view(bsz, q_len, self.num_heads, self.head_dim)
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
@ -190,19 +184,3 @@ def llama_attn_forward(
attn_output = self.o_proj(attn_output) attn_output = self.o_proj(attn_output)
return attn_output return attn_output
def memcpy_to_block(key, value, k_cache, v_cache, block_tables, block_size):
block_table_list = block_tables.tolist()
batch_size, seq_len, num_heads, head_dim = key
reshape_key = key.reshape(batch_size, seq_len, block_size, num_heads, head_dim).tensor.permute(0, 2, 3, 1)
reshape_value = value.reshape(batch_size, seq_len, block_size, num_heads, head_dim).tensor.permute(0, 2, 3, 1)
if seq_len == 1:
for i in range(batch_size):
k_cache[block_table_list[i][-1], :] = reshape_key[i]
v_cache[block_table_list[i][-1], :] = reshape_value[i]
else:
for i in range(batch_size):
k_cache[block_table_list[i], :] = reshape_key[i]
v_cache[block_table_list[i], :] = reshape_value[i]

View File

@ -1,7 +1,165 @@
from functools import partial
from transformers.models.llama.modeling_llama import (
LlamaAttention,
LlamaDecoderLayer,
LlamaFlashAttention2,
LlamaForCausalLM,
LlamaModel,
LlamaSdpaAttention,
)
from colossalai.inference.modeling.models.llama import (
llama_attn_forward,
llama_causal_lm_forward,
llama_decoder_layer_forward,
llama_model_forward,
)
from colossalai.shardformer.policies.base_policy import ModulePolicyDescription, SubModuleReplacementDescription
# import colossalai
from colossalai.shardformer.policies.llama import LlamaForCausalLMPolicy from colossalai.shardformer.policies.llama import LlamaForCausalLMPolicy
class LlamaModelInferPolicy(LlamaForCausalLMPolicy): class LlamaModelInferPolicy(LlamaForCausalLMPolicy):
# The code here just for test and will be modified later.
def __init__(self) -> None: def __init__(self) -> None:
super().__init__() super().__init__()
def module_policy(self):
policy = super().module_policy()
decoder_attribute_replacement = {
"self_attn.hidden_size": self.model.config.hidden_size // self.shard_config.tensor_parallel_size,
"self_attn.num_heads": self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size,
"self_attn.num_key_value_heads": self.model.config.num_key_value_heads
// self.shard_config.tensor_parallel_size,
}
if self.shard_config.extra_kwargs.get("quant", None) == "gptq":
from colossalai.inference.quant.gptq.cai_gptq import ColCaiQuantLinear, RowCaiQuantLinear
policy[LlamaDecoderLayer] = ModulePolicyDescription(
attribute_replacement=decoder_attribute_replacement,
sub_module_replacement=[
SubModuleReplacementDescription(
suffix="self_attn.q_proj",
target_module=ColCaiQuantLinear,
kwargs={"split_num": 1},
),
SubModuleReplacementDescription(
suffix="self_attn.k_proj",
target_module=ColCaiQuantLinear,
kwargs={"split_num": 1},
),
SubModuleReplacementDescription(
suffix="self_attn.v_proj",
target_module=ColCaiQuantLinear,
kwargs={"split_num": 1},
),
SubModuleReplacementDescription(
suffix="self_attn.o_proj",
target_module=RowCaiQuantLinear,
kwargs={"split_num": 1},
),
SubModuleReplacementDescription(
suffix="mlp.gate_proj",
target_module=ColCaiQuantLinear,
kwargs={"split_num": 1},
),
SubModuleReplacementDescription(
suffix="mlp.up_proj",
target_module=ColCaiQuantLinear,
kwargs={"split_num": 1},
),
SubModuleReplacementDescription(
suffix="mlp.down_proj",
target_module=RowCaiQuantLinear,
kwargs={"split_num": 1},
),
],
)
elif self.shard_config.extra_kwargs.get("quant", None) == "smoothquant":
from colossalai.inference.quant.smoothquant.models.llama import LlamaSmoothquantDecoderLayer
from colossalai.inference.quant.smoothquant.models.parallel_linear import (
ColW8A8BFP32OFP32Linear,
RowW8A8B8O8Linear,
RowW8A8BFP32O32LinearSiLU,
RowW8A8BFP32OFP32Linear,
)
policy[LlamaSmoothquantDecoderLayer] = ModulePolicyDescription(
attribute_replacement=decoder_attribute_replacement,
sub_module_replacement=[
SubModuleReplacementDescription(
suffix="self_attn.q_proj",
target_module=RowW8A8B8O8Linear,
kwargs={"split_num": 1},
),
SubModuleReplacementDescription(
suffix="self_attn.k_proj",
target_module=RowW8A8B8O8Linear,
kwargs={"split_num": 1},
),
SubModuleReplacementDescription(
suffix="self_attn.v_proj",
target_module=RowW8A8B8O8Linear,
kwargs={"split_num": 1},
),
SubModuleReplacementDescription(
suffix="self_attn.o_proj",
target_module=ColW8A8BFP32OFP32Linear,
kwargs={"split_num": 1},
),
SubModuleReplacementDescription(
suffix="mlp.gate_proj",
target_module=RowW8A8BFP32O32LinearSiLU,
kwargs={"split_num": 1},
),
SubModuleReplacementDescription(
suffix="mlp.up_proj",
target_module=RowW8A8BFP32OFP32Linear,
kwargs={"split_num": 1},
),
SubModuleReplacementDescription(
suffix="mlp.down_proj",
target_module=ColW8A8BFP32OFP32Linear,
kwargs={"split_num": 1},
),
],
)
self.shard_config._infer()
infer_forward = llama_causal_lm_forward
method_replacement = {"forward": partial(infer_forward)}
self.append_or_create_method_replacement(
description=method_replacement, policy=policy, target_key=LlamaForCausalLM
)
infer_forward = llama_model_forward
method_replacement = {"forward": partial(infer_forward)}
self.append_or_create_method_replacement(description=method_replacement, policy=policy, target_key=LlamaModel)
infer_forward = llama_decoder_layer_forward
method_replacement = {"forward": partial(infer_forward)}
self.append_or_create_method_replacement(
description=method_replacement, policy=policy, target_key=LlamaDecoderLayer
)
infer_forward = llama_attn_forward
method_replacement = {"forward": partial(infer_forward)}
self.append_or_create_method_replacement(
description=method_replacement, policy=policy, target_key=LlamaAttention
)
infer_forward = llama_attn_forward
method_replacement = {"forward": partial(infer_forward)}
self.append_or_create_method_replacement(
description=method_replacement, policy=policy, target_key=LlamaFlashAttention2
)
infer_forward = llama_attn_forward
method_replacement = {"forward": partial(infer_forward)}
self.append_or_create_method_replacement(
description=method_replacement, policy=policy, target_key=LlamaSdpaAttention
)
return policy

View File

@ -1,6 +1,6 @@
import enum import enum
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any, List, Union from typing import Any, List, Tuple, Union
import torch import torch
from ordered_set import OrderedSet from ordered_set import OrderedSet
@ -74,13 +74,6 @@ class Sequence:
self.output_token_id = [] self.output_token_id = []
self.status = RequestStatus.WAITING self.status = RequestStatus.WAITING
@property
def prompt_len(self) -> int:
"""
Get length of prompts
"""
return len(self.input_token_id)
@property @property
def sentence_len(self) -> int: def sentence_len(self) -> int:
""" """
@ -113,7 +106,7 @@ class Sequence:
return True return True
if self.output_token_id: if self.output_token_id:
if self.output_token_id[-1] == self.eos_token_id or len(self.output_token_id) == self.max_output_len: if self.output_token_id[-1] >= self.eos_token_id or len(self.output_token_id) == self.max_output_len:
self.status = RequestStatus.COMPLETED self.status = RequestStatus.COMPLETED
return True return True
@ -143,11 +136,13 @@ class Sequence:
def __repr__(self) -> str: def __repr__(self) -> str:
return ( return (
f"Request ID(request_id={self.request_id}, " f"(request_id={self.request_id}, "
f"prompt={self.prompt}, " f"prompt={self.prompt}, "
f"status={self.status.name}, " f"status={self.status.name}, "
f"sample_params={self.sample_params}, " f"sample_params={self.sample_params}, "
f"logical block number={len(self.block_table_index)}" f"logical_block_number={self.block_table.shape[0]},"
f"input_len={self.input_len}),"
f"output_len={self.output_len})"
) )
@ -159,9 +154,15 @@ class BatchInfo:
sequences_set: OrderedSet["Sequence"] = None sequences_set: OrderedSet["Sequence"] = None
is_prompts: bool = True is_prompts: bool = True
device: torch.device = None
@classmethod def __post_init__(self):
def init_batch(cls, seqs: List["Sequence"] = None) -> "BatchInfo": if self.device is None:
self.device = torch.cuda.current_device()
if self.sequences_set is None:
self.sequences_set = OrderedSet()
def init_batch(self, seqs: List["Sequence"] = None):
""" """
Initializes inference batches by input sentence list. Initializes inference batches by input sentence list.
@ -169,29 +170,29 @@ class BatchInfo:
seqs (List["Sequence"]): List of input sequence. seqs (List["Sequence"]): List of input sequence.
""" """
sequences_set = OrderedSet() assert len(self.sequences_set) == 0, "Sequences set has been initialized."
if seqs is not None: if seqs is not None:
if not isinstance(seqs, list): if not isinstance(seqs, list):
seqs = [seqs] seqs = [seqs]
for seq in seqs: for seq in seqs:
if seq in sequences_set: if seq in self.sequences_set:
logger.warning(f"The sequence(request_id {seq.request_id}) is already in sequences_set.") logger.warning(f"The sequence(request_id {seq.request_id}) is already in sequences_set.")
continue continue
sequences_set.add(seq) self.sequences_set.add(seq)
return cls(sequences_set=sequences_set)
def get_block_table_tensor(self) -> None: def get_block_table_tensor(self) -> None:
tesnor_list = [] tesnor_list = []
block_table = None block_table = None
for seq in self.sequences_set: for seq in self.sequences_set:
block_table = seq.block_table block_table = seq.block_table
assert block_table, f"The sequence(request_id {seq.request_id}) has not initialized the block_table." assert (
block_table is not None
), f"The sequence(request_id {seq.request_id}) has not initialized the block_table."
tesnor_list.append(seq.block_table) tesnor_list.append(seq.block_table)
assert tesnor_list, "Batch has not been initialized yet. Please initialize batch first." assert tesnor_list, "Batch has not been initialized yet. Please initialize batch first."
block_table = torch.concat(tesnor_list) block_table = torch.stack(tesnor_list)
return block_table return block_table
def clear_batch(self) -> None: def clear_batch(self) -> None:
@ -239,7 +240,7 @@ class BatchInfo:
seqs = [seqs] seqs = [seqs]
for seq in seqs: for seq in seqs:
if seq in self.sequences_set: if self.sequences_set and seq in self.sequences_set:
logger.warning(f"The sequence(request_id {seq.request_id}) is already in sequences_set.") logger.warning(f"The sequence(request_id {seq.request_id}) is already in sequences_set.")
continue continue
self.sequences_set.add(seq) self.sequences_set.add(seq)
@ -251,7 +252,7 @@ class BatchInfo:
""" """
return not self.sequences_set return not self.sequences_set
def update_batch_tokens(self, tokens: Union[List[int], List[List[int]]]) -> None: def update_batch_tokens(self, tokens: Union[List[int], List[List[int]], torch.Tensor]) -> None:
""" """
Add an output token for each sentence in the batch. Add an output token for each sentence in the batch.
@ -259,6 +260,9 @@ class BatchInfo:
tokens (List[int]): A batch of tokens tokens (List[int]): A batch of tokens
""" """
if isinstance(tokens, torch.Tensor):
tokens = tokens.tolist()
assert self.get_batch_size() == len(tokens), "The number of tokens does not match batch_size." assert self.get_batch_size() == len(tokens), "The number of tokens does not match batch_size."
for seq, token in zip(self.sequences_set, tokens): for seq, token in zip(self.sequences_set, tokens):
@ -287,19 +291,25 @@ class BatchInfo:
else: else:
input_list.append([seq.output_token_id[-1]]) input_list.append([seq.output_token_id[-1]])
return torch.tensor(input_list, dtype=torch.long) return torch.tensor(input_list, dtype=torch.long, device=self.device)
def get_1D_inputs(self) -> Tuple[torch.LongTensor, torch.Tensor]: def get_1D_inputs(self) -> Tuple[torch.LongTensor, torch.Tensor]:
""" """
Flattening the input tokens. Flattening the input tokens.
""" """
input_list = [] input_list = []
input_len_list = []
for seq in self.sequences_set: for seq in self.sequences_set:
if self.is_prompts: if self.is_prompts:
input_list.extend(seq.input_token_id) input_list.extend(seq.input_token_id)
input_len_list.append(seq.sentence_len)
else: else:
input_list.append(seq.output_token_id[-1]) input_list.append(seq.output_token_id[-1])
return torch.tensor(input_list, dtype=torch.long) input_len_list.append(1)
return torch.tensor(input_list, dtype=torch.long, device=self.device), torch.tensor(
input_len_list, dtype=torch.int, device=device
)
def get_sequence_lengths(self): def get_sequence_lengths(self):
""" """
@ -307,5 +317,9 @@ class BatchInfo:
""" """
len_list = [] len_list = []
for seq in self.sequences_set: for seq in self.sequences_set:
len_list.append(seq.get_sentence_len()) len_list.append(seq.sentence_len)
return torch.tensor(len_list, dtype=torch.int)
return torch.tensor(len_list, dtype=torch.int, device=self.device)
def __repr__(self) -> str:
return f"(sequences_set={self.sequences_set}, " f"is_prompts={self.is_prompts})"

View File

@ -1,6 +1,6 @@
import pytest import pytest
import transformers import transformers
from transformers import AutoTokenizer from transformers import AutoTokenizer, GenerationConfig
import colossalai import colossalai
from colossalai.inference.config import InferenceConfig from colossalai.inference.config import InferenceConfig
@ -11,21 +11,24 @@ from colossalai.testing import spawn
def check_inference_engine(): def check_inference_engine():
model = transformers.LlamaForCausalLM( model = transformers.LlamaForCausalLM(
transformers.LlamaConfig( transformers.LlamaConfig(
vocab_size=20000, hidden_size=512, intermediate_size=1536, num_attention_heads=4, num_hidden_layers=4 vocab_size=50000, hidden_size=512, intermediate_size=1536, num_attention_heads=4, num_hidden_layers=4
) )
) )
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/llama-tokenizer") tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/llama-tokenizer")
inference_config = InferenceConfig() inference_config = InferenceConfig(max_output_len=5)
inference_engine = InferenceEngine(model, tokenizer, inference_config, verbose=True) inference_engine = InferenceEngine(model, tokenizer, inference_config, verbose=True)
inputs = [ inputs = [
"介绍一下北京", "介绍一下今天的北京",
"介绍一下武汉", "介绍一下武汉",
] ]
inference_engine.add_request(prompts=inputs) inference_engine.add_request(prompts=inputs)
assert inference_engine.request_handler._has_waiting() assert inference_engine.request_handler._has_waiting()
# outputs = inference_engine.generate(None) generation_config = GenerationConfig(top_k=2, top_p=0.8, do_sample=True)
outputs = inference_engine.generate(generation_config)
print("outputs: ", outputs)
# Engine still gets some bug # Engine still gets some bug