mirror of https://github.com/hpcaitech/ColossalAI
parent
27e62ba0f7
commit
afe3c78d9a
|
@ -20,23 +20,17 @@ from ..modeling._utils import init_to_get_rotary
|
||||||
from ..modeling.llama import LlamaInferenceForwards
|
from ..modeling.llama import LlamaInferenceForwards
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from colossalai.kernel.triton import rmsnorm_forward
|
from lightllm.models.llama.triton_kernel.rmsnorm import rmsnorm_forward as lightllm_rmsnorm_forward
|
||||||
|
|
||||||
HAS_TRITON_RMSNORM = True
|
HAS_TRITON_RMSNORM = True
|
||||||
except:
|
except:
|
||||||
print("you should install triton from https://github.com/openai/triton")
|
print("you should install triton from https://github.com/openai/triton")
|
||||||
HAS_TRITON_RMSNORM = False
|
HAS_TRITON_RMSNORM = False
|
||||||
|
|
||||||
|
if HAS_TRITON_RMSNORM:
|
||||||
def get_triton_rmsnorm_forward():
|
def get_triton_rmsnorm_forward():
|
||||||
if HAS_TRITON_RMSNORM:
|
|
||||||
|
|
||||||
def _triton_rmsnorm_forward(self: LlamaRMSNorm, hidden_states: torch.Tensor):
|
def _triton_rmsnorm_forward(self: LlamaRMSNorm, hidden_states: torch.Tensor):
|
||||||
return rmsnorm_forward(hidden_states, self.weight.data, self.variance_epsilon)
|
return lightllm_rmsnorm_forward(hidden_states, self.weight.data, self.variance_epsilon)
|
||||||
|
|
||||||
return _triton_rmsnorm_forward
|
return _triton_rmsnorm_forward
|
||||||
else:
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
class LlamaModelInferPolicy(LlamaForCausalLMPolicy):
|
class LlamaModelInferPolicy(LlamaForCausalLMPolicy):
|
||||||
|
|
Loading…
Reference in New Issue