From 708404d5f8ff54d24698260e8719f2e1fe21c573 Mon Sep 17 00:00:00 2001 From: ver217 Date: Fri, 21 Jan 2022 15:46:02 +0800 Subject: [PATCH] fix pipeline forward return tensors (#176) --- colossalai/engine/schedule/_pipeline_schedule.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/colossalai/engine/schedule/_pipeline_schedule.py b/colossalai/engine/schedule/_pipeline_schedule.py index 73a39e833..5bab0d524 100644 --- a/colossalai/engine/schedule/_pipeline_schedule.py +++ b/colossalai/engine/schedule/_pipeline_schedule.py @@ -1,19 +1,20 @@ #!/usr/bin/env python # -*- encoding: utf-8 -*- -from typing import List, Tuple, Union, Callable import inspect -import torch.cuda +from typing import Callable, List, Tuple, Union import colossalai.communication as comm +import torch.cuda +from colossalai.amp.naive_amp import NaiveAMPModel from colossalai.context.parallel_mode import ParallelMode from colossalai.core import global_context as gpc -from colossalai.amp.naive_amp import NaiveAMPModel +from colossalai.logging import get_dist_logger +from colossalai.utils import switch_virtual_pipeline_parallel_rank from colossalai.utils.cuda import get_current_device from colossalai.zero import (ZeroRedundancyOptimizer_Level_2, ZeroRedundancyOptimizer_Level_3) -from colossalai.utils import switch_virtual_pipeline_parallel_rank -from colossalai.logging import get_dist_logger + from ._base_schedule import BaseSchedule @@ -151,7 +152,7 @@ class PipelineSchedule(BaseSchedule): if gpc.is_last_rank(ParallelMode.PIPELINE): if return_output_label: - return_tensors.append(tuple((output_tensor, label))) + return_tensors.append((output_tensor, label)) if accum_loss is not None: loss_reduced = self._call_engine_criterion(engine, output_tensor, label) / self.num_microbatches accum_loss.add_(loss_reduced.detach()) @@ -414,7 +415,7 @@ class InterleavedPipelineSchedule(PipelineSchedule): if gpc.is_pipeline_last_stage(): if return_output_label: - return_tensors.append(tuple(output_tensor, label)) + return_tensors.append((output_tensor, label)) if accum_loss is not None: loss_reduced = self._call_engine_criterion(engine, output_tensor, label) / self.num_microbatches accum_loss.add_(loss_reduced.detach())