fix pipeline forward return tensors (#176)

pull/172/head
ver217 2022-01-21 15:46:02 +08:00 committed by GitHub
parent 6fb550acdb
commit 708404d5f8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 8 additions and 7 deletions

View File

@ -1,19 +1,20 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
from typing import List, Tuple, Union, Callable
import inspect
import torch.cuda
from typing import Callable, List, Tuple, Union
import colossalai.communication as comm
import torch.cuda
from colossalai.amp.naive_amp import NaiveAMPModel
from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.amp.naive_amp import NaiveAMPModel
from colossalai.logging import get_dist_logger
from colossalai.utils import switch_virtual_pipeline_parallel_rank
from colossalai.utils.cuda import get_current_device
from colossalai.zero import (ZeroRedundancyOptimizer_Level_2,
ZeroRedundancyOptimizer_Level_3)
from colossalai.utils import switch_virtual_pipeline_parallel_rank
from colossalai.logging import get_dist_logger
from ._base_schedule import BaseSchedule
@ -151,7 +152,7 @@ class PipelineSchedule(BaseSchedule):
if gpc.is_last_rank(ParallelMode.PIPELINE):
if return_output_label:
return_tensors.append(tuple((output_tensor, label)))
return_tensors.append((output_tensor, label))
if accum_loss is not None:
loss_reduced = self._call_engine_criterion(engine, output_tensor, label) / self.num_microbatches
accum_loss.add_(loss_reduced.detach())
@ -414,7 +415,7 @@ class InterleavedPipelineSchedule(PipelineSchedule):
if gpc.is_pipeline_last_stage():
if return_output_label:
return_tensors.append(tuple(output_tensor, label))
return_tensors.append((output_tensor, label))
if accum_loss is not None:
loss_reduced = self._call_engine_criterion(engine, output_tensor, label) / self.num_microbatches
accum_loss.add_(loss_reduced.detach())