diff --git a/internlm/model/overlap_handler.py b/internlm/model/overlap_handler.py index 5cef92f..35d8a59 100644 --- a/internlm/model/overlap_handler.py +++ b/internlm/model/overlap_handler.py @@ -312,7 +312,8 @@ class FSTPOverlapSchedulerHook(SchedulerHook): self._overlap_handler = overlap_handler def before_forward(self, scheduler, inputs) -> None: - self._overlap_handler.set_forward_mode(True) + if self._overlap_handler is not None: + self._overlap_handler.set_forward_mode(True) def after_forward(self, scheduler, outputs) -> None: pass @@ -324,7 +325,8 @@ class FSTPOverlapSchedulerHook(SchedulerHook): pass def before_backward(self, scheduler, outputs, outputs_grad) -> None: - self._overlap_handler.set_forward_mode(False) + if self._overlap_handler is not None: + self._overlap_handler.set_forward_mode(False) def after_backward(self, scheduler, inputs_grad) -> None: pass