mirror of https://github.com/hpcaitech/ColossalAI
[model checkpoint] reworked unified layers for ease of save/load states (#593)
parent
acae68eb04
commit
cd13b63832
|
@ -25,3 +25,11 @@ class ParallelLayer(nn.Module):
|
||||||
ParallelMode.PIPELINE)
|
ParallelMode.PIPELINE)
|
||||||
self.pipeline_parallel_size = 1 if not gpc.is_initialized(ParallelMode.PIPELINE) else gpc.get_world_size(
|
self.pipeline_parallel_size = 1 if not gpc.is_initialized(ParallelMode.PIPELINE) else gpc.get_world_size(
|
||||||
ParallelMode.PIPELINE)
|
ParallelMode.PIPELINE)
|
||||||
|
|
||||||
|
def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys,
|
||||||
|
error_msgs):
|
||||||
|
super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys,
|
||||||
|
error_msgs)
|
||||||
|
if gpc.get_local_rank(ParallelMode.TENSOR) != 0:
|
||||||
|
missing_keys.clear()
|
||||||
|
unexpected_keys.clear()
|
||||||
|
|
|
@ -1,3 +1,4 @@
|
||||||
|
import torch.nn as nn
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
|
|
||||||
from ..parallel_2d._operation import split_tensor_2d
|
from ..parallel_2d._operation import split_tensor_2d
|
||||||
|
@ -17,3 +18,21 @@ def partition_batch(input_) -> Tensor:
|
||||||
return _parallel_split_batch[tensor_parallel_mode](input_)
|
return _parallel_split_batch[tensor_parallel_mode](input_)
|
||||||
else:
|
else:
|
||||||
return input_
|
return input_
|
||||||
|
|
||||||
|
|
||||||
|
class ColossalaiModule(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, module: nn.Module, **kwargs):
|
||||||
|
super().__init__()
|
||||||
|
# copy values
|
||||||
|
self.__dict__ = module.__dict__.copy()
|
||||||
|
# copy methods
|
||||||
|
for name, attr in module.__class__.__dict__.items():
|
||||||
|
if name not in ['__init__', 'forward'] and callable(attr):
|
||||||
|
setattr(self, name, getattr(module, name))
|
||||||
|
self._forward_func = module.forward
|
||||||
|
for k, v in kwargs.items():
|
||||||
|
setattr(self, k, v)
|
||||||
|
|
||||||
|
def forward(self, *args):
|
||||||
|
return self._forward_func(*args)
|
||||||
|
|
|
@ -3,9 +3,10 @@ from colossalai.context import ParallelMode, seed
|
||||||
|
|
||||||
from ..parallel_1d import *
|
from ..parallel_1d import *
|
||||||
from ..utils import get_tensor_parallel_mode
|
from ..utils import get_tensor_parallel_mode
|
||||||
|
from ._utils import ColossalaiModule
|
||||||
|
|
||||||
|
|
||||||
class Dropout(nn.Module):
|
class Dropout(ColossalaiModule):
|
||||||
"""Dropout layer of colossalai.
|
"""Dropout layer of colossalai.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
@ -13,16 +14,16 @@ class Dropout(nn.Module):
|
||||||
inplace (bool, optional): whether to do dropout in-place, default to be False.
|
inplace (bool, optional): whether to do dropout in-place, default to be False.
|
||||||
"""
|
"""
|
||||||
def __init__(self, p: float = 0.5, inplace: bool = False) -> None:
|
def __init__(self, p: float = 0.5, inplace: bool = False) -> None:
|
||||||
super().__init__()
|
tensor_parallel = get_tensor_parallel_mode()
|
||||||
self.tensor_parallel = get_tensor_parallel_mode()
|
if tensor_parallel == "1d":
|
||||||
if self.tensor_parallel == '1d':
|
drop = Dropout1D(p, inplace)
|
||||||
self.drop = Dropout1D(p, inplace)
|
|
||||||
else:
|
else:
|
||||||
self.drop = nn.Dropout(p, inplace)
|
drop = nn.Dropout(p, inplace)
|
||||||
|
super().__init__(drop, tensor_parallel=tensor_parallel)
|
||||||
|
|
||||||
def forward(self, *args):
|
def forward(self, *args):
|
||||||
if self.tensor_parallel in [None, '1d']:
|
if self.tensor_parallel in [None, '1d']:
|
||||||
return self.drop(*args)
|
return self._forward_func(*args)
|
||||||
else:
|
else:
|
||||||
with seed(ParallelMode.TENSOR):
|
with seed(ParallelMode.TENSOR):
|
||||||
return self.drop(*args)
|
return self._forward_func(*args)
|
||||||
|
|
|
@ -5,14 +5,16 @@ from colossalai.utils import get_current_device
|
||||||
from torch import dtype, nn
|
from torch import dtype, nn
|
||||||
|
|
||||||
from ... import init as init
|
from ... import init as init
|
||||||
from ..parallel_1d import *
|
from ..parallel_1d import Embedding1D, PatchEmbedding1D, VocabParallelEmbedding1D
|
||||||
from ..parallel_2d import *
|
from ..parallel_2d import Embedding2D, PatchEmbedding2D, VocabParallelEmbedding2D
|
||||||
from ..parallel_2p5d import *
|
from ..parallel_2p5d import Embedding2p5D, PatchEmbedding2p5D, VocabParallelEmbedding2p5D
|
||||||
from ..parallel_3d import *
|
from ..parallel_3d import Embedding3D, PatchEmbedding3D, VocabParallelEmbedding3D
|
||||||
from ..utils import get_tensor_parallel_mode
|
from ..utils import get_tensor_parallel_mode
|
||||||
from ..vanilla import *
|
from ..vanilla import VanillaPatchEmbedding
|
||||||
|
from ._utils import ColossalaiModule
|
||||||
|
|
||||||
_parallel_embedding = {
|
_parallel_embedding = {
|
||||||
|
'1d': Embedding1D,
|
||||||
'2d': Embedding2D,
|
'2d': Embedding2D,
|
||||||
'2.5d': Embedding2p5D,
|
'2.5d': Embedding2p5D,
|
||||||
'3d': Embedding3D,
|
'3d': Embedding3D,
|
||||||
|
@ -27,14 +29,14 @@ _vocab_parallel_embedding = {
|
||||||
|
|
||||||
_parallel_patchembedding = {
|
_parallel_patchembedding = {
|
||||||
None: VanillaPatchEmbedding,
|
None: VanillaPatchEmbedding,
|
||||||
'1d': VanillaPatchEmbedding,
|
'1d': PatchEmbedding1D,
|
||||||
'2d': PatchEmbedding2D,
|
'2d': PatchEmbedding2D,
|
||||||
'2.5d': PatchEmbedding2p5D,
|
'2.5d': PatchEmbedding2p5D,
|
||||||
'3d': PatchEmbedding3D
|
'3d': PatchEmbedding3D
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
class Embedding(nn.Module):
|
class Embedding(ColossalaiModule):
|
||||||
r"""Embedding for colossalai.
|
r"""Embedding for colossalai.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
@ -73,14 +75,13 @@ class Embedding(nn.Module):
|
||||||
vocab_parallel_limit: int = 2048,
|
vocab_parallel_limit: int = 2048,
|
||||||
*args,
|
*args,
|
||||||
**kwargs) -> None:
|
**kwargs) -> None:
|
||||||
super().__init__()
|
|
||||||
tensor_parallel = get_tensor_parallel_mode()
|
tensor_parallel = get_tensor_parallel_mode()
|
||||||
if tensor_parallel is None or (tensor_parallel == '1d' and num_embeddings <= vocab_parallel_limit):
|
if tensor_parallel is None:
|
||||||
self.embed = nn.Embedding(num_embeddings, embedding_dim, padding_idx=padding_idx, *args,
|
embed = nn.Embedding(num_embeddings, embedding_dim, padding_idx=padding_idx, *args,
|
||||||
**kwargs).to(dtype).to(get_current_device())
|
**kwargs).to(dtype).to(get_current_device())
|
||||||
weight_initializer(self.embed.weight, fan_in=num_embeddings, fan_out=embedding_dim)
|
weight_initializer(embed.weight, fan_in=num_embeddings, fan_out=embedding_dim)
|
||||||
elif num_embeddings <= vocab_parallel_limit:
|
elif num_embeddings <= vocab_parallel_limit:
|
||||||
self.embed = _parallel_embedding[tensor_parallel](
|
embed = _parallel_embedding[tensor_parallel](
|
||||||
num_embeddings,
|
num_embeddings,
|
||||||
embedding_dim,
|
embedding_dim,
|
||||||
padding_idx=padding_idx,
|
padding_idx=padding_idx,
|
||||||
|
@ -90,7 +91,7 @@ class Embedding(nn.Module):
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
self.embed = _vocab_parallel_embedding[tensor_parallel](
|
embed = _vocab_parallel_embedding[tensor_parallel](
|
||||||
num_embeddings,
|
num_embeddings,
|
||||||
embedding_dim,
|
embedding_dim,
|
||||||
padding_idx=padding_idx,
|
padding_idx=padding_idx,
|
||||||
|
@ -99,16 +100,10 @@ class Embedding(nn.Module):
|
||||||
*args,
|
*args,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
super().__init__(embed)
|
||||||
@property
|
|
||||||
def weight(self):
|
|
||||||
return self.embed.weight
|
|
||||||
|
|
||||||
def forward(self, *args):
|
|
||||||
return self.embed(*args)
|
|
||||||
|
|
||||||
|
|
||||||
class PatchEmbedding(nn.Module):
|
class PatchEmbedding(ColossalaiModule):
|
||||||
"""2D Image to Patch Embedding.
|
"""2D Image to Patch Embedding.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
@ -141,9 +136,8 @@ class PatchEmbedding(nn.Module):
|
||||||
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1),
|
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1),
|
||||||
position_embed_initializer: Callable = init.zeros_()
|
position_embed_initializer: Callable = init.zeros_()
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
|
||||||
tensor_parallel = get_tensor_parallel_mode()
|
tensor_parallel = get_tensor_parallel_mode()
|
||||||
self.embed = _parallel_patchembedding[tensor_parallel](
|
embed = _parallel_patchembedding[tensor_parallel](
|
||||||
img_size,
|
img_size,
|
||||||
patch_size,
|
patch_size,
|
||||||
in_chans,
|
in_chans,
|
||||||
|
@ -154,22 +148,4 @@ class PatchEmbedding(nn.Module):
|
||||||
bias_initializer=bias_initializer,
|
bias_initializer=bias_initializer,
|
||||||
position_embed_initializer=position_embed_initializer,
|
position_embed_initializer=position_embed_initializer,
|
||||||
)
|
)
|
||||||
|
super().__init__(embed)
|
||||||
@property
|
|
||||||
def weight(self):
|
|
||||||
return self.embed.weight
|
|
||||||
|
|
||||||
@property
|
|
||||||
def bias(self):
|
|
||||||
return self.embed.bias
|
|
||||||
|
|
||||||
@property
|
|
||||||
def pos_embed(self):
|
|
||||||
return self.embed.pos_embed
|
|
||||||
|
|
||||||
@property
|
|
||||||
def cls_token(self):
|
|
||||||
return self.embed.cls_token
|
|
||||||
|
|
||||||
def forward(self, *args):
|
|
||||||
return self.embed(*args)
|
|
||||||
|
|
|
@ -12,6 +12,7 @@ from ..parallel_2p5d import *
|
||||||
from ..parallel_3d import *
|
from ..parallel_3d import *
|
||||||
from ..utils import get_tensor_parallel_mode
|
from ..utils import get_tensor_parallel_mode
|
||||||
from ..vanilla import *
|
from ..vanilla import *
|
||||||
|
from ._utils import ColossalaiModule
|
||||||
|
|
||||||
_parallel_linear = {'1d': Linear1D, '2d': Linear2D, '2.5d': Linear2p5D, '3d': Linear3D}
|
_parallel_linear = {'1d': Linear1D, '2d': Linear2D, '2.5d': Linear2p5D, '3d': Linear3D}
|
||||||
|
|
||||||
|
@ -31,7 +32,7 @@ _vocab_parallel_classifier = {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
class Linear(nn.Module):
|
class Linear(ColossalaiModule):
|
||||||
"""Linear layer of colossalai.
|
"""Linear layer of colossalai.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
@ -71,41 +72,30 @@ class Linear(nn.Module):
|
||||||
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
|
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
|
||||||
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1),
|
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1),
|
||||||
**kwargs) -> None:
|
**kwargs) -> None:
|
||||||
super().__init__()
|
|
||||||
tensor_parallel = get_tensor_parallel_mode()
|
tensor_parallel = get_tensor_parallel_mode()
|
||||||
if tensor_parallel is None:
|
if tensor_parallel is None:
|
||||||
self.layer = nn.Linear(in_features, out_features, bias=bias).to(dtype).to(get_current_device())
|
layer = nn.Linear(in_features, out_features, bias=bias).to(dtype).to(get_current_device())
|
||||||
weight_initializer(self.layer.weight, fan_in=in_features, fan_out=out_features)
|
weight_initializer(layer.weight, fan_in=in_features, fan_out=out_features)
|
||||||
if self.layer.bias is not None:
|
if layer.bias is not None:
|
||||||
bias_initializer(self.layer.bias, fan_in=in_features)
|
bias_initializer(layer.bias, fan_in=in_features)
|
||||||
else:
|
else:
|
||||||
linear_cls = _parallel_linear[tensor_parallel]
|
linear_cls = _parallel_linear[tensor_parallel]
|
||||||
gather_output = kwargs.pop('gather_output', None)
|
gather_output = kwargs.pop('gather_output', None)
|
||||||
if 'gather_output' in inspect.signature(linear_cls.__init__).parameters.keys(): # gather_out arg is available
|
if 'gather_output' in inspect.signature(linear_cls.__init__).parameters.keys(): # gather_out arg is available
|
||||||
kwargs['gather_output'] = gather_output
|
kwargs['gather_output'] = gather_output
|
||||||
self.layer = linear_cls(
|
layer = linear_cls(
|
||||||
in_features,
|
in_features,
|
||||||
out_features,
|
out_features,
|
||||||
bias=bias,
|
bias=bias,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
weight_initializer=weight_initializer,
|
weight_initializer=weight_initializer,
|
||||||
bias_initializer=bias_initializer,
|
bias_initializer=bias_initializer,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
super().__init__(layer)
|
||||||
@property
|
|
||||||
def weight(self):
|
|
||||||
return self.layer.weight
|
|
||||||
|
|
||||||
@property
|
|
||||||
def bias(self):
|
|
||||||
return self.layer.bias
|
|
||||||
|
|
||||||
def forward(self, *args):
|
|
||||||
return self.layer(*args)
|
|
||||||
|
|
||||||
|
|
||||||
class Classifier(nn.Module):
|
class Classifier(ColossalaiModule):
|
||||||
"""Classifier layer of colossalai.
|
"""Classifier layer of colossalai.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
@ -132,10 +122,9 @@ class Classifier(nn.Module):
|
||||||
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
|
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
|
||||||
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1),
|
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1),
|
||||||
vocab_parallel_limit: int = 2048) -> None:
|
vocab_parallel_limit: int = 2048) -> None:
|
||||||
super().__init__()
|
|
||||||
tensor_parallel = get_tensor_parallel_mode()
|
tensor_parallel = get_tensor_parallel_mode()
|
||||||
if num_classes <= vocab_parallel_limit or tensor_parallel is None:
|
if num_classes <= vocab_parallel_limit or tensor_parallel is None:
|
||||||
self.layer = _parallel_classifier[tensor_parallel](
|
layer = _parallel_classifier[tensor_parallel](
|
||||||
in_features,
|
in_features,
|
||||||
num_classes,
|
num_classes,
|
||||||
weight=weight,
|
weight=weight,
|
||||||
|
@ -145,7 +134,7 @@ class Classifier(nn.Module):
|
||||||
bias_initializer=bias_initializer,
|
bias_initializer=bias_initializer,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
self.layer = _vocab_parallel_classifier[tensor_parallel](
|
layer = _vocab_parallel_classifier[tensor_parallel](
|
||||||
in_features,
|
in_features,
|
||||||
num_classes,
|
num_classes,
|
||||||
weight=weight,
|
weight=weight,
|
||||||
|
@ -154,14 +143,4 @@ class Classifier(nn.Module):
|
||||||
weight_initializer=weight_initializer,
|
weight_initializer=weight_initializer,
|
||||||
bias_initializer=bias_initializer,
|
bias_initializer=bias_initializer,
|
||||||
)
|
)
|
||||||
|
super().__init__(layer)
|
||||||
@property
|
|
||||||
def weight(self):
|
|
||||||
return self.layer.weight
|
|
||||||
|
|
||||||
@property
|
|
||||||
def bias(self):
|
|
||||||
return self.layer.bias
|
|
||||||
|
|
||||||
def forward(self, *args):
|
|
||||||
return self.layer(*args)
|
|
||||||
|
|
|
@ -1,24 +1,17 @@
|
||||||
from colossalai.utils import get_current_device
|
from colossalai.utils import get_current_device
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from colossalai import kernel
|
|
||||||
|
|
||||||
from ... import init as init
|
from ..parallel_1d import LayerNorm1D
|
||||||
from ..parallel_1d import *
|
from ..parallel_2d import LayerNorm2D
|
||||||
from ..parallel_2d import *
|
from ..parallel_2p5d import LayerNorm2p5D
|
||||||
from ..parallel_2p5d import *
|
from ..parallel_3d import LayerNorm3D
|
||||||
from ..parallel_3d import *
|
|
||||||
from ..utils import get_tensor_parallel_mode
|
from ..utils import get_tensor_parallel_mode
|
||||||
from ..vanilla import *
|
from ._utils import ColossalaiModule
|
||||||
|
|
||||||
_parallel_layernorm = {
|
_parallel_layernorm = {'1d': LayerNorm1D, '2d': LayerNorm2D, '2.5d': LayerNorm2p5D, '3d': LayerNorm3D}
|
||||||
'1d': kernel.LayerNorm,
|
|
||||||
'2d': LayerNorm2D,
|
|
||||||
'2.5d': LayerNorm2p5D,
|
|
||||||
'3d': LayerNorm3D
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
class LayerNorm(nn.Module):
|
class LayerNorm(ColossalaiModule):
|
||||||
r"""Layer Normalization for colossalai.
|
r"""Layer Normalization for colossalai.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
@ -31,20 +24,9 @@ class LayerNorm(nn.Module):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, normalized_shape: int, eps=1e-05, dtype=None) -> None:
|
def __init__(self, normalized_shape: int, eps=1e-05, dtype=None) -> None:
|
||||||
super().__init__()
|
|
||||||
tensor_parallel = get_tensor_parallel_mode()
|
tensor_parallel = get_tensor_parallel_mode()
|
||||||
if tensor_parallel is None:
|
if tensor_parallel is None:
|
||||||
self.norm = nn.LayerNorm(normalized_shape, eps=eps).to(dtype).to(get_current_device())
|
norm = nn.LayerNorm(normalized_shape, eps=eps).to(dtype).to(get_current_device())
|
||||||
else:
|
else:
|
||||||
self.norm = _parallel_layernorm[tensor_parallel](normalized_shape, eps=eps, dtype=dtype)
|
norm = _parallel_layernorm[tensor_parallel](normalized_shape, eps=eps, dtype=dtype)
|
||||||
|
super().__init__(norm)
|
||||||
@property
|
|
||||||
def weight(self):
|
|
||||||
return self.norm.weight
|
|
||||||
|
|
||||||
@property
|
|
||||||
def bias(self):
|
|
||||||
return self.norm.bias
|
|
||||||
|
|
||||||
def forward(self, *args):
|
|
||||||
return self.norm(*args)
|
|
||||||
|
|
Loading…
Reference in New Issue