Merge pull request #5771 from char-1ee/refactor/modeling

[Inference] Refactor modeling attention layer by abstracting attention backends
pull/5791/head
Li Xingjian 2024-06-10 11:52:22 +08:00 committed by GitHub
commit 77a219a082
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
15 changed files with 531 additions and 301 deletions

View File

@ -236,7 +236,7 @@ Completion api is used for single sequence request, like answer a question or co
- POST '/chat': - POST '/chat':
Chat api is used for conversation-style request, which often includes dialogue participants(i.e. roles) and corresponding words. Considering the input data are very different from normal inputs, we introduce Chat-Template to match the data format in chat models. Chat api is used for conversation-style request, which often includes dialogue participants(i.e. roles) and corresponding words. Considering the input data are very different from normal inputs, we introduce Chat-Template to match the data format in chat models.
#### chat-template #### chat-template
Followed `transformers`, we add the chat-template argument. As chat models have been trained with very different formats for converting conversations into a single tokenizable string. Using a format that matches the training data is extremely important. This attribute(chat_template) is inclueded in HuggingFace tokenizers, containing a Jinja template that converts conversation histories into a correctly formatted string. You can refer to the [HuggingFace-blog](https://huggingface.co/blog/chat-templates) for more information. We also provide a simple example temlate bellow. Both str or file style chat template are supported. Followed `transformers`, we add the chat-template argument. As chat models have been trained with very different formats for converting conversations into a single tokenizable string. Using a format that matches the training data is extremely important. This attribute(chat_template) is inclueded in HuggingFace tokenizers, containing a Jinja template that converts conversation histories into a correctly formatted string. You can refer to the [HuggingFace-blog](https://huggingface.co/blog/chat-templates) for more information. We also provide a simple example template bellow. Both str or file style chat template are supported.
### Usage ### Usage
#### Args for customizing your server #### Args for customizing your server
The configuration for api server contains both serving interface and engine backend. The configuration for api server contains both serving interface and engine backend.

View File

@ -10,6 +10,7 @@ import torch
from transformers.generation import GenerationConfig from transformers.generation import GenerationConfig
from colossalai.inference.flash_decoding_utils import FDIntermTensors from colossalai.inference.flash_decoding_utils import FDIntermTensors
from colossalai.inference.utils import can_use_flash_attn2
GibiByte = 1024**3 GibiByte = 1024**3
@ -169,7 +170,8 @@ class InferenceConfig(RPC_PARAM):
no_repeat_ngram_size (Optional[int]): If no_repeat_ngram_size > 0, the consecutive tokens of ngram size can only appear once in inference sentences. no_repeat_ngram_size (Optional[int]): If no_repeat_ngram_size > 0, the consecutive tokens of ngram size can only appear once in inference sentences.
repetition_penalty (Optional[float]): The parameter that influences the model's treatment of new tokens in relation to their appearance in the prompt and the generated text. Values greater than 1 incentivize the model to introduce new tokens, whereas values less than 1 incentivize token repetition., defaults to 1.0. repetition_penalty (Optional[float]): The parameter that influences the model's treatment of new tokens in relation to their appearance in the prompt and the generated text. Values greater than 1 incentivize the model to introduce new tokens, whereas values less than 1 incentivize token repetition., defaults to 1.0.
ignore_eos(bool): Whether to ignore the EOS token and continue generating tokens when encountering the EOS token. ignore_eos(bool): Whether to ignore the EOS token and continue generating tokens when encountering the EOS token.
n_spec_tokens (int): The maximum number of speculating tokens, defaults to None. use_spec_dec (bool): Indicate whether to use speculative decoding, defaults to False.
max_n_spec_tokens (int): The maximum number of speculating tokens, defaults to None.
glimpse_large_kv (bool): Whether to use large KV in drafter model, defaults to False. glimpse_large_kv (bool): Whether to use large KV in drafter model, defaults to False.
block_size (int): The number of blocks in a logical block, defaults to 16. block_size (int): The number of blocks in a logical block, defaults to 16.
tp_size (int): Tensor parallel size, defaults to 1. tp_size (int): Tensor parallel size, defaults to 1.
@ -214,6 +216,7 @@ class InferenceConfig(RPC_PARAM):
ignore_eos: bool = False ignore_eos: bool = False
# speculative decoding configs # speculative decoding configs
use_spec_dec: bool = False
max_n_spec_tokens: int = 5 max_n_spec_tokens: int = 5
glimpse_large_kv: bool = False glimpse_large_kv: bool = False
@ -311,6 +314,16 @@ class InferenceConfig(RPC_PARAM):
return GenerationConfig.from_dict(meta_config) return GenerationConfig.from_dict(meta_config)
def to_model_shard_inference_config(self) -> "ModelShardInferenceConfig":
use_flash_attn = can_use_flash_attn2(self.dtype)
model_inference_config = ModelShardInferenceConfig(
dtype=self.dtype,
use_cuda_kernel=self.use_cuda_kernel,
use_spec_dec=self.use_spec_dec,
use_flash_attn=use_flash_attn,
)
return model_inference_config
def to_rpc_param(self) -> dict: def to_rpc_param(self) -> dict:
kwargs = { kwargs = {
"dtype": str(self.dtype).split(".")[-1], "dtype": str(self.dtype).split(".")[-1],
@ -362,3 +375,21 @@ class InferenceConfig(RPC_PARAM):
# Set the attributes from the parsed arguments. # Set the attributes from the parsed arguments.
inference_config = cls(**inference_config_args) inference_config = cls(**inference_config_args)
return inference_config return inference_config
@dataclass
class ModelShardInferenceConfig:
"""
Configurations used during init of module for inference modeling.
Args:
dtype (torch.dtype): The data type for weights and activations.
use_cuda_kernel (bool): Whether to use cuda kernel, faster but lose some precision occasionally
use_spec_dec (bool): Indicate whether to use speculative decoding.
use_flash_attn (bool): Indicate whether to use flash attention.
"""
dtype: torch.dtype = None
use_cuda_kernel: bool = False
use_spec_dec: bool = False
use_flash_attn: bool = False

View File

@ -18,7 +18,7 @@ from transformers.models.llama.modeling_llama import LlamaForCausalLM
from colossalai.accelerator import get_accelerator from colossalai.accelerator import get_accelerator
from colossalai.cluster import ProcessGroupMesh from colossalai.cluster import ProcessGroupMesh
from colossalai.inference.batch_bucket import BatchBucket from colossalai.inference.batch_bucket import BatchBucket
from colossalai.inference.config import InferenceConfig, InputMetaData from colossalai.inference.config import InferenceConfig, InputMetaData, ModelShardInferenceConfig
from colossalai.inference.graph_runner import CUDAGraphRunner from colossalai.inference.graph_runner import CUDAGraphRunner
from colossalai.inference.modeling.policy import model_policy_map from colossalai.inference.modeling.policy import model_policy_map
from colossalai.inference.sampler import search_tokens from colossalai.inference.sampler import search_tokens
@ -72,8 +72,9 @@ class InferenceEngine:
self.verbose = verbose self.verbose = verbose
self.logger = get_dist_logger(__name__) self.logger = get_dist_logger(__name__)
self.model_shard_infer_config = inference_config.to_model_shard_inference_config()
self.init_model(model_or_path, model_policy) self.init_model(model_or_path, model_policy, self.model_shard_infer_config)
self.generation_config = inference_config.to_generation_config(self.model_config) self.generation_config = inference_config.to_generation_config(self.model_config)
self.generation_config_dict = self.generation_config.to_dict() self.generation_config_dict = self.generation_config.to_dict()
@ -97,7 +98,8 @@ class InferenceEngine:
self.capture_model(self.k_cache, self.v_cache) self.capture_model(self.k_cache, self.v_cache)
# Model and relatable attrs of speculative decoding will be set by `enable_spec_dec` # Model and relatable attrs of speculative decoding will be set by `enable_spec_dec`
self.use_spec_dec = False self.use_spec_dec = self.inference_config.use_spec_dec
self.drafter_model = None self.drafter_model = None
self.drafter = None self.drafter = None
self.use_glide = False self.use_glide = False
@ -105,13 +107,20 @@ class InferenceEngine:
self._verify_args() self._verify_args()
def init_model(self, model_or_path: Union[nn.Module, str], model_policy: Union[Policy, Type[Policy]] = None): def init_model(
self,
model_or_path: Union[nn.Module, str],
model_policy: Union[Policy, Type[Policy]] = None,
model_shard_infer_config: ModelShardInferenceConfig = None,
):
""" """
Shard model or/and Load weight Shard model or/and Load weight
Args: Args:
model_or_path Union[nn.Module, str]: path to the checkpoint or model of transformer format. model_or_path Union[nn.Module, str]: path to the checkpoint or model of transformer format.
model_policy (Policy): the policy to replace the model model_policy (Policy): the policy to replace the model.
model_inference_config: the configuration for modeling initialization when inference.
model_shard_infer_config (ModelShardInferenceConfig): the configuration for init of module when inference.
""" """
if isinstance(model_or_path, str): if isinstance(model_or_path, str):
@ -124,6 +133,7 @@ class InferenceEngine:
# the model load process in the future. # the model load process in the future.
model = _supported_models[arch].from_pretrained(model_or_path, trust_remote_code=True) model = _supported_models[arch].from_pretrained(model_or_path, trust_remote_code=True)
else: else:
# TODO(char-1ee): if the model not supported, use transformers APIs to load and generate
raise ValueError(f"Model {arch} is not supported.") raise ValueError(f"Model {arch} is not supported.")
except Exception as e: except Exception as e:
@ -167,6 +177,7 @@ class InferenceEngine:
self.model = self._shardformer( self.model = self._shardformer(
model, model,
model_policy, model_policy,
model_shard_infer_config,
None, None,
tp_group=tp_group, tp_group=tp_group,
) )
@ -187,7 +198,7 @@ class InferenceEngine:
# assert if_has_index_file, "the model path is invalid" # assert if_has_index_file, "the model path is invalid"
# cpt_io.load_model(self.model, model_index_file) # cpt_io.load_model(self.model, model_index_file)
free_gpu_memory, total_gpu_memory = torch.cuda.mem_get_info() free_gpu_memory, _ = torch.cuda.mem_get_info()
peak_memory = init_gpu_memory - free_gpu_memory peak_memory = init_gpu_memory - free_gpu_memory
if self.verbose: if self.verbose:
self.logger.info( self.logger.info(
@ -287,6 +298,7 @@ class InferenceEngine:
self, self,
model: nn.Module, model: nn.Module,
model_policy: Policy, model_policy: Policy,
model_shard_infer_config: ModelShardInferenceConfig = None,
stage_manager: PipelineStageManager = None, stage_manager: PipelineStageManager = None,
tp_group: ProcessGroupMesh = None, tp_group: ProcessGroupMesh = None,
) -> nn.Module: ) -> nn.Module:
@ -312,6 +324,7 @@ class InferenceEngine:
enable_flash_attention=False, enable_flash_attention=False,
enable_jit_fused=False, enable_jit_fused=False,
enable_sequence_parallelism=False, enable_sequence_parallelism=False,
extra_kwargs={"model_shard_infer_config": model_shard_infer_config},
) )
shardformer = ShardFormer(shard_config=shardconfig) shardformer = ShardFormer(shard_config=shardconfig)
shard_model, _ = shardformer.optimize(model, model_policy) shard_model, _ = shardformer.optimize(model, model_policy)
@ -348,6 +361,7 @@ class InferenceEngine:
engine.clear_spec_dec() engine.clear_spec_dec()
``` ```
""" """
if drafter_model is None and self.drafter is None: if drafter_model is None and self.drafter is None:
raise ValueError("Drafter not initialized. Please provide a Drafter Model") raise ValueError("Drafter not initialized. Please provide a Drafter Model")
if n_spec_tokens is not None: if n_spec_tokens is not None:
@ -517,19 +531,19 @@ class InferenceEngine:
prompts_token_ids: Union[List[int], torch.Tensor, np.ndarray] = None, prompts_token_ids: Union[List[int], torch.Tensor, np.ndarray] = None,
return_token_ids: bool = False, return_token_ids: bool = False,
generation_config: Optional[GenerationConfig] = None, generation_config: Optional[GenerationConfig] = None,
) -> List[str]: ) -> Union[List[str], Tuple[List[str], List[List[int]]]]:
""" """
Executing the inference step. Executing the inference step.
Args: Args:
prompts (Union[List[str], optional): Input prompts. Defaults to None.
prompts_token_ids (List[List[int]], optional): token ids of input prompts. Defaults to None.
request_ids (List[int], optional): The request ID. Defaults to None. request_ids (List[int], optional): The request ID. Defaults to None.
return_token_ids (bool): Whether to return output token ids. Defaults to False. prompts (Union[List[str], optional): Input prompts. Defaults to None.
generation_config (GenerationConfig, optional): Huggingface GenerationConfig used for inference. Defaults to None. prompts_token_ids (Union[List[int], torch.Tensor, np.ndarray], optional): token ids of input prompts. Defaults to None.
return_token_ids (bool, optional): Whether to return output token ids. Defaults to False.
generation_config (Optional[GenerationConfig], optional): Huggingface GenerationConfig used for inference. Defaults to None.
Returns: Returns:
List[str]: Inference result returned by one generation. Union[List[str], Tuple[List[str], List[List[int]]]]: Inference result returned by one generation.
""" """
gen_config_dict = generation_config.to_dict() if generation_config is not None else {} gen_config_dict = generation_config.to_dict() if generation_config is not None else {}

View File

@ -0,0 +1,168 @@
from abc import ABC, abstractmethod
from dataclasses import dataclass
import torch
from flash_attn import flash_attn_varlen_func
from colossalai.inference.config import ModelShardInferenceConfig
from colossalai.kernel.kernel_loader import InferenceOpsLoader
from colossalai.kernel.triton import context_attention_unpadded, flash_decoding_attention
@dataclass
class AttentionMetaData:
query_states: torch.Tensor
key_states: torch.Tensor
value_states: torch.Tensor
k_cache: torch.Tensor
v_cache: torch.Tensor
block_tables: torch.Tensor
block_size: int
kv_seq_len: int = None
sequence_lengths: torch.Tensor = None
cu_seqlens: torch.Tensor = None
sm_scale: int = None
alibi_slopes: torch.Tensor = None
output_tensor: torch.Tensor = None
use_spec_dec: bool = False
use_alibi_attn: bool = False
class AttentionBackend(ABC):
@abstractmethod
def prefill(self, attn_metadata: AttentionMetaData, **kwargs):
raise NotImplementedError
@abstractmethod
def decode(self, attn_metadatas: AttentionMetaData, **kwargs):
raise NotImplementedError
class CudaAttentionBackend(AttentionBackend):
"""
Attention backend when use_cuda_kernel is True but flash-attn not found. If flash-attn is not found,
it uses Triton op `context_attention_unpadded` for prefilling and our cuda op `flash_decoding_attention` for decoding.
"""
def __init__(self, use_flash_attn: bool):
super().__init__()
self.inference_ops = InferenceOpsLoader().load()
self.use_flash_attn = use_flash_attn
def prefill(self, attn_metadata: AttentionMetaData, **kwargs):
if self.use_flash_attn:
token_nums = kwargs.get("token_nums", -1)
attn_output = flash_attn_varlen_func(
attn_metadata.query_states,
attn_metadata.key_states,
attn_metadata.value_states,
cu_seqlens_q=attn_metadata.cu_seqlens,
cu_seqlens_k=attn_metadata.cu_seqlens,
max_seqlen_q=attn_metadata.kv_seq_len,
max_seqlen_k=attn_metadata.kv_seq_len,
dropout_p=0.0,
softmax_scale=attn_metadata.sm_scale,
causal=True,
alibi_slopes=attn_metadata.alibi_slopes,
)
attn_output = attn_output.view(token_nums, -1)
else:
attn_output = context_attention_unpadded(
q=attn_metadata.query_states,
k=attn_metadata.key_states,
v=attn_metadata.value_states,
k_cache=attn_metadata.k_cache,
v_cache=attn_metadata.v_cache,
context_lengths=attn_metadata.sequence_lengths,
block_tables=attn_metadata.block_tables,
block_size=attn_metadata.block_size,
output=attn_metadata.output_tensor,
alibi_slopes=attn_metadata.alibi_slopes,
max_seq_len=attn_metadata.kv_seq_len,
sm_scale=attn_metadata.sm_scale,
use_new_kcache_layout=True, # use new k-cache layout
)
return attn_output
def decode(self, attn_metadata: AttentionMetaData, **kwargs):
fd_inter_tensor = kwargs.get("fd_inter_tensor", None)
output_tensor = attn_metadata.output_tensor
self.inference_ops.flash_decoding_attention(
output_tensor,
attn_metadata.query_states,
attn_metadata.k_cache,
attn_metadata.v_cache,
attn_metadata.sequence_lengths,
attn_metadata.block_tables,
attn_metadata.block_size,
attn_metadata.kv_seq_len,
fd_inter_tensor.mid_output,
fd_inter_tensor.exp_sums,
fd_inter_tensor.max_logits,
attn_metadata.alibi_slopes,
attn_metadata.sm_scale,
)
return output_tensor
class TritonAttentionBackend(AttentionBackend):
"""
Attention backend when use_cuda_kernel is False. It uses pure Triton ops for prefilling and decoding.
"""
def prefill(self, attn_metadata: AttentionMetaData, **kwargs):
return context_attention_unpadded(
q=attn_metadata.query_states,
k=attn_metadata.key_states,
v=attn_metadata.value_states,
k_cache=attn_metadata.k_cache,
v_cache=attn_metadata.v_cache,
context_lengths=attn_metadata.sequence_lengths,
block_tables=attn_metadata.block_tables,
block_size=attn_metadata.block_size,
output=attn_metadata.output_tensor,
alibi_slopes=attn_metadata.alibi_slopes,
max_seq_len=attn_metadata.kv_seq_len,
sm_scale=attn_metadata.sm_scale,
)
def decode(self, attn_metadata: AttentionMetaData, **kwargs):
fd_inter_tensor = kwargs.get("fd_inter_tensor", None)
return flash_decoding_attention(
q=attn_metadata.query_states,
k_cache=attn_metadata.k_cache,
v_cache=attn_metadata.v_cache,
kv_seq_len=attn_metadata.sequence_lengths,
block_tables=attn_metadata.block_tables,
block_size=attn_metadata.block_size,
max_seq_len_in_batch=attn_metadata.kv_seq_len,
output=attn_metadata.output_tensor,
mid_output=fd_inter_tensor.mid_output,
mid_output_lse=fd_inter_tensor.mid_output_lse,
alibi_slopes=attn_metadata.alibi_slopes,
sm_scale=attn_metadata.sm_scale,
kv_group_num=kwargs.get("num_key_value_groups", 1),
q_len=kwargs.get("q_len", 1),
)
def get_attention_backend(
model_shard_infer_config: ModelShardInferenceConfig,
) -> AttentionBackend:
"""
Get the attention backend based on the inference configurations. The modeling will use CUDA-kernel-based backend
for attention module calculation only when:
1. using CUDA kernel (use_cuda_kernel=True)
2. can use flash attention (flash-attn installed and dtype is fp16 or bf16)
3. not using speculative decoding (currently cuda kernel not support speculative decoding)
Otherwise, use Triton attention backend. If found flash-attn not installed while `use_cuda_kernel` is True,
the Triton backend will use a new k cache layout for Triton kernels.
"""
# Currently only triton kernels support speculative decoding
if model_shard_infer_config.use_spec_dec:
return TritonAttentionBackend()
if model_shard_infer_config.use_cuda_kernel:
return CudaAttentionBackend(model_shard_infer_config.use_flash_attn)
return TritonAttentionBackend()

View File

@ -0,0 +1,146 @@
from abc import ABC, abstractmethod
from colossalai.inference.config import ModelShardInferenceConfig
from colossalai.inference.modeling.backends.attention_backend import AttentionMetaData
from colossalai.kernel.kernel_loader import InferenceOpsLoader
from colossalai.kernel.triton import copy_k_to_blocked_cache, decoding_fused_rotary_embedding, rotary_embedding
class PreAttentionBackend(ABC):
@abstractmethod
def prefill(self, attn_metadata: AttentionMetaData, **kwargs):
raise NotImplementedError
@abstractmethod
def decode(self, attn_metadata: AttentionMetaData, **kwargs):
raise NotImplementedError
class CudaPreAttentionBackend(PreAttentionBackend):
"""
CudaPreAttentionBackend handles KV cache initialization and positional encoding for CudaAttentionBackend.
"""
def __init__(self, use_flash_attn: bool):
super().__init__()
self.inference_ops = InferenceOpsLoader().load()
self.use_flash_attn = use_flash_attn
def prefill(self, attn_metadata: AttentionMetaData, **kwargs):
if self.use_flash_attn:
if not attn_metadata.use_alibi_attn:
self.inference_ops.rotary_embedding(
attn_metadata.query_states,
attn_metadata.key_states,
kwargs.get("cos", None),
kwargs.get("sin", None),
kwargs.get("high_precision", False),
)
self.inference_ops.context_kv_cache_memcpy(
attn_metadata.key_states,
attn_metadata.value_states,
attn_metadata.k_cache,
attn_metadata.v_cache,
attn_metadata.sequence_lengths,
attn_metadata.cu_seqlens,
attn_metadata.block_tables,
attn_metadata.kv_seq_len,
)
elif not attn_metadata.use_alibi_attn:
rotary_embedding(
attn_metadata.query_states,
attn_metadata.key_states,
kwargs.get("cos", None),
kwargs.get("sin", None),
)
def decode(self, attn_metadata: AttentionMetaData, **kwargs):
if not attn_metadata.use_alibi_attn:
self.inference_ops.rotary_embedding_and_cache_copy(
attn_metadata.query_states,
attn_metadata.key_states,
attn_metadata.value_states,
kwargs.get("cos", None),
kwargs.get("sin", None),
attn_metadata.k_cache,
attn_metadata.v_cache,
attn_metadata.sequence_lengths,
attn_metadata.block_tables,
kwargs.get("high_precision", None),
)
else:
self.inference_ops.decode_kv_cache_memcpy(
attn_metadata.key_states,
attn_metadata.value_states,
attn_metadata.k_cache,
attn_metadata.v_cache,
attn_metadata.sequence_lengths,
attn_metadata.block_tables,
)
class TritonPreAttentionBackend(PreAttentionBackend):
"""
TritonPreAttentionBackend handles KV cache initialization and positional encoding for TritonAttentionBackend.
"""
def prefill(self, attn_metadata: AttentionMetaData, **kwargs):
if not attn_metadata.use_alibi_attn:
rotary_embedding(
attn_metadata.query_states,
attn_metadata.key_states,
kwargs.get("cos", None),
kwargs.get("sin", None),
)
def decode(self, attn_metadata: AttentionMetaData, **kwargs):
if not attn_metadata.use_spec_dec and not attn_metadata.use_alibi_attn:
decoding_fused_rotary_embedding(
attn_metadata.query_states,
attn_metadata.key_states,
attn_metadata.value_states,
kwargs.get("cos", None),
kwargs.get("sin", None),
attn_metadata.k_cache,
attn_metadata.v_cache,
attn_metadata.block_tables,
attn_metadata.sequence_lengths,
)
else: # else if using speculative decoding
if not attn_metadata.use_alibi_attn:
rotary_embedding(
attn_metadata.query_states,
attn_metadata.key_states,
kwargs.get("cos", None),
kwargs.get("sin", None),
)
copy_k_to_blocked_cache(
attn_metadata.key_states,
attn_metadata.k_cache,
kv_lengths=attn_metadata.sequence_lengths,
block_tables=attn_metadata.block_tables,
n=kwargs.get("q_len", 1),
)
copy_k_to_blocked_cache(
attn_metadata.value_states,
attn_metadata.v_cache,
kv_lengths=attn_metadata.sequence_lengths,
block_tables=attn_metadata.block_tables,
n=kwargs.get("q_len", 1),
)
def get_pre_attention_backend(
model_shard_infer_config: ModelShardInferenceConfig,
) -> PreAttentionBackend:
"""
Get the backend for pre-attention computations, including potisional encoding like
RoPE and KV cache initialization. It adopt the same selection logic as attention_backend/get_attention_backend.
"""
if model_shard_infer_config.use_spec_dec:
return TritonPreAttentionBackend()
if model_shard_infer_config.use_cuda_kernel:
return CudaPreAttentionBackend(model_shard_infer_config.use_flash_attn)
return TritonPreAttentionBackend()

View File

@ -1,68 +1,27 @@
# This code is adapted from huggingface baichuan model: hhttps://huggingface.co/baichuan-inc/Baichuan2-13B-Base/blob/main/modeling_baichuan.py # This code is adapted from huggingface baichuan model: hhttps://huggingface.co/baichuan-inc/Baichuan2-13B-Base/blob/main/modeling_baichuan.py
import itertools import itertools
import math
from typing import List, Optional, Tuple, Union from typing import List, Optional, Tuple, Union
import torch import torch
import torch.nn as nn import torch.nn as nn
from torch.distributed import ProcessGroup from torch.distributed import ProcessGroup
from colossalai.inference.config import ModelShardInferenceConfig
from colossalai.inference.flash_decoding_utils import FDIntermTensors from colossalai.inference.flash_decoding_utils import FDIntermTensors
from colossalai.inference.modeling.backends.attention_backend import AttentionMetaData, get_attention_backend
from colossalai.inference.modeling.backends.pre_attention_backend import get_pre_attention_backend
from colossalai.inference.modeling.models.nopadding_llama import NopadLlamaMLP from colossalai.inference.modeling.models.nopadding_llama import NopadLlamaMLP
from colossalai.inference.utils import get_alibi_slopes
from colossalai.kernel.kernel_loader import InferenceOpsLoader from colossalai.kernel.kernel_loader import InferenceOpsLoader
from colossalai.kernel.triton import ( from colossalai.kernel.triton import rms_layernorm
context_attention_unpadded,
copy_k_to_blocked_cache,
decoding_fused_rotary_embedding,
flash_decoding_attention,
rms_layernorm,
rotary_embedding,
)
from colossalai.logging import get_dist_logger from colossalai.logging import get_dist_logger
from colossalai.shardformer.layer.parallel_module import ParallelModule from colossalai.shardformer.layer.parallel_module import ParallelModule
from colossalai.tensor.d_tensor import Layout, distribute_tensor, is_distributed_tensor from colossalai.tensor.d_tensor import Layout, distribute_tensor, is_distributed_tensor
logger = get_dist_logger(__name__)
try:
from flash_attn import flash_attn_varlen_func
use_flash_attn2 = True
except ImportError:
use_flash_attn2 = False
logger.warning(f"flash_attn2 has not been installed yet, we will use triton flash attn instead.")
logger = get_dist_logger(__name__)
try:
from flash_attn import flash_attn_varlen_func
use_flash_attn2 = True
except ImportError:
use_flash_attn2 = False
logger.warning(f"flash_attn2 has not been installed yet, we will use triton flash attn instead.")
inference_ops = InferenceOpsLoader().load() inference_ops = InferenceOpsLoader().load()
logger = get_dist_logger(__name__) logger = get_dist_logger(__name__)
# alibi slopes calculation adapted from https://github.com/huggingface/transformers/blob/v4.36.0/src/transformers/models/bloom/modeling_bloom.py#L57
def get_alibi_slopes(num_heads: int, device: torch.device) -> torch.Tensor:
closest_power_of_2 = 2 ** math.floor(math.log2(num_heads))
base = torch.tensor(2 ** (-(2 ** -(math.log2(closest_power_of_2) - 3))), dtype=torch.float32, device=device)
powers = torch.arange(1, 1 + closest_power_of_2, dtype=torch.int32, device=device)
slopes = torch.pow(base, powers)
if closest_power_of_2 != num_heads:
extra_base = torch.tensor(
2 ** (-(2 ** -(math.log2(2 * closest_power_of_2) - 3))), dtype=torch.float32, device=device
)
num_remaining_heads = min(closest_power_of_2, num_heads - closest_power_of_2)
extra_powers = torch.arange(1, 1 + 2 * num_remaining_heads, 2, dtype=torch.int32, device=device)
slopes = torch.cat([slopes, torch.pow(extra_base, extra_powers)], dim=0)
return slopes
def baichuan_rmsnorm_forward( def baichuan_rmsnorm_forward(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
@ -102,6 +61,7 @@ class NopadBaichuanAttention(ParallelModule):
attn_oproj: ParallelModule = None, attn_oproj: ParallelModule = None,
num_heads: int = None, num_heads: int = None,
hidden_size: int = None, hidden_size: int = None,
model_shard_infer_config: ModelShardInferenceConfig = None,
process_group: ProcessGroup = None, process_group: ProcessGroup = None,
helper_layout: Layout = None, helper_layout: Layout = None,
): ):
@ -126,6 +86,9 @@ class NopadBaichuanAttention(ParallelModule):
self.qkv_weight = nn.Parameter(torch.stack(qkv_weight_list, dim=0)) self.qkv_weight = nn.Parameter(torch.stack(qkv_weight_list, dim=0))
self.helper_layout = helper_layout self.helper_layout = helper_layout
self.use_cuda_kernel = model_shard_infer_config.use_cuda_kernel
self.attention_backend = get_attention_backend(model_shard_infer_config)
self.pre_attention_backend = get_pre_attention_backend(model_shard_infer_config)
self.alibi_slopes = None self.alibi_slopes = None
self.use_alibi_attn = False self.use_alibi_attn = False
@ -155,6 +118,7 @@ class NopadBaichuanAttention(ParallelModule):
attn_kproj_w = k_proj_w attn_kproj_w = k_proj_w
attn_vproj_w = v_proj_w attn_vproj_w = v_proj_w
attn_oproj = module.o_proj attn_oproj = module.o_proj
model_shard_infer_config = kwargs.get("model_shard_infer_config", None)
helper_layout = ( helper_layout = (
module.W_pack.weight.dist_layout module.W_pack.weight.dist_layout
@ -166,6 +130,7 @@ class NopadBaichuanAttention(ParallelModule):
attn_kproj_w=attn_kproj_w, attn_kproj_w=attn_kproj_w,
attn_vproj_w=attn_vproj_w, attn_vproj_w=attn_vproj_w,
attn_oproj=attn_oproj, attn_oproj=attn_oproj,
model_shard_infer_config=model_shard_infer_config,
num_heads=module.num_heads, num_heads=module.num_heads,
hidden_size=module.hidden_size, hidden_size=module.hidden_size,
process_group=process_group, process_group=process_group,
@ -234,7 +199,6 @@ class NopadBaichuanAttention(ParallelModule):
kv_seq_len: int = 0, kv_seq_len: int = 0,
output_tensor: torch.Tensor = None, output_tensor: torch.Tensor = None,
sm_scale: int = None, sm_scale: int = None,
use_cuda_kernel: bool = True,
cu_seqlens: torch.Tensor = None, cu_seqlens: torch.Tensor = None,
high_precision: bool = False, high_precision: bool = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
@ -253,7 +217,6 @@ class NopadBaichuanAttention(ParallelModule):
kv_seq_len (int, optional): The max sequence length of input sequences. Defaults to 0. kv_seq_len (int, optional): The max sequence length of input sequences. Defaults to 0.
output_tensor (torch.Tensor, optional): The mid tensor holds the output of attention. Defaults to None. output_tensor (torch.Tensor, optional): The mid tensor holds the output of attention. Defaults to None.
sm_scale (int, optional): Used for flash attention. Defaults to None. sm_scale (int, optional): Used for flash attention. Defaults to None.
use_cuda_kernel: (bool, optional): Whether to use cuda kernel. Defaults to True.
cu_seqlens(torch.Tensor, optional): Holding the cumulative sum of sequence length. cu_seqlens(torch.Tensor, optional): Holding the cumulative sum of sequence length.
high_precision(Optional[bool]): Whether to use float32 for underlying calculations of float16 data to achieve higher precision, defaults to False. high_precision(Optional[bool]): Whether to use float32 for underlying calculations of float16 data to achieve higher precision, defaults to False.
""" """
@ -267,121 +230,49 @@ class NopadBaichuanAttention(ParallelModule):
block_size = k_cache.size(-2) block_size = k_cache.size(-2)
if is_prompts: attn_metadata = AttentionMetaData(
if not is_verifier and use_cuda_kernel and query_states.dtype != torch.float32 and use_flash_attn2: query_states=query_states,
# flash attn 2 currently only supports FP16/BF16. key_states=key_states,
if not self.use_alibi_attn: value_states=value_states,
inference_ops.rotary_embedding(query_states, key_states, cos_sin[0], cos_sin[1], high_precision) k_cache=k_cache,
inference_ops.context_kv_cache_memcpy( v_cache=v_cache,
key_states, value_states, k_cache, v_cache, sequence_lengths, cu_seqlens, block_tables, kv_seq_len block_tables=block_tables,
) block_size=block_size,
attn_output = flash_attn_varlen_func( kv_seq_len=kv_seq_len,
query_states, sequence_lengths=sequence_lengths,
key_states, sm_scale=sm_scale,
value_states, alibi_slopes=self.alibi_slopes,
cu_seqlens_q=cu_seqlens, cu_seqlens=cu_seqlens,
cu_seqlens_k=cu_seqlens, output_tensor=output_tensor,
max_seqlen_q=kv_seq_len, use_spec_dec=is_verifier,
max_seqlen_k=kv_seq_len, use_alibi_attn=self.use_alibi_attn,
dropout_p=0.0, )
softmax_scale=sm_scale,
causal=True, if is_prompts: # prefilling stage
alibi_slopes=self.alibi_slopes, self.pre_attention_backend.prefill(
) attn_metadata,
attn_output = attn_output.view(token_nums, -1) cos=cos_sin[0],
else: sin=cos_sin[1],
if not self.use_alibi_attn: high_precision=high_precision,
rotary_embedding(query_states, key_states, cos_sin[0], cos_sin[1]) )
attn_output = context_attention_unpadded( attn_output = self.attention_backend.prefill(
q=query_states, attn_metadata,
k=key_states, token_nums=token_nums,
v=value_states, )
k_cache=k_cache, else: # decoding stage
v_cache=v_cache,
context_lengths=sequence_lengths,
block_tables=block_tables,
block_size=block_size,
output=output_tensor,
alibi_slopes=self.alibi_slopes,
max_seq_len=kv_seq_len,
sm_scale=sm_scale,
use_new_kcache_layout=use_cuda_kernel,
)
else:
q_len = tokens_to_verify + 1 if is_verifier else 1 q_len = tokens_to_verify + 1 if is_verifier else 1
if use_cuda_kernel: self.pre_attention_backend.decode(
if not self.use_alibi_attn: attn_metadata,
inference_ops.rotary_embedding_and_cache_copy( cos=cos_sin[0],
query_states, sin=cos_sin[1],
key_states, q_len=q_len,
value_states, )
cos_sin[0], attn_output = self.attention_backend.decode(
cos_sin[1], attn_metadata,
k_cache, fd_inter_tensor=fd_inter_tensor,
v_cache, q_len=q_len,
sequence_lengths, )
block_tables,
high_precision,
)
else:
inference_ops.decode_kv_cache_memcpy(
key_states, value_states, k_cache, v_cache, sequence_lengths, block_tables
)
inference_ops.flash_decoding_attention(
output_tensor,
query_states,
k_cache,
v_cache,
sequence_lengths,
block_tables,
block_size,
kv_seq_len,
fd_inter_tensor.mid_output,
fd_inter_tensor.exp_sums,
fd_inter_tensor.max_logits,
self.alibi_slopes,
sm_scale,
)
attn_output = output_tensor
else:
if not is_verifier and not self.use_alibi_attn:
decoding_fused_rotary_embedding(
query_states,
key_states,
value_states,
cos_sin[0],
cos_sin[1],
k_cache,
v_cache,
block_tables,
sequence_lengths,
)
else:
if not self.use_alibi_attn:
rotary_embedding(query_states, key_states, cos_sin[0], cos_sin[1])
copy_k_to_blocked_cache(
key_states, k_cache, kv_lengths=sequence_lengths, block_tables=block_tables, n=q_len
)
copy_k_to_blocked_cache(
value_states, v_cache, kv_lengths=sequence_lengths, block_tables=block_tables, n=q_len
)
attn_output = flash_decoding_attention(
q=query_states,
k_cache=k_cache,
v_cache=v_cache,
kv_seq_len=sequence_lengths,
block_tables=block_tables,
block_size=block_size,
max_seq_len_in_batch=kv_seq_len,
output=output_tensor,
mid_output=fd_inter_tensor.mid_output,
mid_output_lse=fd_inter_tensor.mid_output_lse,
alibi_slopes=self.alibi_slopes,
sm_scale=sm_scale,
q_len=q_len,
)
attn_output = attn_output.view(-1, self.hidden_size) attn_output = attn_output.view(-1, self.hidden_size)
attn_output = self.o_proj(attn_output) attn_output = self.o_proj(attn_output)

View File

@ -16,18 +16,13 @@ from transformers.models.llama.modeling_llama import (
LlamaRMSNorm, LlamaRMSNorm,
) )
from colossalai.inference.config import InputMetaData from colossalai.inference.config import InputMetaData, ModelShardInferenceConfig
from colossalai.inference.flash_decoding_utils import FDIntermTensors from colossalai.inference.flash_decoding_utils import FDIntermTensors
from colossalai.inference.modeling.backends.attention_backend import AttentionMetaData, get_attention_backend
from colossalai.inference.modeling.backends.pre_attention_backend import get_pre_attention_backend
from colossalai.inference.utils import can_use_flash_attn2
from colossalai.kernel.kernel_loader import InferenceOpsLoader from colossalai.kernel.kernel_loader import InferenceOpsLoader
from colossalai.kernel.triton import ( from colossalai.kernel.triton import get_xine_cache, rms_layernorm
context_attention_unpadded,
copy_k_to_blocked_cache,
decoding_fused_rotary_embedding,
flash_decoding_attention,
get_xine_cache,
rms_layernorm,
rotary_embedding,
)
from colossalai.logging import get_dist_logger from colossalai.logging import get_dist_logger
from colossalai.shardformer.layer.parallel_module import ParallelModule from colossalai.shardformer.layer.parallel_module import ParallelModule
from colossalai.tensor.d_tensor import distribute_tensor, is_distributed_tensor from colossalai.tensor.d_tensor import distribute_tensor, is_distributed_tensor
@ -36,14 +31,6 @@ inference_ops = InferenceOpsLoader().load()
logger = get_dist_logger(__name__) logger = get_dist_logger(__name__)
try:
from flash_attn import flash_attn_varlen_func
use_flash_attn2 = True
except ImportError:
use_flash_attn2 = False
logger.warning(f"flash_attn2 has not been installed yet, we will use triton flash attn instead.")
def llama_causal_lm_forward( def llama_causal_lm_forward(
self: LlamaForCausalLM, self: LlamaForCausalLM,
@ -126,7 +113,7 @@ def llama_model_forward(
cos_sin = (self._cos_cached[rotary_indexes], self._sin_cached[rotary_indexes]) cos_sin = (self._cos_cached[rotary_indexes], self._sin_cached[rotary_indexes])
elif use_cuda_kernel: elif use_cuda_kernel:
if inputmetadata.dtype != torch.float32 and use_flash_attn2: if can_use_flash_attn2(inputmetadata.dtype):
cu_seqlens = F.pad(torch.cumsum(sequence_lengths, dim=0, dtype=torch.torch.int32), (1, 0)) cu_seqlens = F.pad(torch.cumsum(sequence_lengths, dim=0, dtype=torch.torch.int32), (1, 0))
hidden_dim = self._cos_cached.size(-1) hidden_dim = self._cos_cached.size(-1)
@ -238,7 +225,6 @@ def llama_decoder_layer_forward(
kv_seq_len=kv_seq_len, kv_seq_len=kv_seq_len,
output_tensor=output_tensor, output_tensor=output_tensor,
sm_scale=sm_scale, sm_scale=sm_scale,
use_cuda_kernel=use_cuda_kernel,
cu_seqlens=cu_seqlens, cu_seqlens=cu_seqlens,
high_precision=high_precision, high_precision=high_precision,
) )
@ -402,6 +388,7 @@ class NopadLlamaAttention(LlamaAttention, ParallelModule):
attn_vproj_w: torch.Tensor = None, attn_vproj_w: torch.Tensor = None,
attn_oproj: ParallelModule = None, attn_oproj: ParallelModule = None,
process_group: ProcessGroup = None, process_group: ProcessGroup = None,
model_shard_infer_config: ModelShardInferenceConfig = None,
num_heads: int = None, num_heads: int = None,
hidden_size: int = None, hidden_size: int = None,
num_key_value_heads: int = None, num_key_value_heads: int = None,
@ -433,6 +420,9 @@ class NopadLlamaAttention(LlamaAttention, ParallelModule):
self.rope_theta = config.rope_theta self.rope_theta = config.rope_theta
self.is_causal = True self.is_causal = True
self.attention_backend = get_attention_backend(model_shard_infer_config)
self.pre_attention_backend = get_pre_attention_backend(model_shard_infer_config)
if self.num_heads == self.num_key_value_heads: if self.num_heads == self.num_key_value_heads:
qkv_weight_list = [attn_qproj_w.transpose(0, 1), attn_kproj_w.transpose(0, 1), attn_vproj_w.transpose(0, 1)] qkv_weight_list = [attn_qproj_w.transpose(0, 1), attn_kproj_w.transpose(0, 1), attn_vproj_w.transpose(0, 1)]
self.qkv_weight = nn.Parameter(torch.stack(qkv_weight_list, dim=0)) self.qkv_weight = nn.Parameter(torch.stack(qkv_weight_list, dim=0))
@ -462,6 +452,7 @@ class NopadLlamaAttention(LlamaAttention, ParallelModule):
attn_vproj_w = module.v_proj.weight attn_vproj_w = module.v_proj.weight
assert is_distributed_tensor(attn_qproj_w), "attn_qproj_w must be dist tensor" assert is_distributed_tensor(attn_qproj_w), "attn_qproj_w must be dist tensor"
attn_oproj = module.o_proj attn_oproj = module.o_proj
model_shard_infer_config = kwargs.get("model_shard_infer_config", None)
attn_layer = NopadLlamaAttention( attn_layer = NopadLlamaAttention(
config=config, config=config,
@ -471,6 +462,7 @@ class NopadLlamaAttention(LlamaAttention, ParallelModule):
attn_vproj_w=attn_vproj_w, attn_vproj_w=attn_vproj_w,
attn_oproj=attn_oproj, attn_oproj=attn_oproj,
process_group=process_group, process_group=process_group,
model_shard_infer_config=model_shard_infer_config,
num_heads=module.num_heads, num_heads=module.num_heads,
hidden_size=module.hidden_size, hidden_size=module.hidden_size,
num_key_value_heads=module.num_key_value_heads, num_key_value_heads=module.num_key_value_heads,
@ -533,111 +525,50 @@ class NopadLlamaAttention(LlamaAttention, ParallelModule):
block_size = k_cache.size(-2) block_size = k_cache.size(-2)
if is_prompts: attn_metadata = AttentionMetaData(
if not is_verifier and use_cuda_kernel and query_states.dtype != torch.float32 and use_flash_attn2: query_states=query_states,
# flash attn 2 currently only supports FP16/BF16. key_states=key_states,
inference_ops.rotary_embedding(query_states, key_states, cos_sin[0], cos_sin[1], high_precision) value_states=value_states,
inference_ops.context_kv_cache_memcpy( k_cache=k_cache,
key_states, value_states, k_cache, v_cache, sequence_lengths, cu_seqlens, block_tables, kv_seq_len v_cache=v_cache,
) block_tables=block_tables,
block_size=block_size,
kv_seq_len=kv_seq_len,
sequence_lengths=sequence_lengths,
sm_scale=sm_scale,
alibi_slopes=None,
cu_seqlens=cu_seqlens,
output_tensor=output_tensor,
use_spec_dec=is_verifier,
use_alibi_attn=False,
)
attn_output = flash_attn_varlen_func( if is_prompts: # prefilling stage
query_states, self.pre_attention_backend.prefill(
key_states, attn_metadata,
value_states, cos=cos_sin[0],
cu_seqlens_q=cu_seqlens, sin=cos_sin[1],
cu_seqlens_k=cu_seqlens, high_precision=high_precision,
max_seqlen_q=kv_seq_len, )
max_seqlen_k=kv_seq_len, attn_output = self.attention_backend.prefill(
dropout_p=0.0, attn_metadata,
softmax_scale=sm_scale, token_nums=token_nums,
causal=True, )
) else: # decoding stage
attn_output = attn_output.view(token_nums, -1)
else:
rotary_embedding(query_states, key_states, cos_sin[0], cos_sin[1])
attn_output = context_attention_unpadded(
q=query_states,
k=key_states,
v=value_states,
k_cache=k_cache,
v_cache=v_cache,
context_lengths=sequence_lengths,
block_tables=block_tables,
block_size=block_size,
output=output_tensor,
max_seq_len=kv_seq_len,
sm_scale=sm_scale,
use_new_kcache_layout=use_cuda_kernel,
)
else:
q_len = tokens_to_verify + 1 if is_verifier else 1 q_len = tokens_to_verify + 1 if is_verifier else 1
if use_cuda_kernel: self.pre_attention_backend.decode(
inference_ops.rotary_embedding_and_cache_copy( attn_metadata,
query_states, cos=cos_sin[0],
key_states, sin=cos_sin[1],
value_states, q_len=q_len,
cos_sin[0], )
cos_sin[1], attn_output = self.attention_backend.decode(
k_cache, attn_metadata,
v_cache, fd_inter_tensor=fd_inter_tensor,
sequence_lengths, num_key_value_groups=self.num_key_value_groups,
block_tables, q_len=q_len,
high_precision, )
)
inference_ops.flash_decoding_attention(
output_tensor,
query_states,
k_cache,
v_cache,
sequence_lengths,
block_tables,
block_size,
kv_seq_len,
fd_inter_tensor.mid_output,
fd_inter_tensor.exp_sums,
fd_inter_tensor.max_logits,
None,
sm_scale,
)
attn_output = output_tensor
else:
if is_verifier:
rotary_embedding(query_states, key_states, cos_sin[0], cos_sin[1])
copy_k_to_blocked_cache(
key_states, k_cache, kv_lengths=sequence_lengths, block_tables=block_tables, n=q_len
)
copy_k_to_blocked_cache(
value_states, v_cache, kv_lengths=sequence_lengths, block_tables=block_tables, n=q_len
)
else:
decoding_fused_rotary_embedding(
query_states,
key_states,
value_states,
cos_sin[0],
cos_sin[1],
k_cache,
v_cache,
block_tables,
sequence_lengths,
)
attn_output = flash_decoding_attention(
q=query_states,
k_cache=k_cache,
v_cache=v_cache,
kv_seq_len=sequence_lengths,
block_tables=block_tables,
block_size=block_size,
max_seq_len_in_batch=kv_seq_len,
output=output_tensor,
mid_output=fd_inter_tensor.mid_output,
mid_output_lse=fd_inter_tensor.mid_output_lse,
sm_scale=sm_scale,
kv_group_num=self.num_key_value_groups,
q_len=q_len,
)
attn_output = attn_output.view(-1, self.hidden_size) attn_output = attn_output.view(-1, self.hidden_size)
attn_output = self.o_proj(attn_output) attn_output = self.o_proj(attn_output)

View File

@ -70,6 +70,9 @@ class NoPaddingBaichuanModelInferPolicy(LlamaForCausalLMPolicy, RPC_PARAM):
SubModuleReplacementDescription( SubModuleReplacementDescription(
suffix="self_attn", suffix="self_attn",
target_module=NopadBaichuanAttention, target_module=NopadBaichuanAttention,
kwargs={
"model_shard_infer_config": self.shard_config.extra_kwargs["model_shard_infer_config"],
},
), ),
], ],
) )

View File

@ -72,6 +72,9 @@ class NoPaddingLlamaModelInferPolicy(LlamaForCausalLMPolicy, RPC_PARAM):
SubModuleReplacementDescription( SubModuleReplacementDescription(
suffix="self_attn", suffix="self_attn",
target_module=NopadLlamaAttention, target_module=NopadLlamaAttention,
kwargs={
"model_shard_infer_config": self.shard_config.extra_kwargs["model_shard_infer_config"],
},
), ),
], ],
) )

View File

@ -1,6 +1,7 @@
""" """
Utils for model inference Utils for model inference
""" """
import math
import os import os
import re import re
from pathlib import Path from pathlib import Path
@ -9,8 +10,11 @@ from typing import Optional, Tuple
import torch import torch
from torch import nn from torch import nn
from colossalai.logging import get_dist_logger
from colossalai.testing import free_port from colossalai.testing import free_port
logger = get_dist_logger(__name__)
def init_to_get_rotary(self, base=10000, use_elem=False): def init_to_get_rotary(self, base=10000, use_elem=False):
""" """
@ -113,3 +117,42 @@ def find_available_ports(num: int):
print(f"An OS error occurred: {e}") print(f"An OS error occurred: {e}")
raise RuntimeError("Error finding available ports") raise RuntimeError("Error finding available ports")
return free_ports return free_ports
def get_alibi_slopes(num_heads: int, device: torch.device) -> torch.Tensor:
"""
Alibi slopes calculation adapted from https://github.com/huggingface/transformers/blob/v4.36.0/src/transformers/models/bloom/modeling_bloom.py#L57
Args:
num_heads (int): The number of attention heads.
device (torch.device): The device to use.
Returns:
torch.Tensor: The Alibi slopes.
"""
closest_power_of_2 = 2 ** math.floor(math.log2(num_heads))
base = torch.tensor(2 ** (-(2 ** -(math.log2(closest_power_of_2) - 3))), dtype=torch.float32, device=device)
powers = torch.arange(1, 1 + closest_power_of_2, dtype=torch.int32, device=device)
slopes = torch.pow(base, powers)
if closest_power_of_2 != num_heads:
extra_base = torch.tensor(
2 ** (-(2 ** -(math.log2(2 * closest_power_of_2) - 3))), dtype=torch.float32, device=device
)
num_remaining_heads = min(closest_power_of_2, num_heads - closest_power_of_2)
extra_powers = torch.arange(1, 1 + 2 * num_remaining_heads, 2, dtype=torch.int32, device=device)
slopes = torch.cat([slopes, torch.pow(extra_base, extra_powers)], dim=0)
return slopes
def can_use_flash_attn2(dtype: torch.dtype) -> bool:
"""
Check flash attention2 availability.
"""
if dtype not in (torch.float16, torch.bfloat16):
return False
try:
return True
except ImportError:
logger.warning(f"flash_attn2 has not been installed yet, we will use triton flash attn instead.")
return False

View File

@ -4,7 +4,7 @@ import numpy as np
import pytest import pytest
import torch import torch
from colossalai.inference.modeling.models.nopadding_baichuan import get_alibi_slopes from colossalai.inference.utils import get_alibi_slopes
from colossalai.kernel.kernel_loader import InferenceOpsLoader from colossalai.kernel.kernel_loader import InferenceOpsLoader
from colossalai.utils import get_current_device from colossalai.utils import get_current_device
from tests.test_infer.test_kernels.triton.test_context_attn_unpad import generate_alibi_mask from tests.test_infer.test_kernels.triton.test_context_attn_unpad import generate_alibi_mask

View File

@ -2,7 +2,7 @@ import pytest
import torch import torch
from packaging import version from packaging import version
from colossalai.inference.modeling.models.nopadding_baichuan import get_alibi_slopes from colossalai.inference.utils import get_alibi_slopes
from colossalai.kernel.triton import context_attention_unpadded from colossalai.kernel.triton import context_attention_unpadded
from colossalai.utils import get_current_device from colossalai.utils import get_current_device
from tests.test_infer.test_kernels.triton.kernel_utils import ( from tests.test_infer.test_kernels.triton.kernel_utils import (

View File

@ -3,7 +3,7 @@ import pytest
import torch import torch
from packaging import version from packaging import version
from colossalai.inference.modeling.models.nopadding_baichuan import get_alibi_slopes from colossalai.inference.utils import get_alibi_slopes
from colossalai.kernel.triton import flash_decoding_attention from colossalai.kernel.triton import flash_decoding_attention
from colossalai.utils import get_current_device from colossalai.utils import get_current_device
from tests.test_infer.test_kernels.triton.kernel_utils import ( from tests.test_infer.test_kernels.triton.kernel_utils import (

View File

@ -55,7 +55,7 @@ def check_inference_engine(use_engine=False, do_sample=False, use_cuda_kernel=Fa
assert inference_engine.generation_config.max_new_tokens == output_len assert inference_engine.generation_config.max_new_tokens == output_len
inference_engine.add_request(prompts=inputs) inference_engine.add_request(prompts=inputs)
assert inference_engine.request_handler._has_waiting() assert inference_engine.request_handler._has_waiting()
generation_config = GenerationConfig(do_sample=do_sample, top_p=top_p, top_k=top_k) generation_config = GenerationConfig(do_sample=do_sample, top_p=top_p, top_k=top_k, max_new_tokens=output_len)
outputs = inference_engine.generate(generation_config=generation_config) outputs = inference_engine.generate(generation_config=generation_config)
else: else:
if prompt_template: if prompt_template: