mirror of https://github.com/hpcaitech/ColossalAI
[Hotfix] Fix bugs in testing continuous batching (#5270)
* fix bug * fix bugs * fix bugs * fix bugs and add padding * add funcs and fix bugs * fix typos * fix bugs * add funcpull/5283/head
parent
5ae9099f92
commit
9e2342bde2
|
@ -57,6 +57,9 @@ class RunningList:
|
||||||
def is_empty(self):
|
def is_empty(self):
|
||||||
return not self.decoding and not self.prefill
|
return not self.decoding and not self.prefill
|
||||||
|
|
||||||
|
def total_seq_num(self):
|
||||||
|
return len(self.decoding) + len(self.prefill)
|
||||||
|
|
||||||
|
|
||||||
class RequestHandler:
|
class RequestHandler:
|
||||||
"""
|
"""
|
||||||
|
@ -105,7 +108,13 @@ class RequestHandler:
|
||||||
f"the prompt(Request id = {seq.request_id}) length is longer than max_input_len, abort this sequence."
|
f"the prompt(Request id = {seq.request_id}) length is longer than max_input_len, abort this sequence."
|
||||||
)
|
)
|
||||||
self.abort_sequence(seq.request_id)
|
self.abort_sequence(seq.request_id)
|
||||||
|
remove_list.append(seq)
|
||||||
break
|
break
|
||||||
|
|
||||||
|
# stop feeding new sequence into running list to assure
|
||||||
|
if self.cache_manager.num_available_blocks <= self.running_list.total_seq_num():
|
||||||
|
break
|
||||||
|
|
||||||
# Try to allocate cache blocks for the sequence.
|
# Try to allocate cache blocks for the sequence.
|
||||||
if (
|
if (
|
||||||
self.cache_manager.check_allocation(seq)
|
self.cache_manager.check_allocation(seq)
|
||||||
|
@ -115,7 +124,7 @@ class RequestHandler:
|
||||||
# If succeed, add the sequence to running list.
|
# If succeed, add the sequence to running list.
|
||||||
remove_list.append(seq)
|
remove_list.append(seq)
|
||||||
self.running_list.append(seq)
|
self.running_list.append(seq)
|
||||||
self.cache_manager.allocate_context_from_block_table(seq.block_table, seq.input_len)
|
self.cache_manager.allocate_context_from_block_table(seq.block_table, seq.sentence_len)
|
||||||
for seq in remove_list:
|
for seq in remove_list:
|
||||||
lst.remove(seq)
|
lst.remove(seq)
|
||||||
if self.running_list.ready_for_prefill():
|
if self.running_list.ready_for_prefill():
|
||||||
|
@ -126,7 +135,13 @@ class RequestHandler:
|
||||||
|
|
||||||
if not self.running_batch.is_empty:
|
if not self.running_batch.is_empty:
|
||||||
for seq in self.running_batch.sequences_set:
|
for seq in self.running_batch.sequences_set:
|
||||||
self.cache_manager.allocate_token_from_block_table(seq.block_table, seq.sentence_len)
|
recycle = self.cache_manager.allocate_token_from_block_table(seq.block_table, seq.sentence_len)
|
||||||
|
if recycle:
|
||||||
|
seq.recycle()
|
||||||
|
self.running_batch.del_seq(seq)
|
||||||
|
self.running_list.remove(seq)
|
||||||
|
self.waiting_list[-1].append(seq)
|
||||||
|
# the recycled sequences are handled with highest priority.
|
||||||
|
|
||||||
return self.running_batch
|
return self.running_batch
|
||||||
|
|
||||||
|
|
|
@ -69,7 +69,7 @@ def convert_kvcache(cache, lengths, block_tables, pad_id=0):
|
||||||
)
|
)
|
||||||
padding = seq_len - _cache.size(0)
|
padding = seq_len - _cache.size(0)
|
||||||
if padding > 0:
|
if padding > 0:
|
||||||
_cache = F.pad(_cache, (0, 0, 0, 0, 0, 1), value=pad_id)
|
_cache = F.pad(_cache, (0, 0, 0, 0, 0, padding), value=pad_id)
|
||||||
padded_cache.append(_cache)
|
padded_cache.append(_cache)
|
||||||
return torch.stack(padded_cache, dim=0)
|
return torch.stack(padded_cache, dim=0)
|
||||||
|
|
||||||
|
|
|
@ -173,7 +173,10 @@ def llama_attn_forward(
|
||||||
key_states = self.k_proj(hidden_states).view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
key_states = self.k_proj(hidden_states).view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||||
value_states = self.v_proj(hidden_states).view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
value_states = self.v_proj(hidden_states).view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||||
|
|
||||||
|
kv_seq_len = max(sequence_lengths).item()
|
||||||
|
|
||||||
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
||||||
|
|
||||||
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
|
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
|
||||||
|
|
||||||
query_states = query_states.transpose(1, 2)
|
query_states = query_states.transpose(1, 2)
|
||||||
|
|
|
@ -29,6 +29,9 @@ class RequestStatus(enum.Enum):
|
||||||
COMPLETED = enum.auto()
|
COMPLETED = enum.auto()
|
||||||
LENGTH_CAPPED = enum.auto()
|
LENGTH_CAPPED = enum.auto()
|
||||||
|
|
||||||
|
# recycle status
|
||||||
|
RECYCLED = enum.auto()
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def is_finished(status: "RequestStatus") -> bool:
|
def is_finished(status: "RequestStatus") -> bool:
|
||||||
return status in [
|
return status in [
|
||||||
|
@ -119,7 +122,9 @@ class Sequence:
|
||||||
"""
|
"""
|
||||||
Set status for prefill reqs.
|
Set status for prefill reqs.
|
||||||
"""
|
"""
|
||||||
assert self.status == RequestStatus.WAITING, "Sequence is not in WAITTING STATUS"
|
assert (
|
||||||
|
self.status == RequestStatus.WAITING or RequestStatus.RECYCLED
|
||||||
|
), "Sequence is not in WAITTING/RECYCLED STATUS"
|
||||||
self.status = RequestStatus.RUNNING
|
self.status = RequestStatus.RUNNING
|
||||||
|
|
||||||
def mark_finished(self) -> None:
|
def mark_finished(self) -> None:
|
||||||
|
@ -139,10 +144,10 @@ class Sequence:
|
||||||
Recycle a running sequnce to waiitting list
|
Recycle a running sequnce to waiitting list
|
||||||
"""
|
"""
|
||||||
assert (
|
assert (
|
||||||
not self.status.is_finished and not self.status == RequestStatus.ABORTED
|
not self.check_finish() and not self.status == RequestStatus.ABORTED
|
||||||
), "The running sequence \
|
), "The running sequence \
|
||||||
is already done but it still in running list"
|
is already done but it still in running list"
|
||||||
self.status = RequestStatus.WAITING
|
self.status = RequestStatus.RECYCLED
|
||||||
|
|
||||||
def __repr__(self) -> str:
|
def __repr__(self) -> str:
|
||||||
return (
|
return (
|
||||||
|
@ -162,7 +167,7 @@ class BatchInfo:
|
||||||
Information to be passed and used for a batch of sequences.
|
Information to be passed and used for a batch of sequences.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
sequences_set: OrderedSet["Sequence"] = None
|
sequences_set: OrderedSet[Sequence] = None
|
||||||
is_prompts: bool = True
|
is_prompts: bool = True
|
||||||
device: torch.device = None
|
device: torch.device = None
|
||||||
|
|
||||||
|
@ -207,12 +212,20 @@ class BatchInfo:
|
||||||
|
|
||||||
def clear_batch(self) -> None:
|
def clear_batch(self) -> None:
|
||||||
"""
|
"""
|
||||||
Clear sequence set and block table.
|
Clear sequence set and block table if we need to abort this batch.
|
||||||
|
Prefill: clear sequence set and move them to running batch(external)
|
||||||
|
Decoding: mark unfinished sequences as aborted.
|
||||||
"""
|
"""
|
||||||
for seq in self.sequences_set:
|
if self.is_prompts:
|
||||||
if not seq.check_finish():
|
self.sequences_set.clear()
|
||||||
seq.status = RequestStatus.ABORTED
|
|
||||||
self.sequences_set.clear()
|
else:
|
||||||
|
for seq in self.sequences_set:
|
||||||
|
seq.mark_aborted()
|
||||||
|
if seq.check_finish():
|
||||||
|
seq.mark_finished()
|
||||||
|
|
||||||
|
self.sequences_set.clear()
|
||||||
|
|
||||||
def fliter_batch(self) -> List["Sequence"]:
|
def fliter_batch(self) -> List["Sequence"]:
|
||||||
"""
|
"""
|
||||||
|
@ -255,6 +268,12 @@ class BatchInfo:
|
||||||
continue
|
continue
|
||||||
self.sequences_set.add(seq)
|
self.sequences_set.add(seq)
|
||||||
|
|
||||||
|
def del_seq(self, seq: Sequence) -> Sequence:
|
||||||
|
"""
|
||||||
|
Delete sequence in batch
|
||||||
|
"""
|
||||||
|
self.sequences_set.discard(seq)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def is_empty(self) -> None:
|
def is_empty(self) -> None:
|
||||||
"""
|
"""
|
||||||
|
@ -297,11 +316,19 @@ class BatchInfo:
|
||||||
|
|
||||||
for seq in self.sequences_set:
|
for seq in self.sequences_set:
|
||||||
if self.is_prompts:
|
if self.is_prompts:
|
||||||
input_list.append(seq.input_token_id)
|
if seq.output_len > 0:
|
||||||
|
print(seq.output_token_id)
|
||||||
|
seq_data = seq.input_token_id + seq.output_token_id
|
||||||
|
print(seq_data)
|
||||||
|
input_list.append(seq.input_token_id + seq.output_token_id)
|
||||||
|
else:
|
||||||
|
input_list.append(seq.input_token_id)
|
||||||
else:
|
else:
|
||||||
input_list.append([seq.output_token_id[-1]])
|
input_list.append([seq.output_token_id[-1]])
|
||||||
|
|
||||||
return torch.tensor(input_list, dtype=torch.long, device=self.device)
|
max_seq_len = max(len(sub_list) for sub_list in input_list)
|
||||||
|
|
||||||
|
return _make_tensor_with_pad(input_list, max_seq_len, 0, dtype=torch.int)
|
||||||
|
|
||||||
def get_1D_inputs(self) -> Tuple[torch.LongTensor, torch.Tensor]:
|
def get_1D_inputs(self) -> Tuple[torch.LongTensor, torch.Tensor]:
|
||||||
"""
|
"""
|
||||||
|
@ -340,12 +367,27 @@ class BatchInfo:
|
||||||
for seq in self.sequences_set:
|
for seq in self.sequences_set:
|
||||||
past_values.append(seq.input_token_id + seq.output_token_id)
|
past_values.append(seq.input_token_id + seq.output_token_id)
|
||||||
|
|
||||||
attn_mask = torch.tensor(past_values, dtype=torch.int, device=self.device).ne(padding_id).long()
|
max_seq_len = max(len(sub_list) for sub_list in past_values)
|
||||||
|
attn_mask = _make_tensor_with_pad(past_values, max_seq_len, 0, dtype=torch.int, device=self.device)
|
||||||
|
|
||||||
if torch.any(attn_mask == 0):
|
return attn_mask.ne(padding_id).long()
|
||||||
return attn_mask
|
|
||||||
else:
|
|
||||||
return None
|
|
||||||
|
|
||||||
def __repr__(self) -> str:
|
def __repr__(self) -> str:
|
||||||
return f"(sequences_set={self.sequences_set}, " f"is_prompts={self.is_prompts})"
|
return f"(sequences_set={self.sequences_set}, " f"is_prompts={self.is_prompts})"
|
||||||
|
|
||||||
|
|
||||||
|
def _pad_to_max(x: List[int], max_len: int, pad: int) -> List[int]:
|
||||||
|
assert len(x) <= max_len
|
||||||
|
return x + [pad] * (max_len - len(x))
|
||||||
|
|
||||||
|
|
||||||
|
def _make_tensor_with_pad(
|
||||||
|
x: Union[List[List[int]], List[int]],
|
||||||
|
max_len: int,
|
||||||
|
pad: int,
|
||||||
|
dtype: torch.dtype,
|
||||||
|
device: Union[str, torch.device] = "cuda",
|
||||||
|
pin_memory: bool = False,
|
||||||
|
):
|
||||||
|
padded_x = [_pad_to_max(x_i, max_len, pad) for x_i in x]
|
||||||
|
return torch.tensor(padded_x, dtype=dtype, device=device, pin_memory=pin_memory and str(device) == "cpu")
|
||||||
|
|
|
@ -95,11 +95,10 @@ def benchmark_inference(args):
|
||||||
|
|
||||||
if args.dtype == "fp16":
|
if args.dtype == "fp16":
|
||||||
model = model.half()
|
model = model.half()
|
||||||
elif args.dtype == "bf16":
|
elif args.dtype == "fp16":
|
||||||
model = model.to(torch.bfloat16)
|
model = model.to(torch.bfloat16)
|
||||||
|
|
||||||
# mbsz = args.mbsz
|
mbsz = args.mbsz
|
||||||
mbsz = args.batch_size
|
|
||||||
if args.mode == "caiinference":
|
if args.mode == "caiinference":
|
||||||
inference_config = InferenceConfig(
|
inference_config = InferenceConfig(
|
||||||
dtype=args.dtype,
|
dtype=args.dtype,
|
||||||
|
|
|
@ -2,7 +2,7 @@ import pytest
|
||||||
|
|
||||||
import colossalai
|
import colossalai
|
||||||
from colossalai.inference.config import InferenceConfig
|
from colossalai.inference.config import InferenceConfig
|
||||||
from colossalai.inference.struct import BatchInfo, Sequence
|
from colossalai.inference.struct import BatchInfo, RequestStatus, Sequence
|
||||||
from colossalai.testing import rerun_if_address_is_in_use, spawn
|
from colossalai.testing import rerun_if_address_is_in_use, spawn
|
||||||
|
|
||||||
|
|
||||||
|
@ -41,6 +41,10 @@ def check_config_and_inference():
|
||||||
eos_token_id=2,
|
eos_token_id=2,
|
||||||
max_output_len=256,
|
max_output_len=256,
|
||||||
)
|
)
|
||||||
|
sequence.mark_running()
|
||||||
|
assert sequence.status == RequestStatus.RUNNING
|
||||||
|
sequence.recycle()
|
||||||
|
assert sequence.status == RequestStatus.RECYCLED
|
||||||
|
|
||||||
assert sequence.sentence_len == 3
|
assert sequence.sentence_len == 3
|
||||||
assert sequence.input_len == 3
|
assert sequence.input_len == 3
|
||||||
|
|
Loading…
Reference in New Issue