Refactor modeling by adding attention backend

Signed-off-by: char-1ee <xingjianli59@gmail.com>
pull/5771/head
char-1ee 2024-06-03 01:51:21 +00:00
parent 73e88a5553
commit 04386d9eff
9 changed files with 439 additions and 145 deletions

View File

@ -236,7 +236,7 @@ Completion api is used for single sequence request, like answer a question or co
- POST '/chat':
Chat api is used for conversation-style request, which often includes dialogue participants(i.e. roles) and corresponding words. Considering the input data are very different from normal inputs, we introduce Chat-Template to match the data format in chat models.
#### chat-template
Followed `transformers`, we add the chat-template argument. As chat models have been trained with very different formats for converting conversations into a single tokenizable string. Using a format that matches the training data is extremely important. This attribute(chat_template) is inclueded in HuggingFace tokenizers, containing a Jinja template that converts conversation histories into a correctly formatted string. You can refer to the [HuggingFace-blog](https://huggingface.co/blog/chat-templates) for more information. We also provide a simple example temlate bellow. Both str or file style chat template are supported.
Followed `transformers`, we add the chat-template argument. As chat models have been trained with very different formats for converting conversations into a single tokenizable string. Using a format that matches the training data is extremely important. This attribute(chat_template) is inclueded in HuggingFace tokenizers, containing a Jinja template that converts conversation histories into a correctly formatted string. You can refer to the [HuggingFace-blog](https://huggingface.co/blog/chat-templates) for more information. We also provide a simple example template bellow. Both str or file style chat template are supported.
### Usage
#### Args for customizing your server
The configuration for api server contains both serving interface and engine backend.

View File

@ -169,7 +169,8 @@ class InferenceConfig(RPC_PARAM):
no_repeat_ngram_size (Optional[int]): If no_repeat_ngram_size > 0, the consecutive tokens of ngram size can only appear once in inference sentences.
repetition_penalty (Optional[float]): The parameter that influences the model's treatment of new tokens in relation to their appearance in the prompt and the generated text. Values greater than 1 incentivize the model to introduce new tokens, whereas values less than 1 incentivize token repetition., defaults to 1.0.
ignore_eos(bool): Whether to ignore the EOS token and continue generating tokens when encountering the EOS token.
n_spec_tokens (int): The maximum number of speculating tokens, defaults to None.
use_spec_dec (bool): Indicate whether to use speculative decoding, defaults to False.
max_n_spec_tokens (int): The maximum number of speculating tokens, defaults to None.
glimpse_large_kv (bool): Whether to use large KV in drafter model, defaults to False.
block_size (int): The number of blocks in a logical block, defaults to 16.
tp_size (int): Tensor parallel size, defaults to 1.
@ -214,6 +215,7 @@ class InferenceConfig(RPC_PARAM):
ignore_eos: bool = False
# speculative decoding configs
use_spec_dec: bool = False
max_n_spec_tokens: int = 5
glimpse_large_kv: bool = False
@ -311,6 +313,15 @@ class InferenceConfig(RPC_PARAM):
return GenerationConfig.from_dict(meta_config)
def to_model_inference_config(self) -> "ModelInferenceConfig":
model_inference_config = ModelInferenceConfig(
dtype=self.dtype,
use_cuda_kernel=self.use_cuda_kernel,
use_spec_dec=self.use_spec_dec,
use_cuda_graph=self.use_cuda_graph,
)
return model_inference_config
def to_rpc_param(self) -> dict:
kwargs = {
"dtype": str(self.dtype).split(".")[-1],
@ -362,3 +373,22 @@ class InferenceConfig(RPC_PARAM):
# Set the attributes from the parsed arguments.
inference_config = cls(**inference_config_args)
return inference_config
@dataclass
class ModelInferenceConfig():
"""
Configurations used when initializing/sharding model for inference.
Args:
dtype (torch.dtype): The data type for weights and activations.
use_cuda_kernel (bool): Whether to use cuda kernel, faster but lose some precision occasionally
use_spec_dec (bool): Indicate whether to use speculative decoding.
use_flash_attn (bool): Indicate whether to use flash attention.
use_cuda_graph (bool): Whether to enforce CUDA graph execution. If False, we will disable CUDA graph and always execute the model in eager mode. If True, we will use eager execution in hybrid.
"""
dtype: torch.dtype = None
use_cuda_kernel: bool = False
use_spec_dec: bool = False
use_flash_attn: bool = False
use_cuda_graph: bool = False

View File

@ -18,7 +18,7 @@ from transformers.models.llama.modeling_llama import LlamaForCausalLM
from colossalai.accelerator import get_accelerator
from colossalai.cluster import ProcessGroupMesh
from colossalai.inference.batch_bucket import BatchBucket
from colossalai.inference.config import InferenceConfig, InputMetaData
from colossalai.inference.config import InferenceConfig, InputMetaData, ModelInferenceConfig
from colossalai.inference.graph_runner import CUDAGraphRunner
from colossalai.inference.modeling.policy import model_policy_map
from colossalai.inference.sampler import search_tokens
@ -72,8 +72,9 @@ class InferenceEngine:
self.verbose = verbose
self.logger = get_dist_logger(__name__)
self.model_inference_config = inference_config.to_model_inference_config()
self.init_model(model_or_path, model_policy)
self.init_model(model_or_path, model_policy, self.model_inference_config)
self.generation_config = inference_config.to_generation_config(self.model_config)
self.generation_config_dict = self.generation_config.to_dict()
@ -97,7 +98,10 @@ class InferenceEngine:
self.capture_model(self.k_cache, self.v_cache)
# Model and relatable attrs of speculative decoding will be set by `enable_spec_dec`
self.use_spec_dec = False
self.use_spec_dec = self.inference_config.use_spec_dec
# TODO: when use_spec_dec set to True, users should pass drafter_model configs into InferenceEngine
# We can add a SpecDecConfig class to store these configs.
self.drafter_model = None
self.drafter = None
self.use_glide = False
@ -105,13 +109,19 @@ class InferenceEngine:
self._verify_args()
def init_model(self, model_or_path: Union[nn.Module, str], model_policy: Union[Policy, Type[Policy]] = None):
def init_model(
self,
model_or_path: Union[nn.Module, str],
model_policy: Union[Policy, Type[Policy]] = None,
model_inference_config: ModelInferenceConfig = None,
):
"""
Shard model or/and Load weight
Args:
model_or_path Union[nn.Module, str]: path to the checkpoint or model of transformer format.
model_policy (Policy): the policy to replace the model
model_policy (Policy): the policy to replace the model.
model_inference_config: the configuration for modeling initialization when inference.
"""
if isinstance(model_or_path, str):
@ -124,6 +134,7 @@ class InferenceEngine:
# the model load process in the future.
model = _supported_models[arch].from_pretrained(model_or_path, trust_remote_code=True)
else:
# TODO(char-1ee): if the model not supported, use transformers APIs to load and generate
raise ValueError(f"Model {arch} is not supported.")
except Exception as e:
@ -167,6 +178,7 @@ class InferenceEngine:
self.model = self._shardformer(
model,
model_policy,
model_inference_config,
None,
tp_group=tp_group,
)
@ -187,7 +199,7 @@ class InferenceEngine:
# assert if_has_index_file, "the model path is invalid"
# cpt_io.load_model(self.model, model_index_file)
free_gpu_memory, total_gpu_memory = torch.cuda.mem_get_info()
free_gpu_memory, _ = torch.cuda.mem_get_info()
peak_memory = init_gpu_memory - free_gpu_memory
if self.verbose:
self.logger.info(
@ -287,6 +299,7 @@ class InferenceEngine:
self,
model: nn.Module,
model_policy: Policy,
model_inference_config: ModelInferenceConfig,
stage_manager: PipelineStageManager = None,
tp_group: ProcessGroupMesh = None,
) -> nn.Module:
@ -348,6 +361,8 @@ class InferenceEngine:
engine.clear_spec_dec()
```
"""
self.logger.warning(f"Current method will be deprecated soon. To use speculative decoding, please set `use_spec_dec` in `InferenceConfig` instead.")
if drafter_model is None and self.drafter is None:
raise ValueError("Drafter not initialized. Please provide a Drafter Model")
if n_spec_tokens is not None:
@ -517,19 +532,19 @@ class InferenceEngine:
prompts_token_ids: Union[List[int], torch.Tensor, np.ndarray] = None,
return_token_ids: bool = False,
generation_config: Optional[GenerationConfig] = None,
) -> List[str]:
) -> Union[List[str], Tuple[List[str], List[List[int]]]]:
"""
Executing the inference step.
Args:
prompts (Union[List[str], optional): Input prompts. Defaults to None.
prompts_token_ids (List[List[int]], optional): token ids of input prompts. Defaults to None.
request_ids (List[int], optional): The request ID. Defaults to None.
return_token_ids (bool): Whether to return output token ids. Defaults to False.
generation_config (GenerationConfig, optional): Huggingface GenerationConfig used for inference. Defaults to None.
prompts (Union[List[str], optional): Input prompts. Defaults to None.
prompts_token_ids (Union[List[int], torch.Tensor, np.ndarray], optional): token ids of input prompts. Defaults to None.
return_token_ids (bool, optional): Whether to return output token ids. Defaults to False.
generation_config (Optional[GenerationConfig], optional): Huggingface GenerationConfig used for inference. Defaults to None.
Returns:
List[str]: Inference result returned by one generation.
Union[List[str], Tuple[List[str], List[List[int]]]]: Inference result returned by one generation.
"""
gen_config_dict = generation_config.to_dict() if generation_config is not None else {}

View File

@ -0,0 +1,146 @@
from abc import ABC, abstractmethod
from dataclasses import dataclass
from flash_attn import flash_attn_varlen_func
import torch
from colossalai.inference.config import InputMetaData
from colossalai.inference.utils import can_use_flash_attn2
from colossalai.logging import get_dist_logger
from colossalai.kernel.kernel_loader import InferenceOpsLoader
from colossalai.kernel.triton import (
context_attention_unpadded,
flash_decoding_attention,
)
logger = get_dist_logger(__name__)
inference_ops = InferenceOpsLoader().load()
@dataclass
class AttentionMetaData:
query_states: torch.Tensor
key_states: torch.Tensor
value_states: torch.Tensor
k_cache: torch.Tensor
v_cache: torch.Tensor
block_tables: torch.Tensor
block_size: int
kv_seq_len: int = None
sequence_lengths: torch.Tensor = None
cu_seqlens: torch.Tensor = None
sm_scale: int = None
alibi_slopes: torch.Tensor = None
output_tensor: torch.Tensor = None
use_spec_dec: bool = False
use_alibi_attn: bool = False
class AttentionBackend(ABC):
@abstractmethod
def prefill(self, attn_metadata: AttentionMetaData, **kwargs):
raise NotImplementedError
@abstractmethod
def decode(self, attn_metadatas: AttentionMetaData, **kwargs):
raise NotImplementedError
class CudaAttentionBackend(AttentionBackend):
def prefill(self, attn_metadata: AttentionMetaData, **kwargs):
if not attn_metadata.use_spec_dec:
token_nums = kwargs.get('token_nums', -1)
attn_output = flash_attn_varlen_func(
attn_metadata.query_states,
attn_metadata.key_states,
attn_metadata.value_states,
cu_seqlens_q=attn_metadata.cu_seqlens,
cu_seqlens_k=attn_metadata.cu_seqlens,
max_seqlen_k=attn_metadata.kv_seq_len,
max_seqlen_v=attn_metadata.kv_seq_len,
dropout_p=0.0,
softmax_scale=attn_metadata.sm_scale,
causal=True,
)
attn_output = attn_output.view(token_nums, -1)
else:
attn_output = context_attention_unpadded(
q=attn_metadata.query_states,
k=attn_metadata.key_states,
v=attn_metadata.value_states,
k_cache=attn_metadata.k_cache,
v_cache=attn_metadata.v_cache,
context_lengths=attn_metadata.sequence_lengths,
block_tables=attn_metadata.block_tables,
block_size=attn_metadata.block_size,
output=attn_metadata.output_tensor,
max_seq_len=attn_metadata.kv_seq_len,
sm_scale=attn_metadata.sm_scale,
use_new_kcache_layout=True,
)
return attn_output
def decode(self, attn_metadata: AttentionMetaData, **kwargs):
fd_inter_tensor = kwargs.get('fd_inter_tensor', None)
output_tensor = attn_metadata.output_tensor
inference_ops.flash_decoding_attention(
output_tensor,
attn_metadata.query_states,
attn_metadata.k_cache,
attn_metadata.v_cache,
attn_metadata.sequence_lengths,
attn_metadata.block_tables,
attn_metadata.block_size,
attn_metadata.kv_seq_len,
fd_inter_tensor.mid_output,
fd_inter_tensor.exp_sums,
fd_inter_tensor.max_logits,
attn_metadata.alibi_slopes,
attn_metadata.sm_scale,
)
return output_tensor
class TritonAttentionBackend(AttentionBackend):
def prefill(self, attn_metadata: AttentionMetaData, **kwargs):
return context_attention_unpadded(
q=attn_metadata.query_states,
k=attn_metadata.key_states,
v=attn_metadata.value_states,
k_cache=attn_metadata.k_cache,
v_cache=attn_metadata.v_cache,
context_lengths=attn_metadata.sequence_lengths,
block_tables=attn_metadata.block_tables,
block_size=attn_metadata.block_size,
output=attn_metadata.output_tensor,
max_seq_len=attn_metadata.kv_seq_len,
sm_scale=attn_metadata.sm_scale,
use_new_kcache_layout=False,
)
def decode(self, attn_metadata: AttentionMetaData, **kwargs):
fd_inter_tensor = kwargs.get('fd_inter_tensor', None)
return flash_decoding_attention(
q=attn_metadata.query_states,
k_cache=attn_metadata.k_cache,
v_cache=attn_metadata.v_cache,
kv_seq_len=attn_metadata.sequence_lengths,
block_tables=attn_metadata.block_tables,
block_size=attn_metadata.block_size,
max_seq_len_in_batch=attn_metadata.kv_seq_len,
output=attn_metadata.output_tensor,
mid_output=fd_inter_tensor.mid_output,
mid_output_lse=fd_inter_tensor.mid_output_lse,
sm_scale=attn_metadata.sm_scale,
kv_group_num=kwargs.get('num_key_value_groups', 0),
q_len=kwargs.get('q_len', 1),
)
def get_attention_backend(use_spec_dec: bool, use_cuda_kernel: bool, dtype: torch.dtype) -> AttentionBackend:
use_flash_attn = can_use_flash_attn2(dtype)
if use_cuda_kernel and use_flash_attn and not use_spec_dec:
return CudaAttentionBackend()
else:
return TritonAttentionBackend()

View File

@ -0,0 +1,134 @@
from abc import ABC, abstractmethod
import torch
from colossalai.inference.utils import can_use_flash_attn2
from colossalai.kernel.kernel_loader import InferenceOpsLoader
from colossalai.inference.modeling.backends.attention_backend import AttentionMetaData
from colossalai.logging import get_dist_logger
from colossalai.kernel.triton import (
copy_k_to_blocked_cache,
decoding_fused_rotary_embedding,
rotary_embedding,
)
logger = get_dist_logger(__name__)
inference_ops = InferenceOpsLoader().load()
class AttentionContext(ABC):
@abstractmethod
def prefill(self, attn_metadata: AttentionMetaData, **kwargs):
raise NotImplementedError
@abstractmethod
def decode(self, attn_metadata: AttentionMetaData, **kwargs):
raise NotImplementedError
class CudaAttentionContext(AttentionContext):
def prefill(self, attn_metadata: AttentionMetaData, **kwargs):
if not attn_metadata.use_spec_dec:
if not attn_metadata.use_alibi_attn:
inference_ops.rotary_embedding(
attn_metadata.query_states,
attn_metadata.key_states,
kwargs.get('cos', None),
kwargs.get('sin', None),
kwargs.get('high_precision', False),
)
inference_ops.context_kv_cache_memcpy(
attn_metadata.key_states,
attn_metadata.value_states,
attn_metadata.k_cache,
attn_metadata.v_cache,
attn_metadata.sequence_lengths,
attn_metadata.cu_seqlens,
attn_metadata.block_tables,
attn_metadata.kv_seq_len,
)
else:
rotary_embedding(
attn_metadata.query_states,
attn_metadata.key_states,
kwargs.get('cos', None),
kwargs.get('sin', None),
)
def decode(self, attn_metadata: AttentionMetaData, **kwargs):
if attn_metadata.use_alibi_attn:
inference_ops.rotary_embedding_and_cache_copy(
attn_metadata.query_states,
attn_metadata.key_states,
attn_metadata.value_states,
kwargs.get('cos', None),
kwargs.get('sin', None),
attn_metadata.k_cache,
attn_metadata.v_cache,
attn_metadata.sequence_lengths,
attn_metadata.block_tables,
attn_metadata.high_precision,
)
else:
inference_ops.decode_kv_cache_memcpy(
attn_metadata.key_states,
attn_metadata.value_states,
attn_metadata.k_cache,
attn_metadata.v_cache,
attn_metadata.sequence_lengths,
attn_metadata.block_tables,
)
class TritonAttentionContext(AttentionContext):
def prefill(self, attn_metadata: AttentionMetaData, **kwargs):
if not attn_metadata.use_alibi_attn:
rotary_embedding(
attn_metadata.query_states,
attn_metadata.key_states,
kwargs.get('cos', None),
kwargs.get('sin', None),
)
def decode(self, attn_metadata: AttentionMetaData, **kwargs):
if not attn_metadata.use_spec_dec and not attn_metadata.use_alibi_attn:
decoding_fused_rotary_embedding(
attn_metadata.query_states,
attn_metadata.key_states,
attn_metadata.value_states,
kwargs.get('cos', None),
kwargs.get('sin', None),
attn_metadata.k_cache,
attn_metadata.v_cache,
attn_metadata.block_tables,
attn_metadata.sequence_lengths,
)
else:
if not attn_metadata.use_alibi_attn:
rotary_embedding(
attn_metadata.query_states,
attn_metadata.key_states,
kwargs.get('cos', None),
kwargs.get('sin', None),
)
copy_k_to_blocked_cache(
attn_metadata.key_states,
attn_metadata.k_cache,
kv_lengths=attn_metadata.sequence_lengths,
block_tables=attn_metadata.block_tables,
n=kwargs.get('q_len', 1)
)
copy_k_to_blocked_cache(
attn_metadata.value_states,
attn_metadata.v_cache,
kv_lengths=attn_metadata.sequence_lengths,
block_tables=attn_metadata.block_tables,
n=kwargs.get('q_len', 1)
)
def get_attention_context(use_spec_dec: bool, use_cuda_kernel: bool, dtype: torch.dtype) -> AttentionContext:
use_flash_attn = can_use_flash_attn2(dtype)
if use_cuda_kernel and use_flash_attn and not use_spec_dec:
return CudaAttentionContext()
else:
return TritonAttentionContext()

View File

@ -8,6 +8,7 @@ import torch.nn as nn
from torch.distributed import ProcessGroup
from colossalai.inference.flash_decoding_utils import FDIntermTensors
from colossalai.inference.utils import get_alibi_slopes
from colossalai.inference.modeling.models.nopadding_llama import NopadLlamaMLP
from colossalai.kernel.kernel_loader import InferenceOpsLoader
from colossalai.kernel.triton import (
@ -47,22 +48,6 @@ inference_ops = InferenceOpsLoader().load()
logger = get_dist_logger(__name__)
# alibi slopes calculation adapted from https://github.com/huggingface/transformers/blob/v4.36.0/src/transformers/models/bloom/modeling_bloom.py#L57
def get_alibi_slopes(num_heads: int, device: torch.device) -> torch.Tensor:
closest_power_of_2 = 2 ** math.floor(math.log2(num_heads))
base = torch.tensor(2 ** (-(2 ** -(math.log2(closest_power_of_2) - 3))), dtype=torch.float32, device=device)
powers = torch.arange(1, 1 + closest_power_of_2, dtype=torch.int32, device=device)
slopes = torch.pow(base, powers)
if closest_power_of_2 != num_heads:
extra_base = torch.tensor(
2 ** (-(2 ** -(math.log2(2 * closest_power_of_2) - 3))), dtype=torch.float32, device=device
)
num_remaining_heads = min(closest_power_of_2, num_heads - closest_power_of_2)
extra_powers = torch.arange(1, 1 + 2 * num_remaining_heads, 2, dtype=torch.int32, device=device)
slopes = torch.cat([slopes, torch.pow(extra_base, extra_powers)], dim=0)
return slopes
def baichuan_rmsnorm_forward(
self,
hidden_states: torch.Tensor,

View File

@ -18,6 +18,9 @@ from transformers.models.llama.modeling_llama import (
from colossalai.inference.config import InputMetaData
from colossalai.inference.flash_decoding_utils import FDIntermTensors
from colossalai.inference.modeling.backends.attention_backend import get_attention_backend, AttentionMetaData
from colossalai.inference.modeling.backends.attention_context import get_attention_context
from colossalai.inference.utils import can_use_flash_attn2
from colossalai.kernel.kernel_loader import InferenceOpsLoader
from colossalai.kernel.triton import (
context_attention_unpadded,
@ -36,14 +39,6 @@ inference_ops = InferenceOpsLoader().load()
logger = get_dist_logger(__name__)
try:
from flash_attn import flash_attn_varlen_func
use_flash_attn2 = True
except ImportError:
use_flash_attn2 = False
logger.warning(f"flash_attn2 has not been installed yet, we will use triton flash attn instead.")
def llama_causal_lm_forward(
self: LlamaForCausalLM,
@ -126,7 +121,7 @@ def llama_model_forward(
cos_sin = (self._cos_cached[rotary_indexes], self._sin_cached[rotary_indexes])
elif use_cuda_kernel:
if inputmetadata.dtype != torch.float32 and use_flash_attn2:
if inputmetadata.dtype != torch.float32 and can_use_flash_attn2():
cu_seqlens = F.pad(torch.cumsum(sequence_lengths, dim=0, dtype=torch.torch.int32), (1, 0))
hidden_dim = self._cos_cached.size(-1)
@ -533,109 +528,51 @@ class NopadLlamaAttention(LlamaAttention, ParallelModule):
block_size = k_cache.size(-2)
if is_prompts:
if not is_verifier and use_cuda_kernel and query_states.dtype != torch.float32 and use_flash_attn2:
# flash attn 2 currently only supports FP16/BF16.
inference_ops.rotary_embedding(query_states, key_states, cos_sin[0], cos_sin[1], high_precision)
inference_ops.context_kv_cache_memcpy(
key_states, value_states, k_cache, v_cache, sequence_lengths, cu_seqlens, block_tables, kv_seq_len
)
attn_output = flash_attn_varlen_func(
query_states,
key_states,
value_states,
cu_seqlens_q=cu_seqlens,
cu_seqlens_k=cu_seqlens,
max_seqlen_q=kv_seq_len,
max_seqlen_k=kv_seq_len,
dropout_p=0.0,
softmax_scale=sm_scale,
causal=True,
)
attn_output = attn_output.view(token_nums, -1)
else:
rotary_embedding(query_states, key_states, cos_sin[0], cos_sin[1])
attn_output = context_attention_unpadded(
q=query_states,
k=key_states,
v=value_states,
attn_metadata = AttentionMetaData(
query_states=query_states,
key_states=key_states,
value_states=value_states,
k_cache=k_cache,
v_cache=v_cache,
context_lengths=sequence_lengths,
block_tables=block_tables,
block_size=block_size,
output=output_tensor,
max_seq_len=kv_seq_len,
kv_seq_len=kv_seq_len,
sequence_lengths=sequence_lengths,
sm_scale=sm_scale,
use_new_kcache_layout=use_cuda_kernel,
alibi_slopes=None,
cu_seqlens=cu_seqlens,
output_tensor=output_tensor,
use_spec_dec=is_verifier,
use_alibi_attn=False,
)
else:
attention_backend = get_attention_backend(use_spec_dec=is_verifier, use_cuda_kernel=use_cuda_kernel, dtype=query_states.dtype)
attention_context = get_attention_context(use_spec_dec=is_verifier, use_cuda_kernel=use_cuda_kernel, dtype=query_states.dtype)
if is_prompts: # prefilling stage
attention_context.prefill(
attn_metadata,
cos=cos_sin[0],
sin=cos_sin[1],
high_precision=high_precision,
)
attn_output = attention_backend.prefill(
attn_metadata,
token_nums=token_nums,
)
else: # decoding stage
q_len = tokens_to_verify + 1 if is_verifier else 1
if use_cuda_kernel:
inference_ops.rotary_embedding_and_cache_copy(
query_states,
key_states,
value_states,
cos_sin[0],
cos_sin[1],
k_cache,
v_cache,
sequence_lengths,
block_tables,
high_precision,
attention_context.decode(
attn_metadata,
cos=cos_sin[0],
sin=cos_sin[1],
q_len=q_len,
)
inference_ops.flash_decoding_attention(
output_tensor,
query_states,
k_cache,
v_cache,
sequence_lengths,
block_tables,
block_size,
kv_seq_len,
fd_inter_tensor.mid_output,
fd_inter_tensor.exp_sums,
fd_inter_tensor.max_logits,
None,
sm_scale,
)
attn_output = output_tensor
else:
if is_verifier:
rotary_embedding(query_states, key_states, cos_sin[0], cos_sin[1])
copy_k_to_blocked_cache(
key_states, k_cache, kv_lengths=sequence_lengths, block_tables=block_tables, n=q_len
)
copy_k_to_blocked_cache(
value_states, v_cache, kv_lengths=sequence_lengths, block_tables=block_tables, n=q_len
)
else:
decoding_fused_rotary_embedding(
query_states,
key_states,
value_states,
cos_sin[0],
cos_sin[1],
k_cache,
v_cache,
block_tables,
sequence_lengths,
)
attn_output = flash_decoding_attention(
q=query_states,
k_cache=k_cache,
v_cache=v_cache,
kv_seq_len=sequence_lengths,
block_tables=block_tables,
block_size=block_size,
max_seq_len_in_batch=kv_seq_len,
output=output_tensor,
mid_output=fd_inter_tensor.mid_output,
mid_output_lse=fd_inter_tensor.mid_output_lse,
sm_scale=sm_scale,
kv_group_num=self.num_key_value_groups,
attn_output = attention_backend.decode(
attn_metadata,
fd_inter_tensor=fd_inter_tensor,
num_key_value_groups=self.num_key_value_groups,
q_len=q_len,
)
@ -695,3 +632,4 @@ class NopadLlamaAttention(LlamaAttention, ParallelModule):
def extra_repr(self) -> str:
return f"qkv_weight_proj MergedLinear1D_Col: in_features={self.qkv_weight.shape[1]}x3, out_features={self.qkv_weight.shape[2]}, bias=False"

View File

@ -3,6 +3,7 @@ Utils for model inference
"""
import os
import re
import math
from pathlib import Path
from typing import Optional, Tuple
@ -10,6 +11,9 @@ import torch
from torch import nn
from colossalai.testing import free_port
from colossalai.logging import get_dist_logger
logger = get_dist_logger(__name__)
def init_to_get_rotary(self, base=10000, use_elem=False):
@ -113,3 +117,45 @@ def find_available_ports(num: int):
print(f"An OS error occurred: {e}")
raise RuntimeError("Error finding available ports")
return free_ports
def get_alibi_slopes(num_heads: int, device: torch.device) -> torch.Tensor:
"""
Alibi slopes calculation adapted from https://github.com/huggingface/transformers/blob/v4.36.0/src/transformers/models/bloom/modeling_bloom.py#L57
Args:
num_heads (int): The number of attention heads.
device (torch.device): The device to use.
Returns:
torch.Tensor: The Alibi slopes.
"""
closest_power_of_2 = 2 ** math.floor(math.log2(num_heads))
base = torch.tensor(2 ** (-(2 ** -(math.log2(closest_power_of_2) - 3))), dtype=torch.float32, device=device)
powers = torch.arange(1, 1 + closest_power_of_2, dtype=torch.int32, device=device)
slopes = torch.pow(base, powers)
if closest_power_of_2 != num_heads:
extra_base = torch.tensor(
2 ** (-(2 ** -(math.log2(2 * closest_power_of_2) - 3))), dtype=torch.float32, device=device
)
num_remaining_heads = min(closest_power_of_2, num_heads - closest_power_of_2)
extra_powers = torch.arange(1, 1 + 2 * num_remaining_heads, 2, dtype=torch.int32, device=device)
slopes = torch.cat([slopes, torch.pow(extra_base, extra_powers)], dim=0)
return slopes
def can_use_flash_attn2(dtype: torch.dtype) -> bool:
"""
Check flash attention2 availability.
"""
if dtype not in (torch.float16, torch.bfloat16):
logger.warning(f"Flash attn2 currently only supports float16 and bfloat16.")
return False
try:
from flash_attn import __version__
logger.info(f"flash_attn2 version {__version__}.")
return True
except ImportError:
logger.warning(f"flash_attn2 has not been installed yet, we will use triton flash attn instead.")
return False