mirror of https://github.com/hpcaitech/ColossalAI
[hotfix] Fix import error: colossal.kernel without triton installed (#4722)
* [hotfix] remove triton kernels from kernel init * revise bloom/llama kernel imports for inferpull/4727/head
parent
c7d6975d29
commit
e2c0e7f92a
|
@ -17,9 +17,7 @@ from transformers.models.bloom.modeling_bloom import (
|
|||
from transformers.utils import logging
|
||||
|
||||
from colossalai.inference.tensor_parallel.batch_infer_state import BatchInferState
|
||||
from colossalai.kernel.triton.context_attention import bloom_context_attn_fwd
|
||||
from colossalai.kernel.triton.copy_kv_cache_dest import copy_kv_cache_to_dest
|
||||
from colossalai.kernel.triton.token_attention_kernel import token_attention_fwd
|
||||
from colossalai.kernel.triton import bloom_context_attn_fwd, copy_kv_cache_to_dest, token_attention_fwd
|
||||
|
||||
|
||||
def generate_alibi(n_head, dtype=torch.float16):
|
||||
|
|
|
@ -6,10 +6,12 @@ from transformers.modeling_outputs import BaseModelOutputWithPast
|
|||
from transformers.models.llama.modeling_llama import LlamaAttention, LlamaDecoderLayer, LlamaModel, LlamaRMSNorm
|
||||
|
||||
from colossalai.inference.tensor_parallel.batch_infer_state import BatchInferState
|
||||
from colossalai.kernel.triton.context_attention import llama_context_attn_fwd
|
||||
from colossalai.kernel.triton.copy_kv_cache_dest import copy_kv_cache_to_dest
|
||||
from colossalai.kernel.triton.rotary_embedding_kernel import rotary_embedding_fwd
|
||||
from colossalai.kernel.triton.token_attention_kernel import token_attention_fwd
|
||||
from colossalai.kernel.triton import (
|
||||
copy_kv_cache_to_dest,
|
||||
llama_context_attn_fwd,
|
||||
rotary_embedding_fwd,
|
||||
token_attention_fwd,
|
||||
)
|
||||
|
||||
try:
|
||||
from vllm import layernorm_ops, pos_encoding_ops
|
||||
|
|
|
@ -8,10 +8,10 @@ from colossalai.shardformer.policies.bloom import BloomForCausalLMPolicy
|
|||
from ..modeling.bloom import BloomInferenceForwards
|
||||
|
||||
try:
|
||||
from colossalai.kernel.triton.fused_layernorm import layer_norm
|
||||
from colossalai.kernel.triton import layer_norm
|
||||
HAS_TRITON_NORM = True
|
||||
except:
|
||||
print("you should install triton from https://github.com/openai/triton")
|
||||
print("Some of our kernels require triton. You might want to install triton from https://github.com/openai/triton")
|
||||
HAS_TRITON_NORM = False
|
||||
|
||||
|
||||
|
|
|
@ -1,33 +1,32 @@
|
|||
from functools import partial
|
||||
|
||||
import torch
|
||||
from transformers.models.llama.modeling_llama import (
|
||||
LlamaAttention,
|
||||
LlamaDecoderLayer,
|
||||
LlamaModel,
|
||||
LlamaRMSNorm
|
||||
)
|
||||
from transformers.models.llama.modeling_llama import LlamaAttention, LlamaDecoderLayer, LlamaModel, LlamaRMSNorm
|
||||
|
||||
# import colossalai
|
||||
from colossalai.shardformer.policies.llama import LlamaForCausalLMPolicy
|
||||
|
||||
from ..modeling.llama import LlamaInferenceForwards, get_llama_vllm_rmsnorm_forward
|
||||
|
||||
try:
|
||||
from colossalai.kernel.triton.rms_norm import rmsnorm_forward
|
||||
from colossalai.kernel.triton import rmsnorm_forward
|
||||
HAS_TRITON_RMSNORM = True
|
||||
except:
|
||||
print("you should install triton from https://github.com/openai/triton")
|
||||
HAS_TRITON_RMSNORM = False
|
||||
|
||||
|
||||
|
||||
def get_triton_rmsnorm_forward():
|
||||
if HAS_TRITON_RMSNORM:
|
||||
|
||||
def _triton_rmsnorm_forward(self: LlamaRMSNorm, hidden_states: torch.Tensor):
|
||||
return rmsnorm_forward(hidden_states, self.weight.data, self.variance_epsilon)
|
||||
|
||||
|
||||
return _triton_rmsnorm_forward
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
|
||||
class LlamaModelInferPolicy(LlamaForCausalLMPolicy):
|
||||
|
||||
def __init__(self) -> None:
|
||||
|
@ -59,12 +58,11 @@ class LlamaModelInferPolicy(LlamaForCausalLMPolicy):
|
|||
else:
|
||||
# NOTE: adding rms_norm from cuda kernels caused precision issue, fix @tiandiao123
|
||||
infer_forward = get_llama_vllm_rmsnorm_forward()
|
||||
|
||||
|
||||
if infer_forward is not None:
|
||||
method_replacement = {'forward': partial(infer_forward)}
|
||||
self.append_or_create_method_replacement(description=method_replacement,
|
||||
policy=policy,
|
||||
target_key=LlamaRMSNorm)
|
||||
policy=policy,
|
||||
target_key=LlamaRMSNorm)
|
||||
|
||||
return policy
|
||||
|
||||
|
|
|
@ -1,14 +1,7 @@
|
|||
from .cuda_native import FusedScaleMaskSoftmax, LayerNorm, MultiHeadAttention
|
||||
from .triton import llama_context_attn_fwd, bloom_context_attn_fwd
|
||||
from .triton import softmax
|
||||
from .triton import copy_kv_cache_to_dest
|
||||
|
||||
__all__ = [
|
||||
"LayerNorm",
|
||||
"FusedScaleMaskSoftmax",
|
||||
"MultiHeadAttention",
|
||||
"llama_context_attn_fwd",
|
||||
"bloom_context_attn_fwd",
|
||||
"softmax",
|
||||
"copy_kv_cache_to_dest",
|
||||
]
|
||||
|
|
|
@ -2,4 +2,11 @@ from .context_attention import bloom_context_attn_fwd, llama_context_attn_fwd
|
|||
from .copy_kv_cache_dest import copy_kv_cache_to_dest
|
||||
from .fused_layernorm import layer_norm
|
||||
from .rms_norm import rmsnorm_forward
|
||||
from .rotary_embedding_kernel import rotary_embedding_fwd
|
||||
from .softmax import softmax
|
||||
from .token_attention_kernel import token_attention_fwd
|
||||
|
||||
__all__ = [
|
||||
"llama_context_attn_fwd", "bloom_context_attn_fwd", "softmax", "layer_norm", "rmsnorm_forward",
|
||||
"copy_kv_cache_to_dest", "rotary_embedding_fwd", "token_attention_fwd"
|
||||
]
|
||||
|
|
Loading…
Reference in New Issue