import math import torch import torch.distributed as dist from torch import Tensor from torch.nn import Parameter, init as init from colossalai.context import seed, ParallelMode from colossalai.core import global_context as gpc from colossalai.registry import LAYERS from colossalai.utils import get_current_device from ._operation import Matmul_AB_2D, Add_Bias_2D, _LayerNorm_2D from ._utils import get_summa_dim_from_env, assert_summa_initialization from .._common_utils import divide, set_tensor_parallel_attribute from ..base_layer import ParallelLayer @LAYERS.register_module class Linear2D(ParallelLayer): """ Linear layer for 2D parallelism :param in_features: size of each input sample :type in_features: int :param out_features: size of each output sample :type out_features: int :param bias: If set to ``False``, the layer will not learn an additive bias, defaults to True :type bias: bool, optional :param dtype: The dtype of parameters, defaults to None :type dtype: torch.dtype, optional :param skip_bias_add: If set to ``True``, it will skip bias add for linear layer, which is preserved for kernel fusion, defaults to False :type skip_bias_add: bool, optional """ def __init__(self, in_features: int, out_features: int, bias: bool = True, dtype=None, skip_bias_add: bool = False ): super().__init__() self.in_features = in_features self.out_features = out_features self.skip_bias_add = skip_bias_add # parallel settings assert_summa_initialization() self.row_rank = gpc.get_local_rank(ParallelMode.PARALLEL_2D_COL) self.col_rank = gpc.get_local_rank(ParallelMode.PARALLEL_2D_ROW) self.summa_dim = get_summa_dim_from_env() # partitioning dimension self.input_size_per_partition = divide( self.in_features, self.summa_dim) self.hidden_size_per_partition = divide( self.out_features, self.summa_dim) # create weight, shape: [k/q, h/q] factory_kwargs = {'device': get_current_device(), 'dtype': dtype} self.weight = Parameter(torch.empty( self.input_size_per_partition, self.hidden_size_per_partition, **factory_kwargs)) # create bias, shape: [h/q] if bias: self.bias = Parameter(torch.empty( self.hidden_size_per_partition, **factory_kwargs)) else: self.register_parameter('bias', None) # initialize parameters self.reset_parameters() self._set_tensor_parallel_attributes() def _set_tensor_parallel_attributes(self): set_tensor_parallel_attribute(self.weight) if self.bias is not None: set_tensor_parallel_attribute(self.bias) def reset_parameters(self) -> None: # setting fan_in = self.in_features a = math.sqrt(5) nonlinearity = 'leaky_relu' # init weight std = init.calculate_gain(nonlinearity, a) / math.sqrt(fan_in) bound = math.sqrt(3.0) * std with seed(ParallelMode.TENSOR): init.uniform_(self.weight, -bound, bound) # init bias if self.bias is not None: bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0 with seed(ParallelMode.TENSOR): init.uniform_(self.bias, -bound, bound) def forward(self, x: Tensor) -> Tensor: # input: [m/q, n/q, k/q] # output: [m/q, n/q, h/q] out_shape = x.shape[:-1] + (self.hidden_size_per_partition,) output = Matmul_AB_2D.apply( x, self.weight, self.summa_dim, out_shape, self.row_rank, self.col_rank, ParallelMode.PARALLEL_2D_ROW, ParallelMode.PARALLEL_2D_COL, self.data_parallel_rank, self.pipeline_parallel_rank, self.pipeline_parallel_size, self.tensor_parallel_size) if self.bias is not None: if self.skip_bias_add: bias = Add_Bias_2D.apply( None, self.bias, self.hidden_size_per_partition, self.row_rank, self.col_rank, ParallelMode.PARALLEL_2D_ROW, ParallelMode.PARALLEL_2D_COL, True, self.data_parallel_rank, self.pipeline_parallel_rank, self.pipeline_parallel_size, self.tensor_parallel_size ) return output, bias else: output = Add_Bias_2D.apply( output, self.bias, self.hidden_size_per_partition, self.row_rank, self.col_rank, ParallelMode.PARALLEL_2D_ROW, ParallelMode.PARALLEL_2D_COL, False, self.data_parallel_rank, self.pipeline_parallel_rank, self.pipeline_parallel_size, self.tensor_parallel_size ) return output else: return output @LAYERS.register_module class LayerNorm2D(ParallelLayer): r"""Layer Normalization for 2D parallelism :param normalized_shape: input shape from an expected input of size. :math:`[* \times \text{normalized_shape}[0] \times \text{normalized_shape}[1] \times \ldots \times \text{normalized_shape}[-1]]` If a single integer is used, it is treated as a singleton list, and this module will normalize over the last dimension which is expected to be of that specific size. :type normalized_shape: int :param eps: a value added to the denominator for numerical stability, defaults to 1e-05 :type eps: float, optional :param dtype: The dtype of parameters, defaults to None :type dtype: torch.dtype, optional """ def __init__(self, normalized_shape: int, eps: float = 1e-05, dtype=None ): super().__init__() # layer norm config self.normalized_shape = normalized_shape self.variance_epsilon = eps # parallel setting assert_summa_initialization() self.row_rank = gpc.get_local_rank(ParallelMode.PARALLEL_2D_COL) self.col_rank = gpc.get_local_rank(ParallelMode.PARALLEL_2D_ROW) self.summa_dim = get_summa_dim_from_env() # partitioning dimension self.partitioned_partition = divide(normalized_shape, self.summa_dim) # create parameters factory_kwargs = {'device': get_current_device(), 'dtype': dtype} if self.row_rank == 0: self.gamma = Parameter(torch.ones( self.partitioned_partition, **factory_kwargs)) self.beta = Parameter(torch.zeros( self.partitioned_partition, **factory_kwargs)) else: self.gamma = Parameter(torch.tensor( 1.0, requires_grad=True, **factory_kwargs)) self.beta = Parameter(torch.tensor( 1.0, requires_grad=True, **factory_kwargs)) self._set_tensor_parallel_attributes() def _set_tensor_parallel_attributes(self): set_tensor_parallel_attribute(self.gamma) set_tensor_parallel_attribute(self.beta) def forward(self, x: Tensor) -> Tensor: with torch.no_grad(): E_x = torch.sum(x, dim=-1, keepdim=True) # [b/q, s, 1] torch.distributed.all_reduce( E_x, group=gpc.get_group(ParallelMode.PARALLEL_2D_ROW)) E_x /= self.normalized_shape # Var_x in the block below is the sum of input^2 Var_x = torch.sum(x * x, dim=-1, keepdim=True) # [b/q, s, 1] torch.distributed.all_reduce( Var_x, group=gpc.get_group(ParallelMode.PARALLEL_2D_ROW)) Var_x /= self.normalized_shape Var_x = Var_x - E_x * E_x # variance of x [b/q, s, 1] # this time 1/sqrt(Var_x + epsilon) Var_x = 1.0 / torch.sqrt(Var_x + self.variance_epsilon) output = _LayerNorm_2D.apply(x, E_x, Var_x, self.normalized_shape, ParallelMode.PARALLEL_2D_ROW, ParallelMode.PARALLEL_2D_COL) bias = Add_Bias_2D.apply( None, self.beta, self.partitioned_partition, self.row_rank, self.col_rank, ParallelMode.PARALLEL_2D_ROW, ParallelMode.PARALLEL_2D_COL, True, self.data_parallel_rank, self.pipeline_parallel_rank, self.pipeline_parallel_size, self.tensor_parallel_size ) scale = Add_Bias_2D.apply( None, self.gamma, self.partitioned_partition, self.row_rank, self.col_rank, ParallelMode.PARALLEL_2D_ROW, ParallelMode.PARALLEL_2D_COL, True, self.data_parallel_rank, self.pipeline_parallel_rank, self.pipeline_parallel_size, self.tensor_parallel_size ) output = torch.addcmul(bias, scale, output) return output