2022-03-02 10:28:29 +00:00
from typing import List
import torch
2022-01-25 14:20:54 +00:00
from ._base_ophook import BaseOpHook
from ._memtracer_ophook import MemTracerOpHook
2022-03-02 10:28:29 +00:00
from ._shard_grad_ophook import ShardGradHook
2022-03-01 10:17:01 +00:00
from ._shard_param_ophook import ShardParamHook
2022-01-25 14:20:54 +00:00
2022-03-02 10:28:29 +00:00
all = ["BaseOpHook", "MemTracerOpHook", "register_ophooks_recursively", "ShardParamHook", "ShardGradHook"]
2022-01-25 14:20:54 +00:00
# apply torch.autograd.Function that calls a backward_function to tensors in output
def _apply_to_tensors_only(module, functional, backward_function, outputs):
if type(outputs) is tuple:
touched_outputs = []
for output in outputs:
2022-03-08 10:18:06 +00:00
touched_output = _apply_to_tensors_only(module, functional, backward_function, output)
2022-01-25 14:20:54 +00:00
return tuple(touched_outputs)
elif type(outputs) is torch.Tensor:
return functional.apply(module, backward_function, outputs)
return outputs
class PreBackwardFunction(torch.autograd.Function):
2022-03-08 10:18:06 +00:00
2022-01-25 14:20:54 +00:00
def forward(ctx, module, pre_backward_function, outputs):
ctx.module = module
ctx.pre_backward_function = pre_backward_function
module.applied_pre_backward = False
outputs = outputs.detach()
return outputs
def backward(ctx, *args):
return (None, None) + args
class PostBackwardFunction(torch.autograd.Function):
2022-03-08 10:18:06 +00:00
2022-01-25 14:20:54 +00:00
def forward(ctx, module, pre_backward_function, output):
ctx.module = module
output = output.detach()
ctx.pre_backward_function = pre_backward_function
return output
def backward(ctx, *args):
activation_grad of the next layer.
grad of the input activation.
return (None, None) + args
2022-03-08 10:18:06 +00:00
def register_ophooks_recursively(module: torch.nn.Module, ophook_list: List[BaseOpHook] = None, name: str = ""):
2022-01-25 14:20:54 +00:00
r"""Recursilvely register pre/post hooks for all submodules in the module in FWD and BWD."""
assert isinstance(module, torch.nn.Module)
has_children = False
for child_name, child in module.named_children():
register_ophooks_recursively(child, ophook_list, name + child_name)
has_children = True
# Early return on modules with no parameters or buffers that
# are not in their children.
2022-03-08 10:18:06 +00:00
if (len(list(module.named_parameters(recurse=False))) == 0 and len(list(module.named_buffers(recurse=False))) == 0):
2022-01-25 14:20:54 +00:00
# return if the module has not childern.
if has_children:
if ophook_list is not None:
for hook in ophook_list:
assert (isinstance(hook, BaseOpHook))
def _pre_forward_module_hook(submodule, *args):
for hook in ophook_list:
assert isinstance(submodule, torch.nn.Module)
hook.pre_fwd_exec(submodule, *args)
def _post_forward_module_hook(submodule, *args):
for hook in ophook_list:
assert isinstance(submodule, torch.nn.Module)
hook.post_fwd_exec(submodule, *args)
def _pre_backward_module_hook(submodule, inputs, output):
2022-03-08 10:18:06 +00:00
2022-01-25 14:20:54 +00:00
def _run_before_backward_function(submodule):
for hook in ophook_list:
assert isinstance(submodule, torch.nn.Module)
hook.pre_bwd_exec(submodule, inputs, output)
2022-03-08 10:18:06 +00:00
return _apply_to_tensors_only(submodule, PreBackwardFunction, _run_before_backward_function, output)
2022-01-25 14:20:54 +00:00
def _post_backward_module_hook(submodule, inputs):
2022-03-08 10:18:06 +00:00
2022-01-25 14:20:54 +00:00
def _run_after_backward_function(submodule):
for hook in ophook_list:
assert isinstance(submodule, torch.nn.Module)
hook.post_bwd_exec(submodule, inputs)
2022-03-08 10:18:06 +00:00
return _apply_to_tensors_only(submodule, PostBackwardFunction, _run_after_backward_function, inputs)
2022-01-25 14:20:54 +00:00