mirror of https://github.com/hpcaitech/ColossalAI
[zero] added error message to handle on-the-fly import of torch Module class (#1135)
* [zero] added error message to handle on-the-fly import of torch Module class * polish codepull/1138/head
parent
e4f555f29a
commit
73ad05fc8c
|
@ -80,6 +80,10 @@ class InsertPostInitMethodToModuleSubClasses(object):
|
|||
torch.set_default_dtype(self._old_default_dtype)
|
||||
|
||||
def _disable_class(cls):
|
||||
if not hasattr(cls, '_old_init'):
|
||||
raise AttributeError(
|
||||
f"_old_init is not found in the {cls.__name__}, please make sure that you have imported {cls.__name__} before entering the context."
|
||||
)
|
||||
cls.__init__ = cls._old_init
|
||||
|
||||
# Replace .__init__() for all existing subclasses of torch.nn.Module
|
||||
|
|
Loading…
Reference in New Issue