diff --git a/colossalai/nn/layer/parallel_1d/__init__.py b/colossalai/nn/layer/parallel_1d/__init__.py index fddeedd7d..2353851df 100644 --- a/colossalai/nn/layer/parallel_1d/__init__.py +++ b/colossalai/nn/layer/parallel_1d/__init__.py @@ -1,7 +1,7 @@ -from .layers import (Classifier1D, Dropout1D, Embedding1D, Linear1D, Linear1D_Col, Linear1D_Row, - VocabParallelClassifier1D, VocabParallelEmbedding1D) +from .layers import (Classifier1D, Dropout1D, Embedding1D, LayerNorm1D, Linear1D, Linear1D_Col, Linear1D_Row, + PatchEmbedding1D, VocabParallelClassifier1D, VocabParallelEmbedding1D) __all__ = [ 'Linear1D', 'Linear1D_Col', 'Linear1D_Row', 'Embedding1D', 'Dropout1D', 'Classifier1D', 'VocabParallelClassifier1D', - 'VocabParallelEmbedding1D' + 'VocabParallelEmbedding1D', 'LayerNorm1D', 'PatchEmbedding1D' ] diff --git a/colossalai/nn/layer/parallel_1d/layers.py b/colossalai/nn/layer/parallel_1d/layers.py index 141d988f6..2daf875f8 100644 --- a/colossalai/nn/layer/parallel_1d/layers.py +++ b/colossalai/nn/layer/parallel_1d/layers.py @@ -2,6 +2,7 @@ # -*- encoding: utf-8 -*- import math +from collections import OrderedDict from typing import Callable, Tuple import torch @@ -10,20 +11,25 @@ from colossalai.communication import broadcast from colossalai.context import ParallelMode, seed from colossalai.core import global_context as gpc from colossalai.global_variables import tensor_parallel_env as env +from colossalai.kernel import LayerNorm from colossalai.nn import init as init from colossalai.registry import LAYERS +from colossalai.utils.checkpointing import (broadcast_state_dict, gather_tensor_parallel_state_dict, + partition_tensor_parallel_state_dict) from colossalai.utils.cuda import get_current_device from torch import Tensor from torch.nn.parameter import Parameter +from ..vanilla import VanillaPatchEmbedding from ..base_layer import ParallelLayer +from ..colossalai_layer._utils import ColossalaiModule from ..utils import divide, set_tensor_parallel_attribute_by_partition from ._utils import (gather_forward_split_backward, get_parallel_input, reduce_grad, reduce_input, set_parallel_input, split_forward_gather_backward) @LAYERS.register_module -class Linear1D(torch.nn.Module): +class Linear1D(ColossalaiModule): r"""Linear layer for 1D parallelism. Args: @@ -52,37 +58,69 @@ class Linear1D(torch.nn.Module): skip_bias_add: bool = False, weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1)): - super().__init__() parallel_input = get_parallel_input() if not parallel_input: - self.layer = Linear1D_Col(in_features, - out_features, - bias=bias, - dtype=dtype, - gather_output=gather_output, - skip_bias_add=skip_bias_add, - weight_initializer=weight_initializer, - bias_initializer=bias_initializer) + layer = Linear1D_Col(in_features, + out_features, + bias=bias, + dtype=dtype, + gather_output=gather_output, + skip_bias_add=skip_bias_add, + weight_initializer=weight_initializer, + bias_initializer=bias_initializer) else: - self.layer = Linear1D_Row(in_features, - out_features, - bias=bias, - dtype=dtype, - parallel_input=parallel_input, - skip_bias_add=skip_bias_add, - weight_initializer=weight_initializer, - bias_initializer=bias_initializer) + layer = Linear1D_Row(in_features, + out_features, + bias=bias, + dtype=dtype, + parallel_input=parallel_input, + skip_bias_add=skip_bias_add, + weight_initializer=weight_initializer, + bias_initializer=bias_initializer) + super().__init__(layer) - @property - def weight(self): - return self.layer.weight - @property - def bias(self): - return self.layer.bias +@LAYERS.register_module +class LayerNorm1D(ColossalaiModule): + r""" + Layer Normalization for colossalai + + :param normalized_shape: input shape from an expected input + of size. :math:`[* \times \text{normalized_shape}[0] \times \text{normalized_shape}[1] + \times \ldots \times \text{normalized_shape}[-1]]` + If a single integer is used, it is treated as a singleton list, and this module will + normalize over the last dimension which is expected to be of that specific size. + :type normalized_shape: int + :param eps: a value added to the denominator for numerical stability, defaults to 1e-05 + :type eps: float, optional + :param dtype: The dtype of parameters, defaults to None + :type dtype: torch.dtype, optional + """ - def forward(self, input_: Tensor) -> Tensor: - return self.layer(input_) + def __init__(self, normalized_shape: int, eps=1e-05, dtype=None): + norm = LayerNorm(normalized_shape, eps=eps, device=get_current_device(), dtype=dtype) + super().__init__(norm) + + def _load_from_state_dict(self, state_dict, prefix, *args): + local_state = OrderedDict() + weight_key = prefix + 'weight' + bias_key = prefix + 'bias' + if gpc.get_local_rank(ParallelMode.TENSOR) == 0: + # weight + weight = state_dict.pop(weight_key, None) + if weight is not None: + local_state[weight_key] = weight + # bias + bias = state_dict.pop(bias_key, None) + if bias is not None: + local_state[bias_key] = bias + + local_state = broadcast_state_dict(local_state, ParallelMode.PARALLEL_1D) + super()._load_from_state_dict(local_state, prefix, *args) + + def _save_to_state_dict(self, destination, prefix, keep_vars): + if gpc.get_local_rank(ParallelMode.TENSOR) == 0: + super()._save_to_state_dict(destination, prefix, keep_vars) @LAYERS.register_module @@ -153,6 +191,55 @@ class Classifier1D(ParallelLayer): num_partition = gpc.get_world_size(ParallelMode.TENSOR) set_tensor_parallel_attribute_by_partition(self.weight, num_partition) + def _load_from_state_dict(self, state_dict, prefix, *args): + local_state = OrderedDict() + weight_key = prefix + 'weight' + bias_key = prefix + 'bias' + if gpc.get_local_rank(ParallelMode.TENSOR) == 0: + # weight + if self.has_weight: + weight = state_dict.pop(weight_key, None) + if weight is not None: + local_state[weight_key] = weight + # bias + if self.bias is not None: + bias = state_dict.pop(bias_key, None) + if bias is not None: + local_state[bias_key] = bias + + local_state = partition_tensor_parallel_state_dict(local_state, + ParallelMode.PARALLEL_1D, + dims={ + weight_key: -1, + bias_key: 0 + }, + partition_states={ + weight_key: True, + bias_key: False + }) + super()._load_from_state_dict(local_state, prefix, *args) + + def _save_to_state_dict(self, destination, prefix, keep_vars): + weight_key = prefix + 'weight' + bias_key = prefix + 'bias' + local_state = OrderedDict() + if self.has_weight: + local_state[weight_key] = self.weight + if self.bias is not None: + local_state[bias_key] = self.bias + local_state = gather_tensor_parallel_state_dict(local_state, + ParallelMode.PARALLEL_1D, + dims={ + weight_key: -1, + bias_key: 0 + }, + partition_states={ + weight_key: True, + bias_key: False + }, + keep_vars=keep_vars) + destination.update(local_state) + def forward(self, input_: Tensor) -> Tensor: # Set up backprop all-reduce. if self.parallel_input: @@ -241,6 +328,55 @@ class VocabParallelClassifier1D(ParallelLayer): if self.bias is not None: set_tensor_parallel_attribute_by_partition(self.bias, num_partition) + def _load_from_state_dict(self, state_dict, prefix, *args): + local_state = OrderedDict() + weight_key = prefix + 'weight' + bias_key = prefix + 'bias' + if gpc.get_local_rank(ParallelMode.TENSOR) == 0: + # weight + if self.has_weight: + weight = state_dict.pop(weight_key, None) + if weight is not None: + local_state[weight_key] = weight + # bias + if self.bias is not None: + bias = state_dict.pop(bias_key, None) + if bias is not None: + local_state[bias_key] = bias + + local_state = partition_tensor_parallel_state_dict(local_state, + ParallelMode.PARALLEL_1D, + dims={ + weight_key: 0, + bias_key: 0 + }, + partition_states={ + weight_key: True, + bias_key: True + }) + super()._load_from_state_dict(local_state, prefix, *args) + + def _save_to_state_dict(self, destination, prefix, keep_vars): + weight_key = prefix + 'weight' + bias_key = prefix + 'bias' + local_state = OrderedDict() + if self.has_weight: + local_state[weight_key] = self.weight + if self.bias is not None: + local_state[bias_key] = self.bias + local_state = gather_tensor_parallel_state_dict(local_state, + ParallelMode.PARALLEL_1D, + dims={ + weight_key: 0, + bias_key: 0 + }, + partition_states={ + weight_key: True, + bias_key: True + }, + keep_vars=keep_vars) + destination.update(local_state) + def forward(self, input_: Tensor) -> Tensor: assert input_.shape[-1] == self.weight.shape[-1], \ 'Invalid shapes in VocabParallelClassifier1D forward: input={}, weight={}. Expected last dim of input {}.'.format( @@ -328,6 +464,52 @@ class Linear1D_Col(ParallelLayer): if self.bias is not None: set_tensor_parallel_attribute_by_partition(self.bias, num_partition) + def _load_from_state_dict(self, state_dict, prefix, *args): + local_state = OrderedDict() + weight_key = prefix + 'weight' + bias_key = prefix + 'bias' + if gpc.get_local_rank(ParallelMode.TENSOR) == 0: + # weight + weight = state_dict.pop(weight_key, None) + if weight is not None: + local_state[weight_key] = weight + # bias + if self.bias is not None: + bias = state_dict.pop(bias_key, None) + if bias is not None: + local_state[bias_key] = bias + + local_state = partition_tensor_parallel_state_dict(local_state, + ParallelMode.PARALLEL_1D, + dims={ + weight_key: 0, + bias_key: 0 + }, + partition_states={ + weight_key: True, + bias_key: True + }) + super()._load_from_state_dict(local_state, prefix, *args) + + def _save_to_state_dict(self, destination, prefix, keep_vars): + weight_key = prefix + 'weight' + bias_key = prefix + 'bias' + local_state = OrderedDict({weight_key: self.weight}) + if self.bias is not None: + local_state[bias_key] = self.bias + local_state = gather_tensor_parallel_state_dict(local_state, + ParallelMode.PARALLEL_1D, + dims={ + weight_key: 0, + bias_key: 0 + }, + partition_states={ + weight_key: True, + bias_key: True + }, + keep_vars=keep_vars) + destination.update(local_state) + def forward(self, input_: Tensor) -> Tuple[Tensor, Tensor]: assert input_.shape[-1] == self.weight.shape[-1], \ 'Invalid shapes in Linear1D_Col forward: input={}, weight={}. Expected last dim of input {}.'.format( @@ -420,6 +602,52 @@ class Linear1D_Row(ParallelLayer): num_partition = gpc.get_world_size(ParallelMode.TENSOR) set_tensor_parallel_attribute_by_partition(self.weight, num_partition) + def _load_from_state_dict(self, state_dict, prefix, *args): + local_state = OrderedDict() + weight_key = prefix + 'weight' + bias_key = prefix + 'bias' + if gpc.get_local_rank(ParallelMode.TENSOR) == 0: + # weight + weight = state_dict.pop(weight_key, None) + if weight is not None: + local_state[weight_key] = weight + # bias + if self.bias is not None: + bias = state_dict.pop(bias_key, None) + if bias is not None: + local_state[bias_key] = bias + + local_state = partition_tensor_parallel_state_dict(local_state, + ParallelMode.PARALLEL_1D, + dims={ + weight_key: -1, + bias_key: 0 + }, + partition_states={ + weight_key: True, + bias_key: False + }) + super()._load_from_state_dict(local_state, prefix, *args) + + def _save_to_state_dict(self, destination, prefix, keep_vars): + weight_key = prefix + 'weight' + bias_key = prefix + 'bias' + local_state = OrderedDict({weight_key: self.weight}) + if self.bias is not None: + local_state[bias_key] = self.bias + local_state = gather_tensor_parallel_state_dict(local_state, + ParallelMode.PARALLEL_1D, + dims={ + weight_key: -1, + bias_key: 0 + }, + partition_states={ + weight_key: True, + bias_key: False + }, + keep_vars=keep_vars) + destination.update(local_state) + def forward(self, input_: Tensor) -> Tensor: # Set up backprop all-reduce. if self.parallel_input: @@ -514,6 +742,31 @@ class Embedding1D(ParallelLayer): with torch.no_grad(): self.weight[self.padding_idx].fill_(0) + def _load_from_state_dict(self, state_dict, prefix, *args): + local_state = OrderedDict() + weight_key = prefix + 'weight' + if gpc.get_local_rank(ParallelMode.TENSOR) == 0: + # weight + weight = state_dict.pop(weight_key, None) + if weight is not None: + local_state[weight_key] = weight + + local_state = partition_tensor_parallel_state_dict(local_state, + ParallelMode.PARALLEL_1D, + dims={weight_key: -1}, + partition_states={weight_key: True}) + super()._load_from_state_dict(local_state, prefix, *args) + + def _save_to_state_dict(self, destination, prefix, keep_vars): + weight_key = prefix + 'weight' + local_state = OrderedDict({weight_key: self.weight}) + local_state = gather_tensor_parallel_state_dict(local_state, + ParallelMode.PARALLEL_1D, + dims={weight_key: -1}, + partition_states={weight_key: True}, + keep_vars=keep_vars) + destination.update(local_state) + def forward(self, input_: Tensor) -> Tensor: output_parallel = F.embedding(input_, self.weight, self.padding_idx, *self.embed_args, **self.embed_kwargs) @@ -594,10 +847,35 @@ class VocabParallelEmbedding1D(torch.nn.Module): def _fill_padding_idx_with_zero(self) -> None: if self.padding_idx is not None and \ - self.padding_idx >= self.vocab_start_index and self.padding_idx < self.vocab_end_index: + self.padding_idx >= self.vocab_start_index and self.padding_idx < self.vocab_end_index: with torch.no_grad(): self.weight[self.padding_idx - self.vocab_start_index].fill_(0) + def _load_from_state_dict(self, state_dict, prefix, *args): + local_state = OrderedDict() + weight_key = prefix + 'weight' + if gpc.get_local_rank(ParallelMode.TENSOR) == 0: + # weight + weight = state_dict.pop(weight_key, None) + if weight is not None: + local_state[weight_key] = weight + + local_state = partition_tensor_parallel_state_dict(local_state, + ParallelMode.PARALLEL_1D, + dims={weight_key: 0}, + partition_states={weight_key: True}) + super()._load_from_state_dict(local_state, prefix, *args) + + def _save_to_state_dict(self, destination, prefix, keep_vars): + weight_key = prefix + 'weight' + local_state = OrderedDict({weight_key: self.weight}) + local_state = gather_tensor_parallel_state_dict(local_state, + ParallelMode.PARALLEL_1D, + dims={weight_key: 0}, + partition_states={weight_key: True}, + keep_vars=keep_vars) + destination.update(local_state) + def forward(self, input_: Tensor) -> Tensor: # Build the mask. input_mask = (input_ < self.vocab_start_index) | (input_ >= self.vocab_end_index) @@ -637,3 +915,66 @@ class Dropout1D(ParallelLayer): else: output = F.dropout(input_, self.p, self.training, self.inplace) return output + + +@LAYERS.register_module +class PatchEmbedding1D(ColossalaiModule): + """ + 2D Image to Patch Embedding + + :param img_size: image size + :type img_size: int + :param patch_size: patch size + :type patch_size: int + :param in_chans: number of channels of input image + :type in_chans: int + :param embed_size: size of embedding + :type embed_size: int + :param dtype: The dtype of parameters, defaults to None + :type dtype: torch.dtype, optional + :param flatten: whether to flatten output tensor, defaults to True + :type flatten: bool, optional + :param weight_initializer: The intializer of weight, defaults to kaiming uniform initializer + :type weight_initializer: typing.Callable, optional + :param bias_initializer: The intializer of bias, defaults to xavier uniform initializer + :type bias_initializer: typing.Callable, optional + :param position_embed_initializer: The intializer of position embedding, defaults to zero + :type position_embed_initializer: typing.Callable, optional + """ + + def __init__(self, + img_size: int, + patch_size: int, + in_chans: int, + embed_size: int, + dtype: torch.dtype = None, + flatten: bool = True, + weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), + bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1), + position_embed_initializer: Callable = init.zeros_()): + embed = VanillaPatchEmbedding(img_size, + patch_size, + in_chans, + embed_size, + dtype=dtype, + flatten=flatten, + weight_initializer=weight_initializer, + bias_initializer=bias_initializer, + position_embed_initializer=position_embed_initializer) + super().__init__(embed) + + def _load_from_state_dict(self, state_dict, prefix, *args): + local_state = OrderedDict() + param_keys = [prefix + 'weight', prefix + 'bias', prefix + 'cls_token', prefix + 'pos_embed'] + if gpc.get_local_rank(ParallelMode.TENSOR) == 0: + for key in param_keys: + param = state_dict.pop(key, None) + if param is not None: + local_state[key] = param + + local_state = broadcast_state_dict(local_state, ParallelMode.PARALLEL_1D) + super()._load_from_state_dict(local_state, prefix, *args) + + def _save_to_state_dict(self, destination, prefix, keep_vars): + if gpc.get_local_rank(ParallelMode.TENSOR) == 0: + super()._save_to_state_dict(destination, prefix, keep_vars)