add a common util for hooks registered on parameter. (#292)

pull/394/head
Jiarui Fang 2022-03-02 14:38:22 +08:00 committed by Frank Lee
parent f867365aba
commit 8d653af408
3 changed files with 120 additions and 0 deletions

View File

@ -0,0 +1,2 @@
from ._param_hookmgr import BaseParamHookMgr
__all__ = ["BaseParamHookMgr"]

View File

@ -0,0 +1,32 @@
from typing import Callable, List
import torch
import functools
class BaseParamHookMgr(object):
def __init__(self, param_list: List[torch.nn.Parameter]) -> None:
r"""
register backward hook on every parameters of module
"""
self._param_list = param_list
self._hook_list = []
def register_backward_hooks(self, hook_call : Callable) -> None:
r"""
The hook_call will be called every time a gradient with respect to the a param in self.param_list
is computed.
The hook should have the following signature:
```
hook(param, grad) -> Tensor or None
```
"""
if not torch.is_grad_enabled():
return # don't register grad hooks if grad isn't enabled
for p in self._param_list:
if p.requires_grad and not hasattr(p, '_base_param_hook'):
handle = p.register_hook(functools.partial(hook_call, p))
p._base_param_hook = handle
def remove_hooks(self):
for p in self._param_list:
if p.requires_grad and hasattr(p, '_base_param_hook'):
p._base_param_hook.remove()

View File

@ -0,0 +1,86 @@
import pytest
from colossalai.engine.paramhooks import BaseParamHookMgr
from torch import nn
import torch
import torch.nn.functional as F
import copy
class SubNet(nn.Module):
def __init__(self, out_features) -> None:
super().__init__()
self.bias = nn.Parameter(torch.zeros(out_features))
def forward(self, x, weight):
return F.linear(x, weight, self.bias)
class Net(nn.Module):
def __init__(self, checkpoint=False) -> None:
super().__init__()
self.fc1 = nn.Linear(5, 5)
self.sub_fc = SubNet(5)
self.fc2 = nn.Linear(5, 1)
def forward(self, x):
x = self.fc1(x)
x = self.sub_fc(x, self.fc1.weight)
x = self.fc1(x)
x = self.fc2(x)
return x
def net_data():
return (torch.randn(2, 5, dtype=torch.float, device='cuda'),)
def allclose(tensor_a: torch.Tensor, tensor_b: torch.Tensor, loose=False) -> bool:
if loose:
return torch.allclose(tensor_a, tensor_b, atol=1e-3, rtol=1e-3)
return torch.allclose(tensor_a, tensor_b)
def test_base_param_hook():
torch.manual_seed(0)
model = Net(checkpoint=True).cuda()
model.train()
inputs = net_data()
def run_model(model, inputs, use_param_hook = False):
if use_param_hook:
class HooKWrapper:
def __init__(self) -> None:
self.hook_triggered_times = 0
def wrapper_func(self):
def hook(param, grad) -> torch.Tensor or None:
self.hook_triggered_times += 1
return grad
return hook
hookwrapper = HooKWrapper()
param_list = [p for p in model.parameters()]
hook_mgr = BaseParamHookMgr(param_list)
hook_mgr.register_backward_hooks(hookwrapper.wrapper_func())
model.zero_grad(set_to_none=True)
with torch.cuda.amp.autocast():
y = model(*inputs)
loss = y.sum()
loss.backward()
if use_param_hook:
hook_mgr.remove_hooks()
return hookwrapper.hook_triggered_times
model_copy = copy.deepcopy(model)
run_model(model, inputs, False)
ret2 = run_model(model_copy, inputs, True)
# Make sure param hook has only be fired once in case of parameter sharing
assert ret2 == len(list(model.parameters()))
for p, p_copy in zip(model.parameters(), model_copy.parameters()):
assert allclose(p.grad, p_copy.grad), f"{p.grad} vs {p_copy.grad}"
if __name__ == '__main__':
test_base_param_hook()