[engine] fixed empty op hook check (#1096)

* [engine] fixed empty op hook check

* polish code
pull/1099/head
Frank Lee 2022-06-10 17:27:27 +08:00 committed by GitHub
parent 14e5b11d7f
commit 7f2d2b2b5b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 9 additions and 6 deletions

View File

@ -92,7 +92,10 @@ class Engine:
self._schedule = NonPipelineSchedule()
if self.uses_pipeline:
self._schedule.pre_processing(self)
register_ophooks_recursively(self._model, self._ophook_list)
#register hook if any
if len(self._ophook_list) > 0:
register_ophooks_recursively(self._model, self._ophook_list)
@property
def ophooks(self):

View File

@ -85,11 +85,15 @@ class PostBackwardFunction(torch.autograd.Function):
def register_ophooks_recursively(module: torch.nn.Module,
ophook_list: List[BaseOpHook] = None,
ophook_list: List[BaseOpHook],
name: str = "",
filter_fn: Optional[Callable] = None):
r"""Recursilvely register pre/post hooks for all submodules in the module in FWD and BWD."""
assert isinstance(module, torch.nn.Module)
assert isinstance(ophook_list, (list, tuple))
assert len(ophook_list) > 0, 'expected at least 1 hook in the argument ophook_list but found 0'
for hook in ophook_list:
assert (isinstance(hook, BaseOpHook))
# Add hooks for submodules
for child_name, child in module.named_children():
@ -103,10 +107,6 @@ def register_ophooks_recursively(module: torch.nn.Module,
if filter_fn is not None and filter_fn(module):
return
if ophook_list is not None:
for hook in ophook_list:
assert (isinstance(hook, BaseOpHook))
def _pre_forward_module_hook(submodule, *args):
for hook in ophook_list:
assert isinstance(submodule, torch.nn.Module)