mirror of https://github.com/hpcaitech/ColossalAI
Rename class method of ZeroDDP (#2692)
parent
6e4ac08172
commit
c52edcf0eb
|
@ -294,7 +294,7 @@ class ZeroDDP(ColoDDP):
|
||||||
continue
|
continue
|
||||||
p.grad = None
|
p.grad = None
|
||||||
|
|
||||||
def _pre_bacward(self):
|
def _pre_backward(self):
|
||||||
# set a visit label for all parameters
|
# set a visit label for all parameters
|
||||||
# the label is used to check whether the parameter is correctly reduced
|
# the label is used to check whether the parameter is correctly reduced
|
||||||
for param in self.param2name:
|
for param in self.param2name:
|
||||||
|
@ -318,7 +318,7 @@ class ZeroDDP(ColoDDP):
|
||||||
self.gemini_manager.post_iter()
|
self.gemini_manager.post_iter()
|
||||||
|
|
||||||
def backward(self, loss: torch.Tensor):
|
def backward(self, loss: torch.Tensor):
|
||||||
self._pre_bacward()
|
self._pre_backward()
|
||||||
with self.param_op_hook.switch_to_backward(), ColoParamOpHookManager.use_hooks(self.param_op_hook):
|
with self.param_op_hook.switch_to_backward(), ColoParamOpHookManager.use_hooks(self.param_op_hook):
|
||||||
loss.backward()
|
loss.backward()
|
||||||
self._post_backward()
|
self._post_backward()
|
||||||
|
|
Loading…
Reference in New Issue