|
|
|
@ -168,12 +168,12 @@ def _get_grad_args(*args):
|
|
|
|
|
# if there is no grad tensor, the backward of PreFwdPostBwd can't be triggered |
|
|
|
|
arg_zero = args[0] |
|
|
|
|
if not isinstance(arg_zero, tuple): |
|
|
|
|
raise NotImplementedError("Some torch function is incompatible because of its complcated inputs.") |
|
|
|
|
raise NotImplementedError("Some torch function is incompatible because of its complicated inputs.") |
|
|
|
|
check_grad_flag = False |
|
|
|
|
for obj in arg_zero: |
|
|
|
|
check_grad_flag |= _is_grad_tensor(obj) |
|
|
|
|
if not check_grad_flag: |
|
|
|
|
raise NotImplementedError("Some torch function is incompatible because of its complcated inputs.") |
|
|
|
|
raise NotImplementedError("Some torch function is incompatible because of its complicated inputs.") |
|
|
|
|
return arg_zero, args[1:] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|