You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
ColossalAI/colossalai/tensor/param_op_hook.py

155 lines
4.9 KiB

from abc import ABC, abstractmethod
from contextlib import contextmanager
from typing import Any, List, Tuple
import torch
from torch.utils._pytree import TreeSpec, tree_flatten, tree_unflatten
class ColoParamOpHook(ABC):
"""
Hook which is triggered by each operation when operands contain ColoParameter.
To customize it, you must inherit this abstract class, and implement ``pre_forward``,
``post_forward``, ``pre_backward`` and ``post_backward``.
These four methods apply a list of ColoParameter as input args.
"""
@abstractmethod
def pre_forward(self, params: List[torch.Tensor]) -> None:
pass
@abstractmethod
def post_forward(self, params: List[torch.Tensor]) -> None:
pass
@abstractmethod
def pre_backward(self, params: List[torch.Tensor]) -> None:
pass
@abstractmethod
def post_backward(self, params: List[torch.Tensor]) -> None:
pass
class ColoParamOpHookManager:
"""
Manage your param op hooks. It only has static methods.
The only static method you should call is ``use_hooks(*hooks)``.
"""
hooks: Tuple[ColoParamOpHook, ...] = tuple()
@staticmethod
@contextmanager
def use_hooks(*hooks: ColoParamOpHook):
"""Change the param op hooks you use. Nested calling is allowed.
Example:
>>> with ColoParamOpHookManager.use_hooks(*hooks):
>>> do_something()
>>> with ColoParamOpHookManager.use_hooks():
>>> // clear hooks
>>> do_something()
"""
try:
old_param_op_hooks = ColoParamOpHookManager.hooks
ColoParamOpHookManager.hooks = hooks
yield
finally:
ColoParamOpHookManager.hooks = old_param_op_hooks
@staticmethod
def _trigger_pre_forward(params: List[torch.Tensor]) -> None:
for hook in ColoParamOpHookManager.hooks:
hook.pre_forward(params)
@staticmethod
def _trigger_post_forward(params: List[torch.Tensor]) -> None:
for hook in ColoParamOpHookManager.hooks:
hook.post_forward(params)
@staticmethod
def _trigger_pre_backward(params: List[torch.Tensor]) -> None:
for hook in ColoParamOpHookManager.hooks:
hook.pre_backward(params)
@staticmethod
def _trigger_post_backward(params: List[torch.Tensor]) -> None:
for hook in ColoParamOpHookManager.hooks:
hook.post_backward(params)
@staticmethod
def pre_op(params: List[torch.Tensor], *args: Any) -> list:
ColoParamOpHookManager._trigger_pre_forward(params)
# auto grad function can only recognize torch.Tensor, thus we have to flatten the input
# if one of the input requires grad, all the output will be treated as requires grad
# and will have grad fn even the corresponding input does not require grad
# we have to extract tensors requiring grad into flat list and then merge them back
grad_args, other_args, grad_flags, spec = _flatten_grad_args(args)
new_grad_args = PreFwdPostBwd.apply(params, *grad_args)
return _merge_args(new_grad_args, other_args, grad_flags, spec)
@staticmethod
def post_op(params: List[torch.Tensor], arg: Any) -> Any:
ColoParamOpHookManager._trigger_post_forward(params)
return PostFwdPreBwd.apply(params, arg)
@staticmethod
def has_hook() -> bool:
return len(ColoParamOpHookManager.hooks) > 0
class PreFwdPostBwd(torch.autograd.Function):
@staticmethod
def forward(ctx, params, *args):
ctx.params = params
return args
@staticmethod
def backward(ctx, *grads):
ColoParamOpHookManager._trigger_post_backward(ctx.params)
return (None,) + grads
class PostFwdPreBwd(torch.autograd.Function):
@staticmethod
def forward(ctx, params, args):
ctx.params = params
return args
@staticmethod
def backward(ctx, *grads):
ColoParamOpHookManager._trigger_pre_backward(ctx.params)
return (None,) + grads
def _is_grad_tensor(obj) -> bool:
if torch.is_tensor(obj):
if obj.grad_fn is not None or obj.requires_grad:
return True
return False
def _flatten_grad_args(args) -> Tuple[list, list, List[bool], TreeSpec]:
flat_args, spec = tree_flatten(args)
grad_args = []
other_args = []
grad_flags = []
for arg in flat_args:
flag = _is_grad_tensor(arg)
grad_flags.append(flag)
if flag:
grad_args.append(arg)
else:
other_args.append(arg)
assert len(grad_args) > 0
return grad_args, other_args, grad_flags, spec
def _merge_args(grad_args, other_args, grad_flags, spec):
grad_iter = iter(grad_args)
other_iter = iter(other_args)
flat_args = [next(grad_iter) if flag else next(other_iter) for flag in grad_flags]
return tree_unflatten(flat_args, spec)