Optimize pipeline schedule (#94)

* add pipeline shared module wrapper and update load batch

* added model parallel process group for amp and clip grad (#86)

* added model parallel process group for amp and clip grad

* update amp and clip with model parallel process group

* remove pipeline_prev/next group (#88)

* micro batch offload

* optimize pipeline gpu memory usage

* pipeline can receive tensor shape (#93)

* optimize pipeline gpu memory usage

* fix grad accumulation step counter

* rename classes and functions

Co-authored-by: Frank Lee <somerlee.9@gmail.com>
pull/97/head
ver217 2021-12-30 15:56:46 +08:00 committed by GitHub
parent e5b9f9a08d
commit 96780e6ee4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
29 changed files with 423 additions and 290 deletions

View File

@ -359,12 +359,7 @@ class FP16Optimizer(Optimizer):
# Update across all model parallel instances. # Update across all model parallel instances.
torch.distributed.all_reduce(self.found_inf, torch.distributed.all_reduce(self.found_inf,
op=torch.distributed.ReduceOp.MAX, op=torch.distributed.ReduceOp.MAX,
group=gpc.get_group(ParallelMode.TENSOR)) group=gpc.get_group(ParallelMode.MODEL))
if is_using_pp():
torch.distributed.all_reduce(self.found_inf,
op=torch.distributed.ReduceOp.MAX,
group=gpc.get_group(ParallelMode.PIPELINE))
# Check for nan. # Check for nan.
found_inf_flag = (self.found_inf.item() > 0) found_inf_flag = (self.found_inf.item() > 0)

View File

@ -11,6 +11,7 @@ from typing import Any, Dict, List, Optional, Tuple
from colossalai.context import ParallelMode from colossalai.context import ParallelMode
import torch.distributed as dist import torch.distributed as dist
from colossalai.core import global_context as gpc from colossalai.core import global_context as gpc
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
class _MultiDeviceReplicator(object): class _MultiDeviceReplicator(object):
@ -247,10 +248,14 @@ class GradScaler(object):
device), device),
per_device_inv_scale.get(device)) per_device_inv_scale.get(device))
# For tensor parallel paramters it should be all-reduced over tensor parallel process group # For tensor parallel paramters it should be all-reduced over tensor parallel process group
if gpc.is_initialized(ParallelMode.TENSOR) and gpc.get_world_size(ParallelMode.TENSOR) > 1: if gpc.is_initialized(ParallelMode.MODEL) and gpc.get_world_size(ParallelMode.MODEL) > 1:
for tensor in per_device_found_inf._per_device_tensors.values(): vals = [val for val in per_device_found_inf._per_device_tensors.values()]
dist.all_reduce(tensor, op=dist.ReduceOp.MAX, coalesced = _flatten_dense_tensors(vals)
group=gpc.get_group(ParallelMode.TENSOR)) dist.all_reduce(coalesced,
op=dist.ReduceOp.MAX,
group=gpc.get_group(ParallelMode.MODEL))
for buf, synced in zip(vals, _unflatten_dense_tensors(coalesced, vals)):
buf.copy_(synced)
return per_device_found_inf._per_device_tensors return per_device_found_inf._per_device_tensors
def unscale_(self, optimizer): def unscale_(self, optimizer):

View File

@ -112,7 +112,7 @@ def _binary_search(weights, num):
return intervals return intervals
def _partition_uniform(num_items, pipeline_parallel_size, num_chunks): def partition_uniform(num_items, pipeline_parallel_size, num_chunks):
assert num_items % num_chunks == 0, \ assert num_items % num_chunks == 0, \
"Layer length should be divided by the number of chunks, otherwise parameter method is recomended" "Layer length should be divided by the number of chunks, otherwise parameter method is recomended"
@ -134,11 +134,11 @@ def _partition_uniform(num_items, pipeline_parallel_size, num_chunks):
return parts return parts
def _partition_balanced(weights, pipeline_parallel_size, num_chunks): def partition_balanced(weights, pipeline_parallel_size, num_chunks):
num_total = pipeline_parallel_size * num_chunks num_total = pipeline_parallel_size * num_chunks
num_items = len(weights) num_items = len(weights)
if num_items <= num_total: if num_items <= num_total:
return _partition_uniform(num_items, pipeline_parallel_size, num_chunks) return partition_uniform(num_items, pipeline_parallel_size, num_chunks)
intervals = _binary_search(weights, num_total) intervals = _binary_search(weights, num_total)
@ -151,7 +151,7 @@ def _partition_balanced(weights, pipeline_parallel_size, num_chunks):
return parts return parts
def _count_layer_params(layers): def count_layer_params(layers):
"""Count the number of parameters in each layer """Count the number of parameters in each layer
""" """
param_counts = [0] * len(layers) param_counts = [0] * len(layers)
@ -201,11 +201,11 @@ def build_pipeline_model_from_cfg(config, num_chunks: int = 1, partition_method:
# Make a partition # Make a partition
if method == 'layer': if method == 'layer':
num_layers = len(layers) num_layers = len(layers)
parts = _partition_uniform(num_layers, pipeline_parallel_size, num_chunks) parts = partition_uniform(num_layers, pipeline_parallel_size, num_chunks)
elif method == 'parameter': elif method == 'parameter':
param_counts = _count_layer_params(layers) param_counts = count_layer_params(layers)
# print_rank_0(param_counts) # print_rank_0(param_counts)
parts = _partition_balanced(param_counts, pipeline_parallel_size, num_chunks) parts = partition_balanced(param_counts, pipeline_parallel_size, num_chunks)
else: else:
raise ValueError("Method should be a pre-set string in [layer, parameter]") raise ValueError("Method should be a pre-set string in [layer, parameter]")
@ -250,7 +250,7 @@ def build_pipeline_model(layers: nn.Sequential, num_chunks: int = 1, verbose: bo
""" """
pipeline_parallel_size = gpc.get_world_size(ParallelMode.PIPELINE) pipeline_parallel_size = gpc.get_world_size(ParallelMode.PIPELINE)
pipeline_rank = gpc.get_local_rank(ParallelMode.PIPELINE) pipeline_rank = gpc.get_local_rank(ParallelMode.PIPELINE)
partitions = _partition_uniform(len(layers), pipeline_parallel_size, num_chunks) partitions = partition_uniform(len(layers), pipeline_parallel_size, num_chunks)
module_list = [] module_list = []
for start, end in partitions[pipeline_rank]: for start, end in partitions[pipeline_rank]:
module_list.append(nn.Sequential(*layers[start:end])) module_list.append(nn.Sequential(*layers[start:end]))

View File

@ -14,7 +14,8 @@ INITIALIZER_MAPPING = {
'2d': 'Initializer_2D', '2d': 'Initializer_2D',
'2.5d': 'Initializer_2p5D', '2.5d': 'Initializer_2p5D',
'3d': 'Initializer_3D', '3d': 'Initializer_3D',
'sequence': 'Initializer_Sequence' 'sequence': 'Initializer_Sequence',
'model': 'Initializer_Model'
} }
# 1D parallel # 1D parallel

View File

@ -394,6 +394,9 @@ class ParallelContext:
# LSG: init data parallel process group for compatibility with other parallel module such as zero # LSG: init data parallel process group for compatibility with other parallel module such as zero
pg_init.append(dict(type=INITIALIZER_MAPPING['data'])) pg_init.append(dict(type=INITIALIZER_MAPPING['data']))
# LSG: init model parallel process group for compatibility with amp and clip grad
pg_init.append(dict(type=INITIALIZER_MAPPING['model']))
if self.pipeline_parallel_size > 1: if self.pipeline_parallel_size > 1:
pg_init.append(dict(type=INITIALIZER_MAPPING['pipeline'])) pg_init.append(dict(type=INITIALIZER_MAPPING['pipeline']))
pg_init.append(dict(type=INITIALIZER_MAPPING['tensor'])) pg_init.append(dict(type=INITIALIZER_MAPPING['tensor']))

View File

@ -14,10 +14,12 @@ class ParallelMode(Enum):
# common parallel # common parallel
DATA = 'data' DATA = 'data'
# model parallel - containing tensor and pipeline parallel groups
# this is added to facilitate amp and grad clipping in hybrid parallel
MODEL = 'model'
# pipeline parallel # pipeline parallel
PIPELINE = 'pipe' PIPELINE = 'pipe'
PIPELINE_PREV = 'pipe_prev'
PIPELINE_NEXT = 'pipe_next'
# containing all ranks in tensor parallel # containing all ranks in tensor parallel
TENSOR = 'tensor' TENSOR = 'tensor'

View File

@ -6,10 +6,11 @@ from .initializer_data import Initializer_Data
from .initializer_pipeline import Initializer_Pipeline from .initializer_pipeline import Initializer_Pipeline
from .initializer_sequence import Initializer_Sequence from .initializer_sequence import Initializer_Sequence
from .initializer_tensor import Initializer_Tensor from .initializer_tensor import Initializer_Tensor
from .initializer_model import Initializer_Model
from .process_group_initializer import ProcessGroupInitializer from .process_group_initializer import ProcessGroupInitializer
__all__ = [ __all__ = [
'Initializer_Tensor', 'Initializer_Sequence', 'Initializer_Pipeline', 'Initializer_Tensor', 'Initializer_Sequence', 'Initializer_Pipeline',
'Initializer_Data', 'Initializer_2p5D', 'Initializer_2D', 'Initializer_3D', 'Initializer_Data', 'Initializer_2p5D', 'Initializer_2D', 'Initializer_3D',
'Initializer_1D', 'ProcessGroupInitializer' 'Initializer_1D', 'ProcessGroupInitializer', 'Initializer_Model'
] ]

View File

@ -0,0 +1,43 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
import torch.distributed as dist
from colossalai.context import Config
from colossalai.registry import DIST_GROUP_INITIALIZER
from .process_group_initializer import ProcessGroupInitializer
from ..parallel_mode import ParallelMode
@DIST_GROUP_INITIALIZER.register_module
class Initializer_Model(ProcessGroupInitializer):
'''A ProcessGroupInitializer for model parallelism (model parallel group contains pipeline and tensor parallel groups).
'''
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.model_parallel_size = self.tensor_parallel_size * self.pipeline_parallel_size
self.num_group = self.world_size // self.model_parallel_size
def init_dist_group(self):
'''Initialize 1D tensor parallel groups, and assign local_ranks and groups to each gpu.
:return: (local_rank, group_world_size, process_group, ranks_in_group, mode)
:rtype: tuple
'''
local_rank = None
ranks_in_group = None
process_group = None
group_world_size = None
mode = ParallelMode.MODEL
for i in range(self.num_group):
ranks = [i * self.model_parallel_size + j for j in range(self.model_parallel_size)]
group = dist.new_group(ranks)
if self.rank in ranks:
local_rank = ranks.index(self.rank)
group_world_size = len(ranks)
process_group = group
ranks_in_group = ranks
return local_rank, group_world_size, process_group, ranks_in_group, mode

View File

@ -36,28 +36,4 @@ class Initializer_Pipeline(ProcessGroupInitializer):
process_group, ranks_in_group, process_group, ranks_in_group,
ParallelMode.PIPELINE))) ParallelMode.PIPELINE)))
for k in range(pipe_group_size):
first = pipe_ranks[k]
second = pipe_ranks[(k + 1) % pipe_group_size]
ranks = [first, second]
group = dist.new_group(ranks)
if self.rank == first:
local_rank = 0
group_world_size = 2
process_group = group
ranks_in_group = ranks
dist_settings.append(
tuple((local_rank, group_world_size,
process_group, ranks_in_group,
ParallelMode.PIPELINE_NEXT)))
elif self.rank == second:
local_rank = 1
group_world_size = 2
process_group = group
ranks_in_group = ranks
dist_settings.append(
tuple((local_rank, group_world_size,
process_group, ranks_in_group,
ParallelMode.PIPELINE_PREV)))
return dist_settings return dist_settings

View File

@ -2,15 +2,12 @@
# -*- encoding: utf-8 -*- # -*- encoding: utf-8 -*-
import torch
from typing import List from typing import List
from torch.nn import Module from torch.nn import Module
from torch.nn.modules.loss import _Loss from torch.nn.modules.loss import _Loss
from torch.optim import Optimizer from torch.optim import Optimizer
from colossalai.builder import build_gradient_handler
from colossalai.logging import get_dist_logger from colossalai.logging import get_dist_logger
from colossalai.utils import is_using_ddp, is_using_pp
from torch import Tensor from torch import Tensor

View File

@ -1,5 +1,7 @@
from ._base_gradient_handler import BaseGradientHandler from ._base_gradient_handler import BaseGradientHandler
from ._data_parallel_gradient_handler import DataParallelGradientHandler from ._data_parallel_gradient_handler import DataParallelGradientHandler
from ._zero_gradient_handler import ZeROGradientHandler from ._zero_gradient_handler import ZeROGradientHandler
from ._pipeline_parallel_gradient_handler import PipelineSharedModuleGradientHandler
__all__ = ['BaseGradientHandler', 'DataParallelGradientHandler', 'ZeROGradientHandler'] __all__ = ['BaseGradientHandler', 'DataParallelGradientHandler',
'ZeROGradientHandler', 'PipelineSharedModuleGradientHandler']

View File

@ -0,0 +1,41 @@
#!/usr/bin/env python
import torch.distributed as dist
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
from colossalai.core import global_context as gpc
from colossalai.registry import GRADIENT_HANDLER
from ._base_gradient_handler import BaseGradientHandler
from collections import defaultdict
@GRADIENT_HANDLER.register_module
class PipelineSharedModuleGradientHandler(BaseGradientHandler):
"""A helper class to handle all-reduce operations in sub parallel groups.
A all-reduce collective communication will be operated in
:func:`handle_gradient` among all sub pipeline parallel groups.
For better performance, it bucketizes the gradients of all parameters that are
the same type to improve the efficiency of communication.
"""
def handle_gradient(self):
"""A method running a all-reduce operation in sub pipeline parallel groups.
"""
if gpc.pipeline_parallel_size > 1:
# bucketize and all-reduce
buckets = defaultdict(lambda: defaultdict(list))
# Pack the buckets.
for param in self._model.parameters():
group = getattr(param, 'pipeline_shared_module_pg', None)
if param.requires_grad and param.grad is not None and group is not None:
tp = param.data.type()
buckets[group][tp].append(param)
# For each bucket, all-reduce and copy all-reduced grads.
for group, group_buckets in buckets.items():
for tp, bucket in group_buckets.items():
grads = [param.grad.data for param in bucket]
coalesced = _flatten_dense_tensors(grads)
dist.all_reduce(coalesced, op=dist.ReduceOp.SUM, group=group)
for buf, synced in zip(grads, _unflatten_dense_tensors(coalesced, grads)):
buf.copy_(synced)

View File

@ -5,8 +5,7 @@ from abc import ABC, abstractmethod
import torch import torch
from torch import Tensor from typing import Iterable, Callable
from typing import Iterable, Union, List, Callable
from .._base_engine import Engine from .._base_engine import Engine
from colossalai.logging import get_dist_logger from colossalai.logging import get_dist_logger
from colossalai.utils import get_current_device from colossalai.utils import get_current_device
@ -32,18 +31,17 @@ class BaseSchedule(ABC):
return element return element
def _move_to_device(self, data): def _move_to_device(self, data):
if isinstance(data, (tuple, list)): if isinstance(data, dict):
data = tuple([self._move_tensor(d) for d in data]) data = {k: self._move_tensor(v) for k, v in data.items()}
elif torch.is_tensor(data): else:
data = data.to(get_current_device()).detach() data = self._move_tensor(data)
return data return data
def _to_list(self, data): @staticmethod
if torch.is_tensor(data): def _check_sanity(data, tag):
return [data] assert isinstance(data, (torch.Tensor, dict)), f'{tag} must be torch.Tensor or dict'
return data
def load_batch(self, data_iter): def load_batch(self, data_iter, to_gpu=True):
"""Loads a batch from data iterator. It returns the data and labels which are """Loads a batch from data iterator. It returns the data and labels which are
already in the same GPU as where the model's. already in the same GPU as where the model's.
@ -58,13 +56,17 @@ class BaseSchedule(ABC):
data, label = self.batch_data_process_func(batch_data) data, label = self.batch_data_process_func(batch_data)
else: else:
data, label = batch_data data, label = batch_data
self._check_sanity(data, 'data')
if isinstance(label, (tuple, list)): self._check_sanity(label, 'label')
self.batch_size = label[0].size(0) if isinstance(data, torch.Tensor):
self.batch_size = data.size(0)
else: else:
self.batch_size = label.size(0) self.batch_size = next(iter(data.values())).size(0)
data, label = self._to_list(split_batch(data)), self._to_list(split_batch(label)) data, label = split_batch(data), split_batch(label)
if to_gpu:
return self._move_to_device(data), self._move_to_device(label) return self._move_to_device(data), self._move_to_device(label)
return data, label
def pre_processing(self, engine: Engine): def pre_processing(self, engine: Engine):
"""To perform actions before running the schedule. """To perform actions before running the schedule.
@ -76,7 +78,8 @@ class BaseSchedule(ABC):
engine: Engine, engine: Engine,
data_iter: Iterable, data_iter: Iterable,
forward_only: bool, forward_only: bool,
return_loss: bool = True return_loss: bool = True,
return_output_label: bool = True
): ):
"""The process function over a batch of dataset for training or evaluation. """The process function over a batch of dataset for training or evaluation.
@ -85,5 +88,24 @@ class BaseSchedule(ABC):
:param labels: ground truth :param labels: ground truth
:param forward_only: If True, the process won't include backward :param forward_only: If True, the process won't include backward
:param return_loss: If False, the loss won't be returned :param return_loss: If False, the loss won't be returned
:param return_output_label: If False, the output and label won't be returned
""" """
pass pass
@staticmethod
def _call_engine(engine, inputs):
if isinstance(inputs, torch.Tensor):
return engine(inputs)
else:
return engine(**inputs)
@staticmethod
def _call_engine_criterion(engine, outputs, labels):
assert isinstance(outputs, (torch.Tensor, list, tuple)
), f'Expect output of model is (torch.Tensor, list, tuple), got {type(outputs)}'
if isinstance(outputs, torch.Tensor):
outputs = (outputs, )
if isinstance(labels, torch.Tensor):
return engine.criterion(*outputs, labels)
else:
return engine.criterion(*outputs, **labels)

View File

@ -5,9 +5,7 @@ from typing import Iterable
import torch import torch
import torch.nn as nn
from colossalai.engine import Engine from colossalai.engine import Engine
from torch.optim import Optimizer
from ._base_schedule import BaseSchedule from ._base_schedule import BaseSchedule
from colossalai.utils import conditional_context from colossalai.utils import conditional_context
@ -27,17 +25,20 @@ class NonPipelineSchedule(BaseSchedule):
engine: Engine, engine: Engine,
data_iter: Iterable, data_iter: Iterable,
forward_only: bool = False, forward_only: bool = False,
return_loss: bool = True): return_loss: bool = True,
return_output_label: bool = True):
"""The process function that loads loads a batch of dataset and feeds it to the model. """The process function that loads loads a batch of dataset and feeds it to the model.
The returned labels and loss will None if :attr:`return_loss` is False. The returned labels and loss will None if :attr:`return_loss` is False.
:param engine: Model for training and inference :param engine: Model for training and inference
:param data_iter: Data iterator of the dataloader, e.g. iter(dataloader) :param data_iter: Data iterator of the dataloader, e.g. iter(dataloader)
:param forward_only: If True, the model is run for the forward pass, else back propagation will be executed :param forward_only: If True, the model is run for the forward pass, else back propagation will be executed
:param return_loss: Loss will be returned if True :param return_loss: Loss will be returned if True
:param return_output_label: Output and label will be returned if True
:type engine: Iterator :type engine: Iterator
:type data_iter: Iterator :type data_iter: Iterator
:type forward_only: bool, optional :type forward_only: bool, optional
:type return_loss: bool, optional :type return_loss: bool, optional
:type return_output_label: bool, optional
:return: (output, label, loss) :return: (output, label, loss)
:rtype: Tuple[:class:`torch.Tensor`] :rtype: Tuple[:class:`torch.Tensor`]
@ -48,16 +49,20 @@ class NonPipelineSchedule(BaseSchedule):
# forward # forward
with conditional_context(torch.no_grad(), enable=forward_only): with conditional_context(torch.no_grad(), enable=forward_only):
output = engine(*data) output = self._call_engine(engine, data)
if not isinstance(output, (tuple, list)):
output = (output,)
if return_loss: if return_loss:
loss = engine.criterion(*output, *label) loss = self._call_engine_criterion(engine, output, label)
if not forward_only: if not forward_only:
engine.backward(loss) engine.backward(loss)
if return_output_label:
if return_loss: if return_loss:
return output, label, loss return output, label, loss
else: else:
return output, None, None return output, label, None
else:
if return_loss:
return None, None, loss
else:
return None, None, None

View File

@ -1,19 +1,19 @@
#!/usr/bin/env python #!/usr/bin/env python
# -*- encoding: utf-8 -*- # -*- encoding: utf-8 -*-
from typing import Union from typing import List, Tuple, Union, Callable
import inspect
import torch.cuda import torch.cuda
import torch.distributed as dist
from torch import Tensor from torch import Tensor
from colossalai.communication import * from colossalai.communication import *
from colossalai.context.parallel_mode import ParallelMode from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc from colossalai.core import global_context as gpc
from colossalai.amp.naive_amp import NaiveAMPModel from colossalai.amp.naive_amp import NaiveAMPModel
from colossalai.utils.cuda import get_current_device
from colossalai.zero import (ZeroRedundancyOptimizer_Level_2, from colossalai.zero import (ZeroRedundancyOptimizer_Level_2,
ZeroRedundancyOptimizer_Level_3) ZeroRedundancyOptimizer_Level_3)
from colossalai.utils import get_current_device, switch_virtual_pipeline_parallel_rank from colossalai.utils import switch_virtual_pipeline_parallel_rank
from ._base_schedule import BaseSchedule from ._base_schedule import BaseSchedule
@ -30,102 +30,79 @@ class PipelineSchedule(BaseSchedule):
:class:`NonPipelineSchedule`. :class:`NonPipelineSchedule`.
:param num_microbatches: The number of microbatches :param num_microbatches: The number of microbatches
:param amp_type: The type of automatic mixed precision
:param amp_config: The configuration of automatic mixed procision
:param sync_data: If set to `True`, will sync data every batch over pipeline stages
:type num_microbatches: int :type num_microbatches: int
:type amp_type: AMP_TYPE :param batch_data_process_func: The preprocessing function which receives a batch of data, and it will be executed in `load_batch`
:type amp_config: dict :type batch_data_process_func: Callable
:type sync_data: bool
""" """
def __init__(self, def __init__(self,
num_microbatches, num_microbatches,
sync_data: bool = True): batch_data_process_func: Callable = None,
super().__init__() tensor_shape: Union[torch.Size, List[int], Tuple[int]] = None):
super().__init__(batch_data_process_func=batch_data_process_func)
self.num_microbatches = num_microbatches self.num_microbatches = num_microbatches
self.sync_data = sync_data
self.dtype = torch.float self.dtype = torch.float
self.tensor_shape = tensor_shape
def _move_to_device(self, data):
if isinstance(data, (
tuple,
list,
)):
assert len(data) == 1, "Data tuple's length in pipeline should be 1"
data = data[0]
assert torch.is_tensor(data), "Data in pipeline should be tensor"
data = data.to(get_current_device()).detach()
return data
def _sync_data(self):
reqs = []
if gpc.is_first_rank(ParallelMode.PIPELINE):
src_rank = gpc.get_global_rank()
reqs.append(dist.broadcast(
tensor=self.batch_data,
src=src_rank,
group=gpc.get_group(ParallelMode.PIPELINE_PREV),
async_op=True
))
reqs.append(dist.broadcast(
tensor=self.batch_label,
src=src_rank,
group=gpc.get_group(ParallelMode.PIPELINE_PREV),
async_op=True
))
if gpc.is_last_rank(ParallelMode.PIPELINE):
src_rank = gpc.get_next_global_rank(ParallelMode.PIPELINE)
reqs.append(dist.broadcast(
tensor=self.batch_data,
src=src_rank,
group=gpc.get_group(ParallelMode.PIPELINE_NEXT),
async_op=True
))
reqs.append(dist.broadcast(
tensor=self.batch_label,
src=src_rank,
group=gpc.get_group(ParallelMode.PIPELINE_NEXT),
async_op=True
))
for req in reqs:
req.wait()
# Pipeline schedule just puts data in memory
def load_batch(self, data_iter): def load_batch(self, data_iter):
if data_iter is None: # Pipeline schedule just puts data in memory
raise RuntimeError('Dataloader is not defined.') self.batch_data, self.batch_label = super().load_batch(data_iter, to_gpu=False)
self.batch_pos = 0 self.microbatch_offset = 0
data, label = next(data_iter) assert self.batch_size % self.num_microbatches == 0, \
self.batch_data, self.batch_label = \
self._move_to_device(data), self._move_to_device(label)
batch_size = self.batch_data.shape[0]
assert batch_size % self.num_microbatches == 0, \
"Batch size should divided by the number of microbatches" "Batch size should divided by the number of microbatches"
self.microbatch_size = batch_size // self.num_microbatches self.microbatch_size = self.batch_size // self.num_microbatches
if self.sync_data:
self._sync_data()
def _get_data_slice(self, tensor): def _get_data_slice(self, data, offset):
return tensor[self.batch_pos: self.batch_pos + self.microbatch_size] if isinstance(data, torch.Tensor):
return data[offset: offset + self.microbatch_size]
else:
return {k: v[offset:offset + self.microbatch_size] for k, v in data.items()}
def load_micro_batch(self): def load_micro_batch(self):
data = self._get_data_slice(self.batch_data) data = self._get_data_slice(self.batch_data, self.microbatch_offset)
label = self._get_data_slice(self.batch_label) label = self._get_data_slice(self.batch_label, self.microbatch_offset)
self.batch_pos += self.microbatch_size self.microbatch_offset += self.microbatch_size
return (data,), (label,) return self._move_to_device(data), self._move_to_device(label)
def pre_processing(self, engine): def pre_processing(self, engine):
if isinstance(engine.optimizer, (ZeroRedundancyOptimizer_Level_2, ZeroRedundancyOptimizer_Level_3)): if isinstance(engine.optimizer, (ZeroRedundancyOptimizer_Level_2, ZeroRedundancyOptimizer_Level_3)):
raise TypeError( raise TypeError(
"Pipeline schedule is currently not compatible with ZeRO Level 2 and Level 3" "Pipeline schedule is currently not compatible with ZeRO Level 2 and Level 3"
) )
model = engine.model
if isinstance(engine.model, NaiveAMPModel): if isinstance(model, NaiveAMPModel):
self.dtype = torch.half self.dtype = torch.half
model = model.model
sig = inspect.signature(model.forward)
for p in sig.parameters.values():
assert p.kind != inspect.Parameter.VAR_POSITIONAL, '*args is not supported'
def forward_step(self, engine, input_tensor, return_tensors, return_loss=True): @staticmethod
def _call_engine(model, input_tensor, batch_data):
if isinstance(model, NaiveAMPModel):
sig = inspect.signature(model.model.forward)
else:
sig = inspect.signature(model.forward)
if isinstance(batch_data, torch.Tensor):
if input_tensor is None:
return model(batch_data)
elif len(sig.parameters) > 1:
return model(input_tensor, batch_data)
else:
return model(input_tensor)
else:
filter_batch = True
for p in sig.parameters.values():
if p.kind == inspect.Parameter.VAR_KEYWORD:
filter_batch = False
if filter_batch:
batch_data = {k: v for k, v in batch_data.items() if k in sig.parameters}
if input_tensor is None:
return model(**batch_data)
else:
return model(input_tensor, **batch_data)
def forward_step(self, engine, input_tensor, return_tensors, return_output_label=True, accum_loss=None):
"""Forward step for passed-in model. If it is the first stage, the input tensor """Forward step for passed-in model. If it is the first stage, the input tensor
is obtained from data_iterator, otherwise the passed-in input_tensor is used. is obtained from data_iterator, otherwise the passed-in input_tensor is used.
Returns output tensor. This is a helper function and can be ignored by users. Returns output tensor. This is a helper function and can be ignored by users.
@ -140,26 +117,19 @@ class PipelineSchedule(BaseSchedule):
:return: output or the loss value of the current pipeline stage :return: output or the loss value of the current pipeline stage
:rtype: :class:`torch.Tensor` :rtype: :class:`torch.Tensor`
""" """
data, label = self.load_micro_batch()
if input_tensor is None: output_tensor = self._call_engine(engine.model, input_tensor, data)
input_tensor, label = self.load_micro_batch()
input_tensor = squeeze(input_tensor)
output_tensor = engine(input_tensor)
output_tensor = squeeze(output_tensor) output_tensor = squeeze(output_tensor)
if gpc.is_last_rank(ParallelMode.PIPELINE): if gpc.is_last_rank(ParallelMode.PIPELINE):
if return_loss: if return_output_label:
input_tensor, label = self.load_micro_batch() return_tensors.append(tuple((output_tensor, label)))
loss_reduced = engine.criterion(output_tensor, *label) \ if accum_loss is not None:
/ self.num_microbatches loss_reduced = self._call_engine_criterion(engine, output_tensor, label) / self.num_microbatches
accum_loss.add_(loss_reduced.detach())
return_tensors.append(
tuple((output_tensor, label[0], loss_reduced)))
return loss_reduced return loss_reduced
else: else:
return_tensors.append(output_tensor)
return output_tensor return output_tensor
else: else:
return output_tensor return output_tensor
@ -203,7 +173,8 @@ class PipelineSchedule(BaseSchedule):
engine, engine,
data_iter, data_iter,
forward_only=False, forward_only=False,
return_loss=True): return_loss=True,
return_output_label=True):
"""Runs non-interleaved 1F1B schedule, with communication between pipeline stages. """Runs non-interleaved 1F1B schedule, with communication between pipeline stages.
Returns a tuple with losses if the last stage, an empty tuple otherwise. Returns a tuple with losses if the last stage, an empty tuple otherwise.
@ -215,6 +186,8 @@ class PipelineSchedule(BaseSchedule):
:type forward_only: bool :type forward_only: bool
:param return_loss: whether returns the loss value. Default is true. :param return_loss: whether returns the loss value. Default is true.
:type return_loss: bool :type return_loss: bool
:param return_output_label: If False, the output and label won't be returned
:type return_output_label: bool
:return: (output, label, loss) :return: (output, label, loss)
:rtype: Tuple[:class:`torch.Tensor`] :rtype: Tuple[:class:`torch.Tensor`]
@ -238,11 +211,14 @@ class PipelineSchedule(BaseSchedule):
input_tensors = [] input_tensors = []
output_tensors = [] output_tensors = []
return_tensors = [] return_tensors = []
if return_loss and gpc.is_pipeline_last_stage(ignore_virtual=True):
accum_loss = torch.zeros(1, device=get_current_device())
else:
accum_loss = None
# Used for tensor meta information communication # Used for tensor meta information communication
ft_shape = None ft_shape = self.tensor_shape
bt_shape = None bt_shape = None
fs_checker = True fs_checker = self.tensor_shape is None
# Run warmup forward passes. # Run warmup forward passes.
for i in range(num_warmup_microbatches): for i in range(num_warmup_microbatches):
@ -251,7 +227,8 @@ class PipelineSchedule(BaseSchedule):
input_tensor = recv_forward(ft_shape, dtype=self.dtype) input_tensor = recv_forward(ft_shape, dtype=self.dtype)
output_tensor = self.forward_step( output_tensor = self.forward_step(
engine, input_tensor, return_tensors, engine, input_tensor, return_tensors,
return_loss=return_loss return_output_label=return_output_label,
accum_loss=accum_loss
) )
if not gpc.is_last_rank(ParallelMode.PIPELINE): if not gpc.is_last_rank(ParallelMode.PIPELINE):
bt_shape = output_tensor.shape bt_shape = output_tensor.shape
@ -276,7 +253,8 @@ class PipelineSchedule(BaseSchedule):
output_tensor = self.forward_step( output_tensor = self.forward_step(
engine, input_tensor, return_tensors, engine, input_tensor, return_tensors,
return_loss=return_loss return_output_label=return_output_label,
accum_loss=accum_loss
) )
if forward_only: if forward_only:
send_forward(output_tensor) send_forward(output_tensor)
@ -327,24 +305,37 @@ class PipelineSchedule(BaseSchedule):
send_backward(input_tensor_grad) send_backward(input_tensor_grad)
if len(return_tensors) > 0: if len(return_tensors) > 0:
if return_loss: output, label = tuple(map(list, zip(*return_tensors)))
output, label, loss = tuple(map(list, zip(*return_tensors)))
return (torch.cat(output, dim=0), return (torch.cat(output, dim=0),
torch.cat(label, dim=0), torch.cat(label, dim=0),
sum(loss)) accum_loss)
else: else:
return tuple((torch.cat(return_tensors, dim=0), None, None)) return tuple((None, None, accum_loss))
else:
return tuple((None, None, None))
class InterleavedPipelineSchedule(PipelineSchedule): class InterleavedPipelineSchedule(PipelineSchedule):
def __init__(self, num_microbatches, num_model_chunks, sync_data: bool = True): def __init__(self,
num_microbatches,
num_model_chunks,
batch_data_process_func: Callable = None,
tensor_shape: Union[torch.Size, List[int], Tuple[int]] = None):
"""A helper schedule class for pipeline parallelism running environment.
It uses interleaved 1F1B strategy. Other properties are similar as
:class:`NonPipelineSchedule`.
:param num_microbatches: The number of microbatches
:type num_microbatches: int
:param num_model_chunks: The number of model chunks
:type num_model_chunks: int
:param batch_data_process_func: The preprocessing function which receives a batch of data, and it will be executed in `load_batch`
:type batch_data_process_func: Callable
"""
assert num_microbatches % gpc.get_world_size(ParallelMode.PIPELINE) == 0, \ assert num_microbatches % gpc.get_world_size(ParallelMode.PIPELINE) == 0, \
'num_microbatches must be an integer multiple of pipeline parallel world size' 'num_microbatches must be an integer multiple of pipeline parallel world size'
super().__init__(num_microbatches, sync_data=sync_data) super().__init__(num_microbatches, batch_data_process_func=batch_data_process_func, tensor_shape=tensor_shape)
gpc.set_virtual_pipeline_parallel_size(num_model_chunks) gpc.set_virtual_pipeline_parallel_size(num_model_chunks)
gpc.set_virtual_pipeline_parallel_rank(0) gpc.set_virtual_pipeline_parallel_rank(0)
self.num_model_chunks = num_model_chunks
def pre_processing(self, engine): def pre_processing(self, engine):
if isinstance(engine.optimizer, (ZeroRedundancyOptimizer_Level_2, ZeroRedundancyOptimizer_Level_3)): if isinstance(engine.optimizer, (ZeroRedundancyOptimizer_Level_2, ZeroRedundancyOptimizer_Level_3)):
@ -355,32 +346,46 @@ class InterleavedPipelineSchedule(PipelineSchedule):
if isinstance(engine.model[0], NaiveAMPModel): if isinstance(engine.model[0], NaiveAMPModel):
self.dtype = torch.half self.dtype = torch.half
def forward_step(self, engine, model, input_tensor, return_tensors, return_loss=True): for model in engine.model:
if isinstance(model, NaiveAMPModel):
model = model.model
sig = inspect.signature(model.forward)
for p in sig.parameters.values():
assert p.kind != inspect.Parameter.VAR_POSITIONAL, '*args is not supported'
def load_batch(self, data_iter):
super().load_batch(data_iter)
# overwrite microbatch_offset, since model chunks load the same microbatch, and should tract the offset
self.microbatch_offset = [0 for _ in range(self.num_model_chunks)]
def load_micro_batch(self, model_chunk_id):
data = self._get_data_slice(self.batch_data, self.microbatch_offset[model_chunk_id])
label = self._get_data_slice(self.batch_label, self.microbatch_offset[model_chunk_id])
self.microbatch_offset[model_chunk_id] += self.microbatch_size
return self._move_to_device(data), self._move_to_device(label)
def forward_step(self, engine, model_chunk_id, input_tensor, return_tensors, return_output_label=True, accum_loss=None):
"""Forward step for passed-in model. If it is the first stage, the input tensor """Forward step for passed-in model. If it is the first stage, the input tensor
is obtained from data_iterator, otherwise the passed-in input_tensor is used. is obtained from data_iterator, otherwise the passed-in input_tensor is used.
Returns output tensor. This is a helper function and can be ignored by users. Returns output tensor. This is a helper function and can be ignored by users.
""" """
data, label = self.load_micro_batch(model_chunk_id)
if input_tensor is None: output_tensor = self._call_engine(engine.model[model_chunk_id], input_tensor, data)
input_tensor, label = self.load_micro_batch()
input_tensor = squeeze(input_tensor)
output_tensor = model(input_tensor)
output_tensor = squeeze(output_tensor) output_tensor = squeeze(output_tensor)
if gpc.is_pipeline_last_stage(): if gpc.is_pipeline_last_stage():
if return_loss: if return_output_label:
input_tensor, label = self.load_micro_batch() return_tensors.append(tuple(output_tensor, label))
loss_reduced = engine.criterion(output_tensor, *label) / self.num_microbatches if accum_loss is not None:
return_tensors.append( loss_reduced = self._call_engine_criterion(engine, output_tensor, label) / self.num_microbatches
tuple((output_tensor, label[0], loss_reduced))) accum_loss.add_(loss_reduced.detach())
return loss_reduced return loss_reduced
else: else:
return_tensors.append(output_tensor)
return output_tensor return output_tensor
else: else:
return output_tensor return output_tensor
def forward_backward_step(self, engine, data_iter, forward_only=False, return_loss=True): def forward_backward_step(self, engine, data_iter, forward_only=False, return_loss=True, return_output_label=True):
"""Run interleaved 1F1B schedule (model split into model chunks), with """Run interleaved 1F1B schedule (model split into model chunks), with
communication between pipeline stages as needed. communication between pipeline stages as needed.
@ -394,11 +399,15 @@ class InterleavedPipelineSchedule(PipelineSchedule):
return_tensors = [] return_tensors = []
if not forward_only: if not forward_only:
output_tensor_grads = [[] for _ in range(len(model))] output_tensor_grads = [[] for _ in range(len(model))]
if return_loss and gpc.is_pipeline_last_stage(ignore_virtual=True):
accum_loss = torch.zeros(1, device=get_current_device())
else:
accum_loss = None
# Used for tensor meta information communication # Used for tensor meta information communication
input_tensor_shapes = [None for _ in range(len(model))] input_tensor_shapes = [self.tensor_shape for _ in range(len(model))]
output_tensor_shapes = [None for _ in range(len(model))] output_tensor_shapes = [None for _ in range(len(model))]
send_tensor_shape_flags = [True for _ in range(len(model))] send_tensor_shape_flags = [self.tensor_shape is None for _ in range(len(model))]
pipeline_parallel_size = gpc.get_world_size(ParallelMode.PIPELINE) pipeline_parallel_size = gpc.get_world_size(ParallelMode.PIPELINE)
pipeline_parallel_rank = gpc.get_local_rank(ParallelMode.PIPELINE) pipeline_parallel_rank = gpc.get_local_rank(ParallelMode.PIPELINE)
@ -450,8 +459,8 @@ class InterleavedPipelineSchedule(PipelineSchedule):
len(output_tensors[model_chunk_id]): len(output_tensors[model_chunk_id]):
input_tensors[model_chunk_id].append(None) input_tensors[model_chunk_id].append(None)
input_tensor = input_tensors[model_chunk_id][-1] input_tensor = input_tensors[model_chunk_id][-1]
output_tensor = self.forward_step( output_tensor = self.forward_step(engine, model_chunk_id, input_tensor,
engine, model[model_chunk_id], input_tensor, return_tensors, return_loss=return_loss) return_tensors, return_output_label=return_output_label, accum_loss=accum_loss)
output_tensors[model_chunk_id].append(output_tensor) output_tensors[model_chunk_id].append(output_tensor)
# if forward-only, no need to save tensors for a backward pass # if forward-only, no need to save tensors for a backward pass
@ -633,12 +642,9 @@ class InterleavedPipelineSchedule(PipelineSchedule):
dtype=self.dtype)) dtype=self.dtype))
if len(return_tensors) > 0: if len(return_tensors) > 0:
if return_loss: output, label = tuple(map(list, zip(*return_tensors)))
output, label, loss = tuple(map(list, zip(*return_tensors)))
return (torch.cat(output, dim=0), return (torch.cat(output, dim=0),
torch.cat(label, dim=0), torch.cat(label, dim=0),
sum(loss)) accum_loss)
else: else:
return tuple((torch.cat(return_tensors, dim=0), None, None)) return tuple((None, None, accum_loss))
else:
return tuple((None, None, None))

View File

@ -338,6 +338,19 @@ def initialize(model: Union[nn.Module, List[nn.Module]],
"Data parallel training is detected when using pipeline parallel, DataParallelGradientHandler is automatically " "Data parallel training is detected when using pipeline parallel, DataParallelGradientHandler is automatically "
"added even though not specified in the configuration", "added even though not specified in the configuration",
ranks=[0]) ranks=[0])
# add pipeline parallel gradient handler, if pipeline shared module is detected
for param in model.parameters():
if getattr(param, 'pipeline_shared_module_pg', None) is not None:
if gradient_handler_cfg is None:
gradient_handler_cfg = [dict(type='PipelineSharedModuleGradientHandler')]
else:
gradient_handler_cfg.append(dict(type='PipelineSharedModuleGradientHandler'))
if verbose:
logger.info(
"pipeline_shared_module is detected, PipelineSharedModuleGradientHandler is automatically "
"added even though not specified in the configuration",
ranks=[0])
break
else: else:
if not isinstance(gradient_handler_cfg, list): if not isinstance(gradient_handler_cfg, list):
raise ConfigException( raise ConfigException(

View File

@ -11,8 +11,8 @@ _parallel_split_batch = {'2d': split_tensor_2d, '2.5d': split_tensor_2p5d, '3d':
def split_batch(input_) -> Tensor: def split_batch(input_) -> Tensor:
tensor_parallel_mode = get_tensor_parallel_mode() tensor_parallel_mode = get_tensor_parallel_mode()
if tensor_parallel_mode in _parallel_split_batch: if tensor_parallel_mode in _parallel_split_batch:
if isinstance(input_, (tuple, list)): if isinstance(input_, dict):
return tuple(map(_parallel_split_batch[tensor_parallel_mode], input_)) return {k: _parallel_split_batch[tensor_parallel_mode](v) for k, v in input_.items()}
else: else:
return _parallel_split_batch[tensor_parallel_mode](input_) return _parallel_split_batch[tensor_parallel_mode](input_)
else: else:

View File

@ -1,3 +1,4 @@
from .lambda_wrapper import LambdaWrapper from .lambda_wrapper import LambdaWrapper
from .pipeline_wrapper import PipelineSharedModuleWrapper
__all__ = ['LambdaWrapper'] __all__ = ['LambdaWrapper', 'PipelineSharedModuleWrapper']

View File

@ -0,0 +1,40 @@
import torch.nn as nn
import torch.distributed as dist
from typing import List, Tuple, Union
from colossalai.context import ParallelMode
from colossalai.core import global_context as gpc
class PipelineSharedModuleWrapper:
def __init__(self, pipeline_ranks: Union[List[int], Tuple[int]]) -> None:
assert len(pipeline_ranks) > 1, f'Expect len(pipeline_ranks) > 1, got {len(pipeline_ranks)}'
self.pipeline_ranks = pipeline_ranks
self.group = None
self.ranks_in_group = None
self._init_group()
def _init_group(self):
world_size = gpc.get_world_size(ParallelMode.GLOBAL)
dp_size = gpc.get_world_size(ParallelMode.DATA)
pp_size = gpc.get_world_size(ParallelMode.PIPELINE)
rank = gpc.get_global_rank()
num_dp_groups = world_size // dp_size
num_pp_stages = num_dp_groups // pp_size
for i in range(dp_size):
for j in range(num_pp_stages):
pipeline_ranks = list(
range(i * num_dp_groups + j,
(i + 1) * num_dp_groups,
num_pp_stages))
sub_ranks = [pipeline_ranks[idx] for idx in self.pipeline_ranks]
group = dist.new_group(sub_ranks)
if rank in sub_ranks:
self.group = group
self.ranks_in_group = sub_ranks
def register_module(self, module: nn.Module):
assert self.ranks_in_group is not None, f'Rank {gpc.get_local_rank(ParallelMode.PIPELINE)} is not in pipeline_ranks {self.pipeline_ranks}'
src = self.ranks_in_group[self.pipeline_ranks[0]]
for p in module.parameters():
setattr(p, 'pipeline_shared_module_pg', self.group)
dist.broadcast(p, src, group=self.group)

View File

@ -155,7 +155,8 @@ class Trainer:
def _train_epoch(self, def _train_epoch(self,
train_dataloader: DataLoader, train_dataloader: DataLoader,
epoch: int = None, epoch: int = None,
display_progress: bool = False): display_progress: bool = False,
return_output_label: bool = True):
# set training state # set training state
self._engine.train() self._engine.train()
data_iter = iter(train_dataloader) data_iter = iter(train_dataloader)
@ -175,7 +176,7 @@ class Trainer:
# run 1 training step # run 1 training step
self.engine.zero_grad() self.engine.zero_grad()
logits, label, loss = self.schedule.forward_backward_step( logits, label, loss = self.schedule.forward_backward_step(
self.engine, data_iter, forward_only=False, return_loss=True) self.engine, data_iter, forward_only=False, return_loss=True, return_output_label=return_output_label)
self.engine.step() self.engine.step()
self._call_timer(action='stop', item='Train-step', keep_in_history=True) self._call_timer(action='stop', item='Train-step', keep_in_history=True)
self._call_hooks('after_train_iter', output=(logits, label, loss)) self._call_hooks('after_train_iter', output=(logits, label, loss))
@ -197,7 +198,8 @@ class Trainer:
def _eval(self, def _eval(self,
test_dataloader: DataLoader, test_dataloader: DataLoader,
epoch: int = None, epoch: int = None,
display_progress: bool = False): display_progress: bool = False,
return_output_label: bool = True):
# switch engine status # switch engine status
self._engine.eval() self._engine.eval()
@ -220,7 +222,7 @@ class Trainer:
self._call_hooks('before_test_iter') self._call_hooks('before_test_iter')
self._call_timer(action='start', item='Test-step') self._call_timer(action='start', item='Test-step')
logits, label, loss = self.schedule.forward_backward_step( logits, label, loss = self.schedule.forward_backward_step(
self.engine, data_iter, forward_only=True, return_loss=True) self.engine, data_iter, forward_only=True, return_loss=True, return_output_label=return_output_label)
self._call_timer(action='stop', item='Test-step', keep_in_history=True) self._call_timer(action='stop', item='Test-step', keep_in_history=True)
self._call_hooks('after_test_iter', self._call_hooks('after_test_iter',
output=(logits, label, loss)) output=(logits, label, loss))
@ -246,6 +248,7 @@ class Trainer:
test_interval: int = 1, test_interval: int = 1,
hooks: List[BaseHook] = None, hooks: List[BaseHook] = None,
display_progress: bool = False, display_progress: bool = False,
return_output_label: bool = True,
): ):
"""Trains the model to fit training data. """Trains the model to fit training data.
@ -256,6 +259,8 @@ class Trainer:
:param test_interval: Interval of testing :param test_interval: Interval of testing
:param hooks_cfg: A list of hook configuration :param hooks_cfg: A list of hook configuration
:param display_progress: If True, the training progress will be printed :param display_progress: If True, the training progress will be printed
:param return_output_label: If True, the output of model and the label will be returned
:type return_output_label: bool
:type train_dataloader: DataLoader :type train_dataloader: DataLoader
:type epochs: int :type epochs: int
:type max_steps: int :type max_steps: int
@ -307,7 +312,8 @@ class Trainer:
self._train_epoch( self._train_epoch(
train_dataloader=train_dataloader, train_dataloader=train_dataloader,
epoch=epoch, epoch=epoch,
display_progress=display_progress display_progress=display_progress,
return_output_label=return_output_label
) )
# start eval # start eval
@ -315,6 +321,7 @@ class Trainer:
self._eval(test_dataloader=test_dataloader, self._eval(test_dataloader=test_dataloader,
display_progress=display_progress, display_progress=display_progress,
epoch=epoch, epoch=epoch,
return_output_label=return_output_label
) )
self._cur_epoch += 1 self._cur_epoch += 1
@ -331,13 +338,16 @@ class Trainer:
def evaluate(self, def evaluate(self,
test_dataloader: DataLoader, test_dataloader: DataLoader,
hooks: List[BaseHook] = None, hooks: List[BaseHook] = None,
display_progress: bool = False): display_progress: bool = False,
return_output_label: bool = True):
"""Evaluates the model with testing data. """Evaluates the model with testing data.
:param test_dataloader: DataLoader in testing :param test_dataloader: DataLoader in testing
:param display_progress: If True, the evaluation progress will be printed :param display_progress: If True, the evaluation progress will be printed
:param return_output_label: If True, the output of model and the label will be returned
:type test_dataloader: DataLoader :type test_dataloader: DataLoader
:type display_progress: bool, optional :type display_progress: bool, optional
:type return_output_label: bool
""" """
# set display # set display
display_progress = self._should_display_progress(display_progress) display_progress = self._should_display_progress(display_progress)
@ -360,6 +370,7 @@ class Trainer:
# eval # eval
self._eval(test_dataloader=test_dataloader, self._eval(test_dataloader=test_dataloader,
display_progress=display_progress, display_progress=display_progress,
return_output_label=return_output_label
) )
def predict(self, data: Union[Tensor, List[Tensor]]): def predict(self, data: Union[Tensor, List[Tensor]]):

View File

@ -155,22 +155,12 @@ def clip_grad_norm_fp32(parameters, max_norm, norm_type=2):
if norm_type == inf: if norm_type == inf:
total_norm = max(p.grad.data.abs().max() for p in params) total_norm = max(p.grad.data.abs().max() for p in params)
total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)]) total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)])
ops = []
# Take max across all model-parallel GPUs. # Take max across all model-parallel GPUs.
if gpc.is_initialized(ParallelMode.TENSOR) and gpc.get_world_size(ParallelMode.TENSOR) > 1: if gpc.is_initialized(ParallelMode.MODEL) and gpc.get_world_size(ParallelMode.MODEL) > 1:
ops.append(dist.all_reduce(total_norm_cuda, dist.all_reduce(total_norm_cuda,
op=dist.ReduceOp.MAX, op=dist.ReduceOp.MAX,
group=gpc.get_group( group=gpc.get_group(ParallelMode.MODEL),
ParallelMode.TENSOR), async_op=False)
async_op=True))
if gpc.is_initialized(ParallelMode.PIPELINE) and gpc.get_world_size(ParallelMode.PIPELINE) > 1:
ops.append(dist.all_reduce(total_norm_cuda,
op=dist.ReduceOp.MAX,
group=gpc.get_group(
ParallelMode.PIPELINE),
async_op=True))
for req in ops:
req.wait()
total_norm = total_norm_cuda[0].item() total_norm = total_norm_cuda[0].item()
else: else:
tensor_parallel_grads = [] tensor_parallel_grads = []

View File

@ -65,6 +65,7 @@ class GradAccumOptimizer(ColossalaiOptimizer):
self.optim.backward(scaled_loss) self.optim.backward(scaled_loss)
def backward_by_grad(self, tensor: Tensor, grad: Tensor): def backward_by_grad(self, tensor: Tensor, grad: Tensor):
self.accumulate_step += 1
no_sync = self.is_torch_ddp and self.accumulate_step < self.accumulate_size no_sync = self.is_torch_ddp and self.accumulate_step < self.accumulate_size
if no_sync: if no_sync:

View File

@ -26,8 +26,6 @@ follow the steps below to create a new distributed initialization.
GLOBAL = 'global' GLOBAL = 'global'
DATA = 'data' DATA = 'data'
PIPELINE = 'pipe' PIPELINE = 'pipe'
PIPELINE_PREV = 'pipe_prev'
PIPELINE_NEXT = 'pipe_next'
... ...
NEW_MODE = 'new_mode' # define your mode here NEW_MODE = 'new_mode' # define your mode here

View File

@ -18,8 +18,6 @@ class ParallelMode(Enum):
GLOBAL = 'global' GLOBAL = 'global'
DATA = 'data' DATA = 'data'
PIPELINE = 'pipe' PIPELINE = 'pipe'
PIPELINE_PREV = 'pipe_prev'
PIPELINE_NEXT = 'pipe_next'
... ...
NEW_MODE = 'new_mode' # define your mode here NEW_MODE = 'new_mode' # define your mode here

View File

@ -33,6 +33,12 @@ def check_pipeline_parallel_rank(rank):
assert gpc.get_local_rank(ParallelMode.PIPELINE) == 1 assert gpc.get_local_rank(ParallelMode.PIPELINE) == 1
def check_model_parallel_rank(rank):
for i in range(8):
if rank in [i, i+8]:
assert gpc.get_local_rank(ParallelMode.MODEL) == i
def check_tensor_parallel_rank(rank): def check_tensor_parallel_rank(rank):
if rank in [0, 4, 8, 12]: if rank in [0, 4, 8, 12]:
assert gpc.get_local_rank(ParallelMode.TENSOR) == 0 assert gpc.get_local_rank(ParallelMode.TENSOR) == 0
@ -75,6 +81,7 @@ def init_2d(rank, world_size, backend, port, host):
check_data_parallel_rank(rank) check_data_parallel_rank(rank)
check_2d_parallel_rank(rank) check_2d_parallel_rank(rank)
check_pipeline_parallel_rank(rank) check_pipeline_parallel_rank(rank)
check_model_parallel_rank(rank)
gpc.destroy() gpc.destroy()
torch.cuda.empty_cache() torch.cuda.empty_cache()

View File

@ -37,6 +37,12 @@ def check_pipeline_parallel_rank(rank):
assert ppr == 1 assert ppr == 1
def check_model_parallel_rank(rank):
for i in range(16):
if rank in [i, i+16]:
assert gpc.get_local_rank(ParallelMode.MODEL) == i
def check_tensor_parallel_rank(rank): def check_tensor_parallel_rank(rank):
tp_rank = gpc.get_local_rank(ParallelMode.TENSOR) tp_rank = gpc.get_local_rank(ParallelMode.TENSOR)
@ -98,6 +104,7 @@ def init_2halfd(rank, world_size, backend, port, host):
check_pipeline_parallel_rank(rank) check_pipeline_parallel_rank(rank)
check_tensor_parallel_rank(rank) check_tensor_parallel_rank(rank)
check_2p5d_parallel_rank(rank) check_2p5d_parallel_rank(rank)
check_model_parallel_rank(rank)
gpc.destroy() gpc.destroy()
torch.cuda.empty_cache() torch.cuda.empty_cache()

View File

@ -37,6 +37,12 @@ def check_pipeline_parallel_rank(rank):
assert ppr == 1 assert ppr == 1
def check_model_parallel_rank(rank):
for i in range(16):
if rank in [i, i+16]:
assert gpc.get_local_rank(ParallelMode.MODEL) == i
def check_tensor_parallel_rank(rank): def check_tensor_parallel_rank(rank):
tp_rank = gpc.get_local_rank(ParallelMode.TENSOR) tp_rank = gpc.get_local_rank(ParallelMode.TENSOR)
@ -90,6 +96,7 @@ def init_3d(rank, world_size, backend, port, host):
check_3d_parallel_rank(rank) check_3d_parallel_rank(rank)
check_data_parallel_rank(rank) check_data_parallel_rank(rank)
check_pipeline_parallel_rank(rank) check_pipeline_parallel_rank(rank)
check_model_parallel_rank(rank)
gpc.destroy() gpc.destroy()
torch.cuda.empty_cache() torch.cuda.empty_cache()

View File

@ -23,7 +23,7 @@ BATCH_SIZE = 16
NUM_EPOCHS = 60 NUM_EPOCHS = 60
WARMUP_EPOCHS = 5 WARMUP_EPOCHS = 5
CONFIG = dict(parallel=dict(pipeline=2, tensor=dict(size=2, mode='1d')), CONFIG = dict(parallel=dict(pipeline=2, tensor=dict(size=2, mode='1d')),
fp16=dict(mode=AMP_TYPE.TORCH), fp16=dict(mode=AMP_TYPE.NAIVE),
gradient_accumulation=2) gradient_accumulation=2)

View File

@ -75,40 +75,7 @@ def check_forward_backward(output_tensor, output_grad, rank, logger):
rank, check_equal(grad, output_grad))) rank, check_equal(grad, output_grad)))
def check_op(size, rank, prev_rank, next_rank, up_group, down_group, logger): def check_comm(size, rank, prev_rank, next_rank, logger):
dtype = torch.float32
device = get_current_device()
tensor_shape = (BATCH_SIZE, SEQ_LENGTH, HIDDEN_SIZE)
# recv_tensor_shape = (BATCH_SIZE, SEQ_LENGTH, HIDDEN_SIZE)
grad_shape = (BATCH_SIZE, SEQ_LENGTH, HIDDEN_SIZE)
tensor = torch.randn(tensor_shape, dtype=dtype, device=device)
dist.all_reduce(tensor)
grad = torch.randn(grad_shape, dtype=dtype, device=device)
dist.all_reduce(grad)
if rank % 2 == 0:
need_meta = True
need_meta = send_tensor_meta(tensor, need_meta)
logger.info('Rank {} shape sent (need meta: {}).'.format(
rank, need_meta))
req = dist.broadcast(tensor, src=rank, group=down_group, async_op=True)
req.wait()
out = tensor.clone()
logger.info('Rank {} test op: tensor sent.'.format(rank))
else:
recv_tensor_shape = recv_tensor_meta(None)
logger.info('Rank {} shape received. Correct shape: {}'.format(
rank, tensor_shape == recv_tensor_shape))
out = torch.empty(recv_tensor_shape, dtype=dtype, device=device)
req = dist.broadcast(out, src=prev_rank, group=up_group, async_op=True)
req.wait()
logger.info('Rank {} test op: received tensor ({})'.format(
rank, out.shape))
logger.info('Rank {} test op. Correct tensor: {}'.format(
rank, check_equal(tensor, out)))
def check_comm(size, rank, prev_rank, next_rank, up_group, down_group, logger):
dtype = torch.float32 dtype = torch.float32
device = get_current_device() device = get_current_device()
tensor_shape = (BATCH_SIZE, SEQ_LENGTH, HIDDEN_SIZE) tensor_shape = (BATCH_SIZE, SEQ_LENGTH, HIDDEN_SIZE)
@ -117,7 +84,6 @@ def check_comm(size, rank, prev_rank, next_rank, up_group, down_group, logger):
dist.all_reduce(tensor) dist.all_reduce(tensor)
grad = torch.randn(grad_shape, dtype=dtype, device=device) grad = torch.randn(grad_shape, dtype=dtype, device=device)
dist.all_reduce(grad) dist.all_reduce(grad)
check_op(size, rank, prev_rank, next_rank, up_group, down_group, logger)
check_forward(tensor, rank, logger) check_forward(tensor, rank, logger)
check_backward(grad, rank, logger) check_backward(grad, rank, logger)
check_forward_backward(tensor, grad, rank, logger) check_forward_backward(tensor, grad, rank, logger)
@ -135,18 +101,13 @@ def run_check(rank, world_size, port):
logger = get_dist_logger() logger = get_dist_logger()
rank = gpc.get_global_rank() rank = gpc.get_global_rank()
prev_rank = gpc.get_prev_global_rank(ParallelMode.PIPELINE) prev_rank = gpc.get_prev_global_rank(ParallelMode.PIPELINE)
up_ranks = gpc.get_ranks_in_group(ParallelMode.PIPELINE_PREV)
up_group = gpc.get_group(ParallelMode.PIPELINE_PREV)
next_rank = gpc.get_next_global_rank(ParallelMode.PIPELINE) next_rank = gpc.get_next_global_rank(ParallelMode.PIPELINE)
down_ranks = gpc.get_ranks_in_group(ParallelMode.PIPELINE_NEXT)
down_group = gpc.get_group(ParallelMode.PIPELINE_NEXT)
logger.info( logger.info(
'Rank {0}: prev rank {1} (up: {2}), next rank {3} (down: {4})'.format( 'Rank {0}: prev rank {1}, next rank {2}'.format(
rank, prev_rank, up_ranks, next_rank, down_ranks)) rank, prev_rank, next_rank))
logger.info('Distributed environment is initialzied.') logger.info('Distributed environment is initialzied.')
check_comm(world_size, rank, prev_rank, next_rank, up_group, down_group, check_comm(world_size, rank, prev_rank, next_rank, logger)
logger)
gpc.destroy() gpc.destroy()
torch.cuda.empty_cache() torch.cuda.empty_cache()