mirror of https://github.com/hpcaitech/ColossalAI
[hotfix] Raise messages for indivisible batch sizes with tensor parallelism (#622)
parent
e0f875a8e2
commit
828e465622
|
@ -1,12 +1,12 @@
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
|
|
||||||
from ..parallel_2d._operation import split_tensor_2d
|
from ..parallel_2d._operation import split_batch_2d
|
||||||
from ..parallel_2p5d._operation import split_tensor_2p5d
|
from ..parallel_2p5d._operation import split_batch_2p5d
|
||||||
from ..parallel_3d._operation import split_batch_3d
|
from ..parallel_3d._operation import split_batch_3d
|
||||||
from ..utils import get_tensor_parallel_mode
|
from ..utils import get_tensor_parallel_mode
|
||||||
|
|
||||||
_parallel_split_batch = {'2d': split_tensor_2d, '2.5d': split_tensor_2p5d, '3d': split_batch_3d}
|
_parallel_split_batch = {'2d': split_batch_2d, '2.5d': split_batch_2p5d, '3d': split_batch_3d}
|
||||||
|
|
||||||
|
|
||||||
def partition_batch(input_) -> Tensor:
|
def partition_batch(input_) -> Tensor:
|
||||||
|
|
|
@ -1,8 +1,8 @@
|
||||||
from ._operation import reduce_by_batch_2d, split_tensor_2d
|
from ._operation import reduce_by_batch_2d, split_batch_2d
|
||||||
from .layers import (Classifier2D, Embedding2D, LayerNorm2D, Linear2D, PatchEmbedding2D, VocabParallelClassifier2D,
|
from .layers import (Classifier2D, Embedding2D, LayerNorm2D, Linear2D, PatchEmbedding2D, VocabParallelClassifier2D,
|
||||||
VocabParallelEmbedding2D)
|
VocabParallelEmbedding2D)
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
'split_tensor_2d', 'reduce_by_batch_2d', 'Linear2D', 'LayerNorm2D', 'Classifier2D', 'PatchEmbedding2D',
|
'split_batch_2d', 'reduce_by_batch_2d', 'Linear2D', 'LayerNorm2D', 'Classifier2D', 'PatchEmbedding2D',
|
||||||
'Embedding2D', 'VocabParallelEmbedding2D', 'VocabParallelClassifier2D'
|
'Embedding2D', 'VocabParallelEmbedding2D', 'VocabParallelClassifier2D'
|
||||||
]
|
]
|
||||||
|
|
|
@ -720,7 +720,7 @@ def all_gather_tensor_2d(tensor: Tensor, dim: int, parallel_mode: ParallelMode)
|
||||||
return _AllGatherTensor2D.apply(tensor, dim, parallel_mode)
|
return _AllGatherTensor2D.apply(tensor, dim, parallel_mode)
|
||||||
|
|
||||||
|
|
||||||
def split_tensor_2d(input_: Tensor, dim: int = 0) -> Tensor:
|
def split_batch_2d(input_: Tensor, dim: int = 0) -> Tensor:
|
||||||
"""Splits 2D tensor in specified dimension across cols.
|
"""Splits 2D tensor in specified dimension across cols.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
@ -730,6 +730,11 @@ def split_tensor_2d(input_: Tensor, dim: int = 0) -> Tensor:
|
||||||
Returns:
|
Returns:
|
||||||
:class:`torch.tensor`: The tensor has been split.
|
:class:`torch.tensor`: The tensor has been split.
|
||||||
"""
|
"""
|
||||||
|
dim_size = input_.size(dim)
|
||||||
|
world_size = gpc.get_world_size(ParallelMode.PARALLEL_2D_COL)
|
||||||
|
assert dim_size % world_size == 0, \
|
||||||
|
f'The batch size ({dim_size}) is not a multiple of 2D size ({world_size}).'
|
||||||
|
|
||||||
if input_.size(dim) <= 1:
|
if input_.size(dim) <= 1:
|
||||||
return input_
|
return input_
|
||||||
return torch.chunk(input_, gpc.get_world_size(ParallelMode.PARALLEL_2D_COL),
|
return torch.chunk(input_, gpc.get_world_size(ParallelMode.PARALLEL_2D_COL),
|
||||||
|
@ -784,6 +789,11 @@ def reduce_scatter_tensor_2d(tensor: Tensor, dim: int, parallel_mode: ParallelMo
|
||||||
The parallel_mode should be concluded in ``ParallelMode``. More details about ``ParallelMode`` could be found
|
The parallel_mode should be concluded in ``ParallelMode``. More details about ``ParallelMode`` could be found
|
||||||
in `parallel_mode <https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/context/parallel_mode.py>`_
|
in `parallel_mode <https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/context/parallel_mode.py>`_
|
||||||
"""
|
"""
|
||||||
|
dim_size = tensor.size(dim)
|
||||||
|
world_size = gpc.get_world_size(parallel_mode)
|
||||||
|
assert dim_size % world_size == 0, \
|
||||||
|
f'The batch size ({dim_size}) is not a multiple of 2D size ({world_size}).'
|
||||||
|
|
||||||
return _ReduceScatterTensor2D.apply(tensor, dim, parallel_mode)
|
return _ReduceScatterTensor2D.apply(tensor, dim, parallel_mode)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -19,7 +19,7 @@ from torch.nn import Parameter
|
||||||
from ..base_layer import ParallelLayer
|
from ..base_layer import ParallelLayer
|
||||||
from ..utils import divide, set_tensor_parallel_attribute_by_partition, to_2tuple
|
from ..utils import divide, set_tensor_parallel_attribute_by_partition, to_2tuple
|
||||||
from ._operation import (Matmul_AB_2D, Matmul_ABT_2D, add_bias_2d, all_gather_tensor_2d, classifier_2d, layernorm_2d,
|
from ._operation import (Matmul_AB_2D, Matmul_ABT_2D, add_bias_2d, all_gather_tensor_2d, classifier_2d, layernorm_2d,
|
||||||
reduce_scatter_tensor_2d, split_tensor_2d)
|
reduce_scatter_tensor_2d, split_batch_2d)
|
||||||
from ._utils import assert_summa_initialization, get_summa_dim_from_env
|
from ._utils import assert_summa_initialization, get_summa_dim_from_env
|
||||||
|
|
||||||
|
|
||||||
|
@ -547,7 +547,7 @@ class PatchEmbedding2D(ParallelLayer):
|
||||||
destination.update(local_state)
|
destination.update(local_state)
|
||||||
|
|
||||||
def forward(self, input_: Tensor) -> Tensor:
|
def forward(self, input_: Tensor) -> Tensor:
|
||||||
input_ = split_tensor_2d(input_)
|
input_ = split_batch_2d(input_)
|
||||||
|
|
||||||
B, C, H, W = input_.shape
|
B, C, H, W = input_.shape
|
||||||
assert H == self.img_size[0] and W == self.img_size[1], \
|
assert H == self.img_size[0] and W == self.img_size[1], \
|
||||||
|
@ -692,7 +692,7 @@ class Embedding2D(ParallelLayer):
|
||||||
destination.update(local_state)
|
destination.update(local_state)
|
||||||
|
|
||||||
def forward(self, input_: Tensor) -> Tensor:
|
def forward(self, input_: Tensor) -> Tensor:
|
||||||
input_ = split_tensor_2d(input_)
|
input_ = split_batch_2d(input_)
|
||||||
|
|
||||||
weight = all_gather_tensor_2d(self.weight, -1, ParallelMode.PARALLEL_2D_COL)
|
weight = all_gather_tensor_2d(self.weight, -1, ParallelMode.PARALLEL_2D_COL)
|
||||||
output = F.embedding(input_, weight, self.padding_idx, *self.embed_args, **self.embed_kwargs)
|
output = F.embedding(input_, weight, self.padding_idx, *self.embed_args, **self.embed_kwargs)
|
||||||
|
|
|
@ -1,8 +1,8 @@
|
||||||
from ._operation import reduce_by_batch_2p5d, split_tensor_2p5d
|
from ._operation import reduce_by_batch_2p5d, split_batch_2p5d
|
||||||
from .layers import (Classifier2p5D, Embedding2p5D, LayerNorm2p5D, Linear2p5D, PatchEmbedding2p5D,
|
from .layers import (Classifier2p5D, Embedding2p5D, LayerNorm2p5D, Linear2p5D, PatchEmbedding2p5D,
|
||||||
VocabParallelClassifier2p5D, VocabParallelEmbedding2p5D)
|
VocabParallelClassifier2p5D, VocabParallelEmbedding2p5D)
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
'split_tensor_2p5d', 'reduce_by_batch_2p5d', 'Linear2p5D', 'LayerNorm2p5D', 'Classifier2p5D', 'PatchEmbedding2p5D',
|
'split_batch_2p5d', 'reduce_by_batch_2p5d', 'Linear2p5D', 'LayerNorm2p5D', 'Classifier2p5D', 'PatchEmbedding2p5D',
|
||||||
'Embedding2p5D', 'VocabParallelClassifier2p5D', 'VocabParallelEmbedding2p5D'
|
'Embedding2p5D', 'VocabParallelClassifier2p5D', 'VocabParallelEmbedding2p5D'
|
||||||
]
|
]
|
||||||
|
|
|
@ -755,7 +755,7 @@ class SplitFirst(torch.autograd.Function):
|
||||||
return grad, None, None
|
return grad, None, None
|
||||||
|
|
||||||
|
|
||||||
def split_tensor_2p5d(input_: Tensor, dim: int = 0) -> Tensor:
|
def split_batch_2p5d(input_: Tensor, dim: int = 0) -> Tensor:
|
||||||
"""Splits 2P5D tensor in specified dimension across cols.
|
"""Splits 2P5D tensor in specified dimension across cols.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
@ -765,6 +765,11 @@ def split_tensor_2p5d(input_: Tensor, dim: int = 0) -> Tensor:
|
||||||
Returns:
|
Returns:
|
||||||
:class:`torch.tensor`: The tensor has been split.
|
:class:`torch.tensor`: The tensor has been split.
|
||||||
"""
|
"""
|
||||||
|
dim_size = input_.size(dim)
|
||||||
|
world_size = gpc.get_world_size(ParallelMode.PARALLEL_2P5D_COL)
|
||||||
|
assert dim_size % world_size == 0, \
|
||||||
|
f'The batch size ({dim_size}) is not a multiple of 2.5D size * depth ({world_size}).'
|
||||||
|
|
||||||
if input_.size(dim) <= 1:
|
if input_.size(dim) <= 1:
|
||||||
return input_
|
return input_
|
||||||
return torch.chunk(input_, gpc.get_world_size(ParallelMode.PARALLEL_2P5D_COL),
|
return torch.chunk(input_, gpc.get_world_size(ParallelMode.PARALLEL_2P5D_COL),
|
||||||
|
@ -819,6 +824,11 @@ def reduce_scatter_tensor_2p5d(input_: Tensor, dim: int, parallel_mode: Parallel
|
||||||
The parallel_mode should be concluded in ``ParallelMode``. More details about ``ParallelMode`` could be found
|
The parallel_mode should be concluded in ``ParallelMode``. More details about ``ParallelMode`` could be found
|
||||||
in `parallel_mode <https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/context/parallel_mode.py>`_
|
in `parallel_mode <https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/context/parallel_mode.py>`_
|
||||||
"""
|
"""
|
||||||
|
dim_size = input_.size(dim)
|
||||||
|
world_size = gpc.get_world_size(parallel_mode)
|
||||||
|
assert dim_size % world_size == 0, \
|
||||||
|
f'The batch size ({dim_size}) is not a multiple of 2.5D size * depth ({world_size}).'
|
||||||
|
|
||||||
return _ReduceScatterTensor2p5D.apply(input_, dim, parallel_mode)
|
return _ReduceScatterTensor2p5D.apply(input_, dim, parallel_mode)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -20,7 +20,7 @@ from torch.nn import Parameter
|
||||||
from ..base_layer import ParallelLayer
|
from ..base_layer import ParallelLayer
|
||||||
from ..utils import divide, set_tensor_parallel_attribute_by_partition, to_2tuple
|
from ..utils import divide, set_tensor_parallel_attribute_by_partition, to_2tuple
|
||||||
from ._operation import (Matmul_AB_2p5D, Matmul_ABT_2p5D, add_bias_2p5d, all_gather_tensor_2p5d, classifier_2p5d,
|
from ._operation import (Matmul_AB_2p5D, Matmul_ABT_2p5D, add_bias_2p5d, all_gather_tensor_2p5d, classifier_2p5d,
|
||||||
layernorm_2p5d, reduce_scatter_tensor_2p5d, split_tensor_2p5d)
|
layernorm_2p5d, reduce_scatter_tensor_2p5d, split_batch_2p5d)
|
||||||
from ._utils import assert_tesseract_initialization, get_tesseract_dim_dep_from_env
|
from ._utils import assert_tesseract_initialization, get_tesseract_dim_dep_from_env
|
||||||
|
|
||||||
|
|
||||||
|
@ -568,7 +568,7 @@ class PatchEmbedding2p5D(ParallelLayer):
|
||||||
destination.update(local_state)
|
destination.update(local_state)
|
||||||
|
|
||||||
def forward(self, input_: Tensor) -> Tensor:
|
def forward(self, input_: Tensor) -> Tensor:
|
||||||
input_ = split_tensor_2p5d(input_, 0)
|
input_ = split_batch_2p5d(input_, 0)
|
||||||
|
|
||||||
B, C, H, W = input_.shape
|
B, C, H, W = input_.shape
|
||||||
assert H == self.img_size[0] and W == self.img_size[1], \
|
assert H == self.img_size[0] and W == self.img_size[1], \
|
||||||
|
@ -713,7 +713,7 @@ class Embedding2p5D(ParallelLayer):
|
||||||
destination.update(local_state)
|
destination.update(local_state)
|
||||||
|
|
||||||
def forward(self, input_: Tensor) -> Tensor:
|
def forward(self, input_: Tensor) -> Tensor:
|
||||||
input_ = split_tensor_2p5d(input_, 0)
|
input_ = split_batch_2p5d(input_, 0)
|
||||||
|
|
||||||
weight = all_gather_tensor_2p5d(self.weight, -1, ParallelMode.PARALLEL_2P5D_COL)
|
weight = all_gather_tensor_2p5d(self.weight, -1, ParallelMode.PARALLEL_2P5D_COL)
|
||||||
|
|
||||||
|
|
|
@ -276,6 +276,11 @@ def split_tensor_3d(tensor: Tensor, dim: int, parallel_mode: ParallelMode) -> Te
|
||||||
The parallel_mode should be concluded in ``ParallelMode``. More details about ``ParallelMode`` could be found
|
The parallel_mode should be concluded in ``ParallelMode``. More details about ``ParallelMode`` could be found
|
||||||
in `parallel_mode <https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/context/parallel_mode.py>`_.
|
in `parallel_mode <https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/context/parallel_mode.py>`_.
|
||||||
"""
|
"""
|
||||||
|
dim_size = tensor.size(dim)
|
||||||
|
world_size = gpc.get_world_size(parallel_mode)
|
||||||
|
assert dim_size % world_size == 0, \
|
||||||
|
f'The dimension {dim} to split, size ({dim_size}) is not a multiple of world size ({world_size}), ' \
|
||||||
|
f'cannot split tensor evenly'
|
||||||
if tensor.size(dim) <= 1:
|
if tensor.size(dim) <= 1:
|
||||||
return tensor
|
return tensor
|
||||||
output = torch.chunk(tensor, gpc.get_world_size(parallel_mode),
|
output = torch.chunk(tensor, gpc.get_world_size(parallel_mode),
|
||||||
|
@ -302,13 +307,20 @@ def split_batch_3d(input_: Tensor,
|
||||||
The parallel_mode should be concluded in ``ParallelMode``. More details about ``ParallelMode`` could be found
|
The parallel_mode should be concluded in ``ParallelMode``. More details about ``ParallelMode`` could be found
|
||||||
in `parallel_mode <https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/context/parallel_mode.py>`_.
|
in `parallel_mode <https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/context/parallel_mode.py>`_.
|
||||||
"""
|
"""
|
||||||
if input_.size(dim) <= 1:
|
dim_size = input_.size(dim)
|
||||||
return input_
|
|
||||||
weight_parallel_mode = get_parallel_mode_from_env(WEIGHT_GROUP_3D)
|
weight_parallel_mode = get_parallel_mode_from_env(WEIGHT_GROUP_3D)
|
||||||
input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D)
|
input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D)
|
||||||
output = torch.chunk(input_, gpc.get_world_size(weight_parallel_mode),
|
weight_world_size = gpc.get_world_size(weight_parallel_mode)
|
||||||
|
input_world_size = gpc.get_world_size(input_parallel_mode)
|
||||||
|
|
||||||
|
assert dim_size % (input_world_size*weight_world_size) == 0, \
|
||||||
|
f'The batch size ({dim_size}) is not a multiple of square of 3D depth ({input_world_size*weight_world_size}).'
|
||||||
|
|
||||||
|
if input_.size(dim) <= 1:
|
||||||
|
return input_
|
||||||
|
output = torch.chunk(input_, weight_world_size,
|
||||||
dim=dim)[gpc.get_local_rank(weight_parallel_mode)].contiguous()
|
dim=dim)[gpc.get_local_rank(weight_parallel_mode)].contiguous()
|
||||||
output = torch.chunk(output, gpc.get_world_size(input_parallel_mode),
|
output = torch.chunk(output, input_world_size,
|
||||||
dim=dim)[gpc.get_local_rank(input_parallel_mode)].contiguous()
|
dim=dim)[gpc.get_local_rank(input_parallel_mode)].contiguous()
|
||||||
return output
|
return output
|
||||||
|
|
||||||
|
@ -394,6 +406,11 @@ def reduce_scatter_tensor_3d(tensor: Tensor, dim: int, parallel_mode: ParallelMo
|
||||||
The parallel_mode should be concluded in ``ParallelMode``. More details about ``ParallelMode`` could be found
|
The parallel_mode should be concluded in ``ParallelMode``. More details about ``ParallelMode`` could be found
|
||||||
in `parallel_mode <https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/context/parallel_mode.py>`_
|
in `parallel_mode <https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/context/parallel_mode.py>`_
|
||||||
"""
|
"""
|
||||||
|
dim_size = tensor.size(dim)
|
||||||
|
world_size = gpc.get_world_size(parallel_mode)
|
||||||
|
assert dim_size % world_size == 0, \
|
||||||
|
f'The batch size ({dim_size}) is not a multiple of square of 3D depth ({world_size}).'
|
||||||
|
|
||||||
return _ReduceScatterTensor3D.apply(tensor, dim, parallel_mode)
|
return _ReduceScatterTensor3D.apply(tensor, dim, parallel_mode)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -2,7 +2,7 @@ import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
from colossalai.context import ParallelMode
|
from colossalai.context import ParallelMode
|
||||||
from colossalai.core import global_context as gpc
|
from colossalai.core import global_context as gpc
|
||||||
from colossalai.nn.layer.parallel_2d import reduce_by_batch_2d, split_tensor_2d
|
from colossalai.nn.layer.parallel_2d import reduce_by_batch_2d, split_batch_2d
|
||||||
from colossalai.nn.layer.parallel_2d._utils import assert_summa_initialization
|
from colossalai.nn.layer.parallel_2d._utils import assert_summa_initialization
|
||||||
from colossalai.registry import LOSSES
|
from colossalai.registry import LOSSES
|
||||||
from colossalai.utils import get_current_device
|
from colossalai.utils import get_current_device
|
||||||
|
@ -48,7 +48,7 @@ class CrossEntropyLoss2D(_Loss):
|
||||||
Returns:
|
Returns:
|
||||||
float: the loss between logits and targets.
|
float: the loss between logits and targets.
|
||||||
"""
|
"""
|
||||||
targets = split_tensor_2d(targets)
|
targets = split_batch_2d(targets)
|
||||||
loss = cross_entropy(logits, targets, reduction='none', *self.loss_args, **self.loss_kwargs)
|
loss = cross_entropy(logits, targets, reduction='none', *self.loss_args, **self.loss_kwargs)
|
||||||
if self.reduction_mean:
|
if self.reduction_mean:
|
||||||
loss = loss.mean()
|
loss = loss.mean()
|
||||||
|
@ -145,7 +145,7 @@ class VocabParallelCrossEntropyLoss2D(_Loss):
|
||||||
logits (:class:`torch.tensor`): Predicted unnormalized scores (often referred to as logits).
|
logits (:class:`torch.tensor`): Predicted unnormalized scores (often referred to as logits).
|
||||||
targets (:class:`torch.tensor`): Ground truth class indices or class probabilities.
|
targets (:class:`torch.tensor`): Ground truth class indices or class probabilities.
|
||||||
"""
|
"""
|
||||||
targets = split_tensor_2d(targets)
|
targets = split_batch_2d(targets)
|
||||||
loss = _VocabParallelCrossEntropy2D.apply(
|
loss = _VocabParallelCrossEntropy2D.apply(
|
||||||
logits,
|
logits,
|
||||||
targets,
|
targets,
|
||||||
|
|
|
@ -2,7 +2,7 @@ import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
from colossalai.context import ParallelMode
|
from colossalai.context import ParallelMode
|
||||||
from colossalai.core import global_context as gpc
|
from colossalai.core import global_context as gpc
|
||||||
from colossalai.nn.layer.parallel_2p5d import reduce_by_batch_2p5d, split_tensor_2p5d
|
from colossalai.nn.layer.parallel_2p5d import reduce_by_batch_2p5d, split_batch_2p5d
|
||||||
from colossalai.nn.layer.parallel_2p5d._utils import assert_tesseract_initialization
|
from colossalai.nn.layer.parallel_2p5d._utils import assert_tesseract_initialization
|
||||||
from colossalai.registry import LOSSES
|
from colossalai.registry import LOSSES
|
||||||
from colossalai.utils import get_current_device
|
from colossalai.utils import get_current_device
|
||||||
|
@ -44,7 +44,7 @@ class CrossEntropyLoss2p5D(_Loss):
|
||||||
logits (:class:`torch.tensor`): Predicted unnormalized scores (often referred to as logits).
|
logits (:class:`torch.tensor`): Predicted unnormalized scores (often referred to as logits).
|
||||||
targets (:class:`torch.tensor`): Ground truth class indices or class probabilities.
|
targets (:class:`torch.tensor`): Ground truth class indices or class probabilities.
|
||||||
"""
|
"""
|
||||||
targets = split_tensor_2p5d(targets)
|
targets = split_batch_2p5d(targets)
|
||||||
loss = cross_entropy(logits, targets, reduction='none', *self.loss_args, **self.loss_kwargs)
|
loss = cross_entropy(logits, targets, reduction='none', *self.loss_args, **self.loss_kwargs)
|
||||||
if self.reduction_mean:
|
if self.reduction_mean:
|
||||||
loss = loss.mean()
|
loss = loss.mean()
|
||||||
|
@ -138,7 +138,7 @@ class VocabParallelCrossEntropyLoss2p5D(_Loss):
|
||||||
logits (:class:`torch.tensor`): Predicted unnormalized scores (often referred to as logits).
|
logits (:class:`torch.tensor`): Predicted unnormalized scores (often referred to as logits).
|
||||||
targets (:class:`torch.tensor`): Ground truth class indices or class probabilities.
|
targets (:class:`torch.tensor`): Ground truth class indices or class probabilities.
|
||||||
"""
|
"""
|
||||||
targets = split_tensor_2p5d(targets)
|
targets = split_batch_2p5d(targets)
|
||||||
loss = _VocabParallelCrossEntropy2p5D.apply(logits, targets)
|
loss = _VocabParallelCrossEntropy2p5D.apply(logits, targets)
|
||||||
if self.reduction_mean:
|
if self.reduction_mean:
|
||||||
loss = loss.mean()
|
loss = loss.mean()
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
import torch
|
import torch
|
||||||
from colossalai.nn.layer.parallel_2d import reduce_by_batch_2d, split_tensor_2d
|
from colossalai.nn.layer.parallel_2d import reduce_by_batch_2d, split_batch_2d
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
from ._utils import calc_acc
|
from ._utils import calc_acc
|
||||||
|
@ -22,7 +22,7 @@ class Accuracy2D(nn.Module):
|
||||||
float: the accuracy of prediction.
|
float: the accuracy of prediction.
|
||||||
"""
|
"""
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
targets = split_tensor_2d(targets)
|
targets = split_batch_2d(targets)
|
||||||
correct = calc_acc(logits, targets)
|
correct = calc_acc(logits, targets)
|
||||||
correct = reduce_by_batch_2d(correct)
|
correct = reduce_by_batch_2d(correct)
|
||||||
return correct
|
return correct
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
import torch
|
import torch
|
||||||
from colossalai.nn.layer.parallel_2p5d import reduce_by_batch_2p5d, split_tensor_2p5d
|
from colossalai.nn.layer.parallel_2p5d import reduce_by_batch_2p5d, split_batch_2p5d
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
from ._utils import calc_acc
|
from ._utils import calc_acc
|
||||||
|
@ -22,7 +22,7 @@ class Accuracy2p5D(nn.Module):
|
||||||
float: the accuracy of prediction.
|
float: the accuracy of prediction.
|
||||||
"""
|
"""
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
targets = split_tensor_2p5d(targets)
|
targets = split_batch_2p5d(targets)
|
||||||
correct = calc_acc(logits, targets)
|
correct = calc_acc(logits, targets)
|
||||||
correct = reduce_by_batch_2p5d(correct)
|
correct = reduce_by_batch_2p5d(correct)
|
||||||
return correct
|
return correct
|
||||||
|
|
Loading…
Reference in New Issue