mirror of https://github.com/hpcaitech/ColossalAI
[fix] fix optim bwd;
parent
591a13bf7e
commit
a48afc4a66
|
@ -58,7 +58,7 @@ class OptimizerWrapper:
|
|||
# def backward_by_grad(self, tensor: Tensor, grad: Tensor):
|
||||
# torch.autograd.backward(tensor, grad)
|
||||
|
||||
def backward_by_grad(self, tensor: Tensor, grad: Tensor, inputs: Tensor, retain_graph: bool = False):
|
||||
def backward_by_grad(self, tensor: Tensor, grad: Tensor, inputs: Tensor = None, retain_graph: bool = False):
|
||||
"""
|
||||
Performs a backward pass for dx or dw,
|
||||
for dx, we only calculate dx = w*dy here
|
||||
|
|
Loading…
Reference in New Issue