mirror of https://github.com/hpcaitech/ColossalAI
[zero] added error message to handle on-the-fly import of torch Module class (#1135)
* [zero] added error message to handle on-the-fly import of torch Module class * polish codepull/1138/head
parent
e4f555f29a
commit
73ad05fc8c
|
@ -80,6 +80,10 @@ class InsertPostInitMethodToModuleSubClasses(object):
|
||||||
torch.set_default_dtype(self._old_default_dtype)
|
torch.set_default_dtype(self._old_default_dtype)
|
||||||
|
|
||||||
def _disable_class(cls):
|
def _disable_class(cls):
|
||||||
|
if not hasattr(cls, '_old_init'):
|
||||||
|
raise AttributeError(
|
||||||
|
f"_old_init is not found in the {cls.__name__}, please make sure that you have imported {cls.__name__} before entering the context."
|
||||||
|
)
|
||||||
cls.__init__ = cls._old_init
|
cls.__init__ = cls._old_init
|
||||||
|
|
||||||
# Replace .__init__() for all existing subclasses of torch.nn.Module
|
# Replace .__init__() for all existing subclasses of torch.nn.Module
|
||||||
|
|
Loading…
Reference in New Issue