You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
ColossalAI/colossalai/nn/layer/parallel_2d/layers.py

259 lines
9.3 KiB

3 years ago
import math
import torch
import torch.distributed as dist
from torch import Tensor
from torch.nn import Parameter, init as init
from colossalai.context import seed, ParallelMode
from colossalai.core import global_context as gpc
from colossalai.registry import LAYERS
from colossalai.utils import get_current_device
from ._operation import Matmul_AB_2D, Add_Bias_2D, _LayerNorm_2D
from ._utils import get_summa_dim_from_env, assert_summa_initialization
from .._common_utils import divide, set_tensor_parallel_attribute
from ..base_layer import ParallelLayer
@LAYERS.register_module
class Linear2D(ParallelLayer):
""" Linear layer for 2D parallelism
:param in_features: size of each input sample
:type in_features: int
:param out_features: size of each output sample
:type out_features: int
:param bias: If set to ``False``, the layer will not learn an additive bias, defaults to True
:type bias: bool, optional
:param dtype: The dtype of parameters, defaults to None
:type dtype: torch.dtype, optional
:param skip_bias_add: If set to ``True``, it will skip bias add for linear layer, which is preserved for kernel fusion, defaults to False
:type skip_bias_add: bool, optional
"""
def __init__(self,
in_features: int,
out_features: int,
bias: bool = True,
dtype=None,
skip_bias_add: bool = False
):
super().__init__()
self.in_features = in_features
self.out_features = out_features
self.skip_bias_add = skip_bias_add
# parallel settings
assert_summa_initialization()
self.row_rank = gpc.get_local_rank(ParallelMode.PARALLEL_2D_COL)
self.col_rank = gpc.get_local_rank(ParallelMode.PARALLEL_2D_ROW)
self.summa_dim = get_summa_dim_from_env()
# partitioning dimension
self.input_size_per_partition = divide(
self.in_features, self.summa_dim)
self.hidden_size_per_partition = divide(
self.out_features, self.summa_dim)
# create weight, shape: [k/q, h/q]
factory_kwargs = {'device': get_current_device(), 'dtype': dtype}
self.weight = Parameter(torch.empty(
self.input_size_per_partition,
self.hidden_size_per_partition,
**factory_kwargs))
# create bias, shape: [h/q]
if bias:
self.bias = Parameter(torch.empty(
self.hidden_size_per_partition,
**factory_kwargs))
else:
self.register_parameter('bias', None)
# initialize parameters
self.reset_parameters()
self._set_tensor_parallel_attributes()
def _set_tensor_parallel_attributes(self):
set_tensor_parallel_attribute(self.weight)
if self.bias is not None:
set_tensor_parallel_attribute(self.bias)
def reset_parameters(self) -> None:
# setting
fan_in = self.in_features
a = math.sqrt(5)
nonlinearity = 'leaky_relu'
# init weight
std = init.calculate_gain(nonlinearity, a) / math.sqrt(fan_in)
bound = math.sqrt(3.0) * std
with seed(ParallelMode.TENSOR):
init.uniform_(self.weight, -bound, bound)
# init bias
if self.bias is not None:
bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
with seed(ParallelMode.TENSOR):
init.uniform_(self.bias, -bound, bound)
def forward(self, x: Tensor) -> Tensor:
# input: [m/q, n/q, k/q]
# output: [m/q, n/q, h/q]
out_shape = x.shape[:-1] + (self.hidden_size_per_partition,)
output = Matmul_AB_2D.apply(
x,
self.weight,
self.summa_dim,
out_shape,
self.row_rank,
self.col_rank,
ParallelMode.PARALLEL_2D_ROW,
ParallelMode.PARALLEL_2D_COL,
self.data_parallel_rank,
self.pipeline_parallel_rank,
self.pipeline_parallel_size,
self.tensor_parallel_size)
if self.bias is not None:
if self.skip_bias_add:
bias = Add_Bias_2D.apply(
None,
self.bias,
self.hidden_size_per_partition,
self.row_rank,
self.col_rank,
ParallelMode.PARALLEL_2D_ROW,
ParallelMode.PARALLEL_2D_COL,
True,
self.data_parallel_rank,
self.pipeline_parallel_rank,
self.pipeline_parallel_size,
self.tensor_parallel_size
)
return output, bias
else:
output = Add_Bias_2D.apply(
output,
self.bias,
self.hidden_size_per_partition,
self.row_rank,
self.col_rank,
ParallelMode.PARALLEL_2D_ROW,
ParallelMode.PARALLEL_2D_COL,
False,
self.data_parallel_rank,
self.pipeline_parallel_rank,
self.pipeline_parallel_size,
self.tensor_parallel_size
)
return output
else:
return output
@LAYERS.register_module
class LayerNorm2D(ParallelLayer):
r"""Layer Normalization for 2D parallelism
:param normalized_shape: input shape from an expected input
of size. :math:`[* \times \text{normalized_shape}[0] \times \text{normalized_shape}[1] \times \ldots \times \text{normalized_shape}[-1]]`
If a single integer is used, it is treated as a singleton list, and this module will
normalize over the last dimension which is expected to be of that specific size.
:type normalized_shape: int
:param eps: a value added to the denominator for numerical stability, defaults to 1e-05
:type eps: float, optional
:param dtype: The dtype of parameters, defaults to None
:type dtype: torch.dtype, optional
"""
def __init__(self,
normalized_shape: int,
eps: float = 1e-05,
dtype=None
):
super().__init__()
# layer norm config
self.normalized_shape = normalized_shape
self.variance_epsilon = eps
# parallel setting
assert_summa_initialization()
self.row_rank = gpc.get_local_rank(ParallelMode.PARALLEL_2D_COL)
self.col_rank = gpc.get_local_rank(ParallelMode.PARALLEL_2D_ROW)
self.summa_dim = get_summa_dim_from_env()
# partitioning dimension
self.partitioned_partition = divide(normalized_shape, self.summa_dim)
# create parameters
factory_kwargs = {'device': get_current_device(), 'dtype': dtype}
if self.row_rank == 0:
self.gamma = Parameter(torch.ones(
self.partitioned_partition,
**factory_kwargs))
self.beta = Parameter(torch.zeros(
self.partitioned_partition,
**factory_kwargs))
else:
self.gamma = Parameter(torch.tensor(
1.0,
requires_grad=True,
**factory_kwargs))
self.beta = Parameter(torch.tensor(
1.0,
requires_grad=True,
**factory_kwargs))
self._set_tensor_parallel_attributes()
def _set_tensor_parallel_attributes(self):
set_tensor_parallel_attribute(self.gamma)
set_tensor_parallel_attribute(self.beta)
def forward(self, x: Tensor) -> Tensor:
with torch.no_grad():
E_x = torch.sum(x, dim=-1, keepdim=True) # [b/q, s, 1]
torch.distributed.all_reduce(
E_x, group=gpc.get_group(ParallelMode.PARALLEL_2D_ROW))
E_x /= self.normalized_shape
# Var_x in the block below is the sum of input^2
Var_x = torch.sum(x * x, dim=-1, keepdim=True) # [b/q, s, 1]
torch.distributed.all_reduce(
Var_x, group=gpc.get_group(ParallelMode.PARALLEL_2D_ROW))
Var_x /= self.normalized_shape
Var_x = Var_x - E_x * E_x # variance of x [b/q, s, 1]
# this time 1/sqrt(Var_x + epsilon)
Var_x = 1.0 / torch.sqrt(Var_x + self.variance_epsilon)
output = _LayerNorm_2D.apply(x, E_x, Var_x, self.normalized_shape,
ParallelMode.PARALLEL_2D_ROW, ParallelMode.PARALLEL_2D_COL)
bias = Add_Bias_2D.apply(
None, self.beta, self.partitioned_partition,
self.row_rank, self.col_rank,
ParallelMode.PARALLEL_2D_ROW, ParallelMode.PARALLEL_2D_COL,
True,
self.data_parallel_rank,
self.pipeline_parallel_rank,
self.pipeline_parallel_size,
self.tensor_parallel_size
)
scale = Add_Bias_2D.apply(
None, self.gamma, self.partitioned_partition,
self.row_rank, self.col_rank,
ParallelMode.PARALLEL_2D_ROW, ParallelMode.PARALLEL_2D_COL,
True,
self.data_parallel_rank,
self.pipeline_parallel_rank,
self.pipeline_parallel_size,
self.tensor_parallel_size
)
output = torch.addcmul(bias, scale, output)
return output