ColossalAI/colossalai/elixir/kernels/__init__.py

14 lines
397 B
Python

import torch.nn.functional as F
fused_torch_functions = {F.layer_norm: F.layer_norm}
def register_fused_layer_norm():
try:
from .layernorm import ln_func
fused_torch_functions[F.layer_norm] = ln_func
print('Register fused layer norm successfully from apex.')
except:
print('Cannot import fused layer norm, please install apex from source.')
pass