mirror of https://github.com/hpcaitech/ColossalAI
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
1074 lines
44 KiB
1074 lines
44 KiB
#!/usr/bin/env python
|
|
# -*- encoding: utf-8 -*-
|
|
|
|
import math
|
|
from collections import OrderedDict
|
|
from typing import Callable, Tuple
|
|
|
|
import torch
|
|
import torch.nn.functional as F
|
|
from torch import Tensor
|
|
from torch.nn.parameter import Parameter
|
|
|
|
from colossalai.kernel import LayerNorm
|
|
from colossalai.legacy.communication import broadcast
|
|
from colossalai.legacy.context import ParallelMode, seed
|
|
from colossalai.legacy.context.parallel_context import global_context as gpc
|
|
from colossalai.legacy.global_variables import tensor_parallel_env as env
|
|
from colossalai.legacy.registry import LAYERS
|
|
from colossalai.legacy.utils.checkpointing import (
|
|
broadcast_state_dict,
|
|
gather_tensor_parallel_state_dict,
|
|
partition_tensor_parallel_state_dict,
|
|
)
|
|
from colossalai.nn import init as init
|
|
from colossalai.utils.device import get_current_device
|
|
|
|
from ..base_layer import ParallelLayer
|
|
from ..colossalai_layer._utils import ColossalaiModule
|
|
from ..utils import divide, set_tensor_parallel_attribute_by_partition
|
|
from ..vanilla import VanillaPatchEmbedding
|
|
from ._operation import linear_with_async_comm
|
|
from ._utils import (
|
|
gather_forward_split_backward,
|
|
get_parallel_input,
|
|
reduce_grad,
|
|
reduce_input,
|
|
set_parallel_input,
|
|
split_forward_gather_backward,
|
|
)
|
|
|
|
Fast_LN = None
|
|
try:
|
|
from apex.contrib.layer_norm.layer_norm import FastLayerNorm
|
|
|
|
Fast_LN = FastLayerNorm
|
|
except ImportError:
|
|
pass
|
|
|
|
|
|
@LAYERS.register_module
|
|
class Linear1D(ColossalaiModule):
|
|
r"""Linear layer for 1D parallelism.
|
|
|
|
Args:
|
|
in_features (int): size of each input sample.
|
|
out_features (int): size of each output sample.
|
|
bias (bool, optional): If set to ``False``, the layer will not learn an additive bias, defaults to ``True``.
|
|
dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None.
|
|
gather_output (bool, optional): Whether to call all-gather on output, defaults to False.
|
|
skip_bias_add (bool, optional): If set to ``True``, it will skip bias add for linear layer,
|
|
which is preserved for kernel fusion, defaults to False
|
|
weight_initializer (:class:`typing.Callable`, optional):
|
|
The initializer of weight, defaults to kaiming uniform initializer.
|
|
bias_initializer (:class:`typing.Callable`, optional):
|
|
The initializer of bias, defaults to xavier uniform initializer.
|
|
|
|
More details about ``initializer`` please refer to
|
|
`init <https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/nn/init.py>`_.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
in_features: int,
|
|
out_features: int,
|
|
bias: bool = True,
|
|
dtype: torch.dtype = None,
|
|
gather_output: bool = False,
|
|
skip_bias_add: bool = False,
|
|
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
|
|
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1),
|
|
):
|
|
parallel_input = get_parallel_input()
|
|
if not parallel_input and not gather_output:
|
|
layer = Linear1D_Col(
|
|
in_features,
|
|
out_features,
|
|
bias=bias,
|
|
dtype=dtype,
|
|
skip_bias_add=skip_bias_add,
|
|
weight_initializer=weight_initializer,
|
|
bias_initializer=bias_initializer,
|
|
)
|
|
else:
|
|
layer = Linear1D_Row(
|
|
in_features,
|
|
out_features,
|
|
bias=bias,
|
|
dtype=dtype,
|
|
parallel_input=parallel_input,
|
|
skip_bias_add=skip_bias_add,
|
|
weight_initializer=weight_initializer,
|
|
bias_initializer=bias_initializer,
|
|
)
|
|
super().__init__(layer)
|
|
|
|
|
|
@LAYERS.register_module
|
|
class LayerNorm1D(ColossalaiModule):
|
|
r"""
|
|
Layer Normalization for colossalai
|
|
|
|
Args:
|
|
normalized_shape (int): input shape from an expected input of size.
|
|
:math:`[* \times \text{normalized_shape}[0] \times \text{normalized_shape}[1]
|
|
\times \ldots \times \text{normalized_shape}[-1]]`
|
|
If a single integer is used, it is treated as a singleton list, and this module will
|
|
normalize over the last dimension which is expected to be of that specific size.
|
|
eps (float): a value added to the denominator for numerical stability, defaults to 1e-05.
|
|
bias (bool, optional): Whether to add a bias, defaults to ``True``.
|
|
dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None.
|
|
"""
|
|
|
|
_fast_ln_supported_sizes = [
|
|
1024,
|
|
1536,
|
|
2048,
|
|
2304,
|
|
3072,
|
|
3840,
|
|
4096,
|
|
5120,
|
|
6144,
|
|
8192,
|
|
10240,
|
|
12288,
|
|
12800,
|
|
15360,
|
|
16384,
|
|
18432,
|
|
20480,
|
|
24576,
|
|
25600,
|
|
30720,
|
|
32768,
|
|
40960,
|
|
49152,
|
|
65536,
|
|
]
|
|
|
|
def __init__(self, normalized_shape: int, eps=1e-05, bias=True, dtype=None):
|
|
if Fast_LN is not None and normalized_shape in self._fast_ln_supported_sizes:
|
|
norm = Fast_LN(normalized_shape, eps=eps).to(dtype)
|
|
else:
|
|
norm = None
|
|
try:
|
|
from apex.normalization import FusedLayerNorm
|
|
|
|
norm = FusedLayerNorm(normalized_shape, eps=eps).to(dtype)
|
|
except ImportError:
|
|
norm = LayerNorm(normalized_shape, eps=eps).to(dtype)
|
|
super().__init__(norm)
|
|
|
|
def _load_from_state_dict(self, state_dict, prefix, *args):
|
|
local_state = OrderedDict()
|
|
weight_key = prefix + "weight"
|
|
bias_key = prefix + "bias"
|
|
if gpc.get_local_rank(ParallelMode.TENSOR) == 0:
|
|
# weight
|
|
weight = state_dict.pop(weight_key, None)
|
|
if weight is not None:
|
|
local_state[weight_key] = weight
|
|
# bias
|
|
bias = state_dict.pop(bias_key, None)
|
|
if bias is not None:
|
|
local_state[bias_key] = bias
|
|
|
|
local_state = broadcast_state_dict(local_state, ParallelMode.PARALLEL_1D)
|
|
super()._load_from_state_dict(local_state, prefix, *args)
|
|
|
|
def _save_to_state_dict(self, destination, prefix, keep_vars):
|
|
if gpc.get_local_rank(ParallelMode.TENSOR) == 0:
|
|
super()._save_to_state_dict(destination, prefix, keep_vars)
|
|
|
|
|
|
@LAYERS.register_module
|
|
class Classifier1D(ParallelLayer):
|
|
r"""RowLinear with given weight. Classifier of 1D parallelism.
|
|
|
|
Args:
|
|
in_features (int): size of each input sample.
|
|
num_classes (int): number of classes.
|
|
weight (:class:`torch.nn.Parameter`, optional): weight of the classifier, defaults to None.
|
|
bias (bool, optional): If set to ``False``, the layer will not learn an additive bias, defaults to ``True``.
|
|
dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None.
|
|
weight_initializer (:class:`typing.Callable`, optional):
|
|
The initializer of weight, defaults to kaiming uniform initializer.
|
|
bias_initializer (:class:`typing.Callable`, optional):
|
|
The initializer of bias, defaults to xavier uniform initializer.
|
|
|
|
More details about ``initializer`` please refer to
|
|
`init <https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/nn/init.py>`_.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
in_features: int,
|
|
num_classes: int,
|
|
weight: Parameter = None,
|
|
bias: bool = True,
|
|
dtype: torch.dtype = None,
|
|
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
|
|
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1),
|
|
):
|
|
super().__init__()
|
|
self.in_features = in_features
|
|
self.num_classes = num_classes
|
|
self.parallel_input = get_parallel_input()
|
|
|
|
# Divide the weight matrix along the last dimension.
|
|
self.input_size_per_partition = divide(in_features, gpc.tensor_parallel_size)
|
|
|
|
# Parameters.
|
|
# Initialize weight.
|
|
factory_kwargs = {"device": get_current_device(), "dtype": dtype}
|
|
if weight is not None:
|
|
self.weight = weight
|
|
self.has_weight = False
|
|
else:
|
|
self.weight = Parameter(torch.empty(self.num_classes, self.input_size_per_partition, **factory_kwargs))
|
|
self.has_weight = True
|
|
if bias:
|
|
self.bias = Parameter(torch.empty(self.num_classes, **factory_kwargs))
|
|
else:
|
|
self.bias = None
|
|
with seed(ParallelMode.TENSOR):
|
|
self.reset_parameters(weight_initializer, bias_initializer)
|
|
self._set_tensor_parallel_attributes()
|
|
set_parallel_input(False)
|
|
env.vocab_parallel = False
|
|
|
|
def reset_parameters(self, weight_initializer, bias_initializer) -> None:
|
|
fan_in, fan_out = self.in_features, self.num_classes
|
|
if self.has_weight:
|
|
weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out)
|
|
if self.bias is not None:
|
|
bias_initializer(self.bias, fan_in=fan_in)
|
|
broadcast(self.bias, gpc.get_ranks_in_group(ParallelMode.PARALLEL_1D)[0], ParallelMode.PARALLEL_1D)
|
|
|
|
def _set_tensor_parallel_attributes(self):
|
|
if self.has_weight:
|
|
num_partition = gpc.get_world_size(ParallelMode.TENSOR)
|
|
set_tensor_parallel_attribute_by_partition(self.weight, num_partition)
|
|
|
|
def _load_from_global_state_dict(self, state_dict, prefix, *args):
|
|
local_state = OrderedDict()
|
|
weight_key = prefix + "weight"
|
|
bias_key = prefix + "bias"
|
|
if gpc.get_local_rank(ParallelMode.TENSOR) == 0:
|
|
# weight
|
|
if self.has_weight:
|
|
weight = state_dict.pop(weight_key, None)
|
|
if weight is not None:
|
|
local_state[weight_key] = weight
|
|
# bias
|
|
if self.bias is not None:
|
|
bias = state_dict.pop(bias_key, None)
|
|
if bias is not None:
|
|
local_state[bias_key] = bias
|
|
|
|
local_state = partition_tensor_parallel_state_dict(
|
|
local_state,
|
|
ParallelMode.PARALLEL_1D,
|
|
dims={weight_key: -1, bias_key: 0},
|
|
partition_states={weight_key: True, bias_key: False},
|
|
)
|
|
super()._load_from_global_state_dict(local_state, prefix, *args)
|
|
|
|
def _save_to_global_state_dict(self, destination, prefix, keep_vars):
|
|
weight_key = prefix + "weight"
|
|
bias_key = prefix + "bias"
|
|
local_state = OrderedDict()
|
|
if self.has_weight:
|
|
local_state[weight_key] = self.weight
|
|
if self.bias is not None:
|
|
local_state[bias_key] = self.bias
|
|
local_state = gather_tensor_parallel_state_dict(
|
|
local_state,
|
|
ParallelMode.PARALLEL_1D,
|
|
dims={weight_key: -1, bias_key: 0},
|
|
partition_states={weight_key: True, bias_key: False},
|
|
keep_vars=keep_vars,
|
|
)
|
|
destination.update(local_state)
|
|
|
|
def forward(self, input_: Tensor) -> Tensor:
|
|
# Set up backprop all-reduce.
|
|
if self.parallel_input:
|
|
assert (
|
|
input_.shape[-1] == self.weight.shape[-1]
|
|
), "Invalid shapes in Classifier1D forward: input={}, weight={}. Expected last dim of input {}.".format(
|
|
input_.shape, self.weight.shape, self.weight.shape[-1]
|
|
)
|
|
input_ = input_
|
|
else:
|
|
assert (
|
|
divide(input_.shape[-1], gpc.tensor_parallel_size) == self.weight.shape[-1]
|
|
), "Invalid shapes in Classifier1D forward: input={}, weight={}. Expected last dim of input {}.".format(
|
|
input_.shape, self.weight.shape, self.weight.shape[-1] * gpc.tensor_parallel_size
|
|
)
|
|
input_ = split_forward_gather_backward(input_, ParallelMode.PARALLEL_1D, dim=-1)
|
|
|
|
output_parallel = F.linear(input_, self.weight)
|
|
output = reduce_input(output_parallel, ParallelMode.PARALLEL_1D)
|
|
if self.bias is not None:
|
|
output = output + self.bias
|
|
return output
|
|
|
|
|
|
@LAYERS.register_module
|
|
class VocabParallelClassifier1D(ParallelLayer):
|
|
r"""ColLinear with given weight. Classifier of 1D parallelism.
|
|
|
|
Args:
|
|
in_features (int): size of each input sample.
|
|
num_classes (int): number of classes.
|
|
weight (:class:`torch.nn.Parameter`, optional): weight of the classifier, defaults to None.
|
|
bias (bool, optional): If set to ``False``, the layer will not learn an additive bias, defaults to ``True``.
|
|
dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None.
|
|
weight_initializer (:class:`typing.Callable`, optional):
|
|
The initializer of weight, defaults to kaiming uniform initializer.
|
|
bias_initializer (:class:`typing.Callable`, optional):
|
|
The initializer of bias, defaults to xavier uniform initializer.
|
|
|
|
More details about ``initializer`` please refer to
|
|
`init <https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/nn/init.py>`_.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
in_features: int,
|
|
num_classes: int,
|
|
weight: Parameter = None,
|
|
bias: bool = True,
|
|
dtype: torch.dtype = None,
|
|
gather_output: bool = False,
|
|
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
|
|
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1),
|
|
):
|
|
super().__init__()
|
|
self.in_features = in_features
|
|
self.num_classes = num_classes
|
|
self.gather_output = gather_output
|
|
self.parallel_input = get_parallel_input()
|
|
|
|
# Divide the weight matrix along the last dimension.
|
|
self.num_classes_per_partition = divide(num_classes, gpc.tensor_parallel_size)
|
|
|
|
# Parameters.
|
|
# Initialize weight.
|
|
factory_kwargs = {"device": get_current_device(), "dtype": dtype}
|
|
if weight is not None:
|
|
self.weight = weight
|
|
self.has_weight = False
|
|
else:
|
|
self.weight = Parameter(torch.empty(self.num_classes_per_partition, self.in_features, **factory_kwargs))
|
|
self.has_weight = True
|
|
if bias:
|
|
self.bias = Parameter(torch.empty(self.num_classes_per_partition, **factory_kwargs))
|
|
else:
|
|
self.bias = None
|
|
with seed(ParallelMode.TENSOR):
|
|
self.reset_parameters(weight_initializer, bias_initializer)
|
|
self._set_tensor_parallel_attributes()
|
|
set_parallel_input(False)
|
|
env.vocab_parallel = True
|
|
|
|
def reset_parameters(self, weight_initializer, bias_initializer) -> None:
|
|
fan_in, fan_out = self.in_features, self.num_classes
|
|
if self.has_weight:
|
|
weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out)
|
|
if self.bias is not None:
|
|
bias_initializer(self.bias, fan_in=fan_in)
|
|
|
|
def _set_tensor_parallel_attributes(self):
|
|
num_partition = gpc.get_world_size(ParallelMode.TENSOR)
|
|
if self.has_weight:
|
|
set_tensor_parallel_attribute_by_partition(self.weight, num_partition)
|
|
if self.bias is not None:
|
|
set_tensor_parallel_attribute_by_partition(self.bias, num_partition)
|
|
|
|
def _load_from_global_state_dict(self, state_dict, prefix, *args):
|
|
local_state = OrderedDict()
|
|
weight_key = prefix + "weight"
|
|
bias_key = prefix + "bias"
|
|
if gpc.get_local_rank(ParallelMode.TENSOR) == 0:
|
|
# weight
|
|
if self.has_weight:
|
|
weight = state_dict.pop(weight_key, None)
|
|
if weight is not None:
|
|
local_state[weight_key] = weight
|
|
# bias
|
|
if self.bias is not None:
|
|
bias = state_dict.pop(bias_key, None)
|
|
if bias is not None:
|
|
local_state[bias_key] = bias
|
|
|
|
local_state = partition_tensor_parallel_state_dict(
|
|
local_state,
|
|
ParallelMode.PARALLEL_1D,
|
|
dims={weight_key: 0, bias_key: 0},
|
|
partition_states={weight_key: True, bias_key: True},
|
|
)
|
|
super()._load_from_global_state_dict(local_state, prefix, *args)
|
|
|
|
def _save_to_global_state_dict(self, destination, prefix, keep_vars):
|
|
weight_key = prefix + "weight"
|
|
bias_key = prefix + "bias"
|
|
local_state = OrderedDict()
|
|
if self.has_weight:
|
|
local_state[weight_key] = self.weight
|
|
if self.bias is not None:
|
|
local_state[bias_key] = self.bias
|
|
local_state = gather_tensor_parallel_state_dict(
|
|
local_state,
|
|
ParallelMode.PARALLEL_1D,
|
|
dims={weight_key: 0, bias_key: 0},
|
|
partition_states={weight_key: True, bias_key: True},
|
|
keep_vars=keep_vars,
|
|
)
|
|
destination.update(local_state)
|
|
|
|
def forward(self, input_: Tensor) -> Tensor:
|
|
assert (
|
|
input_.shape[-1] == self.weight.shape[-1]
|
|
), "Invalid shapes in VocabParallelClassifier1D forward: input={}, weight={}. Expected last dim of input {}.".format(
|
|
input_.shape, self.weight.shape, self.weight.shape[-1]
|
|
)
|
|
# Set up backprop all-reduce.
|
|
input_parallel = reduce_grad(input_, ParallelMode.PARALLEL_1D)
|
|
# Matrix multiply.
|
|
output_parallel = F.linear(input_parallel, self.weight, self.bias)
|
|
if self.gather_output:
|
|
# All-gather across the partitions.
|
|
output = gather_forward_split_backward(output_parallel, ParallelMode.PARALLEL_1D, dim=-1)
|
|
else:
|
|
output = output_parallel
|
|
return output
|
|
|
|
|
|
@LAYERS.register_module
|
|
class Linear1D_Col(ParallelLayer):
|
|
r"""Linear layer with column parallelism.
|
|
|
|
The linear layer is defined as :math:`Y = XA + b`. A is parallelized along
|
|
its second dimension as :math:`A = [A_1, ..., A_p]`.
|
|
|
|
Args:
|
|
in_features (int): size of each input sample.
|
|
out_features (int): size of each output sample.
|
|
bias (bool, optional): If set to ``False``, the layer will not learn an additive bias, defaults to ``True``.
|
|
dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None.
|
|
gather_output (bool, optional): If true, call all-gather on output and make Y available
|
|
to all GPUs, otherwise, every GPU will have its output
|
|
which is :math:`Y_i = XA_i`, defaults to False
|
|
skip_bias_add (bool, optional): If set to ``True``, it will skip bias add for linear layer,
|
|
which is preserved for kernel fusion, defaults to False
|
|
weight_initializer (:class:`typing.Callable`, optional):
|
|
The initializer of weight, defaults to kaiming uniform initializer.
|
|
bias_initializer (:class:`typing.Callable`, optional):
|
|
The initializer of bias, defaults to xavier uniform initializer.
|
|
|
|
More details about ``initializer`` please refer to
|
|
`init <https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/nn/init.py>`_.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
in_features: int,
|
|
out_features: int,
|
|
bias: bool = True,
|
|
dtype: torch.dtype = None,
|
|
gather_output: bool = False,
|
|
skip_bias_add: bool = False,
|
|
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
|
|
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1),
|
|
):
|
|
super().__init__()
|
|
|
|
# Keep input parameters
|
|
self.in_features = in_features
|
|
self.out_features = out_features
|
|
self.gather_output = gather_output
|
|
self.skip_bias_add = skip_bias_add
|
|
|
|
if skip_bias_add and not bias:
|
|
raise ValueError("cannot skip bias addition if bias is None")
|
|
|
|
self.out_features_per_partition = divide(out_features, gpc.tensor_parallel_size)
|
|
|
|
# Parameters.
|
|
# Initialize weight.
|
|
factory_kwargs = {"device": get_current_device(), "dtype": dtype}
|
|
self.weight = Parameter(torch.empty(self.out_features_per_partition, self.in_features, **factory_kwargs))
|
|
|
|
if bias:
|
|
self.bias = Parameter(torch.empty(self.out_features_per_partition, **factory_kwargs))
|
|
else:
|
|
self.bias = None
|
|
with seed(ParallelMode.TENSOR):
|
|
self.reset_parameters(weight_initializer, bias_initializer)
|
|
self._set_tensor_parallel_attributes()
|
|
is_parallel_output = not self.gather_output
|
|
set_parallel_input(is_parallel_output)
|
|
|
|
def reset_parameters(self, weight_initializer, bias_initializer) -> None:
|
|
fan_in, fan_out = self.in_features, self.out_features
|
|
weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out)
|
|
if self.bias is not None:
|
|
bias_initializer(self.bias, fan_in=fan_in)
|
|
|
|
def _set_tensor_parallel_attributes(self):
|
|
num_partition = gpc.get_world_size(ParallelMode.TENSOR)
|
|
set_tensor_parallel_attribute_by_partition(self.weight, num_partition)
|
|
if self.bias is not None:
|
|
set_tensor_parallel_attribute_by_partition(self.bias, num_partition)
|
|
|
|
def _load_from_global_state_dict(self, state_dict, prefix, *args):
|
|
local_state = OrderedDict()
|
|
weight_key = prefix + "weight"
|
|
bias_key = prefix + "bias"
|
|
if gpc.get_local_rank(ParallelMode.TENSOR) == 0:
|
|
# weight
|
|
weight = state_dict.pop(weight_key, None)
|
|
if weight is not None:
|
|
local_state[weight_key] = weight
|
|
# bias
|
|
if self.bias is not None:
|
|
bias = state_dict.pop(bias_key, None)
|
|
if bias is not None:
|
|
local_state[bias_key] = bias
|
|
|
|
local_state = partition_tensor_parallel_state_dict(
|
|
local_state,
|
|
ParallelMode.PARALLEL_1D,
|
|
dims={weight_key: 0, bias_key: 0},
|
|
partition_states={weight_key: True, bias_key: True},
|
|
)
|
|
super()._load_from_global_state_dict(local_state, prefix, *args)
|
|
|
|
def _save_to_global_state_dict(self, destination, prefix, keep_vars):
|
|
weight_key = prefix + "weight"
|
|
bias_key = prefix + "bias"
|
|
local_state = OrderedDict({weight_key: self.weight})
|
|
if self.bias is not None:
|
|
local_state[bias_key] = self.bias
|
|
local_state = gather_tensor_parallel_state_dict(
|
|
local_state,
|
|
ParallelMode.PARALLEL_1D,
|
|
dims={weight_key: 0, bias_key: 0},
|
|
partition_states={weight_key: True, bias_key: True},
|
|
keep_vars=keep_vars,
|
|
)
|
|
destination.update(local_state)
|
|
|
|
def forward(self, input_: Tensor) -> Tuple[Tensor, Tensor]:
|
|
assert (
|
|
input_.shape[-1] == self.weight.shape[-1]
|
|
), "Invalid shapes in Linear1D_Col forward: input={}, weight={}. Expected last dim of input {}.".format(
|
|
input_.shape, self.weight.shape, self.weight.shape[-1]
|
|
)
|
|
# Set up backprop all-reduce.
|
|
# input_parallel = reduce_grad(input_, ParallelMode.PARALLEL_1D)
|
|
input_parallel = input_
|
|
# Matrix multiply.
|
|
bias = self.bias if not self.skip_bias_add else None
|
|
# output_parallel = F.linear(input_parallel, self.weight, bias)
|
|
output_parallel = linear_with_async_comm(input_parallel, self.weight, bias, ParallelMode.PARALLEL_1D, True)
|
|
if self.gather_output:
|
|
# All-gather across the partitions.
|
|
output = gather_forward_split_backward(output_parallel, ParallelMode.PARALLEL_1D, dim=-1)
|
|
else:
|
|
output = output_parallel
|
|
|
|
if self.skip_bias_add:
|
|
return output, self.bias
|
|
else:
|
|
return output
|
|
|
|
|
|
@LAYERS.register_module
|
|
class Linear1D_Row(ParallelLayer):
|
|
r"""Linear layer with row parallelism
|
|
|
|
Args:
|
|
in_features (int): size of each input sample.
|
|
out_features (int): size of each output sample.
|
|
bias (bool, optional): If set to ``False``, the layer will not learn an additive bias, defaults to ``True``.
|
|
dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None.
|
|
parallel_input (bool, optional): If set to ``True``, it's assumed that the input is split, defaults to False.
|
|
skip_bias_add (bool, optional): If set to ``True``, it will skip bias add for linear layer,
|
|
which is preserved for kernel fusion, defaults to False
|
|
weight_initializer (:class:`typing.Callable`, optional):
|
|
The initializer of weight, defaults to kaiming uniform initializer.
|
|
bias_initializer (:class:`typing.Callable`, optional):
|
|
The initializer of bias, defaults to xavier uniform initializer.
|
|
|
|
More details about ``initializer`` please refer to
|
|
`init <https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/nn/init.py>`_.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
in_features: int,
|
|
out_features: int,
|
|
bias: bool = True,
|
|
dtype: torch.dtype = None,
|
|
parallel_input: bool = True,
|
|
skip_bias_add: bool = False,
|
|
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
|
|
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1),
|
|
stream_chunk_num: int = 1,
|
|
):
|
|
super().__init__()
|
|
|
|
self.stream_chunk_num = stream_chunk_num
|
|
|
|
# Keep input parameters
|
|
self.in_features = in_features
|
|
self.out_features = out_features
|
|
self.parallel_input = parallel_input
|
|
self.skip_bias_add = skip_bias_add
|
|
|
|
if skip_bias_add and not bias:
|
|
raise ValueError("cannot skip bias addition if bias is None")
|
|
|
|
# Divide the weight matrix along the last dimension.
|
|
self.input_size_per_partition = divide(in_features, gpc.tensor_parallel_size)
|
|
|
|
# Parameters.
|
|
# Initialize weight.
|
|
factory_kwargs = {"device": get_current_device(), "dtype": dtype}
|
|
self.weight = Parameter(torch.empty(self.out_features, self.input_size_per_partition, **factory_kwargs))
|
|
|
|
if self.stream_chunk_num > 1:
|
|
# TODO() work for inference only
|
|
self.chunk_weight()
|
|
if bias:
|
|
self.bias = Parameter(torch.empty(self.out_features, **factory_kwargs))
|
|
else:
|
|
self.bias = None
|
|
with seed(ParallelMode.TENSOR):
|
|
self.reset_parameters(weight_initializer, bias_initializer)
|
|
self._set_tensor_parallel_attributes()
|
|
set_parallel_input(False)
|
|
|
|
def chunk_weight(self):
|
|
self.weight_list = torch.chunk(self.weight, self.stream_chunk_num, dim=0)
|
|
|
|
def reset_parameters(self, weight_initializer, bias_initializer) -> None:
|
|
fan_in, fan_out = self.in_features, self.out_features
|
|
weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out)
|
|
if self.bias is not None:
|
|
bias_initializer(self.bias, fan_in=fan_in)
|
|
broadcast(self.bias, gpc.get_ranks_in_group(ParallelMode.PARALLEL_1D)[0], ParallelMode.PARALLEL_1D)
|
|
|
|
def _set_tensor_parallel_attributes(self):
|
|
num_partition = gpc.get_world_size(ParallelMode.TENSOR)
|
|
set_tensor_parallel_attribute_by_partition(self.weight, num_partition)
|
|
|
|
def _load_from_global_state_dict(self, state_dict, prefix, *args):
|
|
local_state = OrderedDict()
|
|
weight_key = prefix + "weight"
|
|
bias_key = prefix + "bias"
|
|
if gpc.get_local_rank(ParallelMode.TENSOR) == 0:
|
|
# weight
|
|
weight = state_dict.pop(weight_key, None)
|
|
if weight is not None:
|
|
local_state[weight_key] = weight
|
|
# bias
|
|
if self.bias is not None:
|
|
bias = state_dict.pop(bias_key, None)
|
|
if bias is not None:
|
|
local_state[bias_key] = bias
|
|
|
|
local_state = partition_tensor_parallel_state_dict(
|
|
local_state,
|
|
ParallelMode.PARALLEL_1D,
|
|
dims={weight_key: -1, bias_key: 0},
|
|
partition_states={weight_key: True, bias_key: False},
|
|
)
|
|
super()._load_from_global_state_dict(local_state, prefix, *args)
|
|
|
|
def _save_to_global_state_dict(self, destination, prefix, keep_vars):
|
|
weight_key = prefix + "weight"
|
|
bias_key = prefix + "bias"
|
|
local_state = OrderedDict({weight_key: self.weight})
|
|
if self.bias is not None:
|
|
local_state[bias_key] = self.bias
|
|
local_state = gather_tensor_parallel_state_dict(
|
|
local_state,
|
|
ParallelMode.PARALLEL_1D,
|
|
dims={weight_key: -1, bias_key: 0},
|
|
partition_states={weight_key: True, bias_key: False},
|
|
keep_vars=keep_vars,
|
|
)
|
|
destination.update(local_state)
|
|
|
|
def forward(self, input_: Tensor) -> Tensor:
|
|
# Set up backprop all-reduce.
|
|
if self.parallel_input:
|
|
assert (
|
|
input_.shape[-1] == self.weight.shape[-1]
|
|
), "Invalid shapes in Linear1D_Row forward: input={}, weight={}. Expected last dim of input {}.".format(
|
|
input_.shape, self.weight.shape, self.weight.shape[-1]
|
|
)
|
|
input_ = input_
|
|
else:
|
|
assert (
|
|
divide(input_.shape[-1], gpc.tensor_parallel_size) == self.weight.shape[-1]
|
|
), "Invalid shapes in Linear1D_Row forward: input={}, weight={}. Expected last dim of input {}.".format(
|
|
input_.shape, self.weight.shape, self.weight.shape[-1] * gpc.tensor_parallel_size
|
|
)
|
|
input_ = split_forward_gather_backward(input_, ParallelMode.PARALLEL_1D, dim=-1)
|
|
|
|
if self.stream_chunk_num > 1:
|
|
if self.training:
|
|
raise RuntimeError("use stream_chunk_num=1 in Linear1D_Row for training!")
|
|
with torch.no_grad():
|
|
output_parallel_list = [None for i in range(self.stream_chunk_num)]
|
|
handle_list = []
|
|
for i in range(self.stream_chunk_num):
|
|
output_parallel_list[i] = F.linear(input_, self.weight_list[i])
|
|
handle = torch.distributed.all_reduce(
|
|
output_parallel_list[i], group=gpc.get_group(ParallelMode.PARALLEL_1D), async_op=True
|
|
)
|
|
handle_list.append(handle)
|
|
# output_parallel_list[i] = reduce_input(output_parallel_list[i], ParallelMode.PARALLEL_1D)
|
|
for handle in handle_list:
|
|
handle.wait()
|
|
output = torch.cat(output_parallel_list, dim=-1)
|
|
else:
|
|
output_parallel = F.linear(input_, self.weight)
|
|
# output_parallel = linear_with_async_comm(input_, self.weight, None, ParallelMode.PARALLEL_1D, False)
|
|
output = reduce_input(output_parallel, ParallelMode.PARALLEL_1D)
|
|
if not self.skip_bias_add:
|
|
if self.bias is not None:
|
|
output = output + self.bias
|
|
return output
|
|
else:
|
|
return output, self.bias
|
|
|
|
|
|
@LAYERS.register_module
|
|
class Embedding1D(ParallelLayer):
|
|
r"""Embedding for 1D parallelism.
|
|
|
|
Args:
|
|
num_embeddings (int): number of embeddings.
|
|
embedding_dim (int): dimension of embedding.
|
|
padding_idx (int, optional): If specified, the entries at padding_idx do not contribute to the gradient;
|
|
therefore, the embedding vector at padding_idx is not updated during training,
|
|
i.e. it remains as a fixed “pad”, defaults to None.
|
|
dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None.
|
|
weight_initializer (:class:`typing.Callable`, optional):
|
|
he initializer of weight, defaults to normal initializer.
|
|
|
|
The ``args`` and ``kwargs`` used in :class:`torch.nn.functional.embedding` should contain:
|
|
::
|
|
|
|
max_norm (float, optional): If given, each embedding vector with norm larger than max_norm is
|
|
renormalized to have norm max_norm. Note: this will modify weight in-place.
|
|
norm_type (float, optional): The p of the p-norm to compute for the max_norm option. Default 2.
|
|
scale_grad_by_freq (bool, optional): If given, this will scale gradients by the inverse
|
|
of frequency of the words in the mini-batch. Default False.
|
|
sparse (bool, optional): If True, gradient w.r.t. weight will be a sparse tensor. Default False.
|
|
|
|
More details about ``args`` and ``kwargs`` could be found in
|
|
`Embedding <https://pytorch.org/docs/stable/generated/torch.nn.functional.embedding.html#torch.nn.functional.embedding>`_.
|
|
|
|
More details about ``initializer`` please refer to
|
|
`init <https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/nn/init.py>`_
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
num_embeddings: int,
|
|
embedding_dim: int,
|
|
padding_idx: int = None,
|
|
dtype: torch.dtype = None,
|
|
weight_initializer: Callable = init.normal_(),
|
|
*args,
|
|
**kwargs,
|
|
):
|
|
super().__init__()
|
|
|
|
self.num_embeddings = num_embeddings
|
|
self.embed_dim = embedding_dim
|
|
embed_dim_per_partition = divide(embedding_dim, gpc.tensor_parallel_size)
|
|
|
|
self.padding_idx = padding_idx
|
|
self.embed_args = args
|
|
self.embed_kwargs = kwargs
|
|
|
|
self.weight = Parameter(
|
|
torch.empty((num_embeddings, embed_dim_per_partition), device=get_current_device(), dtype=dtype)
|
|
)
|
|
|
|
self.reset_parameters(weight_initializer)
|
|
self._set_tensor_parallel_attributes()
|
|
set_parallel_input(False)
|
|
|
|
def _set_tensor_parallel_attributes(self):
|
|
set_tensor_parallel_attribute_by_partition(self.weight, gpc.tensor_parallel_size)
|
|
|
|
def reset_parameters(self, weight_initializer) -> None:
|
|
with seed(ParallelMode.TENSOR):
|
|
fan_in, fan_out = self.num_embeddings, self.embed_dim
|
|
weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out)
|
|
self._fill_padding_idx_with_zero()
|
|
|
|
def _fill_padding_idx_with_zero(self) -> None:
|
|
if self.padding_idx is not None:
|
|
with torch.no_grad():
|
|
self.weight[self.padding_idx].fill_(0)
|
|
|
|
def _load_from_global_state_dict(self, state_dict, prefix, *args):
|
|
local_state = OrderedDict()
|
|
weight_key = prefix + "weight"
|
|
if gpc.get_local_rank(ParallelMode.TENSOR) == 0:
|
|
# weight
|
|
weight = state_dict.pop(weight_key, None)
|
|
if weight is not None:
|
|
local_state[weight_key] = weight
|
|
|
|
local_state = partition_tensor_parallel_state_dict(
|
|
local_state, ParallelMode.PARALLEL_1D, dims={weight_key: -1}, partition_states={weight_key: True}
|
|
)
|
|
super()._load_from_global_state_dict(local_state, prefix, *args)
|
|
|
|
def _save_to_global_state_dict(self, destination, prefix, keep_vars):
|
|
weight_key = prefix + "weight"
|
|
local_state = OrderedDict({weight_key: self.weight})
|
|
local_state = gather_tensor_parallel_state_dict(
|
|
local_state,
|
|
ParallelMode.PARALLEL_1D,
|
|
dims={weight_key: -1},
|
|
partition_states={weight_key: True},
|
|
keep_vars=keep_vars,
|
|
)
|
|
destination.update(local_state)
|
|
|
|
def forward(self, input_: Tensor) -> Tensor:
|
|
output_parallel = F.embedding(input_, self.weight, self.padding_idx, *self.embed_args, **self.embed_kwargs)
|
|
|
|
output = gather_forward_split_backward(output_parallel, ParallelMode.PARALLEL_1D, dim=-1)
|
|
|
|
return output
|
|
|
|
|
|
@LAYERS.register_module
|
|
class VocabParallelEmbedding1D(ParallelLayer):
|
|
r"""Embedding parallelized in the vocabulary dimension.
|
|
|
|
Args:
|
|
num_embeddings (int): number of embeddings.
|
|
embedding_dim (int): dimension of embedding.
|
|
padding_idx (int, optional): If specified, the entries at padding_idx do not contribute to the gradient;
|
|
therefore, the embedding vector at padding_idx is not updated during training,
|
|
i.e. it remains as a fixed “pad”, defaults to None.
|
|
dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None.
|
|
weight_initializer (:class:`typing.Callable`, optional):
|
|
he initializer of weight, defaults to normal initializer.
|
|
|
|
The ``args`` and ``kwargs`` used in :class:``torch.nn.functional.embedding`` should contain:
|
|
::
|
|
|
|
max_norm (float, optional): If given, each embedding vector with norm larger than max_norm is
|
|
renormalized to have norm max_norm. Note: this will modify weight in-place.
|
|
norm_type (float, optional): The p of the p-norm to compute for the max_norm option. Default 2.
|
|
scale_grad_by_freq (bool, optional): If given, this will scale gradients by the inverse
|
|
of frequency of the words in the mini-batch. Default False.
|
|
sparse (bool, optional): If True, gradient w.r.t. weight will be a sparse tensor. Default False.
|
|
|
|
More details about ``args`` and ``kwargs`` could be found in
|
|
`Embedding <https://pytorch.org/docs/stable/generated/torch.nn.functional.embedding.html#torch.nn.functional.embedding>`_.
|
|
|
|
More details about initializer please refer to
|
|
`init <https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/nn/init.py>`_.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
num_embeddings: int,
|
|
embedding_dim: int,
|
|
padding_idx: int = None,
|
|
dtype: torch.dtype = None,
|
|
weight_initializer: Callable = init.normal_(),
|
|
*args,
|
|
**kwargs,
|
|
):
|
|
super().__init__()
|
|
self.num_embeddings = num_embeddings
|
|
self.embed_dim = embedding_dim
|
|
self.padding_idx = padding_idx
|
|
self.embed_args = args
|
|
self.embed_kwargs = kwargs
|
|
|
|
tensor_parallel_size = gpc.get_world_size(ParallelMode.PARALLEL_1D)
|
|
tensor_parallel_rank = gpc.get_local_rank(ParallelMode.PARALLEL_1D)
|
|
self.num_embeddings_per_partition = divide(num_embeddings, tensor_parallel_size)
|
|
self.vocab_start_index = tensor_parallel_rank * self.num_embeddings_per_partition
|
|
self.vocab_end_index = self.vocab_start_index + self.num_embeddings_per_partition
|
|
|
|
self.weight = Parameter(
|
|
torch.empty((self.num_embeddings_per_partition, self.embed_dim), device=get_current_device(), dtype=dtype)
|
|
)
|
|
|
|
self.reset_parameters(weight_initializer)
|
|
self._set_tensor_parallel_attributes()
|
|
set_parallel_input(False)
|
|
env.vocab_parallel = True
|
|
|
|
def _set_tensor_parallel_attributes(self):
|
|
set_tensor_parallel_attribute_by_partition(self.weight, gpc.tensor_parallel_size)
|
|
|
|
def reset_parameters(self, weight_initializer) -> None:
|
|
with seed(ParallelMode.TENSOR):
|
|
fan_in, fan_out = self.num_embeddings, self.embed_dim
|
|
weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out)
|
|
self._fill_padding_idx_with_zero()
|
|
|
|
def _fill_padding_idx_with_zero(self) -> None:
|
|
if (
|
|
self.padding_idx is not None
|
|
and self.padding_idx >= self.vocab_start_index
|
|
and self.padding_idx < self.vocab_end_index
|
|
):
|
|
with torch.no_grad():
|
|
self.weight[self.padding_idx - self.vocab_start_index].fill_(0)
|
|
|
|
def _load_from_global_state_dict(self, state_dict, prefix, *args):
|
|
local_state = OrderedDict()
|
|
weight_key = prefix + "weight"
|
|
if gpc.get_local_rank(ParallelMode.TENSOR) == 0:
|
|
# weight
|
|
weight = state_dict.pop(weight_key, None)
|
|
if weight is not None:
|
|
local_state[weight_key] = weight
|
|
|
|
local_state = partition_tensor_parallel_state_dict(
|
|
local_state, ParallelMode.PARALLEL_1D, dims={weight_key: 0}, partition_states={weight_key: True}
|
|
)
|
|
super()._load_from_global_state_dict(local_state, prefix, *args)
|
|
|
|
def _save_to_global_state_dict(self, destination, prefix, keep_vars):
|
|
weight_key = prefix + "weight"
|
|
local_state = OrderedDict({weight_key: self.weight})
|
|
local_state = gather_tensor_parallel_state_dict(
|
|
local_state,
|
|
ParallelMode.PARALLEL_1D,
|
|
dims={weight_key: 0},
|
|
partition_states={weight_key: True},
|
|
keep_vars=keep_vars,
|
|
)
|
|
destination.update(local_state)
|
|
|
|
def forward(self, input_: Tensor) -> Tensor:
|
|
# Build the mask.
|
|
input_mask = (input_ < self.vocab_start_index) | (input_ >= self.vocab_end_index)
|
|
# Mask the input.
|
|
masked_input = input_.clone() - self.vocab_start_index
|
|
masked_input[input_mask] = 0
|
|
|
|
output_parallel = F.embedding(
|
|
masked_input, self.weight, self.padding_idx, *self.embed_args, **self.embed_kwargs
|
|
)
|
|
|
|
# Mask the output embedding.
|
|
output_parallel[input_mask, :] = 0.0
|
|
# Reduce across all the model parallel GPUs.
|
|
output = reduce_input(output_parallel, ParallelMode.PARALLEL_1D)
|
|
return output
|
|
|
|
|
|
@LAYERS.register_module
|
|
class Dropout1D(ParallelLayer):
|
|
"""Dropout layer of 1D parallelism.
|
|
|
|
Args:
|
|
p (float, optional): probability of an element to be zeroed, defaults 0.5.
|
|
inplace (bool, optional): whether to do dropout in-place, default to be False.
|
|
"""
|
|
|
|
def __init__(self, p: float = 0.5, inplace: bool = False):
|
|
super().__init__()
|
|
self.parallel_input = get_parallel_input()
|
|
self.p = p
|
|
self.inplace = inplace
|
|
|
|
def forward(self, input_: Tensor) -> Tensor:
|
|
if self.parallel_input:
|
|
with seed(ParallelMode.TENSOR):
|
|
output = F.dropout(input_, self.p, self.training, self.inplace)
|
|
else:
|
|
output = F.dropout(input_, self.p, self.training, self.inplace)
|
|
return output
|
|
|
|
|
|
@LAYERS.register_module
|
|
class PatchEmbedding1D(ColossalaiModule):
|
|
"""
|
|
2D Image to Patch Embedding
|
|
|
|
:param img_size: image size
|
|
:type img_size: int
|
|
:param patch_size: patch size
|
|
:type patch_size: int
|
|
:param in_chans: number of channels of input image
|
|
:type in_chans: int
|
|
:param embed_size: size of embedding
|
|
:type embed_size: int
|
|
:param dtype: The dtype of parameters, defaults to None
|
|
:type dtype: torch.dtype, optional
|
|
:param flatten: whether to flatten output tensor, defaults to True
|
|
:type flatten: bool, optional
|
|
:param weight_initializer: The initializer of weight, defaults to kaiming uniform initializer
|
|
:type weight_initializer: typing.Callable, optional
|
|
:param bias_initializer: The initializer of bias, defaults to xavier uniform initializer
|
|
:type bias_initializer: typing.Callable, optional
|
|
:param position_embed_initializer: The initializer of position embedding, defaults to zero
|
|
:type position_embed_initializer: typing.Callable, optional
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
img_size: int,
|
|
patch_size: int,
|
|
in_chans: int,
|
|
embed_size: int,
|
|
dtype: torch.dtype = None,
|
|
flatten: bool = True,
|
|
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
|
|
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1),
|
|
position_embed_initializer: Callable = init.zeros_(),
|
|
):
|
|
embed = VanillaPatchEmbedding(
|
|
img_size,
|
|
patch_size,
|
|
in_chans,
|
|
embed_size,
|
|
dtype=dtype,
|
|
flatten=flatten,
|
|
weight_initializer=weight_initializer,
|
|
bias_initializer=bias_initializer,
|
|
position_embed_initializer=position_embed_initializer,
|
|
)
|
|
super().__init__(embed)
|
|
|
|
def _load_from_state_dict(self, state_dict, prefix, *args):
|
|
local_state = OrderedDict()
|
|
param_keys = [prefix + "weight", prefix + "bias", prefix + "cls_token", prefix + "pos_embed"]
|
|
if gpc.get_local_rank(ParallelMode.TENSOR) == 0:
|
|
for key in param_keys:
|
|
param = state_dict.pop(key, None)
|
|
if param is not None:
|
|
local_state[key] = param
|
|
|
|
local_state = broadcast_state_dict(local_state, ParallelMode.PARALLEL_1D)
|
|
super()._load_from_state_dict(local_state, prefix, *args)
|
|
|
|
def _save_to_state_dict(self, destination, prefix, keep_vars):
|
|
if gpc.get_local_rank(ParallelMode.TENSOR) == 0:
|
|
super()._save_to_state_dict(destination, prefix, keep_vars)
|