mirror of https://github.com/hpcaitech/ColossalAI
[model checkpoint] updated saving/loading for 1d layers (#594)
parent
7636d518e1
commit
c50bfb807b
|
@ -1,7 +1,7 @@
|
|||
from .layers import (Classifier1D, Dropout1D, Embedding1D, Linear1D, Linear1D_Col, Linear1D_Row,
|
||||
VocabParallelClassifier1D, VocabParallelEmbedding1D)
|
||||
from .layers import (Classifier1D, Dropout1D, Embedding1D, LayerNorm1D, Linear1D, Linear1D_Col, Linear1D_Row,
|
||||
PatchEmbedding1D, VocabParallelClassifier1D, VocabParallelEmbedding1D)
|
||||
|
||||
__all__ = [
|
||||
'Linear1D', 'Linear1D_Col', 'Linear1D_Row', 'Embedding1D', 'Dropout1D', 'Classifier1D', 'VocabParallelClassifier1D',
|
||||
'VocabParallelEmbedding1D'
|
||||
'VocabParallelEmbedding1D', 'LayerNorm1D', 'PatchEmbedding1D'
|
||||
]
|
||||
|
|
|
@ -2,6 +2,7 @@
|
|||
# -*- encoding: utf-8 -*-
|
||||
|
||||
import math
|
||||
from collections import OrderedDict
|
||||
from typing import Callable, Tuple
|
||||
|
||||
import torch
|
||||
|
@ -10,20 +11,25 @@ from colossalai.communication import broadcast
|
|||
from colossalai.context import ParallelMode, seed
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.global_variables import tensor_parallel_env as env
|
||||
from colossalai.kernel import LayerNorm
|
||||
from colossalai.nn import init as init
|
||||
from colossalai.registry import LAYERS
|
||||
from colossalai.utils.checkpointing import (broadcast_state_dict, gather_tensor_parallel_state_dict,
|
||||
partition_tensor_parallel_state_dict)
|
||||
from colossalai.utils.cuda import get_current_device
|
||||
from torch import Tensor
|
||||
from torch.nn.parameter import Parameter
|
||||
from ..vanilla import VanillaPatchEmbedding
|
||||
|
||||
from ..base_layer import ParallelLayer
|
||||
from ..colossalai_layer._utils import ColossalaiModule
|
||||
from ..utils import divide, set_tensor_parallel_attribute_by_partition
|
||||
from ._utils import (gather_forward_split_backward, get_parallel_input, reduce_grad, reduce_input, set_parallel_input,
|
||||
split_forward_gather_backward)
|
||||
|
||||
|
||||
@LAYERS.register_module
|
||||
class Linear1D(torch.nn.Module):
|
||||
class Linear1D(ColossalaiModule):
|
||||
r"""Linear layer for 1D parallelism.
|
||||
|
||||
Args:
|
||||
|
@ -52,37 +58,69 @@ class Linear1D(torch.nn.Module):
|
|||
skip_bias_add: bool = False,
|
||||
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
|
||||
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1)):
|
||||
super().__init__()
|
||||
parallel_input = get_parallel_input()
|
||||
if not parallel_input:
|
||||
self.layer = Linear1D_Col(in_features,
|
||||
out_features,
|
||||
bias=bias,
|
||||
dtype=dtype,
|
||||
gather_output=gather_output,
|
||||
skip_bias_add=skip_bias_add,
|
||||
weight_initializer=weight_initializer,
|
||||
bias_initializer=bias_initializer)
|
||||
layer = Linear1D_Col(in_features,
|
||||
out_features,
|
||||
bias=bias,
|
||||
dtype=dtype,
|
||||
gather_output=gather_output,
|
||||
skip_bias_add=skip_bias_add,
|
||||
weight_initializer=weight_initializer,
|
||||
bias_initializer=bias_initializer)
|
||||
else:
|
||||
self.layer = Linear1D_Row(in_features,
|
||||
out_features,
|
||||
bias=bias,
|
||||
dtype=dtype,
|
||||
parallel_input=parallel_input,
|
||||
skip_bias_add=skip_bias_add,
|
||||
weight_initializer=weight_initializer,
|
||||
bias_initializer=bias_initializer)
|
||||
layer = Linear1D_Row(in_features,
|
||||
out_features,
|
||||
bias=bias,
|
||||
dtype=dtype,
|
||||
parallel_input=parallel_input,
|
||||
skip_bias_add=skip_bias_add,
|
||||
weight_initializer=weight_initializer,
|
||||
bias_initializer=bias_initializer)
|
||||
super().__init__(layer)
|
||||
|
||||
@property
|
||||
def weight(self):
|
||||
return self.layer.weight
|
||||
|
||||
@property
|
||||
def bias(self):
|
||||
return self.layer.bias
|
||||
@LAYERS.register_module
|
||||
class LayerNorm1D(ColossalaiModule):
|
||||
r"""
|
||||
Layer Normalization for colossalai
|
||||
|
||||
def forward(self, input_: Tensor) -> Tensor:
|
||||
return self.layer(input_)
|
||||
:param normalized_shape: input shape from an expected input
|
||||
of size. :math:`[* \times \text{normalized_shape}[0] \times \text{normalized_shape}[1]
|
||||
\times \ldots \times \text{normalized_shape}[-1]]`
|
||||
If a single integer is used, it is treated as a singleton list, and this module will
|
||||
normalize over the last dimension which is expected to be of that specific size.
|
||||
:type normalized_shape: int
|
||||
:param eps: a value added to the denominator for numerical stability, defaults to 1e-05
|
||||
:type eps: float, optional
|
||||
:param dtype: The dtype of parameters, defaults to None
|
||||
:type dtype: torch.dtype, optional
|
||||
"""
|
||||
|
||||
def __init__(self, normalized_shape: int, eps=1e-05, dtype=None):
|
||||
norm = LayerNorm(normalized_shape, eps=eps, device=get_current_device(), dtype=dtype)
|
||||
super().__init__(norm)
|
||||
|
||||
def _load_from_state_dict(self, state_dict, prefix, *args):
|
||||
local_state = OrderedDict()
|
||||
weight_key = prefix + 'weight'
|
||||
bias_key = prefix + 'bias'
|
||||
if gpc.get_local_rank(ParallelMode.TENSOR) == 0:
|
||||
# weight
|
||||
weight = state_dict.pop(weight_key, None)
|
||||
if weight is not None:
|
||||
local_state[weight_key] = weight
|
||||
# bias
|
||||
bias = state_dict.pop(bias_key, None)
|
||||
if bias is not None:
|
||||
local_state[bias_key] = bias
|
||||
|
||||
local_state = broadcast_state_dict(local_state, ParallelMode.PARALLEL_1D)
|
||||
super()._load_from_state_dict(local_state, prefix, *args)
|
||||
|
||||
def _save_to_state_dict(self, destination, prefix, keep_vars):
|
||||
if gpc.get_local_rank(ParallelMode.TENSOR) == 0:
|
||||
super()._save_to_state_dict(destination, prefix, keep_vars)
|
||||
|
||||
|
||||
@LAYERS.register_module
|
||||
|
@ -153,6 +191,55 @@ class Classifier1D(ParallelLayer):
|
|||
num_partition = gpc.get_world_size(ParallelMode.TENSOR)
|
||||
set_tensor_parallel_attribute_by_partition(self.weight, num_partition)
|
||||
|
||||
def _load_from_state_dict(self, state_dict, prefix, *args):
|
||||
local_state = OrderedDict()
|
||||
weight_key = prefix + 'weight'
|
||||
bias_key = prefix + 'bias'
|
||||
if gpc.get_local_rank(ParallelMode.TENSOR) == 0:
|
||||
# weight
|
||||
if self.has_weight:
|
||||
weight = state_dict.pop(weight_key, None)
|
||||
if weight is not None:
|
||||
local_state[weight_key] = weight
|
||||
# bias
|
||||
if self.bias is not None:
|
||||
bias = state_dict.pop(bias_key, None)
|
||||
if bias is not None:
|
||||
local_state[bias_key] = bias
|
||||
|
||||
local_state = partition_tensor_parallel_state_dict(local_state,
|
||||
ParallelMode.PARALLEL_1D,
|
||||
dims={
|
||||
weight_key: -1,
|
||||
bias_key: 0
|
||||
},
|
||||
partition_states={
|
||||
weight_key: True,
|
||||
bias_key: False
|
||||
})
|
||||
super()._load_from_state_dict(local_state, prefix, *args)
|
||||
|
||||
def _save_to_state_dict(self, destination, prefix, keep_vars):
|
||||
weight_key = prefix + 'weight'
|
||||
bias_key = prefix + 'bias'
|
||||
local_state = OrderedDict()
|
||||
if self.has_weight:
|
||||
local_state[weight_key] = self.weight
|
||||
if self.bias is not None:
|
||||
local_state[bias_key] = self.bias
|
||||
local_state = gather_tensor_parallel_state_dict(local_state,
|
||||
ParallelMode.PARALLEL_1D,
|
||||
dims={
|
||||
weight_key: -1,
|
||||
bias_key: 0
|
||||
},
|
||||
partition_states={
|
||||
weight_key: True,
|
||||
bias_key: False
|
||||
},
|
||||
keep_vars=keep_vars)
|
||||
destination.update(local_state)
|
||||
|
||||
def forward(self, input_: Tensor) -> Tensor:
|
||||
# Set up backprop all-reduce.
|
||||
if self.parallel_input:
|
||||
|
@ -241,6 +328,55 @@ class VocabParallelClassifier1D(ParallelLayer):
|
|||
if self.bias is not None:
|
||||
set_tensor_parallel_attribute_by_partition(self.bias, num_partition)
|
||||
|
||||
def _load_from_state_dict(self, state_dict, prefix, *args):
|
||||
local_state = OrderedDict()
|
||||
weight_key = prefix + 'weight'
|
||||
bias_key = prefix + 'bias'
|
||||
if gpc.get_local_rank(ParallelMode.TENSOR) == 0:
|
||||
# weight
|
||||
if self.has_weight:
|
||||
weight = state_dict.pop(weight_key, None)
|
||||
if weight is not None:
|
||||
local_state[weight_key] = weight
|
||||
# bias
|
||||
if self.bias is not None:
|
||||
bias = state_dict.pop(bias_key, None)
|
||||
if bias is not None:
|
||||
local_state[bias_key] = bias
|
||||
|
||||
local_state = partition_tensor_parallel_state_dict(local_state,
|
||||
ParallelMode.PARALLEL_1D,
|
||||
dims={
|
||||
weight_key: 0,
|
||||
bias_key: 0
|
||||
},
|
||||
partition_states={
|
||||
weight_key: True,
|
||||
bias_key: True
|
||||
})
|
||||
super()._load_from_state_dict(local_state, prefix, *args)
|
||||
|
||||
def _save_to_state_dict(self, destination, prefix, keep_vars):
|
||||
weight_key = prefix + 'weight'
|
||||
bias_key = prefix + 'bias'
|
||||
local_state = OrderedDict()
|
||||
if self.has_weight:
|
||||
local_state[weight_key] = self.weight
|
||||
if self.bias is not None:
|
||||
local_state[bias_key] = self.bias
|
||||
local_state = gather_tensor_parallel_state_dict(local_state,
|
||||
ParallelMode.PARALLEL_1D,
|
||||
dims={
|
||||
weight_key: 0,
|
||||
bias_key: 0
|
||||
},
|
||||
partition_states={
|
||||
weight_key: True,
|
||||
bias_key: True
|
||||
},
|
||||
keep_vars=keep_vars)
|
||||
destination.update(local_state)
|
||||
|
||||
def forward(self, input_: Tensor) -> Tensor:
|
||||
assert input_.shape[-1] == self.weight.shape[-1], \
|
||||
'Invalid shapes in VocabParallelClassifier1D forward: input={}, weight={}. Expected last dim of input {}.'.format(
|
||||
|
@ -328,6 +464,52 @@ class Linear1D_Col(ParallelLayer):
|
|||
if self.bias is not None:
|
||||
set_tensor_parallel_attribute_by_partition(self.bias, num_partition)
|
||||
|
||||
def _load_from_state_dict(self, state_dict, prefix, *args):
|
||||
local_state = OrderedDict()
|
||||
weight_key = prefix + 'weight'
|
||||
bias_key = prefix + 'bias'
|
||||
if gpc.get_local_rank(ParallelMode.TENSOR) == 0:
|
||||
# weight
|
||||
weight = state_dict.pop(weight_key, None)
|
||||
if weight is not None:
|
||||
local_state[weight_key] = weight
|
||||
# bias
|
||||
if self.bias is not None:
|
||||
bias = state_dict.pop(bias_key, None)
|
||||
if bias is not None:
|
||||
local_state[bias_key] = bias
|
||||
|
||||
local_state = partition_tensor_parallel_state_dict(local_state,
|
||||
ParallelMode.PARALLEL_1D,
|
||||
dims={
|
||||
weight_key: 0,
|
||||
bias_key: 0
|
||||
},
|
||||
partition_states={
|
||||
weight_key: True,
|
||||
bias_key: True
|
||||
})
|
||||
super()._load_from_state_dict(local_state, prefix, *args)
|
||||
|
||||
def _save_to_state_dict(self, destination, prefix, keep_vars):
|
||||
weight_key = prefix + 'weight'
|
||||
bias_key = prefix + 'bias'
|
||||
local_state = OrderedDict({weight_key: self.weight})
|
||||
if self.bias is not None:
|
||||
local_state[bias_key] = self.bias
|
||||
local_state = gather_tensor_parallel_state_dict(local_state,
|
||||
ParallelMode.PARALLEL_1D,
|
||||
dims={
|
||||
weight_key: 0,
|
||||
bias_key: 0
|
||||
},
|
||||
partition_states={
|
||||
weight_key: True,
|
||||
bias_key: True
|
||||
},
|
||||
keep_vars=keep_vars)
|
||||
destination.update(local_state)
|
||||
|
||||
def forward(self, input_: Tensor) -> Tuple[Tensor, Tensor]:
|
||||
assert input_.shape[-1] == self.weight.shape[-1], \
|
||||
'Invalid shapes in Linear1D_Col forward: input={}, weight={}. Expected last dim of input {}.'.format(
|
||||
|
@ -420,6 +602,52 @@ class Linear1D_Row(ParallelLayer):
|
|||
num_partition = gpc.get_world_size(ParallelMode.TENSOR)
|
||||
set_tensor_parallel_attribute_by_partition(self.weight, num_partition)
|
||||
|
||||
def _load_from_state_dict(self, state_dict, prefix, *args):
|
||||
local_state = OrderedDict()
|
||||
weight_key = prefix + 'weight'
|
||||
bias_key = prefix + 'bias'
|
||||
if gpc.get_local_rank(ParallelMode.TENSOR) == 0:
|
||||
# weight
|
||||
weight = state_dict.pop(weight_key, None)
|
||||
if weight is not None:
|
||||
local_state[weight_key] = weight
|
||||
# bias
|
||||
if self.bias is not None:
|
||||
bias = state_dict.pop(bias_key, None)
|
||||
if bias is not None:
|
||||
local_state[bias_key] = bias
|
||||
|
||||
local_state = partition_tensor_parallel_state_dict(local_state,
|
||||
ParallelMode.PARALLEL_1D,
|
||||
dims={
|
||||
weight_key: -1,
|
||||
bias_key: 0
|
||||
},
|
||||
partition_states={
|
||||
weight_key: True,
|
||||
bias_key: False
|
||||
})
|
||||
super()._load_from_state_dict(local_state, prefix, *args)
|
||||
|
||||
def _save_to_state_dict(self, destination, prefix, keep_vars):
|
||||
weight_key = prefix + 'weight'
|
||||
bias_key = prefix + 'bias'
|
||||
local_state = OrderedDict({weight_key: self.weight})
|
||||
if self.bias is not None:
|
||||
local_state[bias_key] = self.bias
|
||||
local_state = gather_tensor_parallel_state_dict(local_state,
|
||||
ParallelMode.PARALLEL_1D,
|
||||
dims={
|
||||
weight_key: -1,
|
||||
bias_key: 0
|
||||
},
|
||||
partition_states={
|
||||
weight_key: True,
|
||||
bias_key: False
|
||||
},
|
||||
keep_vars=keep_vars)
|
||||
destination.update(local_state)
|
||||
|
||||
def forward(self, input_: Tensor) -> Tensor:
|
||||
# Set up backprop all-reduce.
|
||||
if self.parallel_input:
|
||||
|
@ -514,6 +742,31 @@ class Embedding1D(ParallelLayer):
|
|||
with torch.no_grad():
|
||||
self.weight[self.padding_idx].fill_(0)
|
||||
|
||||
def _load_from_state_dict(self, state_dict, prefix, *args):
|
||||
local_state = OrderedDict()
|
||||
weight_key = prefix + 'weight'
|
||||
if gpc.get_local_rank(ParallelMode.TENSOR) == 0:
|
||||
# weight
|
||||
weight = state_dict.pop(weight_key, None)
|
||||
if weight is not None:
|
||||
local_state[weight_key] = weight
|
||||
|
||||
local_state = partition_tensor_parallel_state_dict(local_state,
|
||||
ParallelMode.PARALLEL_1D,
|
||||
dims={weight_key: -1},
|
||||
partition_states={weight_key: True})
|
||||
super()._load_from_state_dict(local_state, prefix, *args)
|
||||
|
||||
def _save_to_state_dict(self, destination, prefix, keep_vars):
|
||||
weight_key = prefix + 'weight'
|
||||
local_state = OrderedDict({weight_key: self.weight})
|
||||
local_state = gather_tensor_parallel_state_dict(local_state,
|
||||
ParallelMode.PARALLEL_1D,
|
||||
dims={weight_key: -1},
|
||||
partition_states={weight_key: True},
|
||||
keep_vars=keep_vars)
|
||||
destination.update(local_state)
|
||||
|
||||
def forward(self, input_: Tensor) -> Tensor:
|
||||
|
||||
output_parallel = F.embedding(input_, self.weight, self.padding_idx, *self.embed_args, **self.embed_kwargs)
|
||||
|
@ -594,10 +847,35 @@ class VocabParallelEmbedding1D(torch.nn.Module):
|
|||
|
||||
def _fill_padding_idx_with_zero(self) -> None:
|
||||
if self.padding_idx is not None and \
|
||||
self.padding_idx >= self.vocab_start_index and self.padding_idx < self.vocab_end_index:
|
||||
self.padding_idx >= self.vocab_start_index and self.padding_idx < self.vocab_end_index:
|
||||
with torch.no_grad():
|
||||
self.weight[self.padding_idx - self.vocab_start_index].fill_(0)
|
||||
|
||||
def _load_from_state_dict(self, state_dict, prefix, *args):
|
||||
local_state = OrderedDict()
|
||||
weight_key = prefix + 'weight'
|
||||
if gpc.get_local_rank(ParallelMode.TENSOR) == 0:
|
||||
# weight
|
||||
weight = state_dict.pop(weight_key, None)
|
||||
if weight is not None:
|
||||
local_state[weight_key] = weight
|
||||
|
||||
local_state = partition_tensor_parallel_state_dict(local_state,
|
||||
ParallelMode.PARALLEL_1D,
|
||||
dims={weight_key: 0},
|
||||
partition_states={weight_key: True})
|
||||
super()._load_from_state_dict(local_state, prefix, *args)
|
||||
|
||||
def _save_to_state_dict(self, destination, prefix, keep_vars):
|
||||
weight_key = prefix + 'weight'
|
||||
local_state = OrderedDict({weight_key: self.weight})
|
||||
local_state = gather_tensor_parallel_state_dict(local_state,
|
||||
ParallelMode.PARALLEL_1D,
|
||||
dims={weight_key: 0},
|
||||
partition_states={weight_key: True},
|
||||
keep_vars=keep_vars)
|
||||
destination.update(local_state)
|
||||
|
||||
def forward(self, input_: Tensor) -> Tensor:
|
||||
# Build the mask.
|
||||
input_mask = (input_ < self.vocab_start_index) | (input_ >= self.vocab_end_index)
|
||||
|
@ -637,3 +915,66 @@ class Dropout1D(ParallelLayer):
|
|||
else:
|
||||
output = F.dropout(input_, self.p, self.training, self.inplace)
|
||||
return output
|
||||
|
||||
|
||||
@LAYERS.register_module
|
||||
class PatchEmbedding1D(ColossalaiModule):
|
||||
"""
|
||||
2D Image to Patch Embedding
|
||||
|
||||
:param img_size: image size
|
||||
:type img_size: int
|
||||
:param patch_size: patch size
|
||||
:type patch_size: int
|
||||
:param in_chans: number of channels of input image
|
||||
:type in_chans: int
|
||||
:param embed_size: size of embedding
|
||||
:type embed_size: int
|
||||
:param dtype: The dtype of parameters, defaults to None
|
||||
:type dtype: torch.dtype, optional
|
||||
:param flatten: whether to flatten output tensor, defaults to True
|
||||
:type flatten: bool, optional
|
||||
:param weight_initializer: The intializer of weight, defaults to kaiming uniform initializer
|
||||
:type weight_initializer: typing.Callable, optional
|
||||
:param bias_initializer: The intializer of bias, defaults to xavier uniform initializer
|
||||
:type bias_initializer: typing.Callable, optional
|
||||
:param position_embed_initializer: The intializer of position embedding, defaults to zero
|
||||
:type position_embed_initializer: typing.Callable, optional
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
img_size: int,
|
||||
patch_size: int,
|
||||
in_chans: int,
|
||||
embed_size: int,
|
||||
dtype: torch.dtype = None,
|
||||
flatten: bool = True,
|
||||
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
|
||||
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1),
|
||||
position_embed_initializer: Callable = init.zeros_()):
|
||||
embed = VanillaPatchEmbedding(img_size,
|
||||
patch_size,
|
||||
in_chans,
|
||||
embed_size,
|
||||
dtype=dtype,
|
||||
flatten=flatten,
|
||||
weight_initializer=weight_initializer,
|
||||
bias_initializer=bias_initializer,
|
||||
position_embed_initializer=position_embed_initializer)
|
||||
super().__init__(embed)
|
||||
|
||||
def _load_from_state_dict(self, state_dict, prefix, *args):
|
||||
local_state = OrderedDict()
|
||||
param_keys = [prefix + 'weight', prefix + 'bias', prefix + 'cls_token', prefix + 'pos_embed']
|
||||
if gpc.get_local_rank(ParallelMode.TENSOR) == 0:
|
||||
for key in param_keys:
|
||||
param = state_dict.pop(key, None)
|
||||
if param is not None:
|
||||
local_state[key] = param
|
||||
|
||||
local_state = broadcast_state_dict(local_state, ParallelMode.PARALLEL_1D)
|
||||
super()._load_from_state_dict(local_state, prefix, *args)
|
||||
|
||||
def _save_to_state_dict(self, destination, prefix, keep_vars):
|
||||
if gpc.get_local_rank(ParallelMode.TENSOR) == 0:
|
||||
super()._save_to_state_dict(destination, prefix, keep_vars)
|
||||
|
|
Loading…
Reference in New Issue