import math from collections import OrderedDict from typing import Callable import torch import torch.nn as nn import torch.nn.functional as F from torch import Tensor from torch.nn import Parameter from colossalai.communication import broadcast from colossalai.context import ParallelMode, seed from colossalai.core import global_context as gpc from colossalai.global_variables import tensor_parallel_env as env from colossalai.legacy.registry import LAYERS from colossalai.nn import init as init from colossalai.utils.checkpointing import gather_tensor_parallel_state_dict, partition_tensor_parallel_state_dict from colossalai.utils.cuda import get_current_device from ..base_layer import ParallelLayer from ..utils import divide, set_tensor_parallel_attribute_by_partition, to_2tuple from ._operation import ( Matmul_AB_2D, Matmul_ABT_2D, add_bias_2d, all_gather_tensor_2d, classifier_2d, layernorm_2d, reduce_scatter_tensor_2d, split_batch_2d, ) from ._utils import assert_summa_initialization, get_summa_dim_from_env @LAYERS.register_module class Linear2D(ParallelLayer): r"""Linear layer for 2D parallelism Args: in_features (int): size of each input sample. out_features (int): size of each output sample. bias (bool, optional): If set to ``False``, the layer will not learn an additive bias, defaults to ``True``. dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None. skip_bias_add (bool, optional): If set to ``True``, it will skip bias add for linear layer, which is preserved for kernel fusion, defaults to False. weight_initializer (:class:`typing.Callable`, optional): The initializer of weight, defaults to kaiming uniform initializer. bias_initializer (:class:`typing.Callable`, optional): The initializer of bias, defaults to xavier uniform initializer. More details about ``initializer`` please refer to `init `_. """ def __init__(self, in_features: int, out_features: int, bias: bool = True, dtype: torch.dtype = None, skip_bias_add: bool = False, weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1)): super().__init__() self.in_features = in_features self.out_features = out_features self.skip_bias_add = skip_bias_add # parallel settings assert_summa_initialization() self.row_rank = gpc.get_local_rank(ParallelMode.PARALLEL_2D_COL) self.col_rank = gpc.get_local_rank(ParallelMode.PARALLEL_2D_ROW) self.summa_dim = get_summa_dim_from_env() # partitioning dimension self.input_size_per_partition = divide(self.in_features, self.summa_dim) self.hidden_size_per_partition = divide(self.out_features, self.summa_dim) # create weight, shape: [k/q, h/q] factory_kwargs = {'device': get_current_device(), 'dtype': dtype} self.weight = Parameter( torch.empty(self.input_size_per_partition, self.hidden_size_per_partition, **factory_kwargs)) # create bias, shape: [h/q] if bias: self.bias = Parameter(torch.empty(divide(self.out_features, self.summa_dim**2), **factory_kwargs)) else: self.register_parameter('bias', None) # initialize parameters with seed(ParallelMode.TENSOR): self.reset_parameters(weight_initializer, bias_initializer) self._set_tensor_parallel_attributes() def _set_tensor_parallel_attributes(self): set_tensor_parallel_attribute_by_partition(self.weight, self.summa_dim**2) if self.bias is not None: set_tensor_parallel_attribute_by_partition(self.bias, self.summa_dim**2) def reset_parameters(self, weight_initializer, bias_initializer) -> None: fan_in, fan_out = self.in_features, self.out_features weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out) if self.bias is not None: bias_initializer(self.bias, fan_in=fan_in) def _load_from_global_state_dict(self, state_dict, prefix, *args, **kwargs): local_state = OrderedDict() weight_key = prefix + 'weight' bias_key = prefix + 'bias' if gpc.get_local_rank(ParallelMode.TENSOR) == 0: # weight weight = state_dict.pop(weight_key, None) if weight is not None: local_state[weight_key] = weight.transpose(0, 1) # bias if self.bias is not None: bias = state_dict.pop(bias_key, None) if bias is not None: local_state[bias_key] = bias # partition in row groups if gpc.get_local_rank(ParallelMode.PARALLEL_2D_COL) == 0: local_state = partition_tensor_parallel_state_dict( local_state, ParallelMode.PARALLEL_2D_ROW, dims={ weight_key: -1, bias_key: 0 }, partition_states={ weight_key: True, bias_key: True }, ) # partition in column groups local_state = partition_tensor_parallel_state_dict( local_state, ParallelMode.PARALLEL_2D_COL, dims={ weight_key: 0, bias_key: 0 }, partition_states={ weight_key: True, bias_key: True }, ) super()._load_from_global_state_dict(local_state, prefix, *args, **kwargs) def _save_to_global_state_dict(self, destination, prefix, keep_vars): weight_key = prefix + 'weight' bias_key = prefix + 'bias' local_state = OrderedDict({weight_key: self.weight}) if self.bias is not None: local_state[bias_key] = self.bias # gather in column groups local_state = gather_tensor_parallel_state_dict( local_state, ParallelMode.PARALLEL_2D_COL, dims={ weight_key: 0, bias_key: 0 }, partition_states={ weight_key: True, bias_key: True }, keep_vars=keep_vars, ) # gather in row groups if gpc.get_local_rank(ParallelMode.PARALLEL_2D_COL) == 0: local_state = gather_tensor_parallel_state_dict( local_state, ParallelMode.PARALLEL_2D_ROW, dims={ weight_key: -1, bias_key: 0 }, partition_states={ weight_key: True, bias_key: True }, keep_vars=keep_vars, ) if gpc.get_local_rank(ParallelMode.TENSOR) == 0: local_state[weight_key] = local_state[weight_key].transpose(0, 1) destination.update(local_state) def forward(self, x: Tensor) -> Tensor: # input: [m/q, n/q, k/q] # output: [m/q, n/q, h/q] out_shape = x.shape[:-1] + (self.hidden_size_per_partition,) output = Matmul_AB_2D.apply(x, self.weight, self.summa_dim, out_shape, self.row_rank, self.col_rank, ParallelMode.PARALLEL_2D_ROW, ParallelMode.PARALLEL_2D_COL, self.data_parallel_rank, self.pipeline_parallel_rank, self.pipeline_parallel_size, self.tensor_parallel_size) if self.bias is not None: if self.skip_bias_add: bias = add_bias_2d(None, self.bias, self.hidden_size_per_partition, self.row_rank, self.col_rank, ParallelMode.PARALLEL_2D_ROW, ParallelMode.PARALLEL_2D_COL, True, self.data_parallel_rank, self.pipeline_parallel_rank, self.pipeline_parallel_size, self.tensor_parallel_size) return output, bias else: output = add_bias_2d(output, self.bias, self.hidden_size_per_partition, self.row_rank, self.col_rank, ParallelMode.PARALLEL_2D_ROW, ParallelMode.PARALLEL_2D_COL, False, self.data_parallel_rank, self.pipeline_parallel_rank, self.pipeline_parallel_size, self.tensor_parallel_size) return output else: return output @LAYERS.register_module class LayerNorm2D(ParallelLayer): r"""Layer Normalization for 2D parallelism. Args: normalized_shape (int): input shape from an expected input of size. :math:`[* \times \text{normalized_shape}[0] \times \text{normalized_shape}[1] \times \ldots \times \text{normalized_shape}[-1]]` If a single integer is used, it is treated as a singleton list, and this module will normalize over the last dimension which is expected to be of that specific size. eps (float, optional): a value added to the denominator for numerical stability, defaults to 1e-05. bias (bool, optional): Whether to add a bias, defaults to ``True``. dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None. """ def __init__(self, normalized_shape: int, eps: float = 1e-05, bias=True, dtype=None): super().__init__() # layer norm config self.normalized_shape = normalized_shape self.variance_epsilon = eps # parallel setting assert_summa_initialization() self.row_rank = gpc.get_local_rank(ParallelMode.PARALLEL_2D_COL) self.col_rank = gpc.get_local_rank(ParallelMode.PARALLEL_2D_ROW) self.summa_dim = get_summa_dim_from_env() # partitioning dimension self.partitioned_partition = divide(normalized_shape, self.summa_dim**2) # create parameters factory_kwargs = {'device': get_current_device(), 'dtype': dtype} self.weight = Parameter(torch.ones(self.partitioned_partition, **factory_kwargs)) if bias: self.bias = Parameter(torch.zeros(self.partitioned_partition, **factory_kwargs)) else: self.bias = None self._set_tensor_parallel_attributes() def _set_tensor_parallel_attributes(self): set_tensor_parallel_attribute_by_partition(self.weight, self.summa_dim**2) if self.bias is not None: set_tensor_parallel_attribute_by_partition(self.bias, self.summa_dim**2) def _load_from_global_state_dict(self, state_dict, prefix, *args, **kwargs): local_state = OrderedDict() weight_key = prefix + 'weight' bias_key = prefix + 'bias' if gpc.get_local_rank(ParallelMode.TENSOR) == 0: # weight weight = state_dict.pop(weight_key, None) if weight is not None: local_state[weight_key] = weight # bias bias = state_dict.pop(bias_key, None) if bias is not None: local_state[bias_key] = bias # partition in row groups if gpc.get_local_rank(ParallelMode.PARALLEL_2D_COL) == 0: local_state = partition_tensor_parallel_state_dict( local_state, ParallelMode.PARALLEL_2D_ROW, dims={ weight_key: 0, bias_key: 0 }, partition_states={ weight_key: True, bias_key: True }, ) # partition in column groups local_state = partition_tensor_parallel_state_dict( local_state, ParallelMode.PARALLEL_2D_COL, dims={ weight_key: 0, bias_key: 0 }, partition_states={ weight_key: True, bias_key: True }, ) super()._load_from_global_state_dict(local_state, prefix, *args, **kwargs) def _save_to_global_state_dict(self, destination, prefix, keep_vars): weight_key = prefix + 'weight' bias_key = prefix + 'bias' local_state = OrderedDict({weight_key: self.weight}) if self.bias is not None: local_state[bias_key] = self.bias # gather in column groups local_state = gather_tensor_parallel_state_dict( local_state, ParallelMode.PARALLEL_2D_COL, dims={ weight_key: 0, bias_key: 0 }, partition_states={ weight_key: True, bias_key: True }, keep_vars=keep_vars, ) # gather in row groups if gpc.get_local_rank(ParallelMode.PARALLEL_2D_COL) == 0: local_state = gather_tensor_parallel_state_dict( local_state, ParallelMode.PARALLEL_2D_ROW, dims={ weight_key: 0, bias_key: 0 }, partition_states={ weight_key: True, bias_key: True }, keep_vars=keep_vars, ) if gpc.get_local_rank(ParallelMode.TENSOR) == 0: destination.update(local_state) def forward(self, x: Tensor) -> Tensor: with torch.no_grad(): E_x = torch.sum(x, dim=-1, keepdim=True) # [b/q, s, 1] torch.distributed.all_reduce(E_x, group=gpc.get_group(ParallelMode.PARALLEL_2D_ROW)) E_x /= self.normalized_shape # Var_x in the block below is the sum of input^2 Var_x = torch.sum(x * x, dim=-1, keepdim=True) # [b/q, s, 1] torch.distributed.all_reduce(Var_x, group=gpc.get_group(ParallelMode.PARALLEL_2D_ROW)) Var_x /= self.normalized_shape Var_x = Var_x - E_x * E_x # variance of x [b/q, s, 1] # this time 1/sqrt(Var_x + epsilon) Var_x = 1.0 / torch.sqrt(Var_x + self.variance_epsilon) output = layernorm_2d(x, E_x, Var_x, self.normalized_shape, ParallelMode.PARALLEL_2D_ROW, ParallelMode.PARALLEL_2D_COL) scale = add_bias_2d(None, self.weight, self.partitioned_partition, self.row_rank, self.col_rank, ParallelMode.PARALLEL_2D_ROW, ParallelMode.PARALLEL_2D_COL, True, self.data_parallel_rank, self.pipeline_parallel_rank, self.pipeline_parallel_size, self.tensor_parallel_size) if self.bias is not None: bias = add_bias_2d(None, self.bias, self.partitioned_partition, self.row_rank, self.col_rank, ParallelMode.PARALLEL_2D_ROW, ParallelMode.PARALLEL_2D_COL, True, self.data_parallel_rank, self.pipeline_parallel_rank, self.pipeline_parallel_size, self.tensor_parallel_size) output = torch.addcmul(bias, scale, output) else: output = torch.mul(scale, output) return output @LAYERS.register_module class PatchEmbedding2D(ParallelLayer): r"""2D Image to Patch Embedding. Args: img_size (int): image size. patch_size (int): patch size. in_chans (int): number of channels of input image. embed_size (int): size of embedding. dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None. flatten (bool, optional): whether to flatten output tensor, defaults to True. weight_initializer (:class:`typing.Callable`, optional): The initializer of weight, defaults to kaiming uniform initializer. bias_initializer (:class:`typing.Callable`, optional): The initializer of bias, defaults to xavier uniform initializer. position_embed_initializer (:class:`typing.Callable`, optional): The initializer of position embedding, defaults to zeros initializer. More details about ``initializer`` please refer to `init `_. """ def __init__(self, img_size: int, patch_size: int, in_chans: int, embed_size: int, flatten: bool = True, dtype: torch.dtype = None, weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1), position_embed_initializer: Callable = init.zeros_()): super().__init__() img_size = to_2tuple(img_size) patch_size = to_2tuple(patch_size) assert_summa_initialization() self.summa_dim = get_summa_dim_from_env() self.img_size = img_size self.patch_size = patch_size self.grid_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1]) self.num_patches = self.grid_size[0] * self.grid_size[1] self.flatten = flatten self.embed_size = embed_size self.embed_size_per_partition = embed_size // (self.summa_dim**2) with seed(ParallelMode.TENSOR): self.weight = Parameter( torch.empty((self.embed_size_per_partition, in_chans, *self.patch_size), device=get_current_device(), dtype=dtype)) self.bias = Parameter(torch.empty(self.embed_size_per_partition, device=get_current_device(), dtype=dtype)) self.cls_token = Parameter( torch.zeros((1, 1, self.embed_size_per_partition), device=get_current_device(), dtype=dtype)) self.pos_embed = Parameter( torch.zeros((1, self.num_patches + 1, self.embed_size_per_partition), device=get_current_device(), dtype=dtype)) self.reset_parameters(weight_initializer, bias_initializer, position_embed_initializer) self._set_tensor_parallel_attribute() def _set_tensor_parallel_attribute(self): set_tensor_parallel_attribute_by_partition(self.weight, self.summa_dim**2) set_tensor_parallel_attribute_by_partition(self.bias, self.summa_dim**2) set_tensor_parallel_attribute_by_partition(self.cls_token, self.summa_dim**2) set_tensor_parallel_attribute_by_partition(self.pos_embed, self.summa_dim**2) def reset_parameters(self, weight_initializer, bias_initializer, position_embed_initializer): with seed(ParallelMode.TENSOR): fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight) fan_out = self.embed_size weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out) bias_initializer(self.bias, fan_in=fan_in) position_embed_initializer(self.pos_embed) def _load_from_global_state_dict(self, state_dict, prefix, *args, **kwargs): local_state = OrderedDict() weight_key = prefix + 'weight' bias_key = prefix + 'bias' cls_token_key = prefix + 'cls_token' pos_embed_key = prefix + 'pos_embed' if gpc.get_local_rank(ParallelMode.TENSOR) == 0: # weight weight = state_dict.pop(weight_key, None) if weight is not None: local_state[weight_key] = weight # bias bias = state_dict.pop(bias_key, None) if bias is not None: local_state[bias_key] = bias # cls token cls_token = state_dict.pop(cls_token_key, None) if cls_token is not None: local_state[cls_token_key] = cls_token # pos embed pos_embed = state_dict.pop(pos_embed_key, None) if pos_embed is not None: local_state[pos_embed_key] = pos_embed # partition in row groups if gpc.get_local_rank(ParallelMode.PARALLEL_2D_COL) == 0: local_state = partition_tensor_parallel_state_dict( local_state, ParallelMode.PARALLEL_2D_ROW, dims={ weight_key: 0, bias_key: 0, cls_token_key: -1, pos_embed_key: -1 }, partition_states={ weight_key: True, bias_key: True, cls_token_key: True, pos_embed_key: True }, ) # partition in column groups local_state = partition_tensor_parallel_state_dict( local_state, ParallelMode.PARALLEL_2D_COL, dims={ weight_key: 0, bias_key: 0, cls_token_key: -1, pos_embed_key: -1 }, partition_states={ weight_key: True, bias_key: True, cls_token_key: True, pos_embed_key: True }, ) super()._load_from_global_state_dict(local_state, prefix, *args, **kwargs) def _save_to_global_state_dict(self, destination, prefix, keep_vars): weight_key = prefix + 'weight' bias_key = prefix + 'bias' cls_token_key = prefix + 'cls_token' pos_embed_key = prefix + 'pos_embed' local_state = OrderedDict({ weight_key: self.weight, bias_key: self.bias, cls_token_key: self.cls_token, pos_embed_key: self.pos_embed }) # gather in column groups local_state = gather_tensor_parallel_state_dict( local_state, ParallelMode.PARALLEL_2D_COL, dims={ weight_key: 0, bias_key: 0, cls_token_key: -1, pos_embed_key: -1 }, partition_states={ weight_key: True, bias_key: True, cls_token_key: True, pos_embed_key: True }, keep_vars=keep_vars, ) # gather in row groups if gpc.get_local_rank(ParallelMode.PARALLEL_2D_COL) == 0: local_state = gather_tensor_parallel_state_dict( local_state, ParallelMode.PARALLEL_2D_ROW, dims={ weight_key: 0, bias_key: 0, cls_token_key: -1, pos_embed_key: -1 }, partition_states={ weight_key: True, bias_key: True, cls_token_key: True, pos_embed_key: True }, keep_vars=keep_vars, ) if gpc.get_local_rank(ParallelMode.TENSOR) == 0: destination.update(local_state) def forward(self, input_: Tensor) -> Tensor: input_ = split_batch_2d(input_) B, C, H, W = input_.shape assert H == self.img_size[0] and W == self.img_size[1], \ f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." weight = all_gather_tensor_2d(self.weight, 0, ParallelMode.PARALLEL_2D_COL) bias = all_gather_tensor_2d(self.bias, 0, ParallelMode.PARALLEL_2D_COL) output = F.conv2d(input_, weight, bias, stride=self.patch_size) if self.flatten: output = output.flatten(2).transpose(1, 2) # BCHW -> BNC cls_token = all_gather_tensor_2d(self.cls_token, -1, ParallelMode.PARALLEL_2D_COL) pos_embed = all_gather_tensor_2d(self.pos_embed, -1, ParallelMode.PARALLEL_2D_COL) cls_token = cls_token.expand(output.shape[0], -1, -1) output = torch.cat((cls_token, output), dim=1) output = output + pos_embed return output @LAYERS.register_module class Embedding2D(ParallelLayer): r"""Embedding for 2D parallelism. Args: num_embeddings (int): number of embeddings. embedding_dim (int): dimension of embedding. padding_idx (int, optional): If specified, the entries at padding_idx do not contribute to the gradient; therefore, the embedding vector at padding_idx is not updated during training, i.e. it remains as a fixed “pad”, defaults to None. dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None. weight_initializer (:class:`typing.Callable`, optional): he initializer of weight, defaults to normal initializer. The ``args`` and ``kwargs`` used in :class:``torch.nn.functional.embedding`` should contain: :: max_norm (float, optional): If given, each embedding vector with norm larger than max_norm is renormalized to have norm max_norm. Note: this will modify weight in-place. norm_type (float, optional): The p of the p-norm to compute for the max_norm option. Default 2. scale_grad_by_freq (bool, optional): If given, this will scale gradients by the inverse of frequency of the words in the mini-batch. Default False. sparse (bool, optional): If True, gradient w.r.t. weight will be a sparse tensor. Default False. More details about ``args`` and ``kwargs`` could be found in `Embedding `_. More details about initializer please refer to `init `_ """ def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: int = None, dtype: torch.dtype = None, weight_initializer: Callable = init.normal_(), *args, **kwargs): super().__init__() assert_summa_initialization() self.summa_dim = get_summa_dim_from_env() self.num_embeddings = num_embeddings self.embed_dim = embedding_dim embed_dim_per_partition = divide(embedding_dim, self.summa_dim**2) self.padding_idx = padding_idx self.embed_args = args self.embed_kwargs = kwargs self.weight = Parameter( torch.empty((num_embeddings, embed_dim_per_partition), device=get_current_device(), dtype=dtype)) self.reset_parameters(weight_initializer) self._set_tensor_parallel_attributes() def _set_tensor_parallel_attributes(self): set_tensor_parallel_attribute_by_partition(self.weight, self.summa_dim**2) def reset_parameters(self, weight_initializer) -> None: with seed(ParallelMode.TENSOR): fan_in, fan_out = self.num_embeddings, self.embed_dim weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out) self._fill_padding_idx_with_zero() def _fill_padding_idx_with_zero(self) -> None: if self.padding_idx is not None: with torch.no_grad(): self.weight[self.padding_idx].fill_(0) def _load_from_global_state_dict(self, state_dict, prefix, *args, **kwargs): local_state = OrderedDict() weight_key = prefix + 'weight' if gpc.get_local_rank(ParallelMode.TENSOR) == 0: # weight weight = state_dict.pop(weight_key, None) if weight is not None: local_state[weight_key] = weight # partition in row groups if gpc.get_local_rank(ParallelMode.PARALLEL_2D_COL) == 0: local_state = partition_tensor_parallel_state_dict( local_state, ParallelMode.PARALLEL_2D_ROW, dims={weight_key: -1}, partition_states={weight_key: True}, ) # partition in column groups local_state = partition_tensor_parallel_state_dict( local_state, ParallelMode.PARALLEL_2D_COL, dims={weight_key: -1}, partition_states={weight_key: True}, ) super()._load_from_global_state_dict(local_state, prefix, *args, **kwargs) def _save_to_global_state_dict(self, destination, prefix, keep_vars): weight_key = prefix + 'weight' local_state = OrderedDict({weight_key: self.weight}) # gather in column groups local_state = gather_tensor_parallel_state_dict( local_state, ParallelMode.PARALLEL_2D_COL, dims={weight_key: -1}, partition_states={weight_key: True}, keep_vars=keep_vars, ) # gather in row groups if gpc.get_local_rank(ParallelMode.PARALLEL_2D_COL) == 0: local_state = gather_tensor_parallel_state_dict( local_state, ParallelMode.PARALLEL_2D_ROW, dims={weight_key: -1}, partition_states={weight_key: True}, keep_vars=keep_vars, ) if gpc.get_local_rank(ParallelMode.TENSOR) == 0: destination.update(local_state) def forward(self, input_: Tensor) -> Tensor: input_ = split_batch_2d(input_) weight = all_gather_tensor_2d(self.weight, -1, ParallelMode.PARALLEL_2D_COL) output = F.embedding(input_, weight, self.padding_idx, *self.embed_args, **self.embed_kwargs) return output @LAYERS.register_module class VocabParallelEmbedding2D(ParallelLayer): r"""Embedding parallelized in the vocabulary dimension. Args: num_embeddings (int): number of embeddings. embedding_dim (int): dimension of embedding. padding_idx (int, optional): If specified, the entries at padding_idx do not contribute to the gradient; therefore, the embedding vector at padding_idx is not updated during training, i.e. it remains as a fixed “pad”, defaults to None. dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None. weight_initializer (:class:`typing.Callable`, optional): he initializer of weight, defaults to normal initializer. The ``args`` and ``kwargs`` used in :class:``torch.nn.functional.embedding`` should contain: :: max_norm (float, optional): If given, each embedding vector with norm larger than max_norm is renormalized to have norm max_norm. Note: this will modify weight in-place. norm_type (float, optional): The p of the p-norm to compute for the max_norm option. Default 2. scale_grad_by_freq (bool, optional): If given, this will scale gradients by the inverse of frequency of the words in the mini-batch. Default False. sparse (bool, optional): If True, gradient w.r.t. weight will be a sparse tensor. Default False. More details about ``args`` and ``kwargs`` could be found in `Embedding `_. More details about initializer please refer to `init `_. """ def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: int = None, dtype: torch.dtype = None, weight_initializer: Callable = init.normal_(), *args, **kwargs): super().__init__() self.num_embeddings = num_embeddings self.embed_dim = embedding_dim self.padding_idx = padding_idx self.embed_args = args self.embed_kwargs = kwargs assert_summa_initialization() self.summa_dim = get_summa_dim_from_env() self.num_embeddings_per_partition = divide(self.num_embeddings, self.summa_dim) self.embed_dim_per_partition = divide(self.embed_dim, self.summa_dim) tensor_parallel_rank = gpc.get_local_rank(ParallelMode.PARALLEL_2D_COL) self.vocab_start_index = tensor_parallel_rank * self.num_embeddings_per_partition self.vocab_end_index = self.vocab_start_index + self.num_embeddings_per_partition self.weight = Parameter( torch.empty((self.num_embeddings_per_partition, self.embed_dim_per_partition), device=get_current_device(), dtype=dtype)) self.reset_parameters(weight_initializer) self._set_tensor_parallel_attributes() env.vocab_parallel = True def _set_tensor_parallel_attributes(self): set_tensor_parallel_attribute_by_partition(self.weight, self.summa_dim**2) def reset_parameters(self, weight_initializer) -> None: with seed(ParallelMode.TENSOR): fan_in, fan_out = self.num_embeddings, self.embed_dim weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out) self._fill_padding_idx_with_zero() def _fill_padding_idx_with_zero(self) -> None: if self.padding_idx is not None and \ self.padding_idx >= self.vocab_start_index and self.padding_idx < self.vocab_end_index: with torch.no_grad(): self.weight[self.padding_idx - self.vocab_start_index].fill_(0) def _load_from_global_state_dict(self, state_dict, prefix, *args, **kwargs): local_state = OrderedDict() weight_key = prefix + 'weight' if gpc.get_local_rank(ParallelMode.TENSOR) == 0: # weight weight = state_dict.pop(weight_key, None) if weight is not None: local_state[weight_key] = weight # partition in row groups if gpc.get_local_rank(ParallelMode.PARALLEL_2D_COL) == 0: local_state = partition_tensor_parallel_state_dict( local_state, ParallelMode.PARALLEL_2D_ROW, dims={weight_key: -1}, partition_states={weight_key: True}, ) # partition in column groups local_state = partition_tensor_parallel_state_dict( local_state, ParallelMode.PARALLEL_2D_COL, dims={weight_key: 0}, partition_states={weight_key: True}, ) super()._load_from_global_state_dict(local_state, prefix, *args, **kwargs) def _save_to_global_state_dict(self, destination, prefix, keep_vars): weight_key = prefix + 'weight' local_state = OrderedDict({weight_key: self.weight}) # gather in column groups local_state = gather_tensor_parallel_state_dict( local_state, ParallelMode.PARALLEL_2D_COL, dims={weight_key: 0}, partition_states={weight_key: True}, keep_vars=keep_vars, ) # gather in row groups if gpc.get_local_rank(ParallelMode.PARALLEL_2D_COL) == 0: local_state = gather_tensor_parallel_state_dict( local_state, ParallelMode.PARALLEL_2D_ROW, dims={weight_key: -1}, partition_states={weight_key: True}, keep_vars=keep_vars, ) if gpc.get_local_rank(ParallelMode.TENSOR) == 0: destination.update(local_state) def forward(self, input_: Tensor) -> Tensor: input_mask = (input_ < self.vocab_start_index) | (input_ >= self.vocab_end_index) masked_input = input_.clone() - self.vocab_start_index masked_input[input_mask] = 0 output_parallel = F.embedding(masked_input, self.weight, self.padding_idx, *self.embed_args, **self.embed_kwargs) output_parallel[input_mask, :] = 0. output = reduce_scatter_tensor_2d(output_parallel, 0, ParallelMode.PARALLEL_2D_COL) return output @LAYERS.register_module class Classifier2D(ParallelLayer): r"""Classifier for 2D parallelism. Args: in_features (int): size of each input sample. num_classes (int): number of classes. weight (:class:`torch.nn.Parameter`, optional): weight of the classifier, defaults to None. bias (bool, optional): If set to ``False``, the layer will not learn an additive bias, defaults to ``True``. dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None. weight_initializer (:class:`typing.Callable`, optional): The initializer of weight, defaults to kaiming uniform initializer. bias_initializer (:class:`typing.Callable`, optional): The initializer of bias, defaults to xavier uniform initializer. More details about ``initializer`` please refer to `init `_. """ def __init__(self, in_features: int, num_classes: int, weight: Parameter = None, bias: bool = True, dtype: torch.dtype = None, weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1)): super().__init__() self.in_features = in_features self.num_classes = num_classes assert_summa_initialization() self.row_rank = gpc.get_local_rank(ParallelMode.PARALLEL_2D_COL) self.col_rank = gpc.get_local_rank(ParallelMode.PARALLEL_2D_ROW) self.summa_dim = get_summa_dim_from_env() # partitioning dimension self.input_size_per_partition = divide(self.in_features, self.summa_dim**2) if weight is not None: self.weight = weight self.has_weight = False else: self.weight = Parameter( torch.empty(self.num_classes, self.input_size_per_partition, device=get_current_device(), dtype=dtype)) self.has_weight = True if bias: self.bias = Parameter(torch.zeros(self.num_classes, device=get_current_device(), dtype=dtype)) else: self.bias = None self.reset_parameters(weight_initializer, bias_initializer) self._set_tensor_parallel_attributes() def _set_tensor_parallel_attributes(self): if self.has_weight: set_tensor_parallel_attribute_by_partition(self.weight, self.summa_dim**2) def reset_parameters(self, weight_initializer, bias_initializer) -> None: with seed(ParallelMode.TENSOR): fan_in, fan_out = self.in_features, self.num_classes col_src_rank = gpc.get_ranks_in_group(ParallelMode.PARALLEL_2D_COL)[0] row_src_rank = gpc.get_ranks_in_group(ParallelMode.PARALLEL_2D_ROW)[0] if self.has_weight: weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out) if self.bias is not None: bias_initializer(self.bias, fan_in=fan_in) broadcast(self.bias, col_src_rank, ParallelMode.PARALLEL_2D_COL) broadcast(self.bias, row_src_rank, ParallelMode.PARALLEL_2D_ROW) def _load_from_global_state_dict(self, state_dict, prefix, *args, **kwargs): local_state = OrderedDict() weight_key = prefix + 'weight' bias_key = prefix + 'bias' if gpc.get_local_rank(ParallelMode.TENSOR) == 0: # weight if self.has_weight: weight = state_dict.pop(weight_key, None) if weight is not None: local_state[weight_key] = weight # bias if self.bias is not None: bias = state_dict.pop(bias_key, None) if bias is not None: local_state[bias_key] = bias # partition in row groups if gpc.get_local_rank(ParallelMode.PARALLEL_2D_COL) == 0: local_state = partition_tensor_parallel_state_dict( local_state, ParallelMode.PARALLEL_2D_ROW, dims={ weight_key: -1, bias_key: 0 }, partition_states={ weight_key: True, bias_key: False }, ) # partition in column groups local_state = partition_tensor_parallel_state_dict( local_state, ParallelMode.PARALLEL_2D_COL, dims={ weight_key: -1, bias_key: 0 }, partition_states={ weight_key: True, bias_key: False }, ) super()._load_from_global_state_dict(local_state, prefix, *args, **kwargs) def _save_to_global_state_dict(self, destination, prefix, keep_vars): weight_key = prefix + 'weight' bias_key = prefix + 'bias' local_state = OrderedDict() if self.has_weight: local_state[weight_key] = self.weight if self.bias is not None: local_state[bias_key] = self.bias # gather in column groups local_state = gather_tensor_parallel_state_dict( local_state, ParallelMode.PARALLEL_2D_COL, dims={ weight_key: -1, bias_key: 0 }, partition_states={ weight_key: True, bias_key: False }, keep_vars=keep_vars, ) # gather in row groups if gpc.get_local_rank(ParallelMode.PARALLEL_2D_COL) == 0: local_state = gather_tensor_parallel_state_dict( local_state, ParallelMode.PARALLEL_2D_ROW, dims={ weight_key: -1, bias_key: 0 }, partition_states={ weight_key: True, bias_key: False }, keep_vars=keep_vars, ) if gpc.get_local_rank(ParallelMode.TENSOR) == 0: destination.update(local_state) def forward(self, input_: Tensor) -> Tensor: out_shape = input_.shape[:-1] + (self.num_classes,) return classifier_2d(input_, self.weight, self.bias, self.summa_dim, out_shape, self.row_rank, self.col_rank, ParallelMode.PARALLEL_2D_ROW, ParallelMode.PARALLEL_2D_COL, self.data_parallel_rank, self.pipeline_parallel_rank, self.pipeline_parallel_size, self.tensor_parallel_size) @LAYERS.register_module class VocabParallelClassifier2D(ParallelLayer): r"""Vocab parallel classifier layer for 2D parallelism. Args: in_features (int): size of each input sample. num_classes (int): number of classes. weight (:class:`torch.nn.Parameter`, optional): weight of the classifier, defaults to None. bias (bool, optional): If set to ``False``, the layer will not learn an additive bias, defaults to ``True``. dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None. weight_initializer (:class:`typing.Callable`, optional): The initializer of weight, defaults to kaiming uniform initializer. bias_initializer (:class:`typing.Callable`, optional): The initializer of bias, defaults to xavier uniform initializer. More details about ``initializer`` please refer to `init `_. """ def __init__(self, in_features: int, num_classes: int, weight: Parameter = None, bias: bool = True, dtype: torch.dtype = None, weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1)): super().__init__() self.in_features = in_features self.num_classes = num_classes # parallel setting assert_summa_initialization() self.row_rank = gpc.get_local_rank(ParallelMode.PARALLEL_2D_COL) self.col_rank = gpc.get_local_rank(ParallelMode.PARALLEL_2D_ROW) self.summa_dim = get_summa_dim_from_env() # partitioning dimension self.input_size_per_partition = divide(in_features, self.summa_dim) self.output_size_per_partition = divide(num_classes, self.summa_dim) # create weight, shape: [k/q, h/q] factory_kwargs = {'device': get_current_device(), 'dtype': dtype} if weight is not None: self.weight = weight self.has_weight = False else: self.weight = Parameter( torch.empty(self.output_size_per_partition, self.input_size_per_partition, **factory_kwargs)) self.has_weight = True # create bias, shape: [h/q] if bias: self.bias = Parameter(torch.empty(divide(self.num_classes, self.summa_dim**2), **factory_kwargs)) else: self.bias = None # initialize parameters with seed(ParallelMode.TENSOR): self.reset_parameters(weight_initializer, bias_initializer) self._set_tensor_parallel_attributes() env.vocab_parallel = True def _set_tensor_parallel_attributes(self): if self.has_weight: set_tensor_parallel_attribute_by_partition(self.weight, self.summa_dim**2) if self.bias is not None: set_tensor_parallel_attribute_by_partition(self.bias, self.summa_dim**2) def reset_parameters(self, weight_initializer, bias_initializer) -> None: fan_in, fan_out = self.in_features, self.num_classes if self.has_weight: weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out) if self.bias is not None: bias_initializer(self.bias, fan_in=fan_in) def _load_from_global_state_dict(self, state_dict, prefix, *args, **kwargs): local_state = OrderedDict() weight_key = prefix + 'weight' bias_key = prefix + 'bias' if gpc.get_local_rank(ParallelMode.TENSOR) == 0: # weight if self.has_weight: weight = state_dict.pop(weight_key, None) if weight is not None: local_state[weight_key] = weight # bias if self.bias is not None: bias = state_dict.pop(bias_key, None) if bias is not None: local_state[bias_key] = bias # partition in row groups if gpc.get_local_rank(ParallelMode.PARALLEL_2D_COL) == 0: local_state = partition_tensor_parallel_state_dict( local_state, ParallelMode.PARALLEL_2D_ROW, dims={ weight_key: -1, bias_key: 0 }, partition_states={ weight_key: True, bias_key: True }, ) # partition in column groups local_state = partition_tensor_parallel_state_dict( local_state, ParallelMode.PARALLEL_2D_COL, dims={ weight_key: 0, bias_key: 0 }, partition_states={ weight_key: True, bias_key: True }, ) super()._load_from_global_state_dict(local_state, prefix, *args, **kwargs) def _save_to_global_state_dict(self, destination, prefix, keep_vars): weight_key = prefix + 'weight' bias_key = prefix + 'bias' local_state = OrderedDict() if self.has_weight: local_state[weight_key] = self.weight if self.bias is not None: local_state[bias_key] = self.bias # gather in column groups local_state = gather_tensor_parallel_state_dict( local_state, ParallelMode.PARALLEL_2D_COL, dims={ weight_key: 0, bias_key: 0 }, partition_states={ weight_key: True, bias_key: True }, keep_vars=keep_vars, ) # gather in row groups if gpc.get_local_rank(ParallelMode.PARALLEL_2D_COL) == 0: local_state = gather_tensor_parallel_state_dict( local_state, ParallelMode.PARALLEL_2D_ROW, dims={ weight_key: -1, bias_key: 0 }, partition_states={ weight_key: True, bias_key: True }, keep_vars=keep_vars, ) if gpc.get_local_rank(ParallelMode.TENSOR) == 0: local_state[weight_key] = local_state[weight_key].transpose(0, 1) destination.update(local_state) def forward(self, x: Tensor) -> Tensor: # input: [m/q, n/q, k/q] # output: [m/q, n/q, h/q] out_shape = x.shape[:-1] + (self.output_size_per_partition,) output = Matmul_ABT_2D.apply(x, self.weight, self.summa_dim, out_shape, self.row_rank, self.col_rank, ParallelMode.PARALLEL_2D_ROW, ParallelMode.PARALLEL_2D_COL, self.data_parallel_rank, self.pipeline_parallel_rank, self.pipeline_parallel_size, self.tensor_parallel_size) if self.bias is not None: output = add_bias_2d(output, self.bias, self.output_size_per_partition, self.row_rank, self.col_rank, ParallelMode.PARALLEL_2D_ROW, ParallelMode.PARALLEL_2D_COL, False, self.data_parallel_rank, self.pipeline_parallel_rank, self.pipeline_parallel_size, self.tensor_parallel_size) return output