mirror of https://github.com/hpcaitech/ColossalAI
[engine] fixed empty op hook check (#1096)
* [engine] fixed empty op hook check * polish codepull/1099/head
parent
14e5b11d7f
commit
7f2d2b2b5b
|
@ -92,7 +92,10 @@ class Engine:
|
||||||
self._schedule = NonPipelineSchedule()
|
self._schedule = NonPipelineSchedule()
|
||||||
if self.uses_pipeline:
|
if self.uses_pipeline:
|
||||||
self._schedule.pre_processing(self)
|
self._schedule.pre_processing(self)
|
||||||
register_ophooks_recursively(self._model, self._ophook_list)
|
|
||||||
|
#register hook if any
|
||||||
|
if len(self._ophook_list) > 0:
|
||||||
|
register_ophooks_recursively(self._model, self._ophook_list)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def ophooks(self):
|
def ophooks(self):
|
||||||
|
|
|
@ -85,11 +85,15 @@ class PostBackwardFunction(torch.autograd.Function):
|
||||||
|
|
||||||
|
|
||||||
def register_ophooks_recursively(module: torch.nn.Module,
|
def register_ophooks_recursively(module: torch.nn.Module,
|
||||||
ophook_list: List[BaseOpHook] = None,
|
ophook_list: List[BaseOpHook],
|
||||||
name: str = "",
|
name: str = "",
|
||||||
filter_fn: Optional[Callable] = None):
|
filter_fn: Optional[Callable] = None):
|
||||||
r"""Recursilvely register pre/post hooks for all submodules in the module in FWD and BWD."""
|
r"""Recursilvely register pre/post hooks for all submodules in the module in FWD and BWD."""
|
||||||
assert isinstance(module, torch.nn.Module)
|
assert isinstance(module, torch.nn.Module)
|
||||||
|
assert isinstance(ophook_list, (list, tuple))
|
||||||
|
assert len(ophook_list) > 0, 'expected at least 1 hook in the argument ophook_list but found 0'
|
||||||
|
for hook in ophook_list:
|
||||||
|
assert (isinstance(hook, BaseOpHook))
|
||||||
|
|
||||||
# Add hooks for submodules
|
# Add hooks for submodules
|
||||||
for child_name, child in module.named_children():
|
for child_name, child in module.named_children():
|
||||||
|
@ -103,10 +107,6 @@ def register_ophooks_recursively(module: torch.nn.Module,
|
||||||
if filter_fn is not None and filter_fn(module):
|
if filter_fn is not None and filter_fn(module):
|
||||||
return
|
return
|
||||||
|
|
||||||
if ophook_list is not None:
|
|
||||||
for hook in ophook_list:
|
|
||||||
assert (isinstance(hook, BaseOpHook))
|
|
||||||
|
|
||||||
def _pre_forward_module_hook(submodule, *args):
|
def _pre_forward_module_hook(submodule, *args):
|
||||||
for hook in ophook_list:
|
for hook in ophook_list:
|
||||||
assert isinstance(submodule, torch.nn.Module)
|
assert isinstance(submodule, torch.nn.Module)
|
||||||
|
|
Loading…
Reference in New Issue