diff --git a/colossalai/kernel/jit/__init__.py b/colossalai/kernel/jit/__init__.py index 374e43537..57b8fb7b2 100644 --- a/colossalai/kernel/jit/__init__.py +++ b/colossalai/kernel/jit/__init__.py @@ -1,8 +1,8 @@ -from .option import _set_jit_fusion_options +from .option import set_jit_fusion_options from .bias_dropout_add import bias_dropout_add_fused_train, bias_dropout_add_fused_inference from .bias_gelu import bias_gelu_impl -_set_jit_fusion_options() __all__ = [ "bias_dropout_add_fused_train", "bias_dropout_add_fused_inference", "bias_gelu_impl", + "set_jit_fusion_options" ] diff --git a/colossalai/kernel/jit/option.py b/colossalai/kernel/jit/option.py index 73c0b7b57..c21789726 100644 --- a/colossalai/kernel/jit/option.py +++ b/colossalai/kernel/jit/option.py @@ -3,8 +3,11 @@ import torch JIT_OPTIONS_SET = False -def _set_jit_fusion_options(): - """Set PyTorch JIT layer fusion options.""" +def set_jit_fusion_options(): + """Set PyTorch JIT layer fusion options. + """ + # LSG: the latest pytorch and CUDA versions may not support + # the following jit settings global JIT_OPTIONS_SET if JIT_OPTIONS_SET == False: # flags required to enable jit fusion kernels