diff --git a/tests/test_utils/common_fixture.py b/tests/test_utils/common_fixture.py index 746f43e..e5d6328 100644 --- a/tests/test_utils/common_fixture.py +++ b/tests/test_utils/common_fixture.py @@ -8,6 +8,7 @@ import torch from internlm.core.context import global_context as gpc from internlm.core.context.parallel_context import Config from internlm.solver.optimizer.hybrid_zero_optim import HybridZeroOptimizer +from internlm.train.utils import create_param_groups from internlm.utils.storage_manager import SingletonMeta OSS_NAME = os.environ.get("OSS_BUCKET_NAME")