add lightllm rmsnorm (#5096)

Co-authored-by: cuiqing.li <lixx336@gmail.com>
pull/5097/head
Cuiqing Li (李崔卿) 2023-11-22 17:05:34 +08:00 committed by GitHub
parent 27e62ba0f7
commit afe3c78d9a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 4 additions and 10 deletions

View File

@ -20,23 +20,17 @@ from ..modeling._utils import init_to_get_rotary
from ..modeling.llama import LlamaInferenceForwards
try:
from colossalai.kernel.triton import rmsnorm_forward
from lightllm.models.llama.triton_kernel.rmsnorm import rmsnorm_forward as lightllm_rmsnorm_forward
HAS_TRITON_RMSNORM = True
except:
print("you should install triton from https://github.com/openai/triton")
HAS_TRITON_RMSNORM = False
def get_triton_rmsnorm_forward():
if HAS_TRITON_RMSNORM:
if HAS_TRITON_RMSNORM:
def get_triton_rmsnorm_forward():
def _triton_rmsnorm_forward(self: LlamaRMSNorm, hidden_states: torch.Tensor):
return rmsnorm_forward(hidden_states, self.weight.data, self.variance_epsilon)
return lightllm_rmsnorm_forward(hidden_states, self.weight.data, self.variance_epsilon)
return _triton_rmsnorm_forward
else:
return None
class LlamaModelInferPolicy(LlamaForCausalLMPolicy):