|
|
@ -168,12 +168,12 @@ def _get_grad_args(*args):
|
|
|
|
# if there is no grad tensor, the backward of PreFwdPostBwd can't be triggered
|
|
|
|
# if there is no grad tensor, the backward of PreFwdPostBwd can't be triggered
|
|
|
|
arg_zero = args[0]
|
|
|
|
arg_zero = args[0]
|
|
|
|
if not isinstance(arg_zero, tuple):
|
|
|
|
if not isinstance(arg_zero, tuple):
|
|
|
|
raise NotImplementedError("Some torch function is incompatible because of its complcated inputs.")
|
|
|
|
raise NotImplementedError("Some torch function is incompatible because of its complicated inputs.")
|
|
|
|
check_grad_flag = False
|
|
|
|
check_grad_flag = False
|
|
|
|
for obj in arg_zero:
|
|
|
|
for obj in arg_zero:
|
|
|
|
check_grad_flag |= _is_grad_tensor(obj)
|
|
|
|
check_grad_flag |= _is_grad_tensor(obj)
|
|
|
|
if not check_grad_flag:
|
|
|
|
if not check_grad_flag:
|
|
|
|
raise NotImplementedError("Some torch function is incompatible because of its complcated inputs.")
|
|
|
|
raise NotImplementedError("Some torch function is incompatible because of its complicated inputs.")
|
|
|
|
return arg_zero, args[1:]
|
|
|
|
return arg_zero, args[1:]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|