Pass inference model shard configs for module init

Signed-off-by: char-1ee <xingjianli59@gmail.com>
pull/5771/head
char-1ee 2024-06-07 08:28:19 +00:00
parent eec77e5702
commit 5f398fc000
11 changed files with 238 additions and 136 deletions

Binary file not shown.

View File

@ -10,6 +10,7 @@ import torch
from transformers.generation import GenerationConfig
from colossalai.inference.flash_decoding_utils import FDIntermTensors
from colossalai.inference.utils import can_use_flash_attn2
GibiByte = 1024**3
@ -313,12 +314,13 @@ class InferenceConfig(RPC_PARAM):
return GenerationConfig.from_dict(meta_config)
def to_model_inference_config(self) -> "ModelInferenceConfig":
model_inference_config = ModelInferenceConfig(
def to_model_shard_inference_config(self) -> "ModelShardInferenceConfig":
use_flash_attn = can_use_flash_attn2(self.dtype)
model_inference_config = ModelShardInferenceConfig(
dtype=self.dtype,
use_cuda_kernel=self.use_cuda_kernel,
use_spec_dec=self.use_spec_dec,
use_cuda_graph=self.use_cuda_graph,
use_flash_attn=use_flash_attn,
)
return model_inference_config
@ -374,21 +376,20 @@ class InferenceConfig(RPC_PARAM):
inference_config = cls(**inference_config_args)
return inference_config
@dataclass
class ModelInferenceConfig():
class ModelShardInferenceConfig:
"""
Configurations used when initializing/sharding model for inference.
Configurations used during init of module for inference modeling.
Args:
dtype (torch.dtype): The data type for weights and activations.
use_cuda_kernel (bool): Whether to use cuda kernel, faster but lose some precision occasionally
use_spec_dec (bool): Indicate whether to use speculative decoding.
use_flash_attn (bool): Indicate whether to use flash attention.
use_cuda_graph (bool): Whether to enforce CUDA graph execution. If False, we will disable CUDA graph and always execute the model in eager mode. If True, we will use eager execution in hybrid.
"""
dtype: torch.dtype = None
use_cuda_kernel: bool = False
use_spec_dec: bool = False
use_flash_attn: bool = False
use_cuda_graph: bool = False

View File

@ -18,7 +18,7 @@ from transformers.models.llama.modeling_llama import LlamaForCausalLM
from colossalai.accelerator import get_accelerator
from colossalai.cluster import ProcessGroupMesh
from colossalai.inference.batch_bucket import BatchBucket
from colossalai.inference.config import InferenceConfig, InputMetaData
from colossalai.inference.config import InferenceConfig, InputMetaData, ModelShardInferenceConfig
from colossalai.inference.graph_runner import CUDAGraphRunner
from colossalai.inference.modeling.policy import model_policy_map
from colossalai.inference.sampler import search_tokens
@ -72,8 +72,9 @@ class InferenceEngine:
self.verbose = verbose
self.logger = get_dist_logger(__name__)
self.model_shard_infer_config = inference_config.to_model_shard_inference_config()
self.init_model(model_or_path, model_policy)
self.init_model(model_or_path, model_policy, self.model_shard_infer_config)
self.generation_config = inference_config.to_generation_config(self.model_config)
self.generation_config_dict = self.generation_config.to_dict()
@ -99,8 +100,6 @@ class InferenceEngine:
# Model and relatable attrs of speculative decoding will be set by `enable_spec_dec`
self.use_spec_dec = self.inference_config.use_spec_dec
# TODO: when use_spec_dec set to True, users should pass drafter_model configs into InferenceEngine
# We can add a SpecDecConfig class to store these configs.
self.drafter_model = None
self.drafter = None
self.use_glide = False
@ -112,6 +111,7 @@ class InferenceEngine:
self,
model_or_path: Union[nn.Module, str],
model_policy: Union[Policy, Type[Policy]] = None,
model_shard_infer_config: ModelShardInferenceConfig = None,
):
"""
Shard model or/and Load weight
@ -120,6 +120,7 @@ class InferenceEngine:
model_or_path Union[nn.Module, str]: path to the checkpoint or model of transformer format.
model_policy (Policy): the policy to replace the model.
model_inference_config: the configuration for modeling initialization when inference.
model_shard_infer_config (ModelShardInferenceConfig): the configuration for init of module when inference.
"""
if isinstance(model_or_path, str):
@ -176,6 +177,7 @@ class InferenceEngine:
self.model = self._shardformer(
model,
model_policy,
model_shard_infer_config,
None,
tp_group=tp_group,
)
@ -296,6 +298,7 @@ class InferenceEngine:
self,
model: nn.Module,
model_policy: Policy,
model_shard_infer_config: ModelShardInferenceConfig = None,
stage_manager: PipelineStageManager = None,
tp_group: ProcessGroupMesh = None,
) -> nn.Module:
@ -321,6 +324,7 @@ class InferenceEngine:
enable_flash_attention=False,
enable_jit_fused=False,
enable_sequence_parallelism=False,
extra_kwargs={"model_shard_infer_config": model_shard_infer_config},
)
shardformer = ShardFormer(shard_config=shardconfig)
shard_model, _ = shardformer.optimize(model, model_policy)
@ -357,7 +361,6 @@ class InferenceEngine:
engine.clear_spec_dec()
```
"""
self.logger.warning(f"Current method will be deprecated soon. To use speculative decoding, please set `use_spec_dec` in `InferenceConfig` instead.")
if drafter_model is None and self.drafter is None:
raise ValueError("Drafter not initialized. Please provide a Drafter Model")

View File

@ -1,19 +1,12 @@
from abc import ABC, abstractmethod
from dataclasses import dataclass
from flash_attn import flash_attn_varlen_func
import torch
from flash_attn import flash_attn_varlen_func
from colossalai.inference.config import InputMetaData
from colossalai.inference.utils import can_use_flash_attn2
from colossalai.logging import get_dist_logger
from colossalai.inference.config import ModelShardInferenceConfig
from colossalai.kernel.kernel_loader import InferenceOpsLoader
from colossalai.kernel.triton import (
context_attention_unpadded,
flash_decoding_attention,
)
logger = get_dist_logger(__name__)
inference_ops = InferenceOpsLoader().load()
from colossalai.kernel.triton import context_attention_unpadded, flash_decoding_attention
@dataclass
@ -33,7 +26,6 @@ class AttentionMetaData:
output_tensor: torch.Tensor = None
use_spec_dec: bool = False
use_alibi_attn: bool = False
use_cuda_kernel: bool = False
class AttentionBackend(ABC):
@ -46,7 +38,16 @@ class AttentionBackend(ABC):
raise NotImplementedError
class CudaAttentionBackend(AttentionBackend):
class FlashAttentionBackend(AttentionBackend):
"""
Attention backend when use_cuda_kernel is True and flash-attn is installed. It uses
`flash_attn_varlen_func` for prefilling and our cuda op `flash_decoding_attention` for decoding.
"""
def __init__(self):
super().__init__()
self.inference_ops = InferenceOpsLoader().load()
def prefill(self, attn_metadata: AttentionMetaData, **kwargs):
token_nums = kwargs.get("token_nums", -1)
@ -69,7 +70,55 @@ class CudaAttentionBackend(AttentionBackend):
def decode(self, attn_metadata: AttentionMetaData, **kwargs):
fd_inter_tensor = kwargs.get("fd_inter_tensor", None)
output_tensor = attn_metadata.output_tensor
inference_ops.flash_decoding_attention(
self.inference_ops.flash_decoding_attention(
output_tensor,
attn_metadata.query_states,
attn_metadata.k_cache,
attn_metadata.v_cache,
attn_metadata.sequence_lengths,
attn_metadata.block_tables,
attn_metadata.block_size,
attn_metadata.kv_seq_len,
fd_inter_tensor.mid_output,
fd_inter_tensor.exp_sums,
fd_inter_tensor.max_logits,
attn_metadata.alibi_slopes,
attn_metadata.sm_scale,
)
return output_tensor
class CudaAttentionBackend(AttentionBackend):
"""
Attention backend when use_cuda_kernel is True but flash-attn not found. If flash-attn is not found,
it uses Triton op `context_attention_unpadded` for prefilling and our cuda op `flash_decoding_attention` for decoding.
"""
def __init__(self):
super().__init__()
self.inference_ops = InferenceOpsLoader().load()
def prefill(self, attn_metadata: AttentionMetaData, **kwargs):
return context_attention_unpadded(
q=attn_metadata.query_states,
k=attn_metadata.key_states,
v=attn_metadata.value_states,
k_cache=attn_metadata.k_cache,
v_cache=attn_metadata.v_cache,
context_lengths=attn_metadata.sequence_lengths,
block_tables=attn_metadata.block_tables,
block_size=attn_metadata.block_size,
output=attn_metadata.output_tensor,
alibi_slopes=attn_metadata.alibi_slopes,
max_seq_len=attn_metadata.kv_seq_len,
sm_scale=attn_metadata.sm_scale,
use_new_kcache_layout=True, # use new k cache layout for cuda kernels in this triton op
)
def decode(self, attn_metadata: AttentionMetaData, **kwargs):
fd_inter_tensor = kwargs.get("fd_inter_tensor", None)
output_tensor = attn_metadata.output_tensor
self.inference_ops.flash_decoding_attention(
output_tensor,
attn_metadata.query_states,
attn_metadata.k_cache,
@ -88,6 +137,10 @@ class CudaAttentionBackend(AttentionBackend):
class TritonAttentionBackend(AttentionBackend):
"""
Attention backend when use_cuda_kernel is False. It uses pure Triton ops for prefilling and decoding.
"""
def prefill(self, attn_metadata: AttentionMetaData, **kwargs):
return context_attention_unpadded(
q=attn_metadata.query_states,
@ -102,7 +155,7 @@ class TritonAttentionBackend(AttentionBackend):
alibi_slopes=attn_metadata.alibi_slopes,
max_seq_len=attn_metadata.kv_seq_len,
sm_scale=attn_metadata.sm_scale,
use_new_kcache_layout=attn_metadata.use_cuda_kernel,
use_new_kcache_layout=False,
)
def decode(self, attn_metadata: AttentionMetaData, **kwargs):
@ -126,17 +179,24 @@ class TritonAttentionBackend(AttentionBackend):
def get_attention_backend(
use_spec_dec: bool, use_cuda_kernel: bool, dtype: torch.dtype
model_shard_infer_config: ModelShardInferenceConfig,
) -> AttentionBackend:
"""
Get the attention backend based on the inference configurations. Only when:
Get the attention backend based on the inference configurations. The modeling will use CUDA-kernel-based backend
for attention module calculation only when:
1. using CUDA kernel (use_cuda_kernel=True)
2. can use flash attention (flash-attn installed and dtype is fp16 or bf16)
3. not using speculative decoding (currently cuda kernel not support speculative decoding)
will the CUDA-kernel-based backend be used for attention layer computations. Otherwise, use Triton attention backend.
Otherwise, use Triton attention backend. If found flash-attn not installed while `use_cuda_kernel` is True,
the Triton backend will use a new k cache layout for Triton kernels.
"""
use_flash_attn = can_use_flash_attn2(dtype)
if use_cuda_kernel and use_flash_attn and not use_spec_dec:
return CudaAttentionBackend()
else:
# Currently only triton kernels support speculative decoding
if model_shard_infer_config.use_spec_dec:
return TritonAttentionBackend()
if model_shard_infer_config.use_cuda_kernel:
if model_shard_infer_config.use_flash_attn:
return FlashAttentionBackend()
return CudaAttentionBackend()
return TritonAttentionBackend()

View File

@ -1,18 +1,9 @@
from abc import ABC, abstractmethod
import torch
from colossalai.inference.utils import can_use_flash_attn2
from colossalai.kernel.kernel_loader import InferenceOpsLoader
from colossalai.inference.config import ModelShardInferenceConfig
from colossalai.inference.modeling.backends.attention_backend import AttentionMetaData
from colossalai.logging import get_dist_logger
from colossalai.kernel.triton import (
copy_k_to_blocked_cache,
decoding_fused_rotary_embedding,
rotary_embedding,
)
logger = get_dist_logger(__name__)
inference_ops = InferenceOpsLoader().load()
from colossalai.kernel.kernel_loader import InferenceOpsLoader
from colossalai.kernel.triton import copy_k_to_blocked_cache, decoding_fused_rotary_embedding, rotary_embedding
class PreAttentionBackend(ABC):
@ -25,17 +16,25 @@ class PreAttentionBackend(ABC):
raise NotImplementedError
class CudaPreAttentionBackend(PreAttentionBackend):
class FlashPreAttentionBackend(PreAttentionBackend):
"""
FlashPreAttentionBackend handles KV cache initialization and positional encoding for FlashAttentionBackend.
"""
def __init__(self):
super().__init__()
self.inference_ops = InferenceOpsLoader().load()
def prefill(self, attn_metadata: AttentionMetaData, **kwargs):
if not attn_metadata.use_alibi_attn:
inference_ops.rotary_embedding(
self.inference_ops.rotary_embedding(
attn_metadata.query_states,
attn_metadata.key_states,
kwargs.get("cos", None),
kwargs.get("sin", None),
kwargs.get("high_precision", False),
)
inference_ops.context_kv_cache_memcpy(
self.inference_ops.context_kv_cache_memcpy(
attn_metadata.key_states,
attn_metadata.value_states,
attn_metadata.k_cache,
@ -48,7 +47,7 @@ class CudaPreAttentionBackend(PreAttentionBackend):
def decode(self, attn_metadata: AttentionMetaData, **kwargs):
if not attn_metadata.use_alibi_attn:
inference_ops.rotary_embedding_and_cache_copy(
self.inference_ops.rotary_embedding_and_cache_copy(
attn_metadata.query_states,
attn_metadata.key_states,
attn_metadata.value_states,
@ -61,7 +60,50 @@ class CudaPreAttentionBackend(PreAttentionBackend):
kwargs.get("high_precision", None),
)
else:
inference_ops.decode_kv_cache_memcpy(
self.inference_ops.decode_kv_cache_memcpy(
attn_metadata.key_states,
attn_metadata.value_states,
attn_metadata.k_cache,
attn_metadata.v_cache,
attn_metadata.sequence_lengths,
attn_metadata.block_tables,
)
class CudaPreAttentionBackend(PreAttentionBackend):
"""
CudaPreAttentionBackend handles KV cache initialization and positional encoding for CudaAttentionBackend.
"""
def __init__(self):
super().__init__()
self.inference_ops = InferenceOpsLoader().load()
def prefill(self, attn_metadata: AttentionMetaData, **kwargs):
if not attn_metadata.use_alibi_attn:
rotary_embedding(
attn_metadata.query_states,
attn_metadata.key_states,
kwargs.get("cos", None),
kwargs.get("sin", None),
)
def decode(self, attn_metadata: AttentionMetaData, **kwargs):
if not attn_metadata.use_alibi_attn:
self.inference_ops.rotary_embedding_and_cache_copy(
attn_metadata.query_states,
attn_metadata.key_states,
attn_metadata.value_states,
kwargs.get("cos", None),
kwargs.get("sin", None),
attn_metadata.k_cache,
attn_metadata.v_cache,
attn_metadata.sequence_lengths,
attn_metadata.block_tables,
kwargs.get("high_precision", None),
)
else:
self.inference_ops.decode_kv_cache_memcpy(
attn_metadata.key_states,
attn_metadata.value_states,
attn_metadata.k_cache,
@ -72,6 +114,10 @@ class CudaPreAttentionBackend(PreAttentionBackend):
class TritonPreAttentionBackend(PreAttentionBackend):
"""
TritonPreAttentionBackend handles KV cache initialization and positional encoding for TritonAttentionBackend.
"""
def prefill(self, attn_metadata: AttentionMetaData, **kwargs):
if not attn_metadata.use_alibi_attn:
rotary_embedding(
@ -119,13 +165,18 @@ class TritonPreAttentionBackend(PreAttentionBackend):
def get_pre_attention_backend(
use_spec_dec: bool, use_cuda_kernel: bool, dtype: torch.dtype
model_shard_infer_config: ModelShardInferenceConfig,
) -> PreAttentionBackend:
"""
Get the backend for pre-attention computations, including potisional encoding like RoPE and KV cache initialization.
Get the backend for pre-attention computations, including potisional encoding like
RoPE and KV cache initialization. It adopt the same selection logic as attention_backend/get_attention_backend.
"""
use_flash_attn = can_use_flash_attn2(dtype)
if use_cuda_kernel and use_flash_attn and not use_spec_dec:
return CudaPreAttentionBackend()
else:
if model_shard_infer_config.use_spec_dec:
return TritonPreAttentionBackend()
if model_shard_infer_config.use_cuda_kernel:
if model_shard_infer_config.use_flash_attn:
return FlashPreAttentionBackend()
return CudaPreAttentionBackend()
return TritonPreAttentionBackend()

View File

@ -1,31 +1,23 @@
# This code is adapted from huggingface baichuan model: hhttps://huggingface.co/baichuan-inc/Baichuan2-13B-Base/blob/main/modeling_baichuan.py
import itertools
import math
from typing import List, Optional, Tuple, Union
import torch
import torch.nn as nn
from torch.distributed import ProcessGroup
from colossalai.inference.config import ModelShardInferenceConfig
from colossalai.inference.flash_decoding_utils import FDIntermTensors
from colossalai.inference.utils import get_alibi_slopes
from colossalai.inference.modeling.models.nopadding_llama import NopadLlamaMLP
from colossalai.inference.modeling.backends.attention_backend import AttentionMetaData, get_attention_backend
from colossalai.inference.modeling.backends.pre_attention_backend import get_pre_attention_backend
from colossalai.inference.modeling.models.nopadding_llama import NopadLlamaMLP
from colossalai.inference.utils import get_alibi_slopes
from colossalai.kernel.kernel_loader import InferenceOpsLoader
from colossalai.kernel.triton import (
context_attention_unpadded,
copy_k_to_blocked_cache,
decoding_fused_rotary_embedding,
flash_decoding_attention,
rms_layernorm,
rotary_embedding,
)
from colossalai.kernel.triton import rms_layernorm
from colossalai.logging import get_dist_logger
from colossalai.shardformer.layer.parallel_module import ParallelModule
from colossalai.tensor.d_tensor import Layout, distribute_tensor, is_distributed_tensor
inference_ops = InferenceOpsLoader().load()
logger = get_dist_logger(__name__)
@ -69,6 +61,7 @@ class NopadBaichuanAttention(ParallelModule):
attn_oproj: ParallelModule = None,
num_heads: int = None,
hidden_size: int = None,
model_shard_infer_config: ModelShardInferenceConfig = None,
process_group: ProcessGroup = None,
helper_layout: Layout = None,
):
@ -93,6 +86,9 @@ class NopadBaichuanAttention(ParallelModule):
self.qkv_weight = nn.Parameter(torch.stack(qkv_weight_list, dim=0))
self.helper_layout = helper_layout
self.use_cuda_kernel = model_shard_infer_config.use_cuda_kernel
self.attention_backend = get_attention_backend(model_shard_infer_config)
self.pre_attention_backend = get_pre_attention_backend(model_shard_infer_config)
self.alibi_slopes = None
self.use_alibi_attn = False
@ -122,6 +118,7 @@ class NopadBaichuanAttention(ParallelModule):
attn_kproj_w = k_proj_w
attn_vproj_w = v_proj_w
attn_oproj = module.o_proj
model_shard_infer_config = kwargs.get("model_shard_infer_config", None)
helper_layout = (
module.W_pack.weight.dist_layout
@ -133,6 +130,7 @@ class NopadBaichuanAttention(ParallelModule):
attn_kproj_w=attn_kproj_w,
attn_vproj_w=attn_vproj_w,
attn_oproj=attn_oproj,
model_shard_infer_config=model_shard_infer_config,
num_heads=module.num_heads,
hidden_size=module.hidden_size,
process_group=process_group,
@ -201,7 +199,6 @@ class NopadBaichuanAttention(ParallelModule):
kv_seq_len: int = 0,
output_tensor: torch.Tensor = None,
sm_scale: int = None,
use_cuda_kernel: bool = True,
cu_seqlens: torch.Tensor = None,
high_precision: bool = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
@ -220,7 +217,6 @@ class NopadBaichuanAttention(ParallelModule):
kv_seq_len (int, optional): The max sequence length of input sequences. Defaults to 0.
output_tensor (torch.Tensor, optional): The mid tensor holds the output of attention. Defaults to None.
sm_scale (int, optional): Used for flash attention. Defaults to None.
use_cuda_kernel: (bool, optional): Whether to use cuda kernel. Defaults to True.
cu_seqlens(torch.Tensor, optional): Holding the cumulative sum of sequence length.
high_precision(Optional[bool]): Whether to use float32 for underlying calculations of float16 data to achieve higher precision, defaults to False.
"""
@ -250,33 +246,29 @@ class NopadBaichuanAttention(ParallelModule):
output_tensor=output_tensor,
use_spec_dec=is_verifier,
use_alibi_attn=self.use_alibi_attn,
use_cuda_kernel=use_cuda_kernel,
)
attention_backend = get_attention_backend(use_spec_dec=is_verifier, use_cuda_kernel=use_cuda_kernel, dtype=query_states.dtype)
pre_attention_backend = get_pre_attention_backend(use_spec_dec=is_verifier, use_cuda_kernel=use_cuda_kernel, dtype=query_states.dtype)
if is_prompts: # prefilling stage
pre_attention_backend.prefill(
self.pre_attention_backend.prefill(
attn_metadata,
cos=cos_sin[0],
sin=cos_sin[1],
high_precision=high_precision,
)
attn_output = attention_backend.prefill(
attn_output = self.attention_backend.prefill(
attn_metadata,
token_nums=token_nums,
)
else: # decoding stage
q_len = tokens_to_verify + 1 if is_verifier else 1
pre_attention_backend.decode(
self.pre_attention_backend.decode(
attn_metadata,
cos=cos_sin[0],
sin=cos_sin[1],
q_len=q_len,
)
attn_output = attention_backend.decode(
attn_output = self.attention_backend.decode(
attn_metadata,
fd_inter_tensor=fd_inter_tensor,
q_len=q_len,

View File

@ -16,21 +16,13 @@ from transformers.models.llama.modeling_llama import (
LlamaRMSNorm,
)
from colossalai.inference.config import InputMetaData
from colossalai.inference.config import InputMetaData, ModelShardInferenceConfig
from colossalai.inference.flash_decoding_utils import FDIntermTensors
from colossalai.inference.modeling.backends.attention_backend import get_attention_backend, AttentionMetaData
from colossalai.inference.modeling.backends.attention_backend import AttentionMetaData, get_attention_backend
from colossalai.inference.modeling.backends.pre_attention_backend import get_pre_attention_backend
from colossalai.inference.utils import can_use_flash_attn2
from colossalai.kernel.kernel_loader import InferenceOpsLoader
from colossalai.kernel.triton import (
context_attention_unpadded,
copy_k_to_blocked_cache,
decoding_fused_rotary_embedding,
flash_decoding_attention,
get_xine_cache,
rms_layernorm,
rotary_embedding,
)
from colossalai.kernel.triton import get_xine_cache, rms_layernorm
from colossalai.logging import get_dist_logger
from colossalai.shardformer.layer.parallel_module import ParallelModule
from colossalai.tensor.d_tensor import distribute_tensor, is_distributed_tensor
@ -233,7 +225,6 @@ def llama_decoder_layer_forward(
kv_seq_len=kv_seq_len,
output_tensor=output_tensor,
sm_scale=sm_scale,
use_cuda_kernel=use_cuda_kernel,
cu_seqlens=cu_seqlens,
high_precision=high_precision,
)
@ -397,6 +388,7 @@ class NopadLlamaAttention(LlamaAttention, ParallelModule):
attn_vproj_w: torch.Tensor = None,
attn_oproj: ParallelModule = None,
process_group: ProcessGroup = None,
model_shard_infer_config: ModelShardInferenceConfig = None,
num_heads: int = None,
hidden_size: int = None,
num_key_value_heads: int = None,
@ -428,6 +420,9 @@ class NopadLlamaAttention(LlamaAttention, ParallelModule):
self.rope_theta = config.rope_theta
self.is_causal = True
self.attention_backend = get_attention_backend(model_shard_infer_config)
self.pre_attention_backend = get_pre_attention_backend(model_shard_infer_config)
if self.num_heads == self.num_key_value_heads:
qkv_weight_list = [attn_qproj_w.transpose(0, 1), attn_kproj_w.transpose(0, 1), attn_vproj_w.transpose(0, 1)]
self.qkv_weight = nn.Parameter(torch.stack(qkv_weight_list, dim=0))
@ -457,6 +452,7 @@ class NopadLlamaAttention(LlamaAttention, ParallelModule):
attn_vproj_w = module.v_proj.weight
assert is_distributed_tensor(attn_qproj_w), "attn_qproj_w must be dist tensor"
attn_oproj = module.o_proj
model_shard_infer_config = kwargs.get("model_shard_infer_config", None)
attn_layer = NopadLlamaAttention(
config=config,
@ -466,6 +462,7 @@ class NopadLlamaAttention(LlamaAttention, ParallelModule):
attn_vproj_w=attn_vproj_w,
attn_oproj=attn_oproj,
process_group=process_group,
model_shard_infer_config=model_shard_infer_config,
num_heads=module.num_heads,
hidden_size=module.hidden_size,
num_key_value_heads=module.num_key_value_heads,
@ -544,33 +541,29 @@ class NopadLlamaAttention(LlamaAttention, ParallelModule):
output_tensor=output_tensor,
use_spec_dec=is_verifier,
use_alibi_attn=False,
use_cuda_kernel=use_cuda_kernel,
)
attention_backend = get_attention_backend(use_spec_dec=is_verifier, use_cuda_kernel=use_cuda_kernel, dtype=query_states.dtype)
pre_attention_backend = get_pre_attention_backend(use_spec_dec=is_verifier, use_cuda_kernel=use_cuda_kernel, dtype=query_states.dtype)
if is_prompts: # prefilling stage
pre_attention_backend.prefill(
self.pre_attention_backend.prefill(
attn_metadata,
cos=cos_sin[0],
sin=cos_sin[1],
high_precision=high_precision,
)
attn_output = attention_backend.prefill(
attn_output = self.attention_backend.prefill(
attn_metadata,
token_nums=token_nums,
)
else: # decoding stage
q_len = tokens_to_verify + 1 if is_verifier else 1
pre_attention_backend.decode(
self.pre_attention_backend.decode(
attn_metadata,
cos=cos_sin[0],
sin=cos_sin[1],
q_len=q_len,
)
attn_output = attention_backend.decode(
attn_output = self.attention_backend.decode(
attn_metadata,
fd_inter_tensor=fd_inter_tensor,
num_key_value_groups=self.num_key_value_groups,
@ -633,4 +626,3 @@ class NopadLlamaAttention(LlamaAttention, ParallelModule):
def extra_repr(self) -> str:
return f"qkv_weight_proj MergedLinear1D_Col: in_features={self.qkv_weight.shape[1]}x3, out_features={self.qkv_weight.shape[2]}, bias=False"

View File

@ -70,6 +70,9 @@ class NoPaddingBaichuanModelInferPolicy(LlamaForCausalLMPolicy, RPC_PARAM):
SubModuleReplacementDescription(
suffix="self_attn",
target_module=NopadBaichuanAttention,
kwargs={
"model_shard_infer_config": self.shard_config.extra_kwargs["model_shard_infer_config"],
},
),
],
)

View File

@ -72,6 +72,9 @@ class NoPaddingLlamaModelInferPolicy(LlamaForCausalLMPolicy, RPC_PARAM):
SubModuleReplacementDescription(
suffix="self_attn",
target_module=NopadLlamaAttention,
kwargs={
"model_shard_infer_config": self.shard_config.extra_kwargs["model_shard_infer_config"],
},
),
],
)

View File

@ -1,17 +1,17 @@
"""
Utils for model inference
"""
import math
import os
import re
import math
from pathlib import Path
from typing import Optional, Tuple
import torch
from torch import nn
from colossalai.testing import free_port
from colossalai.logging import get_dist_logger
from colossalai.testing import free_port
logger = get_dist_logger(__name__)
@ -149,12 +149,9 @@ def can_use_flash_attn2(dtype: torch.dtype) -> bool:
Check flash attention2 availability.
"""
if dtype not in (torch.float16, torch.bfloat16):
logger.warning(f"Flash attn2 currently only supports float16 and bfloat16.")
return False
try:
from flash_attn import __version__
logger.info(f"flash_attn2 version {__version__}.")
return True
except ImportError:
logger.warning(f"flash_attn2 has not been installed yet, we will use triton flash attn instead.")

View File

@ -55,7 +55,7 @@ def check_inference_engine(use_engine=False, do_sample=False, use_cuda_kernel=Fa
assert inference_engine.generation_config.max_new_tokens == output_len
inference_engine.add_request(prompts=inputs)
assert inference_engine.request_handler._has_waiting()
generation_config = GenerationConfig(do_sample=do_sample, top_p=top_p, top_k=top_k)
generation_config = GenerationConfig(do_sample=do_sample, top_p=top_p, top_k=top_k, max_new_tokens=output_len)
outputs = inference_engine.generate(generation_config=generation_config)
else:
if prompt_template: