ColossalAI/colossalai/engine/ophooks/__init__.py

120 lines
4.1 KiB
Python
Raw Normal View History

from typing import List, Callable, Optional
import torch
from ._base_ophook import BaseOpHook
from ._memtracer_ophook import MemTracerOpHook
from ._shard_grad_ophook import ShardGradHook
from ._shard_param_ophook import ShardParamHook
all = ["BaseOpHook", "MemTracerOpHook", "register_ophooks_recursively", "ShardParamHook", "ShardGradHook"]
# apply torch.autograd.Function that calls a backward_function to tensors in output
def _apply_to_tensors_only(module, functional, backward_function, outputs):
if type(outputs) is tuple:
touched_outputs = []
for output in outputs:
touched_output = _apply_to_tensors_only(module, functional, backward_function, output)
touched_outputs.append(touched_output)
return tuple(touched_outputs)
elif type(outputs) is torch.Tensor:
return functional.apply(module, backward_function, outputs)
else:
return outputs
class PreBackwardFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, module, pre_backward_function, outputs):
ctx.module = module
ctx.pre_backward_function = pre_backward_function
module.applied_pre_backward = False
outputs = outputs.detach()
return outputs
@staticmethod
def backward(ctx, *args):
ctx.pre_backward_function(ctx.module)
return (None, None) + args
class PostBackwardFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, module, pre_backward_function, output):
ctx.module = module
output = output.detach()
ctx.pre_backward_function = pre_backward_function
return output
@staticmethod
def backward(ctx, *args):
"""
Args:
activation_grad of the next layer.
Returns:
grad of the input activation.
"""
ctx.pre_backward_function(ctx.module)
return (None, None) + args
def register_ophooks_recursively(module: torch.nn.Module,
ophook_list: List[BaseOpHook] = None,
name: str = "",
filter_fn: Optional[Callable] = None):
r"""Recursilvely register pre/post hooks for all submodules in the module in FWD and BWD."""
assert isinstance(module, torch.nn.Module)
# Add hooks for submodules
for child_name, child in module.named_children():
register_ophooks_recursively(child, ophook_list, name + child_name, filter_fn)
# Early return on modules with no parameters.
if len(list(module.parameters(recurse=False))) == 0:
return
# return from flitered module
if filter_fn is not None and filter_fn(module):
return
if ophook_list is not None:
for hook in ophook_list:
assert (isinstance(hook, BaseOpHook))
def _pre_forward_module_hook(submodule, *args):
for hook in ophook_list:
assert isinstance(submodule, torch.nn.Module)
hook.pre_fwd_exec(submodule, *args)
def _post_forward_module_hook(submodule, *args):
for hook in ophook_list:
assert isinstance(submodule, torch.nn.Module)
hook.post_fwd_exec(submodule, *args)
def _pre_backward_module_hook(submodule, inputs, output):
def _run_before_backward_function(submodule):
for hook in ophook_list:
assert isinstance(submodule, torch.nn.Module)
hook.pre_bwd_exec(submodule, inputs, output)
return _apply_to_tensors_only(submodule, PreBackwardFunction, _run_before_backward_function, output)
def _post_backward_module_hook(submodule, inputs):
def _run_after_backward_function(submodule):
for hook in ophook_list:
assert isinstance(submodule, torch.nn.Module)
hook.post_bwd_exec(submodule, inputs)
return _apply_to_tensors_only(submodule, PostBackwardFunction, _run_after_backward_function, inputs)
module.register_forward_pre_hook(_pre_forward_module_hook)
module.register_forward_hook(_post_forward_module_hook)
module.register_forward_hook(_pre_backward_module_hook)
module.register_forward_pre_hook(_post_backward_module_hook)