2022-03-02 06:38:22 +00:00
|
|
|
from typing import Callable, List
|
|
|
|
import torch
|
|
|
|
import functools
|
|
|
|
|
2022-03-09 09:28:17 +00:00
|
|
|
|
2022-03-02 06:38:22 +00:00
|
|
|
class BaseParamHookMgr(object):
|
2022-03-09 09:28:17 +00:00
|
|
|
|
2022-03-02 06:38:22 +00:00
|
|
|
def __init__(self, param_list: List[torch.nn.Parameter]) -> None:
|
|
|
|
r"""
|
|
|
|
register backward hook on every parameters of module
|
|
|
|
"""
|
|
|
|
self._param_list = param_list
|
|
|
|
self._hook_list = []
|
|
|
|
|
2022-03-09 09:28:17 +00:00
|
|
|
def register_backward_hooks(self, hook_call: Callable) -> None:
|
2022-03-02 06:38:22 +00:00
|
|
|
r"""
|
2022-03-09 09:28:17 +00:00
|
|
|
The hook_call will be called every time a gradient with respect to the a param in self.param_list
|
|
|
|
is computed.
|
2022-03-02 06:38:22 +00:00
|
|
|
The hook should have the following signature:
|
|
|
|
```
|
|
|
|
hook(param, grad) -> Tensor or None
|
|
|
|
```
|
|
|
|
"""
|
|
|
|
if not torch.is_grad_enabled():
|
2022-03-09 09:28:17 +00:00
|
|
|
return # don't register grad hooks if grad isn't enabled
|
2022-03-02 06:38:22 +00:00
|
|
|
for p in self._param_list:
|
|
|
|
if p.requires_grad and not hasattr(p, '_base_param_hook'):
|
|
|
|
handle = p.register_hook(functools.partial(hook_call, p))
|
|
|
|
p._base_param_hook = handle
|
|
|
|
|
|
|
|
def remove_hooks(self):
|
|
|
|
for p in self._param_list:
|
|
|
|
if p.requires_grad and hasattr(p, '_base_param_hook'):
|
|
|
|
p._base_param_hook.remove()
|