diff --git a/colossalai/inference/engine/policies/llama.py b/colossalai/inference/engine/policies/llama.py index 11517d7e8..397060258 100644 --- a/colossalai/inference/engine/policies/llama.py +++ b/colossalai/inference/engine/policies/llama.py @@ -20,23 +20,17 @@ from ..modeling._utils import init_to_get_rotary from ..modeling.llama import LlamaInferenceForwards try: - from colossalai.kernel.triton import rmsnorm_forward - + from lightllm.models.llama.triton_kernel.rmsnorm import rmsnorm_forward as lightllm_rmsnorm_forward HAS_TRITON_RMSNORM = True except: print("you should install triton from https://github.com/openai/triton") HAS_TRITON_RMSNORM = False - -def get_triton_rmsnorm_forward(): - if HAS_TRITON_RMSNORM: - +if HAS_TRITON_RMSNORM: + def get_triton_rmsnorm_forward(): def _triton_rmsnorm_forward(self: LlamaRMSNorm, hidden_states: torch.Tensor): - return rmsnorm_forward(hidden_states, self.weight.data, self.variance_epsilon) - + return lightllm_rmsnorm_forward(hidden_states, self.weight.data, self.variance_epsilon) return _triton_rmsnorm_forward - else: - return None class LlamaModelInferPolicy(LlamaForCausalLMPolicy):