feat(model/overlap_handler.py): fix overlap handler None bug

pull/456/head
huangting4201 2023-10-24 18:47:27 +08:00
parent 0d3592a53f
commit 41cfa1a10a
1 changed files with 4 additions and 2 deletions

View File

@ -312,7 +312,8 @@ class FSTPOverlapSchedulerHook(SchedulerHook):
self._overlap_handler = overlap_handler
def before_forward(self, scheduler, inputs) -> None:
self._overlap_handler.set_forward_mode(True)
if self._overlap_handler is not None:
self._overlap_handler.set_forward_mode(True)
def after_forward(self, scheduler, outputs) -> None:
pass
@ -324,7 +325,8 @@ class FSTPOverlapSchedulerHook(SchedulerHook):
pass
def before_backward(self, scheduler, outputs, outputs_grad) -> None:
self._overlap_handler.set_forward_mode(False)
if self._overlap_handler is not None:
self._overlap_handler.set_forward_mode(False)
def after_backward(self, scheduler, inputs_grad) -> None:
pass