diff --git a/colossalai/utils/model/utils.py b/colossalai/utils/model/utils.py index 50a75a110..0b0b73820 100644 --- a/colossalai/utils/model/utils.py +++ b/colossalai/utils/model/utils.py @@ -72,6 +72,7 @@ class InsertPostInitMethodToModuleSubClasses(object): torch.nn.modules.module.Module.__init_subclass__ = classmethod(_init_subclass) self._pre_context_exec() + return self def __exit__(self, exc_type, exc_value, traceback):