mirror of https://github.com/hpcaitech/ColossalAI
72 lines
1.7 KiB
Python
72 lines
1.7 KiB
Python
import torch
|
|
from contextlib import contextmanager
|
|
from abc import ABC, abstractmethod
|
|
from typing import List, Tuple
|
|
|
|
|
|
class ParamOpHook(ABC):
|
|
|
|
@abstractmethod
|
|
def pre_forward(self, params: List[torch.Tensor]) -> None:
|
|
pass
|
|
|
|
@abstractmethod
|
|
def post_forward(self, params: List[torch.Tensor]) -> None:
|
|
pass
|
|
|
|
@abstractmethod
|
|
def pre_backward(self, params: List[torch.Tensor]) -> None:
|
|
pass
|
|
|
|
@abstractmethod
|
|
def post_backward(self, params: List[torch.Tensor]) -> None:
|
|
pass
|
|
|
|
|
|
class _ParamOpHookWrapper:
|
|
hooks: Tuple[ParamOpHook, ...] = tuple()
|
|
|
|
|
|
class PreFwdPostBwd(torch.autograd.Function):
|
|
|
|
@staticmethod
|
|
def forward(ctx, params, *args):
|
|
ctx.params = params
|
|
for hook in _ParamOpHookWrapper.hooks:
|
|
hook.pre_forward(ctx.params)
|
|
if len(args) == 1:
|
|
return args[0]
|
|
return args
|
|
|
|
@staticmethod
|
|
def backward(ctx, *grads):
|
|
for hook in _ParamOpHookWrapper.hooks:
|
|
hook.post_backward(ctx.params)
|
|
return (None,) + grads
|
|
|
|
|
|
class PostFwdPreBwd(torch.autograd.Function):
|
|
|
|
@staticmethod
|
|
def forward(ctx, params, args):
|
|
ctx.params = params
|
|
for hook in _ParamOpHookWrapper.hooks:
|
|
hook.post_forward(params)
|
|
return args
|
|
|
|
@staticmethod
|
|
def backward(ctx, *grads):
|
|
for hook in _ParamOpHookWrapper.hooks:
|
|
hook.pre_backward(ctx.params)
|
|
return (None,) + grads
|
|
|
|
|
|
@contextmanager
|
|
def use_param_op_hooks(*hooks: ParamOpHook):
|
|
try:
|
|
old_param_op_hooks = _ParamOpHookWrapper.hooks
|
|
_ParamOpHookWrapper.hooks = hooks
|
|
yield
|
|
finally:
|
|
_ParamOpHookWrapper.hooks = old_param_op_hooks
|