Merge pull request #5771 from char-1ee/refactor/modeling

[Inference] Refactor modeling attention layer by abstracting attention backends
pull/5791/head
Li Xingjian 2024-06-10 11:52:22 +08:00 committed by GitHub
commit 77a219a082
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
15 changed files with 531 additions and 301 deletions

View File

@ -236,7 +236,7 @@ Completion api is used for single sequence request, like answer a question or co
- POST '/chat':
Chat api is used for conversation-style request, which often includes dialogue participants(i.e. roles) and corresponding words. Considering the input data are very different from normal inputs, we introduce Chat-Template to match the data format in chat models.
#### chat-template
Followed `transformers`, we add the chat-template argument. As chat models have been trained with very different formats for converting conversations into a single tokenizable string. Using a format that matches the training data is extremely important. This attribute(chat_template) is inclueded in HuggingFace tokenizers, containing a Jinja template that converts conversation histories into a correctly formatted string. You can refer to the [HuggingFace-blog](https://huggingface.co/blog/chat-templates) for more information. We also provide a simple example temlate bellow. Both str or file style chat template are supported.
Followed `transformers`, we add the chat-template argument. As chat models have been trained with very different formats for converting conversations into a single tokenizable string. Using a format that matches the training data is extremely important. This attribute(chat_template) is inclueded in HuggingFace tokenizers, containing a Jinja template that converts conversation histories into a correctly formatted string. You can refer to the [HuggingFace-blog](https://huggingface.co/blog/chat-templates) for more information. We also provide a simple example template bellow. Both str or file style chat template are supported.
### Usage
#### Args for customizing your server
The configuration for api server contains both serving interface and engine backend.

View File

@ -10,6 +10,7 @@ import torch
from transformers.generation import GenerationConfig
from colossalai.inference.flash_decoding_utils import FDIntermTensors
from colossalai.inference.utils import can_use_flash_attn2
GibiByte = 1024**3
@ -169,7 +170,8 @@ class InferenceConfig(RPC_PARAM):
no_repeat_ngram_size (Optional[int]): If no_repeat_ngram_size > 0, the consecutive tokens of ngram size can only appear once in inference sentences.
repetition_penalty (Optional[float]): The parameter that influences the model's treatment of new tokens in relation to their appearance in the prompt and the generated text. Values greater than 1 incentivize the model to introduce new tokens, whereas values less than 1 incentivize token repetition., defaults to 1.0.
ignore_eos(bool): Whether to ignore the EOS token and continue generating tokens when encountering the EOS token.
n_spec_tokens (int): The maximum number of speculating tokens, defaults to None.
use_spec_dec (bool): Indicate whether to use speculative decoding, defaults to False.
max_n_spec_tokens (int): The maximum number of speculating tokens, defaults to None.
glimpse_large_kv (bool): Whether to use large KV in drafter model, defaults to False.
block_size (int): The number of blocks in a logical block, defaults to 16.
tp_size (int): Tensor parallel size, defaults to 1.
@ -214,6 +216,7 @@ class InferenceConfig(RPC_PARAM):
ignore_eos: bool = False
# speculative decoding configs
use_spec_dec: bool = False
max_n_spec_tokens: int = 5
glimpse_large_kv: bool = False
@ -311,6 +314,16 @@ class InferenceConfig(RPC_PARAM):
return GenerationConfig.from_dict(meta_config)
def to_model_shard_inference_config(self) -> "ModelShardInferenceConfig":
use_flash_attn = can_use_flash_attn2(self.dtype)
model_inference_config = ModelShardInferenceConfig(
dtype=self.dtype,
use_cuda_kernel=self.use_cuda_kernel,
use_spec_dec=self.use_spec_dec,
use_flash_attn=use_flash_attn,
)
return model_inference_config
def to_rpc_param(self) -> dict:
kwargs = {
"dtype": str(self.dtype).split(".")[-1],
@ -362,3 +375,21 @@ class InferenceConfig(RPC_PARAM):
# Set the attributes from the parsed arguments.
inference_config = cls(**inference_config_args)
return inference_config
@dataclass
class ModelShardInferenceConfig:
"""
Configurations used during init of module for inference modeling.
Args:
dtype (torch.dtype): The data type for weights and activations.
use_cuda_kernel (bool): Whether to use cuda kernel, faster but lose some precision occasionally
use_spec_dec (bool): Indicate whether to use speculative decoding.
use_flash_attn (bool): Indicate whether to use flash attention.
"""
dtype: torch.dtype = None
use_cuda_kernel: bool = False
use_spec_dec: bool = False
use_flash_attn: bool = False

View File

@ -18,7 +18,7 @@ from transformers.models.llama.modeling_llama import LlamaForCausalLM
from colossalai.accelerator import get_accelerator
from colossalai.cluster import ProcessGroupMesh
from colossalai.inference.batch_bucket import BatchBucket
from colossalai.inference.config import InferenceConfig, InputMetaData
from colossalai.inference.config import InferenceConfig, InputMetaData, ModelShardInferenceConfig
from colossalai.inference.graph_runner import CUDAGraphRunner
from colossalai.inference.modeling.policy import model_policy_map
from colossalai.inference.sampler import search_tokens
@ -72,8 +72,9 @@ class InferenceEngine:
self.verbose = verbose
self.logger = get_dist_logger(__name__)
self.model_shard_infer_config = inference_config.to_model_shard_inference_config()
self.init_model(model_or_path, model_policy)
self.init_model(model_or_path, model_policy, self.model_shard_infer_config)
self.generation_config = inference_config.to_generation_config(self.model_config)
self.generation_config_dict = self.generation_config.to_dict()
@ -97,7 +98,8 @@ class InferenceEngine:
self.capture_model(self.k_cache, self.v_cache)
# Model and relatable attrs of speculative decoding will be set by `enable_spec_dec`
self.use_spec_dec = False
self.use_spec_dec = self.inference_config.use_spec_dec
self.drafter_model = None
self.drafter = None
self.use_glide = False
@ -105,13 +107,20 @@ class InferenceEngine:
self._verify_args()
def init_model(self, model_or_path: Union[nn.Module, str], model_policy: Union[Policy, Type[Policy]] = None):
def init_model(
self,
model_or_path: Union[nn.Module, str],
model_policy: Union[Policy, Type[Policy]] = None,
model_shard_infer_config: ModelShardInferenceConfig = None,
):
"""
Shard model or/and Load weight
Args:
model_or_path Union[nn.Module, str]: path to the checkpoint or model of transformer format.
model_policy (Policy): the policy to replace the model
model_policy (Policy): the policy to replace the model.
model_inference_config: the configuration for modeling initialization when inference.
model_shard_infer_config (ModelShardInferenceConfig): the configuration for init of module when inference.
"""
if isinstance(model_or_path, str):
@ -124,6 +133,7 @@ class InferenceEngine:
# the model load process in the future.
model = _supported_models[arch].from_pretrained(model_or_path, trust_remote_code=True)
else:
# TODO(char-1ee): if the model not supported, use transformers APIs to load and generate
raise ValueError(f"Model {arch} is not supported.")
except Exception as e:
@ -167,6 +177,7 @@ class InferenceEngine:
self.model = self._shardformer(
model,
model_policy,
model_shard_infer_config,
None,
tp_group=tp_group,
)
@ -187,7 +198,7 @@ class InferenceEngine:
# assert if_has_index_file, "the model path is invalid"
# cpt_io.load_model(self.model, model_index_file)
free_gpu_memory, total_gpu_memory = torch.cuda.mem_get_info()
free_gpu_memory, _ = torch.cuda.mem_get_info()
peak_memory = init_gpu_memory - free_gpu_memory
if self.verbose:
self.logger.info(
@ -287,6 +298,7 @@ class InferenceEngine:
self,
model: nn.Module,
model_policy: Policy,
model_shard_infer_config: ModelShardInferenceConfig = None,
stage_manager: PipelineStageManager = None,
tp_group: ProcessGroupMesh = None,
) -> nn.Module:
@ -312,6 +324,7 @@ class InferenceEngine:
enable_flash_attention=False,
enable_jit_fused=False,
enable_sequence_parallelism=False,
extra_kwargs={"model_shard_infer_config": model_shard_infer_config},
)
shardformer = ShardFormer(shard_config=shardconfig)
shard_model, _ = shardformer.optimize(model, model_policy)
@ -348,6 +361,7 @@ class InferenceEngine:
engine.clear_spec_dec()
```
"""
if drafter_model is None and self.drafter is None:
raise ValueError("Drafter not initialized. Please provide a Drafter Model")
if n_spec_tokens is not None:
@ -517,19 +531,19 @@ class InferenceEngine:
prompts_token_ids: Union[List[int], torch.Tensor, np.ndarray] = None,
return_token_ids: bool = False,
generation_config: Optional[GenerationConfig] = None,
) -> List[str]:
) -> Union[List[str], Tuple[List[str], List[List[int]]]]:
"""
Executing the inference step.
Args:
prompts (Union[List[str], optional): Input prompts. Defaults to None.
prompts_token_ids (List[List[int]], optional): token ids of input prompts. Defaults to None.
request_ids (List[int], optional): The request ID. Defaults to None.
return_token_ids (bool): Whether to return output token ids. Defaults to False.
generation_config (GenerationConfig, optional): Huggingface GenerationConfig used for inference. Defaults to None.
prompts (Union[List[str], optional): Input prompts. Defaults to None.
prompts_token_ids (Union[List[int], torch.Tensor, np.ndarray], optional): token ids of input prompts. Defaults to None.
return_token_ids (bool, optional): Whether to return output token ids. Defaults to False.
generation_config (Optional[GenerationConfig], optional): Huggingface GenerationConfig used for inference. Defaults to None.
Returns:
List[str]: Inference result returned by one generation.
Union[List[str], Tuple[List[str], List[List[int]]]]: Inference result returned by one generation.
"""
gen_config_dict = generation_config.to_dict() if generation_config is not None else {}

View File

@ -0,0 +1,168 @@
from abc import ABC, abstractmethod
from dataclasses import dataclass
import torch
from flash_attn import flash_attn_varlen_func
from colossalai.inference.config import ModelShardInferenceConfig
from colossalai.kernel.kernel_loader import InferenceOpsLoader
from colossalai.kernel.triton import context_attention_unpadded, flash_decoding_attention
@dataclass
class AttentionMetaData:
query_states: torch.Tensor
key_states: torch.Tensor
value_states: torch.Tensor
k_cache: torch.Tensor
v_cache: torch.Tensor
block_tables: torch.Tensor
block_size: int
kv_seq_len: int = None
sequence_lengths: torch.Tensor = None
cu_seqlens: torch.Tensor = None
sm_scale: int = None
alibi_slopes: torch.Tensor = None
output_tensor: torch.Tensor = None
use_spec_dec: bool = False
use_alibi_attn: bool = False
class AttentionBackend(ABC):
@abstractmethod
def prefill(self, attn_metadata: AttentionMetaData, **kwargs):
raise NotImplementedError
@abstractmethod
def decode(self, attn_metadatas: AttentionMetaData, **kwargs):
raise NotImplementedError
class CudaAttentionBackend(AttentionBackend):
"""
Attention backend when use_cuda_kernel is True but flash-attn not found. If flash-attn is not found,
it uses Triton op `context_attention_unpadded` for prefilling and our cuda op `flash_decoding_attention` for decoding.
"""
def __init__(self, use_flash_attn: bool):
super().__init__()
self.inference_ops = InferenceOpsLoader().load()
self.use_flash_attn = use_flash_attn
def prefill(self, attn_metadata: AttentionMetaData, **kwargs):
if self.use_flash_attn:
token_nums = kwargs.get("token_nums", -1)
attn_output = flash_attn_varlen_func(
attn_metadata.query_states,
attn_metadata.key_states,
attn_metadata.value_states,
cu_seqlens_q=attn_metadata.cu_seqlens,
cu_seqlens_k=attn_metadata.cu_seqlens,
max_seqlen_q=attn_metadata.kv_seq_len,
max_seqlen_k=attn_metadata.kv_seq_len,
dropout_p=0.0,
softmax_scale=attn_metadata.sm_scale,
causal=True,
alibi_slopes=attn_metadata.alibi_slopes,
)
attn_output = attn_output.view(token_nums, -1)
else:
attn_output = context_attention_unpadded(
q=attn_metadata.query_states,
k=attn_metadata.key_states,
v=attn_metadata.value_states,
k_cache=attn_metadata.k_cache,
v_cache=attn_metadata.v_cache,
context_lengths=attn_metadata.sequence_lengths,
block_tables=attn_metadata.block_tables,
block_size=attn_metadata.block_size,
output=attn_metadata.output_tensor,
alibi_slopes=attn_metadata.alibi_slopes,
max_seq_len=attn_metadata.kv_seq_len,
sm_scale=attn_metadata.sm_scale,
use_new_kcache_layout=True, # use new k-cache layout
)
return attn_output
def decode(self, attn_metadata: AttentionMetaData, **kwargs):
fd_inter_tensor = kwargs.get("fd_inter_tensor", None)
output_tensor = attn_metadata.output_tensor
self.inference_ops.flash_decoding_attention(
output_tensor,
attn_metadata.query_states,
attn_metadata.k_cache,
attn_metadata.v_cache,
attn_metadata.sequence_lengths,
attn_metadata.block_tables,
attn_metadata.block_size,
attn_metadata.kv_seq_len,
fd_inter_tensor.mid_output,
fd_inter_tensor.exp_sums,
fd_inter_tensor.max_logits,
attn_metadata.alibi_slopes,
attn_metadata.sm_scale,
)
return output_tensor
class TritonAttentionBackend(AttentionBackend):
"""
Attention backend when use_cuda_kernel is False. It uses pure Triton ops for prefilling and decoding.
"""
def prefill(self, attn_metadata: AttentionMetaData, **kwargs):
return context_attention_unpadded(
q=attn_metadata.query_states,
k=attn_metadata.key_states,
v=attn_metadata.value_states,
k_cache=attn_metadata.k_cache,
v_cache=attn_metadata.v_cache,
context_lengths=attn_metadata.sequence_lengths,
block_tables=attn_metadata.block_tables,
block_size=attn_metadata.block_size,
output=attn_metadata.output_tensor,
alibi_slopes=attn_metadata.alibi_slopes,
max_seq_len=attn_metadata.kv_seq_len,
sm_scale=attn_metadata.sm_scale,
)
def decode(self, attn_metadata: AttentionMetaData, **kwargs):
fd_inter_tensor = kwargs.get("fd_inter_tensor", None)
return flash_decoding_attention(
q=attn_metadata.query_states,
k_cache=attn_metadata.k_cache,
v_cache=attn_metadata.v_cache,
kv_seq_len=attn_metadata.sequence_lengths,
block_tables=attn_metadata.block_tables,
block_size=attn_metadata.block_size,
max_seq_len_in_batch=attn_metadata.kv_seq_len,
output=attn_metadata.output_tensor,
mid_output=fd_inter_tensor.mid_output,
mid_output_lse=fd_inter_tensor.mid_output_lse,
alibi_slopes=attn_metadata.alibi_slopes,
sm_scale=attn_metadata.sm_scale,
kv_group_num=kwargs.get("num_key_value_groups", 1),
q_len=kwargs.get("q_len", 1),
)
def get_attention_backend(
model_shard_infer_config: ModelShardInferenceConfig,
) -> AttentionBackend:
"""
Get the attention backend based on the inference configurations. The modeling will use CUDA-kernel-based backend
for attention module calculation only when:
1. using CUDA kernel (use_cuda_kernel=True)
2. can use flash attention (flash-attn installed and dtype is fp16 or bf16)
3. not using speculative decoding (currently cuda kernel not support speculative decoding)
Otherwise, use Triton attention backend. If found flash-attn not installed while `use_cuda_kernel` is True,
the Triton backend will use a new k cache layout for Triton kernels.
"""
# Currently only triton kernels support speculative decoding
if model_shard_infer_config.use_spec_dec:
return TritonAttentionBackend()
if model_shard_infer_config.use_cuda_kernel:
return CudaAttentionBackend(model_shard_infer_config.use_flash_attn)
return TritonAttentionBackend()

View File

@ -0,0 +1,146 @@
from abc import ABC, abstractmethod
from colossalai.inference.config import ModelShardInferenceConfig
from colossalai.inference.modeling.backends.attention_backend import AttentionMetaData
from colossalai.kernel.kernel_loader import InferenceOpsLoader
from colossalai.kernel.triton import copy_k_to_blocked_cache, decoding_fused_rotary_embedding, rotary_embedding
class PreAttentionBackend(ABC):
@abstractmethod
def prefill(self, attn_metadata: AttentionMetaData, **kwargs):
raise NotImplementedError
@abstractmethod
def decode(self, attn_metadata: AttentionMetaData, **kwargs):
raise NotImplementedError
class CudaPreAttentionBackend(PreAttentionBackend):
"""
CudaPreAttentionBackend handles KV cache initialization and positional encoding for CudaAttentionBackend.
"""
def __init__(self, use_flash_attn: bool):
super().__init__()
self.inference_ops = InferenceOpsLoader().load()
self.use_flash_attn = use_flash_attn
def prefill(self, attn_metadata: AttentionMetaData, **kwargs):
if self.use_flash_attn:
if not attn_metadata.use_alibi_attn:
self.inference_ops.rotary_embedding(
attn_metadata.query_states,
attn_metadata.key_states,
kwargs.get("cos", None),
kwargs.get("sin", None),
kwargs.get("high_precision", False),
)
self.inference_ops.context_kv_cache_memcpy(
attn_metadata.key_states,
attn_metadata.value_states,
attn_metadata.k_cache,
attn_metadata.v_cache,
attn_metadata.sequence_lengths,
attn_metadata.cu_seqlens,
attn_metadata.block_tables,
attn_metadata.kv_seq_len,
)
elif not attn_metadata.use_alibi_attn:
rotary_embedding(
attn_metadata.query_states,
attn_metadata.key_states,
kwargs.get("cos", None),
kwargs.get("sin", None),
)
def decode(self, attn_metadata: AttentionMetaData, **kwargs):
if not attn_metadata.use_alibi_attn:
self.inference_ops.rotary_embedding_and_cache_copy(
attn_metadata.query_states,
attn_metadata.key_states,
attn_metadata.value_states,
kwargs.get("cos", None),
kwargs.get("sin", None),
attn_metadata.k_cache,
attn_metadata.v_cache,
attn_metadata.sequence_lengths,
attn_metadata.block_tables,
kwargs.get("high_precision", None),
)
else:
self.inference_ops.decode_kv_cache_memcpy(
attn_metadata.key_states,
attn_metadata.value_states,
attn_metadata.k_cache,
attn_metadata.v_cache,
attn_metadata.sequence_lengths,
attn_metadata.block_tables,
)
class TritonPreAttentionBackend(PreAttentionBackend):
"""
TritonPreAttentionBackend handles KV cache initialization and positional encoding for TritonAttentionBackend.
"""
def prefill(self, attn_metadata: AttentionMetaData, **kwargs):
if not attn_metadata.use_alibi_attn:
rotary_embedding(
attn_metadata.query_states,
attn_metadata.key_states,
kwargs.get("cos", None),
kwargs.get("sin", None),
)
def decode(self, attn_metadata: AttentionMetaData, **kwargs):
if not attn_metadata.use_spec_dec and not attn_metadata.use_alibi_attn:
decoding_fused_rotary_embedding(
attn_metadata.query_states,
attn_metadata.key_states,
attn_metadata.value_states,
kwargs.get("cos", None),
kwargs.get("sin", None),
attn_metadata.k_cache,
attn_metadata.v_cache,
attn_metadata.block_tables,
attn_metadata.sequence_lengths,
)
else: # else if using speculative decoding
if not attn_metadata.use_alibi_attn:
rotary_embedding(
attn_metadata.query_states,
attn_metadata.key_states,
kwargs.get("cos", None),
kwargs.get("sin", None),
)
copy_k_to_blocked_cache(
attn_metadata.key_states,
attn_metadata.k_cache,
kv_lengths=attn_metadata.sequence_lengths,
block_tables=attn_metadata.block_tables,
n=kwargs.get("q_len", 1),
)
copy_k_to_blocked_cache(
attn_metadata.value_states,
attn_metadata.v_cache,
kv_lengths=attn_metadata.sequence_lengths,
block_tables=attn_metadata.block_tables,
n=kwargs.get("q_len", 1),
)
def get_pre_attention_backend(
model_shard_infer_config: ModelShardInferenceConfig,
) -> PreAttentionBackend:
"""
Get the backend for pre-attention computations, including potisional encoding like
RoPE and KV cache initialization. It adopt the same selection logic as attention_backend/get_attention_backend.
"""
if model_shard_infer_config.use_spec_dec:
return TritonPreAttentionBackend()
if model_shard_infer_config.use_cuda_kernel:
return CudaPreAttentionBackend(model_shard_infer_config.use_flash_attn)
return TritonPreAttentionBackend()

View File

@ -1,68 +1,27 @@
# This code is adapted from huggingface baichuan model: hhttps://huggingface.co/baichuan-inc/Baichuan2-13B-Base/blob/main/modeling_baichuan.py
import itertools
import math
from typing import List, Optional, Tuple, Union
import torch
import torch.nn as nn
from torch.distributed import ProcessGroup
from colossalai.inference.config import ModelShardInferenceConfig
from colossalai.inference.flash_decoding_utils import FDIntermTensors
from colossalai.inference.modeling.backends.attention_backend import AttentionMetaData, get_attention_backend
from colossalai.inference.modeling.backends.pre_attention_backend import get_pre_attention_backend
from colossalai.inference.modeling.models.nopadding_llama import NopadLlamaMLP
from colossalai.inference.utils import get_alibi_slopes
from colossalai.kernel.kernel_loader import InferenceOpsLoader
from colossalai.kernel.triton import (
context_attention_unpadded,
copy_k_to_blocked_cache,
decoding_fused_rotary_embedding,
flash_decoding_attention,
rms_layernorm,
rotary_embedding,
)
from colossalai.kernel.triton import rms_layernorm
from colossalai.logging import get_dist_logger
from colossalai.shardformer.layer.parallel_module import ParallelModule
from colossalai.tensor.d_tensor import Layout, distribute_tensor, is_distributed_tensor
logger = get_dist_logger(__name__)
try:
from flash_attn import flash_attn_varlen_func
use_flash_attn2 = True
except ImportError:
use_flash_attn2 = False
logger.warning(f"flash_attn2 has not been installed yet, we will use triton flash attn instead.")
logger = get_dist_logger(__name__)
try:
from flash_attn import flash_attn_varlen_func
use_flash_attn2 = True
except ImportError:
use_flash_attn2 = False
logger.warning(f"flash_attn2 has not been installed yet, we will use triton flash attn instead.")
inference_ops = InferenceOpsLoader().load()
logger = get_dist_logger(__name__)
# alibi slopes calculation adapted from https://github.com/huggingface/transformers/blob/v4.36.0/src/transformers/models/bloom/modeling_bloom.py#L57
def get_alibi_slopes(num_heads: int, device: torch.device) -> torch.Tensor:
closest_power_of_2 = 2 ** math.floor(math.log2(num_heads))
base = torch.tensor(2 ** (-(2 ** -(math.log2(closest_power_of_2) - 3))), dtype=torch.float32, device=device)
powers = torch.arange(1, 1 + closest_power_of_2, dtype=torch.int32, device=device)
slopes = torch.pow(base, powers)
if closest_power_of_2 != num_heads:
extra_base = torch.tensor(
2 ** (-(2 ** -(math.log2(2 * closest_power_of_2) - 3))), dtype=torch.float32, device=device
)
num_remaining_heads = min(closest_power_of_2, num_heads - closest_power_of_2)
extra_powers = torch.arange(1, 1 + 2 * num_remaining_heads, 2, dtype=torch.int32, device=device)
slopes = torch.cat([slopes, torch.pow(extra_base, extra_powers)], dim=0)
return slopes
def baichuan_rmsnorm_forward(
self,
hidden_states: torch.Tensor,
@ -102,6 +61,7 @@ class NopadBaichuanAttention(ParallelModule):
attn_oproj: ParallelModule = None,
num_heads: int = None,
hidden_size: int = None,
model_shard_infer_config: ModelShardInferenceConfig = None,
process_group: ProcessGroup = None,
helper_layout: Layout = None,
):
@ -126,6 +86,9 @@ class NopadBaichuanAttention(ParallelModule):
self.qkv_weight = nn.Parameter(torch.stack(qkv_weight_list, dim=0))
self.helper_layout = helper_layout
self.use_cuda_kernel = model_shard_infer_config.use_cuda_kernel
self.attention_backend = get_attention_backend(model_shard_infer_config)
self.pre_attention_backend = get_pre_attention_backend(model_shard_infer_config)
self.alibi_slopes = None
self.use_alibi_attn = False
@ -155,6 +118,7 @@ class NopadBaichuanAttention(ParallelModule):
attn_kproj_w = k_proj_w
attn_vproj_w = v_proj_w
attn_oproj = module.o_proj
model_shard_infer_config = kwargs.get("model_shard_infer_config", None)
helper_layout = (
module.W_pack.weight.dist_layout
@ -166,6 +130,7 @@ class NopadBaichuanAttention(ParallelModule):
attn_kproj_w=attn_kproj_w,
attn_vproj_w=attn_vproj_w,
attn_oproj=attn_oproj,
model_shard_infer_config=model_shard_infer_config,
num_heads=module.num_heads,
hidden_size=module.hidden_size,
process_group=process_group,
@ -234,7 +199,6 @@ class NopadBaichuanAttention(ParallelModule):
kv_seq_len: int = 0,
output_tensor: torch.Tensor = None,
sm_scale: int = None,
use_cuda_kernel: bool = True,
cu_seqlens: torch.Tensor = None,
high_precision: bool = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
@ -253,7 +217,6 @@ class NopadBaichuanAttention(ParallelModule):
kv_seq_len (int, optional): The max sequence length of input sequences. Defaults to 0.
output_tensor (torch.Tensor, optional): The mid tensor holds the output of attention. Defaults to None.
sm_scale (int, optional): Used for flash attention. Defaults to None.
use_cuda_kernel: (bool, optional): Whether to use cuda kernel. Defaults to True.
cu_seqlens(torch.Tensor, optional): Holding the cumulative sum of sequence length.
high_precision(Optional[bool]): Whether to use float32 for underlying calculations of float16 data to achieve higher precision, defaults to False.
"""
@ -267,121 +230,49 @@ class NopadBaichuanAttention(ParallelModule):
block_size = k_cache.size(-2)
if is_prompts:
if not is_verifier and use_cuda_kernel and query_states.dtype != torch.float32 and use_flash_attn2:
# flash attn 2 currently only supports FP16/BF16.
if not self.use_alibi_attn:
inference_ops.rotary_embedding(query_states, key_states, cos_sin[0], cos_sin[1], high_precision)
inference_ops.context_kv_cache_memcpy(
key_states, value_states, k_cache, v_cache, sequence_lengths, cu_seqlens, block_tables, kv_seq_len
)
attn_output = flash_attn_varlen_func(
query_states,
key_states,
value_states,
cu_seqlens_q=cu_seqlens,
cu_seqlens_k=cu_seqlens,
max_seqlen_q=kv_seq_len,
max_seqlen_k=kv_seq_len,
dropout_p=0.0,
softmax_scale=sm_scale,
causal=True,
alibi_slopes=self.alibi_slopes,
)
attn_output = attn_output.view(token_nums, -1)
else:
if not self.use_alibi_attn:
rotary_embedding(query_states, key_states, cos_sin[0], cos_sin[1])
attn_output = context_attention_unpadded(
q=query_states,
k=key_states,
v=value_states,
k_cache=k_cache,
v_cache=v_cache,
context_lengths=sequence_lengths,
block_tables=block_tables,
block_size=block_size,
output=output_tensor,
alibi_slopes=self.alibi_slopes,
max_seq_len=kv_seq_len,
sm_scale=sm_scale,
use_new_kcache_layout=use_cuda_kernel,
)
else:
attn_metadata = AttentionMetaData(
query_states=query_states,
key_states=key_states,
value_states=value_states,
k_cache=k_cache,
v_cache=v_cache,
block_tables=block_tables,
block_size=block_size,
kv_seq_len=kv_seq_len,
sequence_lengths=sequence_lengths,
sm_scale=sm_scale,
alibi_slopes=self.alibi_slopes,
cu_seqlens=cu_seqlens,
output_tensor=output_tensor,
use_spec_dec=is_verifier,
use_alibi_attn=self.use_alibi_attn,
)
if is_prompts: # prefilling stage
self.pre_attention_backend.prefill(
attn_metadata,
cos=cos_sin[0],
sin=cos_sin[1],
high_precision=high_precision,
)
attn_output = self.attention_backend.prefill(
attn_metadata,
token_nums=token_nums,
)
else: # decoding stage
q_len = tokens_to_verify + 1 if is_verifier else 1
if use_cuda_kernel:
if not self.use_alibi_attn:
inference_ops.rotary_embedding_and_cache_copy(
query_states,
key_states,
value_states,
cos_sin[0],
cos_sin[1],
k_cache,
v_cache,
sequence_lengths,
block_tables,
high_precision,
)
else:
inference_ops.decode_kv_cache_memcpy(
key_states, value_states, k_cache, v_cache, sequence_lengths, block_tables
)
inference_ops.flash_decoding_attention(
output_tensor,
query_states,
k_cache,
v_cache,
sequence_lengths,
block_tables,
block_size,
kv_seq_len,
fd_inter_tensor.mid_output,
fd_inter_tensor.exp_sums,
fd_inter_tensor.max_logits,
self.alibi_slopes,
sm_scale,
)
attn_output = output_tensor
else:
if not is_verifier and not self.use_alibi_attn:
decoding_fused_rotary_embedding(
query_states,
key_states,
value_states,
cos_sin[0],
cos_sin[1],
k_cache,
v_cache,
block_tables,
sequence_lengths,
)
else:
if not self.use_alibi_attn:
rotary_embedding(query_states, key_states, cos_sin[0], cos_sin[1])
copy_k_to_blocked_cache(
key_states, k_cache, kv_lengths=sequence_lengths, block_tables=block_tables, n=q_len
)
copy_k_to_blocked_cache(
value_states, v_cache, kv_lengths=sequence_lengths, block_tables=block_tables, n=q_len
)
attn_output = flash_decoding_attention(
q=query_states,
k_cache=k_cache,
v_cache=v_cache,
kv_seq_len=sequence_lengths,
block_tables=block_tables,
block_size=block_size,
max_seq_len_in_batch=kv_seq_len,
output=output_tensor,
mid_output=fd_inter_tensor.mid_output,
mid_output_lse=fd_inter_tensor.mid_output_lse,
alibi_slopes=self.alibi_slopes,
sm_scale=sm_scale,
q_len=q_len,
)
self.pre_attention_backend.decode(
attn_metadata,
cos=cos_sin[0],
sin=cos_sin[1],
q_len=q_len,
)
attn_output = self.attention_backend.decode(
attn_metadata,
fd_inter_tensor=fd_inter_tensor,
q_len=q_len,
)
attn_output = attn_output.view(-1, self.hidden_size)
attn_output = self.o_proj(attn_output)

View File

@ -16,18 +16,13 @@ from transformers.models.llama.modeling_llama import (
LlamaRMSNorm,
)
from colossalai.inference.config import InputMetaData
from colossalai.inference.config import InputMetaData, ModelShardInferenceConfig
from colossalai.inference.flash_decoding_utils import FDIntermTensors
from colossalai.inference.modeling.backends.attention_backend import AttentionMetaData, get_attention_backend
from colossalai.inference.modeling.backends.pre_attention_backend import get_pre_attention_backend
from colossalai.inference.utils import can_use_flash_attn2
from colossalai.kernel.kernel_loader import InferenceOpsLoader
from colossalai.kernel.triton import (
context_attention_unpadded,
copy_k_to_blocked_cache,
decoding_fused_rotary_embedding,
flash_decoding_attention,
get_xine_cache,
rms_layernorm,
rotary_embedding,
)
from colossalai.kernel.triton import get_xine_cache, rms_layernorm
from colossalai.logging import get_dist_logger
from colossalai.shardformer.layer.parallel_module import ParallelModule
from colossalai.tensor.d_tensor import distribute_tensor, is_distributed_tensor
@ -36,14 +31,6 @@ inference_ops = InferenceOpsLoader().load()
logger = get_dist_logger(__name__)
try:
from flash_attn import flash_attn_varlen_func
use_flash_attn2 = True
except ImportError:
use_flash_attn2 = False
logger.warning(f"flash_attn2 has not been installed yet, we will use triton flash attn instead.")
def llama_causal_lm_forward(
self: LlamaForCausalLM,
@ -126,7 +113,7 @@ def llama_model_forward(
cos_sin = (self._cos_cached[rotary_indexes], self._sin_cached[rotary_indexes])
elif use_cuda_kernel:
if inputmetadata.dtype != torch.float32 and use_flash_attn2:
if can_use_flash_attn2(inputmetadata.dtype):
cu_seqlens = F.pad(torch.cumsum(sequence_lengths, dim=0, dtype=torch.torch.int32), (1, 0))
hidden_dim = self._cos_cached.size(-1)
@ -238,7 +225,6 @@ def llama_decoder_layer_forward(
kv_seq_len=kv_seq_len,
output_tensor=output_tensor,
sm_scale=sm_scale,
use_cuda_kernel=use_cuda_kernel,
cu_seqlens=cu_seqlens,
high_precision=high_precision,
)
@ -402,6 +388,7 @@ class NopadLlamaAttention(LlamaAttention, ParallelModule):
attn_vproj_w: torch.Tensor = None,
attn_oproj: ParallelModule = None,
process_group: ProcessGroup = None,
model_shard_infer_config: ModelShardInferenceConfig = None,
num_heads: int = None,
hidden_size: int = None,
num_key_value_heads: int = None,
@ -433,6 +420,9 @@ class NopadLlamaAttention(LlamaAttention, ParallelModule):
self.rope_theta = config.rope_theta
self.is_causal = True
self.attention_backend = get_attention_backend(model_shard_infer_config)
self.pre_attention_backend = get_pre_attention_backend(model_shard_infer_config)
if self.num_heads == self.num_key_value_heads:
qkv_weight_list = [attn_qproj_w.transpose(0, 1), attn_kproj_w.transpose(0, 1), attn_vproj_w.transpose(0, 1)]
self.qkv_weight = nn.Parameter(torch.stack(qkv_weight_list, dim=0))
@ -462,6 +452,7 @@ class NopadLlamaAttention(LlamaAttention, ParallelModule):
attn_vproj_w = module.v_proj.weight
assert is_distributed_tensor(attn_qproj_w), "attn_qproj_w must be dist tensor"
attn_oproj = module.o_proj
model_shard_infer_config = kwargs.get("model_shard_infer_config", None)
attn_layer = NopadLlamaAttention(
config=config,
@ -471,6 +462,7 @@ class NopadLlamaAttention(LlamaAttention, ParallelModule):
attn_vproj_w=attn_vproj_w,
attn_oproj=attn_oproj,
process_group=process_group,
model_shard_infer_config=model_shard_infer_config,
num_heads=module.num_heads,
hidden_size=module.hidden_size,
num_key_value_heads=module.num_key_value_heads,
@ -533,111 +525,50 @@ class NopadLlamaAttention(LlamaAttention, ParallelModule):
block_size = k_cache.size(-2)
if is_prompts:
if not is_verifier and use_cuda_kernel and query_states.dtype != torch.float32 and use_flash_attn2:
# flash attn 2 currently only supports FP16/BF16.
inference_ops.rotary_embedding(query_states, key_states, cos_sin[0], cos_sin[1], high_precision)
inference_ops.context_kv_cache_memcpy(
key_states, value_states, k_cache, v_cache, sequence_lengths, cu_seqlens, block_tables, kv_seq_len
)
attn_metadata = AttentionMetaData(
query_states=query_states,
key_states=key_states,
value_states=value_states,
k_cache=k_cache,
v_cache=v_cache,
block_tables=block_tables,
block_size=block_size,
kv_seq_len=kv_seq_len,
sequence_lengths=sequence_lengths,
sm_scale=sm_scale,
alibi_slopes=None,
cu_seqlens=cu_seqlens,
output_tensor=output_tensor,
use_spec_dec=is_verifier,
use_alibi_attn=False,
)
attn_output = flash_attn_varlen_func(
query_states,
key_states,
value_states,
cu_seqlens_q=cu_seqlens,
cu_seqlens_k=cu_seqlens,
max_seqlen_q=kv_seq_len,
max_seqlen_k=kv_seq_len,
dropout_p=0.0,
softmax_scale=sm_scale,
causal=True,
)
attn_output = attn_output.view(token_nums, -1)
else:
rotary_embedding(query_states, key_states, cos_sin[0], cos_sin[1])
attn_output = context_attention_unpadded(
q=query_states,
k=key_states,
v=value_states,
k_cache=k_cache,
v_cache=v_cache,
context_lengths=sequence_lengths,
block_tables=block_tables,
block_size=block_size,
output=output_tensor,
max_seq_len=kv_seq_len,
sm_scale=sm_scale,
use_new_kcache_layout=use_cuda_kernel,
)
else:
if is_prompts: # prefilling stage
self.pre_attention_backend.prefill(
attn_metadata,
cos=cos_sin[0],
sin=cos_sin[1],
high_precision=high_precision,
)
attn_output = self.attention_backend.prefill(
attn_metadata,
token_nums=token_nums,
)
else: # decoding stage
q_len = tokens_to_verify + 1 if is_verifier else 1
if use_cuda_kernel:
inference_ops.rotary_embedding_and_cache_copy(
query_states,
key_states,
value_states,
cos_sin[0],
cos_sin[1],
k_cache,
v_cache,
sequence_lengths,
block_tables,
high_precision,
)
inference_ops.flash_decoding_attention(
output_tensor,
query_states,
k_cache,
v_cache,
sequence_lengths,
block_tables,
block_size,
kv_seq_len,
fd_inter_tensor.mid_output,
fd_inter_tensor.exp_sums,
fd_inter_tensor.max_logits,
None,
sm_scale,
)
attn_output = output_tensor
else:
if is_verifier:
rotary_embedding(query_states, key_states, cos_sin[0], cos_sin[1])
copy_k_to_blocked_cache(
key_states, k_cache, kv_lengths=sequence_lengths, block_tables=block_tables, n=q_len
)
copy_k_to_blocked_cache(
value_states, v_cache, kv_lengths=sequence_lengths, block_tables=block_tables, n=q_len
)
else:
decoding_fused_rotary_embedding(
query_states,
key_states,
value_states,
cos_sin[0],
cos_sin[1],
k_cache,
v_cache,
block_tables,
sequence_lengths,
)
attn_output = flash_decoding_attention(
q=query_states,
k_cache=k_cache,
v_cache=v_cache,
kv_seq_len=sequence_lengths,
block_tables=block_tables,
block_size=block_size,
max_seq_len_in_batch=kv_seq_len,
output=output_tensor,
mid_output=fd_inter_tensor.mid_output,
mid_output_lse=fd_inter_tensor.mid_output_lse,
sm_scale=sm_scale,
kv_group_num=self.num_key_value_groups,
q_len=q_len,
)
self.pre_attention_backend.decode(
attn_metadata,
cos=cos_sin[0],
sin=cos_sin[1],
q_len=q_len,
)
attn_output = self.attention_backend.decode(
attn_metadata,
fd_inter_tensor=fd_inter_tensor,
num_key_value_groups=self.num_key_value_groups,
q_len=q_len,
)
attn_output = attn_output.view(-1, self.hidden_size)
attn_output = self.o_proj(attn_output)

View File

@ -70,6 +70,9 @@ class NoPaddingBaichuanModelInferPolicy(LlamaForCausalLMPolicy, RPC_PARAM):
SubModuleReplacementDescription(
suffix="self_attn",
target_module=NopadBaichuanAttention,
kwargs={
"model_shard_infer_config": self.shard_config.extra_kwargs["model_shard_infer_config"],
},
),
],
)

View File

@ -72,6 +72,9 @@ class NoPaddingLlamaModelInferPolicy(LlamaForCausalLMPolicy, RPC_PARAM):
SubModuleReplacementDescription(
suffix="self_attn",
target_module=NopadLlamaAttention,
kwargs={
"model_shard_infer_config": self.shard_config.extra_kwargs["model_shard_infer_config"],
},
),
],
)

View File

@ -1,6 +1,7 @@
"""
Utils for model inference
"""
import math
import os
import re
from pathlib import Path
@ -9,8 +10,11 @@ from typing import Optional, Tuple
import torch
from torch import nn
from colossalai.logging import get_dist_logger
from colossalai.testing import free_port
logger = get_dist_logger(__name__)
def init_to_get_rotary(self, base=10000, use_elem=False):
"""
@ -113,3 +117,42 @@ def find_available_ports(num: int):
print(f"An OS error occurred: {e}")
raise RuntimeError("Error finding available ports")
return free_ports
def get_alibi_slopes(num_heads: int, device: torch.device) -> torch.Tensor:
"""
Alibi slopes calculation adapted from https://github.com/huggingface/transformers/blob/v4.36.0/src/transformers/models/bloom/modeling_bloom.py#L57
Args:
num_heads (int): The number of attention heads.
device (torch.device): The device to use.
Returns:
torch.Tensor: The Alibi slopes.
"""
closest_power_of_2 = 2 ** math.floor(math.log2(num_heads))
base = torch.tensor(2 ** (-(2 ** -(math.log2(closest_power_of_2) - 3))), dtype=torch.float32, device=device)
powers = torch.arange(1, 1 + closest_power_of_2, dtype=torch.int32, device=device)
slopes = torch.pow(base, powers)
if closest_power_of_2 != num_heads:
extra_base = torch.tensor(
2 ** (-(2 ** -(math.log2(2 * closest_power_of_2) - 3))), dtype=torch.float32, device=device
)
num_remaining_heads = min(closest_power_of_2, num_heads - closest_power_of_2)
extra_powers = torch.arange(1, 1 + 2 * num_remaining_heads, 2, dtype=torch.int32, device=device)
slopes = torch.cat([slopes, torch.pow(extra_base, extra_powers)], dim=0)
return slopes
def can_use_flash_attn2(dtype: torch.dtype) -> bool:
"""
Check flash attention2 availability.
"""
if dtype not in (torch.float16, torch.bfloat16):
return False
try:
return True
except ImportError:
logger.warning(f"flash_attn2 has not been installed yet, we will use triton flash attn instead.")
return False

View File

@ -4,7 +4,7 @@ import numpy as np
import pytest
import torch
from colossalai.inference.modeling.models.nopadding_baichuan import get_alibi_slopes
from colossalai.inference.utils import get_alibi_slopes
from colossalai.kernel.kernel_loader import InferenceOpsLoader
from colossalai.utils import get_current_device
from tests.test_infer.test_kernels.triton.test_context_attn_unpad import generate_alibi_mask

View File

@ -2,7 +2,7 @@ import pytest
import torch
from packaging import version
from colossalai.inference.modeling.models.nopadding_baichuan import get_alibi_slopes
from colossalai.inference.utils import get_alibi_slopes
from colossalai.kernel.triton import context_attention_unpadded
from colossalai.utils import get_current_device
from tests.test_infer.test_kernels.triton.kernel_utils import (

View File

@ -3,7 +3,7 @@ import pytest
import torch
from packaging import version
from colossalai.inference.modeling.models.nopadding_baichuan import get_alibi_slopes
from colossalai.inference.utils import get_alibi_slopes
from colossalai.kernel.triton import flash_decoding_attention
from colossalai.utils import get_current_device
from tests.test_infer.test_kernels.triton.kernel_utils import (

View File

@ -55,7 +55,7 @@ def check_inference_engine(use_engine=False, do_sample=False, use_cuda_kernel=Fa
assert inference_engine.generation_config.max_new_tokens == output_len
inference_engine.add_request(prompts=inputs)
assert inference_engine.request_handler._has_waiting()
generation_config = GenerationConfig(do_sample=do_sample, top_p=top_p, top_k=top_k)
generation_config = GenerationConfig(do_sample=do_sample, top_p=top_p, top_k=top_k, max_new_tokens=output_len)
outputs = inference_engine.generate(generation_config=generation_config)
else:
if prompt_template: