diff --git a/colossalai/engine/_base_engine.py b/colossalai/engine/_base_engine.py index 29f63c430..074b9d0cc 100644 --- a/colossalai/engine/_base_engine.py +++ b/colossalai/engine/_base_engine.py @@ -92,7 +92,10 @@ class Engine: self._schedule = NonPipelineSchedule() if self.uses_pipeline: self._schedule.pre_processing(self) - register_ophooks_recursively(self._model, self._ophook_list) + + #register hook if any + if len(self._ophook_list) > 0: + register_ophooks_recursively(self._model, self._ophook_list) @property def ophooks(self): diff --git a/colossalai/engine/ophooks/utils.py b/colossalai/engine/ophooks/utils.py index 26d485657..a0ad50bfe 100644 --- a/colossalai/engine/ophooks/utils.py +++ b/colossalai/engine/ophooks/utils.py @@ -85,11 +85,15 @@ class PostBackwardFunction(torch.autograd.Function): def register_ophooks_recursively(module: torch.nn.Module, - ophook_list: List[BaseOpHook] = None, + ophook_list: List[BaseOpHook], name: str = "", filter_fn: Optional[Callable] = None): r"""Recursilvely register pre/post hooks for all submodules in the module in FWD and BWD.""" assert isinstance(module, torch.nn.Module) + assert isinstance(ophook_list, (list, tuple)) + assert len(ophook_list) > 0, 'expected at least 1 hook in the argument ophook_list but found 0' + for hook in ophook_list: + assert (isinstance(hook, BaseOpHook)) # Add hooks for submodules for child_name, child in module.named_children(): @@ -103,10 +107,6 @@ def register_ophooks_recursively(module: torch.nn.Module, if filter_fn is not None and filter_fn(module): return - if ophook_list is not None: - for hook in ophook_list: - assert (isinstance(hook, BaseOpHook)) - def _pre_forward_module_hook(submodule, *args): for hook in ophook_list: assert isinstance(submodule, torch.nn.Module)