2022-03-31 10:34:11 +00:00
|
|
|
from typing import List, Callable, Optional
|
2022-03-02 10:28:29 +00:00
|
|
|
|
|
|
|
import torch
|
|
|
|
|
2022-01-25 14:20:54 +00:00
|
|
|
from ._base_ophook import BaseOpHook
|
|
|
|
from ._memtracer_ophook import MemTracerOpHook
|
2022-03-02 10:28:29 +00:00
|
|
|
from ._shard_grad_ophook import ShardGradHook
|
2022-03-01 10:17:01 +00:00
|
|
|
from ._shard_param_ophook import ShardParamHook
|
2022-01-25 14:20:54 +00:00
|
|
|
|
2022-03-02 10:28:29 +00:00
|
|
|
all = ["BaseOpHook", "MemTracerOpHook", "register_ophooks_recursively", "ShardParamHook", "ShardGradHook"]
|
2022-01-25 14:20:54 +00:00
|
|
|
|
|
|
|
|
|
|
|
# apply torch.autograd.Function that calls a backward_function to tensors in output
|
|
|
|
def _apply_to_tensors_only(module, functional, backward_function, outputs):
|
|
|
|
if type(outputs) is tuple:
|
|
|
|
touched_outputs = []
|
|
|
|
for output in outputs:
|
2022-03-08 10:18:06 +00:00
|
|
|
touched_output = _apply_to_tensors_only(module, functional, backward_function, output)
|
2022-01-25 14:20:54 +00:00
|
|
|
touched_outputs.append(touched_output)
|
|
|
|
return tuple(touched_outputs)
|
|
|
|
elif type(outputs) is torch.Tensor:
|
|
|
|
return functional.apply(module, backward_function, outputs)
|
|
|
|
else:
|
|
|
|
return outputs
|
|
|
|
|
|
|
|
|
|
|
|
class PreBackwardFunction(torch.autograd.Function):
|
2022-03-08 10:18:06 +00:00
|
|
|
|
2022-01-25 14:20:54 +00:00
|
|
|
@staticmethod
|
|
|
|
def forward(ctx, module, pre_backward_function, outputs):
|
|
|
|
ctx.module = module
|
|
|
|
ctx.pre_backward_function = pre_backward_function
|
|
|
|
module.applied_pre_backward = False
|
|
|
|
outputs = outputs.detach()
|
|
|
|
return outputs
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
def backward(ctx, *args):
|
|
|
|
ctx.pre_backward_function(ctx.module)
|
|
|
|
return (None, None) + args
|
|
|
|
|
|
|
|
|
|
|
|
class PostBackwardFunction(torch.autograd.Function):
|
2022-03-08 10:18:06 +00:00
|
|
|
|
2022-01-25 14:20:54 +00:00
|
|
|
@staticmethod
|
|
|
|
def forward(ctx, module, pre_backward_function, output):
|
|
|
|
ctx.module = module
|
|
|
|
output = output.detach()
|
|
|
|
ctx.pre_backward_function = pre_backward_function
|
|
|
|
return output
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
def backward(ctx, *args):
|
|
|
|
"""
|
|
|
|
Args:
|
|
|
|
activation_grad of the next layer.
|
|
|
|
Returns:
|
|
|
|
grad of the input activation.
|
|
|
|
"""
|
|
|
|
ctx.pre_backward_function(ctx.module)
|
|
|
|
return (None, None) + args
|
|
|
|
|
|
|
|
|
2022-03-31 10:34:11 +00:00
|
|
|
def register_ophooks_recursively(module: torch.nn.Module,
|
|
|
|
ophook_list: List[BaseOpHook] = None,
|
|
|
|
name: str = "",
|
|
|
|
filter_fn: Optional[Callable] = None):
|
2022-01-25 14:20:54 +00:00
|
|
|
r"""Recursilvely register pre/post hooks for all submodules in the module in FWD and BWD."""
|
|
|
|
assert isinstance(module, torch.nn.Module)
|
2022-03-28 09:42:18 +00:00
|
|
|
|
|
|
|
# Add hooks for submodules
|
2022-01-25 14:20:54 +00:00
|
|
|
for child_name, child in module.named_children():
|
2022-03-31 10:34:11 +00:00
|
|
|
register_ophooks_recursively(child, ophook_list, name + child_name, filter_fn)
|
2022-01-25 14:20:54 +00:00
|
|
|
|
2022-03-28 09:42:18 +00:00
|
|
|
# Early return on modules with no parameters.
|
|
|
|
if len(list(module.parameters(recurse=False))) == 0:
|
2022-01-25 14:20:54 +00:00
|
|
|
return
|
|
|
|
|
2022-03-31 10:34:11 +00:00
|
|
|
# return from flitered module
|
|
|
|
if filter_fn is not None and filter_fn(module):
|
|
|
|
return
|
|
|
|
|
2022-01-25 14:20:54 +00:00
|
|
|
if ophook_list is not None:
|
|
|
|
for hook in ophook_list:
|
|
|
|
assert (isinstance(hook, BaseOpHook))
|
|
|
|
|
|
|
|
def _pre_forward_module_hook(submodule, *args):
|
|
|
|
for hook in ophook_list:
|
|
|
|
assert isinstance(submodule, torch.nn.Module)
|
|
|
|
hook.pre_fwd_exec(submodule, *args)
|
|
|
|
|
|
|
|
def _post_forward_module_hook(submodule, *args):
|
|
|
|
for hook in ophook_list:
|
|
|
|
assert isinstance(submodule, torch.nn.Module)
|
|
|
|
hook.post_fwd_exec(submodule, *args)
|
|
|
|
|
|
|
|
def _pre_backward_module_hook(submodule, inputs, output):
|
2022-03-08 10:18:06 +00:00
|
|
|
|
2022-01-25 14:20:54 +00:00
|
|
|
def _run_before_backward_function(submodule):
|
|
|
|
for hook in ophook_list:
|
|
|
|
assert isinstance(submodule, torch.nn.Module)
|
|
|
|
hook.pre_bwd_exec(submodule, inputs, output)
|
|
|
|
|
2022-03-08 10:18:06 +00:00
|
|
|
return _apply_to_tensors_only(submodule, PreBackwardFunction, _run_before_backward_function, output)
|
2022-01-25 14:20:54 +00:00
|
|
|
|
|
|
|
def _post_backward_module_hook(submodule, inputs):
|
2022-03-08 10:18:06 +00:00
|
|
|
|
2022-01-25 14:20:54 +00:00
|
|
|
def _run_after_backward_function(submodule):
|
|
|
|
for hook in ophook_list:
|
|
|
|
assert isinstance(submodule, torch.nn.Module)
|
|
|
|
hook.post_bwd_exec(submodule, inputs)
|
|
|
|
|
2022-03-08 10:18:06 +00:00
|
|
|
return _apply_to_tensors_only(submodule, PostBackwardFunction, _run_after_backward_function, inputs)
|
2022-01-25 14:20:54 +00:00
|
|
|
|
|
|
|
module.register_forward_pre_hook(_pre_forward_module_hook)
|
|
|
|
module.register_forward_hook(_post_forward_module_hook)
|
|
|
|
|
|
|
|
module.register_forward_hook(_pre_backward_module_hook)
|
|
|
|
module.register_forward_pre_hook(_post_backward_module_hook)
|