diff --git a/colossalai/nn/parallel/data_parallel.py b/colossalai/nn/parallel/data_parallel.py index 8e0192c71..a9d001bd0 100644 --- a/colossalai/nn/parallel/data_parallel.py +++ b/colossalai/nn/parallel/data_parallel.py @@ -294,7 +294,7 @@ class ZeroDDP(ColoDDP): continue p.grad = None - def _pre_bacward(self): + def _pre_backward(self): # set a visit label for all parameters # the label is used to check whether the parameter is correctly reduced for param in self.param2name: @@ -318,7 +318,7 @@ class ZeroDDP(ColoDDP): self.gemini_manager.post_iter() def backward(self, loss: torch.Tensor): - self._pre_bacward() + self._pre_backward() with self.param_op_hook.switch_to_backward(), ColoParamOpHookManager.use_hooks(self.param_op_hook): loss.backward() self._post_backward()