|
|
@ -1,19 +1,20 @@
|
|
|
|
#!/usr/bin/env python
|
|
|
|
#!/usr/bin/env python
|
|
|
|
# -*- encoding: utf-8 -*-
|
|
|
|
# -*- encoding: utf-8 -*-
|
|
|
|
|
|
|
|
|
|
|
|
from typing import List, Tuple, Union, Callable
|
|
|
|
|
|
|
|
import inspect
|
|
|
|
import inspect
|
|
|
|
import torch.cuda
|
|
|
|
from typing import Callable, List, Tuple, Union
|
|
|
|
|
|
|
|
|
|
|
|
import colossalai.communication as comm
|
|
|
|
import colossalai.communication as comm
|
|
|
|
|
|
|
|
import torch.cuda
|
|
|
|
|
|
|
|
from colossalai.amp.naive_amp import NaiveAMPModel
|
|
|
|
from colossalai.context.parallel_mode import ParallelMode
|
|
|
|
from colossalai.context.parallel_mode import ParallelMode
|
|
|
|
from colossalai.core import global_context as gpc
|
|
|
|
from colossalai.core import global_context as gpc
|
|
|
|
from colossalai.amp.naive_amp import NaiveAMPModel
|
|
|
|
from colossalai.logging import get_dist_logger
|
|
|
|
|
|
|
|
from colossalai.utils import switch_virtual_pipeline_parallel_rank
|
|
|
|
from colossalai.utils.cuda import get_current_device
|
|
|
|
from colossalai.utils.cuda import get_current_device
|
|
|
|
from colossalai.zero import (ZeroRedundancyOptimizer_Level_2,
|
|
|
|
from colossalai.zero import (ZeroRedundancyOptimizer_Level_2,
|
|
|
|
ZeroRedundancyOptimizer_Level_3)
|
|
|
|
ZeroRedundancyOptimizer_Level_3)
|
|
|
|
from colossalai.utils import switch_virtual_pipeline_parallel_rank
|
|
|
|
|
|
|
|
from colossalai.logging import get_dist_logger
|
|
|
|
|
|
|
|
from ._base_schedule import BaseSchedule
|
|
|
|
from ._base_schedule import BaseSchedule
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@ -151,7 +152,7 @@ class PipelineSchedule(BaseSchedule):
|
|
|
|
|
|
|
|
|
|
|
|
if gpc.is_last_rank(ParallelMode.PIPELINE):
|
|
|
|
if gpc.is_last_rank(ParallelMode.PIPELINE):
|
|
|
|
if return_output_label:
|
|
|
|
if return_output_label:
|
|
|
|
return_tensors.append(tuple((output_tensor, label)))
|
|
|
|
return_tensors.append((output_tensor, label))
|
|
|
|
if accum_loss is not None:
|
|
|
|
if accum_loss is not None:
|
|
|
|
loss_reduced = self._call_engine_criterion(engine, output_tensor, label) / self.num_microbatches
|
|
|
|
loss_reduced = self._call_engine_criterion(engine, output_tensor, label) / self.num_microbatches
|
|
|
|
accum_loss.add_(loss_reduced.detach())
|
|
|
|
accum_loss.add_(loss_reduced.detach())
|
|
|
@ -414,7 +415,7 @@ class InterleavedPipelineSchedule(PipelineSchedule):
|
|
|
|
|
|
|
|
|
|
|
|
if gpc.is_pipeline_last_stage():
|
|
|
|
if gpc.is_pipeline_last_stage():
|
|
|
|
if return_output_label:
|
|
|
|
if return_output_label:
|
|
|
|
return_tensors.append(tuple(output_tensor, label))
|
|
|
|
return_tensors.append((output_tensor, label))
|
|
|
|
if accum_loss is not None:
|
|
|
|
if accum_loss is not None:
|
|
|
|
loss_reduced = self._call_engine_criterion(engine, output_tensor, label) / self.num_microbatches
|
|
|
|
loss_reduced = self._call_engine_criterion(engine, output_tensor, label) / self.num_microbatches
|
|
|
|
accum_loss.add_(loss_reduced.detach())
|
|
|
|
accum_loss.add_(loss_reduced.detach())
|
|
|
|