From a48afc4a665d4217099e08fb1949f5976347d5f6 Mon Sep 17 00:00:00 2001 From: duanjunwen <935724073@qq.com> Date: Tue, 3 Sep 2024 02:40:26 +0000 Subject: [PATCH] [fix] fix optim bwd; --- colossalai/interface/optimizer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/colossalai/interface/optimizer.py b/colossalai/interface/optimizer.py index f259cddad..1afbd0806 100644 --- a/colossalai/interface/optimizer.py +++ b/colossalai/interface/optimizer.py @@ -58,7 +58,7 @@ class OptimizerWrapper: # def backward_by_grad(self, tensor: Tensor, grad: Tensor): # torch.autograd.backward(tensor, grad) - def backward_by_grad(self, tensor: Tensor, grad: Tensor, inputs: Tensor, retain_graph: bool = False): + def backward_by_grad(self, tensor: Tensor, grad: Tensor, inputs: Tensor = None, retain_graph: bool = False): """ Performs a backward pass for dx or dw, for dx, we only calculate dx = w*dy here