[hotfix] Fix import error: colossal.kernel without triton installed (#4722)

* [hotfix] remove triton kernels from kernel init

* revise bloom/llama kernel imports for infer
pull/4727/head
Yuanheng Zhao 2023-09-14 18:03:55 +08:00 committed by GitHub
parent c7d6975d29
commit e2c0e7f92a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 28 additions and 30 deletions

View File

@ -17,9 +17,7 @@ from transformers.models.bloom.modeling_bloom import (
from transformers.utils import logging from transformers.utils import logging
from colossalai.inference.tensor_parallel.batch_infer_state import BatchInferState from colossalai.inference.tensor_parallel.batch_infer_state import BatchInferState
from colossalai.kernel.triton.context_attention import bloom_context_attn_fwd from colossalai.kernel.triton import bloom_context_attn_fwd, copy_kv_cache_to_dest, token_attention_fwd
from colossalai.kernel.triton.copy_kv_cache_dest import copy_kv_cache_to_dest
from colossalai.kernel.triton.token_attention_kernel import token_attention_fwd
def generate_alibi(n_head, dtype=torch.float16): def generate_alibi(n_head, dtype=torch.float16):

View File

@ -6,10 +6,12 @@ from transformers.modeling_outputs import BaseModelOutputWithPast
from transformers.models.llama.modeling_llama import LlamaAttention, LlamaDecoderLayer, LlamaModel, LlamaRMSNorm from transformers.models.llama.modeling_llama import LlamaAttention, LlamaDecoderLayer, LlamaModel, LlamaRMSNorm
from colossalai.inference.tensor_parallel.batch_infer_state import BatchInferState from colossalai.inference.tensor_parallel.batch_infer_state import BatchInferState
from colossalai.kernel.triton.context_attention import llama_context_attn_fwd from colossalai.kernel.triton import (
from colossalai.kernel.triton.copy_kv_cache_dest import copy_kv_cache_to_dest copy_kv_cache_to_dest,
from colossalai.kernel.triton.rotary_embedding_kernel import rotary_embedding_fwd llama_context_attn_fwd,
from colossalai.kernel.triton.token_attention_kernel import token_attention_fwd rotary_embedding_fwd,
token_attention_fwd,
)
try: try:
from vllm import layernorm_ops, pos_encoding_ops from vllm import layernorm_ops, pos_encoding_ops

View File

@ -8,10 +8,10 @@ from colossalai.shardformer.policies.bloom import BloomForCausalLMPolicy
from ..modeling.bloom import BloomInferenceForwards from ..modeling.bloom import BloomInferenceForwards
try: try:
from colossalai.kernel.triton.fused_layernorm import layer_norm from colossalai.kernel.triton import layer_norm
HAS_TRITON_NORM = True HAS_TRITON_NORM = True
except: except:
print("you should install triton from https://github.com/openai/triton") print("Some of our kernels require triton. You might want to install triton from https://github.com/openai/triton")
HAS_TRITON_NORM = False HAS_TRITON_NORM = False

View File

@ -1,18 +1,15 @@
from functools import partial from functools import partial
import torch import torch
from transformers.models.llama.modeling_llama import ( from transformers.models.llama.modeling_llama import LlamaAttention, LlamaDecoderLayer, LlamaModel, LlamaRMSNorm
LlamaAttention,
LlamaDecoderLayer,
LlamaModel,
LlamaRMSNorm
)
# import colossalai # import colossalai
from colossalai.shardformer.policies.llama import LlamaForCausalLMPolicy from colossalai.shardformer.policies.llama import LlamaForCausalLMPolicy
from ..modeling.llama import LlamaInferenceForwards, get_llama_vllm_rmsnorm_forward from ..modeling.llama import LlamaInferenceForwards, get_llama_vllm_rmsnorm_forward
try: try:
from colossalai.kernel.triton.rms_norm import rmsnorm_forward from colossalai.kernel.triton import rmsnorm_forward
HAS_TRITON_RMSNORM = True HAS_TRITON_RMSNORM = True
except: except:
print("you should install triton from https://github.com/openai/triton") print("you should install triton from https://github.com/openai/triton")
@ -21,6 +18,7 @@ except:
def get_triton_rmsnorm_forward(): def get_triton_rmsnorm_forward():
if HAS_TRITON_RMSNORM: if HAS_TRITON_RMSNORM:
def _triton_rmsnorm_forward(self: LlamaRMSNorm, hidden_states: torch.Tensor): def _triton_rmsnorm_forward(self: LlamaRMSNorm, hidden_states: torch.Tensor):
return rmsnorm_forward(hidden_states, self.weight.data, self.variance_epsilon) return rmsnorm_forward(hidden_states, self.weight.data, self.variance_epsilon)
@ -28,6 +26,7 @@ def get_triton_rmsnorm_forward():
else: else:
return None return None
class LlamaModelInferPolicy(LlamaForCausalLMPolicy): class LlamaModelInferPolicy(LlamaForCausalLMPolicy):
def __init__(self) -> None: def __init__(self) -> None:
@ -67,4 +66,3 @@ class LlamaModelInferPolicy(LlamaForCausalLMPolicy):
target_key=LlamaRMSNorm) target_key=LlamaRMSNorm)
return policy return policy

View File

@ -1,14 +1,7 @@
from .cuda_native import FusedScaleMaskSoftmax, LayerNorm, MultiHeadAttention from .cuda_native import FusedScaleMaskSoftmax, LayerNorm, MultiHeadAttention
from .triton import llama_context_attn_fwd, bloom_context_attn_fwd
from .triton import softmax
from .triton import copy_kv_cache_to_dest
__all__ = [ __all__ = [
"LayerNorm", "LayerNorm",
"FusedScaleMaskSoftmax", "FusedScaleMaskSoftmax",
"MultiHeadAttention", "MultiHeadAttention",
"llama_context_attn_fwd",
"bloom_context_attn_fwd",
"softmax",
"copy_kv_cache_to_dest",
] ]

View File

@ -2,4 +2,11 @@ from .context_attention import bloom_context_attn_fwd, llama_context_attn_fwd
from .copy_kv_cache_dest import copy_kv_cache_to_dest from .copy_kv_cache_dest import copy_kv_cache_to_dest
from .fused_layernorm import layer_norm from .fused_layernorm import layer_norm
from .rms_norm import rmsnorm_forward from .rms_norm import rmsnorm_forward
from .rotary_embedding_kernel import rotary_embedding_fwd
from .softmax import softmax from .softmax import softmax
from .token_attention_kernel import token_attention_fwd
__all__ = [
"llama_context_attn_fwd", "bloom_context_attn_fwd", "softmax", "layer_norm", "rmsnorm_forward",
"copy_kv_cache_to_dest", "rotary_embedding_fwd", "token_attention_fwd"
]