From 96780e6ee4c73dc9c6724075e8e6306b765830d0 Mon Sep 17 00:00:00 2001 From: ver217 Date: Thu, 30 Dec 2021 15:56:46 +0800 Subject: [PATCH] Optimize pipeline schedule (#94) * add pipeline shared module wrapper and update load batch * added model parallel process group for amp and clip grad (#86) * added model parallel process group for amp and clip grad * update amp and clip with model parallel process group * remove pipeline_prev/next group (#88) * micro batch offload * optimize pipeline gpu memory usage * pipeline can receive tensor shape (#93) * optimize pipeline gpu memory usage * fix grad accumulation step counter * rename classes and functions Co-authored-by: Frank Lee --- colossalai/amp/naive_amp/_fp16_optimizer.py | 7 +- colossalai/amp/torch_amp/_grad_scaler.py | 13 +- colossalai/builder/pipeline.py | 16 +- colossalai/constants.py | 3 +- colossalai/context/parallel_context.py | 3 + colossalai/context/parallel_mode.py | 6 +- .../process_group_initializer/__init__.py | 3 +- .../initializer_model.py | 43 +++ .../initializer_pipeline.py | 24 -- colossalai/engine/_base_engine.py | 7 +- .../engine/gradient_handler/__init__.py | 4 +- .../_pipeline_parallel_gradient_handler.py | 41 +++ colossalai/engine/schedule/_base_schedule.py | 60 ++-- .../engine/schedule/_non_pipeline_schedule.py | 27 +- .../engine/schedule/_pipeline_schedule.py | 274 +++++++++--------- colossalai/initialize.py | 13 + .../nn/layer/colossalai_layer/_utils.py | 4 +- colossalai/nn/layer/wrapper/__init__.py | 3 +- .../nn/layer/wrapper/pipeline_wrapper.py | 40 +++ colossalai/trainer/_trainer.py | 25 +- colossalai/utils/common.py | 20 +- .../_gradient_accumulation.py | 3 +- docs/add_your_parallel.md | 2 - docs/add_your_parallel_zh.md | 2 - tests/test_context/test_2d_init.py | 7 + tests/test_context/test_2p5d_init.py | 7 + tests/test_context/test_3d_init.py | 7 + .../test_cifar_with_data_pipeline_tensor.py | 2 +- tests/test_trainer/test_pipeline/test_p2p.py | 47 +-- 29 files changed, 423 insertions(+), 290 deletions(-) create mode 100644 colossalai/context/process_group_initializer/initializer_model.py create mode 100644 colossalai/engine/gradient_handler/_pipeline_parallel_gradient_handler.py create mode 100644 colossalai/nn/layer/wrapper/pipeline_wrapper.py diff --git a/colossalai/amp/naive_amp/_fp16_optimizer.py b/colossalai/amp/naive_amp/_fp16_optimizer.py index d34143aec..b1fc621c2 100644 --- a/colossalai/amp/naive_amp/_fp16_optimizer.py +++ b/colossalai/amp/naive_amp/_fp16_optimizer.py @@ -359,12 +359,7 @@ class FP16Optimizer(Optimizer): # Update across all model parallel instances. torch.distributed.all_reduce(self.found_inf, op=torch.distributed.ReduceOp.MAX, - group=gpc.get_group(ParallelMode.TENSOR)) - - if is_using_pp(): - torch.distributed.all_reduce(self.found_inf, - op=torch.distributed.ReduceOp.MAX, - group=gpc.get_group(ParallelMode.PIPELINE)) + group=gpc.get_group(ParallelMode.MODEL)) # Check for nan. found_inf_flag = (self.found_inf.item() > 0) diff --git a/colossalai/amp/torch_amp/_grad_scaler.py b/colossalai/amp/torch_amp/_grad_scaler.py index 7e79ecab8..b3ad5c084 100644 --- a/colossalai/amp/torch_amp/_grad_scaler.py +++ b/colossalai/amp/torch_amp/_grad_scaler.py @@ -11,6 +11,7 @@ from typing import Any, Dict, List, Optional, Tuple from colossalai.context import ParallelMode import torch.distributed as dist from colossalai.core import global_context as gpc +from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors class _MultiDeviceReplicator(object): @@ -247,10 +248,14 @@ class GradScaler(object): device), per_device_inv_scale.get(device)) # For tensor parallel paramters it should be all-reduced over tensor parallel process group - if gpc.is_initialized(ParallelMode.TENSOR) and gpc.get_world_size(ParallelMode.TENSOR) > 1: - for tensor in per_device_found_inf._per_device_tensors.values(): - dist.all_reduce(tensor, op=dist.ReduceOp.MAX, - group=gpc.get_group(ParallelMode.TENSOR)) + if gpc.is_initialized(ParallelMode.MODEL) and gpc.get_world_size(ParallelMode.MODEL) > 1: + vals = [val for val in per_device_found_inf._per_device_tensors.values()] + coalesced = _flatten_dense_tensors(vals) + dist.all_reduce(coalesced, + op=dist.ReduceOp.MAX, + group=gpc.get_group(ParallelMode.MODEL)) + for buf, synced in zip(vals, _unflatten_dense_tensors(coalesced, vals)): + buf.copy_(synced) return per_device_found_inf._per_device_tensors def unscale_(self, optimizer): diff --git a/colossalai/builder/pipeline.py b/colossalai/builder/pipeline.py index 3e545ebb3..39f2414f7 100644 --- a/colossalai/builder/pipeline.py +++ b/colossalai/builder/pipeline.py @@ -112,7 +112,7 @@ def _binary_search(weights, num): return intervals -def _partition_uniform(num_items, pipeline_parallel_size, num_chunks): +def partition_uniform(num_items, pipeline_parallel_size, num_chunks): assert num_items % num_chunks == 0, \ "Layer length should be divided by the number of chunks, otherwise parameter method is recomended" @@ -134,11 +134,11 @@ def _partition_uniform(num_items, pipeline_parallel_size, num_chunks): return parts -def _partition_balanced(weights, pipeline_parallel_size, num_chunks): +def partition_balanced(weights, pipeline_parallel_size, num_chunks): num_total = pipeline_parallel_size * num_chunks num_items = len(weights) if num_items <= num_total: - return _partition_uniform(num_items, pipeline_parallel_size, num_chunks) + return partition_uniform(num_items, pipeline_parallel_size, num_chunks) intervals = _binary_search(weights, num_total) @@ -151,7 +151,7 @@ def _partition_balanced(weights, pipeline_parallel_size, num_chunks): return parts -def _count_layer_params(layers): +def count_layer_params(layers): """Count the number of parameters in each layer """ param_counts = [0] * len(layers) @@ -201,11 +201,11 @@ def build_pipeline_model_from_cfg(config, num_chunks: int = 1, partition_method: # Make a partition if method == 'layer': num_layers = len(layers) - parts = _partition_uniform(num_layers, pipeline_parallel_size, num_chunks) + parts = partition_uniform(num_layers, pipeline_parallel_size, num_chunks) elif method == 'parameter': - param_counts = _count_layer_params(layers) + param_counts = count_layer_params(layers) # print_rank_0(param_counts) - parts = _partition_balanced(param_counts, pipeline_parallel_size, num_chunks) + parts = partition_balanced(param_counts, pipeline_parallel_size, num_chunks) else: raise ValueError("Method should be a pre-set string in [layer, parameter]") @@ -250,7 +250,7 @@ def build_pipeline_model(layers: nn.Sequential, num_chunks: int = 1, verbose: bo """ pipeline_parallel_size = gpc.get_world_size(ParallelMode.PIPELINE) pipeline_rank = gpc.get_local_rank(ParallelMode.PIPELINE) - partitions = _partition_uniform(len(layers), pipeline_parallel_size, num_chunks) + partitions = partition_uniform(len(layers), pipeline_parallel_size, num_chunks) module_list = [] for start, end in partitions[pipeline_rank]: module_list.append(nn.Sequential(*layers[start:end])) diff --git a/colossalai/constants.py b/colossalai/constants.py index 58a94437a..2ba535f43 100644 --- a/colossalai/constants.py +++ b/colossalai/constants.py @@ -14,7 +14,8 @@ INITIALIZER_MAPPING = { '2d': 'Initializer_2D', '2.5d': 'Initializer_2p5D', '3d': 'Initializer_3D', - 'sequence': 'Initializer_Sequence' + 'sequence': 'Initializer_Sequence', + 'model': 'Initializer_Model' } # 1D parallel diff --git a/colossalai/context/parallel_context.py b/colossalai/context/parallel_context.py index f76f4d60e..5bad70f00 100644 --- a/colossalai/context/parallel_context.py +++ b/colossalai/context/parallel_context.py @@ -394,6 +394,9 @@ class ParallelContext: # LSG: init data parallel process group for compatibility with other parallel module such as zero pg_init.append(dict(type=INITIALIZER_MAPPING['data'])) + # LSG: init model parallel process group for compatibility with amp and clip grad + pg_init.append(dict(type=INITIALIZER_MAPPING['model'])) + if self.pipeline_parallel_size > 1: pg_init.append(dict(type=INITIALIZER_MAPPING['pipeline'])) pg_init.append(dict(type=INITIALIZER_MAPPING['tensor'])) diff --git a/colossalai/context/parallel_mode.py b/colossalai/context/parallel_mode.py index f51ed8ecf..440526eae 100644 --- a/colossalai/context/parallel_mode.py +++ b/colossalai/context/parallel_mode.py @@ -14,10 +14,12 @@ class ParallelMode(Enum): # common parallel DATA = 'data' + # model parallel - containing tensor and pipeline parallel groups + # this is added to facilitate amp and grad clipping in hybrid parallel + MODEL = 'model' + # pipeline parallel PIPELINE = 'pipe' - PIPELINE_PREV = 'pipe_prev' - PIPELINE_NEXT = 'pipe_next' # containing all ranks in tensor parallel TENSOR = 'tensor' diff --git a/colossalai/context/process_group_initializer/__init__.py b/colossalai/context/process_group_initializer/__init__.py index c7db5d39c..b98b64310 100644 --- a/colossalai/context/process_group_initializer/__init__.py +++ b/colossalai/context/process_group_initializer/__init__.py @@ -6,10 +6,11 @@ from .initializer_data import Initializer_Data from .initializer_pipeline import Initializer_Pipeline from .initializer_sequence import Initializer_Sequence from .initializer_tensor import Initializer_Tensor +from .initializer_model import Initializer_Model from .process_group_initializer import ProcessGroupInitializer __all__ = [ 'Initializer_Tensor', 'Initializer_Sequence', 'Initializer_Pipeline', 'Initializer_Data', 'Initializer_2p5D', 'Initializer_2D', 'Initializer_3D', - 'Initializer_1D', 'ProcessGroupInitializer' + 'Initializer_1D', 'ProcessGroupInitializer', 'Initializer_Model' ] diff --git a/colossalai/context/process_group_initializer/initializer_model.py b/colossalai/context/process_group_initializer/initializer_model.py new file mode 100644 index 000000000..d5b50ac28 --- /dev/null +++ b/colossalai/context/process_group_initializer/initializer_model.py @@ -0,0 +1,43 @@ +#!/usr/bin/env python +# -*- encoding: utf-8 -*- + +import torch.distributed as dist + +from colossalai.context import Config +from colossalai.registry import DIST_GROUP_INITIALIZER +from .process_group_initializer import ProcessGroupInitializer +from ..parallel_mode import ParallelMode + + +@DIST_GROUP_INITIALIZER.register_module +class Initializer_Model(ProcessGroupInitializer): + '''A ProcessGroupInitializer for model parallelism (model parallel group contains pipeline and tensor parallel groups). + ''' + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.model_parallel_size = self.tensor_parallel_size * self.pipeline_parallel_size + self.num_group = self.world_size // self.model_parallel_size + + def init_dist_group(self): + '''Initialize 1D tensor parallel groups, and assign local_ranks and groups to each gpu. + + :return: (local_rank, group_world_size, process_group, ranks_in_group, mode) + :rtype: tuple + ''' + local_rank = None + ranks_in_group = None + process_group = None + group_world_size = None + mode = ParallelMode.MODEL + + for i in range(self.num_group): + ranks = [i * self.model_parallel_size + j for j in range(self.model_parallel_size)] + group = dist.new_group(ranks) + + if self.rank in ranks: + local_rank = ranks.index(self.rank) + group_world_size = len(ranks) + process_group = group + ranks_in_group = ranks + return local_rank, group_world_size, process_group, ranks_in_group, mode diff --git a/colossalai/context/process_group_initializer/initializer_pipeline.py b/colossalai/context/process_group_initializer/initializer_pipeline.py index d66c6f9af..f014c2486 100644 --- a/colossalai/context/process_group_initializer/initializer_pipeline.py +++ b/colossalai/context/process_group_initializer/initializer_pipeline.py @@ -36,28 +36,4 @@ class Initializer_Pipeline(ProcessGroupInitializer): process_group, ranks_in_group, ParallelMode.PIPELINE))) - for k in range(pipe_group_size): - first = pipe_ranks[k] - second = pipe_ranks[(k + 1) % pipe_group_size] - ranks = [first, second] - group = dist.new_group(ranks) - if self.rank == first: - local_rank = 0 - group_world_size = 2 - process_group = group - ranks_in_group = ranks - dist_settings.append( - tuple((local_rank, group_world_size, - process_group, ranks_in_group, - ParallelMode.PIPELINE_NEXT))) - elif self.rank == second: - local_rank = 1 - group_world_size = 2 - process_group = group - ranks_in_group = ranks - dist_settings.append( - tuple((local_rank, group_world_size, - process_group, ranks_in_group, - ParallelMode.PIPELINE_PREV))) - return dist_settings diff --git a/colossalai/engine/_base_engine.py b/colossalai/engine/_base_engine.py index 90a7f0730..985c9d422 100644 --- a/colossalai/engine/_base_engine.py +++ b/colossalai/engine/_base_engine.py @@ -2,15 +2,12 @@ # -*- encoding: utf-8 -*- -import torch from typing import List from torch.nn import Module from torch.nn.modules.loss import _Loss from torch.optim import Optimizer -from colossalai.builder import build_gradient_handler from colossalai.logging import get_dist_logger -from colossalai.utils import is_using_ddp, is_using_pp from torch import Tensor @@ -84,7 +81,7 @@ class Engine: def backward(self, loss: Tensor): """Start backward propagation given the loss value computed by a loss function - + :param loss: loss value computed by a loss function :type loss: :class:`torch.Tensor` """ @@ -92,7 +89,7 @@ class Engine: def backward_by_grad(self, tensor, grad): """Start backward propagation given the gradient of the output tensor - + :param loss: output tensor :type loss: :class:`torch.Tensor` :param grad: gradient passed back to the output diff --git a/colossalai/engine/gradient_handler/__init__.py b/colossalai/engine/gradient_handler/__init__.py index 3f896baa5..b2fd2d442 100644 --- a/colossalai/engine/gradient_handler/__init__.py +++ b/colossalai/engine/gradient_handler/__init__.py @@ -1,5 +1,7 @@ from ._base_gradient_handler import BaseGradientHandler from ._data_parallel_gradient_handler import DataParallelGradientHandler from ._zero_gradient_handler import ZeROGradientHandler +from ._pipeline_parallel_gradient_handler import PipelineSharedModuleGradientHandler -__all__ = ['BaseGradientHandler', 'DataParallelGradientHandler', 'ZeROGradientHandler'] +__all__ = ['BaseGradientHandler', 'DataParallelGradientHandler', + 'ZeROGradientHandler', 'PipelineSharedModuleGradientHandler'] diff --git a/colossalai/engine/gradient_handler/_pipeline_parallel_gradient_handler.py b/colossalai/engine/gradient_handler/_pipeline_parallel_gradient_handler.py new file mode 100644 index 000000000..458a11509 --- /dev/null +++ b/colossalai/engine/gradient_handler/_pipeline_parallel_gradient_handler.py @@ -0,0 +1,41 @@ +#!/usr/bin/env python + +import torch.distributed as dist +from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors + +from colossalai.core import global_context as gpc +from colossalai.registry import GRADIENT_HANDLER +from ._base_gradient_handler import BaseGradientHandler +from collections import defaultdict + + +@GRADIENT_HANDLER.register_module +class PipelineSharedModuleGradientHandler(BaseGradientHandler): + """A helper class to handle all-reduce operations in sub parallel groups. + A all-reduce collective communication will be operated in + :func:`handle_gradient` among all sub pipeline parallel groups. + For better performance, it bucketizes the gradients of all parameters that are + the same type to improve the efficiency of communication. + """ + + def handle_gradient(self): + """A method running a all-reduce operation in sub pipeline parallel groups. + """ + if gpc.pipeline_parallel_size > 1: + # bucketize and all-reduce + buckets = defaultdict(lambda: defaultdict(list)) + # Pack the buckets. + for param in self._model.parameters(): + group = getattr(param, 'pipeline_shared_module_pg', None) + if param.requires_grad and param.grad is not None and group is not None: + tp = param.data.type() + buckets[group][tp].append(param) + + # For each bucket, all-reduce and copy all-reduced grads. + for group, group_buckets in buckets.items(): + for tp, bucket in group_buckets.items(): + grads = [param.grad.data for param in bucket] + coalesced = _flatten_dense_tensors(grads) + dist.all_reduce(coalesced, op=dist.ReduceOp.SUM, group=group) + for buf, synced in zip(grads, _unflatten_dense_tensors(coalesced, grads)): + buf.copy_(synced) diff --git a/colossalai/engine/schedule/_base_schedule.py b/colossalai/engine/schedule/_base_schedule.py index 411f1861b..76c550144 100644 --- a/colossalai/engine/schedule/_base_schedule.py +++ b/colossalai/engine/schedule/_base_schedule.py @@ -5,8 +5,7 @@ from abc import ABC, abstractmethod import torch -from torch import Tensor -from typing import Iterable, Union, List, Callable +from typing import Iterable, Callable from .._base_engine import Engine from colossalai.logging import get_dist_logger from colossalai.utils import get_current_device @@ -32,18 +31,17 @@ class BaseSchedule(ABC): return element def _move_to_device(self, data): - if isinstance(data, (tuple, list)): - data = tuple([self._move_tensor(d) for d in data]) - elif torch.is_tensor(data): - data = data.to(get_current_device()).detach() + if isinstance(data, dict): + data = {k: self._move_tensor(v) for k, v in data.items()} + else: + data = self._move_tensor(data) return data - def _to_list(self, data): - if torch.is_tensor(data): - return [data] - return data + @staticmethod + def _check_sanity(data, tag): + assert isinstance(data, (torch.Tensor, dict)), f'{tag} must be torch.Tensor or dict' - def load_batch(self, data_iter): + def load_batch(self, data_iter, to_gpu=True): """Loads a batch from data iterator. It returns the data and labels which are already in the same GPU as where the model's. @@ -58,13 +56,17 @@ class BaseSchedule(ABC): data, label = self.batch_data_process_func(batch_data) else: data, label = batch_data - - if isinstance(label, (tuple, list)): - self.batch_size = label[0].size(0) + self._check_sanity(data, 'data') + self._check_sanity(label, 'label') + if isinstance(data, torch.Tensor): + self.batch_size = data.size(0) else: - self.batch_size = label.size(0) - data, label = self._to_list(split_batch(data)), self._to_list(split_batch(label)) - return self._move_to_device(data), self._move_to_device(label) + self.batch_size = next(iter(data.values())).size(0) + data, label = split_batch(data), split_batch(label) + if to_gpu: + return self._move_to_device(data), self._move_to_device(label) + return data, label + def pre_processing(self, engine: Engine): """To perform actions before running the schedule. @@ -76,7 +78,8 @@ class BaseSchedule(ABC): engine: Engine, data_iter: Iterable, forward_only: bool, - return_loss: bool = True + return_loss: bool = True, + return_output_label: bool = True ): """The process function over a batch of dataset for training or evaluation. @@ -85,5 +88,24 @@ class BaseSchedule(ABC): :param labels: ground truth :param forward_only: If True, the process won't include backward :param return_loss: If False, the loss won't be returned + :param return_output_label: If False, the output and label won't be returned """ - pass \ No newline at end of file + pass + + @staticmethod + def _call_engine(engine, inputs): + if isinstance(inputs, torch.Tensor): + return engine(inputs) + else: + return engine(**inputs) + + @staticmethod + def _call_engine_criterion(engine, outputs, labels): + assert isinstance(outputs, (torch.Tensor, list, tuple) + ), f'Expect output of model is (torch.Tensor, list, tuple), got {type(outputs)}' + if isinstance(outputs, torch.Tensor): + outputs = (outputs, ) + if isinstance(labels, torch.Tensor): + return engine.criterion(*outputs, labels) + else: + return engine.criterion(*outputs, **labels) diff --git a/colossalai/engine/schedule/_non_pipeline_schedule.py b/colossalai/engine/schedule/_non_pipeline_schedule.py index 0d8ee8c69..4ad07d6b4 100644 --- a/colossalai/engine/schedule/_non_pipeline_schedule.py +++ b/colossalai/engine/schedule/_non_pipeline_schedule.py @@ -5,9 +5,7 @@ from typing import Iterable import torch -import torch.nn as nn from colossalai.engine import Engine -from torch.optim import Optimizer from ._base_schedule import BaseSchedule from colossalai.utils import conditional_context @@ -27,18 +25,21 @@ class NonPipelineSchedule(BaseSchedule): engine: Engine, data_iter: Iterable, forward_only: bool = False, - return_loss: bool = True): + return_loss: bool = True, + return_output_label: bool = True): """The process function that loads loads a batch of dataset and feeds it to the model. The returned labels and loss will None if :attr:`return_loss` is False. :param engine: Model for training and inference :param data_iter: Data iterator of the dataloader, e.g. iter(dataloader) :param forward_only: If True, the model is run for the forward pass, else back propagation will be executed :param return_loss: Loss will be returned if True + :param return_output_label: Output and label will be returned if True :type engine: Iterator :type data_iter: Iterator :type forward_only: bool, optional :type return_loss: bool, optional - + :type return_output_label: bool, optional + :return: (output, label, loss) :rtype: Tuple[:class:`torch.Tensor`] """ @@ -48,16 +49,20 @@ class NonPipelineSchedule(BaseSchedule): # forward with conditional_context(torch.no_grad(), enable=forward_only): - output = engine(*data) - if not isinstance(output, (tuple, list)): - output = (output,) + output = self._call_engine(engine, data) if return_loss: - loss = engine.criterion(*output, *label) + loss = self._call_engine_criterion(engine, output, label) if not forward_only: engine.backward(loss) - if return_loss: - return output, label, loss + if return_output_label: + if return_loss: + return output, label, loss + else: + return output, label, None else: - return output, None, None + if return_loss: + return None, None, loss + else: + return None, None, None diff --git a/colossalai/engine/schedule/_pipeline_schedule.py b/colossalai/engine/schedule/_pipeline_schedule.py index ad1aba27b..e619e90f2 100644 --- a/colossalai/engine/schedule/_pipeline_schedule.py +++ b/colossalai/engine/schedule/_pipeline_schedule.py @@ -1,19 +1,19 @@ #!/usr/bin/env python # -*- encoding: utf-8 -*- -from typing import Union - +from typing import List, Tuple, Union, Callable +import inspect import torch.cuda -import torch.distributed as dist from torch import Tensor from colossalai.communication import * from colossalai.context.parallel_mode import ParallelMode from colossalai.core import global_context as gpc from colossalai.amp.naive_amp import NaiveAMPModel +from colossalai.utils.cuda import get_current_device from colossalai.zero import (ZeroRedundancyOptimizer_Level_2, ZeroRedundancyOptimizer_Level_3) -from colossalai.utils import get_current_device, switch_virtual_pipeline_parallel_rank +from colossalai.utils import switch_virtual_pipeline_parallel_rank from ._base_schedule import BaseSchedule @@ -30,102 +30,79 @@ class PipelineSchedule(BaseSchedule): :class:`NonPipelineSchedule`. :param num_microbatches: The number of microbatches - :param amp_type: The type of automatic mixed precision - :param amp_config: The configuration of automatic mixed procision - :param sync_data: If set to `True`, will sync data every batch over pipeline stages :type num_microbatches: int - :type amp_type: AMP_TYPE - :type amp_config: dict - :type sync_data: bool + :param batch_data_process_func: The preprocessing function which receives a batch of data, and it will be executed in `load_batch` + :type batch_data_process_func: Callable """ def __init__(self, num_microbatches, - sync_data: bool = True): - super().__init__() - + batch_data_process_func: Callable = None, + tensor_shape: Union[torch.Size, List[int], Tuple[int]] = None): + super().__init__(batch_data_process_func=batch_data_process_func) self.num_microbatches = num_microbatches - self.sync_data = sync_data self.dtype = torch.float + self.tensor_shape = tensor_shape - def _move_to_device(self, data): - if isinstance(data, ( - tuple, - list, - )): - assert len(data) == 1, "Data tuple's length in pipeline should be 1" - data = data[0] - assert torch.is_tensor(data), "Data in pipeline should be tensor" - data = data.to(get_current_device()).detach() - return data - - def _sync_data(self): - reqs = [] - if gpc.is_first_rank(ParallelMode.PIPELINE): - src_rank = gpc.get_global_rank() - reqs.append(dist.broadcast( - tensor=self.batch_data, - src=src_rank, - group=gpc.get_group(ParallelMode.PIPELINE_PREV), - async_op=True - )) - reqs.append(dist.broadcast( - tensor=self.batch_label, - src=src_rank, - group=gpc.get_group(ParallelMode.PIPELINE_PREV), - async_op=True - )) - if gpc.is_last_rank(ParallelMode.PIPELINE): - src_rank = gpc.get_next_global_rank(ParallelMode.PIPELINE) - reqs.append(dist.broadcast( - tensor=self.batch_data, - src=src_rank, - group=gpc.get_group(ParallelMode.PIPELINE_NEXT), - async_op=True - )) - reqs.append(dist.broadcast( - tensor=self.batch_label, - src=src_rank, - group=gpc.get_group(ParallelMode.PIPELINE_NEXT), - async_op=True - )) - for req in reqs: - req.wait() - - # Pipeline schedule just puts data in memory def load_batch(self, data_iter): - if data_iter is None: - raise RuntimeError('Dataloader is not defined.') - self.batch_pos = 0 - data, label = next(data_iter) - self.batch_data, self.batch_label = \ - self._move_to_device(data), self._move_to_device(label) - batch_size = self.batch_data.shape[0] - assert batch_size % self.num_microbatches == 0, \ + # Pipeline schedule just puts data in memory + self.batch_data, self.batch_label = super().load_batch(data_iter, to_gpu=False) + self.microbatch_offset = 0 + assert self.batch_size % self.num_microbatches == 0, \ "Batch size should divided by the number of microbatches" - self.microbatch_size = batch_size // self.num_microbatches - if self.sync_data: - self._sync_data() + self.microbatch_size = self.batch_size // self.num_microbatches - def _get_data_slice(self, tensor): - return tensor[self.batch_pos: self.batch_pos + self.microbatch_size] + def _get_data_slice(self, data, offset): + if isinstance(data, torch.Tensor): + return data[offset: offset + self.microbatch_size] + else: + return {k: v[offset:offset + self.microbatch_size] for k, v in data.items()} def load_micro_batch(self): - data = self._get_data_slice(self.batch_data) - label = self._get_data_slice(self.batch_label) - self.batch_pos += self.microbatch_size - return (data,), (label,) + data = self._get_data_slice(self.batch_data, self.microbatch_offset) + label = self._get_data_slice(self.batch_label, self.microbatch_offset) + self.microbatch_offset += self.microbatch_size + return self._move_to_device(data), self._move_to_device(label) def pre_processing(self, engine): if isinstance(engine.optimizer, (ZeroRedundancyOptimizer_Level_2, ZeroRedundancyOptimizer_Level_3)): raise TypeError( "Pipeline schedule is currently not compatible with ZeRO Level 2 and Level 3" ) - - if isinstance(engine.model, NaiveAMPModel): + model = engine.model + if isinstance(model, NaiveAMPModel): self.dtype = torch.half + model = model.model + sig = inspect.signature(model.forward) + for p in sig.parameters.values(): + assert p.kind != inspect.Parameter.VAR_POSITIONAL, '*args is not supported' - def forward_step(self, engine, input_tensor, return_tensors, return_loss=True): + @staticmethod + def _call_engine(model, input_tensor, batch_data): + if isinstance(model, NaiveAMPModel): + sig = inspect.signature(model.model.forward) + else: + sig = inspect.signature(model.forward) + if isinstance(batch_data, torch.Tensor): + if input_tensor is None: + return model(batch_data) + elif len(sig.parameters) > 1: + return model(input_tensor, batch_data) + else: + return model(input_tensor) + else: + filter_batch = True + for p in sig.parameters.values(): + if p.kind == inspect.Parameter.VAR_KEYWORD: + filter_batch = False + if filter_batch: + batch_data = {k: v for k, v in batch_data.items() if k in sig.parameters} + if input_tensor is None: + return model(**batch_data) + else: + return model(input_tensor, **batch_data) + + def forward_step(self, engine, input_tensor, return_tensors, return_output_label=True, accum_loss=None): """Forward step for passed-in model. If it is the first stage, the input tensor is obtained from data_iterator, otherwise the passed-in input_tensor is used. Returns output tensor. This is a helper function and can be ignored by users. @@ -140,26 +117,19 @@ class PipelineSchedule(BaseSchedule): :return: output or the loss value of the current pipeline stage :rtype: :class:`torch.Tensor` """ - - if input_tensor is None: - input_tensor, label = self.load_micro_batch() - input_tensor = squeeze(input_tensor) - output_tensor = engine(input_tensor) + data, label = self.load_micro_batch() + output_tensor = self._call_engine(engine.model, input_tensor, data) output_tensor = squeeze(output_tensor) if gpc.is_last_rank(ParallelMode.PIPELINE): - if return_loss: - input_tensor, label = self.load_micro_batch() - loss_reduced = engine.criterion(output_tensor, *label) \ - / self.num_microbatches - - return_tensors.append( - tuple((output_tensor, label[0], loss_reduced))) + if return_output_label: + return_tensors.append(tuple((output_tensor, label))) + if accum_loss is not None: + loss_reduced = self._call_engine_criterion(engine, output_tensor, label) / self.num_microbatches + accum_loss.add_(loss_reduced.detach()) return loss_reduced else: - return_tensors.append(output_tensor) return output_tensor - else: return output_tensor @@ -203,7 +173,8 @@ class PipelineSchedule(BaseSchedule): engine, data_iter, forward_only=False, - return_loss=True): + return_loss=True, + return_output_label=True): """Runs non-interleaved 1F1B schedule, with communication between pipeline stages. Returns a tuple with losses if the last stage, an empty tuple otherwise. @@ -215,6 +186,8 @@ class PipelineSchedule(BaseSchedule): :type forward_only: bool :param return_loss: whether returns the loss value. Default is true. :type return_loss: bool + :param return_output_label: If False, the output and label won't be returned + :type return_output_label: bool :return: (output, label, loss) :rtype: Tuple[:class:`torch.Tensor`] @@ -238,11 +211,14 @@ class PipelineSchedule(BaseSchedule): input_tensors = [] output_tensors = [] return_tensors = [] - + if return_loss and gpc.is_pipeline_last_stage(ignore_virtual=True): + accum_loss = torch.zeros(1, device=get_current_device()) + else: + accum_loss = None # Used for tensor meta information communication - ft_shape = None + ft_shape = self.tensor_shape bt_shape = None - fs_checker = True + fs_checker = self.tensor_shape is None # Run warmup forward passes. for i in range(num_warmup_microbatches): @@ -251,7 +227,8 @@ class PipelineSchedule(BaseSchedule): input_tensor = recv_forward(ft_shape, dtype=self.dtype) output_tensor = self.forward_step( engine, input_tensor, return_tensors, - return_loss=return_loss + return_output_label=return_output_label, + accum_loss=accum_loss ) if not gpc.is_last_rank(ParallelMode.PIPELINE): bt_shape = output_tensor.shape @@ -276,7 +253,8 @@ class PipelineSchedule(BaseSchedule): output_tensor = self.forward_step( engine, input_tensor, return_tensors, - return_loss=return_loss + return_output_label=return_output_label, + accum_loss=accum_loss ) if forward_only: send_forward(output_tensor) @@ -327,24 +305,37 @@ class PipelineSchedule(BaseSchedule): send_backward(input_tensor_grad) if len(return_tensors) > 0: - if return_loss: - output, label, loss = tuple(map(list, zip(*return_tensors))) - return (torch.cat(output, dim=0), - torch.cat(label, dim=0), - sum(loss)) - else: - return tuple((torch.cat(return_tensors, dim=0), None, None)) + output, label = tuple(map(list, zip(*return_tensors))) + return (torch.cat(output, dim=0), + torch.cat(label, dim=0), + accum_loss) else: - return tuple((None, None, None)) + return tuple((None, None, accum_loss)) class InterleavedPipelineSchedule(PipelineSchedule): - def __init__(self, num_microbatches, num_model_chunks, sync_data: bool = True): + def __init__(self, + num_microbatches, + num_model_chunks, + batch_data_process_func: Callable = None, + tensor_shape: Union[torch.Size, List[int], Tuple[int]] = None): + """A helper schedule class for pipeline parallelism running environment. + It uses interleaved 1F1B strategy. Other properties are similar as + :class:`NonPipelineSchedule`. + + :param num_microbatches: The number of microbatches + :type num_microbatches: int + :param num_model_chunks: The number of model chunks + :type num_model_chunks: int + :param batch_data_process_func: The preprocessing function which receives a batch of data, and it will be executed in `load_batch` + :type batch_data_process_func: Callable + """ assert num_microbatches % gpc.get_world_size(ParallelMode.PIPELINE) == 0, \ 'num_microbatches must be an integer multiple of pipeline parallel world size' - super().__init__(num_microbatches, sync_data=sync_data) + super().__init__(num_microbatches, batch_data_process_func=batch_data_process_func, tensor_shape=tensor_shape) gpc.set_virtual_pipeline_parallel_size(num_model_chunks) gpc.set_virtual_pipeline_parallel_rank(0) + self.num_model_chunks = num_model_chunks def pre_processing(self, engine): if isinstance(engine.optimizer, (ZeroRedundancyOptimizer_Level_2, ZeroRedundancyOptimizer_Level_3)): @@ -355,32 +346,46 @@ class InterleavedPipelineSchedule(PipelineSchedule): if isinstance(engine.model[0], NaiveAMPModel): self.dtype = torch.half - def forward_step(self, engine, model, input_tensor, return_tensors, return_loss=True): + for model in engine.model: + if isinstance(model, NaiveAMPModel): + model = model.model + sig = inspect.signature(model.forward) + for p in sig.parameters.values(): + assert p.kind != inspect.Parameter.VAR_POSITIONAL, '*args is not supported' + + def load_batch(self, data_iter): + super().load_batch(data_iter) + # overwrite microbatch_offset, since model chunks load the same microbatch, and should tract the offset + self.microbatch_offset = [0 for _ in range(self.num_model_chunks)] + + def load_micro_batch(self, model_chunk_id): + data = self._get_data_slice(self.batch_data, self.microbatch_offset[model_chunk_id]) + label = self._get_data_slice(self.batch_label, self.microbatch_offset[model_chunk_id]) + self.microbatch_offset[model_chunk_id] += self.microbatch_size + return self._move_to_device(data), self._move_to_device(label) + + def forward_step(self, engine, model_chunk_id, input_tensor, return_tensors, return_output_label=True, accum_loss=None): """Forward step for passed-in model. If it is the first stage, the input tensor is obtained from data_iterator, otherwise the passed-in input_tensor is used. Returns output tensor. This is a helper function and can be ignored by users. """ - - if input_tensor is None: - input_tensor, label = self.load_micro_batch() - input_tensor = squeeze(input_tensor) - output_tensor = model(input_tensor) + data, label = self.load_micro_batch(model_chunk_id) + output_tensor = self._call_engine(engine.model[model_chunk_id], input_tensor, data) output_tensor = squeeze(output_tensor) if gpc.is_pipeline_last_stage(): - if return_loss: - input_tensor, label = self.load_micro_batch() - loss_reduced = engine.criterion(output_tensor, *label) / self.num_microbatches - return_tensors.append( - tuple((output_tensor, label[0], loss_reduced))) + if return_output_label: + return_tensors.append(tuple(output_tensor, label)) + if accum_loss is not None: + loss_reduced = self._call_engine_criterion(engine, output_tensor, label) / self.num_microbatches + accum_loss.add_(loss_reduced.detach()) return loss_reduced else: - return_tensors.append(output_tensor) return output_tensor else: return output_tensor - def forward_backward_step(self, engine, data_iter, forward_only=False, return_loss=True): + def forward_backward_step(self, engine, data_iter, forward_only=False, return_loss=True, return_output_label=True): """Run interleaved 1F1B schedule (model split into model chunks), with communication between pipeline stages as needed. @@ -394,11 +399,15 @@ class InterleavedPipelineSchedule(PipelineSchedule): return_tensors = [] if not forward_only: output_tensor_grads = [[] for _ in range(len(model))] + if return_loss and gpc.is_pipeline_last_stage(ignore_virtual=True): + accum_loss = torch.zeros(1, device=get_current_device()) + else: + accum_loss = None # Used for tensor meta information communication - input_tensor_shapes = [None for _ in range(len(model))] + input_tensor_shapes = [self.tensor_shape for _ in range(len(model))] output_tensor_shapes = [None for _ in range(len(model))] - send_tensor_shape_flags = [True for _ in range(len(model))] + send_tensor_shape_flags = [self.tensor_shape is None for _ in range(len(model))] pipeline_parallel_size = gpc.get_world_size(ParallelMode.PIPELINE) pipeline_parallel_rank = gpc.get_local_rank(ParallelMode.PIPELINE) @@ -450,8 +459,8 @@ class InterleavedPipelineSchedule(PipelineSchedule): len(output_tensors[model_chunk_id]): input_tensors[model_chunk_id].append(None) input_tensor = input_tensors[model_chunk_id][-1] - output_tensor = self.forward_step( - engine, model[model_chunk_id], input_tensor, return_tensors, return_loss=return_loss) + output_tensor = self.forward_step(engine, model_chunk_id, input_tensor, + return_tensors, return_output_label=return_output_label, accum_loss=accum_loss) output_tensors[model_chunk_id].append(output_tensor) # if forward-only, no need to save tensors for a backward pass @@ -633,12 +642,9 @@ class InterleavedPipelineSchedule(PipelineSchedule): dtype=self.dtype)) if len(return_tensors) > 0: - if return_loss: - output, label, loss = tuple(map(list, zip(*return_tensors))) - return (torch.cat(output, dim=0), - torch.cat(label, dim=0), - sum(loss)) - else: - return tuple((torch.cat(return_tensors, dim=0), None, None)) + output, label = tuple(map(list, zip(*return_tensors))) + return (torch.cat(output, dim=0), + torch.cat(label, dim=0), + accum_loss) else: - return tuple((None, None, None)) + return tuple((None, None, accum_loss)) diff --git a/colossalai/initialize.py b/colossalai/initialize.py index 519094998..6a767338b 100644 --- a/colossalai/initialize.py +++ b/colossalai/initialize.py @@ -338,6 +338,19 @@ def initialize(model: Union[nn.Module, List[nn.Module]], "Data parallel training is detected when using pipeline parallel, DataParallelGradientHandler is automatically " "added even though not specified in the configuration", ranks=[0]) + # add pipeline parallel gradient handler, if pipeline shared module is detected + for param in model.parameters(): + if getattr(param, 'pipeline_shared_module_pg', None) is not None: + if gradient_handler_cfg is None: + gradient_handler_cfg = [dict(type='PipelineSharedModuleGradientHandler')] + else: + gradient_handler_cfg.append(dict(type='PipelineSharedModuleGradientHandler')) + if verbose: + logger.info( + "pipeline_shared_module is detected, PipelineSharedModuleGradientHandler is automatically " + "added even though not specified in the configuration", + ranks=[0]) + break else: if not isinstance(gradient_handler_cfg, list): raise ConfigException( diff --git a/colossalai/nn/layer/colossalai_layer/_utils.py b/colossalai/nn/layer/colossalai_layer/_utils.py index 8b996c860..0eb8e39e2 100644 --- a/colossalai/nn/layer/colossalai_layer/_utils.py +++ b/colossalai/nn/layer/colossalai_layer/_utils.py @@ -11,8 +11,8 @@ _parallel_split_batch = {'2d': split_tensor_2d, '2.5d': split_tensor_2p5d, '3d': def split_batch(input_) -> Tensor: tensor_parallel_mode = get_tensor_parallel_mode() if tensor_parallel_mode in _parallel_split_batch: - if isinstance(input_, (tuple, list)): - return tuple(map(_parallel_split_batch[tensor_parallel_mode], input_)) + if isinstance(input_, dict): + return {k: _parallel_split_batch[tensor_parallel_mode](v) for k, v in input_.items()} else: return _parallel_split_batch[tensor_parallel_mode](input_) else: diff --git a/colossalai/nn/layer/wrapper/__init__.py b/colossalai/nn/layer/wrapper/__init__.py index a19f65dcc..01f746f65 100644 --- a/colossalai/nn/layer/wrapper/__init__.py +++ b/colossalai/nn/layer/wrapper/__init__.py @@ -1,3 +1,4 @@ from .lambda_wrapper import LambdaWrapper +from .pipeline_wrapper import PipelineSharedModuleWrapper -__all__ = ['LambdaWrapper'] +__all__ = ['LambdaWrapper', 'PipelineSharedModuleWrapper'] diff --git a/colossalai/nn/layer/wrapper/pipeline_wrapper.py b/colossalai/nn/layer/wrapper/pipeline_wrapper.py new file mode 100644 index 000000000..5bd4a83e3 --- /dev/null +++ b/colossalai/nn/layer/wrapper/pipeline_wrapper.py @@ -0,0 +1,40 @@ +import torch.nn as nn +import torch.distributed as dist +from typing import List, Tuple, Union +from colossalai.context import ParallelMode +from colossalai.core import global_context as gpc + + +class PipelineSharedModuleWrapper: + def __init__(self, pipeline_ranks: Union[List[int], Tuple[int]]) -> None: + assert len(pipeline_ranks) > 1, f'Expect len(pipeline_ranks) > 1, got {len(pipeline_ranks)}' + self.pipeline_ranks = pipeline_ranks + self.group = None + self.ranks_in_group = None + self._init_group() + + def _init_group(self): + world_size = gpc.get_world_size(ParallelMode.GLOBAL) + dp_size = gpc.get_world_size(ParallelMode.DATA) + pp_size = gpc.get_world_size(ParallelMode.PIPELINE) + rank = gpc.get_global_rank() + num_dp_groups = world_size // dp_size + num_pp_stages = num_dp_groups // pp_size + for i in range(dp_size): + for j in range(num_pp_stages): + pipeline_ranks = list( + range(i * num_dp_groups + j, + (i + 1) * num_dp_groups, + num_pp_stages)) + sub_ranks = [pipeline_ranks[idx] for idx in self.pipeline_ranks] + group = dist.new_group(sub_ranks) + if rank in sub_ranks: + self.group = group + self.ranks_in_group = sub_ranks + + def register_module(self, module: nn.Module): + assert self.ranks_in_group is not None, f'Rank {gpc.get_local_rank(ParallelMode.PIPELINE)} is not in pipeline_ranks {self.pipeline_ranks}' + src = self.ranks_in_group[self.pipeline_ranks[0]] + for p in module.parameters(): + setattr(p, 'pipeline_shared_module_pg', self.group) + dist.broadcast(p, src, group=self.group) diff --git a/colossalai/trainer/_trainer.py b/colossalai/trainer/_trainer.py index 5abd016cc..e3257a20e 100644 --- a/colossalai/trainer/_trainer.py +++ b/colossalai/trainer/_trainer.py @@ -155,7 +155,8 @@ class Trainer: def _train_epoch(self, train_dataloader: DataLoader, epoch: int = None, - display_progress: bool = False): + display_progress: bool = False, + return_output_label: bool = True): # set training state self._engine.train() data_iter = iter(train_dataloader) @@ -175,7 +176,7 @@ class Trainer: # run 1 training step self.engine.zero_grad() logits, label, loss = self.schedule.forward_backward_step( - self.engine, data_iter, forward_only=False, return_loss=True) + self.engine, data_iter, forward_only=False, return_loss=True, return_output_label=return_output_label) self.engine.step() self._call_timer(action='stop', item='Train-step', keep_in_history=True) self._call_hooks('after_train_iter', output=(logits, label, loss)) @@ -197,7 +198,8 @@ class Trainer: def _eval(self, test_dataloader: DataLoader, epoch: int = None, - display_progress: bool = False): + display_progress: bool = False, + return_output_label: bool = True): # switch engine status self._engine.eval() @@ -220,7 +222,7 @@ class Trainer: self._call_hooks('before_test_iter') self._call_timer(action='start', item='Test-step') logits, label, loss = self.schedule.forward_backward_step( - self.engine, data_iter, forward_only=True, return_loss=True) + self.engine, data_iter, forward_only=True, return_loss=True, return_output_label=return_output_label) self._call_timer(action='stop', item='Test-step', keep_in_history=True) self._call_hooks('after_test_iter', output=(logits, label, loss)) @@ -246,6 +248,7 @@ class Trainer: test_interval: int = 1, hooks: List[BaseHook] = None, display_progress: bool = False, + return_output_label: bool = True, ): """Trains the model to fit training data. @@ -256,6 +259,8 @@ class Trainer: :param test_interval: Interval of testing :param hooks_cfg: A list of hook configuration :param display_progress: If True, the training progress will be printed + :param return_output_label: If True, the output of model and the label will be returned + :type return_output_label: bool :type train_dataloader: DataLoader :type epochs: int :type max_steps: int @@ -307,7 +312,8 @@ class Trainer: self._train_epoch( train_dataloader=train_dataloader, epoch=epoch, - display_progress=display_progress + display_progress=display_progress, + return_output_label=return_output_label ) # start eval @@ -315,6 +321,7 @@ class Trainer: self._eval(test_dataloader=test_dataloader, display_progress=display_progress, epoch=epoch, + return_output_label=return_output_label ) self._cur_epoch += 1 @@ -331,13 +338,16 @@ class Trainer: def evaluate(self, test_dataloader: DataLoader, hooks: List[BaseHook] = None, - display_progress: bool = False): + display_progress: bool = False, + return_output_label: bool = True): """Evaluates the model with testing data. :param test_dataloader: DataLoader in testing :param display_progress: If True, the evaluation progress will be printed + :param return_output_label: If True, the output of model and the label will be returned :type test_dataloader: DataLoader :type display_progress: bool, optional + :type return_output_label: bool """ # set display display_progress = self._should_display_progress(display_progress) @@ -360,6 +370,7 @@ class Trainer: # eval self._eval(test_dataloader=test_dataloader, display_progress=display_progress, + return_output_label=return_output_label ) def predict(self, data: Union[Tensor, List[Tensor]]): @@ -383,4 +394,4 @@ class Trainer: data_iter = iter(simple_dataloader) output, _, _ = self.schedule.forward_backward_step( self.engine, data_iter, forward_only=True, return_loss=False) - return output \ No newline at end of file + return output diff --git a/colossalai/utils/common.py b/colossalai/utils/common.py index 6e3318172..3d64a7b6f 100644 --- a/colossalai/utils/common.py +++ b/colossalai/utils/common.py @@ -155,22 +155,12 @@ def clip_grad_norm_fp32(parameters, max_norm, norm_type=2): if norm_type == inf: total_norm = max(p.grad.data.abs().max() for p in params) total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)]) - ops = [] # Take max across all model-parallel GPUs. - if gpc.is_initialized(ParallelMode.TENSOR) and gpc.get_world_size(ParallelMode.TENSOR) > 1: - ops.append(dist.all_reduce(total_norm_cuda, - op=dist.ReduceOp.MAX, - group=gpc.get_group( - ParallelMode.TENSOR), - async_op=True)) - if gpc.is_initialized(ParallelMode.PIPELINE) and gpc.get_world_size(ParallelMode.PIPELINE) > 1: - ops.append(dist.all_reduce(total_norm_cuda, - op=dist.ReduceOp.MAX, - group=gpc.get_group( - ParallelMode.PIPELINE), - async_op=True)) - for req in ops: - req.wait() + if gpc.is_initialized(ParallelMode.MODEL) and gpc.get_world_size(ParallelMode.MODEL) > 1: + dist.all_reduce(total_norm_cuda, + op=dist.ReduceOp.MAX, + group=gpc.get_group(ParallelMode.MODEL), + async_op=False) total_norm = total_norm_cuda[0].item() else: tensor_parallel_grads = [] diff --git a/colossalai/utils/gradient_accumulation/_gradient_accumulation.py b/colossalai/utils/gradient_accumulation/_gradient_accumulation.py index 8c159c628..e5f3a5796 100644 --- a/colossalai/utils/gradient_accumulation/_gradient_accumulation.py +++ b/colossalai/utils/gradient_accumulation/_gradient_accumulation.py @@ -65,6 +65,7 @@ class GradAccumOptimizer(ColossalaiOptimizer): self.optim.backward(scaled_loss) def backward_by_grad(self, tensor: Tensor, grad: Tensor): + self.accumulate_step += 1 no_sync = self.is_torch_ddp and self.accumulate_step < self.accumulate_size if no_sync: @@ -81,7 +82,7 @@ class GradAccumDataloader(): be update only twice at step 4 and step 8. The last two batches of data do not form a complete 4-step cycle. Thus, they will be automatically skipped by this class. If the dataloader is not standard PyTorch dataloader, (e.g. Dali dataloader), this class will automatically consume (load data for nothing) the remaining 2 batches. - + :param dataloader: your dataloader object :type dataloader: Iterable :param accumulate_size: the number of steps to accumulate gradients diff --git a/docs/add_your_parallel.md b/docs/add_your_parallel.md index 01b8afb67..6a8fe1ed7 100644 --- a/docs/add_your_parallel.md +++ b/docs/add_your_parallel.md @@ -26,8 +26,6 @@ follow the steps below to create a new distributed initialization. GLOBAL = 'global' DATA = 'data' PIPELINE = 'pipe' - PIPELINE_PREV = 'pipe_prev' - PIPELINE_NEXT = 'pipe_next' ... NEW_MODE = 'new_mode' # define your mode here diff --git a/docs/add_your_parallel_zh.md b/docs/add_your_parallel_zh.md index 5be00c3c7..b4625e465 100644 --- a/docs/add_your_parallel_zh.md +++ b/docs/add_your_parallel_zh.md @@ -18,8 +18,6 @@ class ParallelMode(Enum): GLOBAL = 'global' DATA = 'data' PIPELINE = 'pipe' - PIPELINE_PREV = 'pipe_prev' - PIPELINE_NEXT = 'pipe_next' ... NEW_MODE = 'new_mode' # define your mode here diff --git a/tests/test_context/test_2d_init.py b/tests/test_context/test_2d_init.py index 22826bf38..117b6e0d6 100644 --- a/tests/test_context/test_2d_init.py +++ b/tests/test_context/test_2d_init.py @@ -33,6 +33,12 @@ def check_pipeline_parallel_rank(rank): assert gpc.get_local_rank(ParallelMode.PIPELINE) == 1 +def check_model_parallel_rank(rank): + for i in range(8): + if rank in [i, i+8]: + assert gpc.get_local_rank(ParallelMode.MODEL) == i + + def check_tensor_parallel_rank(rank): if rank in [0, 4, 8, 12]: assert gpc.get_local_rank(ParallelMode.TENSOR) == 0 @@ -75,6 +81,7 @@ def init_2d(rank, world_size, backend, port, host): check_data_parallel_rank(rank) check_2d_parallel_rank(rank) check_pipeline_parallel_rank(rank) + check_model_parallel_rank(rank) gpc.destroy() torch.cuda.empty_cache() diff --git a/tests/test_context/test_2p5d_init.py b/tests/test_context/test_2p5d_init.py index 3668c701e..ef6789710 100644 --- a/tests/test_context/test_2p5d_init.py +++ b/tests/test_context/test_2p5d_init.py @@ -37,6 +37,12 @@ def check_pipeline_parallel_rank(rank): assert ppr == 1 +def check_model_parallel_rank(rank): + for i in range(16): + if rank in [i, i+16]: + assert gpc.get_local_rank(ParallelMode.MODEL) == i + + def check_tensor_parallel_rank(rank): tp_rank = gpc.get_local_rank(ParallelMode.TENSOR) @@ -98,6 +104,7 @@ def init_2halfd(rank, world_size, backend, port, host): check_pipeline_parallel_rank(rank) check_tensor_parallel_rank(rank) check_2p5d_parallel_rank(rank) + check_model_parallel_rank(rank) gpc.destroy() torch.cuda.empty_cache() diff --git a/tests/test_context/test_3d_init.py b/tests/test_context/test_3d_init.py index c9395f868..12f0f1ea5 100644 --- a/tests/test_context/test_3d_init.py +++ b/tests/test_context/test_3d_init.py @@ -37,6 +37,12 @@ def check_pipeline_parallel_rank(rank): assert ppr == 1 +def check_model_parallel_rank(rank): + for i in range(16): + if rank in [i, i+16]: + assert gpc.get_local_rank(ParallelMode.MODEL) == i + + def check_tensor_parallel_rank(rank): tp_rank = gpc.get_local_rank(ParallelMode.TENSOR) @@ -90,6 +96,7 @@ def init_3d(rank, world_size, backend, port, host): check_3d_parallel_rank(rank) check_data_parallel_rank(rank) check_pipeline_parallel_rank(rank) + check_model_parallel_rank(rank) gpc.destroy() torch.cuda.empty_cache() diff --git a/tests/test_data_pipeline_tensor_parallel/test_cifar_with_data_pipeline_tensor.py b/tests/test_data_pipeline_tensor_parallel/test_cifar_with_data_pipeline_tensor.py index a472bf0ee..27d1a5e21 100644 --- a/tests/test_data_pipeline_tensor_parallel/test_cifar_with_data_pipeline_tensor.py +++ b/tests/test_data_pipeline_tensor_parallel/test_cifar_with_data_pipeline_tensor.py @@ -23,7 +23,7 @@ BATCH_SIZE = 16 NUM_EPOCHS = 60 WARMUP_EPOCHS = 5 CONFIG = dict(parallel=dict(pipeline=2, tensor=dict(size=2, mode='1d')), - fp16=dict(mode=AMP_TYPE.TORCH), + fp16=dict(mode=AMP_TYPE.NAIVE), gradient_accumulation=2) diff --git a/tests/test_trainer/test_pipeline/test_p2p.py b/tests/test_trainer/test_pipeline/test_p2p.py index 283f49fa0..5258b42a5 100644 --- a/tests/test_trainer/test_pipeline/test_p2p.py +++ b/tests/test_trainer/test_pipeline/test_p2p.py @@ -75,40 +75,7 @@ def check_forward_backward(output_tensor, output_grad, rank, logger): rank, check_equal(grad, output_grad))) -def check_op(size, rank, prev_rank, next_rank, up_group, down_group, logger): - dtype = torch.float32 - device = get_current_device() - tensor_shape = (BATCH_SIZE, SEQ_LENGTH, HIDDEN_SIZE) - # recv_tensor_shape = (BATCH_SIZE, SEQ_LENGTH, HIDDEN_SIZE) - grad_shape = (BATCH_SIZE, SEQ_LENGTH, HIDDEN_SIZE) - tensor = torch.randn(tensor_shape, dtype=dtype, device=device) - dist.all_reduce(tensor) - grad = torch.randn(grad_shape, dtype=dtype, device=device) - dist.all_reduce(grad) - if rank % 2 == 0: - need_meta = True - need_meta = send_tensor_meta(tensor, need_meta) - logger.info('Rank {} shape sent (need meta: {}).'.format( - rank, need_meta)) - req = dist.broadcast(tensor, src=rank, group=down_group, async_op=True) - req.wait() - out = tensor.clone() - logger.info('Rank {} test op: tensor sent.'.format(rank)) - else: - recv_tensor_shape = recv_tensor_meta(None) - logger.info('Rank {} shape received. Correct shape: {}'.format( - rank, tensor_shape == recv_tensor_shape)) - out = torch.empty(recv_tensor_shape, dtype=dtype, device=device) - req = dist.broadcast(out, src=prev_rank, group=up_group, async_op=True) - req.wait() - logger.info('Rank {} test op: received tensor ({})'.format( - rank, out.shape)) - - logger.info('Rank {} test op. Correct tensor: {}'.format( - rank, check_equal(tensor, out))) - - -def check_comm(size, rank, prev_rank, next_rank, up_group, down_group, logger): +def check_comm(size, rank, prev_rank, next_rank, logger): dtype = torch.float32 device = get_current_device() tensor_shape = (BATCH_SIZE, SEQ_LENGTH, HIDDEN_SIZE) @@ -117,7 +84,6 @@ def check_comm(size, rank, prev_rank, next_rank, up_group, down_group, logger): dist.all_reduce(tensor) grad = torch.randn(grad_shape, dtype=dtype, device=device) dist.all_reduce(grad) - check_op(size, rank, prev_rank, next_rank, up_group, down_group, logger) check_forward(tensor, rank, logger) check_backward(grad, rank, logger) check_forward_backward(tensor, grad, rank, logger) @@ -135,18 +101,13 @@ def run_check(rank, world_size, port): logger = get_dist_logger() rank = gpc.get_global_rank() prev_rank = gpc.get_prev_global_rank(ParallelMode.PIPELINE) - up_ranks = gpc.get_ranks_in_group(ParallelMode.PIPELINE_PREV) - up_group = gpc.get_group(ParallelMode.PIPELINE_PREV) next_rank = gpc.get_next_global_rank(ParallelMode.PIPELINE) - down_ranks = gpc.get_ranks_in_group(ParallelMode.PIPELINE_NEXT) - down_group = gpc.get_group(ParallelMode.PIPELINE_NEXT) logger.info( - 'Rank {0}: prev rank {1} (up: {2}), next rank {3} (down: {4})'.format( - rank, prev_rank, up_ranks, next_rank, down_ranks)) + 'Rank {0}: prev rank {1}, next rank {2}'.format( + rank, prev_rank, next_rank)) logger.info('Distributed environment is initialzied.') - check_comm(world_size, rank, prev_rank, next_rank, up_group, down_group, - logger) + check_comm(world_size, rank, prev_rank, next_rank, logger) gpc.destroy() torch.cuda.empty_cache()