[model checkpoint] updated saving/loading for 1d layers (#594)

pull/625/head
アマデウス 3 years ago committed by GitHub
parent 7636d518e1
commit c50bfb807b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -1,7 +1,7 @@
from .layers import (Classifier1D, Dropout1D, Embedding1D, Linear1D, Linear1D_Col, Linear1D_Row, from .layers import (Classifier1D, Dropout1D, Embedding1D, LayerNorm1D, Linear1D, Linear1D_Col, Linear1D_Row,
VocabParallelClassifier1D, VocabParallelEmbedding1D) PatchEmbedding1D, VocabParallelClassifier1D, VocabParallelEmbedding1D)
__all__ = [ __all__ = [
'Linear1D', 'Linear1D_Col', 'Linear1D_Row', 'Embedding1D', 'Dropout1D', 'Classifier1D', 'VocabParallelClassifier1D', 'Linear1D', 'Linear1D_Col', 'Linear1D_Row', 'Embedding1D', 'Dropout1D', 'Classifier1D', 'VocabParallelClassifier1D',
'VocabParallelEmbedding1D' 'VocabParallelEmbedding1D', 'LayerNorm1D', 'PatchEmbedding1D'
] ]

@ -2,6 +2,7 @@
# -*- encoding: utf-8 -*- # -*- encoding: utf-8 -*-
import math import math
from collections import OrderedDict
from typing import Callable, Tuple from typing import Callable, Tuple
import torch import torch
@ -10,20 +11,25 @@ from colossalai.communication import broadcast
from colossalai.context import ParallelMode, seed from colossalai.context import ParallelMode, seed
from colossalai.core import global_context as gpc from colossalai.core import global_context as gpc
from colossalai.global_variables import tensor_parallel_env as env from colossalai.global_variables import tensor_parallel_env as env
from colossalai.kernel import LayerNorm
from colossalai.nn import init as init from colossalai.nn import init as init
from colossalai.registry import LAYERS from colossalai.registry import LAYERS
from colossalai.utils.checkpointing import (broadcast_state_dict, gather_tensor_parallel_state_dict,
partition_tensor_parallel_state_dict)
from colossalai.utils.cuda import get_current_device from colossalai.utils.cuda import get_current_device
from torch import Tensor from torch import Tensor
from torch.nn.parameter import Parameter from torch.nn.parameter import Parameter
from ..vanilla import VanillaPatchEmbedding
from ..base_layer import ParallelLayer from ..base_layer import ParallelLayer
from ..colossalai_layer._utils import ColossalaiModule
from ..utils import divide, set_tensor_parallel_attribute_by_partition from ..utils import divide, set_tensor_parallel_attribute_by_partition
from ._utils import (gather_forward_split_backward, get_parallel_input, reduce_grad, reduce_input, set_parallel_input, from ._utils import (gather_forward_split_backward, get_parallel_input, reduce_grad, reduce_input, set_parallel_input,
split_forward_gather_backward) split_forward_gather_backward)
@LAYERS.register_module @LAYERS.register_module
class Linear1D(torch.nn.Module): class Linear1D(ColossalaiModule):
r"""Linear layer for 1D parallelism. r"""Linear layer for 1D parallelism.
Args: Args:
@ -52,37 +58,69 @@ class Linear1D(torch.nn.Module):
skip_bias_add: bool = False, skip_bias_add: bool = False,
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1)): bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1)):
super().__init__()
parallel_input = get_parallel_input() parallel_input = get_parallel_input()
if not parallel_input: if not parallel_input:
self.layer = Linear1D_Col(in_features, layer = Linear1D_Col(in_features,
out_features, out_features,
bias=bias, bias=bias,
dtype=dtype, dtype=dtype,
gather_output=gather_output, gather_output=gather_output,
skip_bias_add=skip_bias_add, skip_bias_add=skip_bias_add,
weight_initializer=weight_initializer, weight_initializer=weight_initializer,
bias_initializer=bias_initializer) bias_initializer=bias_initializer)
else: else:
self.layer = Linear1D_Row(in_features, layer = Linear1D_Row(in_features,
out_features, out_features,
bias=bias, bias=bias,
dtype=dtype, dtype=dtype,
parallel_input=parallel_input, parallel_input=parallel_input,
skip_bias_add=skip_bias_add, skip_bias_add=skip_bias_add,
weight_initializer=weight_initializer, weight_initializer=weight_initializer,
bias_initializer=bias_initializer) bias_initializer=bias_initializer)
super().__init__(layer)
@property
def weight(self):
return self.layer.weight
@property @LAYERS.register_module
def bias(self): class LayerNorm1D(ColossalaiModule):
return self.layer.bias r"""
Layer Normalization for colossalai
:param normalized_shape: input shape from an expected input
of size. :math:`[* \times \text{normalized_shape}[0] \times \text{normalized_shape}[1]
\times \ldots \times \text{normalized_shape}[-1]]`
If a single integer is used, it is treated as a singleton list, and this module will
normalize over the last dimension which is expected to be of that specific size.
:type normalized_shape: int
:param eps: a value added to the denominator for numerical stability, defaults to 1e-05
:type eps: float, optional
:param dtype: The dtype of parameters, defaults to None
:type dtype: torch.dtype, optional
"""
def forward(self, input_: Tensor) -> Tensor: def __init__(self, normalized_shape: int, eps=1e-05, dtype=None):
return self.layer(input_) norm = LayerNorm(normalized_shape, eps=eps, device=get_current_device(), dtype=dtype)
super().__init__(norm)
def _load_from_state_dict(self, state_dict, prefix, *args):
local_state = OrderedDict()
weight_key = prefix + 'weight'
bias_key = prefix + 'bias'
if gpc.get_local_rank(ParallelMode.TENSOR) == 0:
# weight
weight = state_dict.pop(weight_key, None)
if weight is not None:
local_state[weight_key] = weight
# bias
bias = state_dict.pop(bias_key, None)
if bias is not None:
local_state[bias_key] = bias
local_state = broadcast_state_dict(local_state, ParallelMode.PARALLEL_1D)
super()._load_from_state_dict(local_state, prefix, *args)
def _save_to_state_dict(self, destination, prefix, keep_vars):
if gpc.get_local_rank(ParallelMode.TENSOR) == 0:
super()._save_to_state_dict(destination, prefix, keep_vars)
@LAYERS.register_module @LAYERS.register_module
@ -153,6 +191,55 @@ class Classifier1D(ParallelLayer):
num_partition = gpc.get_world_size(ParallelMode.TENSOR) num_partition = gpc.get_world_size(ParallelMode.TENSOR)
set_tensor_parallel_attribute_by_partition(self.weight, num_partition) set_tensor_parallel_attribute_by_partition(self.weight, num_partition)
def _load_from_state_dict(self, state_dict, prefix, *args):
local_state = OrderedDict()
weight_key = prefix + 'weight'
bias_key = prefix + 'bias'
if gpc.get_local_rank(ParallelMode.TENSOR) == 0:
# weight
if self.has_weight:
weight = state_dict.pop(weight_key, None)
if weight is not None:
local_state[weight_key] = weight
# bias
if self.bias is not None:
bias = state_dict.pop(bias_key, None)
if bias is not None:
local_state[bias_key] = bias
local_state = partition_tensor_parallel_state_dict(local_state,
ParallelMode.PARALLEL_1D,
dims={
weight_key: -1,
bias_key: 0
},
partition_states={
weight_key: True,
bias_key: False
})
super()._load_from_state_dict(local_state, prefix, *args)
def _save_to_state_dict(self, destination, prefix, keep_vars):
weight_key = prefix + 'weight'
bias_key = prefix + 'bias'
local_state = OrderedDict()
if self.has_weight:
local_state[weight_key] = self.weight
if self.bias is not None:
local_state[bias_key] = self.bias
local_state = gather_tensor_parallel_state_dict(local_state,
ParallelMode.PARALLEL_1D,
dims={
weight_key: -1,
bias_key: 0
},
partition_states={
weight_key: True,
bias_key: False
},
keep_vars=keep_vars)
destination.update(local_state)
def forward(self, input_: Tensor) -> Tensor: def forward(self, input_: Tensor) -> Tensor:
# Set up backprop all-reduce. # Set up backprop all-reduce.
if self.parallel_input: if self.parallel_input:
@ -241,6 +328,55 @@ class VocabParallelClassifier1D(ParallelLayer):
if self.bias is not None: if self.bias is not None:
set_tensor_parallel_attribute_by_partition(self.bias, num_partition) set_tensor_parallel_attribute_by_partition(self.bias, num_partition)
def _load_from_state_dict(self, state_dict, prefix, *args):
local_state = OrderedDict()
weight_key = prefix + 'weight'
bias_key = prefix + 'bias'
if gpc.get_local_rank(ParallelMode.TENSOR) == 0:
# weight
if self.has_weight:
weight = state_dict.pop(weight_key, None)
if weight is not None:
local_state[weight_key] = weight
# bias
if self.bias is not None:
bias = state_dict.pop(bias_key, None)
if bias is not None:
local_state[bias_key] = bias
local_state = partition_tensor_parallel_state_dict(local_state,
ParallelMode.PARALLEL_1D,
dims={
weight_key: 0,
bias_key: 0
},
partition_states={
weight_key: True,
bias_key: True
})
super()._load_from_state_dict(local_state, prefix, *args)
def _save_to_state_dict(self, destination, prefix, keep_vars):
weight_key = prefix + 'weight'
bias_key = prefix + 'bias'
local_state = OrderedDict()
if self.has_weight:
local_state[weight_key] = self.weight
if self.bias is not None:
local_state[bias_key] = self.bias
local_state = gather_tensor_parallel_state_dict(local_state,
ParallelMode.PARALLEL_1D,
dims={
weight_key: 0,
bias_key: 0
},
partition_states={
weight_key: True,
bias_key: True
},
keep_vars=keep_vars)
destination.update(local_state)
def forward(self, input_: Tensor) -> Tensor: def forward(self, input_: Tensor) -> Tensor:
assert input_.shape[-1] == self.weight.shape[-1], \ assert input_.shape[-1] == self.weight.shape[-1], \
'Invalid shapes in VocabParallelClassifier1D forward: input={}, weight={}. Expected last dim of input {}.'.format( 'Invalid shapes in VocabParallelClassifier1D forward: input={}, weight={}. Expected last dim of input {}.'.format(
@ -328,6 +464,52 @@ class Linear1D_Col(ParallelLayer):
if self.bias is not None: if self.bias is not None:
set_tensor_parallel_attribute_by_partition(self.bias, num_partition) set_tensor_parallel_attribute_by_partition(self.bias, num_partition)
def _load_from_state_dict(self, state_dict, prefix, *args):
local_state = OrderedDict()
weight_key = prefix + 'weight'
bias_key = prefix + 'bias'
if gpc.get_local_rank(ParallelMode.TENSOR) == 0:
# weight
weight = state_dict.pop(weight_key, None)
if weight is not None:
local_state[weight_key] = weight
# bias
if self.bias is not None:
bias = state_dict.pop(bias_key, None)
if bias is not None:
local_state[bias_key] = bias
local_state = partition_tensor_parallel_state_dict(local_state,
ParallelMode.PARALLEL_1D,
dims={
weight_key: 0,
bias_key: 0
},
partition_states={
weight_key: True,
bias_key: True
})
super()._load_from_state_dict(local_state, prefix, *args)
def _save_to_state_dict(self, destination, prefix, keep_vars):
weight_key = prefix + 'weight'
bias_key = prefix + 'bias'
local_state = OrderedDict({weight_key: self.weight})
if self.bias is not None:
local_state[bias_key] = self.bias
local_state = gather_tensor_parallel_state_dict(local_state,
ParallelMode.PARALLEL_1D,
dims={
weight_key: 0,
bias_key: 0
},
partition_states={
weight_key: True,
bias_key: True
},
keep_vars=keep_vars)
destination.update(local_state)
def forward(self, input_: Tensor) -> Tuple[Tensor, Tensor]: def forward(self, input_: Tensor) -> Tuple[Tensor, Tensor]:
assert input_.shape[-1] == self.weight.shape[-1], \ assert input_.shape[-1] == self.weight.shape[-1], \
'Invalid shapes in Linear1D_Col forward: input={}, weight={}. Expected last dim of input {}.'.format( 'Invalid shapes in Linear1D_Col forward: input={}, weight={}. Expected last dim of input {}.'.format(
@ -420,6 +602,52 @@ class Linear1D_Row(ParallelLayer):
num_partition = gpc.get_world_size(ParallelMode.TENSOR) num_partition = gpc.get_world_size(ParallelMode.TENSOR)
set_tensor_parallel_attribute_by_partition(self.weight, num_partition) set_tensor_parallel_attribute_by_partition(self.weight, num_partition)
def _load_from_state_dict(self, state_dict, prefix, *args):
local_state = OrderedDict()
weight_key = prefix + 'weight'
bias_key = prefix + 'bias'
if gpc.get_local_rank(ParallelMode.TENSOR) == 0:
# weight
weight = state_dict.pop(weight_key, None)
if weight is not None:
local_state[weight_key] = weight
# bias
if self.bias is not None:
bias = state_dict.pop(bias_key, None)
if bias is not None:
local_state[bias_key] = bias
local_state = partition_tensor_parallel_state_dict(local_state,
ParallelMode.PARALLEL_1D,
dims={
weight_key: -1,
bias_key: 0
},
partition_states={
weight_key: True,
bias_key: False
})
super()._load_from_state_dict(local_state, prefix, *args)
def _save_to_state_dict(self, destination, prefix, keep_vars):
weight_key = prefix + 'weight'
bias_key = prefix + 'bias'
local_state = OrderedDict({weight_key: self.weight})
if self.bias is not None:
local_state[bias_key] = self.bias
local_state = gather_tensor_parallel_state_dict(local_state,
ParallelMode.PARALLEL_1D,
dims={
weight_key: -1,
bias_key: 0
},
partition_states={
weight_key: True,
bias_key: False
},
keep_vars=keep_vars)
destination.update(local_state)
def forward(self, input_: Tensor) -> Tensor: def forward(self, input_: Tensor) -> Tensor:
# Set up backprop all-reduce. # Set up backprop all-reduce.
if self.parallel_input: if self.parallel_input:
@ -514,6 +742,31 @@ class Embedding1D(ParallelLayer):
with torch.no_grad(): with torch.no_grad():
self.weight[self.padding_idx].fill_(0) self.weight[self.padding_idx].fill_(0)
def _load_from_state_dict(self, state_dict, prefix, *args):
local_state = OrderedDict()
weight_key = prefix + 'weight'
if gpc.get_local_rank(ParallelMode.TENSOR) == 0:
# weight
weight = state_dict.pop(weight_key, None)
if weight is not None:
local_state[weight_key] = weight
local_state = partition_tensor_parallel_state_dict(local_state,
ParallelMode.PARALLEL_1D,
dims={weight_key: -1},
partition_states={weight_key: True})
super()._load_from_state_dict(local_state, prefix, *args)
def _save_to_state_dict(self, destination, prefix, keep_vars):
weight_key = prefix + 'weight'
local_state = OrderedDict({weight_key: self.weight})
local_state = gather_tensor_parallel_state_dict(local_state,
ParallelMode.PARALLEL_1D,
dims={weight_key: -1},
partition_states={weight_key: True},
keep_vars=keep_vars)
destination.update(local_state)
def forward(self, input_: Tensor) -> Tensor: def forward(self, input_: Tensor) -> Tensor:
output_parallel = F.embedding(input_, self.weight, self.padding_idx, *self.embed_args, **self.embed_kwargs) output_parallel = F.embedding(input_, self.weight, self.padding_idx, *self.embed_args, **self.embed_kwargs)
@ -594,10 +847,35 @@ class VocabParallelEmbedding1D(torch.nn.Module):
def _fill_padding_idx_with_zero(self) -> None: def _fill_padding_idx_with_zero(self) -> None:
if self.padding_idx is not None and \ if self.padding_idx is not None and \
self.padding_idx >= self.vocab_start_index and self.padding_idx < self.vocab_end_index: self.padding_idx >= self.vocab_start_index and self.padding_idx < self.vocab_end_index:
with torch.no_grad(): with torch.no_grad():
self.weight[self.padding_idx - self.vocab_start_index].fill_(0) self.weight[self.padding_idx - self.vocab_start_index].fill_(0)
def _load_from_state_dict(self, state_dict, prefix, *args):
local_state = OrderedDict()
weight_key = prefix + 'weight'
if gpc.get_local_rank(ParallelMode.TENSOR) == 0:
# weight
weight = state_dict.pop(weight_key, None)
if weight is not None:
local_state[weight_key] = weight
local_state = partition_tensor_parallel_state_dict(local_state,
ParallelMode.PARALLEL_1D,
dims={weight_key: 0},
partition_states={weight_key: True})
super()._load_from_state_dict(local_state, prefix, *args)
def _save_to_state_dict(self, destination, prefix, keep_vars):
weight_key = prefix + 'weight'
local_state = OrderedDict({weight_key: self.weight})
local_state = gather_tensor_parallel_state_dict(local_state,
ParallelMode.PARALLEL_1D,
dims={weight_key: 0},
partition_states={weight_key: True},
keep_vars=keep_vars)
destination.update(local_state)
def forward(self, input_: Tensor) -> Tensor: def forward(self, input_: Tensor) -> Tensor:
# Build the mask. # Build the mask.
input_mask = (input_ < self.vocab_start_index) | (input_ >= self.vocab_end_index) input_mask = (input_ < self.vocab_start_index) | (input_ >= self.vocab_end_index)
@ -637,3 +915,66 @@ class Dropout1D(ParallelLayer):
else: else:
output = F.dropout(input_, self.p, self.training, self.inplace) output = F.dropout(input_, self.p, self.training, self.inplace)
return output return output
@LAYERS.register_module
class PatchEmbedding1D(ColossalaiModule):
"""
2D Image to Patch Embedding
:param img_size: image size
:type img_size: int
:param patch_size: patch size
:type patch_size: int
:param in_chans: number of channels of input image
:type in_chans: int
:param embed_size: size of embedding
:type embed_size: int
:param dtype: The dtype of parameters, defaults to None
:type dtype: torch.dtype, optional
:param flatten: whether to flatten output tensor, defaults to True
:type flatten: bool, optional
:param weight_initializer: The intializer of weight, defaults to kaiming uniform initializer
:type weight_initializer: typing.Callable, optional
:param bias_initializer: The intializer of bias, defaults to xavier uniform initializer
:type bias_initializer: typing.Callable, optional
:param position_embed_initializer: The intializer of position embedding, defaults to zero
:type position_embed_initializer: typing.Callable, optional
"""
def __init__(self,
img_size: int,
patch_size: int,
in_chans: int,
embed_size: int,
dtype: torch.dtype = None,
flatten: bool = True,
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1),
position_embed_initializer: Callable = init.zeros_()):
embed = VanillaPatchEmbedding(img_size,
patch_size,
in_chans,
embed_size,
dtype=dtype,
flatten=flatten,
weight_initializer=weight_initializer,
bias_initializer=bias_initializer,
position_embed_initializer=position_embed_initializer)
super().__init__(embed)
def _load_from_state_dict(self, state_dict, prefix, *args):
local_state = OrderedDict()
param_keys = [prefix + 'weight', prefix + 'bias', prefix + 'cls_token', prefix + 'pos_embed']
if gpc.get_local_rank(ParallelMode.TENSOR) == 0:
for key in param_keys:
param = state_dict.pop(key, None)
if param is not None:
local_state[key] = param
local_state = broadcast_state_dict(local_state, ParallelMode.PARALLEL_1D)
super()._load_from_state_dict(local_state, prefix, *args)
def _save_to_state_dict(self, destination, prefix, keep_vars):
if gpc.get_local_rank(ParallelMode.TENSOR) == 0:
super()._save_to_state_dict(destination, prefix, keep_vars)

Loading…
Cancel
Save