diff --git a/internlm/model/moe.py b/internlm/model/moe.py index e1e2edf..b09b319 100644 --- a/internlm/model/moe.py +++ b/internlm/model/moe.py @@ -1,5 +1,6 @@ import torch +import internlm.moe # noqa # pylint: disable=W0611 from internlm.core.context import ParallelMode from internlm.core.context import global_context as gpc from internlm.model.linear import FeedForward diff --git a/internlm/moe/__init__.py b/internlm/moe/__init__.py index fdda02c..343fffc 100644 --- a/internlm/moe/__init__.py +++ b/internlm/moe/__init__.py @@ -1,3 +1,3 @@ -from internlm.moe.sharded_moe import GShardMOELayer +from .sharded_moe import GShardMOELayer __all__ = ["GShardMOELayer"]