From e2c0e7f92abf6b93ccb331298880a41553df0cc7 Mon Sep 17 00:00:00 2001 From: Yuanheng Zhao <54058983+yuanheng-zhao@users.noreply.github.com> Date: Thu, 14 Sep 2023 18:03:55 +0800 Subject: [PATCH] [hotfix] Fix import error: colossal.kernel without triton installed (#4722) * [hotfix] remove triton kernels from kernel init * revise bloom/llama kernel imports for infer --- .../tensor_parallel/modeling/bloom.py | 4 +-- .../tensor_parallel/modeling/llama.py | 10 ++++--- .../tensor_parallel/policies/bloom.py | 4 +-- .../tensor_parallel/policies/llama.py | 26 +++++++++---------- colossalai/kernel/__init__.py | 7 ----- colossalai/kernel/triton/__init__.py | 7 +++++ 6 files changed, 28 insertions(+), 30 deletions(-) diff --git a/colossalai/inference/tensor_parallel/modeling/bloom.py b/colossalai/inference/tensor_parallel/modeling/bloom.py index 9768fc425..ba5eadc92 100644 --- a/colossalai/inference/tensor_parallel/modeling/bloom.py +++ b/colossalai/inference/tensor_parallel/modeling/bloom.py @@ -17,9 +17,7 @@ from transformers.models.bloom.modeling_bloom import ( from transformers.utils import logging from colossalai.inference.tensor_parallel.batch_infer_state import BatchInferState -from colossalai.kernel.triton.context_attention import bloom_context_attn_fwd -from colossalai.kernel.triton.copy_kv_cache_dest import copy_kv_cache_to_dest -from colossalai.kernel.triton.token_attention_kernel import token_attention_fwd +from colossalai.kernel.triton import bloom_context_attn_fwd, copy_kv_cache_to_dest, token_attention_fwd def generate_alibi(n_head, dtype=torch.float16): diff --git a/colossalai/inference/tensor_parallel/modeling/llama.py b/colossalai/inference/tensor_parallel/modeling/llama.py index 219cd1ae0..07b73a6f4 100644 --- a/colossalai/inference/tensor_parallel/modeling/llama.py +++ b/colossalai/inference/tensor_parallel/modeling/llama.py @@ -6,10 +6,12 @@ from transformers.modeling_outputs import BaseModelOutputWithPast from transformers.models.llama.modeling_llama import LlamaAttention, LlamaDecoderLayer, LlamaModel, LlamaRMSNorm from colossalai.inference.tensor_parallel.batch_infer_state import BatchInferState -from colossalai.kernel.triton.context_attention import llama_context_attn_fwd -from colossalai.kernel.triton.copy_kv_cache_dest import copy_kv_cache_to_dest -from colossalai.kernel.triton.rotary_embedding_kernel import rotary_embedding_fwd -from colossalai.kernel.triton.token_attention_kernel import token_attention_fwd +from colossalai.kernel.triton import ( + copy_kv_cache_to_dest, + llama_context_attn_fwd, + rotary_embedding_fwd, + token_attention_fwd, +) try: from vllm import layernorm_ops, pos_encoding_ops diff --git a/colossalai/inference/tensor_parallel/policies/bloom.py b/colossalai/inference/tensor_parallel/policies/bloom.py index 63791fe27..cae43aa20 100644 --- a/colossalai/inference/tensor_parallel/policies/bloom.py +++ b/colossalai/inference/tensor_parallel/policies/bloom.py @@ -8,10 +8,10 @@ from colossalai.shardformer.policies.bloom import BloomForCausalLMPolicy from ..modeling.bloom import BloomInferenceForwards try: - from colossalai.kernel.triton.fused_layernorm import layer_norm + from colossalai.kernel.triton import layer_norm HAS_TRITON_NORM = True except: - print("you should install triton from https://github.com/openai/triton") + print("Some of our kernels require triton. You might want to install triton from https://github.com/openai/triton") HAS_TRITON_NORM = False diff --git a/colossalai/inference/tensor_parallel/policies/llama.py b/colossalai/inference/tensor_parallel/policies/llama.py index e819f2a88..4844415d6 100644 --- a/colossalai/inference/tensor_parallel/policies/llama.py +++ b/colossalai/inference/tensor_parallel/policies/llama.py @@ -1,33 +1,32 @@ from functools import partial + import torch -from transformers.models.llama.modeling_llama import ( - LlamaAttention, - LlamaDecoderLayer, - LlamaModel, - LlamaRMSNorm -) +from transformers.models.llama.modeling_llama import LlamaAttention, LlamaDecoderLayer, LlamaModel, LlamaRMSNorm # import colossalai from colossalai.shardformer.policies.llama import LlamaForCausalLMPolicy + from ..modeling.llama import LlamaInferenceForwards, get_llama_vllm_rmsnorm_forward try: - from colossalai.kernel.triton.rms_norm import rmsnorm_forward + from colossalai.kernel.triton import rmsnorm_forward HAS_TRITON_RMSNORM = True except: print("you should install triton from https://github.com/openai/triton") HAS_TRITON_RMSNORM = False - + def get_triton_rmsnorm_forward(): if HAS_TRITON_RMSNORM: + def _triton_rmsnorm_forward(self: LlamaRMSNorm, hidden_states: torch.Tensor): return rmsnorm_forward(hidden_states, self.weight.data, self.variance_epsilon) - + return _triton_rmsnorm_forward else: return None - + + class LlamaModelInferPolicy(LlamaForCausalLMPolicy): def __init__(self) -> None: @@ -59,12 +58,11 @@ class LlamaModelInferPolicy(LlamaForCausalLMPolicy): else: # NOTE: adding rms_norm from cuda kernels caused precision issue, fix @tiandiao123 infer_forward = get_llama_vllm_rmsnorm_forward() - + if infer_forward is not None: method_replacement = {'forward': partial(infer_forward)} self.append_or_create_method_replacement(description=method_replacement, - policy=policy, - target_key=LlamaRMSNorm) + policy=policy, + target_key=LlamaRMSNorm) return policy - diff --git a/colossalai/kernel/__init__.py b/colossalai/kernel/__init__.py index a99cb497c..8933fc0a3 100644 --- a/colossalai/kernel/__init__.py +++ b/colossalai/kernel/__init__.py @@ -1,14 +1,7 @@ from .cuda_native import FusedScaleMaskSoftmax, LayerNorm, MultiHeadAttention -from .triton import llama_context_attn_fwd, bloom_context_attn_fwd -from .triton import softmax -from .triton import copy_kv_cache_to_dest __all__ = [ "LayerNorm", "FusedScaleMaskSoftmax", "MultiHeadAttention", - "llama_context_attn_fwd", - "bloom_context_attn_fwd", - "softmax", - "copy_kv_cache_to_dest", ] diff --git a/colossalai/kernel/triton/__init__.py b/colossalai/kernel/triton/__init__.py index eb0335c01..5840ad291 100644 --- a/colossalai/kernel/triton/__init__.py +++ b/colossalai/kernel/triton/__init__.py @@ -2,4 +2,11 @@ from .context_attention import bloom_context_attn_fwd, llama_context_attn_fwd from .copy_kv_cache_dest import copy_kv_cache_to_dest from .fused_layernorm import layer_norm from .rms_norm import rmsnorm_forward +from .rotary_embedding_kernel import rotary_embedding_fwd from .softmax import softmax +from .token_attention_kernel import token_attention_fwd + +__all__ = [ + "llama_context_attn_fwd", "bloom_context_attn_fwd", "softmax", "layer_norm", "rmsnorm_forward", + "copy_kv_cache_to_dest", "rotary_embedding_fwd", "token_attention_fwd" +]