2022-05-31 04:00:12 +00:00
|
|
|
from abc import ABC, abstractmethod
|
2022-11-08 09:03:50 +00:00
|
|
|
from contextlib import contextmanager
|
|
|
|
from typing import Any, List, Tuple
|
|
|
|
|
|
|
|
import torch
|
|
|
|
|
2022-06-17 08:12:05 +00:00
|
|
|
from colossalai.tensor.colo_tensor import ColoTensor
|
2022-11-08 09:03:50 +00:00
|
|
|
from colossalai.tensor.tensor_spec import ColoTensorSpec
|
2022-05-31 04:00:12 +00:00
|
|
|
|
|
|
|
|
2022-12-05 09:11:06 +00:00
|
|
|
class ColoParamOpHook(ABC):
|
2022-12-12 07:39:31 +00:00
|
|
|
"""
|
|
|
|
Hook which is triggered by each operation when operands contain ColoParameter.
|
2022-06-13 08:11:53 +00:00
|
|
|
To customize it, you must inherit this abstract class, and implement ``pre_forward``,
|
2022-12-12 07:39:31 +00:00
|
|
|
``post_forward``, ``pre_backward`` and ``post_backward``.
|
|
|
|
These four methods apply a list of ColoParameter as input args.
|
2022-06-13 08:11:53 +00:00
|
|
|
"""
|
2022-05-31 04:00:12 +00:00
|
|
|
|
|
|
|
@abstractmethod
|
|
|
|
def pre_forward(self, params: List[torch.Tensor]) -> None:
|
|
|
|
pass
|
|
|
|
|
|
|
|
@abstractmethod
|
|
|
|
def post_forward(self, params: List[torch.Tensor]) -> None:
|
|
|
|
pass
|
|
|
|
|
|
|
|
@abstractmethod
|
|
|
|
def pre_backward(self, params: List[torch.Tensor]) -> None:
|
|
|
|
pass
|
|
|
|
|
|
|
|
@abstractmethod
|
|
|
|
def post_backward(self, params: List[torch.Tensor]) -> None:
|
|
|
|
pass
|
|
|
|
|
|
|
|
|
2022-12-05 09:11:06 +00:00
|
|
|
class ColoParamOpHookManager:
|
2022-12-12 07:39:31 +00:00
|
|
|
"""
|
|
|
|
Manage your param op hooks. It only has static methods.
|
2022-06-13 08:11:53 +00:00
|
|
|
The only static method you should call is ``use_hooks(*hooks)``.
|
|
|
|
"""
|
2022-12-05 09:11:06 +00:00
|
|
|
hooks: Tuple[ColoParamOpHook, ...] = tuple()
|
2022-05-31 04:00:12 +00:00
|
|
|
|
2022-06-13 08:11:53 +00:00
|
|
|
@staticmethod
|
|
|
|
@contextmanager
|
2022-12-05 09:11:06 +00:00
|
|
|
def use_hooks(*hooks: ColoParamOpHook):
|
2022-06-13 08:11:53 +00:00
|
|
|
"""Change the param op hooks you use. Nested calling is allowed.
|
|
|
|
|
2022-07-21 07:54:53 +00:00
|
|
|
Example:
|
2022-12-05 09:11:06 +00:00
|
|
|
>>> with ColoParamOpHookManager.use_hooks(*hooks):
|
2022-06-13 08:11:53 +00:00
|
|
|
>>> do_something()
|
2022-12-05 09:11:06 +00:00
|
|
|
>>> with ColoParamOpHookManager.use_hooks():
|
2022-06-13 08:11:53 +00:00
|
|
|
>>> // clear hooks
|
|
|
|
>>> do_something()
|
|
|
|
"""
|
|
|
|
try:
|
2022-12-05 09:11:06 +00:00
|
|
|
old_param_op_hooks = ColoParamOpHookManager.hooks
|
|
|
|
ColoParamOpHookManager.hooks = hooks
|
2022-06-13 08:11:53 +00:00
|
|
|
yield
|
|
|
|
finally:
|
2022-12-05 09:11:06 +00:00
|
|
|
ColoParamOpHookManager.hooks = old_param_op_hooks
|
2022-06-13 08:11:53 +00:00
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
def _trigger_pre_forward(params: List[torch.Tensor]) -> None:
|
2022-12-05 09:11:06 +00:00
|
|
|
for hook in ColoParamOpHookManager.hooks:
|
2022-06-13 08:11:53 +00:00
|
|
|
hook.pre_forward(params)
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
def _trigger_post_forward(params: List[torch.Tensor]) -> None:
|
2022-12-05 09:11:06 +00:00
|
|
|
for hook in ColoParamOpHookManager.hooks:
|
2022-06-13 08:11:53 +00:00
|
|
|
hook.post_forward(params)
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
def _trigger_pre_backward(params: List[torch.Tensor]) -> None:
|
2022-12-05 09:11:06 +00:00
|
|
|
for hook in ColoParamOpHookManager.hooks:
|
2022-06-13 08:11:53 +00:00
|
|
|
hook.pre_backward(params)
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
def _trigger_post_backward(params: List[torch.Tensor]) -> None:
|
2022-12-05 09:11:06 +00:00
|
|
|
for hook in ColoParamOpHookManager.hooks:
|
2022-06-13 08:11:53 +00:00
|
|
|
hook.post_backward(params)
|
|
|
|
|
|
|
|
@staticmethod
|
2022-06-17 08:12:05 +00:00
|
|
|
def pre_op(params: List[torch.Tensor], *args: Any) -> list:
|
2022-12-05 09:11:06 +00:00
|
|
|
ColoParamOpHookManager._trigger_pre_forward(params)
|
2022-12-26 07:03:54 +00:00
|
|
|
grad_args, rear_args = _get_grad_args(*args)
|
|
|
|
colo_info = _get_colo_tensors_info(*grad_args)
|
|
|
|
rets = PreFwdPostBwd.apply(params, *grad_args)
|
|
|
|
update_args = _update_colo_tensors(colo_info, *rets)
|
|
|
|
if rear_args is None:
|
|
|
|
return update_args
|
|
|
|
else:
|
|
|
|
arg_zero = (tuple(update_args),)
|
|
|
|
return arg_zero + rear_args
|
2022-06-13 08:11:53 +00:00
|
|
|
|
|
|
|
@staticmethod
|
2022-06-17 08:12:05 +00:00
|
|
|
def post_op(params: List[torch.Tensor], arg: Any) -> Any:
|
2022-12-05 09:11:06 +00:00
|
|
|
ColoParamOpHookManager._trigger_post_forward(params)
|
2022-12-26 07:03:54 +00:00
|
|
|
colo_info = _get_colo_tensors_info(arg)
|
2022-06-17 08:12:05 +00:00
|
|
|
ret = PostFwdPreBwd.apply(params, arg)
|
2022-12-26 07:03:54 +00:00
|
|
|
res = _update_colo_tensors(colo_info, ret)
|
|
|
|
if len(res) == 1:
|
|
|
|
return res[0]
|
|
|
|
else:
|
|
|
|
return res
|
2022-06-13 08:11:53 +00:00
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
def has_hook() -> bool:
|
2022-12-05 09:11:06 +00:00
|
|
|
return len(ColoParamOpHookManager.hooks) > 0
|
2022-06-13 08:11:53 +00:00
|
|
|
|
2022-05-31 04:00:12 +00:00
|
|
|
|
|
|
|
class PreFwdPostBwd(torch.autograd.Function):
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
def forward(ctx, params, *args):
|
|
|
|
ctx.params = params
|
2022-12-26 07:03:54 +00:00
|
|
|
return args
|
2022-05-31 04:00:12 +00:00
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
def backward(ctx, *grads):
|
2022-12-05 09:11:06 +00:00
|
|
|
ColoParamOpHookManager._trigger_post_backward(ctx.params)
|
2022-05-31 04:00:12 +00:00
|
|
|
return (None,) + grads
|
|
|
|
|
|
|
|
|
|
|
|
class PostFwdPreBwd(torch.autograd.Function):
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
def forward(ctx, params, args):
|
|
|
|
ctx.params = params
|
|
|
|
return args
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
def backward(ctx, *grads):
|
2022-12-05 09:11:06 +00:00
|
|
|
ColoParamOpHookManager._trigger_pre_backward(ctx.params)
|
2022-05-31 04:00:12 +00:00
|
|
|
return (None,) + grads
|
2022-06-17 08:12:05 +00:00
|
|
|
|
|
|
|
|
2022-12-26 07:03:54 +00:00
|
|
|
def _is_grad_tensor(obj) -> bool:
|
|
|
|
if torch.is_tensor(obj):
|
|
|
|
if obj.grad_fn is not None or obj.requires_grad:
|
|
|
|
return True
|
|
|
|
return False
|
|
|
|
|
|
|
|
|
2023-01-06 10:37:18 +00:00
|
|
|
def _has_grad_tensor(obj) -> bool:
|
|
|
|
if isinstance(obj, tuple) or isinstance(obj, list):
|
|
|
|
for x in obj:
|
|
|
|
if _has_grad_tensor(x):
|
|
|
|
return True
|
|
|
|
return False
|
|
|
|
elif isinstance(obj, dict):
|
|
|
|
for x in obj.values():
|
|
|
|
if _has_grad_tensor(x):
|
|
|
|
return True
|
|
|
|
return False
|
|
|
|
else:
|
|
|
|
return _is_grad_tensor(obj)
|
|
|
|
|
|
|
|
|
2022-12-26 07:03:54 +00:00
|
|
|
def _get_grad_args(*args):
|
2023-01-06 10:37:18 +00:00
|
|
|
# if there is no grad tensors, do nothing
|
|
|
|
if not _has_grad_tensor(args):
|
|
|
|
return args, None
|
2022-12-26 07:03:54 +00:00
|
|
|
# returns the identical args if there is a grad tensor
|
|
|
|
for obj in args:
|
|
|
|
if _is_grad_tensor(obj):
|
|
|
|
return args, None
|
|
|
|
# otherwise, the first arguement should be a tuple of grad tensors
|
|
|
|
# if there is no grad tensor, the backward of PreFwdPostBwd can't be triggered
|
|
|
|
arg_zero = args[0]
|
|
|
|
if not isinstance(arg_zero, tuple):
|
|
|
|
raise NotImplementedError("Some torch function is incompatible because of its complcated inputs.")
|
|
|
|
check_grad_flag = False
|
|
|
|
for obj in arg_zero:
|
|
|
|
check_grad_flag |= _is_grad_tensor(obj)
|
|
|
|
if not check_grad_flag:
|
|
|
|
raise NotImplementedError("Some torch function is incompatible because of its complcated inputs.")
|
|
|
|
return arg_zero, args[1:]
|
2022-06-17 08:12:05 +00:00
|
|
|
|
|
|
|
|
|
|
|
def _get_colo_tensors_info(*args) -> list:
|
|
|
|
info = []
|
|
|
|
for arg in args:
|
|
|
|
if isinstance(arg, ColoTensor):
|
2022-07-06 08:15:16 +00:00
|
|
|
info.append((arg.__class__, ColoTensorSpec(arg.get_process_group(), arg.dist_spec, arg.compute_spec)))
|
2022-06-17 08:12:05 +00:00
|
|
|
else:
|
|
|
|
info.append(None)
|
|
|
|
return info
|
|
|
|
|
|
|
|
|
|
|
|
def _update_colo_tensors(info, *args) -> list:
|
|
|
|
ret = []
|
|
|
|
for t_info, arg in zip(info, args):
|
|
|
|
if t_info is not None:
|
|
|
|
t_cls, spec = t_info
|
|
|
|
arg = t_cls.from_torch_tensor(arg, spec=spec)
|
|
|
|
ret.append(arg)
|
|
|
|
return ret
|