|
|
|
@ -2,6 +2,7 @@
|
|
|
|
|
# -*- encoding: utf-8 -*- |
|
|
|
|
|
|
|
|
|
import math |
|
|
|
|
from collections import OrderedDict |
|
|
|
|
from typing import Callable, Tuple |
|
|
|
|
|
|
|
|
|
import torch |
|
|
|
@ -10,20 +11,25 @@ from colossalai.communication import broadcast
|
|
|
|
|
from colossalai.context import ParallelMode, seed |
|
|
|
|
from colossalai.core import global_context as gpc |
|
|
|
|
from colossalai.global_variables import tensor_parallel_env as env |
|
|
|
|
from colossalai.kernel import LayerNorm |
|
|
|
|
from colossalai.nn import init as init |
|
|
|
|
from colossalai.registry import LAYERS |
|
|
|
|
from colossalai.utils.checkpointing import (broadcast_state_dict, gather_tensor_parallel_state_dict, |
|
|
|
|
partition_tensor_parallel_state_dict) |
|
|
|
|
from colossalai.utils.cuda import get_current_device |
|
|
|
|
from torch import Tensor |
|
|
|
|
from torch.nn.parameter import Parameter |
|
|
|
|
from ..vanilla import VanillaPatchEmbedding |
|
|
|
|
|
|
|
|
|
from ..base_layer import ParallelLayer |
|
|
|
|
from ..colossalai_layer._utils import ColossalaiModule |
|
|
|
|
from ..utils import divide, set_tensor_parallel_attribute_by_partition |
|
|
|
|
from ._utils import (gather_forward_split_backward, get_parallel_input, reduce_grad, reduce_input, set_parallel_input, |
|
|
|
|
split_forward_gather_backward) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@LAYERS.register_module |
|
|
|
|
class Linear1D(torch.nn.Module): |
|
|
|
|
class Linear1D(ColossalaiModule): |
|
|
|
|
r"""Linear layer for 1D parallelism. |
|
|
|
|
|
|
|
|
|
Args: |
|
|
|
@ -52,37 +58,69 @@ class Linear1D(torch.nn.Module):
|
|
|
|
|
skip_bias_add: bool = False, |
|
|
|
|
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), |
|
|
|
|
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1)): |
|
|
|
|
super().__init__() |
|
|
|
|
parallel_input = get_parallel_input() |
|
|
|
|
if not parallel_input: |
|
|
|
|
self.layer = Linear1D_Col(in_features, |
|
|
|
|
out_features, |
|
|
|
|
bias=bias, |
|
|
|
|
dtype=dtype, |
|
|
|
|
gather_output=gather_output, |
|
|
|
|
skip_bias_add=skip_bias_add, |
|
|
|
|
weight_initializer=weight_initializer, |
|
|
|
|
bias_initializer=bias_initializer) |
|
|
|
|
layer = Linear1D_Col(in_features, |
|
|
|
|
out_features, |
|
|
|
|
bias=bias, |
|
|
|
|
dtype=dtype, |
|
|
|
|
gather_output=gather_output, |
|
|
|
|
skip_bias_add=skip_bias_add, |
|
|
|
|
weight_initializer=weight_initializer, |
|
|
|
|
bias_initializer=bias_initializer) |
|
|
|
|
else: |
|
|
|
|
self.layer = Linear1D_Row(in_features, |
|
|
|
|
out_features, |
|
|
|
|
bias=bias, |
|
|
|
|
dtype=dtype, |
|
|
|
|
parallel_input=parallel_input, |
|
|
|
|
skip_bias_add=skip_bias_add, |
|
|
|
|
weight_initializer=weight_initializer, |
|
|
|
|
bias_initializer=bias_initializer) |
|
|
|
|
layer = Linear1D_Row(in_features, |
|
|
|
|
out_features, |
|
|
|
|
bias=bias, |
|
|
|
|
dtype=dtype, |
|
|
|
|
parallel_input=parallel_input, |
|
|
|
|
skip_bias_add=skip_bias_add, |
|
|
|
|
weight_initializer=weight_initializer, |
|
|
|
|
bias_initializer=bias_initializer) |
|
|
|
|
super().__init__(layer) |
|
|
|
|
|
|
|
|
|
@property |
|
|
|
|
def weight(self): |
|
|
|
|
return self.layer.weight |
|
|
|
|
|
|
|
|
|
@property |
|
|
|
|
def bias(self): |
|
|
|
|
return self.layer.bias |
|
|
|
|
@LAYERS.register_module |
|
|
|
|
class LayerNorm1D(ColossalaiModule): |
|
|
|
|
r""" |
|
|
|
|
Layer Normalization for colossalai |
|
|
|
|
|
|
|
|
|
:param normalized_shape: input shape from an expected input |
|
|
|
|
of size. :math:`[* \times \text{normalized_shape}[0] \times \text{normalized_shape}[1] |
|
|
|
|
\times \ldots \times \text{normalized_shape}[-1]]` |
|
|
|
|
If a single integer is used, it is treated as a singleton list, and this module will |
|
|
|
|
normalize over the last dimension which is expected to be of that specific size. |
|
|
|
|
:type normalized_shape: int |
|
|
|
|
:param eps: a value added to the denominator for numerical stability, defaults to 1e-05 |
|
|
|
|
:type eps: float, optional |
|
|
|
|
:param dtype: The dtype of parameters, defaults to None |
|
|
|
|
:type dtype: torch.dtype, optional |
|
|
|
|
""" |
|
|
|
|
|
|
|
|
|
def forward(self, input_: Tensor) -> Tensor: |
|
|
|
|
return self.layer(input_) |
|
|
|
|
def __init__(self, normalized_shape: int, eps=1e-05, dtype=None): |
|
|
|
|
norm = LayerNorm(normalized_shape, eps=eps, device=get_current_device(), dtype=dtype) |
|
|
|
|
super().__init__(norm) |
|
|
|
|
|
|
|
|
|
def _load_from_state_dict(self, state_dict, prefix, *args): |
|
|
|
|
local_state = OrderedDict() |
|
|
|
|
weight_key = prefix + 'weight' |
|
|
|
|
bias_key = prefix + 'bias' |
|
|
|
|
if gpc.get_local_rank(ParallelMode.TENSOR) == 0: |
|
|
|
|
# weight |
|
|
|
|
weight = state_dict.pop(weight_key, None) |
|
|
|
|
if weight is not None: |
|
|
|
|
local_state[weight_key] = weight |
|
|
|
|
# bias |
|
|
|
|
bias = state_dict.pop(bias_key, None) |
|
|
|
|
if bias is not None: |
|
|
|
|
local_state[bias_key] = bias |
|
|
|
|
|
|
|
|
|
local_state = broadcast_state_dict(local_state, ParallelMode.PARALLEL_1D) |
|
|
|
|
super()._load_from_state_dict(local_state, prefix, *args) |
|
|
|
|
|
|
|
|
|
def _save_to_state_dict(self, destination, prefix, keep_vars): |
|
|
|
|
if gpc.get_local_rank(ParallelMode.TENSOR) == 0: |
|
|
|
|
super()._save_to_state_dict(destination, prefix, keep_vars) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@LAYERS.register_module |
|
|
|
@ -153,6 +191,55 @@ class Classifier1D(ParallelLayer):
|
|
|
|
|
num_partition = gpc.get_world_size(ParallelMode.TENSOR) |
|
|
|
|
set_tensor_parallel_attribute_by_partition(self.weight, num_partition) |
|
|
|
|
|
|
|
|
|
def _load_from_state_dict(self, state_dict, prefix, *args): |
|
|
|
|
local_state = OrderedDict() |
|
|
|
|
weight_key = prefix + 'weight' |
|
|
|
|
bias_key = prefix + 'bias' |
|
|
|
|
if gpc.get_local_rank(ParallelMode.TENSOR) == 0: |
|
|
|
|
# weight |
|
|
|
|
if self.has_weight: |
|
|
|
|
weight = state_dict.pop(weight_key, None) |
|
|
|
|
if weight is not None: |
|
|
|
|
local_state[weight_key] = weight |
|
|
|
|
# bias |
|
|
|
|
if self.bias is not None: |
|
|
|
|
bias = state_dict.pop(bias_key, None) |
|
|
|
|
if bias is not None: |
|
|
|
|
local_state[bias_key] = bias |
|
|
|
|
|
|
|
|
|
local_state = partition_tensor_parallel_state_dict(local_state, |
|
|
|
|
ParallelMode.PARALLEL_1D, |
|
|
|
|
dims={ |
|
|
|
|
weight_key: -1, |
|
|
|
|
bias_key: 0 |
|
|
|
|
}, |
|
|
|
|
partition_states={ |
|
|
|
|
weight_key: True, |
|
|
|
|
bias_key: False |
|
|
|
|
}) |
|
|
|
|
super()._load_from_state_dict(local_state, prefix, *args) |
|
|
|
|
|
|
|
|
|
def _save_to_state_dict(self, destination, prefix, keep_vars): |
|
|
|
|
weight_key = prefix + 'weight' |
|
|
|
|
bias_key = prefix + 'bias' |
|
|
|
|
local_state = OrderedDict() |
|
|
|
|
if self.has_weight: |
|
|
|
|
local_state[weight_key] = self.weight |
|
|
|
|
if self.bias is not None: |
|
|
|
|
local_state[bias_key] = self.bias |
|
|
|
|
local_state = gather_tensor_parallel_state_dict(local_state, |
|
|
|
|
ParallelMode.PARALLEL_1D, |
|
|
|
|
dims={ |
|
|
|
|
weight_key: -1, |
|
|
|
|
bias_key: 0 |
|
|
|
|
}, |
|
|
|
|
partition_states={ |
|
|
|
|
weight_key: True, |
|
|
|
|
bias_key: False |
|
|
|
|
}, |
|
|
|
|
keep_vars=keep_vars) |
|
|
|
|
destination.update(local_state) |
|
|
|
|
|
|
|
|
|
def forward(self, input_: Tensor) -> Tensor: |
|
|
|
|
# Set up backprop all-reduce. |
|
|
|
|
if self.parallel_input: |
|
|
|
@ -241,6 +328,55 @@ class VocabParallelClassifier1D(ParallelLayer):
|
|
|
|
|
if self.bias is not None: |
|
|
|
|
set_tensor_parallel_attribute_by_partition(self.bias, num_partition) |
|
|
|
|
|
|
|
|
|
def _load_from_state_dict(self, state_dict, prefix, *args): |
|
|
|
|
local_state = OrderedDict() |
|
|
|
|
weight_key = prefix + 'weight' |
|
|
|
|
bias_key = prefix + 'bias' |
|
|
|
|
if gpc.get_local_rank(ParallelMode.TENSOR) == 0: |
|
|
|
|
# weight |
|
|
|
|
if self.has_weight: |
|
|
|
|
weight = state_dict.pop(weight_key, None) |
|
|
|
|
if weight is not None: |
|
|
|
|
local_state[weight_key] = weight |
|
|
|
|
# bias |
|
|
|
|
if self.bias is not None: |
|
|
|
|
bias = state_dict.pop(bias_key, None) |
|
|
|
|
if bias is not None: |
|
|
|
|
local_state[bias_key] = bias |
|
|
|
|
|
|
|
|
|
local_state = partition_tensor_parallel_state_dict(local_state, |
|
|
|
|
ParallelMode.PARALLEL_1D, |
|
|
|
|
dims={ |
|
|
|
|
weight_key: 0, |
|
|
|
|
bias_key: 0 |
|
|
|
|
}, |
|
|
|
|
partition_states={ |
|
|
|
|
weight_key: True, |
|
|
|
|
bias_key: True |
|
|
|
|
}) |
|
|
|
|
super()._load_from_state_dict(local_state, prefix, *args) |
|
|
|
|
|
|
|
|
|
def _save_to_state_dict(self, destination, prefix, keep_vars): |
|
|
|
|
weight_key = prefix + 'weight' |
|
|
|
|
bias_key = prefix + 'bias' |
|
|
|
|
local_state = OrderedDict() |
|
|
|
|
if self.has_weight: |
|
|
|
|
local_state[weight_key] = self.weight |
|
|
|
|
if self.bias is not None: |
|
|
|
|
local_state[bias_key] = self.bias |
|
|
|
|
local_state = gather_tensor_parallel_state_dict(local_state, |
|
|
|
|
ParallelMode.PARALLEL_1D, |
|
|
|
|
dims={ |
|
|
|
|
weight_key: 0, |
|
|
|
|
bias_key: 0 |
|
|
|
|
}, |
|
|
|
|
partition_states={ |
|
|
|
|
weight_key: True, |
|
|
|
|
bias_key: True |
|
|
|
|
}, |
|
|
|
|
keep_vars=keep_vars) |
|
|
|
|
destination.update(local_state) |
|
|
|
|
|
|
|
|
|
def forward(self, input_: Tensor) -> Tensor: |
|
|
|
|
assert input_.shape[-1] == self.weight.shape[-1], \ |
|
|
|
|
'Invalid shapes in VocabParallelClassifier1D forward: input={}, weight={}. Expected last dim of input {}.'.format( |
|
|
|
@ -328,6 +464,52 @@ class Linear1D_Col(ParallelLayer):
|
|
|
|
|
if self.bias is not None: |
|
|
|
|
set_tensor_parallel_attribute_by_partition(self.bias, num_partition) |
|
|
|
|
|
|
|
|
|
def _load_from_state_dict(self, state_dict, prefix, *args): |
|
|
|
|
local_state = OrderedDict() |
|
|
|
|
weight_key = prefix + 'weight' |
|
|
|
|
bias_key = prefix + 'bias' |
|
|
|
|
if gpc.get_local_rank(ParallelMode.TENSOR) == 0: |
|
|
|
|
# weight |
|
|
|
|
weight = state_dict.pop(weight_key, None) |
|
|
|
|
if weight is not None: |
|
|
|
|
local_state[weight_key] = weight |
|
|
|
|
# bias |
|
|
|
|
if self.bias is not None: |
|
|
|
|
bias = state_dict.pop(bias_key, None) |
|
|
|
|
if bias is not None: |
|
|
|
|
local_state[bias_key] = bias |
|
|
|
|
|
|
|
|
|
local_state = partition_tensor_parallel_state_dict(local_state, |
|
|
|
|
ParallelMode.PARALLEL_1D, |
|
|
|
|
dims={ |
|
|
|
|
weight_key: 0, |
|
|
|
|
bias_key: 0 |
|
|
|
|
}, |
|
|
|
|
partition_states={ |
|
|
|
|
weight_key: True, |
|
|
|
|
bias_key: True |
|
|
|
|
}) |
|
|
|
|
super()._load_from_state_dict(local_state, prefix, *args) |
|
|
|
|
|
|
|
|
|
def _save_to_state_dict(self, destination, prefix, keep_vars): |
|
|
|
|
weight_key = prefix + 'weight' |
|
|
|
|
bias_key = prefix + 'bias' |
|
|
|
|
local_state = OrderedDict({weight_key: self.weight}) |
|
|
|
|
if self.bias is not None: |
|
|
|
|
local_state[bias_key] = self.bias |
|
|
|
|
local_state = gather_tensor_parallel_state_dict(local_state, |
|
|
|
|
ParallelMode.PARALLEL_1D, |
|
|
|
|
dims={ |
|
|
|
|
weight_key: 0, |
|
|
|
|
bias_key: 0 |
|
|
|
|
}, |
|
|
|
|
partition_states={ |
|
|
|
|
weight_key: True, |
|
|
|
|
bias_key: True |
|
|
|
|
}, |
|
|
|
|
keep_vars=keep_vars) |
|
|
|
|
destination.update(local_state) |
|
|
|
|
|
|
|
|
|
def forward(self, input_: Tensor) -> Tuple[Tensor, Tensor]: |
|
|
|
|
assert input_.shape[-1] == self.weight.shape[-1], \ |
|
|
|
|
'Invalid shapes in Linear1D_Col forward: input={}, weight={}. Expected last dim of input {}.'.format( |
|
|
|
@ -420,6 +602,52 @@ class Linear1D_Row(ParallelLayer):
|
|
|
|
|
num_partition = gpc.get_world_size(ParallelMode.TENSOR) |
|
|
|
|
set_tensor_parallel_attribute_by_partition(self.weight, num_partition) |
|
|
|
|
|
|
|
|
|
def _load_from_state_dict(self, state_dict, prefix, *args): |
|
|
|
|
local_state = OrderedDict() |
|
|
|
|
weight_key = prefix + 'weight' |
|
|
|
|
bias_key = prefix + 'bias' |
|
|
|
|
if gpc.get_local_rank(ParallelMode.TENSOR) == 0: |
|
|
|
|
# weight |
|
|
|
|
weight = state_dict.pop(weight_key, None) |
|
|
|
|
if weight is not None: |
|
|
|
|
local_state[weight_key] = weight |
|
|
|
|
# bias |
|
|
|
|
if self.bias is not None: |
|
|
|
|
bias = state_dict.pop(bias_key, None) |
|
|
|
|
if bias is not None: |
|
|
|
|
local_state[bias_key] = bias |
|
|
|
|
|
|
|
|
|
local_state = partition_tensor_parallel_state_dict(local_state, |
|
|
|
|
ParallelMode.PARALLEL_1D, |
|
|
|
|
dims={ |
|
|
|
|
weight_key: -1, |
|
|
|
|
bias_key: 0 |
|
|
|
|
}, |
|
|
|
|
partition_states={ |
|
|
|
|
weight_key: True, |
|
|
|
|
bias_key: False |
|
|
|
|
}) |
|
|
|
|
super()._load_from_state_dict(local_state, prefix, *args) |
|
|
|
|
|
|
|
|
|
def _save_to_state_dict(self, destination, prefix, keep_vars): |
|
|
|
|
weight_key = prefix + 'weight' |
|
|
|
|
bias_key = prefix + 'bias' |
|
|
|
|
local_state = OrderedDict({weight_key: self.weight}) |
|
|
|
|
if self.bias is not None: |
|
|
|
|
local_state[bias_key] = self.bias |
|
|
|
|
local_state = gather_tensor_parallel_state_dict(local_state, |
|
|
|
|
ParallelMode.PARALLEL_1D, |
|
|
|
|
dims={ |
|
|
|
|
weight_key: -1, |
|
|
|
|
bias_key: 0 |
|
|
|
|
}, |
|
|
|
|
partition_states={ |
|
|
|
|
weight_key: True, |
|
|
|
|
bias_key: False |
|
|
|
|
}, |
|
|
|
|
keep_vars=keep_vars) |
|
|
|
|
destination.update(local_state) |
|
|
|
|
|
|
|
|
|
def forward(self, input_: Tensor) -> Tensor: |
|
|
|
|
# Set up backprop all-reduce. |
|
|
|
|
if self.parallel_input: |
|
|
|
@ -514,6 +742,31 @@ class Embedding1D(ParallelLayer):
|
|
|
|
|
with torch.no_grad(): |
|
|
|
|
self.weight[self.padding_idx].fill_(0) |
|
|
|
|
|
|
|
|
|
def _load_from_state_dict(self, state_dict, prefix, *args): |
|
|
|
|
local_state = OrderedDict() |
|
|
|
|
weight_key = prefix + 'weight' |
|
|
|
|
if gpc.get_local_rank(ParallelMode.TENSOR) == 0: |
|
|
|
|
# weight |
|
|
|
|
weight = state_dict.pop(weight_key, None) |
|
|
|
|
if weight is not None: |
|
|
|
|
local_state[weight_key] = weight |
|
|
|
|
|
|
|
|
|
local_state = partition_tensor_parallel_state_dict(local_state, |
|
|
|
|
ParallelMode.PARALLEL_1D, |
|
|
|
|
dims={weight_key: -1}, |
|
|
|
|
partition_states={weight_key: True}) |
|
|
|
|
super()._load_from_state_dict(local_state, prefix, *args) |
|
|
|
|
|
|
|
|
|
def _save_to_state_dict(self, destination, prefix, keep_vars): |
|
|
|
|
weight_key = prefix + 'weight' |
|
|
|
|
local_state = OrderedDict({weight_key: self.weight}) |
|
|
|
|
local_state = gather_tensor_parallel_state_dict(local_state, |
|
|
|
|
ParallelMode.PARALLEL_1D, |
|
|
|
|
dims={weight_key: -1}, |
|
|
|
|
partition_states={weight_key: True}, |
|
|
|
|
keep_vars=keep_vars) |
|
|
|
|
destination.update(local_state) |
|
|
|
|
|
|
|
|
|
def forward(self, input_: Tensor) -> Tensor: |
|
|
|
|
|
|
|
|
|
output_parallel = F.embedding(input_, self.weight, self.padding_idx, *self.embed_args, **self.embed_kwargs) |
|
|
|
@ -594,10 +847,35 @@ class VocabParallelEmbedding1D(torch.nn.Module):
|
|
|
|
|
|
|
|
|
|
def _fill_padding_idx_with_zero(self) -> None: |
|
|
|
|
if self.padding_idx is not None and \ |
|
|
|
|
self.padding_idx >= self.vocab_start_index and self.padding_idx < self.vocab_end_index: |
|
|
|
|
self.padding_idx >= self.vocab_start_index and self.padding_idx < self.vocab_end_index: |
|
|
|
|
with torch.no_grad(): |
|
|
|
|
self.weight[self.padding_idx - self.vocab_start_index].fill_(0) |
|
|
|
|
|
|
|
|
|
def _load_from_state_dict(self, state_dict, prefix, *args): |
|
|
|
|
local_state = OrderedDict() |
|
|
|
|
weight_key = prefix + 'weight' |
|
|
|
|
if gpc.get_local_rank(ParallelMode.TENSOR) == 0: |
|
|
|
|
# weight |
|
|
|
|
weight = state_dict.pop(weight_key, None) |
|
|
|
|
if weight is not None: |
|
|
|
|
local_state[weight_key] = weight |
|
|
|
|
|
|
|
|
|
local_state = partition_tensor_parallel_state_dict(local_state, |
|
|
|
|
ParallelMode.PARALLEL_1D, |
|
|
|
|
dims={weight_key: 0}, |
|
|
|
|
partition_states={weight_key: True}) |
|
|
|
|
super()._load_from_state_dict(local_state, prefix, *args) |
|
|
|
|
|
|
|
|
|
def _save_to_state_dict(self, destination, prefix, keep_vars): |
|
|
|
|
weight_key = prefix + 'weight' |
|
|
|
|
local_state = OrderedDict({weight_key: self.weight}) |
|
|
|
|
local_state = gather_tensor_parallel_state_dict(local_state, |
|
|
|
|
ParallelMode.PARALLEL_1D, |
|
|
|
|
dims={weight_key: 0}, |
|
|
|
|
partition_states={weight_key: True}, |
|
|
|
|
keep_vars=keep_vars) |
|
|
|
|
destination.update(local_state) |
|
|
|
|
|
|
|
|
|
def forward(self, input_: Tensor) -> Tensor: |
|
|
|
|
# Build the mask. |
|
|
|
|
input_mask = (input_ < self.vocab_start_index) | (input_ >= self.vocab_end_index) |
|
|
|
@ -637,3 +915,66 @@ class Dropout1D(ParallelLayer):
|
|
|
|
|
else: |
|
|
|
|
output = F.dropout(input_, self.p, self.training, self.inplace) |
|
|
|
|
return output |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@LAYERS.register_module |
|
|
|
|
class PatchEmbedding1D(ColossalaiModule): |
|
|
|
|
""" |
|
|
|
|
2D Image to Patch Embedding |
|
|
|
|
|
|
|
|
|
:param img_size: image size |
|
|
|
|
:type img_size: int |
|
|
|
|
:param patch_size: patch size |
|
|
|
|
:type patch_size: int |
|
|
|
|
:param in_chans: number of channels of input image |
|
|
|
|
:type in_chans: int |
|
|
|
|
:param embed_size: size of embedding |
|
|
|
|
:type embed_size: int |
|
|
|
|
:param dtype: The dtype of parameters, defaults to None |
|
|
|
|
:type dtype: torch.dtype, optional |
|
|
|
|
:param flatten: whether to flatten output tensor, defaults to True |
|
|
|
|
:type flatten: bool, optional |
|
|
|
|
:param weight_initializer: The intializer of weight, defaults to kaiming uniform initializer |
|
|
|
|
:type weight_initializer: typing.Callable, optional |
|
|
|
|
:param bias_initializer: The intializer of bias, defaults to xavier uniform initializer |
|
|
|
|
:type bias_initializer: typing.Callable, optional |
|
|
|
|
:param position_embed_initializer: The intializer of position embedding, defaults to zero |
|
|
|
|
:type position_embed_initializer: typing.Callable, optional |
|
|
|
|
""" |
|
|
|
|
|
|
|
|
|
def __init__(self, |
|
|
|
|
img_size: int, |
|
|
|
|
patch_size: int, |
|
|
|
|
in_chans: int, |
|
|
|
|
embed_size: int, |
|
|
|
|
dtype: torch.dtype = None, |
|
|
|
|
flatten: bool = True, |
|
|
|
|
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), |
|
|
|
|
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1), |
|
|
|
|
position_embed_initializer: Callable = init.zeros_()): |
|
|
|
|
embed = VanillaPatchEmbedding(img_size, |
|
|
|
|
patch_size, |
|
|
|
|
in_chans, |
|
|
|
|
embed_size, |
|
|
|
|
dtype=dtype, |
|
|
|
|
flatten=flatten, |
|
|
|
|
weight_initializer=weight_initializer, |
|
|
|
|
bias_initializer=bias_initializer, |
|
|
|
|
position_embed_initializer=position_embed_initializer) |
|
|
|
|
super().__init__(embed) |
|
|
|
|
|
|
|
|
|
def _load_from_state_dict(self, state_dict, prefix, *args): |
|
|
|
|
local_state = OrderedDict() |
|
|
|
|
param_keys = [prefix + 'weight', prefix + 'bias', prefix + 'cls_token', prefix + 'pos_embed'] |
|
|
|
|
if gpc.get_local_rank(ParallelMode.TENSOR) == 0: |
|
|
|
|
for key in param_keys: |
|
|
|
|
param = state_dict.pop(key, None) |
|
|
|
|
if param is not None: |
|
|
|
|
local_state[key] = param |
|
|
|
|
|
|
|
|
|
local_state = broadcast_state_dict(local_state, ParallelMode.PARALLEL_1D) |
|
|
|
|
super()._load_from_state_dict(local_state, prefix, *args) |
|
|
|
|
|
|
|
|
|
def _save_to_state_dict(self, destination, prefix, keep_vars): |
|
|
|
|
if gpc.get_local_rank(ParallelMode.TENSOR) == 0: |
|
|
|
|
super()._save_to_state_dict(destination, prefix, keep_vars) |
|
|
|
|