mirror of https://github.com/hpcaitech/ColossalAI
[model checkpoint] updated saving/loading for 1d layers (#594)
parent
7636d518e1
commit
c50bfb807b
|
@ -1,7 +1,7 @@
|
||||||
from .layers import (Classifier1D, Dropout1D, Embedding1D, Linear1D, Linear1D_Col, Linear1D_Row,
|
from .layers import (Classifier1D, Dropout1D, Embedding1D, LayerNorm1D, Linear1D, Linear1D_Col, Linear1D_Row,
|
||||||
VocabParallelClassifier1D, VocabParallelEmbedding1D)
|
PatchEmbedding1D, VocabParallelClassifier1D, VocabParallelEmbedding1D)
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
'Linear1D', 'Linear1D_Col', 'Linear1D_Row', 'Embedding1D', 'Dropout1D', 'Classifier1D', 'VocabParallelClassifier1D',
|
'Linear1D', 'Linear1D_Col', 'Linear1D_Row', 'Embedding1D', 'Dropout1D', 'Classifier1D', 'VocabParallelClassifier1D',
|
||||||
'VocabParallelEmbedding1D'
|
'VocabParallelEmbedding1D', 'LayerNorm1D', 'PatchEmbedding1D'
|
||||||
]
|
]
|
||||||
|
|
|
@ -2,6 +2,7 @@
|
||||||
# -*- encoding: utf-8 -*-
|
# -*- encoding: utf-8 -*-
|
||||||
|
|
||||||
import math
|
import math
|
||||||
|
from collections import OrderedDict
|
||||||
from typing import Callable, Tuple
|
from typing import Callable, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
@ -10,20 +11,25 @@ from colossalai.communication import broadcast
|
||||||
from colossalai.context import ParallelMode, seed
|
from colossalai.context import ParallelMode, seed
|
||||||
from colossalai.core import global_context as gpc
|
from colossalai.core import global_context as gpc
|
||||||
from colossalai.global_variables import tensor_parallel_env as env
|
from colossalai.global_variables import tensor_parallel_env as env
|
||||||
|
from colossalai.kernel import LayerNorm
|
||||||
from colossalai.nn import init as init
|
from colossalai.nn import init as init
|
||||||
from colossalai.registry import LAYERS
|
from colossalai.registry import LAYERS
|
||||||
|
from colossalai.utils.checkpointing import (broadcast_state_dict, gather_tensor_parallel_state_dict,
|
||||||
|
partition_tensor_parallel_state_dict)
|
||||||
from colossalai.utils.cuda import get_current_device
|
from colossalai.utils.cuda import get_current_device
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
from torch.nn.parameter import Parameter
|
from torch.nn.parameter import Parameter
|
||||||
|
from ..vanilla import VanillaPatchEmbedding
|
||||||
|
|
||||||
from ..base_layer import ParallelLayer
|
from ..base_layer import ParallelLayer
|
||||||
|
from ..colossalai_layer._utils import ColossalaiModule
|
||||||
from ..utils import divide, set_tensor_parallel_attribute_by_partition
|
from ..utils import divide, set_tensor_parallel_attribute_by_partition
|
||||||
from ._utils import (gather_forward_split_backward, get_parallel_input, reduce_grad, reduce_input, set_parallel_input,
|
from ._utils import (gather_forward_split_backward, get_parallel_input, reduce_grad, reduce_input, set_parallel_input,
|
||||||
split_forward_gather_backward)
|
split_forward_gather_backward)
|
||||||
|
|
||||||
|
|
||||||
@LAYERS.register_module
|
@LAYERS.register_module
|
||||||
class Linear1D(torch.nn.Module):
|
class Linear1D(ColossalaiModule):
|
||||||
r"""Linear layer for 1D parallelism.
|
r"""Linear layer for 1D parallelism.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
@ -52,10 +58,9 @@ class Linear1D(torch.nn.Module):
|
||||||
skip_bias_add: bool = False,
|
skip_bias_add: bool = False,
|
||||||
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
|
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
|
||||||
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1)):
|
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1)):
|
||||||
super().__init__()
|
|
||||||
parallel_input = get_parallel_input()
|
parallel_input = get_parallel_input()
|
||||||
if not parallel_input:
|
if not parallel_input:
|
||||||
self.layer = Linear1D_Col(in_features,
|
layer = Linear1D_Col(in_features,
|
||||||
out_features,
|
out_features,
|
||||||
bias=bias,
|
bias=bias,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
|
@ -64,7 +69,7 @@ class Linear1D(torch.nn.Module):
|
||||||
weight_initializer=weight_initializer,
|
weight_initializer=weight_initializer,
|
||||||
bias_initializer=bias_initializer)
|
bias_initializer=bias_initializer)
|
||||||
else:
|
else:
|
||||||
self.layer = Linear1D_Row(in_features,
|
layer = Linear1D_Row(in_features,
|
||||||
out_features,
|
out_features,
|
||||||
bias=bias,
|
bias=bias,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
|
@ -72,17 +77,50 @@ class Linear1D(torch.nn.Module):
|
||||||
skip_bias_add=skip_bias_add,
|
skip_bias_add=skip_bias_add,
|
||||||
weight_initializer=weight_initializer,
|
weight_initializer=weight_initializer,
|
||||||
bias_initializer=bias_initializer)
|
bias_initializer=bias_initializer)
|
||||||
|
super().__init__(layer)
|
||||||
|
|
||||||
@property
|
|
||||||
def weight(self):
|
|
||||||
return self.layer.weight
|
|
||||||
|
|
||||||
@property
|
@LAYERS.register_module
|
||||||
def bias(self):
|
class LayerNorm1D(ColossalaiModule):
|
||||||
return self.layer.bias
|
r"""
|
||||||
|
Layer Normalization for colossalai
|
||||||
|
|
||||||
def forward(self, input_: Tensor) -> Tensor:
|
:param normalized_shape: input shape from an expected input
|
||||||
return self.layer(input_)
|
of size. :math:`[* \times \text{normalized_shape}[0] \times \text{normalized_shape}[1]
|
||||||
|
\times \ldots \times \text{normalized_shape}[-1]]`
|
||||||
|
If a single integer is used, it is treated as a singleton list, and this module will
|
||||||
|
normalize over the last dimension which is expected to be of that specific size.
|
||||||
|
:type normalized_shape: int
|
||||||
|
:param eps: a value added to the denominator for numerical stability, defaults to 1e-05
|
||||||
|
:type eps: float, optional
|
||||||
|
:param dtype: The dtype of parameters, defaults to None
|
||||||
|
:type dtype: torch.dtype, optional
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, normalized_shape: int, eps=1e-05, dtype=None):
|
||||||
|
norm = LayerNorm(normalized_shape, eps=eps, device=get_current_device(), dtype=dtype)
|
||||||
|
super().__init__(norm)
|
||||||
|
|
||||||
|
def _load_from_state_dict(self, state_dict, prefix, *args):
|
||||||
|
local_state = OrderedDict()
|
||||||
|
weight_key = prefix + 'weight'
|
||||||
|
bias_key = prefix + 'bias'
|
||||||
|
if gpc.get_local_rank(ParallelMode.TENSOR) == 0:
|
||||||
|
# weight
|
||||||
|
weight = state_dict.pop(weight_key, None)
|
||||||
|
if weight is not None:
|
||||||
|
local_state[weight_key] = weight
|
||||||
|
# bias
|
||||||
|
bias = state_dict.pop(bias_key, None)
|
||||||
|
if bias is not None:
|
||||||
|
local_state[bias_key] = bias
|
||||||
|
|
||||||
|
local_state = broadcast_state_dict(local_state, ParallelMode.PARALLEL_1D)
|
||||||
|
super()._load_from_state_dict(local_state, prefix, *args)
|
||||||
|
|
||||||
|
def _save_to_state_dict(self, destination, prefix, keep_vars):
|
||||||
|
if gpc.get_local_rank(ParallelMode.TENSOR) == 0:
|
||||||
|
super()._save_to_state_dict(destination, prefix, keep_vars)
|
||||||
|
|
||||||
|
|
||||||
@LAYERS.register_module
|
@LAYERS.register_module
|
||||||
|
@ -153,6 +191,55 @@ class Classifier1D(ParallelLayer):
|
||||||
num_partition = gpc.get_world_size(ParallelMode.TENSOR)
|
num_partition = gpc.get_world_size(ParallelMode.TENSOR)
|
||||||
set_tensor_parallel_attribute_by_partition(self.weight, num_partition)
|
set_tensor_parallel_attribute_by_partition(self.weight, num_partition)
|
||||||
|
|
||||||
|
def _load_from_state_dict(self, state_dict, prefix, *args):
|
||||||
|
local_state = OrderedDict()
|
||||||
|
weight_key = prefix + 'weight'
|
||||||
|
bias_key = prefix + 'bias'
|
||||||
|
if gpc.get_local_rank(ParallelMode.TENSOR) == 0:
|
||||||
|
# weight
|
||||||
|
if self.has_weight:
|
||||||
|
weight = state_dict.pop(weight_key, None)
|
||||||
|
if weight is not None:
|
||||||
|
local_state[weight_key] = weight
|
||||||
|
# bias
|
||||||
|
if self.bias is not None:
|
||||||
|
bias = state_dict.pop(bias_key, None)
|
||||||
|
if bias is not None:
|
||||||
|
local_state[bias_key] = bias
|
||||||
|
|
||||||
|
local_state = partition_tensor_parallel_state_dict(local_state,
|
||||||
|
ParallelMode.PARALLEL_1D,
|
||||||
|
dims={
|
||||||
|
weight_key: -1,
|
||||||
|
bias_key: 0
|
||||||
|
},
|
||||||
|
partition_states={
|
||||||
|
weight_key: True,
|
||||||
|
bias_key: False
|
||||||
|
})
|
||||||
|
super()._load_from_state_dict(local_state, prefix, *args)
|
||||||
|
|
||||||
|
def _save_to_state_dict(self, destination, prefix, keep_vars):
|
||||||
|
weight_key = prefix + 'weight'
|
||||||
|
bias_key = prefix + 'bias'
|
||||||
|
local_state = OrderedDict()
|
||||||
|
if self.has_weight:
|
||||||
|
local_state[weight_key] = self.weight
|
||||||
|
if self.bias is not None:
|
||||||
|
local_state[bias_key] = self.bias
|
||||||
|
local_state = gather_tensor_parallel_state_dict(local_state,
|
||||||
|
ParallelMode.PARALLEL_1D,
|
||||||
|
dims={
|
||||||
|
weight_key: -1,
|
||||||
|
bias_key: 0
|
||||||
|
},
|
||||||
|
partition_states={
|
||||||
|
weight_key: True,
|
||||||
|
bias_key: False
|
||||||
|
},
|
||||||
|
keep_vars=keep_vars)
|
||||||
|
destination.update(local_state)
|
||||||
|
|
||||||
def forward(self, input_: Tensor) -> Tensor:
|
def forward(self, input_: Tensor) -> Tensor:
|
||||||
# Set up backprop all-reduce.
|
# Set up backprop all-reduce.
|
||||||
if self.parallel_input:
|
if self.parallel_input:
|
||||||
|
@ -241,6 +328,55 @@ class VocabParallelClassifier1D(ParallelLayer):
|
||||||
if self.bias is not None:
|
if self.bias is not None:
|
||||||
set_tensor_parallel_attribute_by_partition(self.bias, num_partition)
|
set_tensor_parallel_attribute_by_partition(self.bias, num_partition)
|
||||||
|
|
||||||
|
def _load_from_state_dict(self, state_dict, prefix, *args):
|
||||||
|
local_state = OrderedDict()
|
||||||
|
weight_key = prefix + 'weight'
|
||||||
|
bias_key = prefix + 'bias'
|
||||||
|
if gpc.get_local_rank(ParallelMode.TENSOR) == 0:
|
||||||
|
# weight
|
||||||
|
if self.has_weight:
|
||||||
|
weight = state_dict.pop(weight_key, None)
|
||||||
|
if weight is not None:
|
||||||
|
local_state[weight_key] = weight
|
||||||
|
# bias
|
||||||
|
if self.bias is not None:
|
||||||
|
bias = state_dict.pop(bias_key, None)
|
||||||
|
if bias is not None:
|
||||||
|
local_state[bias_key] = bias
|
||||||
|
|
||||||
|
local_state = partition_tensor_parallel_state_dict(local_state,
|
||||||
|
ParallelMode.PARALLEL_1D,
|
||||||
|
dims={
|
||||||
|
weight_key: 0,
|
||||||
|
bias_key: 0
|
||||||
|
},
|
||||||
|
partition_states={
|
||||||
|
weight_key: True,
|
||||||
|
bias_key: True
|
||||||
|
})
|
||||||
|
super()._load_from_state_dict(local_state, prefix, *args)
|
||||||
|
|
||||||
|
def _save_to_state_dict(self, destination, prefix, keep_vars):
|
||||||
|
weight_key = prefix + 'weight'
|
||||||
|
bias_key = prefix + 'bias'
|
||||||
|
local_state = OrderedDict()
|
||||||
|
if self.has_weight:
|
||||||
|
local_state[weight_key] = self.weight
|
||||||
|
if self.bias is not None:
|
||||||
|
local_state[bias_key] = self.bias
|
||||||
|
local_state = gather_tensor_parallel_state_dict(local_state,
|
||||||
|
ParallelMode.PARALLEL_1D,
|
||||||
|
dims={
|
||||||
|
weight_key: 0,
|
||||||
|
bias_key: 0
|
||||||
|
},
|
||||||
|
partition_states={
|
||||||
|
weight_key: True,
|
||||||
|
bias_key: True
|
||||||
|
},
|
||||||
|
keep_vars=keep_vars)
|
||||||
|
destination.update(local_state)
|
||||||
|
|
||||||
def forward(self, input_: Tensor) -> Tensor:
|
def forward(self, input_: Tensor) -> Tensor:
|
||||||
assert input_.shape[-1] == self.weight.shape[-1], \
|
assert input_.shape[-1] == self.weight.shape[-1], \
|
||||||
'Invalid shapes in VocabParallelClassifier1D forward: input={}, weight={}. Expected last dim of input {}.'.format(
|
'Invalid shapes in VocabParallelClassifier1D forward: input={}, weight={}. Expected last dim of input {}.'.format(
|
||||||
|
@ -328,6 +464,52 @@ class Linear1D_Col(ParallelLayer):
|
||||||
if self.bias is not None:
|
if self.bias is not None:
|
||||||
set_tensor_parallel_attribute_by_partition(self.bias, num_partition)
|
set_tensor_parallel_attribute_by_partition(self.bias, num_partition)
|
||||||
|
|
||||||
|
def _load_from_state_dict(self, state_dict, prefix, *args):
|
||||||
|
local_state = OrderedDict()
|
||||||
|
weight_key = prefix + 'weight'
|
||||||
|
bias_key = prefix + 'bias'
|
||||||
|
if gpc.get_local_rank(ParallelMode.TENSOR) == 0:
|
||||||
|
# weight
|
||||||
|
weight = state_dict.pop(weight_key, None)
|
||||||
|
if weight is not None:
|
||||||
|
local_state[weight_key] = weight
|
||||||
|
# bias
|
||||||
|
if self.bias is not None:
|
||||||
|
bias = state_dict.pop(bias_key, None)
|
||||||
|
if bias is not None:
|
||||||
|
local_state[bias_key] = bias
|
||||||
|
|
||||||
|
local_state = partition_tensor_parallel_state_dict(local_state,
|
||||||
|
ParallelMode.PARALLEL_1D,
|
||||||
|
dims={
|
||||||
|
weight_key: 0,
|
||||||
|
bias_key: 0
|
||||||
|
},
|
||||||
|
partition_states={
|
||||||
|
weight_key: True,
|
||||||
|
bias_key: True
|
||||||
|
})
|
||||||
|
super()._load_from_state_dict(local_state, prefix, *args)
|
||||||
|
|
||||||
|
def _save_to_state_dict(self, destination, prefix, keep_vars):
|
||||||
|
weight_key = prefix + 'weight'
|
||||||
|
bias_key = prefix + 'bias'
|
||||||
|
local_state = OrderedDict({weight_key: self.weight})
|
||||||
|
if self.bias is not None:
|
||||||
|
local_state[bias_key] = self.bias
|
||||||
|
local_state = gather_tensor_parallel_state_dict(local_state,
|
||||||
|
ParallelMode.PARALLEL_1D,
|
||||||
|
dims={
|
||||||
|
weight_key: 0,
|
||||||
|
bias_key: 0
|
||||||
|
},
|
||||||
|
partition_states={
|
||||||
|
weight_key: True,
|
||||||
|
bias_key: True
|
||||||
|
},
|
||||||
|
keep_vars=keep_vars)
|
||||||
|
destination.update(local_state)
|
||||||
|
|
||||||
def forward(self, input_: Tensor) -> Tuple[Tensor, Tensor]:
|
def forward(self, input_: Tensor) -> Tuple[Tensor, Tensor]:
|
||||||
assert input_.shape[-1] == self.weight.shape[-1], \
|
assert input_.shape[-1] == self.weight.shape[-1], \
|
||||||
'Invalid shapes in Linear1D_Col forward: input={}, weight={}. Expected last dim of input {}.'.format(
|
'Invalid shapes in Linear1D_Col forward: input={}, weight={}. Expected last dim of input {}.'.format(
|
||||||
|
@ -420,6 +602,52 @@ class Linear1D_Row(ParallelLayer):
|
||||||
num_partition = gpc.get_world_size(ParallelMode.TENSOR)
|
num_partition = gpc.get_world_size(ParallelMode.TENSOR)
|
||||||
set_tensor_parallel_attribute_by_partition(self.weight, num_partition)
|
set_tensor_parallel_attribute_by_partition(self.weight, num_partition)
|
||||||
|
|
||||||
|
def _load_from_state_dict(self, state_dict, prefix, *args):
|
||||||
|
local_state = OrderedDict()
|
||||||
|
weight_key = prefix + 'weight'
|
||||||
|
bias_key = prefix + 'bias'
|
||||||
|
if gpc.get_local_rank(ParallelMode.TENSOR) == 0:
|
||||||
|
# weight
|
||||||
|
weight = state_dict.pop(weight_key, None)
|
||||||
|
if weight is not None:
|
||||||
|
local_state[weight_key] = weight
|
||||||
|
# bias
|
||||||
|
if self.bias is not None:
|
||||||
|
bias = state_dict.pop(bias_key, None)
|
||||||
|
if bias is not None:
|
||||||
|
local_state[bias_key] = bias
|
||||||
|
|
||||||
|
local_state = partition_tensor_parallel_state_dict(local_state,
|
||||||
|
ParallelMode.PARALLEL_1D,
|
||||||
|
dims={
|
||||||
|
weight_key: -1,
|
||||||
|
bias_key: 0
|
||||||
|
},
|
||||||
|
partition_states={
|
||||||
|
weight_key: True,
|
||||||
|
bias_key: False
|
||||||
|
})
|
||||||
|
super()._load_from_state_dict(local_state, prefix, *args)
|
||||||
|
|
||||||
|
def _save_to_state_dict(self, destination, prefix, keep_vars):
|
||||||
|
weight_key = prefix + 'weight'
|
||||||
|
bias_key = prefix + 'bias'
|
||||||
|
local_state = OrderedDict({weight_key: self.weight})
|
||||||
|
if self.bias is not None:
|
||||||
|
local_state[bias_key] = self.bias
|
||||||
|
local_state = gather_tensor_parallel_state_dict(local_state,
|
||||||
|
ParallelMode.PARALLEL_1D,
|
||||||
|
dims={
|
||||||
|
weight_key: -1,
|
||||||
|
bias_key: 0
|
||||||
|
},
|
||||||
|
partition_states={
|
||||||
|
weight_key: True,
|
||||||
|
bias_key: False
|
||||||
|
},
|
||||||
|
keep_vars=keep_vars)
|
||||||
|
destination.update(local_state)
|
||||||
|
|
||||||
def forward(self, input_: Tensor) -> Tensor:
|
def forward(self, input_: Tensor) -> Tensor:
|
||||||
# Set up backprop all-reduce.
|
# Set up backprop all-reduce.
|
||||||
if self.parallel_input:
|
if self.parallel_input:
|
||||||
|
@ -514,6 +742,31 @@ class Embedding1D(ParallelLayer):
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
self.weight[self.padding_idx].fill_(0)
|
self.weight[self.padding_idx].fill_(0)
|
||||||
|
|
||||||
|
def _load_from_state_dict(self, state_dict, prefix, *args):
|
||||||
|
local_state = OrderedDict()
|
||||||
|
weight_key = prefix + 'weight'
|
||||||
|
if gpc.get_local_rank(ParallelMode.TENSOR) == 0:
|
||||||
|
# weight
|
||||||
|
weight = state_dict.pop(weight_key, None)
|
||||||
|
if weight is not None:
|
||||||
|
local_state[weight_key] = weight
|
||||||
|
|
||||||
|
local_state = partition_tensor_parallel_state_dict(local_state,
|
||||||
|
ParallelMode.PARALLEL_1D,
|
||||||
|
dims={weight_key: -1},
|
||||||
|
partition_states={weight_key: True})
|
||||||
|
super()._load_from_state_dict(local_state, prefix, *args)
|
||||||
|
|
||||||
|
def _save_to_state_dict(self, destination, prefix, keep_vars):
|
||||||
|
weight_key = prefix + 'weight'
|
||||||
|
local_state = OrderedDict({weight_key: self.weight})
|
||||||
|
local_state = gather_tensor_parallel_state_dict(local_state,
|
||||||
|
ParallelMode.PARALLEL_1D,
|
||||||
|
dims={weight_key: -1},
|
||||||
|
partition_states={weight_key: True},
|
||||||
|
keep_vars=keep_vars)
|
||||||
|
destination.update(local_state)
|
||||||
|
|
||||||
def forward(self, input_: Tensor) -> Tensor:
|
def forward(self, input_: Tensor) -> Tensor:
|
||||||
|
|
||||||
output_parallel = F.embedding(input_, self.weight, self.padding_idx, *self.embed_args, **self.embed_kwargs)
|
output_parallel = F.embedding(input_, self.weight, self.padding_idx, *self.embed_args, **self.embed_kwargs)
|
||||||
|
@ -598,6 +851,31 @@ class VocabParallelEmbedding1D(torch.nn.Module):
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
self.weight[self.padding_idx - self.vocab_start_index].fill_(0)
|
self.weight[self.padding_idx - self.vocab_start_index].fill_(0)
|
||||||
|
|
||||||
|
def _load_from_state_dict(self, state_dict, prefix, *args):
|
||||||
|
local_state = OrderedDict()
|
||||||
|
weight_key = prefix + 'weight'
|
||||||
|
if gpc.get_local_rank(ParallelMode.TENSOR) == 0:
|
||||||
|
# weight
|
||||||
|
weight = state_dict.pop(weight_key, None)
|
||||||
|
if weight is not None:
|
||||||
|
local_state[weight_key] = weight
|
||||||
|
|
||||||
|
local_state = partition_tensor_parallel_state_dict(local_state,
|
||||||
|
ParallelMode.PARALLEL_1D,
|
||||||
|
dims={weight_key: 0},
|
||||||
|
partition_states={weight_key: True})
|
||||||
|
super()._load_from_state_dict(local_state, prefix, *args)
|
||||||
|
|
||||||
|
def _save_to_state_dict(self, destination, prefix, keep_vars):
|
||||||
|
weight_key = prefix + 'weight'
|
||||||
|
local_state = OrderedDict({weight_key: self.weight})
|
||||||
|
local_state = gather_tensor_parallel_state_dict(local_state,
|
||||||
|
ParallelMode.PARALLEL_1D,
|
||||||
|
dims={weight_key: 0},
|
||||||
|
partition_states={weight_key: True},
|
||||||
|
keep_vars=keep_vars)
|
||||||
|
destination.update(local_state)
|
||||||
|
|
||||||
def forward(self, input_: Tensor) -> Tensor:
|
def forward(self, input_: Tensor) -> Tensor:
|
||||||
# Build the mask.
|
# Build the mask.
|
||||||
input_mask = (input_ < self.vocab_start_index) | (input_ >= self.vocab_end_index)
|
input_mask = (input_ < self.vocab_start_index) | (input_ >= self.vocab_end_index)
|
||||||
|
@ -637,3 +915,66 @@ class Dropout1D(ParallelLayer):
|
||||||
else:
|
else:
|
||||||
output = F.dropout(input_, self.p, self.training, self.inplace)
|
output = F.dropout(input_, self.p, self.training, self.inplace)
|
||||||
return output
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
@LAYERS.register_module
|
||||||
|
class PatchEmbedding1D(ColossalaiModule):
|
||||||
|
"""
|
||||||
|
2D Image to Patch Embedding
|
||||||
|
|
||||||
|
:param img_size: image size
|
||||||
|
:type img_size: int
|
||||||
|
:param patch_size: patch size
|
||||||
|
:type patch_size: int
|
||||||
|
:param in_chans: number of channels of input image
|
||||||
|
:type in_chans: int
|
||||||
|
:param embed_size: size of embedding
|
||||||
|
:type embed_size: int
|
||||||
|
:param dtype: The dtype of parameters, defaults to None
|
||||||
|
:type dtype: torch.dtype, optional
|
||||||
|
:param flatten: whether to flatten output tensor, defaults to True
|
||||||
|
:type flatten: bool, optional
|
||||||
|
:param weight_initializer: The intializer of weight, defaults to kaiming uniform initializer
|
||||||
|
:type weight_initializer: typing.Callable, optional
|
||||||
|
:param bias_initializer: The intializer of bias, defaults to xavier uniform initializer
|
||||||
|
:type bias_initializer: typing.Callable, optional
|
||||||
|
:param position_embed_initializer: The intializer of position embedding, defaults to zero
|
||||||
|
:type position_embed_initializer: typing.Callable, optional
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
img_size: int,
|
||||||
|
patch_size: int,
|
||||||
|
in_chans: int,
|
||||||
|
embed_size: int,
|
||||||
|
dtype: torch.dtype = None,
|
||||||
|
flatten: bool = True,
|
||||||
|
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
|
||||||
|
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1),
|
||||||
|
position_embed_initializer: Callable = init.zeros_()):
|
||||||
|
embed = VanillaPatchEmbedding(img_size,
|
||||||
|
patch_size,
|
||||||
|
in_chans,
|
||||||
|
embed_size,
|
||||||
|
dtype=dtype,
|
||||||
|
flatten=flatten,
|
||||||
|
weight_initializer=weight_initializer,
|
||||||
|
bias_initializer=bias_initializer,
|
||||||
|
position_embed_initializer=position_embed_initializer)
|
||||||
|
super().__init__(embed)
|
||||||
|
|
||||||
|
def _load_from_state_dict(self, state_dict, prefix, *args):
|
||||||
|
local_state = OrderedDict()
|
||||||
|
param_keys = [prefix + 'weight', prefix + 'bias', prefix + 'cls_token', prefix + 'pos_embed']
|
||||||
|
if gpc.get_local_rank(ParallelMode.TENSOR) == 0:
|
||||||
|
for key in param_keys:
|
||||||
|
param = state_dict.pop(key, None)
|
||||||
|
if param is not None:
|
||||||
|
local_state[key] = param
|
||||||
|
|
||||||
|
local_state = broadcast_state_dict(local_state, ParallelMode.PARALLEL_1D)
|
||||||
|
super()._load_from_state_dict(local_state, prefix, *args)
|
||||||
|
|
||||||
|
def _save_to_state_dict(self, destination, prefix, keep_vars):
|
||||||
|
if gpc.get_local_rank(ParallelMode.TENSOR) == 0:
|
||||||
|
super()._save_to_state_dict(destination, prefix, keep_vars)
|
||||||
|
|
Loading…
Reference in New Issue