fix(train.py): fix scheduler metric hook skip error (#204)

pull/210/head
huangting4201 2023-08-16 15:47:05 +08:00 committed by GitHub
parent 5f2381af62
commit f3664bfbab
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 8 additions and 13 deletions

View File

@ -158,13 +158,6 @@ class SchedulerMetricHook(SchedulerHook):
self._post_func = metric
self._skip = skip
if skip:
# init timer only.
timer("fwd")
timer("bwd")
timer("cal_loss")
timer("post_fn")
def before_forward(self, scheduler, inputs) -> None:
if not self._skip:
timer("fwd").start()
@ -190,8 +183,5 @@ class SchedulerMetricHook(SchedulerHook):
timer("bwd").stop()
def post_helper_func(self, scheduler, outputs, label) -> None:
if not self._skip:
timer("post_fn").start()
if self._post_func is not None:
self._post_func(outputs, label)
timer("post_fn").stop()
if self._post_func is not None:
self._post_func(outputs, label)

View File

@ -543,7 +543,12 @@ def main(args):
scheduler_hooks = [
SchedulerMetricHook(
metric=metric,
skip=gpc.is_using_pp() and gpc.config.parallel["pipeline"].get("interleaved_overlap", False),
skip=(
gpc.is_using_pp()
and hasattr(gpc.config.model, "num_chunks")
and gpc.config.model.num_chunks > 1
and gpc.config.parallel["pipeline"].get("interleaved_overlap", False)
),
),
]