diff --git a/colossalai/utils/model/utils.py b/colossalai/utils/model/utils.py index 0b0b73820..ecc0cdb5a 100644 --- a/colossalai/utils/model/utils.py +++ b/colossalai/utils/model/utils.py @@ -80,6 +80,10 @@ class InsertPostInitMethodToModuleSubClasses(object): torch.set_default_dtype(self._old_default_dtype) def _disable_class(cls): + if not hasattr(cls, '_old_init'): + raise AttributeError( + f"_old_init is not found in the {cls.__name__}, please make sure that you have imported {cls.__name__} before entering the context." + ) cls.__init__ = cls._old_init # Replace .__init__() for all existing subclasses of torch.nn.Module