mirror of https://github.com/hpcaitech/ColossalAI
Update layer integration documentations (#108)
Update the documentations of layer integration Update _log_hook.py Update _operation.pypull/140/head
parent
3a61d785b5
commit
4a3d3446b0
|
@ -9,6 +9,14 @@ from ..utils import get_tensor_parallel_mode
|
||||||
|
|
||||||
|
|
||||||
class Dropout(nn.Module):
|
class Dropout(nn.Module):
|
||||||
|
"""
|
||||||
|
Dropout layer of colossalai
|
||||||
|
|
||||||
|
:param p: dropout rate, defaults to 0.5
|
||||||
|
:type p: float, optional
|
||||||
|
:param inplace: If set to ``True``, will do this operation in-place, defaults tp ``False``
|
||||||
|
:type inplace: bool, optional
|
||||||
|
"""
|
||||||
def __init__(self, p: float = 0.5, inplace: bool = False) -> None:
|
def __init__(self, p: float = 0.5, inplace: bool = False) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.tensor_parallel = get_tensor_parallel_mode()
|
self.tensor_parallel = get_tensor_parallel_mode()
|
||||||
|
|
|
@ -24,6 +24,20 @@ _parallel_patchembedding = {
|
||||||
|
|
||||||
|
|
||||||
class Embedding(nn.Module):
|
class Embedding(nn.Module):
|
||||||
|
"""
|
||||||
|
Embedding for colossalai
|
||||||
|
|
||||||
|
:param num_embeddings: number of embeddings
|
||||||
|
:type num_embeddings: int
|
||||||
|
:param embedding_dim: dimension of embedding
|
||||||
|
:type embedding_dim: int
|
||||||
|
:param padding_idx: index of padding, defaults to None
|
||||||
|
:type padding_idx: int, optional
|
||||||
|
:param dtype: The dtype of parameters, defaults to None
|
||||||
|
:type dtype: torch.dtype, optional
|
||||||
|
:param weight_initializer: The intializer of weight, defaults to normal initializer
|
||||||
|
:type weight_initializer: typing.Callable, optional
|
||||||
|
"""
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
num_embeddings: int,
|
num_embeddings: int,
|
||||||
embedding_dim: int,
|
embedding_dim: int,
|
||||||
|
@ -63,6 +77,28 @@ class Embedding(nn.Module):
|
||||||
|
|
||||||
|
|
||||||
class PatchEmbedding(nn.Module):
|
class PatchEmbedding(nn.Module):
|
||||||
|
"""
|
||||||
|
2D Image to Patch Embedding
|
||||||
|
|
||||||
|
:param img_size: image size
|
||||||
|
:type img_size: int
|
||||||
|
:param patch_size: patch size
|
||||||
|
:type patch_size: int
|
||||||
|
:param in_chans: number of channels of input image
|
||||||
|
:type in_chans: int
|
||||||
|
:param embed_size: size of embedding
|
||||||
|
:type embed_size: int
|
||||||
|
:param dtype: The dtype of parameters, defaults to None
|
||||||
|
:type dtype: torch.dtype, optional
|
||||||
|
:param flatten: whether to flatten output tensor, defaults to True
|
||||||
|
:type flatten: bool, optional
|
||||||
|
:param weight_initializer: The intializer of weight, defaults to kaiming uniform initializer
|
||||||
|
:type weight_initializer: typing.Callable, optional
|
||||||
|
:param bias_initializer: The intializer of bias, defaults to xavier uniform initializer
|
||||||
|
:type bias_initializer: typing.Callable, optional
|
||||||
|
:param position_embed_initializer: The intializer of position embedding, defaults to zero
|
||||||
|
:type position_embed_initializer: typing.Callable, optional
|
||||||
|
"""
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
img_size: int,
|
img_size: int,
|
||||||
patch_size: int,
|
patch_size: int,
|
||||||
|
|
|
@ -25,6 +25,22 @@ _parallel_classifier = {
|
||||||
|
|
||||||
|
|
||||||
class Linear(nn.Module):
|
class Linear(nn.Module):
|
||||||
|
"""
|
||||||
|
Linear layer of colossalai
|
||||||
|
|
||||||
|
:param in_features: size of each input sample
|
||||||
|
:type in_features: int
|
||||||
|
:param out_features: size of each output sample
|
||||||
|
:type out_features: int
|
||||||
|
:param bias: If set to ``False``, the layer will not learn an additive bias, defaults to True
|
||||||
|
:type bias: bool, optional
|
||||||
|
:param dtype: The dtype of parameters, defaults to None
|
||||||
|
:type dtype: torch.dtype, optional
|
||||||
|
:param weight_initializer: The intializer of weight, defaults to kaiming uniform initializer
|
||||||
|
:type weight_initializer: typing.Callable, optional
|
||||||
|
:param bias_initializer: The intializer of bias, defaults to xavier uniform initializer
|
||||||
|
:type bias_initializer: typing.Callable, optional
|
||||||
|
"""
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
in_features: int,
|
in_features: int,
|
||||||
out_features: int,
|
out_features: int,
|
||||||
|
@ -64,6 +80,22 @@ class Linear(nn.Module):
|
||||||
|
|
||||||
|
|
||||||
class Classifier(nn.Module):
|
class Classifier(nn.Module):
|
||||||
|
"""
|
||||||
|
Classifier layer of colossalai
|
||||||
|
|
||||||
|
:param in_features: size of each input sample
|
||||||
|
:type in_features: int
|
||||||
|
:param num_classes: number of total classes for the dataset
|
||||||
|
:type num_classes: int
|
||||||
|
:param bias: If set to ``False``, the layer will not learn an additive bias, defaults to True
|
||||||
|
:type bias: bool, optional
|
||||||
|
:param dtype: The dtype of parameters, defaults to None
|
||||||
|
:type dtype: torch.dtype, optional
|
||||||
|
:param weight_initializer: The intializer of weight, defaults to kaiming uniform initializer
|
||||||
|
:type weight_initializer: typing.Callable, optional
|
||||||
|
:param bias_initializer: The intializer of bias, defaults to xavier uniform initializer
|
||||||
|
:type bias_initializer: typing.Callable, optional
|
||||||
|
"""
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
in_features: int,
|
in_features: int,
|
||||||
|
|
|
@ -15,6 +15,19 @@ _parallel_layernorm = {'2d': LayerNorm2D, '2.5d': LayerNorm2p5D, '3d': LayerNorm
|
||||||
|
|
||||||
|
|
||||||
class LayerNorm(nn.Module):
|
class LayerNorm(nn.Module):
|
||||||
|
r"""
|
||||||
|
Layer Normalization for colossalai
|
||||||
|
|
||||||
|
:param normalized_shape: input shape from an expected input
|
||||||
|
of size. :math:`[* \times \text{normalized_shape}[0] \times \text{normalized_shape}[1] \times \ldots \times \text{normalized_shape}[-1]]`
|
||||||
|
If a single integer is used, it is treated as a singleton list, and this module will
|
||||||
|
normalize over the last dimension which is expected to be of that specific size.
|
||||||
|
:type normalized_shape: int
|
||||||
|
:param eps: a value added to the denominator for numerical stability, defaults to 1e-05
|
||||||
|
:type eps: float, optional
|
||||||
|
:param dtype: The dtype of parameters, defaults to None
|
||||||
|
:type dtype: torch.dtype, optional
|
||||||
|
"""
|
||||||
def __init__(self, normalized_shape: int, eps=1e-05, dtype=None) -> None:
|
def __init__(self, normalized_shape: int, eps=1e-05, dtype=None) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
tensor_parallel = get_tensor_parallel_mode()
|
tensor_parallel = get_tensor_parallel_mode()
|
||||||
|
|
|
@ -7,6 +7,18 @@ except:
|
||||||
|
|
||||||
|
|
||||||
class FusedLayerNormAffineFunction1D(torch.autograd.Function):
|
class FusedLayerNormAffineFunction1D(torch.autograd.Function):
|
||||||
|
r"""
|
||||||
|
Layernorm
|
||||||
|
|
||||||
|
:param input: input maxtrix
|
||||||
|
:param weight: weight matrix
|
||||||
|
:param bias: bias matrix
|
||||||
|
:param normalized_shape: input shape from an expected input
|
||||||
|
of size. :math:`[* \times \text{normalized_shape}[0] \times \text{normalized_shape}[1] \times \ldots \times \text{normalized_shape}[-1]]`
|
||||||
|
If a single integer is used, it is treated as a singleton list, and this module will
|
||||||
|
normalize over the last dimension which is expected to be of that specific size.
|
||||||
|
:param eps: a value added to the denominator for numerical stability
|
||||||
|
"""
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def forward(ctx, input, weight, bias, normalized_shape, eps):
|
def forward(ctx, input, weight, bias, normalized_shape, eps):
|
||||||
|
|
|
@ -76,7 +76,12 @@ def _gather(input_, parallel_mode, dim=-1):
|
||||||
|
|
||||||
|
|
||||||
class _ReduceGrad(torch.autograd.Function):
|
class _ReduceGrad(torch.autograd.Function):
|
||||||
"""Pass the input to the model parallel region."""
|
"""
|
||||||
|
Pass the input to the model parallel region.
|
||||||
|
|
||||||
|
:param input_: input matrix
|
||||||
|
:param parallel_mode: parallel mode
|
||||||
|
"""
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def symbolic(graph, input_):
|
def symbolic(graph, input_):
|
||||||
return input_
|
return input_
|
||||||
|
@ -92,7 +97,12 @@ class _ReduceGrad(torch.autograd.Function):
|
||||||
|
|
||||||
|
|
||||||
class _ReduceInput(torch.autograd.Function):
|
class _ReduceInput(torch.autograd.Function):
|
||||||
"""All-reduce the input from the model parallel region."""
|
"""
|
||||||
|
All-reduce the input from the model parallel region.
|
||||||
|
|
||||||
|
:param input_: input matrix
|
||||||
|
:param parallel_mode: parallel mode
|
||||||
|
"""
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def symbolic(graph, input_):
|
def symbolic(graph, input_):
|
||||||
return _reduce(input_)
|
return _reduce(input_)
|
||||||
|
@ -107,7 +117,13 @@ class _ReduceInput(torch.autograd.Function):
|
||||||
|
|
||||||
|
|
||||||
class _SplitForwardGatherBackward(torch.autograd.Function):
|
class _SplitForwardGatherBackward(torch.autograd.Function):
|
||||||
"""Split the input and keep only the corresponding chuck to the rank."""
|
"""
|
||||||
|
Split the input and keep only the corresponding chuck to the rank.
|
||||||
|
|
||||||
|
:param input_: input matrix
|
||||||
|
:param parallel_mode: parallel mode
|
||||||
|
:param dim: dimension
|
||||||
|
"""
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def symbolic(graph, input_):
|
def symbolic(graph, input_):
|
||||||
return _split(input_)
|
return _split(input_)
|
||||||
|
@ -124,7 +140,13 @@ class _SplitForwardGatherBackward(torch.autograd.Function):
|
||||||
|
|
||||||
|
|
||||||
class _GatherForwardSplitBackward(torch.autograd.Function):
|
class _GatherForwardSplitBackward(torch.autograd.Function):
|
||||||
"""Gather the input from model parallel region and concatinate."""
|
"""
|
||||||
|
Gather the input from model parallel region and concatinate.
|
||||||
|
|
||||||
|
:param input_: input matrix
|
||||||
|
:param parallel_mode: parallel mode
|
||||||
|
:param dim: dimension
|
||||||
|
"""
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def symbolic(graph, input_):
|
def symbolic(graph, input_):
|
||||||
return _gather(input_)
|
return _gather(input_)
|
||||||
|
|
|
@ -26,6 +26,24 @@ from ._utils import (gather_forward_split_backward, get_parallel_input, reduce_g
|
||||||
|
|
||||||
@LAYERS.register_module
|
@LAYERS.register_module
|
||||||
class Linear1D(torch.nn.Module):
|
class Linear1D(torch.nn.Module):
|
||||||
|
"""
|
||||||
|
Linear layer for 1D parallelism
|
||||||
|
|
||||||
|
:param in_features: size of each input sample
|
||||||
|
:type in_features: int
|
||||||
|
:param out_features: size of each output sample
|
||||||
|
:type out_features: int
|
||||||
|
:param bias: If set to ``False``, the layer will not learn an additive bias, defaults to True
|
||||||
|
:type bias: bool, optional
|
||||||
|
:param dtype: The dtype of parameters, defaults to None
|
||||||
|
:type dtype: torch.dtype, optional
|
||||||
|
:param skip_bias_add: If set to ``True``, it will skip bias add for linear layer, which is preserved for kernel fusion, defaults to False
|
||||||
|
:type skip_bias_add: bool, optional
|
||||||
|
:param weight_initializer: The intializer of weight, defaults to kaiming uniform initializer
|
||||||
|
:type weight_initializer: typing.Callable, optional
|
||||||
|
:param bias_initializer: The intializer of bias, defaults to xavier uniform initializer
|
||||||
|
:type bias_initializer: typing.Callable, optional
|
||||||
|
"""
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
in_features: int,
|
in_features: int,
|
||||||
out_features: int,
|
out_features: int,
|
||||||
|
@ -70,8 +88,24 @@ class Linear1D(torch.nn.Module):
|
||||||
|
|
||||||
@LAYERS.register_module
|
@LAYERS.register_module
|
||||||
class Classifier1D(ParallelLayer):
|
class Classifier1D(ParallelLayer):
|
||||||
"""RowLinear with given weight"""
|
"""RowLinear with given weight
|
||||||
|
Classifier of 1D parallelism
|
||||||
|
|
||||||
|
:param in_features: size of input features
|
||||||
|
:type in_features: int
|
||||||
|
:param num_classes: number of classes in the dataset
|
||||||
|
:type num_classes: int
|
||||||
|
:param weight: weight of the classifier, defaults to True
|
||||||
|
:type weight: torch.nn.Parameter, optional
|
||||||
|
:param bias: If set to ``False``, the layer will not learn an additive bias, defaults to ``True``
|
||||||
|
:type bias: bool, optional
|
||||||
|
:param dtype: The dtype of parameters, defaults to None
|
||||||
|
:type dtype: torch.dtype, optional
|
||||||
|
:param weight_initializer: The intializer of weight, defaults to kaiming uniform initializer
|
||||||
|
:type weight_initializer: typing.Callable, optional
|
||||||
|
:param bias_initializer: The intializer of bias, defaults to xavier uniform initializer
|
||||||
|
:type bias_initializer: typing.Callable, optional
|
||||||
|
"""
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
in_features: int,
|
in_features: int,
|
||||||
num_classes: int,
|
num_classes: int,
|
||||||
|
@ -144,7 +178,7 @@ class Linear1D_Col(ParallelLayer):
|
||||||
:type in_features: int
|
:type in_features: int
|
||||||
:param output_size: second dimension of matrix A.
|
:param output_size: second dimension of matrix A.
|
||||||
:type output_size: int
|
:type output_size: int
|
||||||
:param bias: If true, add bias, defaults to True
|
:param bias: If set to ``False``, the layer will not learn an additive bias, defaults to ``True``
|
||||||
:type bias: bool, optional
|
:type bias: bool, optional
|
||||||
:param dtype: The dtype of parameters, defaults to None
|
:param dtype: The dtype of parameters, defaults to None
|
||||||
:type dtype: torch.dtype, optional
|
:type dtype: torch.dtype, optional
|
||||||
|
@ -228,7 +262,7 @@ class Linear1D_Row(ParallelLayer):
|
||||||
:type in_features: int
|
:type in_features: int
|
||||||
:param out_features: size of each output sample
|
:param out_features: size of each output sample
|
||||||
:type out_features: int
|
:type out_features: int
|
||||||
:param bias: If set to ``False``, the layer will not learn an additive bias, defaults to True
|
:param bias: If set to ``False``, the layer will not learn an additive bias, defaults to ``True``
|
||||||
:type bias: bool, optional
|
:type bias: bool, optional
|
||||||
:param dtype: The dtype of parameters, defaults to None
|
:param dtype: The dtype of parameters, defaults to None
|
||||||
:type dtype: torch.dtype, optional
|
:type dtype: torch.dtype, optional
|
||||||
|
@ -303,7 +337,16 @@ class Linear1D_Row(ParallelLayer):
|
||||||
|
|
||||||
@LAYERS.register_module
|
@LAYERS.register_module
|
||||||
class MixedFusedLayerNorm1D(torch.nn.Module):
|
class MixedFusedLayerNorm1D(torch.nn.Module):
|
||||||
""" Experimental
|
r"""
|
||||||
|
Layer Normalization for 1D parallelism
|
||||||
|
|
||||||
|
:param normalized_shape: input shape from an expected input
|
||||||
|
of size. :math:`[* \times \text{normalized_shape}[0] \times \text{normalized_shape}[1] \times \ldots \times \text{normalized_shape}[-1]]`
|
||||||
|
If a single integer is used, it is treated as a singleton list, and this module will
|
||||||
|
normalize over the last dimension which is expected to be of that specific size.
|
||||||
|
:type normalized_shape: int
|
||||||
|
:param eps: a value added to the denominator for numerical stability, defaults to 1e-05
|
||||||
|
:type eps: float, optional
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, normalized_shape, eps=1e-5):
|
def __init__(self, normalized_shape, eps=1e-5):
|
||||||
|
@ -327,6 +370,20 @@ class MixedFusedLayerNorm1D(torch.nn.Module):
|
||||||
|
|
||||||
@LAYERS.register_module
|
@LAYERS.register_module
|
||||||
class Embedding1D(ParallelLayer):
|
class Embedding1D(ParallelLayer):
|
||||||
|
"""
|
||||||
|
Embedding for 1D parallelism
|
||||||
|
|
||||||
|
:param num_embeddings: number of embeddings
|
||||||
|
:type num_embeddings: int
|
||||||
|
:param embedding_dim: dimension of embedding
|
||||||
|
:type embedding_dim: int
|
||||||
|
:param padding_idx: index of padding, defaults to None
|
||||||
|
:type padding_idx: int, optional
|
||||||
|
:param dtype: The dtype of parameters, defaults to None
|
||||||
|
:type dtype: torch.dtype, optional
|
||||||
|
:param weight_initializer: The intializer of weight, defaults to normal initializer
|
||||||
|
:type weight_initializer: typing.Callable, optional
|
||||||
|
"""
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
num_embeddings: int,
|
num_embeddings: int,
|
||||||
embedding_dim: int,
|
embedding_dim: int,
|
||||||
|
@ -377,6 +434,14 @@ class Embedding1D(ParallelLayer):
|
||||||
|
|
||||||
@LAYERS.register_module
|
@LAYERS.register_module
|
||||||
class Dropout1D(ParallelLayer):
|
class Dropout1D(ParallelLayer):
|
||||||
|
"""
|
||||||
|
Dropout layer of 1D parallelism
|
||||||
|
|
||||||
|
:param p: dropout rate, defaults to 0.5
|
||||||
|
:type p: float, optional
|
||||||
|
:param inplace: If set to ``True``, will do this operation in-place, defaults tp ``False``
|
||||||
|
:type inplace: bool, optional
|
||||||
|
"""
|
||||||
def __init__(self, p: float = 0.5, inplace: bool = False):
|
def __init__(self, p: float = 0.5, inplace: bool = False):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.parallel_input = get_parallel_input()
|
self.parallel_input = get_parallel_input()
|
||||||
|
|
|
@ -20,7 +20,8 @@ def matmul_2d(
|
||||||
row_parallel_mode=ParallelMode.PARALLEL_2D_ROW,
|
row_parallel_mode=ParallelMode.PARALLEL_2D_ROW,
|
||||||
col_parallel_mode=ParallelMode.PARALLEL_2D_COL,
|
col_parallel_mode=ParallelMode.PARALLEL_2D_COL,
|
||||||
):
|
):
|
||||||
"""Matrix multiplication for 2D parallelism
|
"""
|
||||||
|
Matrix multiplication for 2D parallelism
|
||||||
:param a: matrix :math:`A`
|
:param a: matrix :math:`A`
|
||||||
:type a: torch.tensor
|
:type a: torch.tensor
|
||||||
:param b: matrix :math:`B`
|
:param b: matrix :math:`B`
|
||||||
|
@ -56,7 +57,35 @@ def matmul_2d(
|
||||||
|
|
||||||
|
|
||||||
class classifier_2d(torch.autograd.Function):
|
class classifier_2d(torch.autograd.Function):
|
||||||
"""Matrix multiplication for :math:`C = AB`
|
"""
|
||||||
|
Classifier
|
||||||
|
|
||||||
|
:param a: matrix :math:`A`
|
||||||
|
:type a: torch.tensor
|
||||||
|
:param b: matrix :math:`B`
|
||||||
|
:type b: torch.tensor
|
||||||
|
:param bias: matrix of bias
|
||||||
|
:type bias: torch.tensor, optional
|
||||||
|
:param summa_dim: dimension of SUMMA fo 2D parallelism
|
||||||
|
:type summa_dim: int
|
||||||
|
:param out_shape: shape of output tensor
|
||||||
|
:type out_shape: tuple
|
||||||
|
:param row_rank: the rank of row
|
||||||
|
:type row_rank: int
|
||||||
|
:param col_rank: the rank of column
|
||||||
|
:type col_rank: int
|
||||||
|
:param row_parallel_mode: row parallel mode
|
||||||
|
:type row_parallel_mode: colossalai.context.parallel_mode.ParallelMode
|
||||||
|
:param col_parallel_mode: column parallel mode
|
||||||
|
:type col_parallel_mode: colossalai.context.parallel_mode.ParallelMode
|
||||||
|
:param data_parallel_rank: data parallel rank
|
||||||
|
:type data_parallel_rank: int
|
||||||
|
:param pipeline_parallel_rank: pipeline parallel rank
|
||||||
|
:type pipeline_parallel_rank: int
|
||||||
|
:param pipeline_parallel_size: pipeline parallel size
|
||||||
|
:type pipeline_parallel_size: int
|
||||||
|
:param tensor_parallel_size: tensor parallel size
|
||||||
|
:type tensor_parallel_size: int
|
||||||
"""
|
"""
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@custom_fwd(cast_inputs=torch.float16)
|
@custom_fwd(cast_inputs=torch.float16)
|
||||||
|
@ -130,7 +159,33 @@ class classifier_2d(torch.autograd.Function):
|
||||||
|
|
||||||
|
|
||||||
class Matmul_AB_2D(torch.autograd.Function):
|
class Matmul_AB_2D(torch.autograd.Function):
|
||||||
"""Matrix multiplication for :math:`C = AB`
|
"""
|
||||||
|
Matrix multiplication for :math:`C = AB`
|
||||||
|
|
||||||
|
:param a: matrix :math:`A`
|
||||||
|
:type a: torch.tensor
|
||||||
|
:param b: matrix :math:`B`
|
||||||
|
:type b: torch.tensor
|
||||||
|
:param summa_dim: dimension of SUMMA fo 2D parallelism
|
||||||
|
:type summa_dim: int
|
||||||
|
:param out_shape: shape of output tensor
|
||||||
|
:type out_shape: tuple
|
||||||
|
:param row_rank: the rank of row
|
||||||
|
:type row_rank: int
|
||||||
|
:param col_rank: the rank of column
|
||||||
|
:type col_rank: int
|
||||||
|
:param row_parallel_mode: row parallel mode
|
||||||
|
:type row_parallel_mode: colossalai.context.parallel_mode.ParallelMode
|
||||||
|
:param col_parallel_mode: column parallel mode
|
||||||
|
:type col_parallel_mode: colossalai.context.parallel_mode.ParallelMode
|
||||||
|
:param data_parallel_rank: data parallel rank
|
||||||
|
:type data_parallel_rank: int
|
||||||
|
:param pipeline_parallel_rank: pipeline parallel rank
|
||||||
|
:type pipeline_parallel_rank: int
|
||||||
|
:param pipeline_parallel_size: pipeline parallel size
|
||||||
|
:type pipeline_parallel_size: int
|
||||||
|
:param tensor_parallel_size: tensor parallel size
|
||||||
|
:type tensor_parallel_size: int
|
||||||
"""
|
"""
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@custom_fwd(cast_inputs=torch.float16)
|
@custom_fwd(cast_inputs=torch.float16)
|
||||||
|
@ -238,7 +293,33 @@ class Matmul_AB_2D(torch.autograd.Function):
|
||||||
|
|
||||||
|
|
||||||
class Matmul_ABT_2D(torch.autograd.Function):
|
class Matmul_ABT_2D(torch.autograd.Function):
|
||||||
"""Matrix multiplication for :math:`C = AB^T`
|
"""
|
||||||
|
Matrix multiplication for :math:`C = AB^T`
|
||||||
|
|
||||||
|
:param a: matrix :math:`A`
|
||||||
|
:type a: torch.tensor
|
||||||
|
:param b: matrix :math:`B`
|
||||||
|
:type b: torch.tensor
|
||||||
|
:param summa_dim: dimension of SUMMA fo 2D parallelism
|
||||||
|
:type summa_dim: int
|
||||||
|
:param out_shape: shape of output tensor
|
||||||
|
:type out_shape: tuple
|
||||||
|
:param row_rank: the rank of row
|
||||||
|
:type row_rank: int
|
||||||
|
:param col_rank: the rank of column
|
||||||
|
:type col_rank: int
|
||||||
|
:param row_parallel_mode: row parallel mode
|
||||||
|
:type row_parallel_mode: colossalai.context.parallel_mode.ParallelMode
|
||||||
|
:param col_parallel_mode: column parallel mode
|
||||||
|
:type col_parallel_mode: colossalai.context.parallel_mode.ParallelMode
|
||||||
|
:param data_parallel_rank: data parallel rank
|
||||||
|
:type data_parallel_rank: int
|
||||||
|
:param pipeline_parallel_rank: pipeline parallel rank
|
||||||
|
:type pipeline_parallel_rank: int
|
||||||
|
:param pipeline_parallel_size: pipeline parallel size
|
||||||
|
:type pipeline_parallel_size: int
|
||||||
|
:param tensor_parallel_size: tensor parallel size
|
||||||
|
:type tensor_parallel_size: int
|
||||||
"""
|
"""
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@custom_fwd(cast_inputs=torch.float16)
|
@custom_fwd(cast_inputs=torch.float16)
|
||||||
|
@ -352,7 +433,33 @@ class Matmul_ABT_2D(torch.autograd.Function):
|
||||||
|
|
||||||
|
|
||||||
class Matmul_ATB_2D(torch.autograd.Function):
|
class Matmul_ATB_2D(torch.autograd.Function):
|
||||||
"""Matrix multiplication for :math:`C = A^TB`
|
"""
|
||||||
|
Matrix multiplication for :math:`C = A^TB`
|
||||||
|
|
||||||
|
:param a: matrix :math:`A`
|
||||||
|
:type a: torch.tensor
|
||||||
|
:param b: matrix :math:`B`
|
||||||
|
:type b: torch.tensor
|
||||||
|
:param summa_dim: dimension of SUMMA fo 2D parallelism
|
||||||
|
:type summa_dim: int
|
||||||
|
:param out_shape: shape of output tensor
|
||||||
|
:type out_shape: tuple
|
||||||
|
:param row_rank: the rank of row
|
||||||
|
:type row_rank: int
|
||||||
|
:param col_rank: the rank of column
|
||||||
|
:type col_rank: int
|
||||||
|
:param row_parallel_mode: row parallel mode
|
||||||
|
:type row_parallel_mode: colossalai.context.parallel_mode.ParallelMode
|
||||||
|
:param col_parallel_mode: column parallel mode
|
||||||
|
:type col_parallel_mode: colossalai.context.parallel_mode.ParallelMode
|
||||||
|
:param data_parallel_rank: data parallel rank
|
||||||
|
:type data_parallel_rank: int
|
||||||
|
:param pipeline_parallel_rank: pipeline parallel rank
|
||||||
|
:type pipeline_parallel_rank: int
|
||||||
|
:param pipeline_parallel_size: pipeline parallel size
|
||||||
|
:type pipeline_parallel_size: int
|
||||||
|
:param tensor_parallel_size: tensor parallel size
|
||||||
|
:type tensor_parallel_size: int
|
||||||
"""
|
"""
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@custom_fwd(cast_inputs=torch.float16)
|
@custom_fwd(cast_inputs=torch.float16)
|
||||||
|
@ -466,7 +573,33 @@ class Matmul_ATB_2D(torch.autograd.Function):
|
||||||
|
|
||||||
|
|
||||||
class add_bias_2d(torch.autograd.Function):
|
class add_bias_2d(torch.autograd.Function):
|
||||||
"""Matrix add bias: :math:`C = A + b`
|
"""
|
||||||
|
Matrix add bias: :math:`C = A + b`
|
||||||
|
|
||||||
|
:param input_: matrix :math:`A`
|
||||||
|
:type input_: torch.tensor
|
||||||
|
:param bias: matrix :math:`b`
|
||||||
|
:type bias: torch.tensor
|
||||||
|
:param output_size_per_partition: size of ouput per partition
|
||||||
|
:type output_size_per_partition: int
|
||||||
|
:param row_rank: the rank of row
|
||||||
|
:type row_rank: int
|
||||||
|
:param col_rank: the rank of column
|
||||||
|
:type col_rank: int
|
||||||
|
:param row_parallel_mode: row parallel mode
|
||||||
|
:type row_parallel_mode: colossalai.context.parallel_mode.ParallelMode
|
||||||
|
:param col_parallel_mode: column parallel mode
|
||||||
|
:type col_parallel_mode: colossalai.context.parallel_mode.ParallelMode
|
||||||
|
:param skip_bias_add: If set to ``True``, it will skip bias add for linear layer, which is preserved for kernel fusion
|
||||||
|
:type skip_bias_add: bool
|
||||||
|
:param data_parallel_rank: data parallel rank
|
||||||
|
:type data_parallel_rank: int
|
||||||
|
:param pipeline_parallel_rank: pipeline parallel rank
|
||||||
|
:type pipeline_parallel_rank: int
|
||||||
|
:param pipeline_parallel_size: pipeline parallel size
|
||||||
|
:type pipeline_parallel_size: int
|
||||||
|
:param tensor_parallel_size: tensor parallel size
|
||||||
|
:type tensor_parallel_size: int
|
||||||
"""
|
"""
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@custom_fwd(cast_inputs=torch.float16)
|
@custom_fwd(cast_inputs=torch.float16)
|
||||||
|
@ -519,9 +652,30 @@ class add_bias_2d(torch.autograd.Function):
|
||||||
|
|
||||||
|
|
||||||
class layernorm_2d(torch.autograd.Function):
|
class layernorm_2d(torch.autograd.Function):
|
||||||
|
"""
|
||||||
|
Layernorm
|
||||||
|
|
||||||
|
:param input_: input maxtrix
|
||||||
|
:type input_: torch.tensor
|
||||||
|
:param E_x: mean
|
||||||
|
:type E_x: torch.tensor
|
||||||
|
:param Var_x: variance
|
||||||
|
:type Var_x: torch.tensor
|
||||||
|
:param hidden_size: hidden size
|
||||||
|
:type hidden_size: int
|
||||||
|
:param row_parallel_mode: row parallel mode
|
||||||
|
:type row_parallel_mode: colossalai.context.parallel_mode.ParallelMode
|
||||||
|
:param col_parallel_mode: column parallel mode
|
||||||
|
:type col_parallel_mode: colossalai.context.parallel_mode.ParallelMode
|
||||||
|
"""
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@custom_fwd(cast_inputs=torch.float32)
|
@custom_fwd(cast_inputs=torch.float32)
|
||||||
def forward(ctx: Any, input_: Tensor, E_x: Tensor, Var_x: Tensor, hidden_size: int, row_parallel_mode: ParallelMode,
|
def forward(ctx: Any,
|
||||||
|
input_: Tensor,
|
||||||
|
E_x: Tensor,
|
||||||
|
Var_x: Tensor,
|
||||||
|
hidden_size: int,
|
||||||
|
row_parallel_mode: ParallelMode,
|
||||||
col_parallel_mode: ParallelMode) -> Tensor:
|
col_parallel_mode: ParallelMode) -> Tensor:
|
||||||
input_ = input_ - E_x
|
input_ = input_ - E_x
|
||||||
# in here, input = x - E[x], Var_x = 1 / sqrt(Var[x] + eps)
|
# in here, input = x - E[x], Var_x = 1 / sqrt(Var[x] + eps)
|
||||||
|
@ -556,6 +710,18 @@ class layernorm_2d(torch.autograd.Function):
|
||||||
|
|
||||||
|
|
||||||
class all_gather_weight_2d(torch.autograd.Function):
|
class all_gather_weight_2d(torch.autograd.Function):
|
||||||
|
"""
|
||||||
|
all gather the weight of 2D parallelism
|
||||||
|
|
||||||
|
:param inputs: input maxtrix
|
||||||
|
:type inputs: torch.tensor
|
||||||
|
:param dim: dimension of all gather
|
||||||
|
:type dim: int
|
||||||
|
:param summa_dim: dimension of SUMMA fo 2D parallelism
|
||||||
|
:type summa_dim: int
|
||||||
|
:param col_parallel_mode: column parallel mode
|
||||||
|
:type col_parallel_mode: colossalai.context.parallel_mode.ParallelMode
|
||||||
|
"""
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@custom_fwd(cast_inputs=torch.float16)
|
@custom_fwd(cast_inputs=torch.float16)
|
||||||
def forward(ctx: Any, inputs: Tensor, dim: int, summa_dim: int, col_parallel_mode: ParallelMode) -> Tensor:
|
def forward(ctx: Any, inputs: Tensor, dim: int, summa_dim: int, col_parallel_mode: ParallelMode) -> Tensor:
|
||||||
|
@ -574,6 +740,14 @@ class all_gather_weight_2d(torch.autograd.Function):
|
||||||
|
|
||||||
|
|
||||||
class SplitFirst(torch.autograd.Function):
|
class SplitFirst(torch.autograd.Function):
|
||||||
|
"""
|
||||||
|
:param inputs: input maxtrix
|
||||||
|
:type inputs: torch.tensor
|
||||||
|
:param summa_dim: dimension of SUMMA fo 2D parallelism
|
||||||
|
:type summa_dim: int
|
||||||
|
:param col_parallel_mode: column parallel mode
|
||||||
|
:type col_parallel_mode: colossalai.context.parallel_mode.ParallelMode
|
||||||
|
"""
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@custom_fwd(cast_inputs=torch.float16)
|
@custom_fwd(cast_inputs=torch.float16)
|
||||||
def forward(ctx: Any, inputs: Tensor, summa_dim: int, col_parallel_mode: ParallelMode) -> Tensor:
|
def forward(ctx: Any, inputs: Tensor, summa_dim: int, col_parallel_mode: ParallelMode) -> Tensor:
|
||||||
|
@ -604,7 +778,14 @@ def split_tensor_2d(input_: Tensor, dim: int = 0) -> Tensor:
|
||||||
|
|
||||||
|
|
||||||
class reduce_by_batch_2d(torch.autograd.Function):
|
class reduce_by_batch_2d(torch.autograd.Function):
|
||||||
"""All-reduce the input from the model parallel region."""
|
"""
|
||||||
|
All-reduce the input from the model parallel region.
|
||||||
|
|
||||||
|
:param input_: input maxtrix
|
||||||
|
:type input_: torch.tensor
|
||||||
|
:param reduce_mean: If set to ``True``, it will divide the output by column parallel size, default to False
|
||||||
|
:type reduce_mean: int, optional
|
||||||
|
"""
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def symbolic(graph, input_, reduce_mean: bool = False):
|
def symbolic(graph, input_, reduce_mean: bool = False):
|
||||||
output = all_reduce(input_, ParallelMode.PARALLEL_2D_COL)
|
output = all_reduce(input_, ParallelMode.PARALLEL_2D_COL)
|
||||||
|
|
|
@ -21,7 +21,8 @@ from ._utils import assert_summa_initialization, get_summa_dim_from_env
|
||||||
|
|
||||||
@LAYERS.register_module
|
@LAYERS.register_module
|
||||||
class Linear2D(ParallelLayer):
|
class Linear2D(ParallelLayer):
|
||||||
""" Linear layer for 2D parallelism
|
"""
|
||||||
|
Linear layer for 2D parallelism
|
||||||
|
|
||||||
:param in_features: size of each input sample
|
:param in_features: size of each input sample
|
||||||
:type in_features: int
|
:type in_features: int
|
||||||
|
@ -33,6 +34,10 @@ class Linear2D(ParallelLayer):
|
||||||
:type dtype: torch.dtype, optional
|
:type dtype: torch.dtype, optional
|
||||||
:param skip_bias_add: If set to ``True``, it will skip bias add for linear layer, which is preserved for kernel fusion, defaults to False
|
:param skip_bias_add: If set to ``True``, it will skip bias add for linear layer, which is preserved for kernel fusion, defaults to False
|
||||||
:type skip_bias_add: bool, optional
|
:type skip_bias_add: bool, optional
|
||||||
|
:param weight_initializer: The intializer of weight, defaults to kaiming uniform initializer
|
||||||
|
:type weight_initializer: typing.Callable, optional
|
||||||
|
:param bias_initializer: The intializer of bias, defaults to xavier uniform initializer
|
||||||
|
:type bias_initializer: typing.Callable, optional
|
||||||
"""
|
"""
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
in_features: int,
|
in_features: int,
|
||||||
|
@ -113,7 +118,8 @@ class Linear2D(ParallelLayer):
|
||||||
|
|
||||||
@LAYERS.register_module
|
@LAYERS.register_module
|
||||||
class LayerNorm2D(ParallelLayer):
|
class LayerNorm2D(ParallelLayer):
|
||||||
r"""Layer Normalization for 2D parallelism
|
r"""
|
||||||
|
Layer Normalization for 2D parallelism
|
||||||
|
|
||||||
:param normalized_shape: input shape from an expected input
|
:param normalized_shape: input shape from an expected input
|
||||||
of size. :math:`[* \times \text{normalized_shape}[0] \times \text{normalized_shape}[1] \times \ldots \times \text{normalized_shape}[-1]]`
|
of size. :math:`[* \times \text{normalized_shape}[0] \times \text{normalized_shape}[1] \times \ldots \times \text{normalized_shape}[-1]]`
|
||||||
|
@ -184,18 +190,27 @@ class LayerNorm2D(ParallelLayer):
|
||||||
|
|
||||||
@LAYERS.register_module
|
@LAYERS.register_module
|
||||||
class PatchEmbedding2D(ParallelLayer):
|
class PatchEmbedding2D(ParallelLayer):
|
||||||
""" 2D Image to Patch Embedding
|
"""
|
||||||
|
2D Image to Patch Embedding
|
||||||
|
|
||||||
:param img_size: iamge size
|
:param img_size: image size
|
||||||
:type img_size: int
|
:type img_size: int
|
||||||
:param patch_size: patch size
|
:param patch_size: patch size
|
||||||
:type patch_size: int
|
:type patch_size: int
|
||||||
:param embed_dim: dimension of embedding
|
:param in_chans: number of channels of input image
|
||||||
:type embed_dim: int
|
:type in_chans: int
|
||||||
:param in_chans: number of channels of input image, defaults to 3
|
:param embed_size: size of embedding
|
||||||
:type in_chans: int, optional
|
:type embed_size: int
|
||||||
|
:param dtype: The dtype of parameters, defaults to None
|
||||||
|
:type dtype: torch.dtype, optional
|
||||||
:param flatten: whether to flatten output tensor, defaults to True
|
:param flatten: whether to flatten output tensor, defaults to True
|
||||||
:type flatten: bool, optional
|
:type flatten: bool, optional
|
||||||
|
:param weight_initializer: The intializer of weight, defaults to kaiming uniform initializer
|
||||||
|
:type weight_initializer: typing.Callable, optional
|
||||||
|
:param bias_initializer: The intializer of bias, defaults to xavier uniform initializer
|
||||||
|
:type bias_initializer: typing.Callable, optional
|
||||||
|
:param position_embed_initializer: The intializer of position embedding, defaults to zero
|
||||||
|
:type position_embed_initializer: typing.Callable, optional
|
||||||
"""
|
"""
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
img_size: int,
|
img_size: int,
|
||||||
|
@ -275,6 +290,20 @@ class PatchEmbedding2D(ParallelLayer):
|
||||||
|
|
||||||
@LAYERS.register_module
|
@LAYERS.register_module
|
||||||
class Embedding2D(ParallelLayer):
|
class Embedding2D(ParallelLayer):
|
||||||
|
"""
|
||||||
|
Embedding for 2D parallelism
|
||||||
|
|
||||||
|
:param num_embeddings: number of embeddings
|
||||||
|
:type num_embeddings: int
|
||||||
|
:param embedding_dim: dimension of embedding
|
||||||
|
:type embedding_dim: int
|
||||||
|
:param padding_idx: index of padding, defaults to None
|
||||||
|
:type padding_idx: int, optional
|
||||||
|
:param dtype: The dtype of parameters, defaults to None
|
||||||
|
:type dtype: torch.dtype, optional
|
||||||
|
:param weight_initializer: The intializer of weight, defaults to normal initializer
|
||||||
|
:type weight_initializer: typing.Callable, optional
|
||||||
|
"""
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
num_embeddings: int,
|
num_embeddings: int,
|
||||||
embedding_dim: int,
|
embedding_dim: int,
|
||||||
|
@ -325,6 +354,24 @@ class Embedding2D(ParallelLayer):
|
||||||
|
|
||||||
@LAYERS.register_module
|
@LAYERS.register_module
|
||||||
class Classifier2D(ParallelLayer):
|
class Classifier2D(ParallelLayer):
|
||||||
|
"""
|
||||||
|
Classifier for 2D parallelism
|
||||||
|
|
||||||
|
:param in_features: size of each input sample
|
||||||
|
:type in_features: int
|
||||||
|
:param num_classes: number of classes
|
||||||
|
:type num_classes: int
|
||||||
|
:param weight: weight of the classifier, defaults to True
|
||||||
|
:type weight: torch.nn.Parameter, optional
|
||||||
|
:param bias: If set to ``False``, the layer will not learn an additive bias, defaults to ``True``
|
||||||
|
:type bias: bool, optional
|
||||||
|
:param dtype: The dtype of parameters, defaults to None
|
||||||
|
:type dtype: torch.dtype, optional
|
||||||
|
:param weight_initializer: The intializer of weight, defaults to kaiming uniform initializer
|
||||||
|
:type weight_initializer: typing.Callable, optional
|
||||||
|
:param bias_initializer: The intializer of bias, defaults to xavier uniform initializer
|
||||||
|
:type bias_initializer: typing.Callable, optional
|
||||||
|
"""
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
in_features: int,
|
in_features: int,
|
||||||
num_classes: int,
|
num_classes: int,
|
||||||
|
|
|
@ -28,7 +28,35 @@ def split_tensor_2p5d(input_: Tensor, dim: int = 0) -> Tensor:
|
||||||
|
|
||||||
|
|
||||||
class classifier_2p5d(torch.autograd.Function):
|
class classifier_2p5d(torch.autograd.Function):
|
||||||
"""Matrix multiplication for :math:`C = AB`
|
"""
|
||||||
|
Classifier
|
||||||
|
|
||||||
|
:param a: matrix :math:`A`
|
||||||
|
:type a: torch.tensor
|
||||||
|
:param b: matrix :math:`B`
|
||||||
|
:type b: torch.tensor
|
||||||
|
:param bias: matrix of bias
|
||||||
|
:type bias: torch.tensor, optional
|
||||||
|
:param tesseract_dim: dimension of TESSERACT fo 2.5D parallelism
|
||||||
|
:type tesseract_dim: int
|
||||||
|
:param out_shape: shape of output tensor
|
||||||
|
:type out_shape: tuple
|
||||||
|
:param row_rank: the rank of row
|
||||||
|
:type row_rank: int
|
||||||
|
:param col_rank: the rank of column
|
||||||
|
:type col_rank: int
|
||||||
|
:param row_parallel_mode: row parallel mode
|
||||||
|
:type row_parallel_mode: colossalai.context.parallel_mode.ParallelMode
|
||||||
|
:param col_parallel_mode: column parallel mode
|
||||||
|
:type col_parallel_mode: colossalai.context.parallel_mode.ParallelMode
|
||||||
|
:param data_parallel_rank: data parallel rank
|
||||||
|
:type data_parallel_rank: int
|
||||||
|
:param pipeline_parallel_rank: pipeline parallel rank
|
||||||
|
:type pipeline_parallel_rank: int
|
||||||
|
:param pipeline_parallel_size: pipeline parallel size
|
||||||
|
:type pipeline_parallel_size: int
|
||||||
|
:param tensor_parallel_size: tensor parallel size
|
||||||
|
:type tensor_parallel_size: int
|
||||||
"""
|
"""
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@custom_fwd(cast_inputs=torch.float16)
|
@custom_fwd(cast_inputs=torch.float16)
|
||||||
|
@ -101,7 +129,35 @@ class classifier_2p5d(torch.autograd.Function):
|
||||||
|
|
||||||
|
|
||||||
class Matmul_AB_2p5D(torch.autograd.Function):
|
class Matmul_AB_2p5D(torch.autograd.Function):
|
||||||
"""Matrix multiplication for :math:`C = AB`
|
"""
|
||||||
|
Matrix multiplication for :math:`C = AB`
|
||||||
|
|
||||||
|
:param a: matrix :math:`A`
|
||||||
|
:type a: torch.tensor
|
||||||
|
:param b: matrix :math:`B`
|
||||||
|
:type b: torch.tensor
|
||||||
|
:param tesseract_dim: dimension of TESSERACT fo 2.5D parallelism
|
||||||
|
:type tesseract_dim: int
|
||||||
|
:param out_shape: shape of output tensor
|
||||||
|
:type out_shape: tuple
|
||||||
|
:param row_rank: the rank of row
|
||||||
|
:type row_rank: int
|
||||||
|
:param col_rank: the rank of column
|
||||||
|
:type col_rank: int
|
||||||
|
:param dep_rank: the rank of depth
|
||||||
|
:type dep_rank: int
|
||||||
|
:param row_parallel_mode: row parallel mode
|
||||||
|
:type row_parallel_mode: colossalai.context.parallel_mode.ParallelMode
|
||||||
|
:param col_parallel_mode: column parallel mode
|
||||||
|
:type col_parallel_mode: colossalai.context.parallel_mode.ParallelMode
|
||||||
|
:param data_parallel_rank: data parallel rank
|
||||||
|
:type data_parallel_rank: int
|
||||||
|
:param pipeline_parallel_rank: pipeline parallel rank
|
||||||
|
:type pipeline_parallel_rank: int
|
||||||
|
:param pipeline_parallel_size: pipeline parallel size
|
||||||
|
:type pipeline_parallel_size: int
|
||||||
|
:param tensor_parallel_size: tensor parallel size
|
||||||
|
:type tensor_parallel_size: int
|
||||||
"""
|
"""
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@custom_fwd(cast_inputs=torch.float16)
|
@custom_fwd(cast_inputs=torch.float16)
|
||||||
|
@ -202,7 +258,35 @@ class Matmul_AB_2p5D(torch.autograd.Function):
|
||||||
|
|
||||||
|
|
||||||
class Matmul_ABT_2p5D(torch.autograd.Function):
|
class Matmul_ABT_2p5D(torch.autograd.Function):
|
||||||
"""Matrix multiplication for :math:`C = AB^T`
|
"""
|
||||||
|
Matrix multiplication for :math:`C = AB^T`
|
||||||
|
|
||||||
|
:param a: matrix :math:`A`
|
||||||
|
:type a: torch.tensor
|
||||||
|
:param b: matrix :math:`B`
|
||||||
|
:type b: torch.tensor
|
||||||
|
:param tesseract_dim: dimension of TESSERACT fo 2.5D parallelism
|
||||||
|
:type tesseract_dim: int
|
||||||
|
:param out_shape: shape of output tensor
|
||||||
|
:type out_shape: tuple
|
||||||
|
:param row_rank: the rank of row
|
||||||
|
:type row_rank: int
|
||||||
|
:param col_rank: the rank of column
|
||||||
|
:type col_rank: int
|
||||||
|
:param dep_rank: the rank of depth
|
||||||
|
:type dep_rank: int
|
||||||
|
:param row_parallel_mode: row parallel mode
|
||||||
|
:type row_parallel_mode: colossalai.context.parallel_mode.ParallelMode
|
||||||
|
:param col_parallel_mode: column parallel mode
|
||||||
|
:type col_parallel_mode: colossalai.context.parallel_mode.ParallelMode
|
||||||
|
:param data_parallel_rank: data parallel rank
|
||||||
|
:type data_parallel_rank: int
|
||||||
|
:param pipeline_parallel_rank: pipeline parallel rank
|
||||||
|
:type pipeline_parallel_rank: int
|
||||||
|
:param pipeline_parallel_size: pipeline parallel size
|
||||||
|
:type pipeline_parallel_size: int
|
||||||
|
:param tensor_parallel_size: tensor parallel size
|
||||||
|
:type tensor_parallel_size: int
|
||||||
"""
|
"""
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@custom_fwd(cast_inputs=torch.float16)
|
@custom_fwd(cast_inputs=torch.float16)
|
||||||
|
@ -308,7 +392,35 @@ class Matmul_ABT_2p5D(torch.autograd.Function):
|
||||||
|
|
||||||
|
|
||||||
class Matmul_ATB_2p5D(torch.autograd.Function):
|
class Matmul_ATB_2p5D(torch.autograd.Function):
|
||||||
"""Matrix multiplication for :math:`C = A^TB`
|
"""
|
||||||
|
Matrix multiplication for :math:`C = A^TB`
|
||||||
|
|
||||||
|
:param a: matrix :math:`A`
|
||||||
|
:type a: torch.tensor
|
||||||
|
:param b: matrix :math:`B`
|
||||||
|
:type b: torch.tensor
|
||||||
|
:param tesseract_dim: dimension of TESSERACT fo 2.5D parallelism
|
||||||
|
:type tesseract_dim: int
|
||||||
|
:param out_shape: shape of output tensor
|
||||||
|
:type out_shape: tuple
|
||||||
|
:param row_rank: the rank of row
|
||||||
|
:type row_rank: int
|
||||||
|
:param col_rank: the rank of column
|
||||||
|
:type col_rank: int
|
||||||
|
:param dep_rank: the rank of depth
|
||||||
|
:type dep_rank: int
|
||||||
|
:param row_parallel_mode: row parallel mode
|
||||||
|
:type row_parallel_mode: colossalai.context.parallel_mode.ParallelMode
|
||||||
|
:param col_parallel_mode: column parallel mode
|
||||||
|
:type col_parallel_mode: colossalai.context.parallel_mode.ParallelMode
|
||||||
|
:param data_parallel_rank: data parallel rank
|
||||||
|
:type data_parallel_rank: int
|
||||||
|
:param pipeline_parallel_rank: pipeline parallel rank
|
||||||
|
:type pipeline_parallel_rank: int
|
||||||
|
:param pipeline_parallel_size: pipeline parallel size
|
||||||
|
:type pipeline_parallel_size: int
|
||||||
|
:param tensor_parallel_size: tensor parallel size
|
||||||
|
:type tensor_parallel_size: int
|
||||||
"""
|
"""
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@custom_fwd(cast_inputs=torch.float16)
|
@custom_fwd(cast_inputs=torch.float16)
|
||||||
|
@ -411,7 +523,35 @@ class Matmul_ATB_2p5D(torch.autograd.Function):
|
||||||
|
|
||||||
|
|
||||||
class Add_Bias_2p5D(torch.autograd.Function):
|
class Add_Bias_2p5D(torch.autograd.Function):
|
||||||
"""Matrix add bias: :math:`C = A + b`
|
"""
|
||||||
|
Matrix add bias: :math:`C = A + b`
|
||||||
|
|
||||||
|
:param input: matrix :math:`A`
|
||||||
|
:type input: torch.tensor
|
||||||
|
:param bias: matrix :math:`b`
|
||||||
|
:type bias: torch.tensor
|
||||||
|
:param output_size_per_partition: output size in each partition
|
||||||
|
:type output_size_per_partition: int
|
||||||
|
:param tesseract_dim: dimension of TESSERACT fo 2.5D parallelism
|
||||||
|
:type tesseract_dim: int
|
||||||
|
:param row_rank: the rank of row
|
||||||
|
:type row_rank: int
|
||||||
|
:param col_rank: the rank of column
|
||||||
|
:type col_rank: int
|
||||||
|
:param row_parallel_mode: row parallel mode
|
||||||
|
:type row_parallel_mode: colossalai.context.parallel_mode.ParallelMode
|
||||||
|
:param col_parallel_mode: column parallel mode
|
||||||
|
:type col_parallel_mode: colossalai.context.parallel_mode.ParallelMode
|
||||||
|
:param skip_bias_add: If set to ``True``, it will skip bias add for linear layer, which is preserved for kernel fusion
|
||||||
|
:type skip_bias_add: bool
|
||||||
|
:param data_parallel_rank: data parallel rank
|
||||||
|
:type data_parallel_rank: int
|
||||||
|
:param pipeline_parallel_rank: pipeline parallel rank
|
||||||
|
:type pipeline_parallel_rank: int
|
||||||
|
:param pipeline_parallel_size: pipeline parallel size
|
||||||
|
:type pipeline_parallel_size: int
|
||||||
|
:param tensor_parallel_size: tensor parallel size
|
||||||
|
:type tensor_parallel_size: int
|
||||||
"""
|
"""
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@custom_fwd(cast_inputs=torch.float16)
|
@custom_fwd(cast_inputs=torch.float16)
|
||||||
|
@ -482,6 +622,20 @@ class Add_Bias_2p5D(torch.autograd.Function):
|
||||||
|
|
||||||
|
|
||||||
class layernorm_2p5d(torch.autograd.Function):
|
class layernorm_2p5d(torch.autograd.Function):
|
||||||
|
"""
|
||||||
|
Layernorm
|
||||||
|
|
||||||
|
:param input: input maxtrix
|
||||||
|
:type input: torch.tensor
|
||||||
|
:param E_x: mean
|
||||||
|
:type E_x: torch.tensor
|
||||||
|
:param Var_x: variance
|
||||||
|
:type Var_x: torch.tensor
|
||||||
|
:param hidden_size: hidden size
|
||||||
|
:type hidden_size: int
|
||||||
|
:param row_parallel_mode: row parallel mode
|
||||||
|
:type row_parallel_mode: colossalai.context.parallel_mode.ParallelMode
|
||||||
|
"""
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@custom_fwd(cast_inputs=torch.float32)
|
@custom_fwd(cast_inputs=torch.float32)
|
||||||
def forward(ctx: Any, input: Tensor, E_x: Tensor, Var_x: Tensor, hidden_size: int,
|
def forward(ctx: Any, input: Tensor, E_x: Tensor, Var_x: Tensor, hidden_size: int,
|
||||||
|
@ -518,6 +672,18 @@ class layernorm_2p5d(torch.autograd.Function):
|
||||||
|
|
||||||
|
|
||||||
class all_gather_weight_2p5d(torch.autograd.Function):
|
class all_gather_weight_2p5d(torch.autograd.Function):
|
||||||
|
"""
|
||||||
|
all gather the weight of 2.5D parallelism
|
||||||
|
|
||||||
|
:param inputs: input maxtrix
|
||||||
|
:type inputs: torch.tensor
|
||||||
|
:param dim: dimension of all gather
|
||||||
|
:type dim: int
|
||||||
|
:param tesseract_dim: dimension of TESSERACT fo 2.5D parallelism
|
||||||
|
:type tesseract_dim: int
|
||||||
|
:param col_parallel_mode: column parallel mode
|
||||||
|
:type col_parallel_mode: colossalai.context.parallel_mode.ParallelMode
|
||||||
|
"""
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@custom_fwd(cast_inputs=torch.float16)
|
@custom_fwd(cast_inputs=torch.float16)
|
||||||
def forward(ctx: Any, inputs: Tensor, dim: int, tesseract_dim: int, col_parallel_mode: ParallelMode) -> Tensor:
|
def forward(ctx: Any, inputs: Tensor, dim: int, tesseract_dim: int, col_parallel_mode: ParallelMode) -> Tensor:
|
||||||
|
@ -536,6 +702,14 @@ class all_gather_weight_2p5d(torch.autograd.Function):
|
||||||
|
|
||||||
|
|
||||||
class SplitFirst(torch.autograd.Function):
|
class SplitFirst(torch.autograd.Function):
|
||||||
|
"""
|
||||||
|
:param inputs: input maxtrix
|
||||||
|
:type inputs: torch.tensor
|
||||||
|
:param tesseract_dim: dimension of TESSERACT fo 2.5D parallelism
|
||||||
|
:type tesseract_dim: int
|
||||||
|
:param col_parallel_mode: column parallel mode
|
||||||
|
:type col_parallel_mode: colossalai.context.parallel_mode.ParallelMode
|
||||||
|
"""
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@custom_fwd(cast_inputs=torch.float16)
|
@custom_fwd(cast_inputs=torch.float16)
|
||||||
def forward(ctx: Any, inputs: Tensor, tesseract_dim: int, col_parallel_mode: ParallelMode) -> Tensor:
|
def forward(ctx: Any, inputs: Tensor, tesseract_dim: int, col_parallel_mode: ParallelMode) -> Tensor:
|
||||||
|
@ -566,7 +740,14 @@ def split_tensor_2p5d(input_: Tensor, dim: int = 0) -> Tensor:
|
||||||
|
|
||||||
|
|
||||||
class reduce_by_batch_2p5d(torch.autograd.Function):
|
class reduce_by_batch_2p5d(torch.autograd.Function):
|
||||||
"""All-reduce the input from the model parallel region."""
|
"""
|
||||||
|
All-reduce the input from the model parallel region.
|
||||||
|
|
||||||
|
:param input_: input maxtrix
|
||||||
|
:type input_: torch.tensor
|
||||||
|
:param reduce_mean: If set to ``True``, it will divide the output by column parallel size, default to False
|
||||||
|
:type reduce_mean: int, optional
|
||||||
|
"""
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def symbolic(graph, input_, reduce_mean: bool = False):
|
def symbolic(graph, input_, reduce_mean: bool = False):
|
||||||
output = all_reduce(input_, ParallelMode.PARALLEL_2P5D_COL)
|
output = all_reduce(input_, ParallelMode.PARALLEL_2P5D_COL)
|
||||||
|
|
|
@ -21,7 +21,8 @@ from ._utils import (assert_tesseract_initialization, get_tesseract_dim_dep_from
|
||||||
|
|
||||||
@LAYERS.register_module
|
@LAYERS.register_module
|
||||||
class Linear2p5D(ParallelLayer):
|
class Linear2p5D(ParallelLayer):
|
||||||
"""Linear layer for 2.5D parallelism
|
"""
|
||||||
|
Linear layer for 2.5D parallelism
|
||||||
|
|
||||||
:param in_features: size of each input sample
|
:param in_features: size of each input sample
|
||||||
:type in_features: int
|
:type in_features: int
|
||||||
|
@ -31,6 +32,10 @@ class Linear2p5D(ParallelLayer):
|
||||||
:type bias: bool, optional
|
:type bias: bool, optional
|
||||||
:param dtype: The dtype of parameters, defaults to None
|
:param dtype: The dtype of parameters, defaults to None
|
||||||
:type dtype: torch.dtype, optional
|
:type dtype: torch.dtype, optional
|
||||||
|
:param weight_initializer: The intializer of weight, defaults to kaiming uniform initializer
|
||||||
|
:type weight_initializer: typing.Callable, optional
|
||||||
|
:param bias_initializer: The intializer of bias, defaults to xavier uniform initializer
|
||||||
|
:type bias_initializer: typing.Callable, optional
|
||||||
"""
|
"""
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
in_features: int,
|
in_features: int,
|
||||||
|
@ -125,7 +130,8 @@ class Linear2p5D(ParallelLayer):
|
||||||
|
|
||||||
@LAYERS.register_module
|
@LAYERS.register_module
|
||||||
class LayerNorm2p5D(ParallelLayer):
|
class LayerNorm2p5D(ParallelLayer):
|
||||||
r"""Layer Normalization for 2.5D parallelism
|
r"""
|
||||||
|
Layer Normalization for 2.5D parallelism
|
||||||
|
|
||||||
:param normalized_shape: input shape from an expected input
|
:param normalized_shape: input shape from an expected input
|
||||||
of size. :math:`[* \times \text{normalized_shape}[0] \times \text{normalized_shape}[1] \times \ldots \times \text{normalized_shape}[-1]]`
|
of size. :math:`[* \times \text{normalized_shape}[0] \times \text{normalized_shape}[1] \times \ldots \times \text{normalized_shape}[-1]]`
|
||||||
|
@ -196,17 +202,27 @@ class LayerNorm2p5D(ParallelLayer):
|
||||||
|
|
||||||
@LAYERS.register_module
|
@LAYERS.register_module
|
||||||
class PatchEmbedding2p5D(ParallelLayer):
|
class PatchEmbedding2p5D(ParallelLayer):
|
||||||
""" 2D Image to Patch Embedding
|
"""
|
||||||
:param img_size: iamge size
|
2D Image to Patch Embedding
|
||||||
|
|
||||||
|
:param img_size: image size
|
||||||
:type img_size: int
|
:type img_size: int
|
||||||
:param patch_size: patch size
|
:param patch_size: patch size
|
||||||
:type patch_size: int
|
:type patch_size: int
|
||||||
:param embed_dim: dimension of embedding
|
:param in_chans: number of channels of input image
|
||||||
:type embed_dim: int
|
:type in_chans: int
|
||||||
:param in_chans: number of channels of input image, defaults to 3
|
:param embed_size: size of embedding
|
||||||
:type in_chans: int, optional
|
:type embed_size: int
|
||||||
|
:param dtype: The dtype of parameters, defaults to None
|
||||||
|
:type dtype: torch.dtype, optional
|
||||||
:param flatten: whether to flatten output tensor, defaults to True
|
:param flatten: whether to flatten output tensor, defaults to True
|
||||||
:type flatten: bool, optional
|
:type flatten: bool, optional
|
||||||
|
:param weight_initializer: The intializer of weight, defaults to kaiming uniform initializer
|
||||||
|
:type weight_initializer: typing.Callable, optional
|
||||||
|
:param bias_initializer: The intializer of bias, defaults to xavier uniform initializer
|
||||||
|
:type bias_initializer: typing.Callable, optional
|
||||||
|
:param position_embed_initializer: The intializer of position embedding, defaults to zero
|
||||||
|
:type position_embed_initializer: typing.Callable, optional
|
||||||
"""
|
"""
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
img_size: int,
|
img_size: int,
|
||||||
|
@ -286,6 +302,20 @@ class PatchEmbedding2p5D(ParallelLayer):
|
||||||
|
|
||||||
@LAYERS.register_module
|
@LAYERS.register_module
|
||||||
class Embedding2p5D(ParallelLayer):
|
class Embedding2p5D(ParallelLayer):
|
||||||
|
"""
|
||||||
|
Embedding for 2.5D parallelism
|
||||||
|
|
||||||
|
:param num_embeddings: number of embeddings
|
||||||
|
:type num_embeddings: int
|
||||||
|
:param embedding_dim: dimension of embedding
|
||||||
|
:type embedding_dim: int
|
||||||
|
:param padding_idx: index of padding, defaults to None
|
||||||
|
:type padding_idx: int, optional
|
||||||
|
:param dtype: The dtype of parameters, defaults to None
|
||||||
|
:type dtype: torch.dtype, optional
|
||||||
|
:param weight_initializer: The intializer of weight, defaults to normal initializer
|
||||||
|
:type weight_initializer: typing.Callable, optional
|
||||||
|
"""
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
num_embeddings: int,
|
num_embeddings: int,
|
||||||
embedding_dim: int,
|
embedding_dim: int,
|
||||||
|
@ -336,6 +366,24 @@ class Embedding2p5D(ParallelLayer):
|
||||||
|
|
||||||
@LAYERS.register_module
|
@LAYERS.register_module
|
||||||
class Classifier2p5D(ParallelLayer):
|
class Classifier2p5D(ParallelLayer):
|
||||||
|
"""
|
||||||
|
Classifier for 2.5D parallelism
|
||||||
|
|
||||||
|
:param in_features: size of each input sample
|
||||||
|
:type in_features: int
|
||||||
|
:param num_classes: number of classes
|
||||||
|
:type num_classes: int
|
||||||
|
:param weight: weight of the classifier, defaults to True
|
||||||
|
:type weight: torch.nn.Parameter, optional
|
||||||
|
:param bias: If set to ``False``, the layer will not learn an additive bias, defaults to True
|
||||||
|
:type bias: bool, optional
|
||||||
|
:param dtype: The dtype of parameters, defaults to None
|
||||||
|
:type dtype: torch.dtype, optional
|
||||||
|
:param weight_initializer: The intializer of weight, defaults to kaiming uniform initializer
|
||||||
|
:type weight_initializer: typing.Callable, optional
|
||||||
|
:param bias_initializer: The intializer of bias, defaults to xavier uniform initializer
|
||||||
|
:type bias_initializer: typing.Callable, optional
|
||||||
|
"""
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
in_features: int,
|
in_features: int,
|
||||||
num_classes: int,
|
num_classes: int,
|
||||||
|
|
|
@ -12,6 +12,28 @@ from torch.cuda.amp import custom_bwd, custom_fwd
|
||||||
|
|
||||||
|
|
||||||
class linear_3d(torch.autograd.Function):
|
class linear_3d(torch.autograd.Function):
|
||||||
|
"""
|
||||||
|
Linear layer for 3D parallelism
|
||||||
|
|
||||||
|
:param input_: matrix of input
|
||||||
|
:type input_: torch.tensor
|
||||||
|
:param weight: matrix of weight
|
||||||
|
:type weight: torch.tensor
|
||||||
|
:param bias: matrix of bias
|
||||||
|
:type bias: torch.tensor, optional
|
||||||
|
:param input_parallel_mode: input parallel mode
|
||||||
|
:type input_parallel_mode: colossalai.context.parallel_mode.ParallelMode
|
||||||
|
:param weight_parallel_mode: weight parallel mode
|
||||||
|
:type weight_parallel_mode: colossalai.context.parallel_mode.ParallelMode
|
||||||
|
:param output_parallel_mode: output parallel mode
|
||||||
|
:type output_parallel_mode: colossalai.context.parallel_mode.ParallelMode
|
||||||
|
:param input_dim: dimension of input, defaults to 0
|
||||||
|
:type input_dim: int, optional
|
||||||
|
:param weight_dim: dimension of weight, defaults to -1
|
||||||
|
:type weight_dim: int, optional
|
||||||
|
:param output_dim: dimension of output, defaults to 0
|
||||||
|
:type output_dim: int, optional
|
||||||
|
"""
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@custom_fwd(cast_inputs=torch.float16)
|
@custom_fwd(cast_inputs=torch.float16)
|
||||||
def forward(ctx,
|
def forward(ctx,
|
||||||
|
@ -74,6 +96,22 @@ class linear_3d(torch.autograd.Function):
|
||||||
|
|
||||||
|
|
||||||
class classifier_3d(torch.autograd.Function):
|
class classifier_3d(torch.autograd.Function):
|
||||||
|
"""
|
||||||
|
Classifier
|
||||||
|
|
||||||
|
:param input_: matrix of input
|
||||||
|
:type input_: torch.tensor
|
||||||
|
:param weight: matrix of weight
|
||||||
|
:type weight: torch.tensor
|
||||||
|
:param bias: matrix of bias
|
||||||
|
:type bias: torch.tensor, optional
|
||||||
|
:param input_parallel_mode: input parallel mode
|
||||||
|
:type input_parallel_mode: colossalai.context.parallel_mode.ParallelMode
|
||||||
|
:param weight_parallel_mode: weight parallel mode
|
||||||
|
:type weight_parallel_mode: colossalai.context.parallel_mode.ParallelMode
|
||||||
|
:param output_parallel_mode: output parallel mode
|
||||||
|
:type output_parallel_mode: colossalai.context.parallel_mode.ParallelMode
|
||||||
|
"""
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@custom_fwd(cast_inputs=torch.float16)
|
@custom_fwd(cast_inputs=torch.float16)
|
||||||
def forward(ctx, input_: Tensor, weight: Tensor, bias: Optional[Tensor], input_parallel_mode: ParallelMode,
|
def forward(ctx, input_: Tensor, weight: Tensor, bias: Optional[Tensor], input_parallel_mode: ParallelMode,
|
||||||
|
@ -129,6 +167,29 @@ class classifier_3d(torch.autograd.Function):
|
||||||
|
|
||||||
|
|
||||||
class layernorm_3d(torch.autograd.Function):
|
class layernorm_3d(torch.autograd.Function):
|
||||||
|
"""
|
||||||
|
Layernorm
|
||||||
|
|
||||||
|
:param input_: input maxtrix
|
||||||
|
:type input_: torch.tensor
|
||||||
|
:param weight: matrix of weight
|
||||||
|
:type weight: torch.tensor
|
||||||
|
:param bias: matrix of bias
|
||||||
|
:type bias: torch.tensor
|
||||||
|
:param normalized_shape: input shape from an expected input
|
||||||
|
of size. :math:`[* \times \text{normalized_shape}[0] \times \text{normalized_shape}[1] \times \ldots \times \text{normalized_shape}[-1]]`
|
||||||
|
If a single integer is used, it is treated as a singleton list, and this module will
|
||||||
|
normalize over the last dimension which is expected to be of that specific size.
|
||||||
|
:type normalized_shape: int
|
||||||
|
:param eps: a value added to the denominator for numerical stability
|
||||||
|
:type eps: float
|
||||||
|
:param input_parallel_mode: input parallel mode
|
||||||
|
:type input_parallel_mode: colossalai.context.parallel_mode.ParallelMode
|
||||||
|
:param weight_parallel_mode: weight parallel mode
|
||||||
|
:type weight_parallel_mode: colossalai.context.parallel_mode.ParallelMode
|
||||||
|
:param output_parallel_mode: output parallel mode
|
||||||
|
:type output_parallel_mode: colossalai.context.parallel_mode.ParallelMode
|
||||||
|
"""
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@custom_fwd(cast_inputs=torch.float32)
|
@custom_fwd(cast_inputs=torch.float32)
|
||||||
def forward(ctx, input_: Tensor, weight: Tensor, bias: Tensor, normalized_shape: int, eps: float,
|
def forward(ctx, input_: Tensor, weight: Tensor, bias: Tensor, normalized_shape: int, eps: float,
|
||||||
|
@ -189,6 +250,18 @@ def split_tensor_3d(input_: Tensor,
|
||||||
|
|
||||||
|
|
||||||
class reduce_by_batch_3d(torch.autograd.Function):
|
class reduce_by_batch_3d(torch.autograd.Function):
|
||||||
|
"""
|
||||||
|
All-reduce the input from the model parallel region.
|
||||||
|
|
||||||
|
:param input_: input maxtrix
|
||||||
|
:type input_: torch.tensor
|
||||||
|
:param input_parallel_mode: input parallel mode
|
||||||
|
:type input_parallel_mode: colossalai.context.parallel_mode.ParallelMode
|
||||||
|
:param weight_parallel_mode: weight parallel mode
|
||||||
|
:type weight_parallel_mode: colossalai.context.parallel_mode.ParallelMode
|
||||||
|
:param reduce_mean: If set to ``True``, it will divide the output by (input parallel size * weight parallel size), default to False
|
||||||
|
:type reduce_mean: int, optional
|
||||||
|
"""
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@custom_fwd(cast_inputs=torch.float32)
|
@custom_fwd(cast_inputs=torch.float32)
|
||||||
def forward(ctx,
|
def forward(ctx,
|
||||||
|
@ -215,6 +288,18 @@ class reduce_by_batch_3d(torch.autograd.Function):
|
||||||
|
|
||||||
|
|
||||||
class broadcast_weight_3d_from_diagonal(torch.autograd.Function):
|
class broadcast_weight_3d_from_diagonal(torch.autograd.Function):
|
||||||
|
"""
|
||||||
|
broadcast weight from diagonal
|
||||||
|
|
||||||
|
:param input_: input maxtrix
|
||||||
|
:type input_: torch.tensor
|
||||||
|
:param input_parallel_mode: input parallel mode
|
||||||
|
:type input_parallel_mode: colossalai.context.parallel_mode.ParallelMode
|
||||||
|
:param weight_parallel_mode: weight parallel mode
|
||||||
|
:type weight_parallel_mode: colossalai.context.parallel_mode.ParallelMode
|
||||||
|
:param weight_parallel_mode: output parallel mode
|
||||||
|
:type weight_parallel_mode: colossalai.context.parallel_mode.ParallelMode
|
||||||
|
"""
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@custom_fwd(cast_inputs=torch.float16)
|
@custom_fwd(cast_inputs=torch.float16)
|
||||||
def forward(ctx, input_: Tensor, input_parallel_mode: ParallelMode, weight_parallel_mode: ParallelMode,
|
def forward(ctx, input_: Tensor, input_parallel_mode: ParallelMode, weight_parallel_mode: ParallelMode,
|
||||||
|
|
|
@ -24,6 +24,19 @@ from ._utils import get_depth_from_env, get_last_group, get_parallel_mode_from_e
|
||||||
|
|
||||||
@LAYERS.register_module
|
@LAYERS.register_module
|
||||||
class LayerNorm3D(ParallelLayer):
|
class LayerNorm3D(ParallelLayer):
|
||||||
|
r"""
|
||||||
|
Layer Normalization for 3D parallelism
|
||||||
|
|
||||||
|
:param normalized_shape: input shape from an expected input
|
||||||
|
of size. :math:`[* \times \text{normalized_shape}[0] \times \text{normalized_shape}[1] \times \ldots \times \text{normalized_shape}[-1]]`
|
||||||
|
If a single integer is used, it is treated as a singleton list, and this module will
|
||||||
|
normalize over the last dimension which is expected to be of that specific size.
|
||||||
|
:type normalized_shape: int
|
||||||
|
:param eps: a value added to the denominator for numerical stability, defaults to 1e-12
|
||||||
|
:type eps: float, optional
|
||||||
|
:param dtype: The dtype of parameters, defaults to None
|
||||||
|
:type dtype: torch.dtype, optional
|
||||||
|
"""
|
||||||
def __init__(self, normalized_shape: int, eps: float = 1e-12, dtype: dtype = None):
|
def __init__(self, normalized_shape: int, eps: float = 1e-12, dtype: dtype = None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D)
|
self.input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D)
|
||||||
|
@ -55,6 +68,22 @@ class LayerNorm3D(ParallelLayer):
|
||||||
|
|
||||||
@LAYERS.register_module
|
@LAYERS.register_module
|
||||||
class Linear3D(ParallelLayer):
|
class Linear3D(ParallelLayer):
|
||||||
|
"""
|
||||||
|
Linear layer for 3D parallelism
|
||||||
|
|
||||||
|
:param in_features: size of each input sample
|
||||||
|
:type in_features: int
|
||||||
|
:param out_features: size of each output sample
|
||||||
|
:type out_features: int
|
||||||
|
:param bias: If set to ``False``, the layer will not learn an additive bias, defaults to True
|
||||||
|
:type bias: bool, optional
|
||||||
|
:param dtype: The dtype of parameters, defaults to None
|
||||||
|
:type dtype: torch.dtype, optional
|
||||||
|
:param weight_initializer: The intializer of weight, defaults to kaiming uniform initializer
|
||||||
|
:type weight_initializer: typing.Callable, optional
|
||||||
|
:param bias_initializer: The intializer of bias, defaults to xavier uniform initializer
|
||||||
|
:type bias_initializer: typing.Callable, optional
|
||||||
|
"""
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
in_features: int,
|
in_features: int,
|
||||||
out_features: int,
|
out_features: int,
|
||||||
|
@ -113,6 +142,24 @@ class Linear3D(ParallelLayer):
|
||||||
|
|
||||||
@LAYERS.register_module
|
@LAYERS.register_module
|
||||||
class Classifier3D(ParallelLayer):
|
class Classifier3D(ParallelLayer):
|
||||||
|
"""
|
||||||
|
Classifier for 3D parallelism
|
||||||
|
|
||||||
|
:param in_features: size of each input sample
|
||||||
|
:type in_features: int
|
||||||
|
:param num_classes: number of classes
|
||||||
|
:type num_classes: int
|
||||||
|
:param weight: weight of the classifier, defaults to True
|
||||||
|
:type weight: torch.nn.Parameter, optional
|
||||||
|
:param bias: If set to ``False``, the layer will not learn an additive bias, defaults to True
|
||||||
|
:type bias: bool, optional
|
||||||
|
:param dtype: The dtype of parameters, defaults to None
|
||||||
|
:type dtype: torch.dtype, optional
|
||||||
|
:param weight_initializer: The intializer of weight, defaults to kaiming uniform initializer
|
||||||
|
:type weight_initializer: typing.Callable, optional
|
||||||
|
:param bias_initializer: The intializer of bias, defaults to xavier uniform initializer
|
||||||
|
:type bias_initializer: typing.Callable, optional
|
||||||
|
"""
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
in_features: int,
|
in_features: int,
|
||||||
num_classes: int,
|
num_classes: int,
|
||||||
|
@ -173,6 +220,28 @@ class Classifier3D(ParallelLayer):
|
||||||
|
|
||||||
@LAYERS.register_module
|
@LAYERS.register_module
|
||||||
class PatchEmbedding3D(ParallelLayer):
|
class PatchEmbedding3D(ParallelLayer):
|
||||||
|
"""
|
||||||
|
2D Image to Patch Embedding
|
||||||
|
|
||||||
|
:param img_size: image size
|
||||||
|
:type img_size: int
|
||||||
|
:param patch_size: patch size
|
||||||
|
:type patch_size: int
|
||||||
|
:param in_chans: number of channels of input image
|
||||||
|
:type in_chans: int
|
||||||
|
:param embed_size: size of embedding
|
||||||
|
:type embed_size: int
|
||||||
|
:param dtype: The dtype of parameters, defaults to None
|
||||||
|
:type dtype: torch.dtype, optional
|
||||||
|
:param flatten: whether to flatten output tensor, defaults to True
|
||||||
|
:type flatten: bool, optional
|
||||||
|
:param weight_initializer: The intializer of weight, defaults to kaiming uniform initializer
|
||||||
|
:type weight_initializer: typing.Callable, optional
|
||||||
|
:param bias_initializer: The intializer of bias, defaults to xavier uniform initializer
|
||||||
|
:type bias_initializer: typing.Callable, optional
|
||||||
|
:param position_embed_initializer: The intializer of position embedding, defaults to zero
|
||||||
|
:type position_embed_initializer: typing.Callable, optional
|
||||||
|
"""
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
img_size: int,
|
img_size: int,
|
||||||
patch_size: int,
|
patch_size: int,
|
||||||
|
@ -256,6 +325,20 @@ class PatchEmbedding3D(ParallelLayer):
|
||||||
|
|
||||||
@LAYERS.register_module
|
@LAYERS.register_module
|
||||||
class Embedding3D(ParallelLayer):
|
class Embedding3D(ParallelLayer):
|
||||||
|
"""
|
||||||
|
Embedding for 3D parallelism
|
||||||
|
|
||||||
|
:param num_embeddings: number of embeddings
|
||||||
|
:type num_embeddings: int
|
||||||
|
:param embedding_dim: dimension of embedding
|
||||||
|
:type embedding_dim: int
|
||||||
|
:param padding_idx: index of padding, defaults to None
|
||||||
|
:type padding_idx: int, optional
|
||||||
|
:param dtype: The dtype of parameters, defaults to None
|
||||||
|
:type dtype: torch.dtype, optional
|
||||||
|
:param weight_initializer: The intializer of weight, defaults to normal initializer
|
||||||
|
:type weight_initializer: typing.Callable, optional
|
||||||
|
"""
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
num_embeddings: int,
|
num_embeddings: int,
|
||||||
embedding_dim: int,
|
embedding_dim: int,
|
||||||
|
|
|
@ -32,7 +32,8 @@ def drop_path(x, drop_prob: float = 0., training: bool = False):
|
||||||
|
|
||||||
|
|
||||||
class DropPath(nn.Module):
|
class DropPath(nn.Module):
|
||||||
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
|
"""
|
||||||
|
Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
|
||||||
Adapted from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/layers/drop.py
|
Adapted from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/layers/drop.py
|
||||||
"""
|
"""
|
||||||
def __init__(self, drop_prob=None):
|
def __init__(self, drop_prob=None):
|
||||||
|
@ -97,7 +98,27 @@ class WrappedDropPath(nn.Module):
|
||||||
|
|
||||||
@LAYERS.register_module
|
@LAYERS.register_module
|
||||||
class VanillaPatchEmbedding(nn.Module):
|
class VanillaPatchEmbedding(nn.Module):
|
||||||
""" 2D Image to Patch Embedding
|
"""
|
||||||
|
2D Image to Patch Embedding
|
||||||
|
|
||||||
|
:param img_size: image size
|
||||||
|
:type img_size: int
|
||||||
|
:param patch_size: patch size
|
||||||
|
:type patch_size: int
|
||||||
|
:param in_chans: number of channels of input image
|
||||||
|
:type in_chans: int
|
||||||
|
:param embed_size: size of embedding
|
||||||
|
:type embed_size: int
|
||||||
|
:param dtype: The dtype of parameters, defaults to None
|
||||||
|
:type dtype: torch.dtype, optional
|
||||||
|
:param flatten: whether to flatten output tensor, defaults to True
|
||||||
|
:type flatten: bool, optional
|
||||||
|
:param weight_initializer: The intializer of weight, defaults to kaiming uniform initializer
|
||||||
|
:type weight_initializer: typing.Callable, optional
|
||||||
|
:param bias_initializer: The intializer of bias, defaults to xavier uniform initializer
|
||||||
|
:type bias_initializer: typing.Callable, optional
|
||||||
|
:param position_embed_initializer: The intializer of position embedding, defaults to zero
|
||||||
|
:type position_embed_initializer: typing.Callable, optional
|
||||||
"""
|
"""
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
img_size: int,
|
img_size: int,
|
||||||
|
@ -148,6 +169,24 @@ class VanillaPatchEmbedding(nn.Module):
|
||||||
|
|
||||||
@LAYERS.register_module
|
@LAYERS.register_module
|
||||||
class VanillaClassifier(nn.Module):
|
class VanillaClassifier(nn.Module):
|
||||||
|
"""
|
||||||
|
Classifier
|
||||||
|
|
||||||
|
:param in_features: size of each input sample
|
||||||
|
:type in_features: int
|
||||||
|
:param num_classes: number of classes
|
||||||
|
:type num_classes: int
|
||||||
|
:param weight: weight of the classifier, defaults to True
|
||||||
|
:type weight: torch.nn.Parameter, optional
|
||||||
|
:param bias: If set to ``False``, the layer will not learn an additive bias, defaults to True
|
||||||
|
:type bias: bool, optional
|
||||||
|
:param dtype: The dtype of parameters, defaults to None
|
||||||
|
:type dtype: torch.dtype, optional
|
||||||
|
:param weight_initializer: The intializer of weight, defaults to kaiming uniform initializer
|
||||||
|
:type weight_initializer: typing.Callable, optional
|
||||||
|
:param bias_initializer: The intializer of bias, defaults to xavier uniform initializer
|
||||||
|
:type bias_initializer: typing.Callable, optional
|
||||||
|
"""
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
in_features: int,
|
in_features: int,
|
||||||
num_classes: int,
|
num_classes: int,
|
||||||
|
|
|
@ -7,7 +7,8 @@ from torch.nn.modules.loss import _Loss
|
||||||
|
|
||||||
@LOSSES.register_module
|
@LOSSES.register_module
|
||||||
class CrossEntropyLoss2D(_Loss):
|
class CrossEntropyLoss2D(_Loss):
|
||||||
"""Cross entropy loss for 2D parallelism
|
"""
|
||||||
|
Cross entropy loss for 2D parallelism
|
||||||
|
|
||||||
:param reduction: whether to average the loss, defaults to True
|
:param reduction: whether to average the loss, defaults to True
|
||||||
:type reduction: bool, optional
|
:type reduction: bool, optional
|
||||||
|
|
|
@ -7,7 +7,9 @@ from torch.nn.modules.loss import _Loss
|
||||||
|
|
||||||
@LOSSES.register_module
|
@LOSSES.register_module
|
||||||
class CrossEntropyLoss2p5D(_Loss):
|
class CrossEntropyLoss2p5D(_Loss):
|
||||||
"""Cross entropy loss for 2.5D parallelism
|
"""
|
||||||
|
Cross entropy loss for 2.5D parallelism
|
||||||
|
|
||||||
:param reduction: whether to average the loss, defaults to True
|
:param reduction: whether to average the loss, defaults to True
|
||||||
:type reduction: bool, optional
|
:type reduction: bool, optional
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -7,14 +7,11 @@ from torch.nn.modules.loss import _Loss
|
||||||
|
|
||||||
@LOSSES.register_module
|
@LOSSES.register_module
|
||||||
class CrossEntropyLoss3D(_Loss):
|
class CrossEntropyLoss3D(_Loss):
|
||||||
"""Cross entropy loss for 3D parallelism
|
"""
|
||||||
|
Cross entropy loss for 3D parallelism
|
||||||
|
|
||||||
:param depth: depth for 3D parallelism
|
:param depth: depth for 3D parallelism
|
||||||
:type depth: int
|
:type depth: int
|
||||||
:param input_parallel_mode: parallel mode for input tensor
|
|
||||||
:type input_parallel_mode: ParallelMode
|
|
||||||
:param weight_parallel_mode: parallel mode for weight
|
|
||||||
:type weight_parallel_mode: ParallelMode
|
|
||||||
:param reduction: whether to average the loss, defaults to True
|
:param reduction: whether to average the loss, defaults to True
|
||||||
:type reduction: bool, optional
|
:type reduction: bool, optional
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -6,6 +6,12 @@ from ._utils import calc_acc
|
||||||
|
|
||||||
|
|
||||||
class Accuracy2D(nn.Module):
|
class Accuracy2D(nn.Module):
|
||||||
|
"""
|
||||||
|
Accuracy for 2D parallelism
|
||||||
|
|
||||||
|
:param logits: predicted labels
|
||||||
|
:param targets: true labels
|
||||||
|
"""
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
|
|
|
@ -6,6 +6,12 @@ from ._utils import calc_acc
|
||||||
|
|
||||||
|
|
||||||
class Accuracy2p5D(nn.Module):
|
class Accuracy2p5D(nn.Module):
|
||||||
|
"""
|
||||||
|
Accuracy for 2p5D parallelism
|
||||||
|
|
||||||
|
:param logits: predicted labels
|
||||||
|
:param targets: true labels
|
||||||
|
"""
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
|
|
|
@ -8,6 +8,12 @@ from ._utils import calc_acc
|
||||||
|
|
||||||
|
|
||||||
class Accuracy3D(nn.Module):
|
class Accuracy3D(nn.Module):
|
||||||
|
"""
|
||||||
|
Accuracy for 3D parallelism
|
||||||
|
|
||||||
|
:param logits: predicted labels
|
||||||
|
:param targets: true labels
|
||||||
|
"""
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D)
|
self.input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D)
|
||||||
|
|
|
@ -10,10 +10,9 @@ class BaseHook(ABC):
|
||||||
"""This class allows users to add desired actions in specific time points
|
"""This class allows users to add desired actions in specific time points
|
||||||
during training or evaluation.
|
during training or evaluation.
|
||||||
|
|
||||||
:param trainer: Trainer attached with current hook
|
|
||||||
:param priority: Priority in the printing, hooks with small priority will be printed in front
|
:param priority: Priority in the printing, hooks with small priority will be printed in front
|
||||||
:type trainer: Trainer
|
|
||||||
:type priority: int
|
:type priority: int
|
||||||
|
:param trainer: Trainer attached with current hook
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, priority: int) -> None:
|
def __init__(self, priority: int) -> None:
|
||||||
|
@ -43,11 +42,11 @@ class BaseHook(ABC):
|
||||||
"""Actions after running a training iteration.
|
"""Actions after running a training iteration.
|
||||||
|
|
||||||
:param output: Output of the model
|
:param output: Output of the model
|
||||||
|
:type output: torch.Tensor
|
||||||
:param label: Labels of the input data
|
:param label: Labels of the input data
|
||||||
|
:type label: torch.Tensor
|
||||||
:param loss: Loss between the output and input data
|
:param loss: Loss between the output and input data
|
||||||
:type output: Tensor
|
:type loss: torch.Tensor
|
||||||
:type label: Tensor
|
|
||||||
:type loss: Tensor
|
|
||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@ -90,10 +89,10 @@ class BaseHook(ABC):
|
||||||
"""Actions after running a testing iteration.
|
"""Actions after running a testing iteration.
|
||||||
|
|
||||||
:param output: Output of the model
|
:param output: Output of the model
|
||||||
:param label: Labels of the input data
|
|
||||||
:param loss: Loss between the output and input data
|
|
||||||
:type output: Tensor
|
:type output: Tensor
|
||||||
|
:param label: Labels of the input data
|
||||||
:type label: Tensor
|
:type label: Tensor
|
||||||
|
:param loss: Loss between the output and input data
|
||||||
:type loss: Tensor
|
:type loss: Tensor
|
||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
|
@ -16,14 +16,15 @@ from ._lr_scheduler_hook import LRSchedulerHook
|
||||||
class SaveCheckpointHook(BaseHook):
|
class SaveCheckpointHook(BaseHook):
|
||||||
"""Saves the model by interval in training process.
|
"""Saves the model by interval in training process.
|
||||||
|
|
||||||
:param interval: Saving interval
|
:param interval: Saving interval, defaults to 1
|
||||||
:param checkpoint_dir: Directory of saving checkpoint
|
|
||||||
:param suffix: Saving suffix of the file
|
|
||||||
:param priority: Priority in the printing, hooks with small priority will be printed in front
|
|
||||||
:type interval: int, optional
|
:type interval: int, optional
|
||||||
:type checkpoint_dir: int, optional
|
:param checkpoint_dir: Directory of saving checkpoint, defaults to None
|
||||||
|
:type checkpoint_dir: str, optional
|
||||||
|
:param suffix: Saving suffix of the file, defaults to ''
|
||||||
:type suffix: str, optional
|
:type suffix: str, optional
|
||||||
|
:param priority: Priority in the printing, hooks with small priority will be printed in front, defaults to 10
|
||||||
:type priority: int, optional
|
:type priority: int, optional
|
||||||
|
:param trainer: Trainer attached with current hook
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
|
@ -71,16 +72,19 @@ class SaveCheckpointHook(BaseHook):
|
||||||
class LoadCheckpointHook(BaseHook):
|
class LoadCheckpointHook(BaseHook):
|
||||||
"""Loads the model before training process.
|
"""Loads the model before training process.
|
||||||
|
|
||||||
:param checkpoint_dir: Directory of saving checkpoint
|
:param checkpoint_dir: Directory of saving checkpoint, defaults to None
|
||||||
:param epoch: Epoch number to be set
|
|
||||||
:param finetune: Whether allows to load a part of the model
|
|
||||||
:param strict: Whether loads a model that has the same shape of parameters
|
|
||||||
:param priority: Priority in the printing, hooks with small priority will be printed in front
|
|
||||||
:type checkpoint_dir: str, optional
|
:type checkpoint_dir: str, optional
|
||||||
|
:param epoch: Epoch number to be set, defaults to -1
|
||||||
:type epoch: str, optional
|
:type epoch: str, optional
|
||||||
|
:param finetune: Whether allows to load a part of the model, defaults to False
|
||||||
:type finetune: bool, optional
|
:type finetune: bool, optional
|
||||||
|
:param strict: Whether loads a model that has the same shape of parameters, defaults to False
|
||||||
:type strict: bool, optional
|
:type strict: bool, optional
|
||||||
|
:param suffix: Suffic, defaults to ''
|
||||||
|
:type suffix: str, optional
|
||||||
|
:param priority: Priority in the printing, hooks with small priority will be printed in front, defaults to 0
|
||||||
:type priority: int, optional
|
:type priority: int, optional
|
||||||
|
:param trainer: Trainer attached with current hook
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
|
|
|
@ -25,6 +25,15 @@ def _format_number(val, prec=5):
|
||||||
|
|
||||||
|
|
||||||
class LogByEpochHook(BaseHook):
|
class LogByEpochHook(BaseHook):
|
||||||
|
"""hook to log by epoch
|
||||||
|
|
||||||
|
:param logger: logger for the log
|
||||||
|
:param interval: Recording interval, defaults to 1
|
||||||
|
:type interval: int, optional
|
||||||
|
:param priority: Priority in the printing, hooks with small priority will be printed in front, defaults to 1
|
||||||
|
:type priority: int, optional
|
||||||
|
:param trainer: Trainer attached with current hook
|
||||||
|
"""
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
logger,
|
logger,
|
||||||
interval: int = 1,
|
interval: int = 1,
|
||||||
|
@ -39,6 +48,12 @@ class LogByEpochHook(BaseHook):
|
||||||
|
|
||||||
@HOOKS.register_module
|
@HOOKS.register_module
|
||||||
class LogMetricByStepHook(BaseHook):
|
class LogMetricByStepHook(BaseHook):
|
||||||
|
"""hook to log metric by step
|
||||||
|
|
||||||
|
:param priority: Priority in the printing, hooks with small priority will be printed in front, defaults to 10
|
||||||
|
:type priority: int, optional
|
||||||
|
:param trainer: Trainer attached with current hook
|
||||||
|
"""
|
||||||
def __init__(self, priority: int = 10):
|
def __init__(self, priority: int = 10):
|
||||||
super().__init__(priority)
|
super().__init__(priority)
|
||||||
|
|
||||||
|
@ -59,12 +74,13 @@ class LogMetricByStepHook(BaseHook):
|
||||||
class LogMetricByEpochHook(LogByEpochHook):
|
class LogMetricByEpochHook(LogByEpochHook):
|
||||||
"""Specialized Hook to record the metric to log.
|
"""Specialized Hook to record the metric to log.
|
||||||
|
|
||||||
:param trainer: Trainer attached with current hook
|
:param logger: logger for the log
|
||||||
:type trainer: Trainer
|
:param interval: Recording interval, defaults to 1
|
||||||
:param interval: Recording interval
|
|
||||||
:type interval: int, optional
|
:type interval: int, optional
|
||||||
:param priority: Priority in the printing, hooks with small priority will be printed in front
|
:param priority: Priority in the printing, hooks with small priority will be printed in front, defaults to 10
|
||||||
:type priority: int, optional
|
:type priority: int, optional
|
||||||
|
:param trainer: Trainer attached with current hook
|
||||||
|
:param mode: Mode of metrics, 'train' and 'test'
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
|
@ -102,12 +118,17 @@ class LogMetricByEpochHook(LogByEpochHook):
|
||||||
class TensorboardHook(BaseHook):
|
class TensorboardHook(BaseHook):
|
||||||
"""Specialized Hook to record the metric to Tensorboard.
|
"""Specialized Hook to record the metric to Tensorboard.
|
||||||
|
|
||||||
:param trainer: Trainer attached with current hook
|
|
||||||
:type trainer: Trainer
|
|
||||||
:param log_dir: Directory of log
|
:param log_dir: Directory of log
|
||||||
:type log_dir: str, optional
|
:type log_dir: str
|
||||||
:param priority: Priority in the printing, hooks with small priority will be printed in front
|
:param ranks: ranks of processors
|
||||||
|
:type ranks: typing.List
|
||||||
|
:param parallel_mode: Parallel mode, defaults to colossalai.context.parallel_mode.ParallelMode.GLOBAL
|
||||||
|
:type parallel_mode: colossalai.context.parallel_mode.ParallelMode, optional
|
||||||
|
:param priority: Priority in the printing, hooks with small priority will be printed in front, defaults to 10
|
||||||
:type priority: int, optional
|
:type priority: int, optional
|
||||||
|
:param trainer: Trainer attached with current hook
|
||||||
|
:param mode: Mode of metrics, 'train' and 'test'
|
||||||
|
:type mode: str
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
|
@ -184,14 +205,20 @@ class TensorboardHook(BaseHook):
|
||||||
class LogTimingByEpochHook(LogByEpochHook):
|
class LogTimingByEpochHook(LogByEpochHook):
|
||||||
"""Specialized Hook to write timing record to log.
|
"""Specialized Hook to write timing record to log.
|
||||||
|
|
||||||
:param trainer: Trainer attached with current hook
|
:param timer: Timer for the hook
|
||||||
:type trainer: Trainer
|
:type timer: colossalai.utils.MultiTimer
|
||||||
:param interval: Recording interval
|
:param logger: Logger for the log
|
||||||
|
:type logger: colossalai.logging.DistributedLogger
|
||||||
|
:param interval: Recording interval, defaults to 1
|
||||||
:type interval: int, optional
|
:type interval: int, optional
|
||||||
:param priority: Priority in the printing, hooks with small priority will be printed in front
|
:param priority: Priority in the printing, hooks with small priority will be printed in front, defaults to 10
|
||||||
:type priority: int, optional
|
:type priority: int, optional
|
||||||
:param log_eval: Whether writes in evaluation
|
:param log_eval: Whether writes in evaluation, defaults to True
|
||||||
:type log_eval: bool, optional
|
:type log_eval: bool, optional
|
||||||
|
:param ignore_num_train_steps: Number of training steps to ignore, defaults to 0
|
||||||
|
:type ignore_num_train_steps: int, optional
|
||||||
|
:param mode: Mode of metrics, 'train' and 'test'
|
||||||
|
:param trainer: Trainer attached with current hook
|
||||||
"""
|
"""
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
timer: MultiTimer,
|
timer: MultiTimer,
|
||||||
|
@ -249,13 +276,13 @@ class LogTimingByEpochHook(LogByEpochHook):
|
||||||
class LogMemoryByEpochHook(LogByEpochHook):
|
class LogMemoryByEpochHook(LogByEpochHook):
|
||||||
"""Specialized Hook to write memory usage record to log.
|
"""Specialized Hook to write memory usage record to log.
|
||||||
|
|
||||||
:param trainer: Trainer attached with current hook
|
:param logger: Logger for the log
|
||||||
:type trainer: Trainer
|
:type logger: colossalai.logging.DistributedLogger
|
||||||
:param interval: Recording interval
|
:param interval: Recording interval, defaults to 1
|
||||||
:type interval: int, optional
|
:type interval: int, optional
|
||||||
:param priority: Priority in the printing, hooks with small priority will be printed in front
|
:param priority: Priority in the printing, hooks with small priority will be printed in front, defaults to 10
|
||||||
:type priority: int, optional
|
:type priority: int, optional
|
||||||
:param log_eval: Whether writes in evaluation
|
:param log_eval: Whether writes in evaluation, defaults to True
|
||||||
:type log_eval: bool, optional
|
:type log_eval: bool, optional
|
||||||
"""
|
"""
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
|
@ -263,7 +290,8 @@ class LogMemoryByEpochHook(LogByEpochHook):
|
||||||
interval: int = 1,
|
interval: int = 1,
|
||||||
priority: int = 10,
|
priority: int = 10,
|
||||||
log_eval: bool = True,
|
log_eval: bool = True,
|
||||||
report_cpu: bool = False) -> None:
|
report_cpu: bool = False, # no reference
|
||||||
|
) -> None:
|
||||||
super().__init__(logger=logger, interval=interval, priority=priority)
|
super().__init__(logger=logger, interval=interval, priority=priority)
|
||||||
self._log_eval = log_eval
|
self._log_eval = log_eval
|
||||||
self._is_rank_to_log = is_dp_rank_0() and is_tp_rank_0()
|
self._is_rank_to_log = is_dp_rank_0() and is_tp_rank_0()
|
||||||
|
|
|
@ -8,14 +8,14 @@ from ._metric_hook import LearningRateMetric, MetricHook
|
||||||
class LRSchedulerHook(MetricHook):
|
class LRSchedulerHook(MetricHook):
|
||||||
"""Build LR scheduler
|
"""Build LR scheduler
|
||||||
|
|
||||||
:param trainer: Trainer attached with current hook
|
:param lr_scheduler: LR scheduler
|
||||||
:type trainer: Trainer
|
:param by_epoch: If `True`, the LR will be scheduled every epoch. Else, the LR will be scheduled every batch
|
||||||
:param lr_scheduler_cfg: The config of LR scheduler
|
|
||||||
:type lr_scheduler_cfg: dict
|
|
||||||
:param by_epoch: If `True`, the LR will be scheduled every epoch. Else, the LR will be scheduled every batch. Defaults to `True`.
|
|
||||||
:type by_epoch: bool
|
:type by_epoch: bool
|
||||||
:param priority: Priority in the printing, hooks with small priority will be printed in front
|
:param store_lr_in_state: If `True`, store the learning rate in each state, defaults to `True`
|
||||||
|
:type store_lr_in_state: bool, optional
|
||||||
|
:param priority: Priority in the printing, hooks with small priority will be printed in front, defaults to 1
|
||||||
:type priority: int, optional
|
:type priority: int, optional
|
||||||
|
:param trainer: Trainer attached with current hook
|
||||||
"""
|
"""
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
|
|
@ -133,6 +133,8 @@ class LearningRateMetric(Metric):
|
||||||
|
|
||||||
:param epoch_only: Whether the metric only read for the full epoch
|
:param epoch_only: Whether the metric only read for the full epoch
|
||||||
:type epoch_only: bool
|
:type epoch_only: bool
|
||||||
|
:param initial_lr: initial learning rate, defaults to 0.0
|
||||||
|
:type initial_lr: float, optional
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, epoch_only: bool, initial_lr: float = 0.):
|
def __init__(self, epoch_only: bool, initial_lr: float = 0.):
|
||||||
|
@ -161,6 +163,8 @@ class AccuracyMetric(Metric):
|
||||||
|
|
||||||
:param epoch_only: Whether the metric only read for the full epoch
|
:param epoch_only: Whether the metric only read for the full epoch
|
||||||
:type epoch_only: bool
|
:type epoch_only: bool
|
||||||
|
:param accuracy_func: accuracy function for the classification task
|
||||||
|
:type accuracy_func: typing.Callable
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, epoch_only: bool, accuracy_func: Callable):
|
def __init__(self, epoch_only: bool, accuracy_func: Callable):
|
||||||
|
@ -182,7 +186,8 @@ class AccuracyMetric(Metric):
|
||||||
and labels. It expects the output has logits and labels.
|
and labels. It expects the output has logits and labels.
|
||||||
|
|
||||||
:param logits: The logits output of the model
|
:param logits: The logits output of the model
|
||||||
:param label: The labels of the input data
|
:param targets: real labels of the dataset
|
||||||
|
:param batch_size: batch size of the task
|
||||||
"""
|
"""
|
||||||
if isinstance(logits, (list, tuple)):
|
if isinstance(logits, (list, tuple)):
|
||||||
logits = logits[0]
|
logits = logits[0]
|
||||||
|
@ -216,10 +221,10 @@ class MetricHook(BaseHook):
|
||||||
update their states. Others are used to display and
|
update their states. Others are used to display and
|
||||||
record the metric.
|
record the metric.
|
||||||
|
|
||||||
:param trainer: Trainer attached with current hook
|
|
||||||
:param priority: Priority in the printing, hooks with small priority will be printed in front
|
:param priority: Priority in the printing, hooks with small priority will be printed in front
|
||||||
:type trainer: Trainer
|
|
||||||
:type priority: int
|
:type priority: int
|
||||||
|
:param trainer: Trainer attached with current hook
|
||||||
|
:type trainer: Trainer
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
|
@ -238,10 +243,10 @@ class MetricHook(BaseHook):
|
||||||
class LossHook(MetricHook):
|
class LossHook(MetricHook):
|
||||||
"""Specialized hook class for :class:`Loss`.
|
"""Specialized hook class for :class:`Loss`.
|
||||||
|
|
||||||
:param trainer: Trainer attached with current hook
|
:param priority: Priority in the printing, hooks with small priority will be printed in front, defaults to 0
|
||||||
:param priority: Priority in the printing, hooks with small priority will be printed in front
|
|
||||||
:type trainer: Trainer
|
|
||||||
:type priority: int, optional
|
:type priority: int, optional
|
||||||
|
:param trainer: Trainer attached with current hook
|
||||||
|
:type trainer: Trainer
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, priority: int = 0):
|
def __init__(self, priority: int = 0):
|
||||||
|
@ -279,10 +284,12 @@ class LossHook(MetricHook):
|
||||||
class AccuracyHook(MetricHook):
|
class AccuracyHook(MetricHook):
|
||||||
"""Specialized hook class for :class:`Accuracy`.
|
"""Specialized hook class for :class:`Accuracy`.
|
||||||
|
|
||||||
|
:param accuracy_func: Priority in the printing, hooks with small priority will be printed in front
|
||||||
|
:type accuracy_func: typing.Callable
|
||||||
|
:param priority: Priority in the printing, hooks with small priority will be printed in front, defaults to 0
|
||||||
|
:type priority: int, optional
|
||||||
:param trainer: Trainer attached with current hook
|
:param trainer: Trainer attached with current hook
|
||||||
:param priority: Priority in the printing, hooks with small priority will be printed in front
|
|
||||||
:type trainer: Trainer
|
:type trainer: Trainer
|
||||||
:type priority: int
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, accuracy_func: Callable, priority: int = 0):
|
def __init__(self, accuracy_func: Callable, priority: int = 0):
|
||||||
|
@ -308,6 +315,13 @@ class AccuracyHook(MetricHook):
|
||||||
|
|
||||||
|
|
||||||
class ThroughputMetric(Metric):
|
class ThroughputMetric(Metric):
|
||||||
|
"""Metric for :class:`Throughput`.
|
||||||
|
|
||||||
|
:param epoch_only: epoch only
|
||||||
|
:type epoch_only: bool
|
||||||
|
:param num_samples: number of samples
|
||||||
|
:param time: time
|
||||||
|
"""
|
||||||
def __init__(self, epoch_only: bool):
|
def __init__(self, epoch_only: bool):
|
||||||
super().__init__(epoch_only=epoch_only)
|
super().__init__(epoch_only=epoch_only)
|
||||||
self.accumulated_num_samples = torch.zeros(1, device=get_current_device())
|
self.accumulated_num_samples = torch.zeros(1, device=get_current_device())
|
||||||
|
@ -345,6 +359,13 @@ class ThroughputMetric(Metric):
|
||||||
|
|
||||||
@HOOKS.register_module
|
@HOOKS.register_module
|
||||||
class ThroughputHook(MetricHook):
|
class ThroughputHook(MetricHook):
|
||||||
|
"""Specialized hook class for :class:`Throughput`.
|
||||||
|
|
||||||
|
:param priority: priority of throughput hook, defaults to 10
|
||||||
|
:type priority: int, optional
|
||||||
|
:param trainer: Trainer attached with current hook
|
||||||
|
:type trainer: Trainer
|
||||||
|
"""
|
||||||
def __init__(self, priority: int = 10):
|
def __init__(self, priority: int = 10):
|
||||||
super().__init__(priority)
|
super().__init__(priority)
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue