import math from collections import OrderedDict from typing import Callable import torch import torch.nn as nn import torch.nn.functional as F from colossalai.communication import all_reduce, broadcast from colossalai.constants import INPUT_GROUP_3D, WEIGHT_GROUP_3D from colossalai.context import ParallelMode, seed from colossalai.core import global_context as gpc from colossalai.global_variables import tensor_parallel_env as env from colossalai.nn import init as init from colossalai.nn.layer.base_layer import ParallelLayer from colossalai.registry import LAYERS from colossalai.utils.checkpointing import (broadcast_state_dict, gather_tensor_parallel_state_dict, partition_tensor_parallel_state_dict) from colossalai.utils.cuda import get_current_device from torch import Tensor from torch.nn import Parameter from ..utils import divide, set_tensor_parallel_attribute_by_partition, to_2tuple from ._operation import (all_gather_tensor_3d, broadcast_weight_3d_from_diagonal, classifier_3d, layernorm_3d, linear_3d, reduce_scatter_tensor_3d, split_tensor_3d) from ._utils import get_depth_from_env, get_last_group, get_parallel_mode_from_env, swap_in_out_group @LAYERS.register_module class LayerNorm3D(ParallelLayer): r"""Layer Normalization for 3D parallelism. Args: normalized_shape (int): input shape from an expected input of size. :math:`[* \times \text{normalized_shape}[0] \times \text{normalized_shape}[1] \times \ldots \times \text{normalized_shape}[-1]]` If a single integer is used, it is treated as a singleton list, and this module will normalize over the last dimension which is expected to be of that specific size. eps (float, optional): a value added to the denominator for numerical stability, defaults to 1e-12. bias (bool, optional): Whether to add a bias, defaults to ``True``. dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None. """ def __init__(self, normalized_shape: int, eps: float = 1e-12, bias=True, dtype=None): super().__init__() self.input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D) self.weight_parallel_mode = get_parallel_mode_from_env(WEIGHT_GROUP_3D) self.output_parallel_mode = get_last_group(self.input_parallel_mode, self.weight_parallel_mode) self.depth = get_depth_from_env() self.normalized_shape = normalized_shape self.normalized_shape_per_partition = divide(normalized_shape, self.depth) self.weight = Parameter( torch.ones(self.normalized_shape_per_partition, device=get_current_device(), dtype=dtype)) if bias: self.bias = Parameter( torch.zeros(self.normalized_shape_per_partition, device=get_current_device(), dtype=dtype)) else: self.bias = None self.variance_epsilon = eps self._set_tensor_parallel_attributes() def _set_tensor_parallel_attributes(self) -> None: set_tensor_parallel_attribute_by_partition(self.weight, self.depth) if self.bias is not None: set_tensor_parallel_attribute_by_partition(self.bias, self.depth) def reset_parameters(self) -> None: init.ones_()(self.weight) if self.bias is not None: init.zeros_()(self.bias) def _load_from_global_state_dict(self, state_dict, prefix, *args, **kwargs): local_state = OrderedDict() weight_key = prefix + 'weight' bias_key = prefix + 'bias' if gpc.get_local_rank(ParallelMode.TENSOR) == 0: # weight weight = state_dict.pop(weight_key, None) if weight is not None: local_state[weight_key] = weight.transpose(0, 1) # bias bias = state_dict.pop(bias_key, None) if bias is not None: local_state[bias_key] = bias # partition in output groups if gpc.get_local_rank(self.input_parallel_mode) == 0 and \ gpc.get_local_rank(self.weight_parallel_mode) == 0: local_state = partition_tensor_parallel_state_dict( local_state, self.output_parallel_mode, dims={ weight_key: 0, bias_key: 0 }, partition_states={ weight_key: True, bias_key: True, }, ) # broadcast in input groups if gpc.get_local_rank(self.weight_parallel_mode) == 0: local_state = broadcast_state_dict(local_state, self.input_parallel_mode) # broadcast in weight groups local_state = broadcast_state_dict(local_state, self.weight_parallel_mode) super()._load_from_global_state_dict(local_state, prefix, *args, **kwargs) def _save_to_global_state_dict(self, destination, prefix, keep_vars): weight_key = prefix + 'weight' bias_key = prefix + 'bias' local_state = OrderedDict({weight_key: self.weight}) if self.bias is not None: local_state[bias_key] = self.bias # gather in output groups if gpc.get_local_rank(self.input_parallel_mode) == 0 and \ gpc.get_local_rank(self.weight_parallel_mode) == 0: local_state = gather_tensor_parallel_state_dict( local_state, self.output_parallel_mode, dims={ weight_key: 0, bias_key: 0 }, partition_states={ weight_key: True, bias_key: True }, keep_vars=keep_vars, ) if gpc.get_local_rank(ParallelMode.TENSOR) == 0: destination.update(local_state) def forward(self, input_: Tensor) -> Tensor: return layernorm_3d(input_, self.weight, self.bias, self.normalized_shape, self.variance_epsilon, self.input_parallel_mode, self.weight_parallel_mode, self.output_parallel_mode) @LAYERS.register_module class Linear3D(ParallelLayer): r"""Linear layer for 3D parallelism. Args: in_features (int): size of each input sample. out_features (int): size of each output sample. bias (bool, optional): If set to ``False``, the layer will not learn an additive bias, defaults to ``True``. dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None. weight_initializer (:class:`typing.Callable`, optional): The initializer of weight, defaults to kaiming uniform initializer. bias_initializer (:class:`typing.Callable`, optional): The initializer of bias, defaults to xavier uniform initializer. More details about ``initializer`` please refer to `init `_. """ def __init__(self, in_features: int, out_features: int, bias: bool = True, dtype: torch.dtype = None, weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1)): super().__init__() self.in_features = in_features self.out_features = out_features self.input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D) self.weight_parallel_mode = get_parallel_mode_from_env(WEIGHT_GROUP_3D) self.output_parallel_mode = get_last_group(self.input_parallel_mode, self.weight_parallel_mode) self.depth = get_depth_from_env() self.in_features_per_partition = divide(in_features, self.depth) self.out_features_per_partition = divide(out_features, self.depth**2) self.bias_features_per_partition = divide(out_features, self.depth) self.weight = Parameter( torch.empty(self.in_features_per_partition, self.out_features_per_partition, device=get_current_device(), dtype=dtype)) if bias: self.bias = Parameter( torch.zeros(self.bias_features_per_partition, device=get_current_device(), dtype=dtype)) else: self.bias = None self.reset_parameters(weight_initializer, bias_initializer) self._set_tensor_parallel_attributes() swap_in_out_group() def _set_tensor_parallel_attributes(self) -> None: set_tensor_parallel_attribute_by_partition(self.weight, self.depth**3) if self.bias is not None: set_tensor_parallel_attribute_by_partition(self.bias, self.depth) def reset_parameters(self, weight_initializer, bias_initializer) -> None: with seed(ParallelMode.TENSOR): fan_in, fan_out = self.in_features, self.out_features weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out) if self.bias is not None: bias_initializer(self.bias, fan_in=fan_in) weight_src_rank = gpc.get_ranks_in_group(self.weight_parallel_mode)[0] output_src_rank = gpc.get_ranks_in_group(self.output_parallel_mode)[0] broadcast(self.bias, weight_src_rank, self.weight_parallel_mode) broadcast(self.bias, output_src_rank, self.output_parallel_mode) def _load_from_global_state_dict(self, state_dict, prefix, *args, **kwargs): local_state = OrderedDict() weight_key = prefix + 'weight' bias_key = prefix + 'bias' if gpc.get_local_rank(ParallelMode.TENSOR) == 0: # weight weight = state_dict.pop(weight_key, None) if weight is not None: local_state[weight_key] = weight.transpose(0, 1) # bias if self.bias is not None: bias = state_dict.pop(bias_key, None) if bias is not None: local_state[bias_key] = bias # partition in output groups if gpc.get_local_rank(self.input_parallel_mode) == 0 and \ gpc.get_local_rank(self.weight_parallel_mode) == 0: local_state = partition_tensor_parallel_state_dict( local_state, self.output_parallel_mode, dims={ weight_key: 0, bias_key: 0 }, partition_states={ weight_key: True, bias_key: False }, ) # partition in input groups if gpc.get_local_rank(self.weight_parallel_mode) == 0: local_state = partition_tensor_parallel_state_dict( local_state, self.input_parallel_mode, dims={ weight_key: -1, bias_key: 0 }, partition_states={ weight_key: True, bias_key: True }, ) # partition in weight groups local_state = partition_tensor_parallel_state_dict( local_state, self.weight_parallel_mode, dims={ weight_key: -1, bias_key: 0 }, partition_states={ weight_key: True, bias_key: False }, ) super()._load_from_global_state_dict(local_state, prefix, *args, **kwargs) def _save_to_global_state_dict(self, destination, prefix, keep_vars): weight_key = prefix + 'weight' bias_key = prefix + 'bias' local_state = OrderedDict({weight_key: self.weight}) if self.bias is not None: local_state[bias_key] = self.bias # gather in weight groups local_state = gather_tensor_parallel_state_dict( local_state, self.weight_parallel_mode, dims={ weight_key: -1, bias_key: 0 }, partition_states={ weight_key: True, bias_key: False }, keep_vars=keep_vars, ) # gather in input groups if gpc.get_local_rank(self.weight_parallel_mode) == 0: local_state = gather_tensor_parallel_state_dict( local_state, self.input_parallel_mode, dims={ weight_key: -1, bias_key: 0 }, partition_states={ weight_key: True, bias_key: True }, keep_vars=keep_vars, ) # gather in output groups if gpc.get_local_rank(self.input_parallel_mode) == 0 and \ gpc.get_local_rank(self.weight_parallel_mode) == 0: local_state = gather_tensor_parallel_state_dict( local_state, self.output_parallel_mode, dims={ weight_key: 0, bias_key: 0 }, partition_states={ weight_key: True, bias_key: False }, keep_vars=keep_vars, ) if gpc.get_local_rank(ParallelMode.TENSOR) == 0: local_state[weight_key] = local_state[weight_key].transpose(0, 1) destination.update(local_state) def forward(self, input_: Tensor) -> Tensor: return linear_3d(input_, self.weight, self.bias, self.input_parallel_mode, self.weight_parallel_mode, self.output_parallel_mode) @LAYERS.register_module class Classifier3D(ParallelLayer): r"""Classifier for 3D parallelism. Args: in_features (int): size of each input sample. num_classes (int): number of classes. weight (:class:`torch.nn.Parameter`, optional): weight of the classifier, defaults to None. bias (bool, optional): If set to ``False``, the layer will not learn an additive bias, defaults to ``True``. dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None. weight_initializer (:class:`typing.Callable`, optional): The initializer of weight, defaults to kaiming uniform initializer. bias_initializer (:class:`typing.Callable`, optional): The initializer of bias, defaults to xavier uniform initializer. More details about ``initializer`` please refer to `init `_. """ def __init__(self, in_features: int, num_classes: int, weight: Parameter = None, bias: bool = True, dtype: torch.dtype = None, weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1)): super().__init__() self.in_features = in_features self.num_classes = num_classes self.input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D) self.weight_parallel_mode = get_parallel_mode_from_env(WEIGHT_GROUP_3D) self.output_parallel_mode = get_last_group(self.input_parallel_mode, self.weight_parallel_mode) self.depth = get_depth_from_env() self.in_features_per_partition = divide(in_features, self.depth) if weight is not None: self.weight = weight self.has_weight = False else: self.weight = Parameter( torch.empty(self.num_classes, self.in_features_per_partition, device=get_current_device(), dtype=dtype)) self.has_weight = True if bias: self.bias = Parameter(torch.zeros(self.num_classes, device=get_current_device(), dtype=dtype)) else: self.bias = None self.reset_parameters(weight_initializer, bias_initializer) self._set_tensor_parallel_attributes() def _set_tensor_parallel_attributes(self) -> None: if self.has_weight: set_tensor_parallel_attribute_by_partition(self.weight, self.depth) def reset_parameters(self, weight_initializer, bias_initializer) -> None: with seed(ParallelMode.TENSOR): fan_in, fan_out = self.in_features, self.num_classes weight_src_rank = gpc.get_ranks_in_group(self.weight_parallel_mode)[0] output_src_rank = gpc.get_ranks_in_group(self.output_parallel_mode)[0] input_src_rank = gpc.get_ranks_in_group(self.input_parallel_mode)[0] if self.has_weight: weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out) broadcast(self.weight, weight_src_rank, self.weight_parallel_mode) if self.bias is not None: bias_initializer(self.bias, fan_in=fan_in) broadcast(self.bias, weight_src_rank, self.weight_parallel_mode) broadcast(self.bias, output_src_rank, self.output_parallel_mode) broadcast(self.bias, input_src_rank, self.input_parallel_mode) def _load_from_global_state_dict(self, state_dict, prefix, *args, **kwargs): local_state = OrderedDict() weight_key = prefix + 'weight' bias_key = prefix + 'bias' if gpc.get_local_rank(ParallelMode.TENSOR) == 0: # weight if self.has_weight: weight = state_dict.pop(weight_key, None) if weight is not None: local_state[weight_key] = weight # bias if self.bias is not None: bias = state_dict.pop(bias_key, None) if bias is not None: local_state[bias_key] = bias # partition in output groups if gpc.get_local_rank(self.input_parallel_mode) == 0 and \ gpc.get_local_rank(self.weight_parallel_mode) == 0: local_state = partition_tensor_parallel_state_dict( local_state, self.output_parallel_mode, dims={ weight_key: -1, bias_key: 0 }, partition_states={ weight_key: True, bias_key: False }, ) # broadcast in input groups if gpc.get_local_rank(self.weight_parallel_mode) == 0: local_state = broadcast_state_dict(local_state, self.input_parallel_mode) # broadcast in weight groups local_state = broadcast_state_dict(local_state, self.weight_parallel_mode) super()._load_from_global_state_dict(local_state, prefix, *args, **kwargs) def _save_to_global_state_dict(self, destination, prefix, keep_vars): weight_key = prefix + 'weight' bias_key = prefix + 'bias' local_state = OrderedDict() if self.has_weight: local_state[weight_key] = self.weight if self.bias is not None: local_state[bias_key] = self.bias # gather in output groups if gpc.get_local_rank(self.input_parallel_mode) == 0 and \ gpc.get_local_rank(self.weight_parallel_mode) == 0: local_state = gather_tensor_parallel_state_dict( local_state, self.output_parallel_mode, dims={ weight_key: -1, bias_key: 0 }, partition_states={ weight_key: True, bias_key: False }, keep_vars=keep_vars, ) if gpc.get_local_rank(ParallelMode.TENSOR) == 0: destination.update(local_state) def forward(self, input_: Tensor) -> Tensor: return classifier_3d(input_, self.weight, self.bias, self.input_parallel_mode, self.weight_parallel_mode, self.output_parallel_mode) @LAYERS.register_module class VocabParallelClassifier3D(ParallelLayer): r"""Vocab parallel classifier layer for 3D parallelism. Args: in_features (int): size of each input sample. num_classes (int): number of classes. weight (:class:`torch.nn.Parameter`, optional): weight of the classifier, defaults to None. bias (bool, optional): If set to ``False``, the layer will not learn an additive bias, defaults to ``True``. dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None. weight_initializer (:class:`typing.Callable`, optional): The initializer of weight, defaults to kaiming uniform initializer. bias_initializer (:class:`typing.Callable`, optional): The initializer of bias, defaults to xavier uniform initializer. More details about ``initializer`` please refer to `init `_. """ def __init__(self, in_features: int, num_classes: int, weight: Parameter = None, bias: bool = True, dtype: torch.dtype = None, weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1)): super().__init__() self.in_features = in_features self.num_classes = num_classes self.input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D) self.weight_parallel_mode = get_parallel_mode_from_env(WEIGHT_GROUP_3D) self.output_parallel_mode = get_last_group(self.input_parallel_mode, self.weight_parallel_mode) self.depth = get_depth_from_env() self.in_features_per_partition = divide(in_features, self.depth) self.out_features_per_partition = divide(num_classes, self.depth**2) self.bias_features_per_partition = divide(num_classes, self.depth) if weight is not None: self.weight = weight self.has_weight = False else: self.weight = Parameter( torch.empty(self.out_features_per_partition, self.in_features_per_partition, device=get_current_device(), dtype=dtype)) self.has_weight = True if bias: self.bias = Parameter( torch.zeros(self.bias_features_per_partition, device=get_current_device(), dtype=dtype)) else: self.bias = None self.reset_parameters(weight_initializer, bias_initializer) self._set_tensor_parallel_attributes() swap_in_out_group() env.vocab_parallel = True def _set_tensor_parallel_attributes(self) -> None: if self.has_weight: set_tensor_parallel_attribute_by_partition(self.weight, self.depth**3) if self.bias is not None: set_tensor_parallel_attribute_by_partition(self.bias, self.depth) def reset_parameters(self, weight_initializer, bias_initializer) -> None: with seed(ParallelMode.TENSOR): fan_in, fan_out = self.in_features, self.num_classes if self.has_weight: weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out) if self.bias is not None: bias_initializer(self.bias, fan_in=fan_in) weight_src_rank = gpc.get_ranks_in_group(self.weight_parallel_mode)[0] output_src_rank = gpc.get_ranks_in_group(self.output_parallel_mode)[0] broadcast(self.bias, weight_src_rank, self.weight_parallel_mode) broadcast(self.bias, output_src_rank, self.output_parallel_mode) def _load_from_global_state_dict(self, state_dict, prefix, *args, **kwargs): local_state = OrderedDict() weight_key = prefix + 'weight' bias_key = prefix + 'bias' if gpc.get_local_rank(ParallelMode.TENSOR) == 0: # weight if self.has_weight: weight = state_dict.pop(weight_key, None) if weight is not None: local_state[weight_key] = weight # bias if self.bias is not None: bias = state_dict.pop(bias_key, None) if bias is not None: local_state[bias_key] = bias # partition in output groups if gpc.get_local_rank(self.input_parallel_mode) == 0 and \ gpc.get_local_rank(self.weight_parallel_mode) == 0: local_state = partition_tensor_parallel_state_dict( local_state, self.output_parallel_mode, dims={ weight_key: -1, bias_key: 0 }, partition_states={ weight_key: True, bias_key: False }, ) # partition in input groups if gpc.get_local_rank(self.weight_parallel_mode) == 0: local_state = partition_tensor_parallel_state_dict( local_state, self.input_parallel_mode, dims={ weight_key: 0, bias_key: 0 }, partition_states={ weight_key: True, bias_key: True }, ) # partition in weight groups local_state = partition_tensor_parallel_state_dict( local_state, self.weight_parallel_mode, dims={ weight_key: 0, bias_key: 0 }, partition_states={ weight_key: True, bias_key: False }, ) super()._load_from_global_state_dict(local_state, prefix, *args, **kwargs) def _save_to_global_state_dict(self, destination, prefix, keep_vars): weight_key = prefix + 'weight' bias_key = prefix + 'bias' local_state = OrderedDict({weight_key: self.weight}) if self.bias is not None: local_state[bias_key] = self.bias # gather in weight groups local_state = gather_tensor_parallel_state_dict( local_state, self.weight_parallel_mode, dims={ weight_key: 0, bias_key: 0 }, partition_states={ weight_key: True, bias_key: False }, keep_vars=keep_vars, ) # gather in input groups if gpc.get_local_rank(self.weight_parallel_mode) == 0: local_state = gather_tensor_parallel_state_dict( local_state, self.input_parallel_mode, dims={ weight_key: 0, bias_key: 0 }, partition_states={ weight_key: True, bias_key: True }, keep_vars=keep_vars, ) # gather in output groups if gpc.get_local_rank(self.input_parallel_mode) == 0 and \ gpc.get_local_rank(self.weight_parallel_mode) == 0: local_state = gather_tensor_parallel_state_dict( local_state, self.output_parallel_mode, dims={ weight_key: -1, bias_key: 0 }, partition_states={ weight_key: True, bias_key: False }, keep_vars=keep_vars, ) if gpc.get_local_rank(ParallelMode.TENSOR) == 0: destination.update(local_state) def forward(self, input_: Tensor) -> Tensor: return linear_3d(input_, self.weight.transpose(0, 1), self.bias, self.input_parallel_mode, self.weight_parallel_mode, self.output_parallel_mode) @LAYERS.register_module class PatchEmbedding3D(ParallelLayer): r"""2D Image to Patch Embedding. Args: img_size (int): image size. patch_size (int): patch size. in_chans (int): number of channels of input image. embed_size (int): size of embedding. dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None. flatten (bool, optional): whether to flatten output tensor, defaults to True. weight_initializer (:class:`typing.Callable`, optional): The initializer of weight, defaults to kaiming uniform initializer. bias_initializer (:class:`typing.Callable`, optional): The initializer of bias, defaults to xavier uniform initializer. position_embed_initializer (:class:`typing.Callable`, optional): The initializer of position embedding, defaults to zeros initializer. More details about ``initializer`` please refer to `init `_. """ def __init__(self, img_size: int, patch_size: int, in_chans: int, embed_size: int, flatten: bool = True, dtype: torch.dtype = None, weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1), position_embed_initializer: Callable = init.zeros_()): super().__init__() self.depth = get_depth_from_env() self.input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D) self.weight_parallel_mode = get_parallel_mode_from_env(WEIGHT_GROUP_3D) self.output_parallel_mode = get_last_group(self.input_parallel_mode, self.weight_parallel_mode) self.patch_size = to_2tuple(patch_size) grid_size = to_2tuple(img_size // patch_size) num_patches = grid_size[0] * grid_size[1] self.embed_size = embed_size embed_size_per_partition = divide(embed_size, self.depth) self.flatten = flatten self.weight = nn.Parameter( torch.empty((embed_size_per_partition, in_chans, *self.patch_size), device=get_current_device(), dtype=dtype)) self.bias = nn.Parameter(torch.empty(embed_size_per_partition, device=get_current_device(), dtype=dtype)) self.cls_token = nn.Parameter( torch.zeros((1, 1, embed_size_per_partition), device=get_current_device(), dtype=dtype)) self.pos_embed = nn.Parameter( torch.zeros((1, num_patches + 1, embed_size_per_partition), device=get_current_device(), dtype=dtype)) self.reset_parameters(weight_initializer, bias_initializer, position_embed_initializer) self._set_tensor_parallel_attributes() def _set_tensor_parallel_attributes(self) -> None: set_tensor_parallel_attribute_by_partition(self.weight, self.depth) set_tensor_parallel_attribute_by_partition(self.bias, self.depth) set_tensor_parallel_attribute_by_partition(self.cls_token, self.depth) set_tensor_parallel_attribute_by_partition(self.pos_embed, self.depth) def _sync_grad_hook(self, grad) -> Tensor: grad = all_reduce(grad.clone(), self.input_parallel_mode) grad = all_reduce(grad, self.weight_parallel_mode) return grad def reset_parameters(self, weight_initializer, bias_initializer, position_embed_initializer) -> None: with seed(ParallelMode.TENSOR): fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight) fan_out = self.embed_size weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out) bias_initializer(self.bias, fan_in=fan_in) position_embed_initializer(self.pos_embed) weight_src_rank = gpc.get_ranks_in_group(self.weight_parallel_mode)[0] input_src_rank = gpc.get_ranks_in_group(self.input_parallel_mode)[0] broadcast(self.weight, weight_src_rank, self.weight_parallel_mode) broadcast(self.bias, weight_src_rank, self.weight_parallel_mode) broadcast(self.pos_embed, weight_src_rank, self.weight_parallel_mode) broadcast(self.weight, input_src_rank, self.input_parallel_mode) broadcast(self.bias, input_src_rank, self.input_parallel_mode) broadcast(self.pos_embed, input_src_rank, self.input_parallel_mode) self.weight.register_hook(self._sync_grad_hook) self.bias.register_hook(self._sync_grad_hook) self.cls_token.register_hook(self._sync_grad_hook) self.pos_embed.register_hook(self._sync_grad_hook) def _load_from_global_state_dict(self, state_dict, prefix, *args, **kwargs): local_state = OrderedDict() weight_key = prefix + 'weight' bias_key = prefix + 'bias' cls_token_key = prefix + 'cls_token' pos_embed_key = prefix + 'pos_embed' if gpc.get_local_rank(ParallelMode.TENSOR) == 0: # weight weight = state_dict.pop(weight_key, None) if weight is not None: local_state[weight_key] = weight # bias bias = state_dict.pop(bias_key, None) if bias is not None: local_state[bias_key] = bias # cls token cls_token = state_dict.pop(cls_token_key, None) if cls_token is not None: local_state[cls_token_key] = cls_token # pos embed pos_embed = state_dict.pop(pos_embed_key, None) if pos_embed is not None: local_state[pos_embed_key] = pos_embed # partition in output groups if gpc.get_local_rank(self.input_parallel_mode) == 0 and \ gpc.get_local_rank(self.weight_parallel_mode) == 0: local_state = partition_tensor_parallel_state_dict( local_state, self.output_parallel_mode, dims={ weight_key: 0, bias_key: 0, cls_token_key: -1, pos_embed_key: -1 }, partition_states={ weight_key: True, bias_key: True, cls_token_key: True, pos_embed_key: True }, ) # broadcast in input groups if gpc.get_local_rank(self.weight_parallel_mode) == 0: local_state = broadcast_state_dict(local_state, self.input_parallel_mode) # broadcast in weight groups local_state = broadcast_state_dict(local_state, self.weight_parallel_mode) super()._load_from_global_state_dict(local_state, prefix, *args, **kwargs) def _save_to_global_state_dict(self, destination, prefix, keep_vars): weight_key = prefix + 'weight' bias_key = prefix + 'bias' cls_token_key = prefix + 'cls_token' pos_embed_key = prefix + 'pos_embed' local_state = OrderedDict({ weight_key: self.weight, bias_key: self.bias, cls_token_key: self.cls_token, pos_embed_key: self.pos_embed }) # gather in output groups if gpc.get_local_rank(self.input_parallel_mode) == 0 and \ gpc.get_local_rank(self.weight_parallel_mode) == 0: local_state = gather_tensor_parallel_state_dict( local_state, self.output_parallel_mode, dims={ weight_key: 0, bias_key: 0, cls_token_key: -1, pos_embed_key: -1 }, partition_states={ weight_key: True, bias_key: True, cls_token_key: True, pos_embed_key: True }, keep_vars=keep_vars, ) if gpc.get_local_rank(ParallelMode.TENSOR) == 0: destination.update(local_state) def forward(self, input_: Tensor) -> Tensor: input_ = split_tensor_3d(input_, 0, self.weight_parallel_mode) input_ = split_tensor_3d(input_, 0, self.input_parallel_mode) output = F.conv2d(input_, self.weight, self.bias, stride=self.patch_size) if self.flatten: output = output.flatten(2).transpose(1, 2) # BCHW -> BNC cls_token = self.cls_token.expand(output.shape[0], -1, -1) output = torch.cat((cls_token, output), dim=1) output = output + self.pos_embed return output @LAYERS.register_module class Embedding3D(ParallelLayer): r"""Embedding for 3D parallelism. Args: num_embeddings (int): number of embeddings. embedding_dim (int): dimension of embedding. padding_idx (int, optional): If specified, the entries at padding_idx do not contribute to the gradient; therefore, the embedding vector at padding_idx is not updated during training, i.e. it remains as a fixed “pad”, defaults to None. dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None. weight_initializer (:class:`typing.Callable`, optional): he initializer of weight, defaults to normal initializer. The ``args`` and ``kwargs`` used in :class:``torch.nn.functional.embedding`` should contain: :: max_norm (float, optional): If given, each embedding vector with norm larger than max_norm is renormalized to have norm max_norm. Note: this will modify weight in-place. norm_type (float, optional): The p of the p-norm to compute for the max_norm option. Default 2. scale_grad_by_freq (bool, optional): If given, this will scale gradients by the inverse of frequency of the words in the mini-batch. Default False. sparse (bool, optional): If True, gradient w.r.t. weight will be a sparse tensor. Default False. More details about ``args`` and ``kwargs`` could be found in `Embedding `_. More details about initializer please refer to `init `_ """ def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: int = None, dtype: torch.dtype = None, weight_initializer: Callable = init.normal_(), *args, **kwargs): super().__init__() self.depth = get_depth_from_env() self.input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D) self.weight_parallel_mode = get_parallel_mode_from_env(WEIGHT_GROUP_3D) self.output_parallel_mode = get_last_group(self.input_parallel_mode, self.weight_parallel_mode) self.num_embeddings = num_embeddings self.embed_dim = embedding_dim embed_dim_per_partition = divide(embedding_dim, self.depth) self.padding_idx = padding_idx self.embed_args = args self.embed_kwargs = kwargs self.weight = nn.Parameter( torch.empty((num_embeddings, embed_dim_per_partition), device=get_current_device(), dtype=dtype)) self.reset_parameters(weight_initializer) self._set_tensor_parallel_attributes() def _set_tensor_parallel_attributes(self) -> None: set_tensor_parallel_attribute_by_partition(self.weight, self.depth) def reset_parameters(self, weight_initializer) -> None: with seed(ParallelMode.TENSOR): fan_in, fan_out = self.num_embeddings, self.embed_dim weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out) self._fill_padding_idx_with_zero() weight_src_rank = gpc.get_ranks_in_group(self.weight_parallel_mode)[0] broadcast(self.weight, weight_src_rank, self.weight_parallel_mode) def _fill_padding_idx_with_zero(self) -> None: if self.padding_idx is not None: with torch.no_grad(): self.weight[self.padding_idx].fill_(0) def _load_from_global_state_dict(self, state_dict, prefix, *args, **kwargs): local_state = OrderedDict() weight_key = prefix + 'weight' if gpc.get_local_rank(ParallelMode.TENSOR) == 0: # weight weight = state_dict.pop(weight_key, None) if weight is not None: local_state[weight_key] = weight # partition in output groups if gpc.get_local_rank(self.input_parallel_mode) == 0 and \ gpc.get_local_rank(self.weight_parallel_mode) == 0: local_state = partition_tensor_parallel_state_dict( local_state, self.output_parallel_mode, dims={weight_key: 0}, partition_states={weight_key: True}, ) # broadcast in input groups if gpc.get_local_rank(self.weight_parallel_mode) == 0: local_state = broadcast_state_dict(local_state, self.input_parallel_mode) # broadcast in weight groups local_state = broadcast_state_dict(local_state, self.weight_parallel_mode) super()._load_from_global_state_dict(local_state, prefix, *args, **kwargs) def _save_to_global_state_dict(self, destination, prefix, keep_vars): weight_key = prefix + 'weight' local_state = OrderedDict({weight_key: self.weight}) # gather in output groups if gpc.get_local_rank(self.input_parallel_mode) == 0 and \ gpc.get_local_rank(self.weight_parallel_mode) == 0: local_state = gather_tensor_parallel_state_dict( local_state, self.output_parallel_mode, dims={weight_key: 0}, partition_states={weight_key: True}, keep_vars=keep_vars, ) if gpc.get_local_rank(ParallelMode.TENSOR) == 0: destination.update(local_state) def forward(self, input_: Tensor) -> Tensor: input_ = split_tensor_3d(input_, 0, self.weight_parallel_mode) input_ = split_tensor_3d(input_, 0, self.input_parallel_mode) weight = broadcast_weight_3d_from_diagonal(self.weight, self.input_parallel_mode, self.weight_parallel_mode, self.output_parallel_mode) output = F.embedding(input_, weight, self.padding_idx, *self.embed_args, **self.embed_kwargs) return output @LAYERS.register_module class VocabParallelEmbedding3D(ParallelLayer): r"""Embedding parallelized in the vocabulary dimension. Args: num_embeddings (int): number of embeddings. embedding_dim (int): dimension of embedding. padding_idx (int, optional): If specified, the entries at padding_idx do not contribute to the gradient; therefore, the embedding vector at padding_idx is not updated during training, i.e. it remains as a fixed “pad”, defaults to None. dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None. weight_initializer (:class:`typing.Callable`, optional): he initializer of weight, defaults to normal initializer. The ``args`` and ``kwargs`` used in :class:``torch.nn.functional.embedding`` should contain: :: max_norm (float, optional): If given, each embedding vector with norm larger than max_norm is renormalized to have norm max_norm. Note: this will modify weight in-place. norm_type (float, optional): The p of the p-norm to compute for the max_norm option. Default 2. scale_grad_by_freq (bool, optional): If given, this will scale gradients by the inverse of frequency of the words in the mini-batch. Default False. sparse (bool, optional): If True, gradient w.r.t. weight will be a sparse tensor. Default False. More details about ``args`` and ``kwargs`` could be found in `Embedding `_. More details about initializer please refer to `init `_. """ def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: int = None, dtype: torch.dtype = None, weight_initializer: Callable = init.normal_(), *args, **kwargs): super().__init__() self.num_embeddings = num_embeddings self.embed_dim = embedding_dim self.padding_idx = padding_idx self.embed_args = args self.embed_kwargs = kwargs self.depth = get_depth_from_env() self.input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D) self.weight_parallel_mode = get_parallel_mode_from_env(WEIGHT_GROUP_3D) self.output_parallel_mode = get_last_group(self.input_parallel_mode, self.weight_parallel_mode) self.num_embeddings_per_partition = divide(self.num_embeddings, self.depth**2) self.embed_dim_per_partition = divide(self.embed_dim, self.depth) vocab_parallel_rank = gpc.get_local_rank(self.input_parallel_mode) self.vocab_start_index = vocab_parallel_rank * self.num_embeddings_per_partition * self.depth self.vocab_end_index = self.vocab_start_index + self.num_embeddings_per_partition * self.depth self.weight = Parameter( torch.empty((self.num_embeddings_per_partition, self.embed_dim_per_partition), device=get_current_device(), dtype=dtype)) self.reset_parameters(weight_initializer) self._set_tensor_parallel_attributes() env.vocab_parallel = True def _set_tensor_parallel_attributes(self): set_tensor_parallel_attribute_by_partition(self.weight, self.depth**3) def reset_parameters(self, weight_initializer) -> None: with seed(ParallelMode.TENSOR): fan_in, fan_out = self.num_embeddings, self.embed_dim weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out) self._fill_padding_idx_with_zero() def _fill_padding_idx_with_zero(self) -> None: if self.padding_idx is not None and \ self.padding_idx >= self.vocab_start_index and self.padding_idx < self.vocab_end_index: with torch.no_grad(): self.weight[self.padding_idx - self.vocab_start_index].fill_(0) def _load_from_global_state_dict(self, state_dict, prefix, *args, **kwargs): local_state = OrderedDict() weight_key = prefix + 'weight' if gpc.get_local_rank(ParallelMode.TENSOR) == 0: # weight weight = state_dict.pop(weight_key, None) if weight is not None: local_state[weight_key] = weight # partition in output groups if gpc.get_local_rank(self.input_parallel_mode) == 0 and \ gpc.get_local_rank(self.weight_parallel_mode) == 0: local_state = partition_tensor_parallel_state_dict( local_state, self.output_parallel_mode, dims={weight_key: -1}, partition_states={weight_key: True}, ) # partition in input groups if gpc.get_local_rank(self.weight_parallel_mode) == 0: local_state = partition_tensor_parallel_state_dict( local_state, self.input_parallel_mode, dims={weight_key: 0}, partition_states={weight_key: True}, ) # partition in weight groups local_state = partition_tensor_parallel_state_dict( local_state, self.weight_parallel_mode, dims={weight_key: 0}, partition_states={weight_key: True}, ) super()._load_from_global_state_dict(local_state, prefix, *args, **kwargs) def _save_to_global_state_dict(self, destination, prefix, keep_vars): weight_key = prefix + 'weight' local_state = OrderedDict({weight_key: self.weight}) # gather in weight groups local_state = gather_tensor_parallel_state_dict( local_state, self.weight_parallel_mode, dims={weight_key: 0}, partition_states={weight_key: True}, keep_vars=keep_vars, ) # gather in input groups if gpc.get_local_rank(self.weight_parallel_mode) == 0: local_state = gather_tensor_parallel_state_dict( local_state, self.input_parallel_mode, dims={weight_key: 0}, partition_states={weight_key: True}, keep_vars=keep_vars, ) # gather in output groups if gpc.get_local_rank(self.input_parallel_mode) == 0 and \ gpc.get_local_rank(self.weight_parallel_mode) == 0: local_state = gather_tensor_parallel_state_dict( local_state, self.output_parallel_mode, dims={weight_key: -1}, partition_states={weight_key: True}, keep_vars=keep_vars, ) if gpc.get_local_rank(ParallelMode.TENSOR) == 0: destination.update(local_state) def forward(self, input_: Tensor) -> Tensor: input_ = split_tensor_3d(input_, 0, self.weight_parallel_mode) input_mask = (input_ < self.vocab_start_index) | (input_ >= self.vocab_end_index) masked_input = input_.clone() - self.vocab_start_index masked_input[input_mask] = 0 weight = all_gather_tensor_3d(self.weight, 0, self.weight_parallel_mode) output_parallel = F.embedding(masked_input, weight, self.padding_idx, *self.embed_args, **self.embed_kwargs) output_parallel[input_mask, :] = 0. output = reduce_scatter_tensor_3d(output_parallel, 0, self.input_parallel_mode) return output