mirror of https://github.com/hpcaitech/ColossalAI
[engine] fixed empty op hook check (#1096)
* [engine] fixed empty op hook check * polish codepull/1099/head
parent
14e5b11d7f
commit
7f2d2b2b5b
|
@ -92,7 +92,10 @@ class Engine:
|
|||
self._schedule = NonPipelineSchedule()
|
||||
if self.uses_pipeline:
|
||||
self._schedule.pre_processing(self)
|
||||
register_ophooks_recursively(self._model, self._ophook_list)
|
||||
|
||||
#register hook if any
|
||||
if len(self._ophook_list) > 0:
|
||||
register_ophooks_recursively(self._model, self._ophook_list)
|
||||
|
||||
@property
|
||||
def ophooks(self):
|
||||
|
|
|
@ -85,11 +85,15 @@ class PostBackwardFunction(torch.autograd.Function):
|
|||
|
||||
|
||||
def register_ophooks_recursively(module: torch.nn.Module,
|
||||
ophook_list: List[BaseOpHook] = None,
|
||||
ophook_list: List[BaseOpHook],
|
||||
name: str = "",
|
||||
filter_fn: Optional[Callable] = None):
|
||||
r"""Recursilvely register pre/post hooks for all submodules in the module in FWD and BWD."""
|
||||
assert isinstance(module, torch.nn.Module)
|
||||
assert isinstance(ophook_list, (list, tuple))
|
||||
assert len(ophook_list) > 0, 'expected at least 1 hook in the argument ophook_list but found 0'
|
||||
for hook in ophook_list:
|
||||
assert (isinstance(hook, BaseOpHook))
|
||||
|
||||
# Add hooks for submodules
|
||||
for child_name, child in module.named_children():
|
||||
|
@ -103,10 +107,6 @@ def register_ophooks_recursively(module: torch.nn.Module,
|
|||
if filter_fn is not None and filter_fn(module):
|
||||
return
|
||||
|
||||
if ophook_list is not None:
|
||||
for hook in ophook_list:
|
||||
assert (isinstance(hook, BaseOpHook))
|
||||
|
||||
def _pre_forward_module_hook(submodule, *args):
|
||||
for hook in ophook_list:
|
||||
assert isinstance(submodule, torch.nn.Module)
|
||||
|
|
Loading…
Reference in New Issue