fix pipeline forward return tensors (#176)

pull/172/head
ver217 3 years ago committed by GitHub
parent 6fb550acdb
commit 708404d5f8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -1,19 +1,20 @@
#!/usr/bin/env python #!/usr/bin/env python
# -*- encoding: utf-8 -*- # -*- encoding: utf-8 -*-
from typing import List, Tuple, Union, Callable
import inspect import inspect
import torch.cuda from typing import Callable, List, Tuple, Union
import colossalai.communication as comm import colossalai.communication as comm
import torch.cuda
from colossalai.amp.naive_amp import NaiveAMPModel
from colossalai.context.parallel_mode import ParallelMode from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc from colossalai.core import global_context as gpc
from colossalai.amp.naive_amp import NaiveAMPModel from colossalai.logging import get_dist_logger
from colossalai.utils import switch_virtual_pipeline_parallel_rank
from colossalai.utils.cuda import get_current_device from colossalai.utils.cuda import get_current_device
from colossalai.zero import (ZeroRedundancyOptimizer_Level_2, from colossalai.zero import (ZeroRedundancyOptimizer_Level_2,
ZeroRedundancyOptimizer_Level_3) ZeroRedundancyOptimizer_Level_3)
from colossalai.utils import switch_virtual_pipeline_parallel_rank
from colossalai.logging import get_dist_logger
from ._base_schedule import BaseSchedule from ._base_schedule import BaseSchedule
@ -151,7 +152,7 @@ class PipelineSchedule(BaseSchedule):
if gpc.is_last_rank(ParallelMode.PIPELINE): if gpc.is_last_rank(ParallelMode.PIPELINE):
if return_output_label: if return_output_label:
return_tensors.append(tuple((output_tensor, label))) return_tensors.append((output_tensor, label))
if accum_loss is not None: if accum_loss is not None:
loss_reduced = self._call_engine_criterion(engine, output_tensor, label) / self.num_microbatches loss_reduced = self._call_engine_criterion(engine, output_tensor, label) / self.num_microbatches
accum_loss.add_(loss_reduced.detach()) accum_loss.add_(loss_reduced.detach())
@ -414,7 +415,7 @@ class InterleavedPipelineSchedule(PipelineSchedule):
if gpc.is_pipeline_last_stage(): if gpc.is_pipeline_last_stage():
if return_output_label: if return_output_label:
return_tensors.append(tuple(output_tensor, label)) return_tensors.append((output_tensor, label))
if accum_loss is not None: if accum_loss is not None:
loss_reduced = self._call_engine_criterion(engine, output_tensor, label) / self.num_microbatches loss_reduced = self._call_engine_criterion(engine, output_tensor, label) / self.num_microbatches
accum_loss.add_(loss_reduced.detach()) accum_loss.add_(loss_reduced.detach())

Loading…
Cancel
Save