From afe3c78d9af3ac4d0bb67ad25ca64452618ffa68 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Cuiqing=20Li=20=28=E6=9D=8E=E5=B4=94=E5=8D=BF=29?= Date: Wed, 22 Nov 2023 17:05:34 +0800 Subject: [PATCH] add lightllm rmsnorm (#5096) Co-authored-by: cuiqing.li --- colossalai/inference/engine/policies/llama.py | 14 ++++---------- 1 file changed, 4 insertions(+), 10 deletions(-) diff --git a/colossalai/inference/engine/policies/llama.py b/colossalai/inference/engine/policies/llama.py index 11517d7e8..397060258 100644 --- a/colossalai/inference/engine/policies/llama.py +++ b/colossalai/inference/engine/policies/llama.py @@ -20,23 +20,17 @@ from ..modeling._utils import init_to_get_rotary from ..modeling.llama import LlamaInferenceForwards try: - from colossalai.kernel.triton import rmsnorm_forward - + from lightllm.models.llama.triton_kernel.rmsnorm import rmsnorm_forward as lightllm_rmsnorm_forward HAS_TRITON_RMSNORM = True except: print("you should install triton from https://github.com/openai/triton") HAS_TRITON_RMSNORM = False - -def get_triton_rmsnorm_forward(): - if HAS_TRITON_RMSNORM: - +if HAS_TRITON_RMSNORM: + def get_triton_rmsnorm_forward(): def _triton_rmsnorm_forward(self: LlamaRMSNorm, hidden_states: torch.Tensor): - return rmsnorm_forward(hidden_states, self.weight.data, self.variance_epsilon) - + return lightllm_rmsnorm_forward(hidden_states, self.weight.data, self.variance_epsilon) return _triton_rmsnorm_forward - else: - return None class LlamaModelInferPolicy(LlamaForCausalLMPolicy):