|
|
|
@ -8,7 +8,7 @@ from colossalai.tensor.colo_tensor import ColoTensor
|
|
|
|
|
from colossalai.tensor.tensor_spec import ColoTensorSpec
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ParamOpHook(ABC):
|
|
|
|
|
class ColoParamOpHook(ABC):
|
|
|
|
|
"""Hook which is triggered by each operation when operands contain ColoParameter.
|
|
|
|
|
To customize it, you must inherit this abstract class, and implement ``pre_forward``,
|
|
|
|
|
``post_forward``, ``pre_backward`` and ``post_backward``. These four methods take a list
|
|
|
|
@ -32,68 +32,68 @@ class ParamOpHook(ABC):
|
|
|
|
|
pass
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ParamOpHookManager:
|
|
|
|
|
class ColoParamOpHookManager:
|
|
|
|
|
"""Manage your param op hooks. It only has static methods.
|
|
|
|
|
The only static method you should call is ``use_hooks(*hooks)``.
|
|
|
|
|
"""
|
|
|
|
|
hooks: Tuple[ParamOpHook, ...] = tuple()
|
|
|
|
|
hooks: Tuple[ColoParamOpHook, ...] = tuple()
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
@contextmanager
|
|
|
|
|
def use_hooks(*hooks: ParamOpHook):
|
|
|
|
|
def use_hooks(*hooks: ColoParamOpHook):
|
|
|
|
|
"""Change the param op hooks you use. Nested calling is allowed.
|
|
|
|
|
|
|
|
|
|
Example:
|
|
|
|
|
>>> with ParamOpHookManager.use_hooks(*hooks):
|
|
|
|
|
>>> with ColoParamOpHookManager.use_hooks(*hooks):
|
|
|
|
|
>>> do_something()
|
|
|
|
|
>>> with ParamOpHookManager.use_hooks():
|
|
|
|
|
>>> with ColoParamOpHookManager.use_hooks():
|
|
|
|
|
>>> // clear hooks
|
|
|
|
|
>>> do_something()
|
|
|
|
|
"""
|
|
|
|
|
try:
|
|
|
|
|
old_param_op_hooks = ParamOpHookManager.hooks
|
|
|
|
|
ParamOpHookManager.hooks = hooks
|
|
|
|
|
old_param_op_hooks = ColoParamOpHookManager.hooks
|
|
|
|
|
ColoParamOpHookManager.hooks = hooks
|
|
|
|
|
yield
|
|
|
|
|
finally:
|
|
|
|
|
ParamOpHookManager.hooks = old_param_op_hooks
|
|
|
|
|
ColoParamOpHookManager.hooks = old_param_op_hooks
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
def _trigger_pre_forward(params: List[torch.Tensor]) -> None:
|
|
|
|
|
for hook in ParamOpHookManager.hooks:
|
|
|
|
|
for hook in ColoParamOpHookManager.hooks:
|
|
|
|
|
hook.pre_forward(params)
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
def _trigger_post_forward(params: List[torch.Tensor]) -> None:
|
|
|
|
|
for hook in ParamOpHookManager.hooks:
|
|
|
|
|
for hook in ColoParamOpHookManager.hooks:
|
|
|
|
|
hook.post_forward(params)
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
def _trigger_pre_backward(params: List[torch.Tensor]) -> None:
|
|
|
|
|
for hook in ParamOpHookManager.hooks:
|
|
|
|
|
for hook in ColoParamOpHookManager.hooks:
|
|
|
|
|
hook.pre_backward(params)
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
def _trigger_post_backward(params: List[torch.Tensor]) -> None:
|
|
|
|
|
for hook in ParamOpHookManager.hooks:
|
|
|
|
|
for hook in ColoParamOpHookManager.hooks:
|
|
|
|
|
hook.post_backward(params)
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
def pre_op(params: List[torch.Tensor], *args: Any) -> list:
|
|
|
|
|
ParamOpHookManager._trigger_pre_forward(params)
|
|
|
|
|
ColoParamOpHookManager._trigger_pre_forward(params)
|
|
|
|
|
args_info = _get_colo_tensors_info(*args)
|
|
|
|
|
rets = PreFwdPostBwd.apply(params, *args)
|
|
|
|
|
return _update_colo_tensors(args_info, *rets)
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
def post_op(params: List[torch.Tensor], arg: Any) -> Any:
|
|
|
|
|
ParamOpHookManager._trigger_post_forward(params)
|
|
|
|
|
ColoParamOpHookManager._trigger_post_forward(params)
|
|
|
|
|
arg_info = _get_colo_tensors_info(arg)
|
|
|
|
|
ret = PostFwdPreBwd.apply(params, arg)
|
|
|
|
|
return _unpack_args(_update_colo_tensors(arg_info, ret))
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
def has_hook() -> bool:
|
|
|
|
|
return len(ParamOpHookManager.hooks) > 0
|
|
|
|
|
return len(ColoParamOpHookManager.hooks) > 0
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class PreFwdPostBwd(torch.autograd.Function):
|
|
|
|
@ -105,7 +105,7 @@ class PreFwdPostBwd(torch.autograd.Function):
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
def backward(ctx, *grads):
|
|
|
|
|
ParamOpHookManager._trigger_post_backward(ctx.params)
|
|
|
|
|
ColoParamOpHookManager._trigger_post_backward(ctx.params)
|
|
|
|
|
return (None,) + grads
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@ -118,7 +118,7 @@ class PostFwdPreBwd(torch.autograd.Function):
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
def backward(ctx, *grads):
|
|
|
|
|
ParamOpHookManager._trigger_pre_backward(ctx.params)
|
|
|
|
|
ColoParamOpHookManager._trigger_pre_backward(ctx.params)
|
|
|
|
|
return (None,) + grads
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|