From b6696beb049b29af5d9bf43923d0cfd58511383b Mon Sep 17 00:00:00 2001 From: Bin Jia <45593998+FoolPlayer@users.noreply.github.com> Date: Wed, 1 Nov 2023 12:46:21 +0800 Subject: [PATCH] [Pipeline Inference] Merge pp with tp (#4993) * refactor pipeline into new CaiInferEngine * updata llama modeling forward * merge tp with pp * update docstring * optimize test workflow and example * fix typo * add assert and todo --- colossalai/inference/__init__.py | 6 +- colossalai/inference/hybridengine/__init__.py | 3 + .../{pipeline => hybridengine}/engine.py | 98 ++++--- .../modeling/__init__.py | 0 .../modeling/_utils.py | 0 .../modeling/llama.py | 239 ++++++++++-------- .../polices}/__init__.py | 0 .../polices}/llama.py | 5 +- colossalai/inference/pipeline/__init__.py | 4 +- .../inference/pipeline/microbatch_manager.py | 17 +- .../tensor_parallel/modeling/llama.py | 56 ++-- tests/test_infer/test_pipeline_infer.py | 43 +++- 12 files changed, 268 insertions(+), 203 deletions(-) create mode 100644 colossalai/inference/hybridengine/__init__.py rename colossalai/inference/{pipeline => hybridengine}/engine.py (60%) rename colossalai/inference/{pipeline => hybridengine}/modeling/__init__.py (100%) rename colossalai/inference/{pipeline => hybridengine}/modeling/_utils.py (100%) rename colossalai/inference/{pipeline => hybridengine}/modeling/llama.py (74%) rename colossalai/inference/{pipeline/policies => hybridengine/polices}/__init__.py (100%) rename colossalai/inference/{pipeline/policies => hybridengine/polices}/llama.py (95%) diff --git a/colossalai/inference/__init__.py b/colossalai/inference/__init__.py index 761e48e59..d5a988cfc 100644 --- a/colossalai/inference/__init__.py +++ b/colossalai/inference/__init__.py @@ -1,4 +1,4 @@ -from .pipeline import PPInferEngine +from .hybridengine import CaiInferEngine +from .hybridengine.polices import LlamaModelInferPolicy - -__all__ = ['PPInferEngine'] +__all__ = ["CaiInferEngine", "LlamaModelInferPolicy"] diff --git a/colossalai/inference/hybridengine/__init__.py b/colossalai/inference/hybridengine/__init__.py new file mode 100644 index 000000000..6377ef817 --- /dev/null +++ b/colossalai/inference/hybridengine/__init__.py @@ -0,0 +1,3 @@ +from .engine import CaiInferEngine + +__all__ = ["CaiInferEngine"] diff --git a/colossalai/inference/pipeline/engine.py b/colossalai/inference/hybridengine/engine.py similarity index 60% rename from colossalai/inference/pipeline/engine.py rename to colossalai/inference/hybridengine/engine.py index 480ac5dc7..bb0b4c77a 100644 --- a/colossalai/inference/pipeline/engine.py +++ b/colossalai/inference/hybridengine/engine.py @@ -1,4 +1,5 @@ import torch +import torch.distributed as dist import torch.nn as nn from transformers.tokenization_utils_base import BatchEncoding @@ -8,23 +9,27 @@ from colossalai.pipeline.stage_manager import PipelineStageManager from colossalai.shardformer import ShardConfig, ShardFormer from colossalai.shardformer.policies.base_policy import Policy +from ..pipeline.microbatch_manager import MicroBatchManager from ..tensor_parallel.kvcache_manager import MemoryManager -from .microbatch_manager import MicroBatchManager + +PP_AXIS, TP_AXIS = 0, 1 + +_supported_models = [ + "LlamaForCausalLM", +] -class PPInferEngine: +class CaiInferEngine: """ - PPInferEngine is a class that handles the pipeline parallel inference. + CaiInferEngine is a class that handles the pipeline parallel inference. Args: - pp_size (int): the number of pipeline stages. - pp_model (`nn.Module`): the model already in pipeline parallelism style. + tp_size (int): the size of tensor parallelism. + pp_size (int): the size of pipeline parallelism. model (`nn.Module`): the model not in pipeline style, and will be modified with `ShardFormer`. model_policy (`colossalai.shardformer.policies.base_policy.Policy`): the policy to shardformer model. micro_batch_size (int): the micro batch size. micro_batch_buffer_size (int): the buffer size for micro batch. Normally, it should be the same as the number of pipeline stages. - new_length (int): the new length of the input sequence. - early_stopping (bool): whether to stop early. max_batch_size (int): the maximum batch size. max_input_len (int): the maximum input length. max_output_len (int): the maximum output length. @@ -32,7 +37,7 @@ class PPInferEngine: Example: ```python - from colossalai.inference import PPInferEngine + from colossalai.inference import InferEngine from colossalai.inference.pipeline.policies import LlamaModelInferPolicy import colossalai from transformers import LlamaForCausalLM, LlamaTokenizer @@ -42,7 +47,7 @@ class PPInferEngine: model = LlamaForCausalLM.from_pretrained("your_path_to_model") tokenizer = LlamaTokenizer.from_pretrained("/home/lczyh/share/models/llama-7b-hf") # assume the model is infered with 2 pipeline stages - inferengine = PPInferEngine(pp_size=2, model=model, model_policy=LlamaModelInferPolicy(), new_length=8) + inferengine = CaiInferEngine(pp_size=2, model=model, model_policy=LlamaModelInferPolicy()) input = ["Introduce a landmark in China ","Introduce a landmark in China "] data = tokenizer(input, return_tensors='pt') @@ -54,12 +59,11 @@ class PPInferEngine: def __init__( self, - pp_size: int, + tp_size: int = 1, + pp_size: int = 1, dtype: str = "fp16", - pp_model: nn.Module = None, model: nn.Module = None, model_policy: Policy = None, - new_length: int = 32, micro_batch_size: int = 1, micro_batch_buffer_size: int = None, max_batch_size: int = 4, @@ -71,12 +75,21 @@ class PPInferEngine: do_sample: bool = False, num_beams: int = 1, ) -> None: - assert pp_model or (model and model_policy), "Either pp_model or model with model_policy should be provided." + assert model.__class__.__name__ in _supported_models, f"Model {model.__class__.__name__} is not supported." + assert ( + tp_size * pp_size == dist.get_world_size() + ), f"TP size({tp_size}) * PP size({pp_size}) should be equal to the global world size ({dist.get_world_size()})" + assert model and model_policy, "Model with model_policy should be provided." assert dtype in ["fp16", "fp32", "bf16"], "dtype should be one of 'fp16', 'fp32', 'bf16'" - max_output_len = max(max_output_len, max_input_len + new_length) + assert max_batch_size <= 64, "Max batch size exceeds the constraint" + assert max_input_len + max_output_len <= 4096, "Max length exceeds the constraint" + # TODO: support only tensor parallel inference + assert pp_size > 1, "Not support only tensor parallel inference." self.pp_size = pp_size + self.tp_size = tp_size + if dtype == "fp16": self.dtype = torch.float16 model.half() @@ -85,24 +98,29 @@ class PPInferEngine: model.to(torch.bfloat16) else: self.dtype = torch.float32 - self.pg_mesh = ProcessGroupMesh(pp_size) - self.stage_manager = PipelineStageManager(self.pg_mesh, 0, True) - self.model = pp_model or self._shardformer(model, model_policy) - self.cache_manager_list = [ - self._init_manager(max_batch_size, max_input_len, max_output_len) - for _ in range(micro_batch_buffer_size or pp_size) - ] - self.mb_manager = MicroBatchManager( - self.stage_manager.stage, - new_length, - micro_batch_size, - micro_batch_buffer_size or pp_size, - max_input_len, - max_output_len, - self.cache_manager_list, - ) - self.verbose = verbose - self.schedule = GenerateSchedule(self.stage_manager, self.mb_manager, verbose) + + # Init pg mesh + pg_mesh = ProcessGroupMesh(pp_size, tp_size) + + stage_manager = None + if pp_size > 1: + stage_manager = PipelineStageManager(pg_mesh, PP_AXIS, True) + self.cache_manager_list = [ + self._init_manager(model, max_batch_size, max_input_len, max_output_len) + for _ in range(micro_batch_buffer_size or pp_size) + ] + self.mb_manager = MicroBatchManager( + stage_manager.stage, + micro_batch_size, + micro_batch_buffer_size or pp_size, + max_input_len, + max_output_len, + self.cache_manager_list, + ) + self.verbose = verbose + self.schedule = GenerateSchedule(stage_manager, self.mb_manager, verbose) + + self.model = self._shardformer(model, model_policy, stage_manager, pg_mesh.get_group_along_axis(TP_AXIS)) def inference(self, input_list): """ @@ -124,10 +142,10 @@ class PPInferEngine: else: return out - def _shardformer(self, model, model_policy): + def _shardformer(self, model, model_policy, stage_manager, tp_group): shardconfig = ShardConfig( - tensor_parallel_process_group=None, - pipeline_stage_manager=self.stage_manager, + tensor_parallel_process_group=tp_group, + pipeline_stage_manager=stage_manager, enable_tensor_parallelism=False, enable_fused_normalization=False, enable_all_optimization=False, @@ -139,14 +157,12 @@ class PPInferEngine: shard_model, _ = shardformer.optimize(model, model_policy) return shard_model.cuda() - def _init_manager(self, max_batch_size: int, max_input_len: int, max_output_len: int) -> None: + def _init_manager(self, model, max_batch_size: int, max_input_len: int, max_output_len: int) -> None: max_total_token_num = max_batch_size * (max_input_len + max_output_len) - head_dim = self.model.config.hidden_size // self.model.config.num_attention_heads - head_num = self.model.config.num_attention_heads + head_dim = model.config.hidden_size // model.config.num_attention_heads + head_num = model.config.num_attention_heads num_hidden_layers = ( - self.model.config.num_hidden_layers - if hasattr(self.model.config, "num_hidden_layers") - else self.model.config.num_layers + model.config.num_hidden_layers if hasattr(model.config, "num_hidden_layers") else model.config.num_layers ) layer_num = num_hidden_layers // self.pp_size diff --git a/colossalai/inference/pipeline/modeling/__init__.py b/colossalai/inference/hybridengine/modeling/__init__.py similarity index 100% rename from colossalai/inference/pipeline/modeling/__init__.py rename to colossalai/inference/hybridengine/modeling/__init__.py diff --git a/colossalai/inference/pipeline/modeling/_utils.py b/colossalai/inference/hybridengine/modeling/_utils.py similarity index 100% rename from colossalai/inference/pipeline/modeling/_utils.py rename to colossalai/inference/hybridengine/modeling/_utils.py diff --git a/colossalai/inference/pipeline/modeling/llama.py b/colossalai/inference/hybridengine/modeling/llama.py similarity index 74% rename from colossalai/inference/pipeline/modeling/llama.py rename to colossalai/inference/hybridengine/modeling/llama.py index 9c72b02cc..34474d115 100644 --- a/colossalai/inference/pipeline/modeling/llama.py +++ b/colossalai/inference/hybridengine/modeling/llama.py @@ -1,37 +1,25 @@ # This code is adapted from huggingface transformers: https://github.com/huggingface/transformers/blob/v4.34.1/src/transformers/models/llama/modeling_llama.py +import math from typing import List, Optional, Tuple import torch -from transformers.models.llama.modeling_llama import ( - LlamaAttention, - LlamaDecoderLayer, - LlamaForCausalLM, - LlamaModel, - LlamaRMSNorm, -) +from transformers.models.llama.modeling_llama import LlamaAttention, LlamaDecoderLayer, LlamaForCausalLM, LlamaModel from transformers.utils import logging from colossalai.inference.tensor_parallel.batch_infer_state import BatchInferState from colossalai.kernel.triton import llama_context_attn_fwd, token_attention_fwd +from colossalai.kernel.triton.token_attention_kernel import Llama2TokenAttentionForwards from colossalai.pipeline.stage_manager import PipelineStageManager from ._utils import copy_kv_to_mem_cache try: - from vllm import layernorm_ops, pos_encoding_ops - - rms_norm = layernorm_ops.rms_norm - rotary_embedding_neox = pos_encoding_ops.rotary_embedding_neox - HAS_VLLM_KERNERL = True -except: - print("fall back to original rotary_embedding_neox of huggingface") - print("install vllm from https://github.com/vllm-project/vllm to accelerate your inference") - print( - "if falied to install vllm, please use this branch to install: https://github.com/tiandiao123/vllm/tree/setup_branch" + from lightllm.models.llama2.triton_kernel.context_flashattention_nopad import ( + context_attention_fwd as lightllm_llama2_context_attention_fwd, + ) + from lightllm.models.llama.triton_kernel.context_flashattention_nopad import ( + context_attention_fwd as lightllm_context_attention_fwd, ) - HAS_VLLM_KERNERL = False - -try: from lightllm.models.llama.triton_kernel.rotary_emb import rotary_emb_fwd as llama_rotary_embedding_fwd HAS_LIGHTLLM_KERNEL = True @@ -39,6 +27,14 @@ except: print("please install lightllm from source to run inference: https://github.com/ModelTC/lightllm") HAS_LIGHTLLM_KERNEL = False +try: + from flash_attn import flash_attn_with_kvcache + + HAS_FLASH_KERNEL = True +except: + HAS_FLASH_KERNEL = False + print("please install flash attentiom from https://github.com/Dao-AILab/flash-attention") + def rotate_half(x): """Rotates half the hidden dims of the input.""" @@ -59,6 +55,75 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids): return q_embed, k_embed +def llama_triton_context_attention( + query_states, key_states, value_states, attn_output, infer_state, num_key_value_groups=1 +): + if num_key_value_groups == 1: + if HAS_LIGHTLLM_KERNEL is False: + llama_context_attn_fwd( + query_states, + key_states, + value_states, + attn_output, + infer_state.start_loc, + infer_state.seq_len, + # infer_state.cache_manager.past_key_values_length, + infer_state.max_len_in_batch, + ) + else: + lightllm_context_attention_fwd( + query_states, + key_states, + value_states, + attn_output, + infer_state.start_loc, + infer_state.seq_len, + # infer_state.cache_manager.past_key_values_length, + infer_state.max_len_in_batch, + ) + else: + assert HAS_LIGHTLLM_KERNEL is True, "You have to install lightllm kernels to run llama2 model" + lightllm_llama2_context_attention_fwd( + query_states, + key_states, + value_states, + attn_output, + infer_state.start_loc, + infer_state.seq_len, + # infer_state.cache_manager.past_key_values_length, + infer_state.max_len_in_batch, + ) + + +def llama_triton_token_attention(query_states, attn_output, infer_state, num_key_value_groups=1): + assert HAS_LIGHTLLM_KERNEL is True, "You have to install lightllm kernel to run token attention for llama models" + if num_key_value_groups == 1: + token_attention_fwd( + query_states, + infer_state.cache_manager.key_buffer[infer_state.decode_layer_id], + infer_state.cache_manager.value_buffer[infer_state.decode_layer_id], + attn_output, + infer_state.block_loc, + infer_state.start_loc, + infer_state.seq_len, + # infer_state.cache_manager.past_key_values_length, + infer_state.max_len_in_batch, + ) + else: + Llama2TokenAttentionForwards.token_attn( + query_states, + infer_state.cache_manager.key_buffer[infer_state.decode_layer_id], + infer_state.cache_manager.value_buffer[infer_state.decode_layer_id], + attn_output, + infer_state.block_loc, + infer_state.start_loc, + infer_state.seq_len, + # infer_state.cache_manager.past_key_values_length, + infer_state.max_len_in_batch, + infer_state.other_kv_index, + ) + + class LlamaInferenceForwards: """ This class holds forwards for llama inference. @@ -144,13 +209,9 @@ class LlamaInferenceForwards: hidden_states: Optional[torch.FloatTensor] = None, stage_index: Optional[List[int]] = None, ): - # batch_size = input_ids.shape[0] # input_ids.shape[0] - # print(f"[Before] rank:{torch.distributed.get_rank()}\n->{infer_state}") - - # infer_state = self.infer_state - use_cache = use_cache if use_cache is not None else self.config.use_cache return_dict = return_dict if return_dict is not None else self.config.use_return_dict + use_cache = use_cache if use_cache is not None else self.config.use_cache # retrieve input_ids and inputs_embeds if stage_manager is None or stage_manager.is_first_stage(): if input_ids is not None and inputs_embeds is not None: @@ -172,12 +233,10 @@ class LlamaInferenceForwards: batch_size, seq_length = input_shape device = hidden_states.device - seq_length_with_past = seq_length - past_key_values_length = 0 - - if infer_state.is_context_stage is False: - past_key_values_length = infer_state.cache_manager.past_key_values_length - seq_length_with_past = seq_length_with_past + past_key_values_length + if infer_state.is_context_stage: + past_key_values_length = 0 + else: + past_key_values_length = infer_state.max_len_in_batch - 1 # NOTE: differentiate with prefill stage # block_loc require different value-assigning method for two different stage @@ -197,26 +256,19 @@ class LlamaInferenceForwards: infer_state.decode_mem_index = alloc_mem[0] infer_state.decode_mem_start = alloc_mem[1] infer_state.decode_mem_end = alloc_mem[2] - infer_state.block_loc[:, seq_length_with_past - 1] = infer_state.decode_mem_index + infer_state.block_loc[:, infer_state.max_len_in_batch - 1] = infer_state.decode_mem_index else: - print(f" *** Encountered allocation non-contiguous") - print( - f" infer_state.cache_manager.past_key_values_length: {infer_state.cache_manager.past_key_values_length}" - ) infer_state.decode_is_contiguous = False alloc_mem = infer_state.cache_manager.alloc(batch_size) infer_state.decode_mem_index = alloc_mem - # infer_state.decode_key_buffer = torch.empty((batch_size, self.tp_head_num_, self.head_dim_), dtype=torch.float16, device="cuda") - # infer_state.decode_value_buffer = torch.empty((batch_size, self.tp_head_num_, self.head_dim_), dtype=torch.float16, device="cuda") - infer_state.block_loc[:, seq_length_with_past - 1] = infer_state.decode_mem_index + infer_state.block_loc[:, infer_state.max_len_in_batch - 1] = infer_state.decode_mem_index + if position_ids is None: position_ids = torch.arange( past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device ) - position_ids = position_ids.unsqueeze(0) - new_shape = [1] * position_ids.dim() - new_shape[0] = batch_size - position_ids = position_ids.repeat(*new_shape).view(-1, seq_length) + position_ids = position_ids.repeat(batch_size, 1) + position_ids = position_ids.unsqueeze(0).view(-1, seq_length) else: position_ids = position_ids.view(-1, seq_length).long() @@ -227,15 +279,17 @@ class LlamaInferenceForwards: infer_state.position_sin = torch.index_select(self._sin_cached, 0, position_ids.view(-1)).view( position_ids.view(-1).shape[0], -1 ) + else: seq_len = infer_state.seq_len infer_state.position_cos = torch.index_select(self._cos_cached, 0, seq_len - 1).view(seq_len.shape[0], -1) infer_state.position_sin = torch.index_select(self._sin_cached, 0, seq_len - 1).view(seq_len.shape[0], -1) + infer_state.other_kv_index = infer_state.block_loc[0, infer_state.max_len_in_batch - 1].item() # embed positions if attention_mask is None: attention_mask = torch.ones( - (batch_size, seq_length_with_past), dtype=torch.bool, device=hidden_states.device + (batch_size, infer_state.max_len_in_batch), dtype=torch.bool, device=hidden_states.device ) attention_mask = self._prepare_decoder_attention_mask( @@ -243,10 +297,6 @@ class LlamaInferenceForwards: ) # decoder layers - () if output_hidden_states else None - () if output_attentions else None - next_decoder_cache = () if use_cache else None - infer_state.decode_layer_id = 0 start_idx, end_idx = stage_index[0], stage_index[1] @@ -268,19 +318,15 @@ class LlamaInferenceForwards: infer_state.decode_layer_id += 1 hidden_states = layer_outputs[0] - if use_cache: - next_decoder_cache += (layer_outputs[2 if output_attentions else 1],) - - if stage_manager.is_last_stage(): + if stage_manager.is_last_stage() or stage_manager.num_stages == 1: hidden_states = self.norm(hidden_states) - next_cache = next_decoder_cache if use_cache else None # update indices # infer_state.block_loc[:, infer_state.max_len_in_batch-1] = infer_state.total_token_num + torch.arange(0, batch_size, dtype=torch.int32, device="cuda") - infer_state.start_loc = infer_state.start_loc + torch.arange(0, batch_size, dtype=torch.int32, device="cuda") + infer_state.start_loc += torch.arange(0, batch_size, dtype=torch.int32, device="cuda") infer_state.seq_len += 1 + infer_state.max_len_in_batch += 1 - # TODO: fix this to necessary return # if not return_dict: # return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) @@ -290,8 +336,7 @@ class LlamaInferenceForwards: # hidden_states=all_hidden_states, # attentions=all_self_attns, # ) - # print(f"[After] rank:{torch.distributed.get_rank()}\n->{infer_state}") - return {"hidden_states": hidden_states, "past_key_values": next_cache} + return {"hidden_states": hidden_states} @staticmethod def llama_decoder_layer_forward( @@ -307,7 +352,6 @@ class LlamaInferenceForwards: residual = hidden_states hidden_states = self.input_layernorm(hidden_states) - # Self Attention hidden_states, self_attn_weights, present_key_value = self.self_attn( hidden_states=hidden_states, @@ -357,28 +401,24 @@ class LlamaInferenceForwards: # key_states_transposed [bs, num_heads, seq_len, head_dim/embed_size_per_head] query_states = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim) - key_states = self.k_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim) - value_states = self.v_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim) + key_states = self.k_proj(hidden_states).view(bsz, q_len, self.num_key_value_heads, self.head_dim) + value_states = self.v_proj(hidden_states).view(bsz, q_len, self.num_key_value_heads, self.head_dim) # NOTE might want to revise # need some way to record the length of past key values cache # since we won't return past_key_value_cache right now - if infer_state.decode_layer_id == 0: # once per model.forward - infer_state.cache_manager.past_key_values_length += q_len # seq_len cos, sin = infer_state.position_cos, infer_state.position_sin llama_rotary_embedding_fwd(query_states.view(-1, self.num_heads, self.head_dim), cos, sin) - llama_rotary_embedding_fwd(key_states.view(-1, self.num_heads, self.head_dim), cos, sin) + llama_rotary_embedding_fwd(key_states.view(-1, self.num_key_value_heads, self.head_dim), cos, sin) query_states = query_states.reshape(-1, self.num_heads, self.head_dim) - key_states = key_states.reshape(-1, self.num_heads, self.head_dim) - value_states = value_states.reshape(-1, self.num_heads, self.head_dim) + key_states = key_states.reshape(-1, self.num_key_value_heads, self.head_dim) + value_states = value_states.reshape(-1, self.num_key_value_heads, self.head_dim) if infer_state.is_context_stage: - # print(f"rank:{torch.distributed.get_rank()}, {infer_state}") # first token generation - # copy key and value calculated in current step to memory manager copy_kv_to_mem_cache( infer_state.decode_layer_id, @@ -387,19 +427,16 @@ class LlamaInferenceForwards: infer_state.context_mem_index, infer_state.cache_manager, ) - attn_output = torch.empty_like(query_states) - llama_context_attn_fwd( + llama_triton_context_attention( query_states, key_states, value_states, attn_output, - infer_state.start_loc, - infer_state.seq_len, - infer_state.cache_manager.past_key_values_length, + infer_state, + num_key_value_groups=self.num_key_value_groups, ) - else: if infer_state.decode_is_contiguous: # if decode is contiguous, then we copy to key cache and value cache in cache manager directly @@ -422,45 +459,31 @@ class LlamaInferenceForwards: infer_state.cache_manager, ) - # second token and follows - # kv = torch.stack((key_states, value_states), dim=2) - # (batch_size, seqlen, nheads, headdim) - attn_output = torch.empty_like(query_states) + if HAS_LIGHTLLM_KERNEL: + attn_output = torch.empty_like(query_states) + llama_triton_token_attention( + query_states, attn_output, infer_state, num_key_value_groups=self.num_key_value_groups + ) + else: + self.num_heads // self.num_key_value_heads + cache_k = infer_state.cache_manager.key_buffer[infer_state.decode_layer_id] + cache_v = infer_state.cache_manager.value_buffer[infer_state.decode_layer_id] - token_attention_fwd( - query_states, - infer_state.cache_manager.key_buffer[infer_state.decode_layer_id], - infer_state.cache_manager.value_buffer[infer_state.decode_layer_id], - attn_output, - infer_state.block_loc, - infer_state.start_loc, - infer_state.seq_len, - infer_state.cache_manager.past_key_values_length, - ) + query_states = query_states.view(bsz, -1, self.num_heads, self.head_dim) + copy_cache_k = cache_k.view(bsz, -1, self.num_key_value_heads, self.head_dim) + copy_cache_v = cache_v.view(bsz, -1, self.num_key_value_heads, self.head_dim) + + attn_output = flash_attn_with_kvcache( + q=query_states, + k_cache=copy_cache_k, + v_cache=copy_cache_v, + softmax_scale=1 / math.sqrt(self.head_dim), + causal=True, + ) attn_output = attn_output.view(bsz, q_len, self.hidden_size) - # print(f"rank:{torch.distributed.get_rank()}, {attn_output}") + attn_output = self.o_proj(attn_output) # return past_key_value as None return attn_output, None, None - - -def get_llama_vllm_rmsnorm_forward(): - if HAS_VLLM_KERNERL: - - def _vllm_rmsnorm_forward(self: LlamaRMSNorm, hidden_states: torch.Tensor): - x = hidden_states - out = torch.empty_like(x) - rms_norm( - out, - x, - self.weight.data, - self.variance_epsilon, - ) - - return out - - return _vllm_rmsnorm_forward - else: - return None diff --git a/colossalai/inference/pipeline/policies/__init__.py b/colossalai/inference/hybridengine/polices/__init__.py similarity index 100% rename from colossalai/inference/pipeline/policies/__init__.py rename to colossalai/inference/hybridengine/polices/__init__.py diff --git a/colossalai/inference/pipeline/policies/llama.py b/colossalai/inference/hybridengine/polices/llama.py similarity index 95% rename from colossalai/inference/pipeline/policies/llama.py rename to colossalai/inference/hybridengine/polices/llama.py index 9f8c93c61..992299714 100644 --- a/colossalai/inference/pipeline/policies/llama.py +++ b/colossalai/inference/hybridengine/polices/llama.py @@ -17,7 +17,7 @@ from colossalai.shardformer.policies.base_policy import ModulePolicyDescription, from colossalai.shardformer.policies.llama import LlamaForCausalLMPolicy from ..modeling._utils import init_to_get_rotary -from ..modeling.llama import LlamaInferenceForwards, get_llama_vllm_rmsnorm_forward +from ..modeling.llama import LlamaInferenceForwards try: from colossalai.kernel.triton import rmsnorm_forward @@ -120,9 +120,6 @@ class LlamaModelInferPolicy(LlamaForCausalLMPolicy): infer_forward = None if HAS_TRITON_RMSNORM: infer_forward = get_triton_rmsnorm_forward() - else: - # NOTE: adding rms_norm from cuda kernels caused precision issue, fix @tiandiao123 - infer_forward = get_llama_vllm_rmsnorm_forward() if infer_forward is not None: method_replacement = {"forward": partial(infer_forward)} diff --git a/colossalai/inference/pipeline/__init__.py b/colossalai/inference/pipeline/__init__.py index 41af9f3ef..f43e4a847 100644 --- a/colossalai/inference/pipeline/__init__.py +++ b/colossalai/inference/pipeline/__init__.py @@ -1,3 +1,3 @@ -from .engine import PPInferEngine +from .microbatch_manager import MicroBatchManager -__all__ = ["PPInferEngine"] +__all__ = ["MicroBatchManager"] diff --git a/colossalai/inference/pipeline/microbatch_manager.py b/colossalai/inference/pipeline/microbatch_manager.py index 2bf52161d..441cf6039 100644 --- a/colossalai/inference/pipeline/microbatch_manager.py +++ b/colossalai/inference/pipeline/microbatch_manager.py @@ -33,10 +33,9 @@ class MicroBatchDescription: max_input_len: int, max_output_len: int, cache_manager: MemoryManager, - new_length: int, ) -> None: self.mb_length = inputs_dict["input_ids"].shape[-1] - self.target_length = self.mb_length + new_length + self.target_length = self.mb_length + max_output_len self.infer_state = BatchInferState.init_from_batch( batch=inputs_dict, max_input_len=max_input_len, max_output_len=max_output_len, cache_manager=cache_manager ) @@ -77,7 +76,6 @@ class HeadMicroBatchDescription(MicroBatchDescription): Args: inputs_dict (Dict[str, torch.Tensor]): the inputs of current stage. The key should have `input_ids` and `attention_mask`. output_dict (Dict[str, torch.Tensor]): the outputs of previous stage. The key should have `hidden_states` and `past_key_values`. - new_length (int): the new length of the input sequence. """ @@ -87,9 +85,8 @@ class HeadMicroBatchDescription(MicroBatchDescription): max_input_len: int, max_output_len: int, cache_manager: MemoryManager, - new_length: int, ) -> None: - super().__init__(inputs_dict, max_input_len, max_output_len, cache_manager, new_length) + super().__init__(inputs_dict, max_input_len, max_output_len, cache_manager) assert inputs_dict is not None assert inputs_dict.get("input_ids") is not None and inputs_dict.get("attention_mask") is not None self.input_ids = inputs_dict["input_ids"] @@ -139,9 +136,8 @@ class BodyMicroBatchDescription(MicroBatchDescription): max_input_len: int, max_output_len: int, cache_manager: MemoryManager, - new_length: int, ) -> None: - super().__init__(inputs_dict, max_input_len, max_output_len, cache_manager, new_length) + super().__init__(inputs_dict, max_input_len, max_output_len, cache_manager) @property def cur_length(self): @@ -158,7 +154,6 @@ class MicroBatchManager: Args: stage (int): stage id of current stage. - new_length (int): the new length of the input sequence. micro_batch_size (int): the micro batch size. micro_batch_buffer_size (int): the buffer size for micro batch. Normally, it should be the same as the number of pipeline stages. @@ -167,7 +162,6 @@ class MicroBatchManager: def __init__( self, stage: int, - new_length: int, micro_batch_size: int, micro_batch_buffer_size: int, max_input_len: int, @@ -175,7 +169,6 @@ class MicroBatchManager: cache_manager_list: MemoryManager, ): self.stage = stage - self.new_length = new_length self.micro_batch_size = micro_batch_size self.buffer_size = micro_batch_buffer_size self.max_input_len = max_input_len @@ -188,11 +181,11 @@ class MicroBatchManager: def add_descrption(self, inputs_dict: Dict[str, torch.Tensor]): if self.stage == 0: self.mb_descrption_buffer[self.idx] = HeadMicroBatchDescription( - inputs_dict, self.max_input_len, self.max_output_len, self.cache_manager_list[self.idx], self.new_length + inputs_dict, self.max_input_len, self.max_output_len, self.cache_manager_list[self.idx] ) else: self.mb_descrption_buffer[self.idx] = BodyMicroBatchDescription( - inputs_dict, self.max_input_len, self.max_output_len, self.cache_manager_list[self.idx], self.new_length + inputs_dict, self.max_input_len, self.max_output_len, self.cache_manager_list[self.idx] ) def step(self, new_token: torch.Tensor = None): diff --git a/colossalai/inference/tensor_parallel/modeling/llama.py b/colossalai/inference/tensor_parallel/modeling/llama.py index 8573bb965..62c2aad3c 100644 --- a/colossalai/inference/tensor_parallel/modeling/llama.py +++ b/colossalai/inference/tensor_parallel/modeling/llama.py @@ -1,10 +1,9 @@ -from typing import List, Optional, Tuple import math -import copy +from typing import List, Optional, Tuple import torch from transformers.modeling_outputs import BaseModelOutputWithPast -from transformers.models.llama.modeling_llama import LlamaAttention, LlamaDecoderLayer, LlamaModel, LlamaRMSNorm +from transformers.models.llama.modeling_llama import LlamaAttention, LlamaDecoderLayer, LlamaModel from colossalai.inference.tensor_parallel.batch_infer_state import BatchInferState from colossalai.kernel.triton import llama_context_attn_fwd, token_attention_fwd @@ -16,7 +15,9 @@ try: from lightllm.models.llama2.triton_kernel.context_flashattention_nopad import ( context_attention_fwd as lightllm_llama2_context_attention_fwd, ) - from lightllm.models.llama.triton_kernel.context_flashattention_nopad import context_attention_fwd as lightllm_context_attention_fwd + from lightllm.models.llama.triton_kernel.context_flashattention_nopad import ( + context_attention_fwd as lightllm_context_attention_fwd, + ) from lightllm.models.llama.triton_kernel.rotary_emb import rotary_emb_fwd as llama_rotary_embedding_fwd HAS_LIGHTLLM_KERNEL = True @@ -26,6 +27,7 @@ except: try: from flash_attn import flash_attn_with_kvcache + HAS_FLASH_KERNEL = True except: HAS_FLASH_KERNEL = False @@ -50,7 +52,10 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids): k_embed = (k * cos) + (rotate_half(k) * sin) return q_embed, k_embed -def llama_triton_context_attention(query_states, key_states, value_states, attn_output, infer_state, num_key_value_groups=1): + +def llama_triton_context_attention( + query_states, key_states, value_states, attn_output, infer_state, num_key_value_groups=1 +): if num_key_value_groups == 1: if HAS_LIGHTLLM_KERNEL is False: llama_context_attn_fwd( @@ -87,6 +92,7 @@ def llama_triton_context_attention(query_states, key_states, value_states, attn_ infer_state.max_len_in_batch, ) + def llama_triton_token_attention(query_states, attn_output, infer_state, num_key_value_groups=1): assert HAS_LIGHTLLM_KERNEL is True, "You have to install lightllm kernel to run token attention for llama models" if num_key_value_groups == 1: @@ -265,8 +271,7 @@ class LlamaInferenceForwards: hidden_states=all_hidden_states, attentions=all_self_attns, ) - - + @staticmethod def llama_decoder_layer_forward( self: LlamaDecoderLayer, @@ -309,7 +314,6 @@ class LlamaInferenceForwards: outputs += (present_key_value,) return outputs - @staticmethod def llama_flash_attn_kvcache_forward( @@ -358,8 +362,15 @@ class LlamaInferenceForwards: infer_state.cache_manager, ) attn_output = torch.empty_like(query_states) - - llama_triton_context_attention(query_states, key_states, value_states, attn_output, infer_state, num_key_value_groups=self.num_key_value_groups) + + llama_triton_context_attention( + query_states, + key_states, + value_states, + attn_output, + infer_state, + num_key_value_groups=self.num_key_value_groups, + ) else: if infer_state.decode_is_contiguous: # if decode is contiguous, then we copy to key cache and value cache in cache manager directly @@ -381,26 +392,28 @@ class LlamaInferenceForwards: infer_state.decode_mem_index, infer_state.cache_manager, ) - - HAS_LIGHTLLM_KERNEL = False + if HAS_LIGHTLLM_KERNEL: attn_output = torch.empty_like(query_states) - llama_triton_token_attention(query_states, attn_output, infer_state, num_key_value_groups=self.num_key_value_groups) + llama_triton_token_attention( + query_states, attn_output, infer_state, num_key_value_groups=self.num_key_value_groups + ) else: - heads_per_group = self.num_heads // self.num_key_value_heads + self.num_heads // self.num_key_value_heads cache_k = infer_state.cache_manager.key_buffer[infer_state.decode_layer_id] cache_v = infer_state.cache_manager.value_buffer[infer_state.decode_layer_id] - + query_states = query_states.view(bsz, -1, self.num_heads, self.head_dim) copy_cache_k = cache_k.view(bsz, -1, self.num_key_value_heads, self.head_dim) copy_cache_v = cache_v.view(bsz, -1, self.num_key_value_heads, self.head_dim) - - attn_output = flash_attn_with_kvcache(q = query_states, - k_cache = copy_cache_k, - v_cache = copy_cache_v, - softmax_scale = 1/ math.sqrt(self.head_dim), - causal = True) + attn_output = flash_attn_with_kvcache( + q=query_states, + k_cache=copy_cache_k, + v_cache=copy_cache_v, + softmax_scale=1 / math.sqrt(self.head_dim), + causal=True, + ) attn_output = attn_output.view(bsz, q_len, self.hidden_size) @@ -408,4 +421,3 @@ class LlamaInferenceForwards: # return past_key_value as None return attn_output, None, None - diff --git a/tests/test_infer/test_pipeline_infer.py b/tests/test_infer/test_pipeline_infer.py index 6d02f2b32..3544153da 100644 --- a/tests/test_infer/test_pipeline_infer.py +++ b/tests/test_infer/test_pipeline_infer.py @@ -5,8 +5,7 @@ import transformers from packaging import version import colossalai -from colossalai.inference.pipeline import PPInferEngine -from colossalai.inference.pipeline.policies import LlamaModelInferPolicy +from colossalai.inference import CaiInferEngine, LlamaModelInferPolicy from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse("11.5") @@ -26,27 +25,43 @@ for k, v in inputs.items(): inputs[k] = v.to("cuda").repeat(*new_shape) -def pipeline_inference_test(pp_size, new_length, micro_batch_size): - model = transformers.LlamaForCausalLM(transformers.LlamaConfig(num_hidden_layers=4)) +def pipeline_inference_test(tp_size, pp_size, max_output_len, micro_batch_size): + model = transformers.LlamaForCausalLM( + transformers.LlamaConfig( + vocab_size=20000, hidden_size=512, intermediate_size=1536, num_attention_heads=4, num_hidden_layers=4 + ) + ) - engine = PPInferEngine( + engine = CaiInferEngine( + tp_size=tp_size, pp_size=pp_size, model=model, model_policy=LlamaModelInferPolicy(), - new_length=new_length, + max_output_len=max_output_len, micro_batch_size=micro_batch_size, ) output = engine.inference(inputs) if dist.get_rank() == 0: - assert len(output[0]) == new_length, f"{len(output)}, {new_length}" + assert len(output[0]) == max_output_len, f"{len(output)}, {max_output_len}" +@parameterize("tp_size", [1]) @parameterize("pp_size", [2]) -@parameterize("new_length", [4, 8, 16]) -@parameterize("micro_batch_size", [1, 4]) +@parameterize("max_output_len", [4]) +@parameterize("micro_batch_size", [1]) @clear_cache_before_run() -def run_pipeline_inference_test(pp_size, new_length, micro_batch_size): - pipeline_inference_test(pp_size, new_length, micro_batch_size) +def run_pipeline_inference_test(tp_size, pp_size, max_output_len, micro_batch_size): + pipeline_inference_test(tp_size, pp_size, max_output_len, micro_batch_size) + torch.cuda.empty_cache() + + +@parameterize("tp_size", [2]) +@parameterize("pp_size", [2]) +@parameterize("max_output_len", [4]) +@parameterize("micro_batch_size", [1]) +@clear_cache_before_run() +def run_tp_pipeline_inference_test(tp_size, pp_size, max_output_len, micro_batch_size): + pipeline_inference_test(tp_size, pp_size, max_output_len, micro_batch_size) torch.cuda.empty_cache() @@ -55,12 +70,18 @@ def check_pipeline_inference(rank, world_size, port): run_pipeline_inference_test() +def check_tp_pipeline_inference(rank, world_size, port): + colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") + run_tp_pipeline_inference_test() + + @pytest.mark.skipif(not CUDA_SUPPORT, reason="kv-cache manager engine requires cuda version to be higher than 11.5") @pytest.mark.dist @rerun_if_address_is_in_use() @clear_cache_before_run() def test_pipeline_inference(): spawn(check_pipeline_inference, nprocs=2) + spawn(check_tp_pipeline_inference, nprocs=4) if __name__ == "__main__":