|
|
|
@ -1,19 +1,20 @@
|
|
|
|
|
#!/usr/bin/env python |
|
|
|
|
# -*- encoding: utf-8 -*- |
|
|
|
|
|
|
|
|
|
from typing import List, Tuple, Union, Callable |
|
|
|
|
import inspect |
|
|
|
|
import torch.cuda |
|
|
|
|
from typing import Callable, List, Tuple, Union |
|
|
|
|
|
|
|
|
|
import colossalai.communication as comm |
|
|
|
|
import torch.cuda |
|
|
|
|
from colossalai.amp.naive_amp import NaiveAMPModel |
|
|
|
|
from colossalai.context.parallel_mode import ParallelMode |
|
|
|
|
from colossalai.core import global_context as gpc |
|
|
|
|
from colossalai.amp.naive_amp import NaiveAMPModel |
|
|
|
|
from colossalai.logging import get_dist_logger |
|
|
|
|
from colossalai.utils import switch_virtual_pipeline_parallel_rank |
|
|
|
|
from colossalai.utils.cuda import get_current_device |
|
|
|
|
from colossalai.zero import (ZeroRedundancyOptimizer_Level_2, |
|
|
|
|
ZeroRedundancyOptimizer_Level_3) |
|
|
|
|
from colossalai.utils import switch_virtual_pipeline_parallel_rank |
|
|
|
|
from colossalai.logging import get_dist_logger |
|
|
|
|
|
|
|
|
|
from ._base_schedule import BaseSchedule |
|
|
|
|
|
|
|
|
|
|
|
|
|
@ -151,7 +152,7 @@ class PipelineSchedule(BaseSchedule):
|
|
|
|
|
|
|
|
|
|
if gpc.is_last_rank(ParallelMode.PIPELINE): |
|
|
|
|
if return_output_label: |
|
|
|
|
return_tensors.append(tuple((output_tensor, label))) |
|
|
|
|
return_tensors.append((output_tensor, label)) |
|
|
|
|
if accum_loss is not None: |
|
|
|
|
loss_reduced = self._call_engine_criterion(engine, output_tensor, label) / self.num_microbatches |
|
|
|
|
accum_loss.add_(loss_reduced.detach()) |
|
|
|
@ -414,7 +415,7 @@ class InterleavedPipelineSchedule(PipelineSchedule):
|
|
|
|
|
|
|
|
|
|
if gpc.is_pipeline_last_stage(): |
|
|
|
|
if return_output_label: |
|
|
|
|
return_tensors.append(tuple(output_tensor, label)) |
|
|
|
|
return_tensors.append((output_tensor, label)) |
|
|
|
|
if accum_loss is not None: |
|
|
|
|
loss_reduced = self._call_engine_criterion(engine, output_tensor, label) / self.num_microbatches |
|
|
|
|
accum_loss.add_(loss_reduced.detach()) |
|
|
|
|