[fix] fix optim bwd;

pull/6034/head
duanjunwen 2024-09-03 02:40:26 +00:00
parent 591a13bf7e
commit a48afc4a66
1 changed files with 1 additions and 1 deletions

View File

@ -58,7 +58,7 @@ class OptimizerWrapper:
# def backward_by_grad(self, tensor: Tensor, grad: Tensor):
# torch.autograd.backward(tensor, grad)
def backward_by_grad(self, tensor: Tensor, grad: Tensor, inputs: Tensor, retain_graph: bool = False):
def backward_by_grad(self, tensor: Tensor, grad: Tensor, inputs: Tensor = None, retain_graph: bool = False):
"""
Performs a backward pass for dx or dw,
for dx, we only calculate dx = w*dy here