diff --git a/colossalai/inference/core/engine.py b/colossalai/inference/core/engine.py index 1ee62cd51..a94120a20 100644 --- a/colossalai/inference/core/engine.py +++ b/colossalai/inference/core/engine.py @@ -214,9 +214,6 @@ class InferenceEngine: List[str]: Decoded finished sequences generated by one step. """ - if self.verbose: - self.logger.info("Running generation step") - output_list = [] batch = self.request_handler.schedule() @@ -224,6 +221,7 @@ class InferenceEngine: batch, self.k_cahce, self.v_cache, + padding_id=self.tokenizer.pad_token_id, ) logits = logits[:, -1, :] diff --git a/colossalai/inference/core/request_handler.py b/colossalai/inference/core/request_handler.py index 1754a8862..7c2752a0d 100644 --- a/colossalai/inference/core/request_handler.py +++ b/colossalai/inference/core/request_handler.py @@ -110,6 +110,10 @@ class RequestHandler: self.prefill_batch.init_batch(self.running_list.prefill) return self.prefill_batch + if not self.running_batch.is_empty: + for seq in self.running_batch.sequences_set: + self.cache_manager.allocate_token_from_block_table(seq.block_table, seq.sentence_len) + return self.running_batch def add_sequence(self, req: Sequence): diff --git a/colossalai/inference/modeling/layers/attention.py b/colossalai/inference/modeling/layers/attention.py index 0a9f8566e..4619e8c45 100644 --- a/colossalai/inference/modeling/layers/attention.py +++ b/colossalai/inference/modeling/layers/attention.py @@ -29,47 +29,50 @@ def copy_to_cache(source, cache, lengths, block_tables, type: str = "prefill"): for block_idx in range(block_num - 1): cache[block_tables[i][block_idx]] = source[i][token_id : token_id + block_size].permute(1, 2, 0) token_id += block_size - cache[block_tables[i][block_num - 1]] = source[i][token_id:seq_len].permute(1, 2, 0) + cache[block_tables[i][block_num - 1], :, :, : seq_len - token_id] = source[i][token_id:seq_len].permute( + 1, 2, 0 + ) elif type == "decoding": assert len(source[0]) == 1, "seq_len should be equal to 1 when decoding." source = source.squeeze(1) slot_idx = (lengths + block_size - 1) % block_size for i in range(bsz): - cache[block_tables[i, needed_blocks[i] - 1], :, :, slot_idx[i]] = source[i].permute(0, 1) + cache[block_tables[i, needed_blocks[i] - 1], :, :, slot_idx[i]] = source[i] return cache -def convert_kvcache(source, cache, lengths, block_tables): +def convert_kvcache(cache, lengths, block_tables): """ Func: convert key/value cache for calculation - Args: key/value(source): shape [bsz, 1, num_heads, head_size] - cache: shape [num_blocks, num_heads, head_size, block_size] + Args: cache: shape [num_blocks, num_heads, head_size, block_size] lengths: key/value length block_tables """ num_blocks, num_heads, head_size, block_size = cache.shape needed_blocks = (lengths + block_size - 1) // block_size - num_remaing_tokens = (lengths - 1) % block_size + num_remaing_tokens = lengths % block_size + num_remaing_tokens[num_remaing_tokens == 0] += block_size bsz = block_tables.shape[0] seq_len = max(lengths) padded_cache = [] for i in range(bsz): + cache1 = cache[block_tables[i][: needed_blocks[i] - 1]].permute((0, 3, 1, 2)).reshape(-1, num_heads, head_size) + cache2 = cache[block_tables[i][needed_blocks[i] - 1], :, :, : num_remaing_tokens[i]].permute(2, 0, 1) + _cache = torch.cat( ( - cache[block_tables[i][: needed_blocks[i] - 1]].permute((3, 0, 1, 2)).reshape(-1, num_heads, head_size), - cache[block_tables[i][needed_blocks[i] - 1], :, :, : num_remaing_tokens[i]].permute(2, 1, 0), + cache1, + cache2, ), dim=0, ) - concat_cache = torch.cat((_cache, source[i]), dim=0) - padding = seq_len - concat_cache.size(0) + padding = seq_len - _cache.size(0) if padding > 0: - concat_cache = F.pad(concat_cache, (0, 0, 0, 0, 0, 1)) - padded_cache.append(concat_cache) - + _cache = F.pad(_cache, (0, 0, 0, 0, 0, 1)) + padded_cache.append(_cache) return torch.stack(padded_cache, dim=0) diff --git a/colossalai/inference/modeling/models/llama.py b/colossalai/inference/modeling/models/llama.py index b4246d947..b17ced6e6 100644 --- a/colossalai/inference/modeling/models/llama.py +++ b/colossalai/inference/modeling/models/llama.py @@ -1,11 +1,22 @@ # This code is adapted from huggingface transformers: https://github.com/huggingface/transformers/blob/v4.34.1/src/transformers/models/llama/modeling_llama.py +import math from typing import List, Optional, Tuple import torch -from transformers.models.llama.modeling_llama import LlamaAttention, LlamaDecoderLayer, LlamaForCausalLM, LlamaModel +import torch.nn as nn +from transformers.modeling_attn_mask_utils import AttentionMaskConverter +from transformers.models.llama.modeling_llama import ( + LlamaAttention, + LlamaDecoderLayer, + LlamaForCausalLM, + LlamaModel, + repeat_kv, +) +from colossalai.inference.modeling.layers.attention import convert_kvcache, copy_to_cache from colossalai.inference.struct import BatchInfo -from colossalai.kernel.triton import context_attention_unpadded + +from flash_attn.bert_padding import index_first_axis, pad_input # noqa def rotate_half(x): @@ -27,24 +38,12 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids): return q_embed, k_embed -# Copied from transformers.models.llama.modeling_llama.repeat_kv -def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: - """ - This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, - num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) - """ - batch, num_key_value_heads, slen, head_dim = hidden_states.shape - if n_rep == 1: - return hidden_states - hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) - return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) - - def llama_causal_lm_forward( self: LlamaForCausalLM, batch: BatchInfo = None, k_caches: List[torch.Tensor] = None, v_caches: List[torch.Tensor] = None, + padding_id: int = None, ): # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) hidden_states = llama_model_forward( @@ -52,6 +51,7 @@ def llama_causal_lm_forward( batch=batch, k_caches=k_caches, v_caches=v_caches, + padding_id=padding_id, ) logits = self.lm_head(hidden_states) return logits @@ -62,13 +62,20 @@ def llama_model_forward( batch: BatchInfo = None, k_caches: List[torch.Tensor] = None, v_caches: List[torch.Tensor] = None, + padding_id: int = None, ): input_ids = batch.get_batch_inputs() block_tables = batch.get_block_table_tensor() sequence_lengths = batch.get_sequence_lengths() - # Here, we generate position_ids through the input tensor, which can align with the output precision of the transformer. - position_ids = generate_padding_position_id(input_ids) + attention_mask = batch.get_attn_mask(padding_id) + + if batch.is_prompts: + # Here, we generate position_ids through the input tensor, which can align with the output precision of the transformer. + position_ids = generate_padding_position_id(attention_mask) + else: + position_ids = (attention_mask.sum(dim=-1) - 1).reshape(-1, 1) + hidden_states = self.embed_tokens(input_ids) for layer_id, decoder_layer in enumerate(self.layers): @@ -80,6 +87,7 @@ def llama_model_forward( v_cache=v_caches[layer_id], is_prompts=batch.is_prompts, sequence_lengths=sequence_lengths, + attention_mask=attention_mask, ) hidden_states = self.norm(hidden_states) @@ -96,6 +104,7 @@ def llama_decoder_layer_forward( v_cache: torch.Tensor = None, is_prompts: bool = True, sequence_lengths: int = None, + attention_mask: torch.Tensor = None, ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: residual = hidden_states @@ -109,6 +118,7 @@ def llama_decoder_layer_forward( v_cache=v_cache, is_prompts=is_prompts, sequence_lengths=sequence_lengths, + attention_mask=attention_mask, ) hidden_states = residual + hidden_states @@ -132,6 +142,7 @@ def llama_attn_forward( v_cache: torch.Tensor = None, is_prompts: bool = True, sequence_lengths: torch.Tensor = None, + attention_mask: torch.Tensor = None, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: bsz, q_len, _ = hidden_states.size() @@ -139,9 +150,7 @@ def llama_attn_forward( key_states = self.k_proj(hidden_states).view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) value_states = self.v_proj(hidden_states).view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - kv_seq_len = key_states.shape[-2] - if not is_prompts: - kv_seq_len = kv_seq_len + sequence_lengths[0].item() + kv_seq_len = sequence_lengths[0].item() cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) @@ -153,20 +162,26 @@ def llama_attn_forward( key_states = key_states.transpose(1, 2) value_states = value_states.transpose(1, 2) - query_states = query_states.view(-1, self.num_heads, self.head_dim) - key_states = key_states.view(-1, self.num_heads, self.head_dim) - value_states = value_states.view(-1, self.num_heads, self.head_dim) - - _, _, _, block_size = k_cache.shape - - # NOTE: context_attention_unpadded is used for testing accuracy and we can only use aligned inputs. - # The code below will be uncommented after the development of attention-related kernel is completed. if is_prompts: - attn_output = context_attention_unpadded( - query_states, key_states, value_states, k_cache, v_cache, sequence_lengths, block_tables, block_size + attn_output = pad_context_forward( + query_states, key_states, value_states, k_cache, v_cache, sequence_lengths, block_tables, attention_mask + ) + else: + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + attn_output = pad_decoding_forward( + query_states, + key_states, + value_states, + k_cache, + v_cache, + sequence_lengths, + block_tables, + attention_mask, + self.layer_idx, + self.attention_dropout, + self.training, ) - # else: - # attn_output = context_attention_unpadded(query_states, key_states, value_states, k_cache, v_cache, sequence_lengths, block_tables, block_size) attn_output = attn_output.view(bsz, q_len, self.num_heads, self.head_dim) attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) @@ -175,13 +190,129 @@ def llama_attn_forward( return attn_output -def generate_padding_position_id(input_ids: torch.Tensor) -> torch.Tensor: - # Replace this code and use a more flexible method to obtain padding_id, avoiding directly setting padding_id like this. - padding_id = 2 - attention_mask = input_ids.ne(padding_id).long() +def generate_padding_position_id(attention_mask: torch.Tensor) -> torch.Tensor: position_ids = attention_mask.long().cumsum(-1) - 1 position_ids.masked_fill_(attention_mask == 0, 1) return position_ids -# def unpad_inputs(input_ids: torch.Tensor): - + +def unpading_input(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, attention_mask: torch.Tensor): + seqlens = attention_mask.sum(dim=-1, dtype=torch.int32) + indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten() + batch_size, kv_seq_len, num_key_value_heads, head_dim = q.shape + q = index_first_axis(q.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices) + k = index_first_axis(k.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices) + v = index_first_axis(v.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices) + return (q, k, v, indices, seqlens) + + +def pad_decoding_forward( + query: torch.Tensor, # [bsz, 1, num_heads, head_size] + key: torch.Tensor, + value: torch.Tensor, + k_cache: torch.Tensor, # [num_blocks, num_heads, head_size, block_size] + v_cache: torch.Tensor, + lengths: torch.Tensor, # [num_seqs]: input_lengths + output_lengths + block_tables: torch.Tensor, # [num_seqs,max_blocks_per_sequence] + attn_mask: torch.Tensor = None, + layer_id: int = 0, + attention_dropout: float = None, + training: bool = False, +): + bsz, query_length, num_heads, head_size = query.shape + seq_len = max(lengths) + + copy_to_cache(key, k_cache, lengths=lengths, block_tables=block_tables, type="decoding") + copy_to_cache(value, v_cache, lengths=lengths, block_tables=block_tables, type="decoding") + + key = convert_kvcache(k_cache, lengths, block_tables) # bsz, seqlen, + value = convert_kvcache(v_cache, lengths, block_tables) + + query = query.transpose(1, 2) + key = key.transpose(1, 2) + value = value.transpose(1, 2) + + attn_weights = torch.matmul(query, key.transpose(2, 3)) / math.sqrt(head_size) + if attn_weights.size() != (bsz, num_heads, 1, seq_len): + raise ValueError(f"Got wrong attn_weights, should be in shape {(bsz,num_heads,1,seq_len)}.") + + if attn_mask is not None: + padding_mask = AttentionMaskConverter._expand_mask(attn_mask, query.dtype, query_length) + + attn_mask = AttentionMaskConverter._make_causal_mask( + (bsz, query_length), query.dtype, query.device, past_key_values_length=seq_len - query_length + ) + + if padding_mask is not None: + attn_mask = attn_mask.masked_fill(padding_mask.bool(), torch.finfo(query.dtype).min) + + attn_weights += attn_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=attention_dropout, training=training) + attn_output = torch.matmul(attn_weights, value) + + if attn_output.size() != (bsz, num_heads, 1, head_size): + raise ValueError(f"Got wrong attn_output, should be in shape {(bsz,num_heads,1,head_size)}.") + attn_output = attn_output.transpose(1, 2).contiguous().reshape(bsz, 1, -1) + + return attn_output + + +def pad_context_forward( + q: torch.Tensor, # [batch_size, seq_len, num_heads, head_size] + k: torch.Tensor, # [batch_size, seq_len, num_kv_heads, head_size] + v: torch.Tensor, + k_cache: torch.Tensor, # [num_blocks, num_heads, head_size, block_size] + v_cache: torch.Tensor, + context_lengths: torch.Tensor, # [num_seqs] + block_tables: torch.Tensor, # [num_seqs,max_blocks_per_sequence] + attn_mask: torch.Tensor = None, +): + # Firt, do shape verification + bsz, seq_len, num_heads, head_size = q.shape + num_kv_heads = k.shape[-2] + assert num_heads % num_kv_heads == 0, "num_kv_heads should be divisible by num_heads" + num_kv_groups = num_heads // num_kv_heads + block_size = k_cache.shape[-1] + assert q.shape[0] == k.shape[0] == v.shape[0] == block_tables.shape[0] + block_tables.shape[-1] * block_size + shape = (bsz, seq_len, num_heads, head_size) + input_shape = shape[:2] + + # Copy kv to memory(rotary embedded) + copy_to_cache(k, k_cache, lengths=context_lengths, block_tables=block_tables) + copy_to_cache(v, v_cache, lengths=context_lengths, block_tables=block_tables) + + q = q.transpose(1, 2) + k = repeat_kv(k.transpose(1, 2), num_kv_groups) + v = repeat_kv(v.transpose(1, 2), num_kv_groups) + + attn_weights = torch.matmul(q, k.transpose(2, 3)) / math.sqrt(head_size) + + if attn_mask is not None: + padding_mask = AttentionMaskConverter._expand_mask(attn_mask, q.dtype, seq_len) + + attn_mask = AttentionMaskConverter._make_causal_mask( + (bsz, seq_len), q.dtype, q.device, past_key_values_length=seq_len - seq_len + ) + + if padding_mask is not None: + attn_mask = attn_mask.masked_fill(padding_mask.bool(), torch.finfo(q.dtype).min) + + if attn_weights.size() != (bsz, num_heads, seq_len, seq_len): + raise ValueError(f"Got wrong attn_weights, should be in shape {(bsz,num_heads,seq_len,seq_len)}.") + if attn_mask is not None: + attn_weights += attn_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(q.dtype) + attn_output = torch.matmul(attn_weights, v) + + if attn_output.size() != (bsz, num_heads, seq_len, head_size): + raise ValueError(f"Got wrong attn_output, should be in shape {(bsz,num_heads,seq_len,head_size)}.") + + attn_output = attn_output.transpose(1, 2).contiguous().reshape(bsz, seq_len, -1) + + del attn_weights + + return attn_output diff --git a/colossalai/inference/struct.py b/colossalai/inference/struct.py index ec0bb442f..ef07b7ff9 100644 --- a/colossalai/inference/struct.py +++ b/colossalai/inference/struct.py @@ -321,5 +321,13 @@ class BatchInfo: return torch.tensor(len_list, dtype=torch.int, device=self.device) + def get_attn_mask(self, padding_id: int) -> torch.Tensor: + past_values = [] + + for seq in self.sequences_set: + past_values.append(seq.input_token_id + seq.output_token_id) + + return torch.tensor(past_values, dtype=torch.int, device=self.device).ne(padding_id).long() + def __repr__(self) -> str: return f"(sequences_set={self.sequences_set}, " f"is_prompts={self.is_prompts})" diff --git a/tests/test_infer/test_inference_engine.py b/tests/test_infer/test_inference_engine.py index 5315c7811..5fab016e5 100644 --- a/tests/test_infer/test_inference_engine.py +++ b/tests/test_infer/test_inference_engine.py @@ -9,7 +9,7 @@ from transformers import AutoTokenizer, GenerationConfig import colossalai from colossalai.inference.config import InferenceConfig from colossalai.inference.core.engine import InferenceEngine -from colossalai.testing import spawn +from colossalai.testing import rerun_if_address_is_in_use, spawn def setup_seed(seed): @@ -24,21 +24,24 @@ def check_inference_engine(test_cai=False): tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/llama-tokenizer") model = transformers.LlamaForCausalLM( transformers.LlamaConfig( - vocab_size=50000, hidden_size=512, intermediate_size=1536, num_attention_heads=4, num_hidden_layers=4 + vocab_size=50000, hidden_size=512, intermediate_size=1536, num_attention_heads=4, num_hidden_layers=16 ) ).cuda() inputs = [ - "介绍一下北京,", + "介绍一下今天的北京,", "介绍一下武汉,", ] + output_len = 16 + do_sample = True + if test_cai: - inference_config = InferenceConfig(max_output_len=1) + inference_config = InferenceConfig(max_output_len=output_len) inference_engine = InferenceEngine(model, tokenizer, inference_config, verbose=True) inference_engine.add_request(prompts=inputs) assert inference_engine.request_handler._has_waiting() - generation_config = GenerationConfig(do_sample=True, top_p=0.5, top_k=50) + generation_config = GenerationConfig(do_sample=do_sample, top_p=0.5, top_k=50) outputs = inference_engine.generate(generation_config) else: tokenizer.pad_token = tokenizer.eos_token @@ -46,7 +49,7 @@ def check_inference_engine(test_cai=False): inputs = tokenizer.batch_encode_plus(inputs, padding=True, return_tensors="pt")["input_ids"] inputs = inputs.cuda() generation_config = GenerationConfig( - do_sample=True, top_p=0.5, top_k=50, pad_token_id=tokenizer.pad_token_id, max_new_tokens=1 + do_sample=do_sample, top_p=0.5, top_k=50, pad_token_id=tokenizer.pad_token_id, max_new_tokens=output_len ) outputs = model.generate(inputs, generation_config=generation_config) outputs = tokenizer.batch_decode(outputs, skip_special_tokens=True) @@ -64,6 +67,7 @@ def run_dist(rank, world_size, port): @pytest.mark.dist +@rerun_if_address_is_in_use() def test_inference_engine(): spawn(run_dist, 1)