|
|
|
@ -82,16 +82,26 @@ class ColoParamOpHookManager:
|
|
|
|
|
@staticmethod |
|
|
|
|
def pre_op(params: List[torch.Tensor], *args: Any) -> list: |
|
|
|
|
ColoParamOpHookManager._trigger_pre_forward(params) |
|
|
|
|
args_info = _get_colo_tensors_info(*args) |
|
|
|
|
rets = PreFwdPostBwd.apply(params, *args) |
|
|
|
|
return _update_colo_tensors(args_info, *rets) |
|
|
|
|
grad_args, rear_args = _get_grad_args(*args) |
|
|
|
|
colo_info = _get_colo_tensors_info(*grad_args) |
|
|
|
|
rets = PreFwdPostBwd.apply(params, *grad_args) |
|
|
|
|
update_args = _update_colo_tensors(colo_info, *rets) |
|
|
|
|
if rear_args is None: |
|
|
|
|
return update_args |
|
|
|
|
else: |
|
|
|
|
arg_zero = (tuple(update_args),) |
|
|
|
|
return arg_zero + rear_args |
|
|
|
|
|
|
|
|
|
@staticmethod |
|
|
|
|
def post_op(params: List[torch.Tensor], arg: Any) -> Any: |
|
|
|
|
ColoParamOpHookManager._trigger_post_forward(params) |
|
|
|
|
arg_info = _get_colo_tensors_info(arg) |
|
|
|
|
colo_info = _get_colo_tensors_info(arg) |
|
|
|
|
ret = PostFwdPreBwd.apply(params, arg) |
|
|
|
|
return _unpack_args(_update_colo_tensors(arg_info, ret)) |
|
|
|
|
res = _update_colo_tensors(colo_info, ret) |
|
|
|
|
if len(res) == 1: |
|
|
|
|
return res[0] |
|
|
|
|
else: |
|
|
|
|
return res |
|
|
|
|
|
|
|
|
|
@staticmethod |
|
|
|
|
def has_hook() -> bool: |
|
|
|
@ -103,7 +113,7 @@ class PreFwdPostBwd(torch.autograd.Function):
|
|
|
|
|
@staticmethod |
|
|
|
|
def forward(ctx, params, *args): |
|
|
|
|
ctx.params = params |
|
|
|
|
return _unpack_args(args) |
|
|
|
|
return args |
|
|
|
|
|
|
|
|
|
@staticmethod |
|
|
|
|
def backward(ctx, *grads): |
|
|
|
@ -124,10 +134,29 @@ class PostFwdPreBwd(torch.autograd.Function):
|
|
|
|
|
return (None,) + grads |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _unpack_args(args): |
|
|
|
|
if len(args) == 1: |
|
|
|
|
return args[0] |
|
|
|
|
return args |
|
|
|
|
def _is_grad_tensor(obj) -> bool: |
|
|
|
|
if torch.is_tensor(obj): |
|
|
|
|
if obj.grad_fn is not None or obj.requires_grad: |
|
|
|
|
return True |
|
|
|
|
return False |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _get_grad_args(*args): |
|
|
|
|
# returns the identical args if there is a grad tensor |
|
|
|
|
for obj in args: |
|
|
|
|
if _is_grad_tensor(obj): |
|
|
|
|
return args, None |
|
|
|
|
# otherwise, the first arguement should be a tuple of grad tensors |
|
|
|
|
# if there is no grad tensor, the backward of PreFwdPostBwd can't be triggered |
|
|
|
|
arg_zero = args[0] |
|
|
|
|
if not isinstance(arg_zero, tuple): |
|
|
|
|
raise NotImplementedError("Some torch function is incompatible because of its complcated inputs.") |
|
|
|
|
check_grad_flag = False |
|
|
|
|
for obj in arg_zero: |
|
|
|
|
check_grad_flag |= _is_grad_tensor(obj) |
|
|
|
|
if not check_grad_flag: |
|
|
|
|
raise NotImplementedError("Some torch function is incompatible because of its complcated inputs.") |
|
|
|
|
return arg_zero, args[1:] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _get_colo_tensors_info(*args) -> list: |
|
|
|
|