fixed jit default setting (#154)

pull/156/head
Frank Lee 2022-01-18 13:37:20 +08:00 committed by GitHub
parent a1da3900c8
commit f3802d6b06
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 7 additions and 4 deletions

View File

@ -1,8 +1,8 @@
from .option import _set_jit_fusion_options
from .option import set_jit_fusion_options
from .bias_dropout_add import bias_dropout_add_fused_train, bias_dropout_add_fused_inference
from .bias_gelu import bias_gelu_impl
_set_jit_fusion_options()
__all__ = [
"bias_dropout_add_fused_train", "bias_dropout_add_fused_inference", "bias_gelu_impl",
"set_jit_fusion_options"
]

View File

@ -3,8 +3,11 @@ import torch
JIT_OPTIONS_SET = False
def _set_jit_fusion_options():
"""Set PyTorch JIT layer fusion options."""
def set_jit_fusion_options():
"""Set PyTorch JIT layer fusion options.
"""
# LSG: the latest pytorch and CUDA versions may not support
# the following jit settings
global JIT_OPTIONS_SET
if JIT_OPTIONS_SET == False:
# flags required to enable jit fusion kernels