|
|
@ -2,6 +2,7 @@
|
|
|
|
# -*- encoding: utf-8 -*-
|
|
|
|
# -*- encoding: utf-8 -*-
|
|
|
|
|
|
|
|
|
|
|
|
import math
|
|
|
|
import math
|
|
|
|
|
|
|
|
from collections import OrderedDict
|
|
|
|
from typing import Callable, Tuple
|
|
|
|
from typing import Callable, Tuple
|
|
|
|
|
|
|
|
|
|
|
|
import torch
|
|
|
|
import torch
|
|
|
@ -10,20 +11,25 @@ from colossalai.communication import broadcast
|
|
|
|
from colossalai.context import ParallelMode, seed
|
|
|
|
from colossalai.context import ParallelMode, seed
|
|
|
|
from colossalai.core import global_context as gpc
|
|
|
|
from colossalai.core import global_context as gpc
|
|
|
|
from colossalai.global_variables import tensor_parallel_env as env
|
|
|
|
from colossalai.global_variables import tensor_parallel_env as env
|
|
|
|
|
|
|
|
from colossalai.kernel import LayerNorm
|
|
|
|
from colossalai.nn import init as init
|
|
|
|
from colossalai.nn import init as init
|
|
|
|
from colossalai.registry import LAYERS
|
|
|
|
from colossalai.registry import LAYERS
|
|
|
|
|
|
|
|
from colossalai.utils.checkpointing import (broadcast_state_dict, gather_tensor_parallel_state_dict,
|
|
|
|
|
|
|
|
partition_tensor_parallel_state_dict)
|
|
|
|
from colossalai.utils.cuda import get_current_device
|
|
|
|
from colossalai.utils.cuda import get_current_device
|
|
|
|
from torch import Tensor
|
|
|
|
from torch import Tensor
|
|
|
|
from torch.nn.parameter import Parameter
|
|
|
|
from torch.nn.parameter import Parameter
|
|
|
|
|
|
|
|
from ..vanilla import VanillaPatchEmbedding
|
|
|
|
|
|
|
|
|
|
|
|
from ..base_layer import ParallelLayer
|
|
|
|
from ..base_layer import ParallelLayer
|
|
|
|
|
|
|
|
from ..colossalai_layer._utils import ColossalaiModule
|
|
|
|
from ..utils import divide, set_tensor_parallel_attribute_by_partition
|
|
|
|
from ..utils import divide, set_tensor_parallel_attribute_by_partition
|
|
|
|
from ._utils import (gather_forward_split_backward, get_parallel_input, reduce_grad, reduce_input, set_parallel_input,
|
|
|
|
from ._utils import (gather_forward_split_backward, get_parallel_input, reduce_grad, reduce_input, set_parallel_input,
|
|
|
|
split_forward_gather_backward)
|
|
|
|
split_forward_gather_backward)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@LAYERS.register_module
|
|
|
|
@LAYERS.register_module
|
|
|
|
class Linear1D(torch.nn.Module):
|
|
|
|
class Linear1D(ColossalaiModule):
|
|
|
|
r"""Linear layer for 1D parallelism.
|
|
|
|
r"""Linear layer for 1D parallelism.
|
|
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
Args:
|
|
|
@ -52,37 +58,69 @@ class Linear1D(torch.nn.Module):
|
|
|
|
skip_bias_add: bool = False,
|
|
|
|
skip_bias_add: bool = False,
|
|
|
|
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
|
|
|
|
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
|
|
|
|
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1)):
|
|
|
|
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1)):
|
|
|
|
super().__init__()
|
|
|
|
|
|
|
|
parallel_input = get_parallel_input()
|
|
|
|
parallel_input = get_parallel_input()
|
|
|
|
if not parallel_input:
|
|
|
|
if not parallel_input:
|
|
|
|
self.layer = Linear1D_Col(in_features,
|
|
|
|
layer = Linear1D_Col(in_features,
|
|
|
|
out_features,
|
|
|
|
out_features,
|
|
|
|
bias=bias,
|
|
|
|
bias=bias,
|
|
|
|
dtype=dtype,
|
|
|
|
dtype=dtype,
|
|
|
|
gather_output=gather_output,
|
|
|
|
gather_output=gather_output,
|
|
|
|
skip_bias_add=skip_bias_add,
|
|
|
|
skip_bias_add=skip_bias_add,
|
|
|
|
weight_initializer=weight_initializer,
|
|
|
|
weight_initializer=weight_initializer,
|
|
|
|
bias_initializer=bias_initializer)
|
|
|
|
bias_initializer=bias_initializer)
|
|
|
|
else:
|
|
|
|
else:
|
|
|
|
self.layer = Linear1D_Row(in_features,
|
|
|
|
layer = Linear1D_Row(in_features,
|
|
|
|
out_features,
|
|
|
|
out_features,
|
|
|
|
bias=bias,
|
|
|
|
bias=bias,
|
|
|
|
dtype=dtype,
|
|
|
|
dtype=dtype,
|
|
|
|
parallel_input=parallel_input,
|
|
|
|
parallel_input=parallel_input,
|
|
|
|
skip_bias_add=skip_bias_add,
|
|
|
|
skip_bias_add=skip_bias_add,
|
|
|
|
weight_initializer=weight_initializer,
|
|
|
|
weight_initializer=weight_initializer,
|
|
|
|
bias_initializer=bias_initializer)
|
|
|
|
bias_initializer=bias_initializer)
|
|
|
|
|
|
|
|
super().__init__(layer)
|
|
|
|
|
|
|
|
|
|
|
|
@property
|
|
|
|
|
|
|
|
def weight(self):
|
|
|
|
|
|
|
|
return self.layer.weight
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@property
|
|
|
|
@LAYERS.register_module
|
|
|
|
def bias(self):
|
|
|
|
class LayerNorm1D(ColossalaiModule):
|
|
|
|
return self.layer.bias
|
|
|
|
r"""
|
|
|
|
|
|
|
|
Layer Normalization for colossalai
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
:param normalized_shape: input shape from an expected input
|
|
|
|
|
|
|
|
of size. :math:`[* \times \text{normalized_shape}[0] \times \text{normalized_shape}[1]
|
|
|
|
|
|
|
|
\times \ldots \times \text{normalized_shape}[-1]]`
|
|
|
|
|
|
|
|
If a single integer is used, it is treated as a singleton list, and this module will
|
|
|
|
|
|
|
|
normalize over the last dimension which is expected to be of that specific size.
|
|
|
|
|
|
|
|
:type normalized_shape: int
|
|
|
|
|
|
|
|
:param eps: a value added to the denominator for numerical stability, defaults to 1e-05
|
|
|
|
|
|
|
|
:type eps: float, optional
|
|
|
|
|
|
|
|
:param dtype: The dtype of parameters, defaults to None
|
|
|
|
|
|
|
|
:type dtype: torch.dtype, optional
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
|
|
def forward(self, input_: Tensor) -> Tensor:
|
|
|
|
def __init__(self, normalized_shape: int, eps=1e-05, dtype=None):
|
|
|
|
return self.layer(input_)
|
|
|
|
norm = LayerNorm(normalized_shape, eps=eps, device=get_current_device(), dtype=dtype)
|
|
|
|
|
|
|
|
super().__init__(norm)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _load_from_state_dict(self, state_dict, prefix, *args):
|
|
|
|
|
|
|
|
local_state = OrderedDict()
|
|
|
|
|
|
|
|
weight_key = prefix + 'weight'
|
|
|
|
|
|
|
|
bias_key = prefix + 'bias'
|
|
|
|
|
|
|
|
if gpc.get_local_rank(ParallelMode.TENSOR) == 0:
|
|
|
|
|
|
|
|
# weight
|
|
|
|
|
|
|
|
weight = state_dict.pop(weight_key, None)
|
|
|
|
|
|
|
|
if weight is not None:
|
|
|
|
|
|
|
|
local_state[weight_key] = weight
|
|
|
|
|
|
|
|
# bias
|
|
|
|
|
|
|
|
bias = state_dict.pop(bias_key, None)
|
|
|
|
|
|
|
|
if bias is not None:
|
|
|
|
|
|
|
|
local_state[bias_key] = bias
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
local_state = broadcast_state_dict(local_state, ParallelMode.PARALLEL_1D)
|
|
|
|
|
|
|
|
super()._load_from_state_dict(local_state, prefix, *args)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _save_to_state_dict(self, destination, prefix, keep_vars):
|
|
|
|
|
|
|
|
if gpc.get_local_rank(ParallelMode.TENSOR) == 0:
|
|
|
|
|
|
|
|
super()._save_to_state_dict(destination, prefix, keep_vars)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@LAYERS.register_module
|
|
|
|
@LAYERS.register_module
|
|
|
@ -153,6 +191,55 @@ class Classifier1D(ParallelLayer):
|
|
|
|
num_partition = gpc.get_world_size(ParallelMode.TENSOR)
|
|
|
|
num_partition = gpc.get_world_size(ParallelMode.TENSOR)
|
|
|
|
set_tensor_parallel_attribute_by_partition(self.weight, num_partition)
|
|
|
|
set_tensor_parallel_attribute_by_partition(self.weight, num_partition)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _load_from_state_dict(self, state_dict, prefix, *args):
|
|
|
|
|
|
|
|
local_state = OrderedDict()
|
|
|
|
|
|
|
|
weight_key = prefix + 'weight'
|
|
|
|
|
|
|
|
bias_key = prefix + 'bias'
|
|
|
|
|
|
|
|
if gpc.get_local_rank(ParallelMode.TENSOR) == 0:
|
|
|
|
|
|
|
|
# weight
|
|
|
|
|
|
|
|
if self.has_weight:
|
|
|
|
|
|
|
|
weight = state_dict.pop(weight_key, None)
|
|
|
|
|
|
|
|
if weight is not None:
|
|
|
|
|
|
|
|
local_state[weight_key] = weight
|
|
|
|
|
|
|
|
# bias
|
|
|
|
|
|
|
|
if self.bias is not None:
|
|
|
|
|
|
|
|
bias = state_dict.pop(bias_key, None)
|
|
|
|
|
|
|
|
if bias is not None:
|
|
|
|
|
|
|
|
local_state[bias_key] = bias
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
local_state = partition_tensor_parallel_state_dict(local_state,
|
|
|
|
|
|
|
|
ParallelMode.PARALLEL_1D,
|
|
|
|
|
|
|
|
dims={
|
|
|
|
|
|
|
|
weight_key: -1,
|
|
|
|
|
|
|
|
bias_key: 0
|
|
|
|
|
|
|
|
},
|
|
|
|
|
|
|
|
partition_states={
|
|
|
|
|
|
|
|
weight_key: True,
|
|
|
|
|
|
|
|
bias_key: False
|
|
|
|
|
|
|
|
})
|
|
|
|
|
|
|
|
super()._load_from_state_dict(local_state, prefix, *args)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _save_to_state_dict(self, destination, prefix, keep_vars):
|
|
|
|
|
|
|
|
weight_key = prefix + 'weight'
|
|
|
|
|
|
|
|
bias_key = prefix + 'bias'
|
|
|
|
|
|
|
|
local_state = OrderedDict()
|
|
|
|
|
|
|
|
if self.has_weight:
|
|
|
|
|
|
|
|
local_state[weight_key] = self.weight
|
|
|
|
|
|
|
|
if self.bias is not None:
|
|
|
|
|
|
|
|
local_state[bias_key] = self.bias
|
|
|
|
|
|
|
|
local_state = gather_tensor_parallel_state_dict(local_state,
|
|
|
|
|
|
|
|
ParallelMode.PARALLEL_1D,
|
|
|
|
|
|
|
|
dims={
|
|
|
|
|
|
|
|
weight_key: -1,
|
|
|
|
|
|
|
|
bias_key: 0
|
|
|
|
|
|
|
|
},
|
|
|
|
|
|
|
|
partition_states={
|
|
|
|
|
|
|
|
weight_key: True,
|
|
|
|
|
|
|
|
bias_key: False
|
|
|
|
|
|
|
|
},
|
|
|
|
|
|
|
|
keep_vars=keep_vars)
|
|
|
|
|
|
|
|
destination.update(local_state)
|
|
|
|
|
|
|
|
|
|
|
|
def forward(self, input_: Tensor) -> Tensor:
|
|
|
|
def forward(self, input_: Tensor) -> Tensor:
|
|
|
|
# Set up backprop all-reduce.
|
|
|
|
# Set up backprop all-reduce.
|
|
|
|
if self.parallel_input:
|
|
|
|
if self.parallel_input:
|
|
|
@ -241,6 +328,55 @@ class VocabParallelClassifier1D(ParallelLayer):
|
|
|
|
if self.bias is not None:
|
|
|
|
if self.bias is not None:
|
|
|
|
set_tensor_parallel_attribute_by_partition(self.bias, num_partition)
|
|
|
|
set_tensor_parallel_attribute_by_partition(self.bias, num_partition)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _load_from_state_dict(self, state_dict, prefix, *args):
|
|
|
|
|
|
|
|
local_state = OrderedDict()
|
|
|
|
|
|
|
|
weight_key = prefix + 'weight'
|
|
|
|
|
|
|
|
bias_key = prefix + 'bias'
|
|
|
|
|
|
|
|
if gpc.get_local_rank(ParallelMode.TENSOR) == 0:
|
|
|
|
|
|
|
|
# weight
|
|
|
|
|
|
|
|
if self.has_weight:
|
|
|
|
|
|
|
|
weight = state_dict.pop(weight_key, None)
|
|
|
|
|
|
|
|
if weight is not None:
|
|
|
|
|
|
|
|
local_state[weight_key] = weight
|
|
|
|
|
|
|
|
# bias
|
|
|
|
|
|
|
|
if self.bias is not None:
|
|
|
|
|
|
|
|
bias = state_dict.pop(bias_key, None)
|
|
|
|
|
|
|
|
if bias is not None:
|
|
|
|
|
|
|
|
local_state[bias_key] = bias
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
local_state = partition_tensor_parallel_state_dict(local_state,
|
|
|
|
|
|
|
|
ParallelMode.PARALLEL_1D,
|
|
|
|
|
|
|
|
dims={
|
|
|
|
|
|
|
|
weight_key: 0,
|
|
|
|
|
|
|
|
bias_key: 0
|
|
|
|
|
|
|
|
},
|
|
|
|
|
|
|
|
partition_states={
|
|
|
|
|
|
|
|
weight_key: True,
|
|
|
|
|
|
|
|
bias_key: True
|
|
|
|
|
|
|
|
})
|
|
|
|
|
|
|
|
super()._load_from_state_dict(local_state, prefix, *args)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _save_to_state_dict(self, destination, prefix, keep_vars):
|
|
|
|
|
|
|
|
weight_key = prefix + 'weight'
|
|
|
|
|
|
|
|
bias_key = prefix + 'bias'
|
|
|
|
|
|
|
|
local_state = OrderedDict()
|
|
|
|
|
|
|
|
if self.has_weight:
|
|
|
|
|
|
|
|
local_state[weight_key] = self.weight
|
|
|
|
|
|
|
|
if self.bias is not None:
|
|
|
|
|
|
|
|
local_state[bias_key] = self.bias
|
|
|
|
|
|
|
|
local_state = gather_tensor_parallel_state_dict(local_state,
|
|
|
|
|
|
|
|
ParallelMode.PARALLEL_1D,
|
|
|
|
|
|
|
|
dims={
|
|
|
|
|
|
|
|
weight_key: 0,
|
|
|
|
|
|
|
|
bias_key: 0
|
|
|
|
|
|
|
|
},
|
|
|
|
|
|
|
|
partition_states={
|
|
|
|
|
|
|
|
weight_key: True,
|
|
|
|
|
|
|
|
bias_key: True
|
|
|
|
|
|
|
|
},
|
|
|
|
|
|
|
|
keep_vars=keep_vars)
|
|
|
|
|
|
|
|
destination.update(local_state)
|
|
|
|
|
|
|
|
|
|
|
|
def forward(self, input_: Tensor) -> Tensor:
|
|
|
|
def forward(self, input_: Tensor) -> Tensor:
|
|
|
|
assert input_.shape[-1] == self.weight.shape[-1], \
|
|
|
|
assert input_.shape[-1] == self.weight.shape[-1], \
|
|
|
|
'Invalid shapes in VocabParallelClassifier1D forward: input={}, weight={}. Expected last dim of input {}.'.format(
|
|
|
|
'Invalid shapes in VocabParallelClassifier1D forward: input={}, weight={}. Expected last dim of input {}.'.format(
|
|
|
@ -328,6 +464,52 @@ class Linear1D_Col(ParallelLayer):
|
|
|
|
if self.bias is not None:
|
|
|
|
if self.bias is not None:
|
|
|
|
set_tensor_parallel_attribute_by_partition(self.bias, num_partition)
|
|
|
|
set_tensor_parallel_attribute_by_partition(self.bias, num_partition)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _load_from_state_dict(self, state_dict, prefix, *args):
|
|
|
|
|
|
|
|
local_state = OrderedDict()
|
|
|
|
|
|
|
|
weight_key = prefix + 'weight'
|
|
|
|
|
|
|
|
bias_key = prefix + 'bias'
|
|
|
|
|
|
|
|
if gpc.get_local_rank(ParallelMode.TENSOR) == 0:
|
|
|
|
|
|
|
|
# weight
|
|
|
|
|
|
|
|
weight = state_dict.pop(weight_key, None)
|
|
|
|
|
|
|
|
if weight is not None:
|
|
|
|
|
|
|
|
local_state[weight_key] = weight
|
|
|
|
|
|
|
|
# bias
|
|
|
|
|
|
|
|
if self.bias is not None:
|
|
|
|
|
|
|
|
bias = state_dict.pop(bias_key, None)
|
|
|
|
|
|
|
|
if bias is not None:
|
|
|
|
|
|
|
|
local_state[bias_key] = bias
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
local_state = partition_tensor_parallel_state_dict(local_state,
|
|
|
|
|
|
|
|
ParallelMode.PARALLEL_1D,
|
|
|
|
|
|
|
|
dims={
|
|
|
|
|
|
|
|
weight_key: 0,
|
|
|
|
|
|
|
|
bias_key: 0
|
|
|
|
|
|
|
|
},
|
|
|
|
|
|
|
|
partition_states={
|
|
|
|
|
|
|
|
weight_key: True,
|
|
|
|
|
|
|
|
bias_key: True
|
|
|
|
|
|
|
|
})
|
|
|
|
|
|
|
|
super()._load_from_state_dict(local_state, prefix, *args)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _save_to_state_dict(self, destination, prefix, keep_vars):
|
|
|
|
|
|
|
|
weight_key = prefix + 'weight'
|
|
|
|
|
|
|
|
bias_key = prefix + 'bias'
|
|
|
|
|
|
|
|
local_state = OrderedDict({weight_key: self.weight})
|
|
|
|
|
|
|
|
if self.bias is not None:
|
|
|
|
|
|
|
|
local_state[bias_key] = self.bias
|
|
|
|
|
|
|
|
local_state = gather_tensor_parallel_state_dict(local_state,
|
|
|
|
|
|
|
|
ParallelMode.PARALLEL_1D,
|
|
|
|
|
|
|
|
dims={
|
|
|
|
|
|
|
|
weight_key: 0,
|
|
|
|
|
|
|
|
bias_key: 0
|
|
|
|
|
|
|
|
},
|
|
|
|
|
|
|
|
partition_states={
|
|
|
|
|
|
|
|
weight_key: True,
|
|
|
|
|
|
|
|
bias_key: True
|
|
|
|
|
|
|
|
},
|
|
|
|
|
|
|
|
keep_vars=keep_vars)
|
|
|
|
|
|
|
|
destination.update(local_state)
|
|
|
|
|
|
|
|
|
|
|
|
def forward(self, input_: Tensor) -> Tuple[Tensor, Tensor]:
|
|
|
|
def forward(self, input_: Tensor) -> Tuple[Tensor, Tensor]:
|
|
|
|
assert input_.shape[-1] == self.weight.shape[-1], \
|
|
|
|
assert input_.shape[-1] == self.weight.shape[-1], \
|
|
|
|
'Invalid shapes in Linear1D_Col forward: input={}, weight={}. Expected last dim of input {}.'.format(
|
|
|
|
'Invalid shapes in Linear1D_Col forward: input={}, weight={}. Expected last dim of input {}.'.format(
|
|
|
@ -420,6 +602,52 @@ class Linear1D_Row(ParallelLayer):
|
|
|
|
num_partition = gpc.get_world_size(ParallelMode.TENSOR)
|
|
|
|
num_partition = gpc.get_world_size(ParallelMode.TENSOR)
|
|
|
|
set_tensor_parallel_attribute_by_partition(self.weight, num_partition)
|
|
|
|
set_tensor_parallel_attribute_by_partition(self.weight, num_partition)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _load_from_state_dict(self, state_dict, prefix, *args):
|
|
|
|
|
|
|
|
local_state = OrderedDict()
|
|
|
|
|
|
|
|
weight_key = prefix + 'weight'
|
|
|
|
|
|
|
|
bias_key = prefix + 'bias'
|
|
|
|
|
|
|
|
if gpc.get_local_rank(ParallelMode.TENSOR) == 0:
|
|
|
|
|
|
|
|
# weight
|
|
|
|
|
|
|
|
weight = state_dict.pop(weight_key, None)
|
|
|
|
|
|
|
|
if weight is not None:
|
|
|
|
|
|
|
|
local_state[weight_key] = weight
|
|
|
|
|
|
|
|
# bias
|
|
|
|
|
|
|
|
if self.bias is not None:
|
|
|
|
|
|
|
|
bias = state_dict.pop(bias_key, None)
|
|
|
|
|
|
|
|
if bias is not None:
|
|
|
|
|
|
|
|
local_state[bias_key] = bias
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
local_state = partition_tensor_parallel_state_dict(local_state,
|
|
|
|
|
|
|
|
ParallelMode.PARALLEL_1D,
|
|
|
|
|
|
|
|
dims={
|
|
|
|
|
|
|
|
weight_key: -1,
|
|
|
|
|
|
|
|
bias_key: 0
|
|
|
|
|
|
|
|
},
|
|
|
|
|
|
|
|
partition_states={
|
|
|
|
|
|
|
|
weight_key: True,
|
|
|
|
|
|
|
|
bias_key: False
|
|
|
|
|
|
|
|
})
|
|
|
|
|
|
|
|
super()._load_from_state_dict(local_state, prefix, *args)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _save_to_state_dict(self, destination, prefix, keep_vars):
|
|
|
|
|
|
|
|
weight_key = prefix + 'weight'
|
|
|
|
|
|
|
|
bias_key = prefix + 'bias'
|
|
|
|
|
|
|
|
local_state = OrderedDict({weight_key: self.weight})
|
|
|
|
|
|
|
|
if self.bias is not None:
|
|
|
|
|
|
|
|
local_state[bias_key] = self.bias
|
|
|
|
|
|
|
|
local_state = gather_tensor_parallel_state_dict(local_state,
|
|
|
|
|
|
|
|
ParallelMode.PARALLEL_1D,
|
|
|
|
|
|
|
|
dims={
|
|
|
|
|
|
|
|
weight_key: -1,
|
|
|
|
|
|
|
|
bias_key: 0
|
|
|
|
|
|
|
|
},
|
|
|
|
|
|
|
|
partition_states={
|
|
|
|
|
|
|
|
weight_key: True,
|
|
|
|
|
|
|
|
bias_key: False
|
|
|
|
|
|
|
|
},
|
|
|
|
|
|
|
|
keep_vars=keep_vars)
|
|
|
|
|
|
|
|
destination.update(local_state)
|
|
|
|
|
|
|
|
|
|
|
|
def forward(self, input_: Tensor) -> Tensor:
|
|
|
|
def forward(self, input_: Tensor) -> Tensor:
|
|
|
|
# Set up backprop all-reduce.
|
|
|
|
# Set up backprop all-reduce.
|
|
|
|
if self.parallel_input:
|
|
|
|
if self.parallel_input:
|
|
|
@ -514,6 +742,31 @@ class Embedding1D(ParallelLayer):
|
|
|
|
with torch.no_grad():
|
|
|
|
with torch.no_grad():
|
|
|
|
self.weight[self.padding_idx].fill_(0)
|
|
|
|
self.weight[self.padding_idx].fill_(0)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _load_from_state_dict(self, state_dict, prefix, *args):
|
|
|
|
|
|
|
|
local_state = OrderedDict()
|
|
|
|
|
|
|
|
weight_key = prefix + 'weight'
|
|
|
|
|
|
|
|
if gpc.get_local_rank(ParallelMode.TENSOR) == 0:
|
|
|
|
|
|
|
|
# weight
|
|
|
|
|
|
|
|
weight = state_dict.pop(weight_key, None)
|
|
|
|
|
|
|
|
if weight is not None:
|
|
|
|
|
|
|
|
local_state[weight_key] = weight
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
local_state = partition_tensor_parallel_state_dict(local_state,
|
|
|
|
|
|
|
|
ParallelMode.PARALLEL_1D,
|
|
|
|
|
|
|
|
dims={weight_key: -1},
|
|
|
|
|
|
|
|
partition_states={weight_key: True})
|
|
|
|
|
|
|
|
super()._load_from_state_dict(local_state, prefix, *args)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _save_to_state_dict(self, destination, prefix, keep_vars):
|
|
|
|
|
|
|
|
weight_key = prefix + 'weight'
|
|
|
|
|
|
|
|
local_state = OrderedDict({weight_key: self.weight})
|
|
|
|
|
|
|
|
local_state = gather_tensor_parallel_state_dict(local_state,
|
|
|
|
|
|
|
|
ParallelMode.PARALLEL_1D,
|
|
|
|
|
|
|
|
dims={weight_key: -1},
|
|
|
|
|
|
|
|
partition_states={weight_key: True},
|
|
|
|
|
|
|
|
keep_vars=keep_vars)
|
|
|
|
|
|
|
|
destination.update(local_state)
|
|
|
|
|
|
|
|
|
|
|
|
def forward(self, input_: Tensor) -> Tensor:
|
|
|
|
def forward(self, input_: Tensor) -> Tensor:
|
|
|
|
|
|
|
|
|
|
|
|
output_parallel = F.embedding(input_, self.weight, self.padding_idx, *self.embed_args, **self.embed_kwargs)
|
|
|
|
output_parallel = F.embedding(input_, self.weight, self.padding_idx, *self.embed_args, **self.embed_kwargs)
|
|
|
@ -594,10 +847,35 @@ class VocabParallelEmbedding1D(torch.nn.Module):
|
|
|
|
|
|
|
|
|
|
|
|
def _fill_padding_idx_with_zero(self) -> None:
|
|
|
|
def _fill_padding_idx_with_zero(self) -> None:
|
|
|
|
if self.padding_idx is not None and \
|
|
|
|
if self.padding_idx is not None and \
|
|
|
|
self.padding_idx >= self.vocab_start_index and self.padding_idx < self.vocab_end_index:
|
|
|
|
self.padding_idx >= self.vocab_start_index and self.padding_idx < self.vocab_end_index:
|
|
|
|
with torch.no_grad():
|
|
|
|
with torch.no_grad():
|
|
|
|
self.weight[self.padding_idx - self.vocab_start_index].fill_(0)
|
|
|
|
self.weight[self.padding_idx - self.vocab_start_index].fill_(0)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _load_from_state_dict(self, state_dict, prefix, *args):
|
|
|
|
|
|
|
|
local_state = OrderedDict()
|
|
|
|
|
|
|
|
weight_key = prefix + 'weight'
|
|
|
|
|
|
|
|
if gpc.get_local_rank(ParallelMode.TENSOR) == 0:
|
|
|
|
|
|
|
|
# weight
|
|
|
|
|
|
|
|
weight = state_dict.pop(weight_key, None)
|
|
|
|
|
|
|
|
if weight is not None:
|
|
|
|
|
|
|
|
local_state[weight_key] = weight
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
local_state = partition_tensor_parallel_state_dict(local_state,
|
|
|
|
|
|
|
|
ParallelMode.PARALLEL_1D,
|
|
|
|
|
|
|
|
dims={weight_key: 0},
|
|
|
|
|
|
|
|
partition_states={weight_key: True})
|
|
|
|
|
|
|
|
super()._load_from_state_dict(local_state, prefix, *args)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _save_to_state_dict(self, destination, prefix, keep_vars):
|
|
|
|
|
|
|
|
weight_key = prefix + 'weight'
|
|
|
|
|
|
|
|
local_state = OrderedDict({weight_key: self.weight})
|
|
|
|
|
|
|
|
local_state = gather_tensor_parallel_state_dict(local_state,
|
|
|
|
|
|
|
|
ParallelMode.PARALLEL_1D,
|
|
|
|
|
|
|
|
dims={weight_key: 0},
|
|
|
|
|
|
|
|
partition_states={weight_key: True},
|
|
|
|
|
|
|
|
keep_vars=keep_vars)
|
|
|
|
|
|
|
|
destination.update(local_state)
|
|
|
|
|
|
|
|
|
|
|
|
def forward(self, input_: Tensor) -> Tensor:
|
|
|
|
def forward(self, input_: Tensor) -> Tensor:
|
|
|
|
# Build the mask.
|
|
|
|
# Build the mask.
|
|
|
|
input_mask = (input_ < self.vocab_start_index) | (input_ >= self.vocab_end_index)
|
|
|
|
input_mask = (input_ < self.vocab_start_index) | (input_ >= self.vocab_end_index)
|
|
|
@ -637,3 +915,66 @@ class Dropout1D(ParallelLayer):
|
|
|
|
else:
|
|
|
|
else:
|
|
|
|
output = F.dropout(input_, self.p, self.training, self.inplace)
|
|
|
|
output = F.dropout(input_, self.p, self.training, self.inplace)
|
|
|
|
return output
|
|
|
|
return output
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@LAYERS.register_module
|
|
|
|
|
|
|
|
class PatchEmbedding1D(ColossalaiModule):
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
2D Image to Patch Embedding
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
:param img_size: image size
|
|
|
|
|
|
|
|
:type img_size: int
|
|
|
|
|
|
|
|
:param patch_size: patch size
|
|
|
|
|
|
|
|
:type patch_size: int
|
|
|
|
|
|
|
|
:param in_chans: number of channels of input image
|
|
|
|
|
|
|
|
:type in_chans: int
|
|
|
|
|
|
|
|
:param embed_size: size of embedding
|
|
|
|
|
|
|
|
:type embed_size: int
|
|
|
|
|
|
|
|
:param dtype: The dtype of parameters, defaults to None
|
|
|
|
|
|
|
|
:type dtype: torch.dtype, optional
|
|
|
|
|
|
|
|
:param flatten: whether to flatten output tensor, defaults to True
|
|
|
|
|
|
|
|
:type flatten: bool, optional
|
|
|
|
|
|
|
|
:param weight_initializer: The intializer of weight, defaults to kaiming uniform initializer
|
|
|
|
|
|
|
|
:type weight_initializer: typing.Callable, optional
|
|
|
|
|
|
|
|
:param bias_initializer: The intializer of bias, defaults to xavier uniform initializer
|
|
|
|
|
|
|
|
:type bias_initializer: typing.Callable, optional
|
|
|
|
|
|
|
|
:param position_embed_initializer: The intializer of position embedding, defaults to zero
|
|
|
|
|
|
|
|
:type position_embed_initializer: typing.Callable, optional
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def __init__(self,
|
|
|
|
|
|
|
|
img_size: int,
|
|
|
|
|
|
|
|
patch_size: int,
|
|
|
|
|
|
|
|
in_chans: int,
|
|
|
|
|
|
|
|
embed_size: int,
|
|
|
|
|
|
|
|
dtype: torch.dtype = None,
|
|
|
|
|
|
|
|
flatten: bool = True,
|
|
|
|
|
|
|
|
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
|
|
|
|
|
|
|
|
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1),
|
|
|
|
|
|
|
|
position_embed_initializer: Callable = init.zeros_()):
|
|
|
|
|
|
|
|
embed = VanillaPatchEmbedding(img_size,
|
|
|
|
|
|
|
|
patch_size,
|
|
|
|
|
|
|
|
in_chans,
|
|
|
|
|
|
|
|
embed_size,
|
|
|
|
|
|
|
|
dtype=dtype,
|
|
|
|
|
|
|
|
flatten=flatten,
|
|
|
|
|
|
|
|
weight_initializer=weight_initializer,
|
|
|
|
|
|
|
|
bias_initializer=bias_initializer,
|
|
|
|
|
|
|
|
position_embed_initializer=position_embed_initializer)
|
|
|
|
|
|
|
|
super().__init__(embed)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _load_from_state_dict(self, state_dict, prefix, *args):
|
|
|
|
|
|
|
|
local_state = OrderedDict()
|
|
|
|
|
|
|
|
param_keys = [prefix + 'weight', prefix + 'bias', prefix + 'cls_token', prefix + 'pos_embed']
|
|
|
|
|
|
|
|
if gpc.get_local_rank(ParallelMode.TENSOR) == 0:
|
|
|
|
|
|
|
|
for key in param_keys:
|
|
|
|
|
|
|
|
param = state_dict.pop(key, None)
|
|
|
|
|
|
|
|
if param is not None:
|
|
|
|
|
|
|
|
local_state[key] = param
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
local_state = broadcast_state_dict(local_state, ParallelMode.PARALLEL_1D)
|
|
|
|
|
|
|
|
super()._load_from_state_dict(local_state, prefix, *args)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _save_to_state_dict(self, destination, prefix, keep_vars):
|
|
|
|
|
|
|
|
if gpc.get_local_rank(ParallelMode.TENSOR) == 0:
|
|
|
|
|
|
|
|
super()._save_to_state_dict(destination, prefix, keep_vars)
|
|
|
|