[model checkpoint] updated saving/loading for 1d layers (#594)

pull/625/head
アマデウス 2022-04-01 16:51:52 +08:00 committed by GitHub
parent 7636d518e1
commit c50bfb807b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 371 additions and 30 deletions

View File

@ -1,7 +1,7 @@
from .layers import (Classifier1D, Dropout1D, Embedding1D, Linear1D, Linear1D_Col, Linear1D_Row,
VocabParallelClassifier1D, VocabParallelEmbedding1D)
from .layers import (Classifier1D, Dropout1D, Embedding1D, LayerNorm1D, Linear1D, Linear1D_Col, Linear1D_Row,
PatchEmbedding1D, VocabParallelClassifier1D, VocabParallelEmbedding1D)
__all__ = [
'Linear1D', 'Linear1D_Col', 'Linear1D_Row', 'Embedding1D', 'Dropout1D', 'Classifier1D', 'VocabParallelClassifier1D',
'VocabParallelEmbedding1D'
'VocabParallelEmbedding1D', 'LayerNorm1D', 'PatchEmbedding1D'
]

View File

@ -2,6 +2,7 @@
# -*- encoding: utf-8 -*-
import math
from collections import OrderedDict
from typing import Callable, Tuple
import torch
@ -10,20 +11,25 @@ from colossalai.communication import broadcast
from colossalai.context import ParallelMode, seed
from colossalai.core import global_context as gpc
from colossalai.global_variables import tensor_parallel_env as env
from colossalai.kernel import LayerNorm
from colossalai.nn import init as init
from colossalai.registry import LAYERS
from colossalai.utils.checkpointing import (broadcast_state_dict, gather_tensor_parallel_state_dict,
partition_tensor_parallel_state_dict)
from colossalai.utils.cuda import get_current_device
from torch import Tensor
from torch.nn.parameter import Parameter
from ..vanilla import VanillaPatchEmbedding
from ..base_layer import ParallelLayer
from ..colossalai_layer._utils import ColossalaiModule
from ..utils import divide, set_tensor_parallel_attribute_by_partition
from ._utils import (gather_forward_split_backward, get_parallel_input, reduce_grad, reduce_input, set_parallel_input,
split_forward_gather_backward)
@LAYERS.register_module
class Linear1D(torch.nn.Module):
class Linear1D(ColossalaiModule):
r"""Linear layer for 1D parallelism.
Args:
@ -52,37 +58,69 @@ class Linear1D(torch.nn.Module):
skip_bias_add: bool = False,
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1)):
super().__init__()
parallel_input = get_parallel_input()
if not parallel_input:
self.layer = Linear1D_Col(in_features,
out_features,
bias=bias,
dtype=dtype,
gather_output=gather_output,
skip_bias_add=skip_bias_add,
weight_initializer=weight_initializer,
bias_initializer=bias_initializer)
layer = Linear1D_Col(in_features,
out_features,
bias=bias,
dtype=dtype,
gather_output=gather_output,
skip_bias_add=skip_bias_add,
weight_initializer=weight_initializer,
bias_initializer=bias_initializer)
else:
self.layer = Linear1D_Row(in_features,
out_features,
bias=bias,
dtype=dtype,
parallel_input=parallel_input,
skip_bias_add=skip_bias_add,
weight_initializer=weight_initializer,
bias_initializer=bias_initializer)
layer = Linear1D_Row(in_features,
out_features,
bias=bias,
dtype=dtype,
parallel_input=parallel_input,
skip_bias_add=skip_bias_add,
weight_initializer=weight_initializer,
bias_initializer=bias_initializer)
super().__init__(layer)
@property
def weight(self):
return self.layer.weight
@property
def bias(self):
return self.layer.bias
@LAYERS.register_module
class LayerNorm1D(ColossalaiModule):
r"""
Layer Normalization for colossalai
def forward(self, input_: Tensor) -> Tensor:
return self.layer(input_)
:param normalized_shape: input shape from an expected input
of size. :math:`[* \times \text{normalized_shape}[0] \times \text{normalized_shape}[1]
\times \ldots \times \text{normalized_shape}[-1]]`
If a single integer is used, it is treated as a singleton list, and this module will
normalize over the last dimension which is expected to be of that specific size.
:type normalized_shape: int
:param eps: a value added to the denominator for numerical stability, defaults to 1e-05
:type eps: float, optional
:param dtype: The dtype of parameters, defaults to None
:type dtype: torch.dtype, optional
"""
def __init__(self, normalized_shape: int, eps=1e-05, dtype=None):
norm = LayerNorm(normalized_shape, eps=eps, device=get_current_device(), dtype=dtype)
super().__init__(norm)
def _load_from_state_dict(self, state_dict, prefix, *args):
local_state = OrderedDict()
weight_key = prefix + 'weight'
bias_key = prefix + 'bias'
if gpc.get_local_rank(ParallelMode.TENSOR) == 0:
# weight
weight = state_dict.pop(weight_key, None)
if weight is not None:
local_state[weight_key] = weight
# bias
bias = state_dict.pop(bias_key, None)
if bias is not None:
local_state[bias_key] = bias
local_state = broadcast_state_dict(local_state, ParallelMode.PARALLEL_1D)
super()._load_from_state_dict(local_state, prefix, *args)
def _save_to_state_dict(self, destination, prefix, keep_vars):
if gpc.get_local_rank(ParallelMode.TENSOR) == 0:
super()._save_to_state_dict(destination, prefix, keep_vars)
@LAYERS.register_module
@ -153,6 +191,55 @@ class Classifier1D(ParallelLayer):
num_partition = gpc.get_world_size(ParallelMode.TENSOR)
set_tensor_parallel_attribute_by_partition(self.weight, num_partition)
def _load_from_state_dict(self, state_dict, prefix, *args):
local_state = OrderedDict()
weight_key = prefix + 'weight'
bias_key = prefix + 'bias'
if gpc.get_local_rank(ParallelMode.TENSOR) == 0:
# weight
if self.has_weight:
weight = state_dict.pop(weight_key, None)
if weight is not None:
local_state[weight_key] = weight
# bias
if self.bias is not None:
bias = state_dict.pop(bias_key, None)
if bias is not None:
local_state[bias_key] = bias
local_state = partition_tensor_parallel_state_dict(local_state,
ParallelMode.PARALLEL_1D,
dims={
weight_key: -1,
bias_key: 0
},
partition_states={
weight_key: True,
bias_key: False
})
super()._load_from_state_dict(local_state, prefix, *args)
def _save_to_state_dict(self, destination, prefix, keep_vars):
weight_key = prefix + 'weight'
bias_key = prefix + 'bias'
local_state = OrderedDict()
if self.has_weight:
local_state[weight_key] = self.weight
if self.bias is not None:
local_state[bias_key] = self.bias
local_state = gather_tensor_parallel_state_dict(local_state,
ParallelMode.PARALLEL_1D,
dims={
weight_key: -1,
bias_key: 0
},
partition_states={
weight_key: True,
bias_key: False
},
keep_vars=keep_vars)
destination.update(local_state)
def forward(self, input_: Tensor) -> Tensor:
# Set up backprop all-reduce.
if self.parallel_input:
@ -241,6 +328,55 @@ class VocabParallelClassifier1D(ParallelLayer):
if self.bias is not None:
set_tensor_parallel_attribute_by_partition(self.bias, num_partition)
def _load_from_state_dict(self, state_dict, prefix, *args):
local_state = OrderedDict()
weight_key = prefix + 'weight'
bias_key = prefix + 'bias'
if gpc.get_local_rank(ParallelMode.TENSOR) == 0:
# weight
if self.has_weight:
weight = state_dict.pop(weight_key, None)
if weight is not None:
local_state[weight_key] = weight
# bias
if self.bias is not None:
bias = state_dict.pop(bias_key, None)
if bias is not None:
local_state[bias_key] = bias
local_state = partition_tensor_parallel_state_dict(local_state,
ParallelMode.PARALLEL_1D,
dims={
weight_key: 0,
bias_key: 0
},
partition_states={
weight_key: True,
bias_key: True
})
super()._load_from_state_dict(local_state, prefix, *args)
def _save_to_state_dict(self, destination, prefix, keep_vars):
weight_key = prefix + 'weight'
bias_key = prefix + 'bias'
local_state = OrderedDict()
if self.has_weight:
local_state[weight_key] = self.weight
if self.bias is not None:
local_state[bias_key] = self.bias
local_state = gather_tensor_parallel_state_dict(local_state,
ParallelMode.PARALLEL_1D,
dims={
weight_key: 0,
bias_key: 0
},
partition_states={
weight_key: True,
bias_key: True
},
keep_vars=keep_vars)
destination.update(local_state)
def forward(self, input_: Tensor) -> Tensor:
assert input_.shape[-1] == self.weight.shape[-1], \
'Invalid shapes in VocabParallelClassifier1D forward: input={}, weight={}. Expected last dim of input {}.'.format(
@ -328,6 +464,52 @@ class Linear1D_Col(ParallelLayer):
if self.bias is not None:
set_tensor_parallel_attribute_by_partition(self.bias, num_partition)
def _load_from_state_dict(self, state_dict, prefix, *args):
local_state = OrderedDict()
weight_key = prefix + 'weight'
bias_key = prefix + 'bias'
if gpc.get_local_rank(ParallelMode.TENSOR) == 0:
# weight
weight = state_dict.pop(weight_key, None)
if weight is not None:
local_state[weight_key] = weight
# bias
if self.bias is not None:
bias = state_dict.pop(bias_key, None)
if bias is not None:
local_state[bias_key] = bias
local_state = partition_tensor_parallel_state_dict(local_state,
ParallelMode.PARALLEL_1D,
dims={
weight_key: 0,
bias_key: 0
},
partition_states={
weight_key: True,
bias_key: True
})
super()._load_from_state_dict(local_state, prefix, *args)
def _save_to_state_dict(self, destination, prefix, keep_vars):
weight_key = prefix + 'weight'
bias_key = prefix + 'bias'
local_state = OrderedDict({weight_key: self.weight})
if self.bias is not None:
local_state[bias_key] = self.bias
local_state = gather_tensor_parallel_state_dict(local_state,
ParallelMode.PARALLEL_1D,
dims={
weight_key: 0,
bias_key: 0
},
partition_states={
weight_key: True,
bias_key: True
},
keep_vars=keep_vars)
destination.update(local_state)
def forward(self, input_: Tensor) -> Tuple[Tensor, Tensor]:
assert input_.shape[-1] == self.weight.shape[-1], \
'Invalid shapes in Linear1D_Col forward: input={}, weight={}. Expected last dim of input {}.'.format(
@ -420,6 +602,52 @@ class Linear1D_Row(ParallelLayer):
num_partition = gpc.get_world_size(ParallelMode.TENSOR)
set_tensor_parallel_attribute_by_partition(self.weight, num_partition)
def _load_from_state_dict(self, state_dict, prefix, *args):
local_state = OrderedDict()
weight_key = prefix + 'weight'
bias_key = prefix + 'bias'
if gpc.get_local_rank(ParallelMode.TENSOR) == 0:
# weight
weight = state_dict.pop(weight_key, None)
if weight is not None:
local_state[weight_key] = weight
# bias
if self.bias is not None:
bias = state_dict.pop(bias_key, None)
if bias is not None:
local_state[bias_key] = bias
local_state = partition_tensor_parallel_state_dict(local_state,
ParallelMode.PARALLEL_1D,
dims={
weight_key: -1,
bias_key: 0
},
partition_states={
weight_key: True,
bias_key: False
})
super()._load_from_state_dict(local_state, prefix, *args)
def _save_to_state_dict(self, destination, prefix, keep_vars):
weight_key = prefix + 'weight'
bias_key = prefix + 'bias'
local_state = OrderedDict({weight_key: self.weight})
if self.bias is not None:
local_state[bias_key] = self.bias
local_state = gather_tensor_parallel_state_dict(local_state,
ParallelMode.PARALLEL_1D,
dims={
weight_key: -1,
bias_key: 0
},
partition_states={
weight_key: True,
bias_key: False
},
keep_vars=keep_vars)
destination.update(local_state)
def forward(self, input_: Tensor) -> Tensor:
# Set up backprop all-reduce.
if self.parallel_input:
@ -514,6 +742,31 @@ class Embedding1D(ParallelLayer):
with torch.no_grad():
self.weight[self.padding_idx].fill_(0)
def _load_from_state_dict(self, state_dict, prefix, *args):
local_state = OrderedDict()
weight_key = prefix + 'weight'
if gpc.get_local_rank(ParallelMode.TENSOR) == 0:
# weight
weight = state_dict.pop(weight_key, None)
if weight is not None:
local_state[weight_key] = weight
local_state = partition_tensor_parallel_state_dict(local_state,
ParallelMode.PARALLEL_1D,
dims={weight_key: -1},
partition_states={weight_key: True})
super()._load_from_state_dict(local_state, prefix, *args)
def _save_to_state_dict(self, destination, prefix, keep_vars):
weight_key = prefix + 'weight'
local_state = OrderedDict({weight_key: self.weight})
local_state = gather_tensor_parallel_state_dict(local_state,
ParallelMode.PARALLEL_1D,
dims={weight_key: -1},
partition_states={weight_key: True},
keep_vars=keep_vars)
destination.update(local_state)
def forward(self, input_: Tensor) -> Tensor:
output_parallel = F.embedding(input_, self.weight, self.padding_idx, *self.embed_args, **self.embed_kwargs)
@ -594,10 +847,35 @@ class VocabParallelEmbedding1D(torch.nn.Module):
def _fill_padding_idx_with_zero(self) -> None:
if self.padding_idx is not None and \
self.padding_idx >= self.vocab_start_index and self.padding_idx < self.vocab_end_index:
self.padding_idx >= self.vocab_start_index and self.padding_idx < self.vocab_end_index:
with torch.no_grad():
self.weight[self.padding_idx - self.vocab_start_index].fill_(0)
def _load_from_state_dict(self, state_dict, prefix, *args):
local_state = OrderedDict()
weight_key = prefix + 'weight'
if gpc.get_local_rank(ParallelMode.TENSOR) == 0:
# weight
weight = state_dict.pop(weight_key, None)
if weight is not None:
local_state[weight_key] = weight
local_state = partition_tensor_parallel_state_dict(local_state,
ParallelMode.PARALLEL_1D,
dims={weight_key: 0},
partition_states={weight_key: True})
super()._load_from_state_dict(local_state, prefix, *args)
def _save_to_state_dict(self, destination, prefix, keep_vars):
weight_key = prefix + 'weight'
local_state = OrderedDict({weight_key: self.weight})
local_state = gather_tensor_parallel_state_dict(local_state,
ParallelMode.PARALLEL_1D,
dims={weight_key: 0},
partition_states={weight_key: True},
keep_vars=keep_vars)
destination.update(local_state)
def forward(self, input_: Tensor) -> Tensor:
# Build the mask.
input_mask = (input_ < self.vocab_start_index) | (input_ >= self.vocab_end_index)
@ -637,3 +915,66 @@ class Dropout1D(ParallelLayer):
else:
output = F.dropout(input_, self.p, self.training, self.inplace)
return output
@LAYERS.register_module
class PatchEmbedding1D(ColossalaiModule):
"""
2D Image to Patch Embedding
:param img_size: image size
:type img_size: int
:param patch_size: patch size
:type patch_size: int
:param in_chans: number of channels of input image
:type in_chans: int
:param embed_size: size of embedding
:type embed_size: int
:param dtype: The dtype of parameters, defaults to None
:type dtype: torch.dtype, optional
:param flatten: whether to flatten output tensor, defaults to True
:type flatten: bool, optional
:param weight_initializer: The intializer of weight, defaults to kaiming uniform initializer
:type weight_initializer: typing.Callable, optional
:param bias_initializer: The intializer of bias, defaults to xavier uniform initializer
:type bias_initializer: typing.Callable, optional
:param position_embed_initializer: The intializer of position embedding, defaults to zero
:type position_embed_initializer: typing.Callable, optional
"""
def __init__(self,
img_size: int,
patch_size: int,
in_chans: int,
embed_size: int,
dtype: torch.dtype = None,
flatten: bool = True,
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1),
position_embed_initializer: Callable = init.zeros_()):
embed = VanillaPatchEmbedding(img_size,
patch_size,
in_chans,
embed_size,
dtype=dtype,
flatten=flatten,
weight_initializer=weight_initializer,
bias_initializer=bias_initializer,
position_embed_initializer=position_embed_initializer)
super().__init__(embed)
def _load_from_state_dict(self, state_dict, prefix, *args):
local_state = OrderedDict()
param_keys = [prefix + 'weight', prefix + 'bias', prefix + 'cls_token', prefix + 'pos_embed']
if gpc.get_local_rank(ParallelMode.TENSOR) == 0:
for key in param_keys:
param = state_dict.pop(key, None)
if param is not None:
local_state[key] = param
local_state = broadcast_state_dict(local_state, ParallelMode.PARALLEL_1D)
super()._load_from_state_dict(local_state, prefix, *args)
def _save_to_state_dict(self, destination, prefix, keep_vars):
if gpc.get_local_rank(ParallelMode.TENSOR) == 0:
super()._save_to_state_dict(destination, prefix, keep_vars)