Browse Source

[hotfix] Fix import error: colossal.kernel without triton installed (#4722)

* [hotfix] remove triton kernels from kernel init

* revise bloom/llama kernel imports for infer
pull/4727/head
Yuanheng Zhao 1 year ago committed by GitHub
parent
commit
e2c0e7f92a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 4
      colossalai/inference/tensor_parallel/modeling/bloom.py
  2. 10
      colossalai/inference/tensor_parallel/modeling/llama.py
  3. 4
      colossalai/inference/tensor_parallel/policies/bloom.py
  4. 26
      colossalai/inference/tensor_parallel/policies/llama.py
  5. 7
      colossalai/kernel/__init__.py
  6. 7
      colossalai/kernel/triton/__init__.py

4
colossalai/inference/tensor_parallel/modeling/bloom.py

@ -17,9 +17,7 @@ from transformers.models.bloom.modeling_bloom import (
from transformers.utils import logging
from colossalai.inference.tensor_parallel.batch_infer_state import BatchInferState
from colossalai.kernel.triton.context_attention import bloom_context_attn_fwd
from colossalai.kernel.triton.copy_kv_cache_dest import copy_kv_cache_to_dest
from colossalai.kernel.triton.token_attention_kernel import token_attention_fwd
from colossalai.kernel.triton import bloom_context_attn_fwd, copy_kv_cache_to_dest, token_attention_fwd
def generate_alibi(n_head, dtype=torch.float16):

10
colossalai/inference/tensor_parallel/modeling/llama.py

@ -6,10 +6,12 @@ from transformers.modeling_outputs import BaseModelOutputWithPast
from transformers.models.llama.modeling_llama import LlamaAttention, LlamaDecoderLayer, LlamaModel, LlamaRMSNorm
from colossalai.inference.tensor_parallel.batch_infer_state import BatchInferState
from colossalai.kernel.triton.context_attention import llama_context_attn_fwd
from colossalai.kernel.triton.copy_kv_cache_dest import copy_kv_cache_to_dest
from colossalai.kernel.triton.rotary_embedding_kernel import rotary_embedding_fwd
from colossalai.kernel.triton.token_attention_kernel import token_attention_fwd
from colossalai.kernel.triton import (
copy_kv_cache_to_dest,
llama_context_attn_fwd,
rotary_embedding_fwd,
token_attention_fwd,
)
try:
from vllm import layernorm_ops, pos_encoding_ops

4
colossalai/inference/tensor_parallel/policies/bloom.py

@ -8,10 +8,10 @@ from colossalai.shardformer.policies.bloom import BloomForCausalLMPolicy
from ..modeling.bloom import BloomInferenceForwards
try:
from colossalai.kernel.triton.fused_layernorm import layer_norm
from colossalai.kernel.triton import layer_norm
HAS_TRITON_NORM = True
except:
print("you should install triton from https://github.com/openai/triton")
print("Some of our kernels require triton. You might want to install triton from https://github.com/openai/triton")
HAS_TRITON_NORM = False

26
colossalai/inference/tensor_parallel/policies/llama.py

@ -1,33 +1,32 @@
from functools import partial
import torch
from transformers.models.llama.modeling_llama import (
LlamaAttention,
LlamaDecoderLayer,
LlamaModel,
LlamaRMSNorm
)
from transformers.models.llama.modeling_llama import LlamaAttention, LlamaDecoderLayer, LlamaModel, LlamaRMSNorm
# import colossalai
from colossalai.shardformer.policies.llama import LlamaForCausalLMPolicy
from ..modeling.llama import LlamaInferenceForwards, get_llama_vllm_rmsnorm_forward
try:
from colossalai.kernel.triton.rms_norm import rmsnorm_forward
from colossalai.kernel.triton import rmsnorm_forward
HAS_TRITON_RMSNORM = True
except:
print("you should install triton from https://github.com/openai/triton")
HAS_TRITON_RMSNORM = False
def get_triton_rmsnorm_forward():
if HAS_TRITON_RMSNORM:
def _triton_rmsnorm_forward(self: LlamaRMSNorm, hidden_states: torch.Tensor):
return rmsnorm_forward(hidden_states, self.weight.data, self.variance_epsilon)
return _triton_rmsnorm_forward
else:
return None
class LlamaModelInferPolicy(LlamaForCausalLMPolicy):
def __init__(self) -> None:
@ -59,12 +58,11 @@ class LlamaModelInferPolicy(LlamaForCausalLMPolicy):
else:
# NOTE: adding rms_norm from cuda kernels caused precision issue, fix @tiandiao123
infer_forward = get_llama_vllm_rmsnorm_forward()
if infer_forward is not None:
method_replacement = {'forward': partial(infer_forward)}
self.append_or_create_method_replacement(description=method_replacement,
policy=policy,
target_key=LlamaRMSNorm)
policy=policy,
target_key=LlamaRMSNorm)
return policy

7
colossalai/kernel/__init__.py

@ -1,14 +1,7 @@
from .cuda_native import FusedScaleMaskSoftmax, LayerNorm, MultiHeadAttention
from .triton import llama_context_attn_fwd, bloom_context_attn_fwd
from .triton import softmax
from .triton import copy_kv_cache_to_dest
__all__ = [
"LayerNorm",
"FusedScaleMaskSoftmax",
"MultiHeadAttention",
"llama_context_attn_fwd",
"bloom_context_attn_fwd",
"softmax",
"copy_kv_cache_to_dest",
]

7
colossalai/kernel/triton/__init__.py

@ -2,4 +2,11 @@ from .context_attention import bloom_context_attn_fwd, llama_context_attn_fwd
from .copy_kv_cache_dest import copy_kv_cache_to_dest
from .fused_layernorm import layer_norm
from .rms_norm import rmsnorm_forward
from .rotary_embedding_kernel import rotary_embedding_fwd
from .softmax import softmax
from .token_attention_kernel import token_attention_fwd
__all__ = [
"llama_context_attn_fwd", "bloom_context_attn_fwd", "softmax", "layer_norm", "rmsnorm_forward",
"copy_kv_cache_to_dest", "rotary_embedding_fwd", "token_attention_fwd"
]

Loading…
Cancel
Save