mirror of https://github.com/InternLM/InternLM
fix(train.py): fix scheduler metric hook skip error (#204)
parent
5f2381af62
commit
f3664bfbab
|
@ -158,13 +158,6 @@ class SchedulerMetricHook(SchedulerHook):
|
|||
self._post_func = metric
|
||||
self._skip = skip
|
||||
|
||||
if skip:
|
||||
# init timer only.
|
||||
timer("fwd")
|
||||
timer("bwd")
|
||||
timer("cal_loss")
|
||||
timer("post_fn")
|
||||
|
||||
def before_forward(self, scheduler, inputs) -> None:
|
||||
if not self._skip:
|
||||
timer("fwd").start()
|
||||
|
@ -190,8 +183,5 @@ class SchedulerMetricHook(SchedulerHook):
|
|||
timer("bwd").stop()
|
||||
|
||||
def post_helper_func(self, scheduler, outputs, label) -> None:
|
||||
if not self._skip:
|
||||
timer("post_fn").start()
|
||||
if self._post_func is not None:
|
||||
self._post_func(outputs, label)
|
||||
timer("post_fn").stop()
|
||||
if self._post_func is not None:
|
||||
self._post_func(outputs, label)
|
||||
|
|
7
train.py
7
train.py
|
@ -543,7 +543,12 @@ def main(args):
|
|||
scheduler_hooks = [
|
||||
SchedulerMetricHook(
|
||||
metric=metric,
|
||||
skip=gpc.is_using_pp() and gpc.config.parallel["pipeline"].get("interleaved_overlap", False),
|
||||
skip=(
|
||||
gpc.is_using_pp()
|
||||
and hasattr(gpc.config.model, "num_chunks")
|
||||
and gpc.config.model.num_chunks > 1
|
||||
and gpc.config.parallel["pipeline"].get("interleaved_overlap", False)
|
||||
),
|
||||
),
|
||||
]
|
||||
|
||||
|
|
Loading…
Reference in New Issue