mirror of https://github.com/hpcaitech/ColossalAI
fix pipeline forward return tensors (#176)
parent
6fb550acdb
commit
708404d5f8
|
@ -1,19 +1,20 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- encoding: utf-8 -*-
|
||||
|
||||
from typing import List, Tuple, Union, Callable
|
||||
import inspect
|
||||
import torch.cuda
|
||||
from typing import Callable, List, Tuple, Union
|
||||
|
||||
import colossalai.communication as comm
|
||||
import torch.cuda
|
||||
from colossalai.amp.naive_amp import NaiveAMPModel
|
||||
from colossalai.context.parallel_mode import ParallelMode
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.amp.naive_amp import NaiveAMPModel
|
||||
from colossalai.logging import get_dist_logger
|
||||
from colossalai.utils import switch_virtual_pipeline_parallel_rank
|
||||
from colossalai.utils.cuda import get_current_device
|
||||
from colossalai.zero import (ZeroRedundancyOptimizer_Level_2,
|
||||
ZeroRedundancyOptimizer_Level_3)
|
||||
from colossalai.utils import switch_virtual_pipeline_parallel_rank
|
||||
from colossalai.logging import get_dist_logger
|
||||
|
||||
from ._base_schedule import BaseSchedule
|
||||
|
||||
|
||||
|
@ -151,7 +152,7 @@ class PipelineSchedule(BaseSchedule):
|
|||
|
||||
if gpc.is_last_rank(ParallelMode.PIPELINE):
|
||||
if return_output_label:
|
||||
return_tensors.append(tuple((output_tensor, label)))
|
||||
return_tensors.append((output_tensor, label))
|
||||
if accum_loss is not None:
|
||||
loss_reduced = self._call_engine_criterion(engine, output_tensor, label) / self.num_microbatches
|
||||
accum_loss.add_(loss_reduced.detach())
|
||||
|
@ -414,7 +415,7 @@ class InterleavedPipelineSchedule(PipelineSchedule):
|
|||
|
||||
if gpc.is_pipeline_last_stage():
|
||||
if return_output_label:
|
||||
return_tensors.append(tuple(output_tensor, label))
|
||||
return_tensors.append((output_tensor, label))
|
||||
if accum_loss is not None:
|
||||
loss_reduced = self._call_engine_criterion(engine, output_tensor, label) / self.num_microbatches
|
||||
accum_loss.add_(loss_reduced.detach())
|
||||
|
|
Loading…
Reference in New Issue