[zero] added error message to handle on-the-fly import of torch Module class (#1135)

* [zero] added error message to handle on-the-fly import of torch Module class

* polish code
pull/1138/head
Frank Lee 2022-06-20 11:24:27 +08:00 committed by GitHub
parent e4f555f29a
commit 73ad05fc8c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 4 additions and 0 deletions

View File

@ -80,6 +80,10 @@ class InsertPostInitMethodToModuleSubClasses(object):
torch.set_default_dtype(self._old_default_dtype)
def _disable_class(cls):
if not hasattr(cls, '_old_init'):
raise AttributeError(
f"_old_init is not found in the {cls.__name__}, please make sure that you have imported {cls.__name__} before entering the context."
)
cls.__init__ = cls._old_init
# Replace .__init__() for all existing subclasses of torch.nn.Module