[hotfix] Raise messages for indivisible batch sizes with tensor parallelism (#622)

pull/657/head
Liang Bowen 2022-04-02 16:12:04 +08:00 committed by GitHub
parent e0f875a8e2
commit 828e465622
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
12 changed files with 66 additions and 29 deletions

View File

@ -1,12 +1,12 @@
import torch.nn as nn
from torch import Tensor
from ..parallel_2d._operation import split_tensor_2d
from ..parallel_2p5d._operation import split_tensor_2p5d
from ..parallel_2d._operation import split_batch_2d
from ..parallel_2p5d._operation import split_batch_2p5d
from ..parallel_3d._operation import split_batch_3d
from ..utils import get_tensor_parallel_mode
_parallel_split_batch = {'2d': split_tensor_2d, '2.5d': split_tensor_2p5d, '3d': split_batch_3d}
_parallel_split_batch = {'2d': split_batch_2d, '2.5d': split_batch_2p5d, '3d': split_batch_3d}
def partition_batch(input_) -> Tensor:

View File

@ -1,8 +1,8 @@
from ._operation import reduce_by_batch_2d, split_tensor_2d
from ._operation import reduce_by_batch_2d, split_batch_2d
from .layers import (Classifier2D, Embedding2D, LayerNorm2D, Linear2D, PatchEmbedding2D, VocabParallelClassifier2D,
VocabParallelEmbedding2D)
__all__ = [
'split_tensor_2d', 'reduce_by_batch_2d', 'Linear2D', 'LayerNorm2D', 'Classifier2D', 'PatchEmbedding2D',
'split_batch_2d', 'reduce_by_batch_2d', 'Linear2D', 'LayerNorm2D', 'Classifier2D', 'PatchEmbedding2D',
'Embedding2D', 'VocabParallelEmbedding2D', 'VocabParallelClassifier2D'
]

View File

@ -720,7 +720,7 @@ def all_gather_tensor_2d(tensor: Tensor, dim: int, parallel_mode: ParallelMode)
return _AllGatherTensor2D.apply(tensor, dim, parallel_mode)
def split_tensor_2d(input_: Tensor, dim: int = 0) -> Tensor:
def split_batch_2d(input_: Tensor, dim: int = 0) -> Tensor:
"""Splits 2D tensor in specified dimension across cols.
Args:
@ -730,6 +730,11 @@ def split_tensor_2d(input_: Tensor, dim: int = 0) -> Tensor:
Returns:
:class:`torch.tensor`: The tensor has been split.
"""
dim_size = input_.size(dim)
world_size = gpc.get_world_size(ParallelMode.PARALLEL_2D_COL)
assert dim_size % world_size == 0, \
f'The batch size ({dim_size}) is not a multiple of 2D size ({world_size}).'
if input_.size(dim) <= 1:
return input_
return torch.chunk(input_, gpc.get_world_size(ParallelMode.PARALLEL_2D_COL),
@ -784,6 +789,11 @@ def reduce_scatter_tensor_2d(tensor: Tensor, dim: int, parallel_mode: ParallelMo
The parallel_mode should be concluded in ``ParallelMode``. More details about ``ParallelMode`` could be found
in `parallel_mode <https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/context/parallel_mode.py>`_
"""
dim_size = tensor.size(dim)
world_size = gpc.get_world_size(parallel_mode)
assert dim_size % world_size == 0, \
f'The batch size ({dim_size}) is not a multiple of 2D size ({world_size}).'
return _ReduceScatterTensor2D.apply(tensor, dim, parallel_mode)

View File

@ -19,7 +19,7 @@ from torch.nn import Parameter
from ..base_layer import ParallelLayer
from ..utils import divide, set_tensor_parallel_attribute_by_partition, to_2tuple
from ._operation import (Matmul_AB_2D, Matmul_ABT_2D, add_bias_2d, all_gather_tensor_2d, classifier_2d, layernorm_2d,
reduce_scatter_tensor_2d, split_tensor_2d)
reduce_scatter_tensor_2d, split_batch_2d)
from ._utils import assert_summa_initialization, get_summa_dim_from_env
@ -547,7 +547,7 @@ class PatchEmbedding2D(ParallelLayer):
destination.update(local_state)
def forward(self, input_: Tensor) -> Tensor:
input_ = split_tensor_2d(input_)
input_ = split_batch_2d(input_)
B, C, H, W = input_.shape
assert H == self.img_size[0] and W == self.img_size[1], \
@ -692,7 +692,7 @@ class Embedding2D(ParallelLayer):
destination.update(local_state)
def forward(self, input_: Tensor) -> Tensor:
input_ = split_tensor_2d(input_)
input_ = split_batch_2d(input_)
weight = all_gather_tensor_2d(self.weight, -1, ParallelMode.PARALLEL_2D_COL)
output = F.embedding(input_, weight, self.padding_idx, *self.embed_args, **self.embed_kwargs)

View File

@ -1,8 +1,8 @@
from ._operation import reduce_by_batch_2p5d, split_tensor_2p5d
from ._operation import reduce_by_batch_2p5d, split_batch_2p5d
from .layers import (Classifier2p5D, Embedding2p5D, LayerNorm2p5D, Linear2p5D, PatchEmbedding2p5D,
VocabParallelClassifier2p5D, VocabParallelEmbedding2p5D)
__all__ = [
'split_tensor_2p5d', 'reduce_by_batch_2p5d', 'Linear2p5D', 'LayerNorm2p5D', 'Classifier2p5D', 'PatchEmbedding2p5D',
'split_batch_2p5d', 'reduce_by_batch_2p5d', 'Linear2p5D', 'LayerNorm2p5D', 'Classifier2p5D', 'PatchEmbedding2p5D',
'Embedding2p5D', 'VocabParallelClassifier2p5D', 'VocabParallelEmbedding2p5D'
]

View File

@ -755,7 +755,7 @@ class SplitFirst(torch.autograd.Function):
return grad, None, None
def split_tensor_2p5d(input_: Tensor, dim: int = 0) -> Tensor:
def split_batch_2p5d(input_: Tensor, dim: int = 0) -> Tensor:
"""Splits 2P5D tensor in specified dimension across cols.
Args:
@ -765,6 +765,11 @@ def split_tensor_2p5d(input_: Tensor, dim: int = 0) -> Tensor:
Returns:
:class:`torch.tensor`: The tensor has been split.
"""
dim_size = input_.size(dim)
world_size = gpc.get_world_size(ParallelMode.PARALLEL_2P5D_COL)
assert dim_size % world_size == 0, \
f'The batch size ({dim_size}) is not a multiple of 2.5D size * depth ({world_size}).'
if input_.size(dim) <= 1:
return input_
return torch.chunk(input_, gpc.get_world_size(ParallelMode.PARALLEL_2P5D_COL),
@ -819,6 +824,11 @@ def reduce_scatter_tensor_2p5d(input_: Tensor, dim: int, parallel_mode: Parallel
The parallel_mode should be concluded in ``ParallelMode``. More details about ``ParallelMode`` could be found
in `parallel_mode <https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/context/parallel_mode.py>`_
"""
dim_size = input_.size(dim)
world_size = gpc.get_world_size(parallel_mode)
assert dim_size % world_size == 0, \
f'The batch size ({dim_size}) is not a multiple of 2.5D size * depth ({world_size}).'
return _ReduceScatterTensor2p5D.apply(input_, dim, parallel_mode)

View File

@ -20,7 +20,7 @@ from torch.nn import Parameter
from ..base_layer import ParallelLayer
from ..utils import divide, set_tensor_parallel_attribute_by_partition, to_2tuple
from ._operation import (Matmul_AB_2p5D, Matmul_ABT_2p5D, add_bias_2p5d, all_gather_tensor_2p5d, classifier_2p5d,
layernorm_2p5d, reduce_scatter_tensor_2p5d, split_tensor_2p5d)
layernorm_2p5d, reduce_scatter_tensor_2p5d, split_batch_2p5d)
from ._utils import assert_tesseract_initialization, get_tesseract_dim_dep_from_env
@ -568,7 +568,7 @@ class PatchEmbedding2p5D(ParallelLayer):
destination.update(local_state)
def forward(self, input_: Tensor) -> Tensor:
input_ = split_tensor_2p5d(input_, 0)
input_ = split_batch_2p5d(input_, 0)
B, C, H, W = input_.shape
assert H == self.img_size[0] and W == self.img_size[1], \
@ -713,7 +713,7 @@ class Embedding2p5D(ParallelLayer):
destination.update(local_state)
def forward(self, input_: Tensor) -> Tensor:
input_ = split_tensor_2p5d(input_, 0)
input_ = split_batch_2p5d(input_, 0)
weight = all_gather_tensor_2p5d(self.weight, -1, ParallelMode.PARALLEL_2P5D_COL)

View File

@ -276,6 +276,11 @@ def split_tensor_3d(tensor: Tensor, dim: int, parallel_mode: ParallelMode) -> Te
The parallel_mode should be concluded in ``ParallelMode``. More details about ``ParallelMode`` could be found
in `parallel_mode <https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/context/parallel_mode.py>`_.
"""
dim_size = tensor.size(dim)
world_size = gpc.get_world_size(parallel_mode)
assert dim_size % world_size == 0, \
f'The dimension {dim} to split, size ({dim_size}) is not a multiple of world size ({world_size}), ' \
f'cannot split tensor evenly'
if tensor.size(dim) <= 1:
return tensor
output = torch.chunk(tensor, gpc.get_world_size(parallel_mode),
@ -302,13 +307,20 @@ def split_batch_3d(input_: Tensor,
The parallel_mode should be concluded in ``ParallelMode``. More details about ``ParallelMode`` could be found
in `parallel_mode <https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/context/parallel_mode.py>`_.
"""
if input_.size(dim) <= 1:
return input_
dim_size = input_.size(dim)
weight_parallel_mode = get_parallel_mode_from_env(WEIGHT_GROUP_3D)
input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D)
output = torch.chunk(input_, gpc.get_world_size(weight_parallel_mode),
weight_world_size = gpc.get_world_size(weight_parallel_mode)
input_world_size = gpc.get_world_size(input_parallel_mode)
assert dim_size % (input_world_size*weight_world_size) == 0, \
f'The batch size ({dim_size}) is not a multiple of square of 3D depth ({input_world_size*weight_world_size}).'
if input_.size(dim) <= 1:
return input_
output = torch.chunk(input_, weight_world_size,
dim=dim)[gpc.get_local_rank(weight_parallel_mode)].contiguous()
output = torch.chunk(output, gpc.get_world_size(input_parallel_mode),
output = torch.chunk(output, input_world_size,
dim=dim)[gpc.get_local_rank(input_parallel_mode)].contiguous()
return output
@ -394,6 +406,11 @@ def reduce_scatter_tensor_3d(tensor: Tensor, dim: int, parallel_mode: ParallelMo
The parallel_mode should be concluded in ``ParallelMode``. More details about ``ParallelMode`` could be found
in `parallel_mode <https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/context/parallel_mode.py>`_
"""
dim_size = tensor.size(dim)
world_size = gpc.get_world_size(parallel_mode)
assert dim_size % world_size == 0, \
f'The batch size ({dim_size}) is not a multiple of square of 3D depth ({world_size}).'
return _ReduceScatterTensor3D.apply(tensor, dim, parallel_mode)

View File

@ -2,7 +2,7 @@ import torch
import torch.distributed as dist
from colossalai.context import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.nn.layer.parallel_2d import reduce_by_batch_2d, split_tensor_2d
from colossalai.nn.layer.parallel_2d import reduce_by_batch_2d, split_batch_2d
from colossalai.nn.layer.parallel_2d._utils import assert_summa_initialization
from colossalai.registry import LOSSES
from colossalai.utils import get_current_device
@ -48,7 +48,7 @@ class CrossEntropyLoss2D(_Loss):
Returns:
float: the loss between logits and targets.
"""
targets = split_tensor_2d(targets)
targets = split_batch_2d(targets)
loss = cross_entropy(logits, targets, reduction='none', *self.loss_args, **self.loss_kwargs)
if self.reduction_mean:
loss = loss.mean()
@ -145,7 +145,7 @@ class VocabParallelCrossEntropyLoss2D(_Loss):
logits (:class:`torch.tensor`): Predicted unnormalized scores (often referred to as logits).
targets (:class:`torch.tensor`): Ground truth class indices or class probabilities.
"""
targets = split_tensor_2d(targets)
targets = split_batch_2d(targets)
loss = _VocabParallelCrossEntropy2D.apply(
logits,
targets,

View File

@ -2,7 +2,7 @@ import torch
import torch.distributed as dist
from colossalai.context import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.nn.layer.parallel_2p5d import reduce_by_batch_2p5d, split_tensor_2p5d
from colossalai.nn.layer.parallel_2p5d import reduce_by_batch_2p5d, split_batch_2p5d
from colossalai.nn.layer.parallel_2p5d._utils import assert_tesseract_initialization
from colossalai.registry import LOSSES
from colossalai.utils import get_current_device
@ -44,7 +44,7 @@ class CrossEntropyLoss2p5D(_Loss):
logits (:class:`torch.tensor`): Predicted unnormalized scores (often referred to as logits).
targets (:class:`torch.tensor`): Ground truth class indices or class probabilities.
"""
targets = split_tensor_2p5d(targets)
targets = split_batch_2p5d(targets)
loss = cross_entropy(logits, targets, reduction='none', *self.loss_args, **self.loss_kwargs)
if self.reduction_mean:
loss = loss.mean()
@ -138,7 +138,7 @@ class VocabParallelCrossEntropyLoss2p5D(_Loss):
logits (:class:`torch.tensor`): Predicted unnormalized scores (often referred to as logits).
targets (:class:`torch.tensor`): Ground truth class indices or class probabilities.
"""
targets = split_tensor_2p5d(targets)
targets = split_batch_2p5d(targets)
loss = _VocabParallelCrossEntropy2p5D.apply(logits, targets)
if self.reduction_mean:
loss = loss.mean()

View File

@ -1,5 +1,5 @@
import torch
from colossalai.nn.layer.parallel_2d import reduce_by_batch_2d, split_tensor_2d
from colossalai.nn.layer.parallel_2d import reduce_by_batch_2d, split_batch_2d
from torch import nn
from ._utils import calc_acc
@ -22,7 +22,7 @@ class Accuracy2D(nn.Module):
float: the accuracy of prediction.
"""
with torch.no_grad():
targets = split_tensor_2d(targets)
targets = split_batch_2d(targets)
correct = calc_acc(logits, targets)
correct = reduce_by_batch_2d(correct)
return correct

View File

@ -1,5 +1,5 @@
import torch
from colossalai.nn.layer.parallel_2p5d import reduce_by_batch_2p5d, split_tensor_2p5d
from colossalai.nn.layer.parallel_2p5d import reduce_by_batch_2p5d, split_batch_2p5d
from torch import nn
from ._utils import calc_acc
@ -22,7 +22,7 @@ class Accuracy2p5D(nn.Module):
float: the accuracy of prediction.
"""
with torch.no_grad():
targets = split_tensor_2p5d(targets)
targets = split_batch_2p5d(targets)
correct = calc_acc(logits, targets)
correct = reduce_by_batch_2p5d(correct)
return correct