mirror of https://github.com/hpcaitech/ColossalAI
Pass inference model shard configs for module init
Signed-off-by: char-1ee <xingjianli59@gmail.com>pull/5771/head
parent
eec77e5702
commit
5f398fc000
Binary file not shown.
|
@ -10,6 +10,7 @@ import torch
|
|||
from transformers.generation import GenerationConfig
|
||||
|
||||
from colossalai.inference.flash_decoding_utils import FDIntermTensors
|
||||
from colossalai.inference.utils import can_use_flash_attn2
|
||||
|
||||
GibiByte = 1024**3
|
||||
|
||||
|
@ -313,12 +314,13 @@ class InferenceConfig(RPC_PARAM):
|
|||
|
||||
return GenerationConfig.from_dict(meta_config)
|
||||
|
||||
def to_model_inference_config(self) -> "ModelInferenceConfig":
|
||||
model_inference_config = ModelInferenceConfig(
|
||||
def to_model_shard_inference_config(self) -> "ModelShardInferenceConfig":
|
||||
use_flash_attn = can_use_flash_attn2(self.dtype)
|
||||
model_inference_config = ModelShardInferenceConfig(
|
||||
dtype=self.dtype,
|
||||
use_cuda_kernel=self.use_cuda_kernel,
|
||||
use_spec_dec=self.use_spec_dec,
|
||||
use_cuda_graph=self.use_cuda_graph,
|
||||
use_flash_attn=use_flash_attn,
|
||||
)
|
||||
return model_inference_config
|
||||
|
||||
|
@ -374,21 +376,20 @@ class InferenceConfig(RPC_PARAM):
|
|||
inference_config = cls(**inference_config_args)
|
||||
return inference_config
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelInferenceConfig():
|
||||
class ModelShardInferenceConfig:
|
||||
"""
|
||||
Configurations used when initializing/sharding model for inference.
|
||||
Configurations used during init of module for inference modeling.
|
||||
|
||||
Args:
|
||||
dtype (torch.dtype): The data type for weights and activations.
|
||||
use_cuda_kernel (bool): Whether to use cuda kernel, faster but lose some precision occasionally
|
||||
use_spec_dec (bool): Indicate whether to use speculative decoding.
|
||||
use_flash_attn (bool): Indicate whether to use flash attention.
|
||||
use_cuda_graph (bool): Whether to enforce CUDA graph execution. If False, we will disable CUDA graph and always execute the model in eager mode. If True, we will use eager execution in hybrid.
|
||||
"""
|
||||
|
||||
dtype: torch.dtype = None
|
||||
use_cuda_kernel: bool = False
|
||||
use_spec_dec: bool = False
|
||||
use_flash_attn: bool = False
|
||||
use_cuda_graph: bool = False
|
||||
|
|
@ -18,7 +18,7 @@ from transformers.models.llama.modeling_llama import LlamaForCausalLM
|
|||
from colossalai.accelerator import get_accelerator
|
||||
from colossalai.cluster import ProcessGroupMesh
|
||||
from colossalai.inference.batch_bucket import BatchBucket
|
||||
from colossalai.inference.config import InferenceConfig, InputMetaData
|
||||
from colossalai.inference.config import InferenceConfig, InputMetaData, ModelShardInferenceConfig
|
||||
from colossalai.inference.graph_runner import CUDAGraphRunner
|
||||
from colossalai.inference.modeling.policy import model_policy_map
|
||||
from colossalai.inference.sampler import search_tokens
|
||||
|
@ -72,8 +72,9 @@ class InferenceEngine:
|
|||
|
||||
self.verbose = verbose
|
||||
self.logger = get_dist_logger(__name__)
|
||||
self.model_shard_infer_config = inference_config.to_model_shard_inference_config()
|
||||
|
||||
self.init_model(model_or_path, model_policy)
|
||||
self.init_model(model_or_path, model_policy, self.model_shard_infer_config)
|
||||
|
||||
self.generation_config = inference_config.to_generation_config(self.model_config)
|
||||
self.generation_config_dict = self.generation_config.to_dict()
|
||||
|
@ -99,8 +100,6 @@ class InferenceEngine:
|
|||
# Model and relatable attrs of speculative decoding will be set by `enable_spec_dec`
|
||||
self.use_spec_dec = self.inference_config.use_spec_dec
|
||||
|
||||
# TODO: when use_spec_dec set to True, users should pass drafter_model configs into InferenceEngine
|
||||
# We can add a SpecDecConfig class to store these configs.
|
||||
self.drafter_model = None
|
||||
self.drafter = None
|
||||
self.use_glide = False
|
||||
|
@ -112,6 +111,7 @@ class InferenceEngine:
|
|||
self,
|
||||
model_or_path: Union[nn.Module, str],
|
||||
model_policy: Union[Policy, Type[Policy]] = None,
|
||||
model_shard_infer_config: ModelShardInferenceConfig = None,
|
||||
):
|
||||
"""
|
||||
Shard model or/and Load weight
|
||||
|
@ -120,6 +120,7 @@ class InferenceEngine:
|
|||
model_or_path Union[nn.Module, str]: path to the checkpoint or model of transformer format.
|
||||
model_policy (Policy): the policy to replace the model.
|
||||
model_inference_config: the configuration for modeling initialization when inference.
|
||||
model_shard_infer_config (ModelShardInferenceConfig): the configuration for init of module when inference.
|
||||
"""
|
||||
|
||||
if isinstance(model_or_path, str):
|
||||
|
@ -176,6 +177,7 @@ class InferenceEngine:
|
|||
self.model = self._shardformer(
|
||||
model,
|
||||
model_policy,
|
||||
model_shard_infer_config,
|
||||
None,
|
||||
tp_group=tp_group,
|
||||
)
|
||||
|
@ -296,6 +298,7 @@ class InferenceEngine:
|
|||
self,
|
||||
model: nn.Module,
|
||||
model_policy: Policy,
|
||||
model_shard_infer_config: ModelShardInferenceConfig = None,
|
||||
stage_manager: PipelineStageManager = None,
|
||||
tp_group: ProcessGroupMesh = None,
|
||||
) -> nn.Module:
|
||||
|
@ -321,6 +324,7 @@ class InferenceEngine:
|
|||
enable_flash_attention=False,
|
||||
enable_jit_fused=False,
|
||||
enable_sequence_parallelism=False,
|
||||
extra_kwargs={"model_shard_infer_config": model_shard_infer_config},
|
||||
)
|
||||
shardformer = ShardFormer(shard_config=shardconfig)
|
||||
shard_model, _ = shardformer.optimize(model, model_policy)
|
||||
|
@ -357,7 +361,6 @@ class InferenceEngine:
|
|||
engine.clear_spec_dec()
|
||||
```
|
||||
"""
|
||||
self.logger.warning(f"Current method will be deprecated soon. To use speculative decoding, please set `use_spec_dec` in `InferenceConfig` instead.")
|
||||
|
||||
if drafter_model is None and self.drafter is None:
|
||||
raise ValueError("Drafter not initialized. Please provide a Drafter Model")
|
||||
|
|
|
@ -1,19 +1,12 @@
|
|||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from flash_attn import flash_attn_varlen_func
|
||||
|
||||
import torch
|
||||
from flash_attn import flash_attn_varlen_func
|
||||
|
||||
from colossalai.inference.config import InputMetaData
|
||||
from colossalai.inference.utils import can_use_flash_attn2
|
||||
from colossalai.logging import get_dist_logger
|
||||
from colossalai.inference.config import ModelShardInferenceConfig
|
||||
from colossalai.kernel.kernel_loader import InferenceOpsLoader
|
||||
from colossalai.kernel.triton import (
|
||||
context_attention_unpadded,
|
||||
flash_decoding_attention,
|
||||
)
|
||||
|
||||
logger = get_dist_logger(__name__)
|
||||
inference_ops = InferenceOpsLoader().load()
|
||||
from colossalai.kernel.triton import context_attention_unpadded, flash_decoding_attention
|
||||
|
||||
|
||||
@dataclass
|
||||
|
@ -33,7 +26,6 @@ class AttentionMetaData:
|
|||
output_tensor: torch.Tensor = None
|
||||
use_spec_dec: bool = False
|
||||
use_alibi_attn: bool = False
|
||||
use_cuda_kernel: bool = False
|
||||
|
||||
|
||||
class AttentionBackend(ABC):
|
||||
|
@ -46,7 +38,16 @@ class AttentionBackend(ABC):
|
|||
raise NotImplementedError
|
||||
|
||||
|
||||
class CudaAttentionBackend(AttentionBackend):
|
||||
class FlashAttentionBackend(AttentionBackend):
|
||||
"""
|
||||
Attention backend when use_cuda_kernel is True and flash-attn is installed. It uses
|
||||
`flash_attn_varlen_func` for prefilling and our cuda op `flash_decoding_attention` for decoding.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.inference_ops = InferenceOpsLoader().load()
|
||||
|
||||
def prefill(self, attn_metadata: AttentionMetaData, **kwargs):
|
||||
token_nums = kwargs.get("token_nums", -1)
|
||||
|
||||
|
@ -69,7 +70,55 @@ class CudaAttentionBackend(AttentionBackend):
|
|||
def decode(self, attn_metadata: AttentionMetaData, **kwargs):
|
||||
fd_inter_tensor = kwargs.get("fd_inter_tensor", None)
|
||||
output_tensor = attn_metadata.output_tensor
|
||||
inference_ops.flash_decoding_attention(
|
||||
self.inference_ops.flash_decoding_attention(
|
||||
output_tensor,
|
||||
attn_metadata.query_states,
|
||||
attn_metadata.k_cache,
|
||||
attn_metadata.v_cache,
|
||||
attn_metadata.sequence_lengths,
|
||||
attn_metadata.block_tables,
|
||||
attn_metadata.block_size,
|
||||
attn_metadata.kv_seq_len,
|
||||
fd_inter_tensor.mid_output,
|
||||
fd_inter_tensor.exp_sums,
|
||||
fd_inter_tensor.max_logits,
|
||||
attn_metadata.alibi_slopes,
|
||||
attn_metadata.sm_scale,
|
||||
)
|
||||
return output_tensor
|
||||
|
||||
|
||||
class CudaAttentionBackend(AttentionBackend):
|
||||
"""
|
||||
Attention backend when use_cuda_kernel is True but flash-attn not found. If flash-attn is not found,
|
||||
it uses Triton op `context_attention_unpadded` for prefilling and our cuda op `flash_decoding_attention` for decoding.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.inference_ops = InferenceOpsLoader().load()
|
||||
|
||||
def prefill(self, attn_metadata: AttentionMetaData, **kwargs):
|
||||
return context_attention_unpadded(
|
||||
q=attn_metadata.query_states,
|
||||
k=attn_metadata.key_states,
|
||||
v=attn_metadata.value_states,
|
||||
k_cache=attn_metadata.k_cache,
|
||||
v_cache=attn_metadata.v_cache,
|
||||
context_lengths=attn_metadata.sequence_lengths,
|
||||
block_tables=attn_metadata.block_tables,
|
||||
block_size=attn_metadata.block_size,
|
||||
output=attn_metadata.output_tensor,
|
||||
alibi_slopes=attn_metadata.alibi_slopes,
|
||||
max_seq_len=attn_metadata.kv_seq_len,
|
||||
sm_scale=attn_metadata.sm_scale,
|
||||
use_new_kcache_layout=True, # use new k cache layout for cuda kernels in this triton op
|
||||
)
|
||||
|
||||
def decode(self, attn_metadata: AttentionMetaData, **kwargs):
|
||||
fd_inter_tensor = kwargs.get("fd_inter_tensor", None)
|
||||
output_tensor = attn_metadata.output_tensor
|
||||
self.inference_ops.flash_decoding_attention(
|
||||
output_tensor,
|
||||
attn_metadata.query_states,
|
||||
attn_metadata.k_cache,
|
||||
|
@ -88,6 +137,10 @@ class CudaAttentionBackend(AttentionBackend):
|
|||
|
||||
|
||||
class TritonAttentionBackend(AttentionBackend):
|
||||
"""
|
||||
Attention backend when use_cuda_kernel is False. It uses pure Triton ops for prefilling and decoding.
|
||||
"""
|
||||
|
||||
def prefill(self, attn_metadata: AttentionMetaData, **kwargs):
|
||||
return context_attention_unpadded(
|
||||
q=attn_metadata.query_states,
|
||||
|
@ -102,7 +155,7 @@ class TritonAttentionBackend(AttentionBackend):
|
|||
alibi_slopes=attn_metadata.alibi_slopes,
|
||||
max_seq_len=attn_metadata.kv_seq_len,
|
||||
sm_scale=attn_metadata.sm_scale,
|
||||
use_new_kcache_layout=attn_metadata.use_cuda_kernel,
|
||||
use_new_kcache_layout=False,
|
||||
)
|
||||
|
||||
def decode(self, attn_metadata: AttentionMetaData, **kwargs):
|
||||
|
@ -126,17 +179,24 @@ class TritonAttentionBackend(AttentionBackend):
|
|||
|
||||
|
||||
def get_attention_backend(
|
||||
use_spec_dec: bool, use_cuda_kernel: bool, dtype: torch.dtype
|
||||
model_shard_infer_config: ModelShardInferenceConfig,
|
||||
) -> AttentionBackend:
|
||||
"""
|
||||
Get the attention backend based on the inference configurations. Only when:
|
||||
Get the attention backend based on the inference configurations. The modeling will use CUDA-kernel-based backend
|
||||
for attention module calculation only when:
|
||||
1. using CUDA kernel (use_cuda_kernel=True)
|
||||
2. can use flash attention (flash-attn installed and dtype is fp16 or bf16)
|
||||
3. not using speculative decoding (currently cuda kernel not support speculative decoding)
|
||||
will the CUDA-kernel-based backend be used for attention layer computations. Otherwise, use Triton attention backend.
|
||||
Otherwise, use Triton attention backend. If found flash-attn not installed while `use_cuda_kernel` is True,
|
||||
the Triton backend will use a new k cache layout for Triton kernels.
|
||||
"""
|
||||
use_flash_attn = can_use_flash_attn2(dtype)
|
||||
if use_cuda_kernel and use_flash_attn and not use_spec_dec:
|
||||
return CudaAttentionBackend()
|
||||
else:
|
||||
# Currently only triton kernels support speculative decoding
|
||||
if model_shard_infer_config.use_spec_dec:
|
||||
return TritonAttentionBackend()
|
||||
|
||||
if model_shard_infer_config.use_cuda_kernel:
|
||||
if model_shard_infer_config.use_flash_attn:
|
||||
return FlashAttentionBackend()
|
||||
return CudaAttentionBackend()
|
||||
|
||||
return TritonAttentionBackend()
|
||||
|
|
|
@ -1,18 +1,9 @@
|
|||
from abc import ABC, abstractmethod
|
||||
import torch
|
||||
|
||||
from colossalai.inference.utils import can_use_flash_attn2
|
||||
from colossalai.kernel.kernel_loader import InferenceOpsLoader
|
||||
from colossalai.inference.config import ModelShardInferenceConfig
|
||||
from colossalai.inference.modeling.backends.attention_backend import AttentionMetaData
|
||||
from colossalai.logging import get_dist_logger
|
||||
from colossalai.kernel.triton import (
|
||||
copy_k_to_blocked_cache,
|
||||
decoding_fused_rotary_embedding,
|
||||
rotary_embedding,
|
||||
)
|
||||
|
||||
logger = get_dist_logger(__name__)
|
||||
inference_ops = InferenceOpsLoader().load()
|
||||
from colossalai.kernel.kernel_loader import InferenceOpsLoader
|
||||
from colossalai.kernel.triton import copy_k_to_blocked_cache, decoding_fused_rotary_embedding, rotary_embedding
|
||||
|
||||
|
||||
class PreAttentionBackend(ABC):
|
||||
|
@ -25,17 +16,25 @@ class PreAttentionBackend(ABC):
|
|||
raise NotImplementedError
|
||||
|
||||
|
||||
class CudaPreAttentionBackend(PreAttentionBackend):
|
||||
class FlashPreAttentionBackend(PreAttentionBackend):
|
||||
"""
|
||||
FlashPreAttentionBackend handles KV cache initialization and positional encoding for FlashAttentionBackend.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.inference_ops = InferenceOpsLoader().load()
|
||||
|
||||
def prefill(self, attn_metadata: AttentionMetaData, **kwargs):
|
||||
if not attn_metadata.use_alibi_attn:
|
||||
inference_ops.rotary_embedding(
|
||||
self.inference_ops.rotary_embedding(
|
||||
attn_metadata.query_states,
|
||||
attn_metadata.key_states,
|
||||
kwargs.get("cos", None),
|
||||
kwargs.get("sin", None),
|
||||
kwargs.get("high_precision", False),
|
||||
)
|
||||
inference_ops.context_kv_cache_memcpy(
|
||||
self.inference_ops.context_kv_cache_memcpy(
|
||||
attn_metadata.key_states,
|
||||
attn_metadata.value_states,
|
||||
attn_metadata.k_cache,
|
||||
|
@ -48,7 +47,7 @@ class CudaPreAttentionBackend(PreAttentionBackend):
|
|||
|
||||
def decode(self, attn_metadata: AttentionMetaData, **kwargs):
|
||||
if not attn_metadata.use_alibi_attn:
|
||||
inference_ops.rotary_embedding_and_cache_copy(
|
||||
self.inference_ops.rotary_embedding_and_cache_copy(
|
||||
attn_metadata.query_states,
|
||||
attn_metadata.key_states,
|
||||
attn_metadata.value_states,
|
||||
|
@ -61,7 +60,50 @@ class CudaPreAttentionBackend(PreAttentionBackend):
|
|||
kwargs.get("high_precision", None),
|
||||
)
|
||||
else:
|
||||
inference_ops.decode_kv_cache_memcpy(
|
||||
self.inference_ops.decode_kv_cache_memcpy(
|
||||
attn_metadata.key_states,
|
||||
attn_metadata.value_states,
|
||||
attn_metadata.k_cache,
|
||||
attn_metadata.v_cache,
|
||||
attn_metadata.sequence_lengths,
|
||||
attn_metadata.block_tables,
|
||||
)
|
||||
|
||||
|
||||
class CudaPreAttentionBackend(PreAttentionBackend):
|
||||
"""
|
||||
CudaPreAttentionBackend handles KV cache initialization and positional encoding for CudaAttentionBackend.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.inference_ops = InferenceOpsLoader().load()
|
||||
|
||||
def prefill(self, attn_metadata: AttentionMetaData, **kwargs):
|
||||
if not attn_metadata.use_alibi_attn:
|
||||
rotary_embedding(
|
||||
attn_metadata.query_states,
|
||||
attn_metadata.key_states,
|
||||
kwargs.get("cos", None),
|
||||
kwargs.get("sin", None),
|
||||
)
|
||||
|
||||
def decode(self, attn_metadata: AttentionMetaData, **kwargs):
|
||||
if not attn_metadata.use_alibi_attn:
|
||||
self.inference_ops.rotary_embedding_and_cache_copy(
|
||||
attn_metadata.query_states,
|
||||
attn_metadata.key_states,
|
||||
attn_metadata.value_states,
|
||||
kwargs.get("cos", None),
|
||||
kwargs.get("sin", None),
|
||||
attn_metadata.k_cache,
|
||||
attn_metadata.v_cache,
|
||||
attn_metadata.sequence_lengths,
|
||||
attn_metadata.block_tables,
|
||||
kwargs.get("high_precision", None),
|
||||
)
|
||||
else:
|
||||
self.inference_ops.decode_kv_cache_memcpy(
|
||||
attn_metadata.key_states,
|
||||
attn_metadata.value_states,
|
||||
attn_metadata.k_cache,
|
||||
|
@ -72,6 +114,10 @@ class CudaPreAttentionBackend(PreAttentionBackend):
|
|||
|
||||
|
||||
class TritonPreAttentionBackend(PreAttentionBackend):
|
||||
"""
|
||||
TritonPreAttentionBackend handles KV cache initialization and positional encoding for TritonAttentionBackend.
|
||||
"""
|
||||
|
||||
def prefill(self, attn_metadata: AttentionMetaData, **kwargs):
|
||||
if not attn_metadata.use_alibi_attn:
|
||||
rotary_embedding(
|
||||
|
@ -119,13 +165,18 @@ class TritonPreAttentionBackend(PreAttentionBackend):
|
|||
|
||||
|
||||
def get_pre_attention_backend(
|
||||
use_spec_dec: bool, use_cuda_kernel: bool, dtype: torch.dtype
|
||||
model_shard_infer_config: ModelShardInferenceConfig,
|
||||
) -> PreAttentionBackend:
|
||||
"""
|
||||
Get the backend for pre-attention computations, including potisional encoding like RoPE and KV cache initialization.
|
||||
Get the backend for pre-attention computations, including potisional encoding like
|
||||
RoPE and KV cache initialization. It adopt the same selection logic as attention_backend/get_attention_backend.
|
||||
"""
|
||||
use_flash_attn = can_use_flash_attn2(dtype)
|
||||
if use_cuda_kernel and use_flash_attn and not use_spec_dec:
|
||||
return CudaPreAttentionBackend()
|
||||
else:
|
||||
if model_shard_infer_config.use_spec_dec:
|
||||
return TritonPreAttentionBackend()
|
||||
|
||||
if model_shard_infer_config.use_cuda_kernel:
|
||||
if model_shard_infer_config.use_flash_attn:
|
||||
return FlashPreAttentionBackend()
|
||||
return CudaPreAttentionBackend()
|
||||
|
||||
return TritonPreAttentionBackend()
|
||||
|
|
|
@ -1,31 +1,23 @@
|
|||
# This code is adapted from huggingface baichuan model: hhttps://huggingface.co/baichuan-inc/Baichuan2-13B-Base/blob/main/modeling_baichuan.py
|
||||
import itertools
|
||||
import math
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.distributed import ProcessGroup
|
||||
|
||||
from colossalai.inference.config import ModelShardInferenceConfig
|
||||
from colossalai.inference.flash_decoding_utils import FDIntermTensors
|
||||
from colossalai.inference.utils import get_alibi_slopes
|
||||
from colossalai.inference.modeling.models.nopadding_llama import NopadLlamaMLP
|
||||
from colossalai.inference.modeling.backends.attention_backend import AttentionMetaData, get_attention_backend
|
||||
from colossalai.inference.modeling.backends.pre_attention_backend import get_pre_attention_backend
|
||||
from colossalai.inference.modeling.models.nopadding_llama import NopadLlamaMLP
|
||||
from colossalai.inference.utils import get_alibi_slopes
|
||||
from colossalai.kernel.kernel_loader import InferenceOpsLoader
|
||||
from colossalai.kernel.triton import (
|
||||
context_attention_unpadded,
|
||||
copy_k_to_blocked_cache,
|
||||
decoding_fused_rotary_embedding,
|
||||
flash_decoding_attention,
|
||||
rms_layernorm,
|
||||
rotary_embedding,
|
||||
)
|
||||
from colossalai.kernel.triton import rms_layernorm
|
||||
from colossalai.logging import get_dist_logger
|
||||
from colossalai.shardformer.layer.parallel_module import ParallelModule
|
||||
from colossalai.tensor.d_tensor import Layout, distribute_tensor, is_distributed_tensor
|
||||
|
||||
|
||||
inference_ops = InferenceOpsLoader().load()
|
||||
logger = get_dist_logger(__name__)
|
||||
|
||||
|
@ -69,6 +61,7 @@ class NopadBaichuanAttention(ParallelModule):
|
|||
attn_oproj: ParallelModule = None,
|
||||
num_heads: int = None,
|
||||
hidden_size: int = None,
|
||||
model_shard_infer_config: ModelShardInferenceConfig = None,
|
||||
process_group: ProcessGroup = None,
|
||||
helper_layout: Layout = None,
|
||||
):
|
||||
|
@ -93,6 +86,9 @@ class NopadBaichuanAttention(ParallelModule):
|
|||
self.qkv_weight = nn.Parameter(torch.stack(qkv_weight_list, dim=0))
|
||||
|
||||
self.helper_layout = helper_layout
|
||||
self.use_cuda_kernel = model_shard_infer_config.use_cuda_kernel
|
||||
self.attention_backend = get_attention_backend(model_shard_infer_config)
|
||||
self.pre_attention_backend = get_pre_attention_backend(model_shard_infer_config)
|
||||
|
||||
self.alibi_slopes = None
|
||||
self.use_alibi_attn = False
|
||||
|
@ -122,6 +118,7 @@ class NopadBaichuanAttention(ParallelModule):
|
|||
attn_kproj_w = k_proj_w
|
||||
attn_vproj_w = v_proj_w
|
||||
attn_oproj = module.o_proj
|
||||
model_shard_infer_config = kwargs.get("model_shard_infer_config", None)
|
||||
|
||||
helper_layout = (
|
||||
module.W_pack.weight.dist_layout
|
||||
|
@ -133,6 +130,7 @@ class NopadBaichuanAttention(ParallelModule):
|
|||
attn_kproj_w=attn_kproj_w,
|
||||
attn_vproj_w=attn_vproj_w,
|
||||
attn_oproj=attn_oproj,
|
||||
model_shard_infer_config=model_shard_infer_config,
|
||||
num_heads=module.num_heads,
|
||||
hidden_size=module.hidden_size,
|
||||
process_group=process_group,
|
||||
|
@ -201,7 +199,6 @@ class NopadBaichuanAttention(ParallelModule):
|
|||
kv_seq_len: int = 0,
|
||||
output_tensor: torch.Tensor = None,
|
||||
sm_scale: int = None,
|
||||
use_cuda_kernel: bool = True,
|
||||
cu_seqlens: torch.Tensor = None,
|
||||
high_precision: bool = False,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||
|
@ -220,7 +217,6 @@ class NopadBaichuanAttention(ParallelModule):
|
|||
kv_seq_len (int, optional): The max sequence length of input sequences. Defaults to 0.
|
||||
output_tensor (torch.Tensor, optional): The mid tensor holds the output of attention. Defaults to None.
|
||||
sm_scale (int, optional): Used for flash attention. Defaults to None.
|
||||
use_cuda_kernel: (bool, optional): Whether to use cuda kernel. Defaults to True.
|
||||
cu_seqlens(torch.Tensor, optional): Holding the cumulative sum of sequence length.
|
||||
high_precision(Optional[bool]): Whether to use float32 for underlying calculations of float16 data to achieve higher precision, defaults to False.
|
||||
"""
|
||||
|
@ -250,33 +246,29 @@ class NopadBaichuanAttention(ParallelModule):
|
|||
output_tensor=output_tensor,
|
||||
use_spec_dec=is_verifier,
|
||||
use_alibi_attn=self.use_alibi_attn,
|
||||
use_cuda_kernel=use_cuda_kernel,
|
||||
)
|
||||
|
||||
attention_backend = get_attention_backend(use_spec_dec=is_verifier, use_cuda_kernel=use_cuda_kernel, dtype=query_states.dtype)
|
||||
pre_attention_backend = get_pre_attention_backend(use_spec_dec=is_verifier, use_cuda_kernel=use_cuda_kernel, dtype=query_states.dtype)
|
||||
|
||||
if is_prompts: # prefilling stage
|
||||
pre_attention_backend.prefill(
|
||||
self.pre_attention_backend.prefill(
|
||||
attn_metadata,
|
||||
cos=cos_sin[0],
|
||||
sin=cos_sin[1],
|
||||
high_precision=high_precision,
|
||||
)
|
||||
attn_output = attention_backend.prefill(
|
||||
attn_output = self.attention_backend.prefill(
|
||||
attn_metadata,
|
||||
token_nums=token_nums,
|
||||
)
|
||||
else: # decoding stage
|
||||
q_len = tokens_to_verify + 1 if is_verifier else 1
|
||||
|
||||
pre_attention_backend.decode(
|
||||
self.pre_attention_backend.decode(
|
||||
attn_metadata,
|
||||
cos=cos_sin[0],
|
||||
sin=cos_sin[1],
|
||||
q_len=q_len,
|
||||
)
|
||||
attn_output = attention_backend.decode(
|
||||
attn_output = self.attention_backend.decode(
|
||||
attn_metadata,
|
||||
fd_inter_tensor=fd_inter_tensor,
|
||||
q_len=q_len,
|
||||
|
|
|
@ -16,21 +16,13 @@ from transformers.models.llama.modeling_llama import (
|
|||
LlamaRMSNorm,
|
||||
)
|
||||
|
||||
from colossalai.inference.config import InputMetaData
|
||||
from colossalai.inference.config import InputMetaData, ModelShardInferenceConfig
|
||||
from colossalai.inference.flash_decoding_utils import FDIntermTensors
|
||||
from colossalai.inference.modeling.backends.attention_backend import get_attention_backend, AttentionMetaData
|
||||
from colossalai.inference.modeling.backends.attention_backend import AttentionMetaData, get_attention_backend
|
||||
from colossalai.inference.modeling.backends.pre_attention_backend import get_pre_attention_backend
|
||||
from colossalai.inference.utils import can_use_flash_attn2
|
||||
from colossalai.kernel.kernel_loader import InferenceOpsLoader
|
||||
from colossalai.kernel.triton import (
|
||||
context_attention_unpadded,
|
||||
copy_k_to_blocked_cache,
|
||||
decoding_fused_rotary_embedding,
|
||||
flash_decoding_attention,
|
||||
get_xine_cache,
|
||||
rms_layernorm,
|
||||
rotary_embedding,
|
||||
)
|
||||
from colossalai.kernel.triton import get_xine_cache, rms_layernorm
|
||||
from colossalai.logging import get_dist_logger
|
||||
from colossalai.shardformer.layer.parallel_module import ParallelModule
|
||||
from colossalai.tensor.d_tensor import distribute_tensor, is_distributed_tensor
|
||||
|
@ -233,7 +225,6 @@ def llama_decoder_layer_forward(
|
|||
kv_seq_len=kv_seq_len,
|
||||
output_tensor=output_tensor,
|
||||
sm_scale=sm_scale,
|
||||
use_cuda_kernel=use_cuda_kernel,
|
||||
cu_seqlens=cu_seqlens,
|
||||
high_precision=high_precision,
|
||||
)
|
||||
|
@ -397,6 +388,7 @@ class NopadLlamaAttention(LlamaAttention, ParallelModule):
|
|||
attn_vproj_w: torch.Tensor = None,
|
||||
attn_oproj: ParallelModule = None,
|
||||
process_group: ProcessGroup = None,
|
||||
model_shard_infer_config: ModelShardInferenceConfig = None,
|
||||
num_heads: int = None,
|
||||
hidden_size: int = None,
|
||||
num_key_value_heads: int = None,
|
||||
|
@ -428,6 +420,9 @@ class NopadLlamaAttention(LlamaAttention, ParallelModule):
|
|||
self.rope_theta = config.rope_theta
|
||||
self.is_causal = True
|
||||
|
||||
self.attention_backend = get_attention_backend(model_shard_infer_config)
|
||||
self.pre_attention_backend = get_pre_attention_backend(model_shard_infer_config)
|
||||
|
||||
if self.num_heads == self.num_key_value_heads:
|
||||
qkv_weight_list = [attn_qproj_w.transpose(0, 1), attn_kproj_w.transpose(0, 1), attn_vproj_w.transpose(0, 1)]
|
||||
self.qkv_weight = nn.Parameter(torch.stack(qkv_weight_list, dim=0))
|
||||
|
@ -457,6 +452,7 @@ class NopadLlamaAttention(LlamaAttention, ParallelModule):
|
|||
attn_vproj_w = module.v_proj.weight
|
||||
assert is_distributed_tensor(attn_qproj_w), "attn_qproj_w must be dist tensor"
|
||||
attn_oproj = module.o_proj
|
||||
model_shard_infer_config = kwargs.get("model_shard_infer_config", None)
|
||||
|
||||
attn_layer = NopadLlamaAttention(
|
||||
config=config,
|
||||
|
@ -466,6 +462,7 @@ class NopadLlamaAttention(LlamaAttention, ParallelModule):
|
|||
attn_vproj_w=attn_vproj_w,
|
||||
attn_oproj=attn_oproj,
|
||||
process_group=process_group,
|
||||
model_shard_infer_config=model_shard_infer_config,
|
||||
num_heads=module.num_heads,
|
||||
hidden_size=module.hidden_size,
|
||||
num_key_value_heads=module.num_key_value_heads,
|
||||
|
@ -544,33 +541,29 @@ class NopadLlamaAttention(LlamaAttention, ParallelModule):
|
|||
output_tensor=output_tensor,
|
||||
use_spec_dec=is_verifier,
|
||||
use_alibi_attn=False,
|
||||
use_cuda_kernel=use_cuda_kernel,
|
||||
)
|
||||
|
||||
attention_backend = get_attention_backend(use_spec_dec=is_verifier, use_cuda_kernel=use_cuda_kernel, dtype=query_states.dtype)
|
||||
pre_attention_backend = get_pre_attention_backend(use_spec_dec=is_verifier, use_cuda_kernel=use_cuda_kernel, dtype=query_states.dtype)
|
||||
|
||||
if is_prompts: # prefilling stage
|
||||
pre_attention_backend.prefill(
|
||||
self.pre_attention_backend.prefill(
|
||||
attn_metadata,
|
||||
cos=cos_sin[0],
|
||||
sin=cos_sin[1],
|
||||
high_precision=high_precision,
|
||||
)
|
||||
attn_output = attention_backend.prefill(
|
||||
attn_output = self.attention_backend.prefill(
|
||||
attn_metadata,
|
||||
token_nums=token_nums,
|
||||
)
|
||||
else: # decoding stage
|
||||
q_len = tokens_to_verify + 1 if is_verifier else 1
|
||||
|
||||
pre_attention_backend.decode(
|
||||
self.pre_attention_backend.decode(
|
||||
attn_metadata,
|
||||
cos=cos_sin[0],
|
||||
sin=cos_sin[1],
|
||||
q_len=q_len,
|
||||
)
|
||||
attn_output = attention_backend.decode(
|
||||
attn_output = self.attention_backend.decode(
|
||||
attn_metadata,
|
||||
fd_inter_tensor=fd_inter_tensor,
|
||||
num_key_value_groups=self.num_key_value_groups,
|
||||
|
@ -633,4 +626,3 @@ class NopadLlamaAttention(LlamaAttention, ParallelModule):
|
|||
|
||||
def extra_repr(self) -> str:
|
||||
return f"qkv_weight_proj MergedLinear1D_Col: in_features={self.qkv_weight.shape[1]}x3, out_features={self.qkv_weight.shape[2]}, bias=False"
|
||||
|
|
@ -70,6 +70,9 @@ class NoPaddingBaichuanModelInferPolicy(LlamaForCausalLMPolicy, RPC_PARAM):
|
|||
SubModuleReplacementDescription(
|
||||
suffix="self_attn",
|
||||
target_module=NopadBaichuanAttention,
|
||||
kwargs={
|
||||
"model_shard_infer_config": self.shard_config.extra_kwargs["model_shard_infer_config"],
|
||||
},
|
||||
),
|
||||
],
|
||||
)
|
||||
|
|
|
@ -72,6 +72,9 @@ class NoPaddingLlamaModelInferPolicy(LlamaForCausalLMPolicy, RPC_PARAM):
|
|||
SubModuleReplacementDescription(
|
||||
suffix="self_attn",
|
||||
target_module=NopadLlamaAttention,
|
||||
kwargs={
|
||||
"model_shard_infer_config": self.shard_config.extra_kwargs["model_shard_infer_config"],
|
||||
},
|
||||
),
|
||||
],
|
||||
)
|
||||
|
|
|
@ -1,17 +1,17 @@
|
|||
"""
|
||||
Utils for model inference
|
||||
"""
|
||||
import math
|
||||
import os
|
||||
import re
|
||||
import math
|
||||
from pathlib import Path
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from colossalai.testing import free_port
|
||||
from colossalai.logging import get_dist_logger
|
||||
from colossalai.testing import free_port
|
||||
|
||||
logger = get_dist_logger(__name__)
|
||||
|
||||
|
@ -149,12 +149,9 @@ def can_use_flash_attn2(dtype: torch.dtype) -> bool:
|
|||
Check flash attention2 availability.
|
||||
"""
|
||||
if dtype not in (torch.float16, torch.bfloat16):
|
||||
logger.warning(f"Flash attn2 currently only supports float16 and bfloat16.")
|
||||
return False
|
||||
|
||||
try:
|
||||
from flash_attn import __version__
|
||||
logger.info(f"flash_attn2 version {__version__}.")
|
||||
return True
|
||||
except ImportError:
|
||||
logger.warning(f"flash_attn2 has not been installed yet, we will use triton flash attn instead.")
|
||||
|
|
|
@ -55,7 +55,7 @@ def check_inference_engine(use_engine=False, do_sample=False, use_cuda_kernel=Fa
|
|||
assert inference_engine.generation_config.max_new_tokens == output_len
|
||||
inference_engine.add_request(prompts=inputs)
|
||||
assert inference_engine.request_handler._has_waiting()
|
||||
generation_config = GenerationConfig(do_sample=do_sample, top_p=top_p, top_k=top_k)
|
||||
generation_config = GenerationConfig(do_sample=do_sample, top_p=top_p, top_k=top_k, max_new_tokens=output_len)
|
||||
outputs = inference_engine.generate(generation_config=generation_config)
|
||||
else:
|
||||
if prompt_template:
|
||||
|
|
Loading…
Reference in New Issue