diff --git a/colossalai/inference/README.md b/colossalai/inference/README.md index ec4044127..0a9b5293d 100644 --- a/colossalai/inference/README.md +++ b/colossalai/inference/README.md @@ -236,7 +236,7 @@ Completion api is used for single sequence request, like answer a question or co - POST '/chat': Chat api is used for conversation-style request, which often includes dialogue participants(i.e. roles) and corresponding words. Considering the input data are very different from normal inputs, we introduce Chat-Template to match the data format in chat models. #### chat-template -Followed `transformers`, we add the chat-template argument. As chat models have been trained with very different formats for converting conversations into a single tokenizable string. Using a format that matches the training data is extremely important. This attribute(chat_template) is inclueded in HuggingFace tokenizers, containing a Jinja template that converts conversation histories into a correctly formatted string. You can refer to the [HuggingFace-blog](https://huggingface.co/blog/chat-templates) for more information. We also provide a simple example temlate bellow. Both str or file style chat template are supported. +Followed `transformers`, we add the chat-template argument. As chat models have been trained with very different formats for converting conversations into a single tokenizable string. Using a format that matches the training data is extremely important. This attribute(chat_template) is inclueded in HuggingFace tokenizers, containing a Jinja template that converts conversation histories into a correctly formatted string. You can refer to the [HuggingFace-blog](https://huggingface.co/blog/chat-templates) for more information. We also provide a simple example template bellow. Both str or file style chat template are supported. ### Usage #### Args for customizing your server The configuration for api server contains both serving interface and engine backend. diff --git a/colossalai/inference/config.py b/colossalai/inference/config.py index 9cf9a65e6..c73ee9df4 100644 --- a/colossalai/inference/config.py +++ b/colossalai/inference/config.py @@ -10,6 +10,7 @@ import torch from transformers.generation import GenerationConfig from colossalai.inference.flash_decoding_utils import FDIntermTensors +from colossalai.inference.utils import can_use_flash_attn2 GibiByte = 1024**3 @@ -169,7 +170,8 @@ class InferenceConfig(RPC_PARAM): no_repeat_ngram_size (Optional[int]): If no_repeat_ngram_size > 0, the consecutive tokens of ngram size can only appear once in inference sentences. repetition_penalty (Optional[float]): The parameter that influences the model's treatment of new tokens in relation to their appearance in the prompt and the generated text. Values greater than 1 incentivize the model to introduce new tokens, whereas values less than 1 incentivize token repetition., defaults to 1.0. ignore_eos(bool): Whether to ignore the EOS token and continue generating tokens when encountering the EOS token. - n_spec_tokens (int): The maximum number of speculating tokens, defaults to None. + use_spec_dec (bool): Indicate whether to use speculative decoding, defaults to False. + max_n_spec_tokens (int): The maximum number of speculating tokens, defaults to None. glimpse_large_kv (bool): Whether to use large KV in drafter model, defaults to False. block_size (int): The number of blocks in a logical block, defaults to 16. tp_size (int): Tensor parallel size, defaults to 1. @@ -214,6 +216,7 @@ class InferenceConfig(RPC_PARAM): ignore_eos: bool = False # speculative decoding configs + use_spec_dec: bool = False max_n_spec_tokens: int = 5 glimpse_large_kv: bool = False @@ -311,6 +314,16 @@ class InferenceConfig(RPC_PARAM): return GenerationConfig.from_dict(meta_config) + def to_model_shard_inference_config(self) -> "ModelShardInferenceConfig": + use_flash_attn = can_use_flash_attn2(self.dtype) + model_inference_config = ModelShardInferenceConfig( + dtype=self.dtype, + use_cuda_kernel=self.use_cuda_kernel, + use_spec_dec=self.use_spec_dec, + use_flash_attn=use_flash_attn, + ) + return model_inference_config + def to_rpc_param(self) -> dict: kwargs = { "dtype": str(self.dtype).split(".")[-1], @@ -362,3 +375,21 @@ class InferenceConfig(RPC_PARAM): # Set the attributes from the parsed arguments. inference_config = cls(**inference_config_args) return inference_config + + +@dataclass +class ModelShardInferenceConfig: + """ + Configurations used during init of module for inference modeling. + + Args: + dtype (torch.dtype): The data type for weights and activations. + use_cuda_kernel (bool): Whether to use cuda kernel, faster but lose some precision occasionally + use_spec_dec (bool): Indicate whether to use speculative decoding. + use_flash_attn (bool): Indicate whether to use flash attention. + """ + + dtype: torch.dtype = None + use_cuda_kernel: bool = False + use_spec_dec: bool = False + use_flash_attn: bool = False diff --git a/colossalai/inference/core/engine.py b/colossalai/inference/core/engine.py index 1b6e62553..d0d46d81b 100644 --- a/colossalai/inference/core/engine.py +++ b/colossalai/inference/core/engine.py @@ -18,7 +18,7 @@ from transformers.models.llama.modeling_llama import LlamaForCausalLM from colossalai.accelerator import get_accelerator from colossalai.cluster import ProcessGroupMesh from colossalai.inference.batch_bucket import BatchBucket -from colossalai.inference.config import InferenceConfig, InputMetaData +from colossalai.inference.config import InferenceConfig, InputMetaData, ModelShardInferenceConfig from colossalai.inference.graph_runner import CUDAGraphRunner from colossalai.inference.modeling.policy import model_policy_map from colossalai.inference.sampler import search_tokens @@ -72,8 +72,9 @@ class InferenceEngine: self.verbose = verbose self.logger = get_dist_logger(__name__) + self.model_shard_infer_config = inference_config.to_model_shard_inference_config() - self.init_model(model_or_path, model_policy) + self.init_model(model_or_path, model_policy, self.model_shard_infer_config) self.generation_config = inference_config.to_generation_config(self.model_config) self.generation_config_dict = self.generation_config.to_dict() @@ -97,7 +98,8 @@ class InferenceEngine: self.capture_model(self.k_cache, self.v_cache) # Model and relatable attrs of speculative decoding will be set by `enable_spec_dec` - self.use_spec_dec = False + self.use_spec_dec = self.inference_config.use_spec_dec + self.drafter_model = None self.drafter = None self.use_glide = False @@ -105,13 +107,20 @@ class InferenceEngine: self._verify_args() - def init_model(self, model_or_path: Union[nn.Module, str], model_policy: Union[Policy, Type[Policy]] = None): + def init_model( + self, + model_or_path: Union[nn.Module, str], + model_policy: Union[Policy, Type[Policy]] = None, + model_shard_infer_config: ModelShardInferenceConfig = None, + ): """ Shard model or/and Load weight Args: model_or_path Union[nn.Module, str]: path to the checkpoint or model of transformer format. - model_policy (Policy): the policy to replace the model + model_policy (Policy): the policy to replace the model. + model_inference_config: the configuration for modeling initialization when inference. + model_shard_infer_config (ModelShardInferenceConfig): the configuration for init of module when inference. """ if isinstance(model_or_path, str): @@ -124,6 +133,7 @@ class InferenceEngine: # the model load process in the future. model = _supported_models[arch].from_pretrained(model_or_path, trust_remote_code=True) else: + # TODO(char-1ee): if the model not supported, use transformers APIs to load and generate raise ValueError(f"Model {arch} is not supported.") except Exception as e: @@ -167,6 +177,7 @@ class InferenceEngine: self.model = self._shardformer( model, model_policy, + model_shard_infer_config, None, tp_group=tp_group, ) @@ -187,7 +198,7 @@ class InferenceEngine: # assert if_has_index_file, "the model path is invalid" # cpt_io.load_model(self.model, model_index_file) - free_gpu_memory, total_gpu_memory = torch.cuda.mem_get_info() + free_gpu_memory, _ = torch.cuda.mem_get_info() peak_memory = init_gpu_memory - free_gpu_memory if self.verbose: self.logger.info( @@ -287,6 +298,7 @@ class InferenceEngine: self, model: nn.Module, model_policy: Policy, + model_shard_infer_config: ModelShardInferenceConfig = None, stage_manager: PipelineStageManager = None, tp_group: ProcessGroupMesh = None, ) -> nn.Module: @@ -312,6 +324,7 @@ class InferenceEngine: enable_flash_attention=False, enable_jit_fused=False, enable_sequence_parallelism=False, + extra_kwargs={"model_shard_infer_config": model_shard_infer_config}, ) shardformer = ShardFormer(shard_config=shardconfig) shard_model, _ = shardformer.optimize(model, model_policy) @@ -348,6 +361,7 @@ class InferenceEngine: engine.clear_spec_dec() ``` """ + if drafter_model is None and self.drafter is None: raise ValueError("Drafter not initialized. Please provide a Drafter Model") if n_spec_tokens is not None: @@ -517,19 +531,19 @@ class InferenceEngine: prompts_token_ids: Union[List[int], torch.Tensor, np.ndarray] = None, return_token_ids: bool = False, generation_config: Optional[GenerationConfig] = None, - ) -> List[str]: + ) -> Union[List[str], Tuple[List[str], List[List[int]]]]: """ Executing the inference step. Args: - prompts (Union[List[str], optional): Input prompts. Defaults to None. - prompts_token_ids (List[List[int]], optional): token ids of input prompts. Defaults to None. request_ids (List[int], optional): The request ID. Defaults to None. - return_token_ids (bool): Whether to return output token ids. Defaults to False. - generation_config (GenerationConfig, optional): Huggingface GenerationConfig used for inference. Defaults to None. + prompts (Union[List[str], optional): Input prompts. Defaults to None. + prompts_token_ids (Union[List[int], torch.Tensor, np.ndarray], optional): token ids of input prompts. Defaults to None. + return_token_ids (bool, optional): Whether to return output token ids. Defaults to False. + generation_config (Optional[GenerationConfig], optional): Huggingface GenerationConfig used for inference. Defaults to None. Returns: - List[str]: Inference result returned by one generation. + Union[List[str], Tuple[List[str], List[List[int]]]]: Inference result returned by one generation. """ gen_config_dict = generation_config.to_dict() if generation_config is not None else {} diff --git a/colossalai/inference/modeling/backends/__init__.py b/colossalai/inference/modeling/backends/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/colossalai/inference/modeling/backends/attention_backend.py b/colossalai/inference/modeling/backends/attention_backend.py new file mode 100644 index 000000000..e0a4ec33d --- /dev/null +++ b/colossalai/inference/modeling/backends/attention_backend.py @@ -0,0 +1,168 @@ +from abc import ABC, abstractmethod +from dataclasses import dataclass + +import torch +from flash_attn import flash_attn_varlen_func + +from colossalai.inference.config import ModelShardInferenceConfig +from colossalai.kernel.kernel_loader import InferenceOpsLoader +from colossalai.kernel.triton import context_attention_unpadded, flash_decoding_attention + + +@dataclass +class AttentionMetaData: + query_states: torch.Tensor + key_states: torch.Tensor + value_states: torch.Tensor + k_cache: torch.Tensor + v_cache: torch.Tensor + block_tables: torch.Tensor + block_size: int + kv_seq_len: int = None + sequence_lengths: torch.Tensor = None + cu_seqlens: torch.Tensor = None + sm_scale: int = None + alibi_slopes: torch.Tensor = None + output_tensor: torch.Tensor = None + use_spec_dec: bool = False + use_alibi_attn: bool = False + + +class AttentionBackend(ABC): + @abstractmethod + def prefill(self, attn_metadata: AttentionMetaData, **kwargs): + raise NotImplementedError + + @abstractmethod + def decode(self, attn_metadatas: AttentionMetaData, **kwargs): + raise NotImplementedError + + +class CudaAttentionBackend(AttentionBackend): + """ + Attention backend when use_cuda_kernel is True but flash-attn not found. If flash-attn is not found, + it uses Triton op `context_attention_unpadded` for prefilling and our cuda op `flash_decoding_attention` for decoding. + """ + + def __init__(self, use_flash_attn: bool): + super().__init__() + self.inference_ops = InferenceOpsLoader().load() + self.use_flash_attn = use_flash_attn + + def prefill(self, attn_metadata: AttentionMetaData, **kwargs): + if self.use_flash_attn: + token_nums = kwargs.get("token_nums", -1) + attn_output = flash_attn_varlen_func( + attn_metadata.query_states, + attn_metadata.key_states, + attn_metadata.value_states, + cu_seqlens_q=attn_metadata.cu_seqlens, + cu_seqlens_k=attn_metadata.cu_seqlens, + max_seqlen_q=attn_metadata.kv_seq_len, + max_seqlen_k=attn_metadata.kv_seq_len, + dropout_p=0.0, + softmax_scale=attn_metadata.sm_scale, + causal=True, + alibi_slopes=attn_metadata.alibi_slopes, + ) + attn_output = attn_output.view(token_nums, -1) + else: + attn_output = context_attention_unpadded( + q=attn_metadata.query_states, + k=attn_metadata.key_states, + v=attn_metadata.value_states, + k_cache=attn_metadata.k_cache, + v_cache=attn_metadata.v_cache, + context_lengths=attn_metadata.sequence_lengths, + block_tables=attn_metadata.block_tables, + block_size=attn_metadata.block_size, + output=attn_metadata.output_tensor, + alibi_slopes=attn_metadata.alibi_slopes, + max_seq_len=attn_metadata.kv_seq_len, + sm_scale=attn_metadata.sm_scale, + use_new_kcache_layout=True, # use new k-cache layout + ) + return attn_output + + def decode(self, attn_metadata: AttentionMetaData, **kwargs): + fd_inter_tensor = kwargs.get("fd_inter_tensor", None) + output_tensor = attn_metadata.output_tensor + self.inference_ops.flash_decoding_attention( + output_tensor, + attn_metadata.query_states, + attn_metadata.k_cache, + attn_metadata.v_cache, + attn_metadata.sequence_lengths, + attn_metadata.block_tables, + attn_metadata.block_size, + attn_metadata.kv_seq_len, + fd_inter_tensor.mid_output, + fd_inter_tensor.exp_sums, + fd_inter_tensor.max_logits, + attn_metadata.alibi_slopes, + attn_metadata.sm_scale, + ) + return output_tensor + + +class TritonAttentionBackend(AttentionBackend): + """ + Attention backend when use_cuda_kernel is False. It uses pure Triton ops for prefilling and decoding. + """ + + def prefill(self, attn_metadata: AttentionMetaData, **kwargs): + return context_attention_unpadded( + q=attn_metadata.query_states, + k=attn_metadata.key_states, + v=attn_metadata.value_states, + k_cache=attn_metadata.k_cache, + v_cache=attn_metadata.v_cache, + context_lengths=attn_metadata.sequence_lengths, + block_tables=attn_metadata.block_tables, + block_size=attn_metadata.block_size, + output=attn_metadata.output_tensor, + alibi_slopes=attn_metadata.alibi_slopes, + max_seq_len=attn_metadata.kv_seq_len, + sm_scale=attn_metadata.sm_scale, + ) + + def decode(self, attn_metadata: AttentionMetaData, **kwargs): + fd_inter_tensor = kwargs.get("fd_inter_tensor", None) + return flash_decoding_attention( + q=attn_metadata.query_states, + k_cache=attn_metadata.k_cache, + v_cache=attn_metadata.v_cache, + kv_seq_len=attn_metadata.sequence_lengths, + block_tables=attn_metadata.block_tables, + block_size=attn_metadata.block_size, + max_seq_len_in_batch=attn_metadata.kv_seq_len, + output=attn_metadata.output_tensor, + mid_output=fd_inter_tensor.mid_output, + mid_output_lse=fd_inter_tensor.mid_output_lse, + alibi_slopes=attn_metadata.alibi_slopes, + sm_scale=attn_metadata.sm_scale, + kv_group_num=kwargs.get("num_key_value_groups", 1), + q_len=kwargs.get("q_len", 1), + ) + + +def get_attention_backend( + model_shard_infer_config: ModelShardInferenceConfig, +) -> AttentionBackend: + """ + Get the attention backend based on the inference configurations. The modeling will use CUDA-kernel-based backend + for attention module calculation only when: + 1. using CUDA kernel (use_cuda_kernel=True) + 2. can use flash attention (flash-attn installed and dtype is fp16 or bf16) + 3. not using speculative decoding (currently cuda kernel not support speculative decoding) + Otherwise, use Triton attention backend. If found flash-attn not installed while `use_cuda_kernel` is True, + the Triton backend will use a new k cache layout for Triton kernels. + """ + # Currently only triton kernels support speculative decoding + if model_shard_infer_config.use_spec_dec: + return TritonAttentionBackend() + + if model_shard_infer_config.use_cuda_kernel: + return CudaAttentionBackend(model_shard_infer_config.use_flash_attn) + + return TritonAttentionBackend() diff --git a/colossalai/inference/modeling/backends/pre_attention_backend.py b/colossalai/inference/modeling/backends/pre_attention_backend.py new file mode 100644 index 000000000..77804429d --- /dev/null +++ b/colossalai/inference/modeling/backends/pre_attention_backend.py @@ -0,0 +1,146 @@ +from abc import ABC, abstractmethod + +from colossalai.inference.config import ModelShardInferenceConfig +from colossalai.inference.modeling.backends.attention_backend import AttentionMetaData +from colossalai.kernel.kernel_loader import InferenceOpsLoader +from colossalai.kernel.triton import copy_k_to_blocked_cache, decoding_fused_rotary_embedding, rotary_embedding + + +class PreAttentionBackend(ABC): + @abstractmethod + def prefill(self, attn_metadata: AttentionMetaData, **kwargs): + raise NotImplementedError + + @abstractmethod + def decode(self, attn_metadata: AttentionMetaData, **kwargs): + raise NotImplementedError + + +class CudaPreAttentionBackend(PreAttentionBackend): + """ + CudaPreAttentionBackend handles KV cache initialization and positional encoding for CudaAttentionBackend. + """ + + def __init__(self, use_flash_attn: bool): + super().__init__() + self.inference_ops = InferenceOpsLoader().load() + self.use_flash_attn = use_flash_attn + + def prefill(self, attn_metadata: AttentionMetaData, **kwargs): + if self.use_flash_attn: + if not attn_metadata.use_alibi_attn: + self.inference_ops.rotary_embedding( + attn_metadata.query_states, + attn_metadata.key_states, + kwargs.get("cos", None), + kwargs.get("sin", None), + kwargs.get("high_precision", False), + ) + self.inference_ops.context_kv_cache_memcpy( + attn_metadata.key_states, + attn_metadata.value_states, + attn_metadata.k_cache, + attn_metadata.v_cache, + attn_metadata.sequence_lengths, + attn_metadata.cu_seqlens, + attn_metadata.block_tables, + attn_metadata.kv_seq_len, + ) + elif not attn_metadata.use_alibi_attn: + rotary_embedding( + attn_metadata.query_states, + attn_metadata.key_states, + kwargs.get("cos", None), + kwargs.get("sin", None), + ) + + def decode(self, attn_metadata: AttentionMetaData, **kwargs): + if not attn_metadata.use_alibi_attn: + self.inference_ops.rotary_embedding_and_cache_copy( + attn_metadata.query_states, + attn_metadata.key_states, + attn_metadata.value_states, + kwargs.get("cos", None), + kwargs.get("sin", None), + attn_metadata.k_cache, + attn_metadata.v_cache, + attn_metadata.sequence_lengths, + attn_metadata.block_tables, + kwargs.get("high_precision", None), + ) + else: + self.inference_ops.decode_kv_cache_memcpy( + attn_metadata.key_states, + attn_metadata.value_states, + attn_metadata.k_cache, + attn_metadata.v_cache, + attn_metadata.sequence_lengths, + attn_metadata.block_tables, + ) + + +class TritonPreAttentionBackend(PreAttentionBackend): + """ + TritonPreAttentionBackend handles KV cache initialization and positional encoding for TritonAttentionBackend. + """ + + def prefill(self, attn_metadata: AttentionMetaData, **kwargs): + if not attn_metadata.use_alibi_attn: + rotary_embedding( + attn_metadata.query_states, + attn_metadata.key_states, + kwargs.get("cos", None), + kwargs.get("sin", None), + ) + + def decode(self, attn_metadata: AttentionMetaData, **kwargs): + if not attn_metadata.use_spec_dec and not attn_metadata.use_alibi_attn: + decoding_fused_rotary_embedding( + attn_metadata.query_states, + attn_metadata.key_states, + attn_metadata.value_states, + kwargs.get("cos", None), + kwargs.get("sin", None), + attn_metadata.k_cache, + attn_metadata.v_cache, + attn_metadata.block_tables, + attn_metadata.sequence_lengths, + ) + else: # else if using speculative decoding + if not attn_metadata.use_alibi_attn: + rotary_embedding( + attn_metadata.query_states, + attn_metadata.key_states, + kwargs.get("cos", None), + kwargs.get("sin", None), + ) + copy_k_to_blocked_cache( + attn_metadata.key_states, + attn_metadata.k_cache, + kv_lengths=attn_metadata.sequence_lengths, + block_tables=attn_metadata.block_tables, + n=kwargs.get("q_len", 1), + ) + copy_k_to_blocked_cache( + attn_metadata.value_states, + attn_metadata.v_cache, + kv_lengths=attn_metadata.sequence_lengths, + block_tables=attn_metadata.block_tables, + n=kwargs.get("q_len", 1), + ) + + +def get_pre_attention_backend( + model_shard_infer_config: ModelShardInferenceConfig, +) -> PreAttentionBackend: + """ + Get the backend for pre-attention computations, including potisional encoding like + RoPE and KV cache initialization. It adopt the same selection logic as attention_backend/get_attention_backend. + """ + if model_shard_infer_config.use_spec_dec: + return TritonPreAttentionBackend() + + if model_shard_infer_config.use_cuda_kernel: + return CudaPreAttentionBackend(model_shard_infer_config.use_flash_attn) + + return TritonPreAttentionBackend() diff --git a/colossalai/inference/modeling/models/nopadding_baichuan.py b/colossalai/inference/modeling/models/nopadding_baichuan.py index b50e73d6f..f10ef6e3c 100644 --- a/colossalai/inference/modeling/models/nopadding_baichuan.py +++ b/colossalai/inference/modeling/models/nopadding_baichuan.py @@ -1,68 +1,27 @@ # This code is adapted from huggingface baichuan model: hhttps://huggingface.co/baichuan-inc/Baichuan2-13B-Base/blob/main/modeling_baichuan.py import itertools -import math from typing import List, Optional, Tuple, Union import torch import torch.nn as nn from torch.distributed import ProcessGroup +from colossalai.inference.config import ModelShardInferenceConfig from colossalai.inference.flash_decoding_utils import FDIntermTensors +from colossalai.inference.modeling.backends.attention_backend import AttentionMetaData, get_attention_backend +from colossalai.inference.modeling.backends.pre_attention_backend import get_pre_attention_backend from colossalai.inference.modeling.models.nopadding_llama import NopadLlamaMLP +from colossalai.inference.utils import get_alibi_slopes from colossalai.kernel.kernel_loader import InferenceOpsLoader -from colossalai.kernel.triton import ( - context_attention_unpadded, - copy_k_to_blocked_cache, - decoding_fused_rotary_embedding, - flash_decoding_attention, - rms_layernorm, - rotary_embedding, -) +from colossalai.kernel.triton import rms_layernorm from colossalai.logging import get_dist_logger from colossalai.shardformer.layer.parallel_module import ParallelModule from colossalai.tensor.d_tensor import Layout, distribute_tensor, is_distributed_tensor -logger = get_dist_logger(__name__) - -try: - from flash_attn import flash_attn_varlen_func - - use_flash_attn2 = True -except ImportError: - use_flash_attn2 = False - logger.warning(f"flash_attn2 has not been installed yet, we will use triton flash attn instead.") - -logger = get_dist_logger(__name__) - -try: - from flash_attn import flash_attn_varlen_func - - use_flash_attn2 = True -except ImportError: - use_flash_attn2 = False - logger.warning(f"flash_attn2 has not been installed yet, we will use triton flash attn instead.") - inference_ops = InferenceOpsLoader().load() - logger = get_dist_logger(__name__) -# alibi slopes calculation adapted from https://github.com/huggingface/transformers/blob/v4.36.0/src/transformers/models/bloom/modeling_bloom.py#L57 -def get_alibi_slopes(num_heads: int, device: torch.device) -> torch.Tensor: - closest_power_of_2 = 2 ** math.floor(math.log2(num_heads)) - base = torch.tensor(2 ** (-(2 ** -(math.log2(closest_power_of_2) - 3))), dtype=torch.float32, device=device) - powers = torch.arange(1, 1 + closest_power_of_2, dtype=torch.int32, device=device) - slopes = torch.pow(base, powers) - if closest_power_of_2 != num_heads: - extra_base = torch.tensor( - 2 ** (-(2 ** -(math.log2(2 * closest_power_of_2) - 3))), dtype=torch.float32, device=device - ) - num_remaining_heads = min(closest_power_of_2, num_heads - closest_power_of_2) - extra_powers = torch.arange(1, 1 + 2 * num_remaining_heads, 2, dtype=torch.int32, device=device) - slopes = torch.cat([slopes, torch.pow(extra_base, extra_powers)], dim=0) - return slopes - - def baichuan_rmsnorm_forward( self, hidden_states: torch.Tensor, @@ -102,6 +61,7 @@ class NopadBaichuanAttention(ParallelModule): attn_oproj: ParallelModule = None, num_heads: int = None, hidden_size: int = None, + model_shard_infer_config: ModelShardInferenceConfig = None, process_group: ProcessGroup = None, helper_layout: Layout = None, ): @@ -126,6 +86,9 @@ class NopadBaichuanAttention(ParallelModule): self.qkv_weight = nn.Parameter(torch.stack(qkv_weight_list, dim=0)) self.helper_layout = helper_layout + self.use_cuda_kernel = model_shard_infer_config.use_cuda_kernel + self.attention_backend = get_attention_backend(model_shard_infer_config) + self.pre_attention_backend = get_pre_attention_backend(model_shard_infer_config) self.alibi_slopes = None self.use_alibi_attn = False @@ -155,6 +118,7 @@ class NopadBaichuanAttention(ParallelModule): attn_kproj_w = k_proj_w attn_vproj_w = v_proj_w attn_oproj = module.o_proj + model_shard_infer_config = kwargs.get("model_shard_infer_config", None) helper_layout = ( module.W_pack.weight.dist_layout @@ -166,6 +130,7 @@ class NopadBaichuanAttention(ParallelModule): attn_kproj_w=attn_kproj_w, attn_vproj_w=attn_vproj_w, attn_oproj=attn_oproj, + model_shard_infer_config=model_shard_infer_config, num_heads=module.num_heads, hidden_size=module.hidden_size, process_group=process_group, @@ -234,7 +199,6 @@ class NopadBaichuanAttention(ParallelModule): kv_seq_len: int = 0, output_tensor: torch.Tensor = None, sm_scale: int = None, - use_cuda_kernel: bool = True, cu_seqlens: torch.Tensor = None, high_precision: bool = False, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: @@ -253,7 +217,6 @@ class NopadBaichuanAttention(ParallelModule): kv_seq_len (int, optional): The max sequence length of input sequences. Defaults to 0. output_tensor (torch.Tensor, optional): The mid tensor holds the output of attention. Defaults to None. sm_scale (int, optional): Used for flash attention. Defaults to None. - use_cuda_kernel: (bool, optional): Whether to use cuda kernel. Defaults to True. cu_seqlens(torch.Tensor, optional): Holding the cumulative sum of sequence length. high_precision(Optional[bool]): Whether to use float32 for underlying calculations of float16 data to achieve higher precision, defaults to False. """ @@ -267,121 +230,49 @@ class NopadBaichuanAttention(ParallelModule): block_size = k_cache.size(-2) - if is_prompts: - if not is_verifier and use_cuda_kernel and query_states.dtype != torch.float32 and use_flash_attn2: - # flash attn 2 currently only supports FP16/BF16. - if not self.use_alibi_attn: - inference_ops.rotary_embedding(query_states, key_states, cos_sin[0], cos_sin[1], high_precision) - inference_ops.context_kv_cache_memcpy( - key_states, value_states, k_cache, v_cache, sequence_lengths, cu_seqlens, block_tables, kv_seq_len - ) - attn_output = flash_attn_varlen_func( - query_states, - key_states, - value_states, - cu_seqlens_q=cu_seqlens, - cu_seqlens_k=cu_seqlens, - max_seqlen_q=kv_seq_len, - max_seqlen_k=kv_seq_len, - dropout_p=0.0, - softmax_scale=sm_scale, - causal=True, - alibi_slopes=self.alibi_slopes, - ) - attn_output = attn_output.view(token_nums, -1) - else: - if not self.use_alibi_attn: - rotary_embedding(query_states, key_states, cos_sin[0], cos_sin[1]) - attn_output = context_attention_unpadded( - q=query_states, - k=key_states, - v=value_states, - k_cache=k_cache, - v_cache=v_cache, - context_lengths=sequence_lengths, - block_tables=block_tables, - block_size=block_size, - output=output_tensor, - alibi_slopes=self.alibi_slopes, - max_seq_len=kv_seq_len, - sm_scale=sm_scale, - use_new_kcache_layout=use_cuda_kernel, - ) - else: + attn_metadata = AttentionMetaData( + query_states=query_states, + key_states=key_states, + value_states=value_states, + k_cache=k_cache, + v_cache=v_cache, + block_tables=block_tables, + block_size=block_size, + kv_seq_len=kv_seq_len, + sequence_lengths=sequence_lengths, + sm_scale=sm_scale, + alibi_slopes=self.alibi_slopes, + cu_seqlens=cu_seqlens, + output_tensor=output_tensor, + use_spec_dec=is_verifier, + use_alibi_attn=self.use_alibi_attn, + ) + + if is_prompts: # prefilling stage + self.pre_attention_backend.prefill( + attn_metadata, + cos=cos_sin[0], + sin=cos_sin[1], + high_precision=high_precision, + ) + attn_output = self.attention_backend.prefill( + attn_metadata, + token_nums=token_nums, + ) + else: # decoding stage q_len = tokens_to_verify + 1 if is_verifier else 1 - if use_cuda_kernel: - if not self.use_alibi_attn: - inference_ops.rotary_embedding_and_cache_copy( - query_states, - key_states, - value_states, - cos_sin[0], - cos_sin[1], - k_cache, - v_cache, - sequence_lengths, - block_tables, - high_precision, - ) - else: - inference_ops.decode_kv_cache_memcpy( - key_states, value_states, k_cache, v_cache, sequence_lengths, block_tables - ) - inference_ops.flash_decoding_attention( - output_tensor, - query_states, - k_cache, - v_cache, - sequence_lengths, - block_tables, - block_size, - kv_seq_len, - fd_inter_tensor.mid_output, - fd_inter_tensor.exp_sums, - fd_inter_tensor.max_logits, - self.alibi_slopes, - sm_scale, - ) - attn_output = output_tensor - else: - if not is_verifier and not self.use_alibi_attn: - decoding_fused_rotary_embedding( - query_states, - key_states, - value_states, - cos_sin[0], - cos_sin[1], - k_cache, - v_cache, - block_tables, - sequence_lengths, - ) - else: - if not self.use_alibi_attn: - rotary_embedding(query_states, key_states, cos_sin[0], cos_sin[1]) - copy_k_to_blocked_cache( - key_states, k_cache, kv_lengths=sequence_lengths, block_tables=block_tables, n=q_len - ) - copy_k_to_blocked_cache( - value_states, v_cache, kv_lengths=sequence_lengths, block_tables=block_tables, n=q_len - ) - - attn_output = flash_decoding_attention( - q=query_states, - k_cache=k_cache, - v_cache=v_cache, - kv_seq_len=sequence_lengths, - block_tables=block_tables, - block_size=block_size, - max_seq_len_in_batch=kv_seq_len, - output=output_tensor, - mid_output=fd_inter_tensor.mid_output, - mid_output_lse=fd_inter_tensor.mid_output_lse, - alibi_slopes=self.alibi_slopes, - sm_scale=sm_scale, - q_len=q_len, - ) + self.pre_attention_backend.decode( + attn_metadata, + cos=cos_sin[0], + sin=cos_sin[1], + q_len=q_len, + ) + attn_output = self.attention_backend.decode( + attn_metadata, + fd_inter_tensor=fd_inter_tensor, + q_len=q_len, + ) attn_output = attn_output.view(-1, self.hidden_size) attn_output = self.o_proj(attn_output) diff --git a/colossalai/inference/modeling/models/nopadding_llama.py b/colossalai/inference/modeling/models/nopadding_llama.py index f6f160eb7..e274e7b7c 100644 --- a/colossalai/inference/modeling/models/nopadding_llama.py +++ b/colossalai/inference/modeling/models/nopadding_llama.py @@ -16,18 +16,13 @@ from transformers.models.llama.modeling_llama import ( LlamaRMSNorm, ) -from colossalai.inference.config import InputMetaData +from colossalai.inference.config import InputMetaData, ModelShardInferenceConfig from colossalai.inference.flash_decoding_utils import FDIntermTensors +from colossalai.inference.modeling.backends.attention_backend import AttentionMetaData, get_attention_backend +from colossalai.inference.modeling.backends.pre_attention_backend import get_pre_attention_backend +from colossalai.inference.utils import can_use_flash_attn2 from colossalai.kernel.kernel_loader import InferenceOpsLoader -from colossalai.kernel.triton import ( - context_attention_unpadded, - copy_k_to_blocked_cache, - decoding_fused_rotary_embedding, - flash_decoding_attention, - get_xine_cache, - rms_layernorm, - rotary_embedding, -) +from colossalai.kernel.triton import get_xine_cache, rms_layernorm from colossalai.logging import get_dist_logger from colossalai.shardformer.layer.parallel_module import ParallelModule from colossalai.tensor.d_tensor import distribute_tensor, is_distributed_tensor @@ -36,14 +31,6 @@ inference_ops = InferenceOpsLoader().load() logger = get_dist_logger(__name__) -try: - from flash_attn import flash_attn_varlen_func - - use_flash_attn2 = True -except ImportError: - use_flash_attn2 = False - logger.warning(f"flash_attn2 has not been installed yet, we will use triton flash attn instead.") - def llama_causal_lm_forward( self: LlamaForCausalLM, @@ -126,7 +113,7 @@ def llama_model_forward( cos_sin = (self._cos_cached[rotary_indexes], self._sin_cached[rotary_indexes]) elif use_cuda_kernel: - if inputmetadata.dtype != torch.float32 and use_flash_attn2: + if can_use_flash_attn2(inputmetadata.dtype): cu_seqlens = F.pad(torch.cumsum(sequence_lengths, dim=0, dtype=torch.torch.int32), (1, 0)) hidden_dim = self._cos_cached.size(-1) @@ -238,7 +225,6 @@ def llama_decoder_layer_forward( kv_seq_len=kv_seq_len, output_tensor=output_tensor, sm_scale=sm_scale, - use_cuda_kernel=use_cuda_kernel, cu_seqlens=cu_seqlens, high_precision=high_precision, ) @@ -402,6 +388,7 @@ class NopadLlamaAttention(LlamaAttention, ParallelModule): attn_vproj_w: torch.Tensor = None, attn_oproj: ParallelModule = None, process_group: ProcessGroup = None, + model_shard_infer_config: ModelShardInferenceConfig = None, num_heads: int = None, hidden_size: int = None, num_key_value_heads: int = None, @@ -433,6 +420,9 @@ class NopadLlamaAttention(LlamaAttention, ParallelModule): self.rope_theta = config.rope_theta self.is_causal = True + self.attention_backend = get_attention_backend(model_shard_infer_config) + self.pre_attention_backend = get_pre_attention_backend(model_shard_infer_config) + if self.num_heads == self.num_key_value_heads: qkv_weight_list = [attn_qproj_w.transpose(0, 1), attn_kproj_w.transpose(0, 1), attn_vproj_w.transpose(0, 1)] self.qkv_weight = nn.Parameter(torch.stack(qkv_weight_list, dim=0)) @@ -462,6 +452,7 @@ class NopadLlamaAttention(LlamaAttention, ParallelModule): attn_vproj_w = module.v_proj.weight assert is_distributed_tensor(attn_qproj_w), "attn_qproj_w must be dist tensor" attn_oproj = module.o_proj + model_shard_infer_config = kwargs.get("model_shard_infer_config", None) attn_layer = NopadLlamaAttention( config=config, @@ -471,6 +462,7 @@ class NopadLlamaAttention(LlamaAttention, ParallelModule): attn_vproj_w=attn_vproj_w, attn_oproj=attn_oproj, process_group=process_group, + model_shard_infer_config=model_shard_infer_config, num_heads=module.num_heads, hidden_size=module.hidden_size, num_key_value_heads=module.num_key_value_heads, @@ -533,111 +525,50 @@ class NopadLlamaAttention(LlamaAttention, ParallelModule): block_size = k_cache.size(-2) - if is_prompts: - if not is_verifier and use_cuda_kernel and query_states.dtype != torch.float32 and use_flash_attn2: - # flash attn 2 currently only supports FP16/BF16. - inference_ops.rotary_embedding(query_states, key_states, cos_sin[0], cos_sin[1], high_precision) - inference_ops.context_kv_cache_memcpy( - key_states, value_states, k_cache, v_cache, sequence_lengths, cu_seqlens, block_tables, kv_seq_len - ) + attn_metadata = AttentionMetaData( + query_states=query_states, + key_states=key_states, + value_states=value_states, + k_cache=k_cache, + v_cache=v_cache, + block_tables=block_tables, + block_size=block_size, + kv_seq_len=kv_seq_len, + sequence_lengths=sequence_lengths, + sm_scale=sm_scale, + alibi_slopes=None, + cu_seqlens=cu_seqlens, + output_tensor=output_tensor, + use_spec_dec=is_verifier, + use_alibi_attn=False, + ) - attn_output = flash_attn_varlen_func( - query_states, - key_states, - value_states, - cu_seqlens_q=cu_seqlens, - cu_seqlens_k=cu_seqlens, - max_seqlen_q=kv_seq_len, - max_seqlen_k=kv_seq_len, - dropout_p=0.0, - softmax_scale=sm_scale, - causal=True, - ) - attn_output = attn_output.view(token_nums, -1) - else: - rotary_embedding(query_states, key_states, cos_sin[0], cos_sin[1]) - attn_output = context_attention_unpadded( - q=query_states, - k=key_states, - v=value_states, - k_cache=k_cache, - v_cache=v_cache, - context_lengths=sequence_lengths, - block_tables=block_tables, - block_size=block_size, - output=output_tensor, - max_seq_len=kv_seq_len, - sm_scale=sm_scale, - use_new_kcache_layout=use_cuda_kernel, - ) - else: + if is_prompts: # prefilling stage + self.pre_attention_backend.prefill( + attn_metadata, + cos=cos_sin[0], + sin=cos_sin[1], + high_precision=high_precision, + ) + attn_output = self.attention_backend.prefill( + attn_metadata, + token_nums=token_nums, + ) + else: # decoding stage q_len = tokens_to_verify + 1 if is_verifier else 1 - if use_cuda_kernel: - inference_ops.rotary_embedding_and_cache_copy( - query_states, - key_states, - value_states, - cos_sin[0], - cos_sin[1], - k_cache, - v_cache, - sequence_lengths, - block_tables, - high_precision, - ) - inference_ops.flash_decoding_attention( - output_tensor, - query_states, - k_cache, - v_cache, - sequence_lengths, - block_tables, - block_size, - kv_seq_len, - fd_inter_tensor.mid_output, - fd_inter_tensor.exp_sums, - fd_inter_tensor.max_logits, - None, - sm_scale, - ) - attn_output = output_tensor - else: - if is_verifier: - rotary_embedding(query_states, key_states, cos_sin[0], cos_sin[1]) - copy_k_to_blocked_cache( - key_states, k_cache, kv_lengths=sequence_lengths, block_tables=block_tables, n=q_len - ) - copy_k_to_blocked_cache( - value_states, v_cache, kv_lengths=sequence_lengths, block_tables=block_tables, n=q_len - ) - else: - decoding_fused_rotary_embedding( - query_states, - key_states, - value_states, - cos_sin[0], - cos_sin[1], - k_cache, - v_cache, - block_tables, - sequence_lengths, - ) - attn_output = flash_decoding_attention( - q=query_states, - k_cache=k_cache, - v_cache=v_cache, - kv_seq_len=sequence_lengths, - block_tables=block_tables, - block_size=block_size, - max_seq_len_in_batch=kv_seq_len, - output=output_tensor, - mid_output=fd_inter_tensor.mid_output, - mid_output_lse=fd_inter_tensor.mid_output_lse, - sm_scale=sm_scale, - kv_group_num=self.num_key_value_groups, - q_len=q_len, - ) + self.pre_attention_backend.decode( + attn_metadata, + cos=cos_sin[0], + sin=cos_sin[1], + q_len=q_len, + ) + attn_output = self.attention_backend.decode( + attn_metadata, + fd_inter_tensor=fd_inter_tensor, + num_key_value_groups=self.num_key_value_groups, + q_len=q_len, + ) attn_output = attn_output.view(-1, self.hidden_size) attn_output = self.o_proj(attn_output) diff --git a/colossalai/inference/modeling/policy/nopadding_baichuan.py b/colossalai/inference/modeling/policy/nopadding_baichuan.py index 78268d6e7..b28c2fce8 100644 --- a/colossalai/inference/modeling/policy/nopadding_baichuan.py +++ b/colossalai/inference/modeling/policy/nopadding_baichuan.py @@ -70,6 +70,9 @@ class NoPaddingBaichuanModelInferPolicy(LlamaForCausalLMPolicy, RPC_PARAM): SubModuleReplacementDescription( suffix="self_attn", target_module=NopadBaichuanAttention, + kwargs={ + "model_shard_infer_config": self.shard_config.extra_kwargs["model_shard_infer_config"], + }, ), ], ) diff --git a/colossalai/inference/modeling/policy/nopadding_llama.py b/colossalai/inference/modeling/policy/nopadding_llama.py index 24cf7c740..0b6797560 100644 --- a/colossalai/inference/modeling/policy/nopadding_llama.py +++ b/colossalai/inference/modeling/policy/nopadding_llama.py @@ -72,6 +72,9 @@ class NoPaddingLlamaModelInferPolicy(LlamaForCausalLMPolicy, RPC_PARAM): SubModuleReplacementDescription( suffix="self_attn", target_module=NopadLlamaAttention, + kwargs={ + "model_shard_infer_config": self.shard_config.extra_kwargs["model_shard_infer_config"], + }, ), ], ) diff --git a/colossalai/inference/utils.py b/colossalai/inference/utils.py index 072bedec3..1374103a9 100644 --- a/colossalai/inference/utils.py +++ b/colossalai/inference/utils.py @@ -1,6 +1,7 @@ """ Utils for model inference """ +import math import os import re from pathlib import Path @@ -9,8 +10,11 @@ from typing import Optional, Tuple import torch from torch import nn +from colossalai.logging import get_dist_logger from colossalai.testing import free_port +logger = get_dist_logger(__name__) + def init_to_get_rotary(self, base=10000, use_elem=False): """ @@ -113,3 +117,42 @@ def find_available_ports(num: int): print(f"An OS error occurred: {e}") raise RuntimeError("Error finding available ports") return free_ports + + +def get_alibi_slopes(num_heads: int, device: torch.device) -> torch.Tensor: + """ + Alibi slopes calculation adapted from https://github.com/huggingface/transformers/blob/v4.36.0/src/transformers/models/bloom/modeling_bloom.py#L57 + + Args: + num_heads (int): The number of attention heads. + device (torch.device): The device to use. + + Returns: + torch.Tensor: The Alibi slopes. + """ + closest_power_of_2 = 2 ** math.floor(math.log2(num_heads)) + base = torch.tensor(2 ** (-(2 ** -(math.log2(closest_power_of_2) - 3))), dtype=torch.float32, device=device) + powers = torch.arange(1, 1 + closest_power_of_2, dtype=torch.int32, device=device) + slopes = torch.pow(base, powers) + if closest_power_of_2 != num_heads: + extra_base = torch.tensor( + 2 ** (-(2 ** -(math.log2(2 * closest_power_of_2) - 3))), dtype=torch.float32, device=device + ) + num_remaining_heads = min(closest_power_of_2, num_heads - closest_power_of_2) + extra_powers = torch.arange(1, 1 + 2 * num_remaining_heads, 2, dtype=torch.int32, device=device) + slopes = torch.cat([slopes, torch.pow(extra_base, extra_powers)], dim=0) + return slopes + + +def can_use_flash_attn2(dtype: torch.dtype) -> bool: + """ + Check flash attention2 availability. + """ + if dtype not in (torch.float16, torch.bfloat16): + return False + + try: + return True + except ImportError: + logger.warning(f"flash_attn2 has not been installed yet, we will use triton flash attn instead.") + return False diff --git a/tests/test_infer/test_kernels/cuda/test_flash_decoding_attention.py b/tests/test_infer/test_kernels/cuda/test_flash_decoding_attention.py index 0bd398e2e..e9bf24d53 100644 --- a/tests/test_infer/test_kernels/cuda/test_flash_decoding_attention.py +++ b/tests/test_infer/test_kernels/cuda/test_flash_decoding_attention.py @@ -4,7 +4,7 @@ import numpy as np import pytest import torch -from colossalai.inference.modeling.models.nopadding_baichuan import get_alibi_slopes +from colossalai.inference.utils import get_alibi_slopes from colossalai.kernel.kernel_loader import InferenceOpsLoader from colossalai.utils import get_current_device from tests.test_infer.test_kernels.triton.test_context_attn_unpad import generate_alibi_mask diff --git a/tests/test_infer/test_kernels/triton/test_context_attn_unpad.py b/tests/test_infer/test_kernels/triton/test_context_attn_unpad.py index 9d76858ed..92173ac13 100644 --- a/tests/test_infer/test_kernels/triton/test_context_attn_unpad.py +++ b/tests/test_infer/test_kernels/triton/test_context_attn_unpad.py @@ -2,7 +2,7 @@ import pytest import torch from packaging import version -from colossalai.inference.modeling.models.nopadding_baichuan import get_alibi_slopes +from colossalai.inference.utils import get_alibi_slopes from colossalai.kernel.triton import context_attention_unpadded from colossalai.utils import get_current_device from tests.test_infer.test_kernels.triton.kernel_utils import ( diff --git a/tests/test_infer/test_kernels/triton/test_decoding_attn.py b/tests/test_infer/test_kernels/triton/test_decoding_attn.py index 40a6eae58..aa2a7e2b4 100644 --- a/tests/test_infer/test_kernels/triton/test_decoding_attn.py +++ b/tests/test_infer/test_kernels/triton/test_decoding_attn.py @@ -3,7 +3,7 @@ import pytest import torch from packaging import version -from colossalai.inference.modeling.models.nopadding_baichuan import get_alibi_slopes +from colossalai.inference.utils import get_alibi_slopes from colossalai.kernel.triton import flash_decoding_attention from colossalai.utils import get_current_device from tests.test_infer.test_kernels.triton.kernel_utils import ( diff --git a/tests/test_infer/test_models/test_baichuan.py b/tests/test_infer/test_models/test_baichuan.py index 736fab5ff..f24e1bb3f 100644 --- a/tests/test_infer/test_models/test_baichuan.py +++ b/tests/test_infer/test_models/test_baichuan.py @@ -55,7 +55,7 @@ def check_inference_engine(use_engine=False, do_sample=False, use_cuda_kernel=Fa assert inference_engine.generation_config.max_new_tokens == output_len inference_engine.add_request(prompts=inputs) assert inference_engine.request_handler._has_waiting() - generation_config = GenerationConfig(do_sample=do_sample, top_p=top_p, top_k=top_k) + generation_config = GenerationConfig(do_sample=do_sample, top_p=top_p, top_k=top_k, max_new_tokens=output_len) outputs = inference_engine.generate(generation_config=generation_config) else: if prompt_template: