fix CI bugs

pull/5258/head
yuehuayingxueluo 2024-01-09 15:18:28 +08:00 committed by FrankLeeeee
parent 2a73e828eb
commit fab294c7f4
5 changed files with 21 additions and 9 deletions

View File

@ -191,7 +191,14 @@ class InferenceEngine:
prompt = None
else:
prompt = prompts[i]
block_table = torch.full([self.inference_config.max_seq_len], -1, device=self.device)
max_blocks_per_sequence = (
self.inference_config.max_input_len
+ self.inference_config.max_output_len
+ self.inference_config.block_size
- 1
) // self.inference_config.block_size
block_table = torch.full([max_blocks_per_sequence], -1, device=self.device)
sequence = Sequence(
request_id,
prompt,

View File

@ -7,7 +7,7 @@ from colossalai.inference.config import InferenceConfig
from colossalai.inference.kv_cache import KVCacheManager
from colossalai.inference.logit_processors import logit_processor
from colossalai.inference.sampler import *
from colossalai.inference.struct import BatchInfo, Sequence
from colossalai.inference.struct import BatchInfo, RequestStatus, Sequence
from colossalai.logging import get_dist_logger
logger = get_dist_logger(__name__)
@ -104,7 +104,7 @@ class RequestHandler:
f"the prompt(Request id = {seq.request_id}) length is longer than max_input_len, abort this sequence."
)
self.abort_sequence(seq.request_id)
remove_list.append(seq)
break
# Try to allocate cache blocks for the sequence.
if self.cache_manager.check_allocation(seq):
# If succeed, add the sequence to running list.
@ -139,9 +139,10 @@ class RequestHandler:
"""
Abort the request.
"""
seq, _ = self._find_sequence(request_id)
if seq.status.is_waiting:
seq, priority = self._find_sequence(request_id)
if seq.status == RequestStatus.WAITING:
seq.mark_aborted()
self.waiting_list[priority].remove(seq)
elif seq.status.is_running():
self.cache_manager.free_block_table(seq.block_table)
self.running_list.remove(seq)

View File

@ -217,6 +217,8 @@ class PagedAttention:
attn_weights = torch.matmul(q, k.transpose(2, 3)) / math.sqrt(head_size)
padding_mask = None
if attn_mask is not None:
padding_mask = AttentionMaskConverter._expand_mask(attn_mask, q.dtype, seq_len)
@ -279,11 +281,12 @@ class PagedAttention:
if attn_weights.size() != (bsz, num_heads, 1, seq_len):
raise ValueError(f"Got wrong attn_weights, should be in shape {(bsz,num_heads,1,seq_len)}.")
padding_mask = None
if attn_mask is not None:
padding_mask = AttentionMaskConverter._expand_mask(attn_mask, q.dtype, query_length)
padding_mask = AttentionMaskConverter._expand_mask(attn_mask, q.dtype, q_length)
attn_mask = AttentionMaskConverter._make_causal_mask(
(bsz, q_length), q.dtype, q.device, past_key_values_length=seq_len - query_length
(bsz, q_length), q.dtype, q.device, past_key_values_length=seq_len - q_length
)
if padding_mask is not None:

View File

@ -11,6 +11,7 @@ from colossalai.inference.config import InferenceConfig
from colossalai.inference.core.engine import InferenceEngine
from colossalai.testing import rerun_if_address_is_in_use, spawn
def setup_seed(seed):
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
@ -34,7 +35,7 @@ def check_inference_engine(test_cai=False):
"介绍一下武汉,",
]
output_len = 128
output_len = 38
do_sample = True
top_p = 0.5
top_k = 50

View File

@ -57,7 +57,7 @@ def check_request_handler():
block_size=16,
eos_token_id=0,
sample_params=None,
block_table=torch.tensor([0, 0]),
block_table=torch.tensor([-1, -1]),
)
request_handler.add_sequence(seq1)
# the priority should be 1