from abc import ABC, abstractmethod from contextlib import contextmanager from typing import Any, List, Tuple import torch from torch.utils._pytree import TreeSpec, tree_flatten, tree_unflatten class ColoParamOpHook(ABC): """ Hook which is triggered by each operation when operands contain ColoParameter. To customize it, you must inherit this abstract class, and implement ``pre_forward``, ``post_forward``, ``pre_backward`` and ``post_backward``. These four methods apply a list of ColoParameter as input args. """ @abstractmethod def pre_forward(self, params: List[torch.Tensor]) -> None: pass @abstractmethod def post_forward(self, params: List[torch.Tensor]) -> None: pass @abstractmethod def pre_backward(self, params: List[torch.Tensor]) -> None: pass @abstractmethod def post_backward(self, params: List[torch.Tensor]) -> None: pass class ColoParamOpHookManager: """ Manage your param op hooks. It only has static methods. The only static method you should call is ``use_hooks(*hooks)``. """ hooks: Tuple[ColoParamOpHook, ...] = tuple() @staticmethod @contextmanager def use_hooks(*hooks: ColoParamOpHook): """Change the param op hooks you use. Nested calling is allowed. Example: >>> with ColoParamOpHookManager.use_hooks(*hooks): >>> do_something() >>> with ColoParamOpHookManager.use_hooks(): >>> // clear hooks >>> do_something() """ try: old_param_op_hooks = ColoParamOpHookManager.hooks ColoParamOpHookManager.hooks = hooks yield finally: ColoParamOpHookManager.hooks = old_param_op_hooks @staticmethod def _trigger_pre_forward(params: List[torch.Tensor]) -> None: for hook in ColoParamOpHookManager.hooks: hook.pre_forward(params) @staticmethod def _trigger_post_forward(params: List[torch.Tensor]) -> None: for hook in ColoParamOpHookManager.hooks: hook.post_forward(params) @staticmethod def _trigger_pre_backward(params: List[torch.Tensor]) -> None: for hook in ColoParamOpHookManager.hooks: hook.pre_backward(params) @staticmethod def _trigger_post_backward(params: List[torch.Tensor]) -> None: for hook in ColoParamOpHookManager.hooks: hook.post_backward(params) @staticmethod def pre_op(params: List[torch.Tensor], *args: Any) -> list: ColoParamOpHookManager._trigger_pre_forward(params) # auto grad function can only recognize torch.Tensor, thus we have to flatten the input # if one of the input requires grad, all the output will be treated as requires grad # and will have grad fn even the corresponding input does not require grad # we have to extract tensors requiring grad into flat list and then merge them back grad_args, other_args, grad_flags, spec = _flatten_grad_args(args) new_grad_args = PreFwdPostBwd.apply(params, *grad_args) return _merge_args(new_grad_args, other_args, grad_flags, spec) @staticmethod def post_op(params: List[torch.Tensor], arg: Any) -> Any: ColoParamOpHookManager._trigger_post_forward(params) return PostFwdPreBwd.apply(params, arg) @staticmethod def has_hook() -> bool: return len(ColoParamOpHookManager.hooks) > 0 class PreFwdPostBwd(torch.autograd.Function): @staticmethod def forward(ctx, params, *args): ctx.params = params return args @staticmethod def backward(ctx, *grads): ColoParamOpHookManager._trigger_post_backward(ctx.params) return (None,) + grads class PostFwdPreBwd(torch.autograd.Function): @staticmethod def forward(ctx, params, args): ctx.params = params return args @staticmethod def backward(ctx, *grads): ColoParamOpHookManager._trigger_pre_backward(ctx.params) return (None,) + grads def _is_grad_tensor(obj) -> bool: if torch.is_tensor(obj): if obj.grad_fn is not None or obj.requires_grad: return True return False def _flatten_grad_args(args) -> Tuple[list, list, List[bool], TreeSpec]: flat_args, spec = tree_flatten(args) grad_args = [] other_args = [] grad_flags = [] for arg in flat_args: flag = _is_grad_tensor(arg) grad_flags.append(flag) if flag: grad_args.append(arg) else: other_args.append(arg) assert len(grad_args) > 0 return grad_args, other_args, grad_flags, spec def _merge_args(grad_args, other_args, grad_flags, spec): grad_iter = iter(grad_args) other_iter = iter(other_args) flat_args = [next(grad_iter) if flag else next(other_iter) for flag in grad_flags] return tree_unflatten(flat_args, spec)