mirror of https://github.com/hpcaitech/ColossalAI
[hotfix] Raise messages for indivisible batch sizes with tensor parallelism (#622)
parent
e0f875a8e2
commit
828e465622
|
@ -1,12 +1,12 @@
|
|||
import torch.nn as nn
|
||||
from torch import Tensor
|
||||
|
||||
from ..parallel_2d._operation import split_tensor_2d
|
||||
from ..parallel_2p5d._operation import split_tensor_2p5d
|
||||
from ..parallel_2d._operation import split_batch_2d
|
||||
from ..parallel_2p5d._operation import split_batch_2p5d
|
||||
from ..parallel_3d._operation import split_batch_3d
|
||||
from ..utils import get_tensor_parallel_mode
|
||||
|
||||
_parallel_split_batch = {'2d': split_tensor_2d, '2.5d': split_tensor_2p5d, '3d': split_batch_3d}
|
||||
_parallel_split_batch = {'2d': split_batch_2d, '2.5d': split_batch_2p5d, '3d': split_batch_3d}
|
||||
|
||||
|
||||
def partition_batch(input_) -> Tensor:
|
||||
|
|
|
@ -1,8 +1,8 @@
|
|||
from ._operation import reduce_by_batch_2d, split_tensor_2d
|
||||
from ._operation import reduce_by_batch_2d, split_batch_2d
|
||||
from .layers import (Classifier2D, Embedding2D, LayerNorm2D, Linear2D, PatchEmbedding2D, VocabParallelClassifier2D,
|
||||
VocabParallelEmbedding2D)
|
||||
|
||||
__all__ = [
|
||||
'split_tensor_2d', 'reduce_by_batch_2d', 'Linear2D', 'LayerNorm2D', 'Classifier2D', 'PatchEmbedding2D',
|
||||
'split_batch_2d', 'reduce_by_batch_2d', 'Linear2D', 'LayerNorm2D', 'Classifier2D', 'PatchEmbedding2D',
|
||||
'Embedding2D', 'VocabParallelEmbedding2D', 'VocabParallelClassifier2D'
|
||||
]
|
||||
|
|
|
@ -720,7 +720,7 @@ def all_gather_tensor_2d(tensor: Tensor, dim: int, parallel_mode: ParallelMode)
|
|||
return _AllGatherTensor2D.apply(tensor, dim, parallel_mode)
|
||||
|
||||
|
||||
def split_tensor_2d(input_: Tensor, dim: int = 0) -> Tensor:
|
||||
def split_batch_2d(input_: Tensor, dim: int = 0) -> Tensor:
|
||||
"""Splits 2D tensor in specified dimension across cols.
|
||||
|
||||
Args:
|
||||
|
@ -730,6 +730,11 @@ def split_tensor_2d(input_: Tensor, dim: int = 0) -> Tensor:
|
|||
Returns:
|
||||
:class:`torch.tensor`: The tensor has been split.
|
||||
"""
|
||||
dim_size = input_.size(dim)
|
||||
world_size = gpc.get_world_size(ParallelMode.PARALLEL_2D_COL)
|
||||
assert dim_size % world_size == 0, \
|
||||
f'The batch size ({dim_size}) is not a multiple of 2D size ({world_size}).'
|
||||
|
||||
if input_.size(dim) <= 1:
|
||||
return input_
|
||||
return torch.chunk(input_, gpc.get_world_size(ParallelMode.PARALLEL_2D_COL),
|
||||
|
@ -784,6 +789,11 @@ def reduce_scatter_tensor_2d(tensor: Tensor, dim: int, parallel_mode: ParallelMo
|
|||
The parallel_mode should be concluded in ``ParallelMode``. More details about ``ParallelMode`` could be found
|
||||
in `parallel_mode <https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/context/parallel_mode.py>`_
|
||||
"""
|
||||
dim_size = tensor.size(dim)
|
||||
world_size = gpc.get_world_size(parallel_mode)
|
||||
assert dim_size % world_size == 0, \
|
||||
f'The batch size ({dim_size}) is not a multiple of 2D size ({world_size}).'
|
||||
|
||||
return _ReduceScatterTensor2D.apply(tensor, dim, parallel_mode)
|
||||
|
||||
|
||||
|
|
|
@ -19,7 +19,7 @@ from torch.nn import Parameter
|
|||
from ..base_layer import ParallelLayer
|
||||
from ..utils import divide, set_tensor_parallel_attribute_by_partition, to_2tuple
|
||||
from ._operation import (Matmul_AB_2D, Matmul_ABT_2D, add_bias_2d, all_gather_tensor_2d, classifier_2d, layernorm_2d,
|
||||
reduce_scatter_tensor_2d, split_tensor_2d)
|
||||
reduce_scatter_tensor_2d, split_batch_2d)
|
||||
from ._utils import assert_summa_initialization, get_summa_dim_from_env
|
||||
|
||||
|
||||
|
@ -547,7 +547,7 @@ class PatchEmbedding2D(ParallelLayer):
|
|||
destination.update(local_state)
|
||||
|
||||
def forward(self, input_: Tensor) -> Tensor:
|
||||
input_ = split_tensor_2d(input_)
|
||||
input_ = split_batch_2d(input_)
|
||||
|
||||
B, C, H, W = input_.shape
|
||||
assert H == self.img_size[0] and W == self.img_size[1], \
|
||||
|
@ -692,7 +692,7 @@ class Embedding2D(ParallelLayer):
|
|||
destination.update(local_state)
|
||||
|
||||
def forward(self, input_: Tensor) -> Tensor:
|
||||
input_ = split_tensor_2d(input_)
|
||||
input_ = split_batch_2d(input_)
|
||||
|
||||
weight = all_gather_tensor_2d(self.weight, -1, ParallelMode.PARALLEL_2D_COL)
|
||||
output = F.embedding(input_, weight, self.padding_idx, *self.embed_args, **self.embed_kwargs)
|
||||
|
|
|
@ -1,8 +1,8 @@
|
|||
from ._operation import reduce_by_batch_2p5d, split_tensor_2p5d
|
||||
from ._operation import reduce_by_batch_2p5d, split_batch_2p5d
|
||||
from .layers import (Classifier2p5D, Embedding2p5D, LayerNorm2p5D, Linear2p5D, PatchEmbedding2p5D,
|
||||
VocabParallelClassifier2p5D, VocabParallelEmbedding2p5D)
|
||||
|
||||
__all__ = [
|
||||
'split_tensor_2p5d', 'reduce_by_batch_2p5d', 'Linear2p5D', 'LayerNorm2p5D', 'Classifier2p5D', 'PatchEmbedding2p5D',
|
||||
'split_batch_2p5d', 'reduce_by_batch_2p5d', 'Linear2p5D', 'LayerNorm2p5D', 'Classifier2p5D', 'PatchEmbedding2p5D',
|
||||
'Embedding2p5D', 'VocabParallelClassifier2p5D', 'VocabParallelEmbedding2p5D'
|
||||
]
|
||||
|
|
|
@ -755,7 +755,7 @@ class SplitFirst(torch.autograd.Function):
|
|||
return grad, None, None
|
||||
|
||||
|
||||
def split_tensor_2p5d(input_: Tensor, dim: int = 0) -> Tensor:
|
||||
def split_batch_2p5d(input_: Tensor, dim: int = 0) -> Tensor:
|
||||
"""Splits 2P5D tensor in specified dimension across cols.
|
||||
|
||||
Args:
|
||||
|
@ -765,6 +765,11 @@ def split_tensor_2p5d(input_: Tensor, dim: int = 0) -> Tensor:
|
|||
Returns:
|
||||
:class:`torch.tensor`: The tensor has been split.
|
||||
"""
|
||||
dim_size = input_.size(dim)
|
||||
world_size = gpc.get_world_size(ParallelMode.PARALLEL_2P5D_COL)
|
||||
assert dim_size % world_size == 0, \
|
||||
f'The batch size ({dim_size}) is not a multiple of 2.5D size * depth ({world_size}).'
|
||||
|
||||
if input_.size(dim) <= 1:
|
||||
return input_
|
||||
return torch.chunk(input_, gpc.get_world_size(ParallelMode.PARALLEL_2P5D_COL),
|
||||
|
@ -819,6 +824,11 @@ def reduce_scatter_tensor_2p5d(input_: Tensor, dim: int, parallel_mode: Parallel
|
|||
The parallel_mode should be concluded in ``ParallelMode``. More details about ``ParallelMode`` could be found
|
||||
in `parallel_mode <https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/context/parallel_mode.py>`_
|
||||
"""
|
||||
dim_size = input_.size(dim)
|
||||
world_size = gpc.get_world_size(parallel_mode)
|
||||
assert dim_size % world_size == 0, \
|
||||
f'The batch size ({dim_size}) is not a multiple of 2.5D size * depth ({world_size}).'
|
||||
|
||||
return _ReduceScatterTensor2p5D.apply(input_, dim, parallel_mode)
|
||||
|
||||
|
||||
|
|
|
@ -20,7 +20,7 @@ from torch.nn import Parameter
|
|||
from ..base_layer import ParallelLayer
|
||||
from ..utils import divide, set_tensor_parallel_attribute_by_partition, to_2tuple
|
||||
from ._operation import (Matmul_AB_2p5D, Matmul_ABT_2p5D, add_bias_2p5d, all_gather_tensor_2p5d, classifier_2p5d,
|
||||
layernorm_2p5d, reduce_scatter_tensor_2p5d, split_tensor_2p5d)
|
||||
layernorm_2p5d, reduce_scatter_tensor_2p5d, split_batch_2p5d)
|
||||
from ._utils import assert_tesseract_initialization, get_tesseract_dim_dep_from_env
|
||||
|
||||
|
||||
|
@ -568,7 +568,7 @@ class PatchEmbedding2p5D(ParallelLayer):
|
|||
destination.update(local_state)
|
||||
|
||||
def forward(self, input_: Tensor) -> Tensor:
|
||||
input_ = split_tensor_2p5d(input_, 0)
|
||||
input_ = split_batch_2p5d(input_, 0)
|
||||
|
||||
B, C, H, W = input_.shape
|
||||
assert H == self.img_size[0] and W == self.img_size[1], \
|
||||
|
@ -713,7 +713,7 @@ class Embedding2p5D(ParallelLayer):
|
|||
destination.update(local_state)
|
||||
|
||||
def forward(self, input_: Tensor) -> Tensor:
|
||||
input_ = split_tensor_2p5d(input_, 0)
|
||||
input_ = split_batch_2p5d(input_, 0)
|
||||
|
||||
weight = all_gather_tensor_2p5d(self.weight, -1, ParallelMode.PARALLEL_2P5D_COL)
|
||||
|
||||
|
|
|
@ -276,6 +276,11 @@ def split_tensor_3d(tensor: Tensor, dim: int, parallel_mode: ParallelMode) -> Te
|
|||
The parallel_mode should be concluded in ``ParallelMode``. More details about ``ParallelMode`` could be found
|
||||
in `parallel_mode <https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/context/parallel_mode.py>`_.
|
||||
"""
|
||||
dim_size = tensor.size(dim)
|
||||
world_size = gpc.get_world_size(parallel_mode)
|
||||
assert dim_size % world_size == 0, \
|
||||
f'The dimension {dim} to split, size ({dim_size}) is not a multiple of world size ({world_size}), ' \
|
||||
f'cannot split tensor evenly'
|
||||
if tensor.size(dim) <= 1:
|
||||
return tensor
|
||||
output = torch.chunk(tensor, gpc.get_world_size(parallel_mode),
|
||||
|
@ -302,13 +307,20 @@ def split_batch_3d(input_: Tensor,
|
|||
The parallel_mode should be concluded in ``ParallelMode``. More details about ``ParallelMode`` could be found
|
||||
in `parallel_mode <https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/context/parallel_mode.py>`_.
|
||||
"""
|
||||
if input_.size(dim) <= 1:
|
||||
return input_
|
||||
dim_size = input_.size(dim)
|
||||
weight_parallel_mode = get_parallel_mode_from_env(WEIGHT_GROUP_3D)
|
||||
input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D)
|
||||
output = torch.chunk(input_, gpc.get_world_size(weight_parallel_mode),
|
||||
weight_world_size = gpc.get_world_size(weight_parallel_mode)
|
||||
input_world_size = gpc.get_world_size(input_parallel_mode)
|
||||
|
||||
assert dim_size % (input_world_size*weight_world_size) == 0, \
|
||||
f'The batch size ({dim_size}) is not a multiple of square of 3D depth ({input_world_size*weight_world_size}).'
|
||||
|
||||
if input_.size(dim) <= 1:
|
||||
return input_
|
||||
output = torch.chunk(input_, weight_world_size,
|
||||
dim=dim)[gpc.get_local_rank(weight_parallel_mode)].contiguous()
|
||||
output = torch.chunk(output, gpc.get_world_size(input_parallel_mode),
|
||||
output = torch.chunk(output, input_world_size,
|
||||
dim=dim)[gpc.get_local_rank(input_parallel_mode)].contiguous()
|
||||
return output
|
||||
|
||||
|
@ -394,6 +406,11 @@ def reduce_scatter_tensor_3d(tensor: Tensor, dim: int, parallel_mode: ParallelMo
|
|||
The parallel_mode should be concluded in ``ParallelMode``. More details about ``ParallelMode`` could be found
|
||||
in `parallel_mode <https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/context/parallel_mode.py>`_
|
||||
"""
|
||||
dim_size = tensor.size(dim)
|
||||
world_size = gpc.get_world_size(parallel_mode)
|
||||
assert dim_size % world_size == 0, \
|
||||
f'The batch size ({dim_size}) is not a multiple of square of 3D depth ({world_size}).'
|
||||
|
||||
return _ReduceScatterTensor3D.apply(tensor, dim, parallel_mode)
|
||||
|
||||
|
||||
|
|
|
@ -2,7 +2,7 @@ import torch
|
|||
import torch.distributed as dist
|
||||
from colossalai.context import ParallelMode
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.nn.layer.parallel_2d import reduce_by_batch_2d, split_tensor_2d
|
||||
from colossalai.nn.layer.parallel_2d import reduce_by_batch_2d, split_batch_2d
|
||||
from colossalai.nn.layer.parallel_2d._utils import assert_summa_initialization
|
||||
from colossalai.registry import LOSSES
|
||||
from colossalai.utils import get_current_device
|
||||
|
@ -48,7 +48,7 @@ class CrossEntropyLoss2D(_Loss):
|
|||
Returns:
|
||||
float: the loss between logits and targets.
|
||||
"""
|
||||
targets = split_tensor_2d(targets)
|
||||
targets = split_batch_2d(targets)
|
||||
loss = cross_entropy(logits, targets, reduction='none', *self.loss_args, **self.loss_kwargs)
|
||||
if self.reduction_mean:
|
||||
loss = loss.mean()
|
||||
|
@ -145,7 +145,7 @@ class VocabParallelCrossEntropyLoss2D(_Loss):
|
|||
logits (:class:`torch.tensor`): Predicted unnormalized scores (often referred to as logits).
|
||||
targets (:class:`torch.tensor`): Ground truth class indices or class probabilities.
|
||||
"""
|
||||
targets = split_tensor_2d(targets)
|
||||
targets = split_batch_2d(targets)
|
||||
loss = _VocabParallelCrossEntropy2D.apply(
|
||||
logits,
|
||||
targets,
|
||||
|
|
|
@ -2,7 +2,7 @@ import torch
|
|||
import torch.distributed as dist
|
||||
from colossalai.context import ParallelMode
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.nn.layer.parallel_2p5d import reduce_by_batch_2p5d, split_tensor_2p5d
|
||||
from colossalai.nn.layer.parallel_2p5d import reduce_by_batch_2p5d, split_batch_2p5d
|
||||
from colossalai.nn.layer.parallel_2p5d._utils import assert_tesseract_initialization
|
||||
from colossalai.registry import LOSSES
|
||||
from colossalai.utils import get_current_device
|
||||
|
@ -44,7 +44,7 @@ class CrossEntropyLoss2p5D(_Loss):
|
|||
logits (:class:`torch.tensor`): Predicted unnormalized scores (often referred to as logits).
|
||||
targets (:class:`torch.tensor`): Ground truth class indices or class probabilities.
|
||||
"""
|
||||
targets = split_tensor_2p5d(targets)
|
||||
targets = split_batch_2p5d(targets)
|
||||
loss = cross_entropy(logits, targets, reduction='none', *self.loss_args, **self.loss_kwargs)
|
||||
if self.reduction_mean:
|
||||
loss = loss.mean()
|
||||
|
@ -138,7 +138,7 @@ class VocabParallelCrossEntropyLoss2p5D(_Loss):
|
|||
logits (:class:`torch.tensor`): Predicted unnormalized scores (often referred to as logits).
|
||||
targets (:class:`torch.tensor`): Ground truth class indices or class probabilities.
|
||||
"""
|
||||
targets = split_tensor_2p5d(targets)
|
||||
targets = split_batch_2p5d(targets)
|
||||
loss = _VocabParallelCrossEntropy2p5D.apply(logits, targets)
|
||||
if self.reduction_mean:
|
||||
loss = loss.mean()
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
import torch
|
||||
from colossalai.nn.layer.parallel_2d import reduce_by_batch_2d, split_tensor_2d
|
||||
from colossalai.nn.layer.parallel_2d import reduce_by_batch_2d, split_batch_2d
|
||||
from torch import nn
|
||||
|
||||
from ._utils import calc_acc
|
||||
|
@ -22,7 +22,7 @@ class Accuracy2D(nn.Module):
|
|||
float: the accuracy of prediction.
|
||||
"""
|
||||
with torch.no_grad():
|
||||
targets = split_tensor_2d(targets)
|
||||
targets = split_batch_2d(targets)
|
||||
correct = calc_acc(logits, targets)
|
||||
correct = reduce_by_batch_2d(correct)
|
||||
return correct
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
import torch
|
||||
from colossalai.nn.layer.parallel_2p5d import reduce_by_batch_2p5d, split_tensor_2p5d
|
||||
from colossalai.nn.layer.parallel_2p5d import reduce_by_batch_2p5d, split_batch_2p5d
|
||||
from torch import nn
|
||||
|
||||
from ._utils import calc_acc
|
||||
|
@ -22,7 +22,7 @@ class Accuracy2p5D(nn.Module):
|
|||
float: the accuracy of prediction.
|
||||
"""
|
||||
with torch.no_grad():
|
||||
targets = split_tensor_2p5d(targets)
|
||||
targets = split_batch_2p5d(targets)
|
||||
correct = calc_acc(logits, targets)
|
||||
correct = reduce_by_batch_2p5d(correct)
|
||||
return correct
|
||||
|
|
Loading…
Reference in New Issue