mirror of https://github.com/hpcaitech/ColossalAI
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
72 lines
1.7 KiB
72 lines
1.7 KiB
3 years ago
|
import torch
|
||
|
from contextlib import contextmanager
|
||
|
from abc import ABC, abstractmethod
|
||
|
from typing import List, Tuple
|
||
|
|
||
|
|
||
|
class ParamOpHook(ABC):
|
||
|
|
||
|
@abstractmethod
|
||
|
def pre_forward(self, params: List[torch.Tensor]) -> None:
|
||
|
pass
|
||
|
|
||
|
@abstractmethod
|
||
|
def post_forward(self, params: List[torch.Tensor]) -> None:
|
||
|
pass
|
||
|
|
||
|
@abstractmethod
|
||
|
def pre_backward(self, params: List[torch.Tensor]) -> None:
|
||
|
pass
|
||
|
|
||
|
@abstractmethod
|
||
|
def post_backward(self, params: List[torch.Tensor]) -> None:
|
||
|
pass
|
||
|
|
||
|
|
||
|
class _ParamOpHookWrapper:
|
||
|
hooks: Tuple[ParamOpHook, ...] = tuple()
|
||
|
|
||
|
|
||
|
class PreFwdPostBwd(torch.autograd.Function):
|
||
|
|
||
|
@staticmethod
|
||
|
def forward(ctx, params, *args):
|
||
|
ctx.params = params
|
||
|
for hook in _ParamOpHookWrapper.hooks:
|
||
|
hook.pre_forward(ctx.params)
|
||
|
if len(args) == 1:
|
||
|
return args[0]
|
||
|
return args
|
||
|
|
||
|
@staticmethod
|
||
|
def backward(ctx, *grads):
|
||
|
for hook in _ParamOpHookWrapper.hooks:
|
||
|
hook.post_backward(ctx.params)
|
||
|
return (None,) + grads
|
||
|
|
||
|
|
||
|
class PostFwdPreBwd(torch.autograd.Function):
|
||
|
|
||
|
@staticmethod
|
||
|
def forward(ctx, params, args):
|
||
|
ctx.params = params
|
||
|
for hook in _ParamOpHookWrapper.hooks:
|
||
|
hook.post_forward(params)
|
||
|
return args
|
||
|
|
||
|
@staticmethod
|
||
|
def backward(ctx, *grads):
|
||
|
for hook in _ParamOpHookWrapper.hooks:
|
||
|
hook.pre_backward(ctx.params)
|
||
|
return (None,) + grads
|
||
|
|
||
|
|
||
|
@contextmanager
|
||
|
def use_param_op_hooks(*hooks: ParamOpHook):
|
||
|
try:
|
||
|
old_param_op_hooks = _ParamOpHookWrapper.hooks
|
||
|
_ParamOpHookWrapper.hooks = hooks
|
||
|
yield
|
||
|
finally:
|
||
|
_ParamOpHookWrapper.hooks = old_param_op_hooks
|