mirror of https://github.com/hpcaitech/ColossalAI
120 lines
4.3 KiB
Python
120 lines
4.3 KiB
Python
from typing import List
|
|
|
|
import torch
|
|
|
|
from ._base_ophook import BaseOpHook
|
|
from ._memtracer_ophook import MemTracerOpHook
|
|
from ._shard_grad_ophook import ShardGradHook
|
|
from ._shard_param_ophook import ShardParamHook
|
|
|
|
all = ["BaseOpHook", "MemTracerOpHook", "register_ophooks_recursively", "ShardParamHook", "ShardGradHook"]
|
|
|
|
|
|
# apply torch.autograd.Function that calls a backward_function to tensors in output
|
|
def _apply_to_tensors_only(module, functional, backward_function, outputs):
|
|
if type(outputs) is tuple:
|
|
touched_outputs = []
|
|
for output in outputs:
|
|
touched_output = _apply_to_tensors_only(module, functional,
|
|
backward_function, output)
|
|
touched_outputs.append(touched_output)
|
|
return tuple(touched_outputs)
|
|
elif type(outputs) is torch.Tensor:
|
|
return functional.apply(module, backward_function, outputs)
|
|
else:
|
|
return outputs
|
|
|
|
|
|
class PreBackwardFunction(torch.autograd.Function):
|
|
@staticmethod
|
|
def forward(ctx, module, pre_backward_function, outputs):
|
|
ctx.module = module
|
|
ctx.pre_backward_function = pre_backward_function
|
|
module.applied_pre_backward = False
|
|
outputs = outputs.detach()
|
|
return outputs
|
|
|
|
@staticmethod
|
|
def backward(ctx, *args):
|
|
ctx.pre_backward_function(ctx.module)
|
|
return (None, None) + args
|
|
|
|
|
|
class PostBackwardFunction(torch.autograd.Function):
|
|
@staticmethod
|
|
def forward(ctx, module, pre_backward_function, output):
|
|
ctx.module = module
|
|
output = output.detach()
|
|
ctx.pre_backward_function = pre_backward_function
|
|
return output
|
|
|
|
@staticmethod
|
|
def backward(ctx, *args):
|
|
"""
|
|
Args:
|
|
activation_grad of the next layer.
|
|
Returns:
|
|
grad of the input activation.
|
|
"""
|
|
ctx.pre_backward_function(ctx.module)
|
|
return (None, None) + args
|
|
|
|
|
|
def register_ophooks_recursively(module: torch.nn.Module,
|
|
ophook_list: List[BaseOpHook] = None,
|
|
name: str = ""):
|
|
r"""Recursilvely register pre/post hooks for all submodules in the module in FWD and BWD."""
|
|
assert isinstance(module, torch.nn.Module)
|
|
has_children = False
|
|
for child_name, child in module.named_children():
|
|
register_ophooks_recursively(child, ophook_list, name + child_name)
|
|
has_children = True
|
|
|
|
# Early return on modules with no parameters or buffers that
|
|
# are not in their children.
|
|
if (len(list(module.named_parameters(recurse=False))) == 0
|
|
and len(list(module.named_buffers(recurse=False))) == 0):
|
|
return
|
|
|
|
# return if the module has not childern.
|
|
if has_children:
|
|
return
|
|
|
|
if ophook_list is not None:
|
|
for hook in ophook_list:
|
|
assert (isinstance(hook, BaseOpHook))
|
|
|
|
def _pre_forward_module_hook(submodule, *args):
|
|
for hook in ophook_list:
|
|
assert isinstance(submodule, torch.nn.Module)
|
|
hook.pre_fwd_exec(submodule, *args)
|
|
|
|
def _post_forward_module_hook(submodule, *args):
|
|
for hook in ophook_list:
|
|
assert isinstance(submodule, torch.nn.Module)
|
|
hook.post_fwd_exec(submodule, *args)
|
|
|
|
def _pre_backward_module_hook(submodule, inputs, output):
|
|
def _run_before_backward_function(submodule):
|
|
for hook in ophook_list:
|
|
assert isinstance(submodule, torch.nn.Module)
|
|
hook.pre_bwd_exec(submodule, inputs, output)
|
|
|
|
return _apply_to_tensors_only(submodule, PreBackwardFunction,
|
|
_run_before_backward_function, output)
|
|
|
|
def _post_backward_module_hook(submodule, inputs):
|
|
def _run_after_backward_function(submodule):
|
|
for hook in ophook_list:
|
|
assert isinstance(submodule, torch.nn.Module)
|
|
hook.post_bwd_exec(submodule, inputs)
|
|
|
|
return _apply_to_tensors_only(submodule, PostBackwardFunction,
|
|
_run_after_backward_function, inputs)
|
|
|
|
module.register_forward_pre_hook(_pre_forward_module_hook)
|
|
module.register_forward_hook(_post_forward_module_hook)
|
|
|
|
module.register_forward_hook(_pre_backward_module_hook)
|
|
module.register_forward_pre_hook(_post_backward_module_hook)
|