mirror of https://github.com/hpcaitech/ColossalAI
fix CI bugs
parent
2a73e828eb
commit
fab294c7f4
|
@ -191,7 +191,14 @@ class InferenceEngine:
|
|||
prompt = None
|
||||
else:
|
||||
prompt = prompts[i]
|
||||
block_table = torch.full([self.inference_config.max_seq_len], -1, device=self.device)
|
||||
|
||||
max_blocks_per_sequence = (
|
||||
self.inference_config.max_input_len
|
||||
+ self.inference_config.max_output_len
|
||||
+ self.inference_config.block_size
|
||||
- 1
|
||||
) // self.inference_config.block_size
|
||||
block_table = torch.full([max_blocks_per_sequence], -1, device=self.device)
|
||||
sequence = Sequence(
|
||||
request_id,
|
||||
prompt,
|
||||
|
|
|
@ -7,7 +7,7 @@ from colossalai.inference.config import InferenceConfig
|
|||
from colossalai.inference.kv_cache import KVCacheManager
|
||||
from colossalai.inference.logit_processors import logit_processor
|
||||
from colossalai.inference.sampler import *
|
||||
from colossalai.inference.struct import BatchInfo, Sequence
|
||||
from colossalai.inference.struct import BatchInfo, RequestStatus, Sequence
|
||||
from colossalai.logging import get_dist_logger
|
||||
|
||||
logger = get_dist_logger(__name__)
|
||||
|
@ -104,7 +104,7 @@ class RequestHandler:
|
|||
f"the prompt(Request id = {seq.request_id}) length is longer than max_input_len, abort this sequence."
|
||||
)
|
||||
self.abort_sequence(seq.request_id)
|
||||
remove_list.append(seq)
|
||||
break
|
||||
# Try to allocate cache blocks for the sequence.
|
||||
if self.cache_manager.check_allocation(seq):
|
||||
# If succeed, add the sequence to running list.
|
||||
|
@ -139,9 +139,10 @@ class RequestHandler:
|
|||
"""
|
||||
Abort the request.
|
||||
"""
|
||||
seq, _ = self._find_sequence(request_id)
|
||||
if seq.status.is_waiting:
|
||||
seq, priority = self._find_sequence(request_id)
|
||||
if seq.status == RequestStatus.WAITING:
|
||||
seq.mark_aborted()
|
||||
self.waiting_list[priority].remove(seq)
|
||||
elif seq.status.is_running():
|
||||
self.cache_manager.free_block_table(seq.block_table)
|
||||
self.running_list.remove(seq)
|
||||
|
|
|
@ -217,6 +217,8 @@ class PagedAttention:
|
|||
|
||||
attn_weights = torch.matmul(q, k.transpose(2, 3)) / math.sqrt(head_size)
|
||||
|
||||
padding_mask = None
|
||||
|
||||
if attn_mask is not None:
|
||||
padding_mask = AttentionMaskConverter._expand_mask(attn_mask, q.dtype, seq_len)
|
||||
|
||||
|
@ -279,11 +281,12 @@ class PagedAttention:
|
|||
if attn_weights.size() != (bsz, num_heads, 1, seq_len):
|
||||
raise ValueError(f"Got wrong attn_weights, should be in shape {(bsz,num_heads,1,seq_len)}.")
|
||||
|
||||
padding_mask = None
|
||||
if attn_mask is not None:
|
||||
padding_mask = AttentionMaskConverter._expand_mask(attn_mask, q.dtype, query_length)
|
||||
padding_mask = AttentionMaskConverter._expand_mask(attn_mask, q.dtype, q_length)
|
||||
|
||||
attn_mask = AttentionMaskConverter._make_causal_mask(
|
||||
(bsz, q_length), q.dtype, q.device, past_key_values_length=seq_len - query_length
|
||||
(bsz, q_length), q.dtype, q.device, past_key_values_length=seq_len - q_length
|
||||
)
|
||||
|
||||
if padding_mask is not None:
|
||||
|
|
|
@ -11,6 +11,7 @@ from colossalai.inference.config import InferenceConfig
|
|||
from colossalai.inference.core.engine import InferenceEngine
|
||||
from colossalai.testing import rerun_if_address_is_in_use, spawn
|
||||
|
||||
|
||||
def setup_seed(seed):
|
||||
torch.manual_seed(seed)
|
||||
torch.cuda.manual_seed_all(seed)
|
||||
|
@ -34,7 +35,7 @@ def check_inference_engine(test_cai=False):
|
|||
"介绍一下武汉,",
|
||||
]
|
||||
|
||||
output_len = 128
|
||||
output_len = 38
|
||||
do_sample = True
|
||||
top_p = 0.5
|
||||
top_k = 50
|
||||
|
|
|
@ -57,7 +57,7 @@ def check_request_handler():
|
|||
block_size=16,
|
||||
eos_token_id=0,
|
||||
sample_params=None,
|
||||
block_table=torch.tensor([0, 0]),
|
||||
block_table=torch.tensor([-1, -1]),
|
||||
)
|
||||
request_handler.add_sequence(seq1)
|
||||
# the priority should be 1
|
||||
|
|
Loading…
Reference in New Issue