From fab294c7f4a5db0a4e19109ac5656492ff3ca08b Mon Sep 17 00:00:00 2001 From: yuehuayingxueluo <867460659@qq.com> Date: Tue, 9 Jan 2024 15:18:28 +0800 Subject: [PATCH] fix CI bugs --- colossalai/inference/core/engine.py | 9 ++++++++- colossalai/inference/core/request_handler.py | 9 +++++---- colossalai/inference/modeling/layers/attention.py | 7 +++++-- tests/test_infer/test_inference_engine.py | 3 ++- tests/test_infer/test_request_handler.py | 2 +- 5 files changed, 21 insertions(+), 9 deletions(-) diff --git a/colossalai/inference/core/engine.py b/colossalai/inference/core/engine.py index 6f582c619..eaacfe0f5 100644 --- a/colossalai/inference/core/engine.py +++ b/colossalai/inference/core/engine.py @@ -191,7 +191,14 @@ class InferenceEngine: prompt = None else: prompt = prompts[i] - block_table = torch.full([self.inference_config.max_seq_len], -1, device=self.device) + + max_blocks_per_sequence = ( + self.inference_config.max_input_len + + self.inference_config.max_output_len + + self.inference_config.block_size + - 1 + ) // self.inference_config.block_size + block_table = torch.full([max_blocks_per_sequence], -1, device=self.device) sequence = Sequence( request_id, prompt, diff --git a/colossalai/inference/core/request_handler.py b/colossalai/inference/core/request_handler.py index 7fad20211..a83e5041d 100644 --- a/colossalai/inference/core/request_handler.py +++ b/colossalai/inference/core/request_handler.py @@ -7,7 +7,7 @@ from colossalai.inference.config import InferenceConfig from colossalai.inference.kv_cache import KVCacheManager from colossalai.inference.logit_processors import logit_processor from colossalai.inference.sampler import * -from colossalai.inference.struct import BatchInfo, Sequence +from colossalai.inference.struct import BatchInfo, RequestStatus, Sequence from colossalai.logging import get_dist_logger logger = get_dist_logger(__name__) @@ -104,7 +104,7 @@ class RequestHandler: f"the prompt(Request id = {seq.request_id}) length is longer than max_input_len, abort this sequence." ) self.abort_sequence(seq.request_id) - remove_list.append(seq) + break # Try to allocate cache blocks for the sequence. if self.cache_manager.check_allocation(seq): # If succeed, add the sequence to running list. @@ -139,9 +139,10 @@ class RequestHandler: """ Abort the request. """ - seq, _ = self._find_sequence(request_id) - if seq.status.is_waiting: + seq, priority = self._find_sequence(request_id) + if seq.status == RequestStatus.WAITING: seq.mark_aborted() + self.waiting_list[priority].remove(seq) elif seq.status.is_running(): self.cache_manager.free_block_table(seq.block_table) self.running_list.remove(seq) diff --git a/colossalai/inference/modeling/layers/attention.py b/colossalai/inference/modeling/layers/attention.py index d95504903..b5cb2c073 100644 --- a/colossalai/inference/modeling/layers/attention.py +++ b/colossalai/inference/modeling/layers/attention.py @@ -217,6 +217,8 @@ class PagedAttention: attn_weights = torch.matmul(q, k.transpose(2, 3)) / math.sqrt(head_size) + padding_mask = None + if attn_mask is not None: padding_mask = AttentionMaskConverter._expand_mask(attn_mask, q.dtype, seq_len) @@ -279,11 +281,12 @@ class PagedAttention: if attn_weights.size() != (bsz, num_heads, 1, seq_len): raise ValueError(f"Got wrong attn_weights, should be in shape {(bsz,num_heads,1,seq_len)}.") + padding_mask = None if attn_mask is not None: - padding_mask = AttentionMaskConverter._expand_mask(attn_mask, q.dtype, query_length) + padding_mask = AttentionMaskConverter._expand_mask(attn_mask, q.dtype, q_length) attn_mask = AttentionMaskConverter._make_causal_mask( - (bsz, q_length), q.dtype, q.device, past_key_values_length=seq_len - query_length + (bsz, q_length), q.dtype, q.device, past_key_values_length=seq_len - q_length ) if padding_mask is not None: diff --git a/tests/test_infer/test_inference_engine.py b/tests/test_infer/test_inference_engine.py index ede4fb18a..bf626d758 100644 --- a/tests/test_infer/test_inference_engine.py +++ b/tests/test_infer/test_inference_engine.py @@ -11,6 +11,7 @@ from colossalai.inference.config import InferenceConfig from colossalai.inference.core.engine import InferenceEngine from colossalai.testing import rerun_if_address_is_in_use, spawn + def setup_seed(seed): torch.manual_seed(seed) torch.cuda.manual_seed_all(seed) @@ -34,7 +35,7 @@ def check_inference_engine(test_cai=False): "介绍一下武汉,", ] - output_len = 128 + output_len = 38 do_sample = True top_p = 0.5 top_k = 50 diff --git a/tests/test_infer/test_request_handler.py b/tests/test_infer/test_request_handler.py index aa2cac6cb..673fcf9cf 100644 --- a/tests/test_infer/test_request_handler.py +++ b/tests/test_infer/test_request_handler.py @@ -57,7 +57,7 @@ def check_request_handler(): block_size=16, eos_token_id=0, sample_params=None, - block_table=torch.tensor([0, 0]), + block_table=torch.tensor([-1, -1]), ) request_handler.add_sequence(seq1) # the priority should be 1