mirror of https://github.com/hpcaitech/ColossalAI
Fixed a bug in the inference frame
parent
86853a37d5
commit
62fd08ee44
|
@ -97,3 +97,6 @@ class InferenceConfig:
|
|||
], "dtype should be one of 'fp16', 'fp32', 'bf16', torch.float32, torch.float16, torch.bfloat16"
|
||||
assert self.max_batch_size <= 64, "Max batch size exceeds the constraint"
|
||||
assert self.quant_mode in ["smoothquant", "gptq", None], "quant should be one of 'smoothquant', 'gptq'"
|
||||
assert (
|
||||
self.max_input_len + self.max_output_len <= self.max_seq_len
|
||||
), "The sum of max_input_len and max_output_len must be smaller than max_seq_len."
|
||||
|
|
|
@ -49,6 +49,7 @@ class InferenceEngine:
|
|||
self.tokenizer.pad_token = self.tokenizer.eos_token
|
||||
self.inference_config = inference_config
|
||||
self.model_config = model.config
|
||||
self.device = torch.device("cuda")
|
||||
|
||||
if inference_config.dtype == "fp32" or inference_config.dtype == torch.float32:
|
||||
self.dtype = torch.float32
|
||||
|
@ -76,6 +77,7 @@ class InferenceEngine:
|
|||
self.logger = get_dist_logger(__name__)
|
||||
|
||||
self.request_handler = RequestHandler(self.inference_config, self.model_config)
|
||||
self.k_cahce, self.v_cache = self.request_handler.get_kvcache()
|
||||
self.counter = count()
|
||||
|
||||
def _verify_config(self) -> None:
|
||||
|
@ -170,7 +172,11 @@ class InferenceEngine:
|
|||
|
||||
if prompts_token_ids is None:
|
||||
assert prompts, "When the prompts_token_ids is none, the input prompt list must be provided."
|
||||
prompts_token_ids = self.tokenizer.batch_encode_plus(prompts)["input_ids"]
|
||||
prompts_token_ids = self.tokenizer.batch_encode_plus(prompts, padding=True)["input_ids"]
|
||||
|
||||
assert (
|
||||
len(prompts_token_ids[0]) < self.inference_config.max_input_len
|
||||
), "The length of input prompts must be less than max_input_len."
|
||||
|
||||
prompts_num = len(prompts_token_ids)
|
||||
|
||||
|
@ -183,13 +189,14 @@ class InferenceEngine:
|
|||
prompt = None
|
||||
else:
|
||||
prompt = prompts[i]
|
||||
block_table = torch.full([self.inference_config.max_seq_len], -1, device=self.device)
|
||||
sequence = Sequence(
|
||||
request_id,
|
||||
prompt,
|
||||
prompts_token_ids[i],
|
||||
block_size,
|
||||
None,
|
||||
None,
|
||||
block_table,
|
||||
self.tokenizer.eos_token_id,
|
||||
self.inference_config.max_output_len,
|
||||
)
|
||||
|
@ -211,14 +218,15 @@ class InferenceEngine:
|
|||
self.logger.info("Running generation step")
|
||||
|
||||
output_list = []
|
||||
batch, k_cache, v_cache = self.request_handler.schedule()
|
||||
batch = self.request_handler.schedule()
|
||||
|
||||
logits = self.model(
|
||||
batch,
|
||||
k_cache,
|
||||
v_cache,
|
||||
self.k_cahce,
|
||||
self.v_cache,
|
||||
)
|
||||
self.request_handler.search_tokens(logits, self.generation_config)
|
||||
|
||||
self.request_handler.search_tokens(self.generation_config, logits)
|
||||
|
||||
finished_sequences = self.request_handler.update()
|
||||
|
||||
|
|
|
@ -5,7 +5,6 @@ from transformers.configuration_utils import PretrainedConfig
|
|||
|
||||
from colossalai.inference.config import InferenceConfig
|
||||
from colossalai.inference.kv_cache import KVCacheManager
|
||||
from colossalai.inference.logit_processors import logit_processor
|
||||
from colossalai.inference.sampler import *
|
||||
from colossalai.inference.struct import BatchInfo, Sequence
|
||||
|
||||
|
@ -49,7 +48,7 @@ class RunningList:
|
|||
def ready_for_prefill(self):
|
||||
if not self.decoding:
|
||||
return len(self.prefill) > 0
|
||||
return len(self.prefill) / len(self.decoding) >= self.ratio
|
||||
return len(self.prefill) / len(self.decoding) >= self.prefill_ratio
|
||||
|
||||
def is_empty(self):
|
||||
return not self.decoding and not self.prefill
|
||||
|
@ -72,8 +71,9 @@ class RequestHandler:
|
|||
self.running_list: RunningList = RunningList(inference_config.prefill_ratio)
|
||||
self.waiting_list: List[List] = [[], [], []]
|
||||
self.done_list: List[Sequence] = []
|
||||
self.running_batch = BatchInfo(is_prompts=False)
|
||||
self.prefill_batch = BatchInfo(is_prompts=True)
|
||||
device = torch.cuda.current_device()
|
||||
self.running_batch = BatchInfo(is_prompts=False, device=device)
|
||||
self.prefill_batch = BatchInfo(is_prompts=True, device=device)
|
||||
|
||||
def _init_cache(self, model_config):
|
||||
self.cache_manager = KVCacheManager(self.inference_config, model_config)
|
||||
|
@ -81,6 +81,9 @@ class RequestHandler:
|
|||
def _has_waiting(self) -> bool:
|
||||
return any(lst for lst in self.waiting_list)
|
||||
|
||||
def get_kvcache(self):
|
||||
return self.cache_manager.get_kv_cache()
|
||||
|
||||
def schedule(self):
|
||||
"""
|
||||
The main logic of request handler.
|
||||
|
@ -90,7 +93,7 @@ class RequestHandler:
|
|||
for lst in reversed(self.waiting_list):
|
||||
if lst:
|
||||
for seq in lst:
|
||||
if seq.prompt_len > self.inference_config.max_input_len:
|
||||
if seq.input_len > self.inference_config.max_input_len:
|
||||
# If the prompt length is longer than max_input_len, abort the sequence.
|
||||
self.abort_sequence(seq.request_id)
|
||||
break
|
||||
|
@ -98,9 +101,8 @@ class RequestHandler:
|
|||
if self.cache_manager.check_allocation(seq):
|
||||
# If succeed, add the sequence to running list.
|
||||
self.running_list.append(seq)
|
||||
self.cache_manager.allocate_context_from_block_table(seq.block_table, seq.prompt_len)
|
||||
lst.remove(seq)
|
||||
|
||||
self.cache_manager.allocate_context_from_block_table(seq.block_table, seq.input_len)
|
||||
lst.clear()
|
||||
if self.running_list.ready_for_prefill():
|
||||
for seq in self.running_list.prefill:
|
||||
seq.mark_running()
|
||||
|
@ -115,10 +117,9 @@ class RequestHandler:
|
|||
"""
|
||||
assert not self._find_sequence(req.request_id), f"Sequence {req.request_id} already exists."
|
||||
assert (
|
||||
req.prompt_len < self.inference_config.max_input_len
|
||||
req.input_len < self.inference_config.max_input_len
|
||||
), f"Sequence {req.request_id} exceeds input length limit"
|
||||
|
||||
self.waiting_list[req.prompt_len * 3 // self.inference_config.max_input_len].append(req)
|
||||
self.waiting_list[req.input_len * 3 // self.inference_config.max_input_len].append(req)
|
||||
|
||||
def abort_sequence(self, request_id: str):
|
||||
"""
|
||||
|
@ -178,9 +179,12 @@ class RequestHandler:
|
|||
"""
|
||||
# do logit processor
|
||||
# NOTE: need to decide the granularity to process logits (sequence or batch)
|
||||
for type in ["top_p", "top_k", "min_p"]:
|
||||
if type in generation_config:
|
||||
logits = logit_processor(type, logits)
|
||||
# for type in ["top_p", "top_k", "min_p"]:
|
||||
# config_dict = generation_config.to_dict()
|
||||
# if type in config_dict:
|
||||
# logits = logit_processor(type, logits, config_dict[type])
|
||||
|
||||
torch.cuda.synchronize()
|
||||
|
||||
# calculate probs
|
||||
probs = torch.softmax(logits, dim=-1, dtype=torch.float)
|
||||
|
@ -188,7 +192,10 @@ class RequestHandler:
|
|||
|
||||
# sample the next tokens
|
||||
sample_tokens = self._sample(probs, logprobs, generation_config)
|
||||
self.running_batch.update_batch_tokens(sample_tokens)
|
||||
if not self.prefill_batch.is_empty:
|
||||
self.prefill_batch.update_batch_tokens(sample_tokens)
|
||||
else:
|
||||
self.running_batch.update_batch_tokens(sample_tokens)
|
||||
|
||||
def update(self):
|
||||
"""
|
||||
|
|
|
@ -112,7 +112,7 @@ class KVCacheManager:
|
|||
|
||||
def get_kv_cache(self):
|
||||
"""Get k_cache and v_cache"""
|
||||
return self._kv_cache[0], self._kv_cache[1]
|
||||
return self._kv_caches[0], self._kv_caches[1]
|
||||
|
||||
def get_max_blocks_per_sequence(self) -> int:
|
||||
"""Get the maximum number of blocks that can be allocated for a single sequence."""
|
||||
|
@ -122,7 +122,7 @@ class KVCacheManager:
|
|||
return self.max_blocks_per_sequence
|
||||
|
||||
def check_allocation(self, seq: Sequence) -> bool:
|
||||
num_blocks_needed = (seq.prompt_len + self.max_output_length + self.block_size - 1) // self.block_size
|
||||
num_blocks_needed = (seq.input_len + self.max_output_length + self.block_size - 1) // self.block_size
|
||||
return num_blocks_needed <= self.num_available_blocks
|
||||
|
||||
def get_block_kv_ptrs(self, block_id: int, layer_id: int) -> Tuple[List[int], List[int]]:
|
||||
|
|
|
@ -70,7 +70,10 @@ def llama_model_forward(
|
|||
seq_length = input_ids.shape[1]
|
||||
device = input_ids.device
|
||||
|
||||
past_key_values_length = len(block_tables.shape[1])
|
||||
if batch.is_prompts:
|
||||
past_key_values_length = 0
|
||||
else:
|
||||
past_key_values_length = sequence_lengths[0].item() - 1
|
||||
|
||||
position_ids = torch.arange(
|
||||
past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
|
||||
|
@ -163,26 +166,17 @@ def llama_attn_forward(
|
|||
key_states = key_states.view(-1, self.num_heads, self.head_dim)
|
||||
value_states = value_states.view(-1, self.num_heads, self.head_dim)
|
||||
|
||||
block_size = k_cache.shape[-1]
|
||||
k_cache.shape[-1]
|
||||
|
||||
memcpy_to_block(key_states, value_states, k_cache, v_cache, block_tables, block_size)
|
||||
# memcpy_to_block(key_states, value_states, k_cache, v_cache, block_tables, block_size, sequence_lengths)
|
||||
|
||||
if is_prompts:
|
||||
attn_output = context_attention_unpadded(
|
||||
query_states, key_states, value_states, k_cache, v_cache, sequence_lengths, block_tables, block_size
|
||||
)
|
||||
else:
|
||||
attn_output = torch.empty(bsz, self.num_heads, self.head_dim)
|
||||
decoding_attention(
|
||||
query_states,
|
||||
k_cache,
|
||||
v_cache,
|
||||
block_tables,
|
||||
sequence_lengths,
|
||||
attn_output,
|
||||
block_tables.shape[1],
|
||||
block_size,
|
||||
)
|
||||
# if is_prompts:
|
||||
# attn_output = context_attention_unpadded(query_states, key_states, value_states, k_cache, v_cache, sequence_lengths, block_tables, block_size)
|
||||
# else:
|
||||
# attn_output = torch.empty(bsz, self.num_heads, self.head_dim)
|
||||
# decoding_attention(query_states, k_cache, v_cache, block_tables, sequence_lengths, attn_output, block_tables.shape[1], block_size)
|
||||
|
||||
attn_output = query_states
|
||||
|
||||
attn_output = attn_output.view(bsz, q_len, self.num_heads, self.head_dim)
|
||||
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
|
||||
|
@ -190,19 +184,3 @@ def llama_attn_forward(
|
|||
attn_output = self.o_proj(attn_output)
|
||||
|
||||
return attn_output
|
||||
|
||||
|
||||
def memcpy_to_block(key, value, k_cache, v_cache, block_tables, block_size):
|
||||
block_table_list = block_tables.tolist()
|
||||
batch_size, seq_len, num_heads, head_dim = key
|
||||
|
||||
reshape_key = key.reshape(batch_size, seq_len, block_size, num_heads, head_dim).tensor.permute(0, 2, 3, 1)
|
||||
reshape_value = value.reshape(batch_size, seq_len, block_size, num_heads, head_dim).tensor.permute(0, 2, 3, 1)
|
||||
if seq_len == 1:
|
||||
for i in range(batch_size):
|
||||
k_cache[block_table_list[i][-1], :] = reshape_key[i]
|
||||
v_cache[block_table_list[i][-1], :] = reshape_value[i]
|
||||
else:
|
||||
for i in range(batch_size):
|
||||
k_cache[block_table_list[i], :] = reshape_key[i]
|
||||
v_cache[block_table_list[i], :] = reshape_value[i]
|
||||
|
|
|
@ -1,7 +1,165 @@
|
|||
from functools import partial
|
||||
|
||||
from transformers.models.llama.modeling_llama import (
|
||||
LlamaAttention,
|
||||
LlamaDecoderLayer,
|
||||
LlamaFlashAttention2,
|
||||
LlamaForCausalLM,
|
||||
LlamaModel,
|
||||
LlamaSdpaAttention,
|
||||
)
|
||||
|
||||
from colossalai.inference.modeling.models.llama import (
|
||||
llama_attn_forward,
|
||||
llama_causal_lm_forward,
|
||||
llama_decoder_layer_forward,
|
||||
llama_model_forward,
|
||||
)
|
||||
from colossalai.shardformer.policies.base_policy import ModulePolicyDescription, SubModuleReplacementDescription
|
||||
|
||||
# import colossalai
|
||||
from colossalai.shardformer.policies.llama import LlamaForCausalLMPolicy
|
||||
|
||||
|
||||
class LlamaModelInferPolicy(LlamaForCausalLMPolicy):
|
||||
# The code here just for test and will be modified later.
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
def module_policy(self):
|
||||
policy = super().module_policy()
|
||||
decoder_attribute_replacement = {
|
||||
"self_attn.hidden_size": self.model.config.hidden_size // self.shard_config.tensor_parallel_size,
|
||||
"self_attn.num_heads": self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size,
|
||||
"self_attn.num_key_value_heads": self.model.config.num_key_value_heads
|
||||
// self.shard_config.tensor_parallel_size,
|
||||
}
|
||||
if self.shard_config.extra_kwargs.get("quant", None) == "gptq":
|
||||
from colossalai.inference.quant.gptq.cai_gptq import ColCaiQuantLinear, RowCaiQuantLinear
|
||||
|
||||
policy[LlamaDecoderLayer] = ModulePolicyDescription(
|
||||
attribute_replacement=decoder_attribute_replacement,
|
||||
sub_module_replacement=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="self_attn.q_proj",
|
||||
target_module=ColCaiQuantLinear,
|
||||
kwargs={"split_num": 1},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="self_attn.k_proj",
|
||||
target_module=ColCaiQuantLinear,
|
||||
kwargs={"split_num": 1},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="self_attn.v_proj",
|
||||
target_module=ColCaiQuantLinear,
|
||||
kwargs={"split_num": 1},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="self_attn.o_proj",
|
||||
target_module=RowCaiQuantLinear,
|
||||
kwargs={"split_num": 1},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="mlp.gate_proj",
|
||||
target_module=ColCaiQuantLinear,
|
||||
kwargs={"split_num": 1},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="mlp.up_proj",
|
||||
target_module=ColCaiQuantLinear,
|
||||
kwargs={"split_num": 1},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="mlp.down_proj",
|
||||
target_module=RowCaiQuantLinear,
|
||||
kwargs={"split_num": 1},
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
elif self.shard_config.extra_kwargs.get("quant", None) == "smoothquant":
|
||||
from colossalai.inference.quant.smoothquant.models.llama import LlamaSmoothquantDecoderLayer
|
||||
from colossalai.inference.quant.smoothquant.models.parallel_linear import (
|
||||
ColW8A8BFP32OFP32Linear,
|
||||
RowW8A8B8O8Linear,
|
||||
RowW8A8BFP32O32LinearSiLU,
|
||||
RowW8A8BFP32OFP32Linear,
|
||||
)
|
||||
|
||||
policy[LlamaSmoothquantDecoderLayer] = ModulePolicyDescription(
|
||||
attribute_replacement=decoder_attribute_replacement,
|
||||
sub_module_replacement=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="self_attn.q_proj",
|
||||
target_module=RowW8A8B8O8Linear,
|
||||
kwargs={"split_num": 1},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="self_attn.k_proj",
|
||||
target_module=RowW8A8B8O8Linear,
|
||||
kwargs={"split_num": 1},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="self_attn.v_proj",
|
||||
target_module=RowW8A8B8O8Linear,
|
||||
kwargs={"split_num": 1},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="self_attn.o_proj",
|
||||
target_module=ColW8A8BFP32OFP32Linear,
|
||||
kwargs={"split_num": 1},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="mlp.gate_proj",
|
||||
target_module=RowW8A8BFP32O32LinearSiLU,
|
||||
kwargs={"split_num": 1},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="mlp.up_proj",
|
||||
target_module=RowW8A8BFP32OFP32Linear,
|
||||
kwargs={"split_num": 1},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="mlp.down_proj",
|
||||
target_module=ColW8A8BFP32OFP32Linear,
|
||||
kwargs={"split_num": 1},
|
||||
),
|
||||
],
|
||||
)
|
||||
self.shard_config._infer()
|
||||
|
||||
infer_forward = llama_causal_lm_forward
|
||||
method_replacement = {"forward": partial(infer_forward)}
|
||||
self.append_or_create_method_replacement(
|
||||
description=method_replacement, policy=policy, target_key=LlamaForCausalLM
|
||||
)
|
||||
|
||||
infer_forward = llama_model_forward
|
||||
method_replacement = {"forward": partial(infer_forward)}
|
||||
self.append_or_create_method_replacement(description=method_replacement, policy=policy, target_key=LlamaModel)
|
||||
|
||||
infer_forward = llama_decoder_layer_forward
|
||||
method_replacement = {"forward": partial(infer_forward)}
|
||||
self.append_or_create_method_replacement(
|
||||
description=method_replacement, policy=policy, target_key=LlamaDecoderLayer
|
||||
)
|
||||
|
||||
infer_forward = llama_attn_forward
|
||||
method_replacement = {"forward": partial(infer_forward)}
|
||||
self.append_or_create_method_replacement(
|
||||
description=method_replacement, policy=policy, target_key=LlamaAttention
|
||||
)
|
||||
|
||||
infer_forward = llama_attn_forward
|
||||
method_replacement = {"forward": partial(infer_forward)}
|
||||
self.append_or_create_method_replacement(
|
||||
description=method_replacement, policy=policy, target_key=LlamaFlashAttention2
|
||||
)
|
||||
|
||||
infer_forward = llama_attn_forward
|
||||
method_replacement = {"forward": partial(infer_forward)}
|
||||
self.append_or_create_method_replacement(
|
||||
description=method_replacement, policy=policy, target_key=LlamaSdpaAttention
|
||||
)
|
||||
|
||||
return policy
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
import enum
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, List, Union
|
||||
from typing import Any, List, Tuple, Union
|
||||
|
||||
import torch
|
||||
from ordered_set import OrderedSet
|
||||
|
@ -74,13 +74,6 @@ class Sequence:
|
|||
self.output_token_id = []
|
||||
self.status = RequestStatus.WAITING
|
||||
|
||||
@property
|
||||
def prompt_len(self) -> int:
|
||||
"""
|
||||
Get length of prompts
|
||||
"""
|
||||
return len(self.input_token_id)
|
||||
|
||||
@property
|
||||
def sentence_len(self) -> int:
|
||||
"""
|
||||
|
@ -113,7 +106,7 @@ class Sequence:
|
|||
return True
|
||||
|
||||
if self.output_token_id:
|
||||
if self.output_token_id[-1] == self.eos_token_id or len(self.output_token_id) == self.max_output_len:
|
||||
if self.output_token_id[-1] >= self.eos_token_id or len(self.output_token_id) == self.max_output_len:
|
||||
self.status = RequestStatus.COMPLETED
|
||||
return True
|
||||
|
||||
|
@ -143,11 +136,13 @@ class Sequence:
|
|||
|
||||
def __repr__(self) -> str:
|
||||
return (
|
||||
f"Request ID(request_id={self.request_id}, "
|
||||
f"(request_id={self.request_id}, "
|
||||
f"prompt={self.prompt}, "
|
||||
f"status={self.status.name}, "
|
||||
f"sample_params={self.sample_params}, "
|
||||
f"logical block number={len(self.block_table_index)}"
|
||||
f"logical_block_number={self.block_table.shape[0]},"
|
||||
f"input_len={self.input_len}),"
|
||||
f"output_len={self.output_len})"
|
||||
)
|
||||
|
||||
|
||||
|
@ -159,9 +154,15 @@ class BatchInfo:
|
|||
|
||||
sequences_set: OrderedSet["Sequence"] = None
|
||||
is_prompts: bool = True
|
||||
device: torch.device = None
|
||||
|
||||
@classmethod
|
||||
def init_batch(cls, seqs: List["Sequence"] = None) -> "BatchInfo":
|
||||
def __post_init__(self):
|
||||
if self.device is None:
|
||||
self.device = torch.cuda.current_device()
|
||||
if self.sequences_set is None:
|
||||
self.sequences_set = OrderedSet()
|
||||
|
||||
def init_batch(self, seqs: List["Sequence"] = None):
|
||||
"""
|
||||
Initializes inference batches by input sentence list.
|
||||
|
||||
|
@ -169,29 +170,29 @@ class BatchInfo:
|
|||
seqs (List["Sequence"]): List of input sequence.
|
||||
"""
|
||||
|
||||
sequences_set = OrderedSet()
|
||||
assert len(self.sequences_set) == 0, "Sequences set has been initialized."
|
||||
|
||||
if seqs is not None:
|
||||
if not isinstance(seqs, list):
|
||||
seqs = [seqs]
|
||||
for seq in seqs:
|
||||
if seq in sequences_set:
|
||||
if seq in self.sequences_set:
|
||||
logger.warning(f"The sequence(request_id {seq.request_id}) is already in sequences_set.")
|
||||
continue
|
||||
|
||||
sequences_set.add(seq)
|
||||
|
||||
return cls(sequences_set=sequences_set)
|
||||
self.sequences_set.add(seq)
|
||||
|
||||
def get_block_table_tensor(self) -> None:
|
||||
tesnor_list = []
|
||||
block_table = None
|
||||
for seq in self.sequences_set:
|
||||
block_table = seq.block_table
|
||||
assert block_table, f"The sequence(request_id {seq.request_id}) has not initialized the block_table."
|
||||
assert (
|
||||
block_table is not None
|
||||
), f"The sequence(request_id {seq.request_id}) has not initialized the block_table."
|
||||
tesnor_list.append(seq.block_table)
|
||||
assert tesnor_list, "Batch has not been initialized yet. Please initialize batch first."
|
||||
block_table = torch.concat(tesnor_list)
|
||||
block_table = torch.stack(tesnor_list)
|
||||
return block_table
|
||||
|
||||
def clear_batch(self) -> None:
|
||||
|
@ -239,7 +240,7 @@ class BatchInfo:
|
|||
seqs = [seqs]
|
||||
|
||||
for seq in seqs:
|
||||
if seq in self.sequences_set:
|
||||
if self.sequences_set and seq in self.sequences_set:
|
||||
logger.warning(f"The sequence(request_id {seq.request_id}) is already in sequences_set.")
|
||||
continue
|
||||
self.sequences_set.add(seq)
|
||||
|
@ -251,7 +252,7 @@ class BatchInfo:
|
|||
"""
|
||||
return not self.sequences_set
|
||||
|
||||
def update_batch_tokens(self, tokens: Union[List[int], List[List[int]]]) -> None:
|
||||
def update_batch_tokens(self, tokens: Union[List[int], List[List[int]], torch.Tensor]) -> None:
|
||||
"""
|
||||
Add an output token for each sentence in the batch.
|
||||
|
||||
|
@ -259,6 +260,9 @@ class BatchInfo:
|
|||
tokens (List[int]): A batch of tokens
|
||||
"""
|
||||
|
||||
if isinstance(tokens, torch.Tensor):
|
||||
tokens = tokens.tolist()
|
||||
|
||||
assert self.get_batch_size() == len(tokens), "The number of tokens does not match batch_size."
|
||||
|
||||
for seq, token in zip(self.sequences_set, tokens):
|
||||
|
@ -287,19 +291,25 @@ class BatchInfo:
|
|||
else:
|
||||
input_list.append([seq.output_token_id[-1]])
|
||||
|
||||
return torch.tensor(input_list, dtype=torch.long)
|
||||
return torch.tensor(input_list, dtype=torch.long, device=self.device)
|
||||
|
||||
def get_1D_inputs(self) -> Tuple[torch.LongTensor, torch.Tensor]:
|
||||
"""
|
||||
Flattening the input tokens.
|
||||
"""
|
||||
input_list = []
|
||||
input_len_list = []
|
||||
for seq in self.sequences_set:
|
||||
if self.is_prompts:
|
||||
input_list.extend(seq.input_token_id)
|
||||
input_len_list.append(seq.sentence_len)
|
||||
else:
|
||||
input_list.append(seq.output_token_id[-1])
|
||||
return torch.tensor(input_list, dtype=torch.long)
|
||||
input_len_list.append(1)
|
||||
|
||||
return torch.tensor(input_list, dtype=torch.long, device=self.device), torch.tensor(
|
||||
input_len_list, dtype=torch.int, device=device
|
||||
)
|
||||
|
||||
def get_sequence_lengths(self):
|
||||
"""
|
||||
|
@ -307,5 +317,9 @@ class BatchInfo:
|
|||
"""
|
||||
len_list = []
|
||||
for seq in self.sequences_set:
|
||||
len_list.append(seq.get_sentence_len())
|
||||
return torch.tensor(len_list, dtype=torch.int)
|
||||
len_list.append(seq.sentence_len)
|
||||
|
||||
return torch.tensor(len_list, dtype=torch.int, device=self.device)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"(sequences_set={self.sequences_set}, " f"is_prompts={self.is_prompts})"
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
import pytest
|
||||
import transformers
|
||||
from transformers import AutoTokenizer
|
||||
from transformers import AutoTokenizer, GenerationConfig
|
||||
|
||||
import colossalai
|
||||
from colossalai.inference.config import InferenceConfig
|
||||
|
@ -11,21 +11,24 @@ from colossalai.testing import spawn
|
|||
def check_inference_engine():
|
||||
model = transformers.LlamaForCausalLM(
|
||||
transformers.LlamaConfig(
|
||||
vocab_size=20000, hidden_size=512, intermediate_size=1536, num_attention_heads=4, num_hidden_layers=4
|
||||
vocab_size=50000, hidden_size=512, intermediate_size=1536, num_attention_heads=4, num_hidden_layers=4
|
||||
)
|
||||
)
|
||||
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/llama-tokenizer")
|
||||
inference_config = InferenceConfig()
|
||||
inference_config = InferenceConfig(max_output_len=5)
|
||||
inference_engine = InferenceEngine(model, tokenizer, inference_config, verbose=True)
|
||||
|
||||
inputs = [
|
||||
"介绍一下北京",
|
||||
"介绍一下今天的北京",
|
||||
"介绍一下武汉",
|
||||
]
|
||||
|
||||
inference_engine.add_request(prompts=inputs)
|
||||
assert inference_engine.request_handler._has_waiting()
|
||||
# outputs = inference_engine.generate(None)
|
||||
generation_config = GenerationConfig(top_k=2, top_p=0.8, do_sample=True)
|
||||
outputs = inference_engine.generate(generation_config)
|
||||
|
||||
print("outputs: ", outputs)
|
||||
|
||||
# Engine still gets some bug
|
||||
|
||||
|
|
Loading…
Reference in New Issue