diff --git a/colossalai/interface/optimizer.py b/colossalai/interface/optimizer.py index f259cddad..1afbd0806 100644 --- a/colossalai/interface/optimizer.py +++ b/colossalai/interface/optimizer.py @@ -58,7 +58,7 @@ class OptimizerWrapper: # def backward_by_grad(self, tensor: Tensor, grad: Tensor): # torch.autograd.backward(tensor, grad) - def backward_by_grad(self, tensor: Tensor, grad: Tensor, inputs: Tensor, retain_graph: bool = False): + def backward_by_grad(self, tensor: Tensor, grad: Tensor, inputs: Tensor = None, retain_graph: bool = False): """ Performs a backward pass for dx or dw, for dx, we only calculate dx = w*dy here