mirror of https://github.com/hpcaitech/ColossalAI
[hotfix] Fix import error: colossal.kernel without triton installed (#4722)
* [hotfix] remove triton kernels from kernel init * revise bloom/llama kernel imports for inferpull/4727/head
parent
c7d6975d29
commit
e2c0e7f92a
|
@ -17,9 +17,7 @@ from transformers.models.bloom.modeling_bloom import (
|
||||||
from transformers.utils import logging
|
from transformers.utils import logging
|
||||||
|
|
||||||
from colossalai.inference.tensor_parallel.batch_infer_state import BatchInferState
|
from colossalai.inference.tensor_parallel.batch_infer_state import BatchInferState
|
||||||
from colossalai.kernel.triton.context_attention import bloom_context_attn_fwd
|
from colossalai.kernel.triton import bloom_context_attn_fwd, copy_kv_cache_to_dest, token_attention_fwd
|
||||||
from colossalai.kernel.triton.copy_kv_cache_dest import copy_kv_cache_to_dest
|
|
||||||
from colossalai.kernel.triton.token_attention_kernel import token_attention_fwd
|
|
||||||
|
|
||||||
|
|
||||||
def generate_alibi(n_head, dtype=torch.float16):
|
def generate_alibi(n_head, dtype=torch.float16):
|
||||||
|
|
|
@ -6,10 +6,12 @@ from transformers.modeling_outputs import BaseModelOutputWithPast
|
||||||
from transformers.models.llama.modeling_llama import LlamaAttention, LlamaDecoderLayer, LlamaModel, LlamaRMSNorm
|
from transformers.models.llama.modeling_llama import LlamaAttention, LlamaDecoderLayer, LlamaModel, LlamaRMSNorm
|
||||||
|
|
||||||
from colossalai.inference.tensor_parallel.batch_infer_state import BatchInferState
|
from colossalai.inference.tensor_parallel.batch_infer_state import BatchInferState
|
||||||
from colossalai.kernel.triton.context_attention import llama_context_attn_fwd
|
from colossalai.kernel.triton import (
|
||||||
from colossalai.kernel.triton.copy_kv_cache_dest import copy_kv_cache_to_dest
|
copy_kv_cache_to_dest,
|
||||||
from colossalai.kernel.triton.rotary_embedding_kernel import rotary_embedding_fwd
|
llama_context_attn_fwd,
|
||||||
from colossalai.kernel.triton.token_attention_kernel import token_attention_fwd
|
rotary_embedding_fwd,
|
||||||
|
token_attention_fwd,
|
||||||
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from vllm import layernorm_ops, pos_encoding_ops
|
from vllm import layernorm_ops, pos_encoding_ops
|
||||||
|
|
|
@ -8,10 +8,10 @@ from colossalai.shardformer.policies.bloom import BloomForCausalLMPolicy
|
||||||
from ..modeling.bloom import BloomInferenceForwards
|
from ..modeling.bloom import BloomInferenceForwards
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from colossalai.kernel.triton.fused_layernorm import layer_norm
|
from colossalai.kernel.triton import layer_norm
|
||||||
HAS_TRITON_NORM = True
|
HAS_TRITON_NORM = True
|
||||||
except:
|
except:
|
||||||
print("you should install triton from https://github.com/openai/triton")
|
print("Some of our kernels require triton. You might want to install triton from https://github.com/openai/triton")
|
||||||
HAS_TRITON_NORM = False
|
HAS_TRITON_NORM = False
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -1,18 +1,15 @@
|
||||||
from functools import partial
|
from functools import partial
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from transformers.models.llama.modeling_llama import (
|
from transformers.models.llama.modeling_llama import LlamaAttention, LlamaDecoderLayer, LlamaModel, LlamaRMSNorm
|
||||||
LlamaAttention,
|
|
||||||
LlamaDecoderLayer,
|
|
||||||
LlamaModel,
|
|
||||||
LlamaRMSNorm
|
|
||||||
)
|
|
||||||
|
|
||||||
# import colossalai
|
# import colossalai
|
||||||
from colossalai.shardformer.policies.llama import LlamaForCausalLMPolicy
|
from colossalai.shardformer.policies.llama import LlamaForCausalLMPolicy
|
||||||
|
|
||||||
from ..modeling.llama import LlamaInferenceForwards, get_llama_vllm_rmsnorm_forward
|
from ..modeling.llama import LlamaInferenceForwards, get_llama_vllm_rmsnorm_forward
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from colossalai.kernel.triton.rms_norm import rmsnorm_forward
|
from colossalai.kernel.triton import rmsnorm_forward
|
||||||
HAS_TRITON_RMSNORM = True
|
HAS_TRITON_RMSNORM = True
|
||||||
except:
|
except:
|
||||||
print("you should install triton from https://github.com/openai/triton")
|
print("you should install triton from https://github.com/openai/triton")
|
||||||
|
@ -21,6 +18,7 @@ except:
|
||||||
|
|
||||||
def get_triton_rmsnorm_forward():
|
def get_triton_rmsnorm_forward():
|
||||||
if HAS_TRITON_RMSNORM:
|
if HAS_TRITON_RMSNORM:
|
||||||
|
|
||||||
def _triton_rmsnorm_forward(self: LlamaRMSNorm, hidden_states: torch.Tensor):
|
def _triton_rmsnorm_forward(self: LlamaRMSNorm, hidden_states: torch.Tensor):
|
||||||
return rmsnorm_forward(hidden_states, self.weight.data, self.variance_epsilon)
|
return rmsnorm_forward(hidden_states, self.weight.data, self.variance_epsilon)
|
||||||
|
|
||||||
|
@ -28,6 +26,7 @@ def get_triton_rmsnorm_forward():
|
||||||
else:
|
else:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
class LlamaModelInferPolicy(LlamaForCausalLMPolicy):
|
class LlamaModelInferPolicy(LlamaForCausalLMPolicy):
|
||||||
|
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
|
@ -67,4 +66,3 @@ class LlamaModelInferPolicy(LlamaForCausalLMPolicy):
|
||||||
target_key=LlamaRMSNorm)
|
target_key=LlamaRMSNorm)
|
||||||
|
|
||||||
return policy
|
return policy
|
||||||
|
|
||||||
|
|
|
@ -1,14 +1,7 @@
|
||||||
from .cuda_native import FusedScaleMaskSoftmax, LayerNorm, MultiHeadAttention
|
from .cuda_native import FusedScaleMaskSoftmax, LayerNorm, MultiHeadAttention
|
||||||
from .triton import llama_context_attn_fwd, bloom_context_attn_fwd
|
|
||||||
from .triton import softmax
|
|
||||||
from .triton import copy_kv_cache_to_dest
|
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"LayerNorm",
|
"LayerNorm",
|
||||||
"FusedScaleMaskSoftmax",
|
"FusedScaleMaskSoftmax",
|
||||||
"MultiHeadAttention",
|
"MultiHeadAttention",
|
||||||
"llama_context_attn_fwd",
|
|
||||||
"bloom_context_attn_fwd",
|
|
||||||
"softmax",
|
|
||||||
"copy_kv_cache_to_dest",
|
|
||||||
]
|
]
|
||||||
|
|
|
@ -2,4 +2,11 @@ from .context_attention import bloom_context_attn_fwd, llama_context_attn_fwd
|
||||||
from .copy_kv_cache_dest import copy_kv_cache_to_dest
|
from .copy_kv_cache_dest import copy_kv_cache_to_dest
|
||||||
from .fused_layernorm import layer_norm
|
from .fused_layernorm import layer_norm
|
||||||
from .rms_norm import rmsnorm_forward
|
from .rms_norm import rmsnorm_forward
|
||||||
|
from .rotary_embedding_kernel import rotary_embedding_fwd
|
||||||
from .softmax import softmax
|
from .softmax import softmax
|
||||||
|
from .token_attention_kernel import token_attention_fwd
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"llama_context_attn_fwd", "bloom_context_attn_fwd", "softmax", "layer_norm", "rmsnorm_forward",
|
||||||
|
"copy_kv_cache_to_dest", "rotary_embedding_fwd", "token_attention_fwd"
|
||||||
|
]
|
||||||
|
|
Loading…
Reference in New Issue