mirror of https://github.com/hpcaitech/ColossalAI
Refactor modeling by adding attention backend
Signed-off-by: char-1ee <xingjianli59@gmail.com>pull/5771/head
parent
73e88a5553
commit
04386d9eff
|
@ -236,7 +236,7 @@ Completion api is used for single sequence request, like answer a question or co
|
|||
- POST '/chat':
|
||||
Chat api is used for conversation-style request, which often includes dialogue participants(i.e. roles) and corresponding words. Considering the input data are very different from normal inputs, we introduce Chat-Template to match the data format in chat models.
|
||||
#### chat-template
|
||||
Followed `transformers`, we add the chat-template argument. As chat models have been trained with very different formats for converting conversations into a single tokenizable string. Using a format that matches the training data is extremely important. This attribute(chat_template) is inclueded in HuggingFace tokenizers, containing a Jinja template that converts conversation histories into a correctly formatted string. You can refer to the [HuggingFace-blog](https://huggingface.co/blog/chat-templates) for more information. We also provide a simple example temlate bellow. Both str or file style chat template are supported.
|
||||
Followed `transformers`, we add the chat-template argument. As chat models have been trained with very different formats for converting conversations into a single tokenizable string. Using a format that matches the training data is extremely important. This attribute(chat_template) is inclueded in HuggingFace tokenizers, containing a Jinja template that converts conversation histories into a correctly formatted string. You can refer to the [HuggingFace-blog](https://huggingface.co/blog/chat-templates) for more information. We also provide a simple example template bellow. Both str or file style chat template are supported.
|
||||
### Usage
|
||||
#### Args for customizing your server
|
||||
The configuration for api server contains both serving interface and engine backend.
|
||||
|
|
|
@ -169,7 +169,8 @@ class InferenceConfig(RPC_PARAM):
|
|||
no_repeat_ngram_size (Optional[int]): If no_repeat_ngram_size > 0, the consecutive tokens of ngram size can only appear once in inference sentences.
|
||||
repetition_penalty (Optional[float]): The parameter that influences the model's treatment of new tokens in relation to their appearance in the prompt and the generated text. Values greater than 1 incentivize the model to introduce new tokens, whereas values less than 1 incentivize token repetition., defaults to 1.0.
|
||||
ignore_eos(bool): Whether to ignore the EOS token and continue generating tokens when encountering the EOS token.
|
||||
n_spec_tokens (int): The maximum number of speculating tokens, defaults to None.
|
||||
use_spec_dec (bool): Indicate whether to use speculative decoding, defaults to False.
|
||||
max_n_spec_tokens (int): The maximum number of speculating tokens, defaults to None.
|
||||
glimpse_large_kv (bool): Whether to use large KV in drafter model, defaults to False.
|
||||
block_size (int): The number of blocks in a logical block, defaults to 16.
|
||||
tp_size (int): Tensor parallel size, defaults to 1.
|
||||
|
@ -214,6 +215,7 @@ class InferenceConfig(RPC_PARAM):
|
|||
ignore_eos: bool = False
|
||||
|
||||
# speculative decoding configs
|
||||
use_spec_dec: bool = False
|
||||
max_n_spec_tokens: int = 5
|
||||
glimpse_large_kv: bool = False
|
||||
|
||||
|
@ -311,6 +313,15 @@ class InferenceConfig(RPC_PARAM):
|
|||
|
||||
return GenerationConfig.from_dict(meta_config)
|
||||
|
||||
def to_model_inference_config(self) -> "ModelInferenceConfig":
|
||||
model_inference_config = ModelInferenceConfig(
|
||||
dtype=self.dtype,
|
||||
use_cuda_kernel=self.use_cuda_kernel,
|
||||
use_spec_dec=self.use_spec_dec,
|
||||
use_cuda_graph=self.use_cuda_graph,
|
||||
)
|
||||
return model_inference_config
|
||||
|
||||
def to_rpc_param(self) -> dict:
|
||||
kwargs = {
|
||||
"dtype": str(self.dtype).split(".")[-1],
|
||||
|
@ -362,3 +373,22 @@ class InferenceConfig(RPC_PARAM):
|
|||
# Set the attributes from the parsed arguments.
|
||||
inference_config = cls(**inference_config_args)
|
||||
return inference_config
|
||||
|
||||
@dataclass
|
||||
class ModelInferenceConfig():
|
||||
"""
|
||||
Configurations used when initializing/sharding model for inference.
|
||||
|
||||
Args:
|
||||
dtype (torch.dtype): The data type for weights and activations.
|
||||
use_cuda_kernel (bool): Whether to use cuda kernel, faster but lose some precision occasionally
|
||||
use_spec_dec (bool): Indicate whether to use speculative decoding.
|
||||
use_flash_attn (bool): Indicate whether to use flash attention.
|
||||
use_cuda_graph (bool): Whether to enforce CUDA graph execution. If False, we will disable CUDA graph and always execute the model in eager mode. If True, we will use eager execution in hybrid.
|
||||
"""
|
||||
dtype: torch.dtype = None
|
||||
use_cuda_kernel: bool = False
|
||||
use_spec_dec: bool = False
|
||||
use_flash_attn: bool = False
|
||||
use_cuda_graph: bool = False
|
||||
|
|
@ -18,7 +18,7 @@ from transformers.models.llama.modeling_llama import LlamaForCausalLM
|
|||
from colossalai.accelerator import get_accelerator
|
||||
from colossalai.cluster import ProcessGroupMesh
|
||||
from colossalai.inference.batch_bucket import BatchBucket
|
||||
from colossalai.inference.config import InferenceConfig, InputMetaData
|
||||
from colossalai.inference.config import InferenceConfig, InputMetaData, ModelInferenceConfig
|
||||
from colossalai.inference.graph_runner import CUDAGraphRunner
|
||||
from colossalai.inference.modeling.policy import model_policy_map
|
||||
from colossalai.inference.sampler import search_tokens
|
||||
|
@ -72,8 +72,9 @@ class InferenceEngine:
|
|||
|
||||
self.verbose = verbose
|
||||
self.logger = get_dist_logger(__name__)
|
||||
self.model_inference_config = inference_config.to_model_inference_config()
|
||||
|
||||
self.init_model(model_or_path, model_policy)
|
||||
self.init_model(model_or_path, model_policy, self.model_inference_config)
|
||||
|
||||
self.generation_config = inference_config.to_generation_config(self.model_config)
|
||||
self.generation_config_dict = self.generation_config.to_dict()
|
||||
|
@ -97,7 +98,10 @@ class InferenceEngine:
|
|||
self.capture_model(self.k_cache, self.v_cache)
|
||||
|
||||
# Model and relatable attrs of speculative decoding will be set by `enable_spec_dec`
|
||||
self.use_spec_dec = False
|
||||
self.use_spec_dec = self.inference_config.use_spec_dec
|
||||
|
||||
# TODO: when use_spec_dec set to True, users should pass drafter_model configs into InferenceEngine
|
||||
# We can add a SpecDecConfig class to store these configs.
|
||||
self.drafter_model = None
|
||||
self.drafter = None
|
||||
self.use_glide = False
|
||||
|
@ -105,13 +109,19 @@ class InferenceEngine:
|
|||
|
||||
self._verify_args()
|
||||
|
||||
def init_model(self, model_or_path: Union[nn.Module, str], model_policy: Union[Policy, Type[Policy]] = None):
|
||||
def init_model(
|
||||
self,
|
||||
model_or_path: Union[nn.Module, str],
|
||||
model_policy: Union[Policy, Type[Policy]] = None,
|
||||
model_inference_config: ModelInferenceConfig = None,
|
||||
):
|
||||
"""
|
||||
Shard model or/and Load weight
|
||||
|
||||
Args:
|
||||
model_or_path Union[nn.Module, str]: path to the checkpoint or model of transformer format.
|
||||
model_policy (Policy): the policy to replace the model
|
||||
model_policy (Policy): the policy to replace the model.
|
||||
model_inference_config: the configuration for modeling initialization when inference.
|
||||
"""
|
||||
|
||||
if isinstance(model_or_path, str):
|
||||
|
@ -124,6 +134,7 @@ class InferenceEngine:
|
|||
# the model load process in the future.
|
||||
model = _supported_models[arch].from_pretrained(model_or_path, trust_remote_code=True)
|
||||
else:
|
||||
# TODO(char-1ee): if the model not supported, use transformers APIs to load and generate
|
||||
raise ValueError(f"Model {arch} is not supported.")
|
||||
|
||||
except Exception as e:
|
||||
|
@ -167,6 +178,7 @@ class InferenceEngine:
|
|||
self.model = self._shardformer(
|
||||
model,
|
||||
model_policy,
|
||||
model_inference_config,
|
||||
None,
|
||||
tp_group=tp_group,
|
||||
)
|
||||
|
@ -187,7 +199,7 @@ class InferenceEngine:
|
|||
# assert if_has_index_file, "the model path is invalid"
|
||||
# cpt_io.load_model(self.model, model_index_file)
|
||||
|
||||
free_gpu_memory, total_gpu_memory = torch.cuda.mem_get_info()
|
||||
free_gpu_memory, _ = torch.cuda.mem_get_info()
|
||||
peak_memory = init_gpu_memory - free_gpu_memory
|
||||
if self.verbose:
|
||||
self.logger.info(
|
||||
|
@ -287,6 +299,7 @@ class InferenceEngine:
|
|||
self,
|
||||
model: nn.Module,
|
||||
model_policy: Policy,
|
||||
model_inference_config: ModelInferenceConfig,
|
||||
stage_manager: PipelineStageManager = None,
|
||||
tp_group: ProcessGroupMesh = None,
|
||||
) -> nn.Module:
|
||||
|
@ -348,6 +361,8 @@ class InferenceEngine:
|
|||
engine.clear_spec_dec()
|
||||
```
|
||||
"""
|
||||
self.logger.warning(f"Current method will be deprecated soon. To use speculative decoding, please set `use_spec_dec` in `InferenceConfig` instead.")
|
||||
|
||||
if drafter_model is None and self.drafter is None:
|
||||
raise ValueError("Drafter not initialized. Please provide a Drafter Model")
|
||||
if n_spec_tokens is not None:
|
||||
|
@ -517,19 +532,19 @@ class InferenceEngine:
|
|||
prompts_token_ids: Union[List[int], torch.Tensor, np.ndarray] = None,
|
||||
return_token_ids: bool = False,
|
||||
generation_config: Optional[GenerationConfig] = None,
|
||||
) -> List[str]:
|
||||
) -> Union[List[str], Tuple[List[str], List[List[int]]]]:
|
||||
"""
|
||||
Executing the inference step.
|
||||
|
||||
Args:
|
||||
prompts (Union[List[str], optional): Input prompts. Defaults to None.
|
||||
prompts_token_ids (List[List[int]], optional): token ids of input prompts. Defaults to None.
|
||||
request_ids (List[int], optional): The request ID. Defaults to None.
|
||||
return_token_ids (bool): Whether to return output token ids. Defaults to False.
|
||||
generation_config (GenerationConfig, optional): Huggingface GenerationConfig used for inference. Defaults to None.
|
||||
prompts (Union[List[str], optional): Input prompts. Defaults to None.
|
||||
prompts_token_ids (Union[List[int], torch.Tensor, np.ndarray], optional): token ids of input prompts. Defaults to None.
|
||||
return_token_ids (bool, optional): Whether to return output token ids. Defaults to False.
|
||||
generation_config (Optional[GenerationConfig], optional): Huggingface GenerationConfig used for inference. Defaults to None.
|
||||
|
||||
Returns:
|
||||
List[str]: Inference result returned by one generation.
|
||||
Union[List[str], Tuple[List[str], List[List[int]]]]: Inference result returned by one generation.
|
||||
"""
|
||||
|
||||
gen_config_dict = generation_config.to_dict() if generation_config is not None else {}
|
||||
|
|
|
@ -0,0 +1,146 @@
|
|||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from flash_attn import flash_attn_varlen_func
|
||||
import torch
|
||||
|
||||
from colossalai.inference.config import InputMetaData
|
||||
from colossalai.inference.utils import can_use_flash_attn2
|
||||
from colossalai.logging import get_dist_logger
|
||||
from colossalai.kernel.kernel_loader import InferenceOpsLoader
|
||||
from colossalai.kernel.triton import (
|
||||
context_attention_unpadded,
|
||||
flash_decoding_attention,
|
||||
)
|
||||
|
||||
logger = get_dist_logger(__name__)
|
||||
inference_ops = InferenceOpsLoader().load()
|
||||
|
||||
@dataclass
|
||||
class AttentionMetaData:
|
||||
query_states: torch.Tensor
|
||||
key_states: torch.Tensor
|
||||
value_states: torch.Tensor
|
||||
k_cache: torch.Tensor
|
||||
v_cache: torch.Tensor
|
||||
block_tables: torch.Tensor
|
||||
block_size: int
|
||||
kv_seq_len: int = None
|
||||
sequence_lengths: torch.Tensor = None
|
||||
cu_seqlens: torch.Tensor = None
|
||||
sm_scale: int = None
|
||||
alibi_slopes: torch.Tensor = None
|
||||
output_tensor: torch.Tensor = None
|
||||
use_spec_dec: bool = False
|
||||
use_alibi_attn: bool = False
|
||||
|
||||
|
||||
class AttentionBackend(ABC):
|
||||
@abstractmethod
|
||||
def prefill(self, attn_metadata: AttentionMetaData, **kwargs):
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def decode(self, attn_metadatas: AttentionMetaData, **kwargs):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class CudaAttentionBackend(AttentionBackend):
|
||||
def prefill(self, attn_metadata: AttentionMetaData, **kwargs):
|
||||
if not attn_metadata.use_spec_dec:
|
||||
token_nums = kwargs.get('token_nums', -1)
|
||||
|
||||
attn_output = flash_attn_varlen_func(
|
||||
attn_metadata.query_states,
|
||||
attn_metadata.key_states,
|
||||
attn_metadata.value_states,
|
||||
cu_seqlens_q=attn_metadata.cu_seqlens,
|
||||
cu_seqlens_k=attn_metadata.cu_seqlens,
|
||||
max_seqlen_k=attn_metadata.kv_seq_len,
|
||||
max_seqlen_v=attn_metadata.kv_seq_len,
|
||||
dropout_p=0.0,
|
||||
softmax_scale=attn_metadata.sm_scale,
|
||||
causal=True,
|
||||
)
|
||||
attn_output = attn_output.view(token_nums, -1)
|
||||
else:
|
||||
attn_output = context_attention_unpadded(
|
||||
q=attn_metadata.query_states,
|
||||
k=attn_metadata.key_states,
|
||||
v=attn_metadata.value_states,
|
||||
k_cache=attn_metadata.k_cache,
|
||||
v_cache=attn_metadata.v_cache,
|
||||
context_lengths=attn_metadata.sequence_lengths,
|
||||
block_tables=attn_metadata.block_tables,
|
||||
block_size=attn_metadata.block_size,
|
||||
output=attn_metadata.output_tensor,
|
||||
max_seq_len=attn_metadata.kv_seq_len,
|
||||
sm_scale=attn_metadata.sm_scale,
|
||||
use_new_kcache_layout=True,
|
||||
)
|
||||
return attn_output
|
||||
|
||||
|
||||
def decode(self, attn_metadata: AttentionMetaData, **kwargs):
|
||||
fd_inter_tensor = kwargs.get('fd_inter_tensor', None)
|
||||
output_tensor = attn_metadata.output_tensor
|
||||
inference_ops.flash_decoding_attention(
|
||||
output_tensor,
|
||||
attn_metadata.query_states,
|
||||
attn_metadata.k_cache,
|
||||
attn_metadata.v_cache,
|
||||
attn_metadata.sequence_lengths,
|
||||
attn_metadata.block_tables,
|
||||
attn_metadata.block_size,
|
||||
attn_metadata.kv_seq_len,
|
||||
fd_inter_tensor.mid_output,
|
||||
fd_inter_tensor.exp_sums,
|
||||
fd_inter_tensor.max_logits,
|
||||
attn_metadata.alibi_slopes,
|
||||
attn_metadata.sm_scale,
|
||||
)
|
||||
return output_tensor
|
||||
|
||||
|
||||
class TritonAttentionBackend(AttentionBackend):
|
||||
def prefill(self, attn_metadata: AttentionMetaData, **kwargs):
|
||||
return context_attention_unpadded(
|
||||
q=attn_metadata.query_states,
|
||||
k=attn_metadata.key_states,
|
||||
v=attn_metadata.value_states,
|
||||
k_cache=attn_metadata.k_cache,
|
||||
v_cache=attn_metadata.v_cache,
|
||||
context_lengths=attn_metadata.sequence_lengths,
|
||||
block_tables=attn_metadata.block_tables,
|
||||
block_size=attn_metadata.block_size,
|
||||
output=attn_metadata.output_tensor,
|
||||
max_seq_len=attn_metadata.kv_seq_len,
|
||||
sm_scale=attn_metadata.sm_scale,
|
||||
use_new_kcache_layout=False,
|
||||
)
|
||||
|
||||
def decode(self, attn_metadata: AttentionMetaData, **kwargs):
|
||||
fd_inter_tensor = kwargs.get('fd_inter_tensor', None)
|
||||
return flash_decoding_attention(
|
||||
q=attn_metadata.query_states,
|
||||
k_cache=attn_metadata.k_cache,
|
||||
v_cache=attn_metadata.v_cache,
|
||||
kv_seq_len=attn_metadata.sequence_lengths,
|
||||
block_tables=attn_metadata.block_tables,
|
||||
block_size=attn_metadata.block_size,
|
||||
max_seq_len_in_batch=attn_metadata.kv_seq_len,
|
||||
output=attn_metadata.output_tensor,
|
||||
mid_output=fd_inter_tensor.mid_output,
|
||||
mid_output_lse=fd_inter_tensor.mid_output_lse,
|
||||
sm_scale=attn_metadata.sm_scale,
|
||||
kv_group_num=kwargs.get('num_key_value_groups', 0),
|
||||
q_len=kwargs.get('q_len', 1),
|
||||
)
|
||||
|
||||
|
||||
def get_attention_backend(use_spec_dec: bool, use_cuda_kernel: bool, dtype: torch.dtype) -> AttentionBackend:
|
||||
use_flash_attn = can_use_flash_attn2(dtype)
|
||||
if use_cuda_kernel and use_flash_attn and not use_spec_dec:
|
||||
return CudaAttentionBackend()
|
||||
else:
|
||||
return TritonAttentionBackend()
|
||||
|
|
@ -0,0 +1,134 @@
|
|||
from abc import ABC, abstractmethod
|
||||
import torch
|
||||
|
||||
from colossalai.inference.utils import can_use_flash_attn2
|
||||
from colossalai.kernel.kernel_loader import InferenceOpsLoader
|
||||
from colossalai.inference.modeling.backends.attention_backend import AttentionMetaData
|
||||
from colossalai.logging import get_dist_logger
|
||||
from colossalai.kernel.triton import (
|
||||
copy_k_to_blocked_cache,
|
||||
decoding_fused_rotary_embedding,
|
||||
rotary_embedding,
|
||||
)
|
||||
|
||||
logger = get_dist_logger(__name__)
|
||||
inference_ops = InferenceOpsLoader().load()
|
||||
|
||||
|
||||
class AttentionContext(ABC):
|
||||
@abstractmethod
|
||||
def prefill(self, attn_metadata: AttentionMetaData, **kwargs):
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def decode(self, attn_metadata: AttentionMetaData, **kwargs):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class CudaAttentionContext(AttentionContext):
|
||||
def prefill(self, attn_metadata: AttentionMetaData, **kwargs):
|
||||
if not attn_metadata.use_spec_dec:
|
||||
if not attn_metadata.use_alibi_attn:
|
||||
inference_ops.rotary_embedding(
|
||||
attn_metadata.query_states,
|
||||
attn_metadata.key_states,
|
||||
kwargs.get('cos', None),
|
||||
kwargs.get('sin', None),
|
||||
kwargs.get('high_precision', False),
|
||||
)
|
||||
inference_ops.context_kv_cache_memcpy(
|
||||
attn_metadata.key_states,
|
||||
attn_metadata.value_states,
|
||||
attn_metadata.k_cache,
|
||||
attn_metadata.v_cache,
|
||||
attn_metadata.sequence_lengths,
|
||||
attn_metadata.cu_seqlens,
|
||||
attn_metadata.block_tables,
|
||||
attn_metadata.kv_seq_len,
|
||||
)
|
||||
else:
|
||||
rotary_embedding(
|
||||
attn_metadata.query_states,
|
||||
attn_metadata.key_states,
|
||||
kwargs.get('cos', None),
|
||||
kwargs.get('sin', None),
|
||||
)
|
||||
|
||||
def decode(self, attn_metadata: AttentionMetaData, **kwargs):
|
||||
if attn_metadata.use_alibi_attn:
|
||||
inference_ops.rotary_embedding_and_cache_copy(
|
||||
attn_metadata.query_states,
|
||||
attn_metadata.key_states,
|
||||
attn_metadata.value_states,
|
||||
kwargs.get('cos', None),
|
||||
kwargs.get('sin', None),
|
||||
attn_metadata.k_cache,
|
||||
attn_metadata.v_cache,
|
||||
attn_metadata.sequence_lengths,
|
||||
attn_metadata.block_tables,
|
||||
attn_metadata.high_precision,
|
||||
)
|
||||
else:
|
||||
inference_ops.decode_kv_cache_memcpy(
|
||||
attn_metadata.key_states,
|
||||
attn_metadata.value_states,
|
||||
attn_metadata.k_cache,
|
||||
attn_metadata.v_cache,
|
||||
attn_metadata.sequence_lengths,
|
||||
attn_metadata.block_tables,
|
||||
)
|
||||
|
||||
|
||||
class TritonAttentionContext(AttentionContext):
|
||||
def prefill(self, attn_metadata: AttentionMetaData, **kwargs):
|
||||
if not attn_metadata.use_alibi_attn:
|
||||
rotary_embedding(
|
||||
attn_metadata.query_states,
|
||||
attn_metadata.key_states,
|
||||
kwargs.get('cos', None),
|
||||
kwargs.get('sin', None),
|
||||
)
|
||||
|
||||
def decode(self, attn_metadata: AttentionMetaData, **kwargs):
|
||||
if not attn_metadata.use_spec_dec and not attn_metadata.use_alibi_attn:
|
||||
decoding_fused_rotary_embedding(
|
||||
attn_metadata.query_states,
|
||||
attn_metadata.key_states,
|
||||
attn_metadata.value_states,
|
||||
kwargs.get('cos', None),
|
||||
kwargs.get('sin', None),
|
||||
attn_metadata.k_cache,
|
||||
attn_metadata.v_cache,
|
||||
attn_metadata.block_tables,
|
||||
attn_metadata.sequence_lengths,
|
||||
)
|
||||
else:
|
||||
if not attn_metadata.use_alibi_attn:
|
||||
rotary_embedding(
|
||||
attn_metadata.query_states,
|
||||
attn_metadata.key_states,
|
||||
kwargs.get('cos', None),
|
||||
kwargs.get('sin', None),
|
||||
)
|
||||
copy_k_to_blocked_cache(
|
||||
attn_metadata.key_states,
|
||||
attn_metadata.k_cache,
|
||||
kv_lengths=attn_metadata.sequence_lengths,
|
||||
block_tables=attn_metadata.block_tables,
|
||||
n=kwargs.get('q_len', 1)
|
||||
)
|
||||
copy_k_to_blocked_cache(
|
||||
attn_metadata.value_states,
|
||||
attn_metadata.v_cache,
|
||||
kv_lengths=attn_metadata.sequence_lengths,
|
||||
block_tables=attn_metadata.block_tables,
|
||||
n=kwargs.get('q_len', 1)
|
||||
)
|
||||
|
||||
|
||||
def get_attention_context(use_spec_dec: bool, use_cuda_kernel: bool, dtype: torch.dtype) -> AttentionContext:
|
||||
use_flash_attn = can_use_flash_attn2(dtype)
|
||||
if use_cuda_kernel and use_flash_attn and not use_spec_dec:
|
||||
return CudaAttentionContext()
|
||||
else:
|
||||
return TritonAttentionContext()
|
|
@ -8,6 +8,7 @@ import torch.nn as nn
|
|||
from torch.distributed import ProcessGroup
|
||||
|
||||
from colossalai.inference.flash_decoding_utils import FDIntermTensors
|
||||
from colossalai.inference.utils import get_alibi_slopes
|
||||
from colossalai.inference.modeling.models.nopadding_llama import NopadLlamaMLP
|
||||
from colossalai.kernel.kernel_loader import InferenceOpsLoader
|
||||
from colossalai.kernel.triton import (
|
||||
|
@ -47,22 +48,6 @@ inference_ops = InferenceOpsLoader().load()
|
|||
logger = get_dist_logger(__name__)
|
||||
|
||||
|
||||
# alibi slopes calculation adapted from https://github.com/huggingface/transformers/blob/v4.36.0/src/transformers/models/bloom/modeling_bloom.py#L57
|
||||
def get_alibi_slopes(num_heads: int, device: torch.device) -> torch.Tensor:
|
||||
closest_power_of_2 = 2 ** math.floor(math.log2(num_heads))
|
||||
base = torch.tensor(2 ** (-(2 ** -(math.log2(closest_power_of_2) - 3))), dtype=torch.float32, device=device)
|
||||
powers = torch.arange(1, 1 + closest_power_of_2, dtype=torch.int32, device=device)
|
||||
slopes = torch.pow(base, powers)
|
||||
if closest_power_of_2 != num_heads:
|
||||
extra_base = torch.tensor(
|
||||
2 ** (-(2 ** -(math.log2(2 * closest_power_of_2) - 3))), dtype=torch.float32, device=device
|
||||
)
|
||||
num_remaining_heads = min(closest_power_of_2, num_heads - closest_power_of_2)
|
||||
extra_powers = torch.arange(1, 1 + 2 * num_remaining_heads, 2, dtype=torch.int32, device=device)
|
||||
slopes = torch.cat([slopes, torch.pow(extra_base, extra_powers)], dim=0)
|
||||
return slopes
|
||||
|
||||
|
||||
def baichuan_rmsnorm_forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
|
|
|
@ -18,6 +18,9 @@ from transformers.models.llama.modeling_llama import (
|
|||
|
||||
from colossalai.inference.config import InputMetaData
|
||||
from colossalai.inference.flash_decoding_utils import FDIntermTensors
|
||||
from colossalai.inference.modeling.backends.attention_backend import get_attention_backend, AttentionMetaData
|
||||
from colossalai.inference.modeling.backends.attention_context import get_attention_context
|
||||
from colossalai.inference.utils import can_use_flash_attn2
|
||||
from colossalai.kernel.kernel_loader import InferenceOpsLoader
|
||||
from colossalai.kernel.triton import (
|
||||
context_attention_unpadded,
|
||||
|
@ -36,14 +39,6 @@ inference_ops = InferenceOpsLoader().load()
|
|||
|
||||
logger = get_dist_logger(__name__)
|
||||
|
||||
try:
|
||||
from flash_attn import flash_attn_varlen_func
|
||||
|
||||
use_flash_attn2 = True
|
||||
except ImportError:
|
||||
use_flash_attn2 = False
|
||||
logger.warning(f"flash_attn2 has not been installed yet, we will use triton flash attn instead.")
|
||||
|
||||
|
||||
def llama_causal_lm_forward(
|
||||
self: LlamaForCausalLM,
|
||||
|
@ -126,7 +121,7 @@ def llama_model_forward(
|
|||
cos_sin = (self._cos_cached[rotary_indexes], self._sin_cached[rotary_indexes])
|
||||
|
||||
elif use_cuda_kernel:
|
||||
if inputmetadata.dtype != torch.float32 and use_flash_attn2:
|
||||
if inputmetadata.dtype != torch.float32 and can_use_flash_attn2():
|
||||
cu_seqlens = F.pad(torch.cumsum(sequence_lengths, dim=0, dtype=torch.torch.int32), (1, 0))
|
||||
|
||||
hidden_dim = self._cos_cached.size(-1)
|
||||
|
@ -533,109 +528,51 @@ class NopadLlamaAttention(LlamaAttention, ParallelModule):
|
|||
|
||||
block_size = k_cache.size(-2)
|
||||
|
||||
if is_prompts:
|
||||
if not is_verifier and use_cuda_kernel and query_states.dtype != torch.float32 and use_flash_attn2:
|
||||
# flash attn 2 currently only supports FP16/BF16.
|
||||
inference_ops.rotary_embedding(query_states, key_states, cos_sin[0], cos_sin[1], high_precision)
|
||||
inference_ops.context_kv_cache_memcpy(
|
||||
key_states, value_states, k_cache, v_cache, sequence_lengths, cu_seqlens, block_tables, kv_seq_len
|
||||
)
|
||||
|
||||
attn_output = flash_attn_varlen_func(
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
cu_seqlens_q=cu_seqlens,
|
||||
cu_seqlens_k=cu_seqlens,
|
||||
max_seqlen_q=kv_seq_len,
|
||||
max_seqlen_k=kv_seq_len,
|
||||
dropout_p=0.0,
|
||||
softmax_scale=sm_scale,
|
||||
causal=True,
|
||||
)
|
||||
attn_output = attn_output.view(token_nums, -1)
|
||||
else:
|
||||
rotary_embedding(query_states, key_states, cos_sin[0], cos_sin[1])
|
||||
attn_output = context_attention_unpadded(
|
||||
q=query_states,
|
||||
k=key_states,
|
||||
v=value_states,
|
||||
attn_metadata = AttentionMetaData(
|
||||
query_states=query_states,
|
||||
key_states=key_states,
|
||||
value_states=value_states,
|
||||
k_cache=k_cache,
|
||||
v_cache=v_cache,
|
||||
context_lengths=sequence_lengths,
|
||||
block_tables=block_tables,
|
||||
block_size=block_size,
|
||||
output=output_tensor,
|
||||
max_seq_len=kv_seq_len,
|
||||
kv_seq_len=kv_seq_len,
|
||||
sequence_lengths=sequence_lengths,
|
||||
sm_scale=sm_scale,
|
||||
use_new_kcache_layout=use_cuda_kernel,
|
||||
alibi_slopes=None,
|
||||
cu_seqlens=cu_seqlens,
|
||||
output_tensor=output_tensor,
|
||||
use_spec_dec=is_verifier,
|
||||
use_alibi_attn=False,
|
||||
)
|
||||
else:
|
||||
|
||||
attention_backend = get_attention_backend(use_spec_dec=is_verifier, use_cuda_kernel=use_cuda_kernel, dtype=query_states.dtype)
|
||||
attention_context = get_attention_context(use_spec_dec=is_verifier, use_cuda_kernel=use_cuda_kernel, dtype=query_states.dtype)
|
||||
|
||||
if is_prompts: # prefilling stage
|
||||
attention_context.prefill(
|
||||
attn_metadata,
|
||||
cos=cos_sin[0],
|
||||
sin=cos_sin[1],
|
||||
high_precision=high_precision,
|
||||
)
|
||||
attn_output = attention_backend.prefill(
|
||||
attn_metadata,
|
||||
token_nums=token_nums,
|
||||
)
|
||||
else: # decoding stage
|
||||
q_len = tokens_to_verify + 1 if is_verifier else 1
|
||||
|
||||
if use_cuda_kernel:
|
||||
inference_ops.rotary_embedding_and_cache_copy(
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
cos_sin[0],
|
||||
cos_sin[1],
|
||||
k_cache,
|
||||
v_cache,
|
||||
sequence_lengths,
|
||||
block_tables,
|
||||
high_precision,
|
||||
attention_context.decode(
|
||||
attn_metadata,
|
||||
cos=cos_sin[0],
|
||||
sin=cos_sin[1],
|
||||
q_len=q_len,
|
||||
)
|
||||
inference_ops.flash_decoding_attention(
|
||||
output_tensor,
|
||||
query_states,
|
||||
k_cache,
|
||||
v_cache,
|
||||
sequence_lengths,
|
||||
block_tables,
|
||||
block_size,
|
||||
kv_seq_len,
|
||||
fd_inter_tensor.mid_output,
|
||||
fd_inter_tensor.exp_sums,
|
||||
fd_inter_tensor.max_logits,
|
||||
None,
|
||||
sm_scale,
|
||||
)
|
||||
attn_output = output_tensor
|
||||
else:
|
||||
if is_verifier:
|
||||
rotary_embedding(query_states, key_states, cos_sin[0], cos_sin[1])
|
||||
copy_k_to_blocked_cache(
|
||||
key_states, k_cache, kv_lengths=sequence_lengths, block_tables=block_tables, n=q_len
|
||||
)
|
||||
copy_k_to_blocked_cache(
|
||||
value_states, v_cache, kv_lengths=sequence_lengths, block_tables=block_tables, n=q_len
|
||||
)
|
||||
else:
|
||||
decoding_fused_rotary_embedding(
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
cos_sin[0],
|
||||
cos_sin[1],
|
||||
k_cache,
|
||||
v_cache,
|
||||
block_tables,
|
||||
sequence_lengths,
|
||||
)
|
||||
attn_output = flash_decoding_attention(
|
||||
q=query_states,
|
||||
k_cache=k_cache,
|
||||
v_cache=v_cache,
|
||||
kv_seq_len=sequence_lengths,
|
||||
block_tables=block_tables,
|
||||
block_size=block_size,
|
||||
max_seq_len_in_batch=kv_seq_len,
|
||||
output=output_tensor,
|
||||
mid_output=fd_inter_tensor.mid_output,
|
||||
mid_output_lse=fd_inter_tensor.mid_output_lse,
|
||||
sm_scale=sm_scale,
|
||||
kv_group_num=self.num_key_value_groups,
|
||||
attn_output = attention_backend.decode(
|
||||
attn_metadata,
|
||||
fd_inter_tensor=fd_inter_tensor,
|
||||
num_key_value_groups=self.num_key_value_groups,
|
||||
q_len=q_len,
|
||||
)
|
||||
|
||||
|
@ -695,3 +632,4 @@ class NopadLlamaAttention(LlamaAttention, ParallelModule):
|
|||
|
||||
def extra_repr(self) -> str:
|
||||
return f"qkv_weight_proj MergedLinear1D_Col: in_features={self.qkv_weight.shape[1]}x3, out_features={self.qkv_weight.shape[2]}, bias=False"
|
||||
|
|
@ -3,6 +3,7 @@ Utils for model inference
|
|||
"""
|
||||
import os
|
||||
import re
|
||||
import math
|
||||
from pathlib import Path
|
||||
from typing import Optional, Tuple
|
||||
|
||||
|
@ -10,6 +11,9 @@ import torch
|
|||
from torch import nn
|
||||
|
||||
from colossalai.testing import free_port
|
||||
from colossalai.logging import get_dist_logger
|
||||
|
||||
logger = get_dist_logger(__name__)
|
||||
|
||||
|
||||
def init_to_get_rotary(self, base=10000, use_elem=False):
|
||||
|
@ -113,3 +117,45 @@ def find_available_ports(num: int):
|
|||
print(f"An OS error occurred: {e}")
|
||||
raise RuntimeError("Error finding available ports")
|
||||
return free_ports
|
||||
|
||||
|
||||
def get_alibi_slopes(num_heads: int, device: torch.device) -> torch.Tensor:
|
||||
"""
|
||||
Alibi slopes calculation adapted from https://github.com/huggingface/transformers/blob/v4.36.0/src/transformers/models/bloom/modeling_bloom.py#L57
|
||||
|
||||
Args:
|
||||
num_heads (int): The number of attention heads.
|
||||
device (torch.device): The device to use.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: The Alibi slopes.
|
||||
"""
|
||||
closest_power_of_2 = 2 ** math.floor(math.log2(num_heads))
|
||||
base = torch.tensor(2 ** (-(2 ** -(math.log2(closest_power_of_2) - 3))), dtype=torch.float32, device=device)
|
||||
powers = torch.arange(1, 1 + closest_power_of_2, dtype=torch.int32, device=device)
|
||||
slopes = torch.pow(base, powers)
|
||||
if closest_power_of_2 != num_heads:
|
||||
extra_base = torch.tensor(
|
||||
2 ** (-(2 ** -(math.log2(2 * closest_power_of_2) - 3))), dtype=torch.float32, device=device
|
||||
)
|
||||
num_remaining_heads = min(closest_power_of_2, num_heads - closest_power_of_2)
|
||||
extra_powers = torch.arange(1, 1 + 2 * num_remaining_heads, 2, dtype=torch.int32, device=device)
|
||||
slopes = torch.cat([slopes, torch.pow(extra_base, extra_powers)], dim=0)
|
||||
return slopes
|
||||
|
||||
|
||||
def can_use_flash_attn2(dtype: torch.dtype) -> bool:
|
||||
"""
|
||||
Check flash attention2 availability.
|
||||
"""
|
||||
if dtype not in (torch.float16, torch.bfloat16):
|
||||
logger.warning(f"Flash attn2 currently only supports float16 and bfloat16.")
|
||||
return False
|
||||
|
||||
try:
|
||||
from flash_attn import __version__
|
||||
logger.info(f"flash_attn2 version {__version__}.")
|
||||
return True
|
||||
except ImportError:
|
||||
logger.warning(f"flash_attn2 has not been installed yet, we will use triton flash attn instead.")
|
||||
return False
|
Loading…
Reference in New Issue