From cd13b63832e044ba3eb20c9f60e0122f6aa85790 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E3=82=A2=E3=83=9E=E3=83=87=E3=82=A6=E3=82=B9?= Date: Fri, 1 Apr 2022 16:49:56 +0800 Subject: [PATCH] [model checkpoint] reworked unified layers for ease of save/load states (#593) --- colossalai/nn/layer/base_layer.py | 8 +++ .../nn/layer/colossalai_layer/_utils.py | 19 ++++++ .../nn/layer/colossalai_layer/dropout.py | 17 ++--- .../nn/layer/colossalai_layer/embedding.py | 62 ++++++------------- .../nn/layer/colossalai_layer/linear.py | 61 ++++++------------ .../layer/colossalai_layer/normalization.py | 38 +++--------- 6 files changed, 85 insertions(+), 120 deletions(-) diff --git a/colossalai/nn/layer/base_layer.py b/colossalai/nn/layer/base_layer.py index fd0d6ef5e..041f0fdf0 100644 --- a/colossalai/nn/layer/base_layer.py +++ b/colossalai/nn/layer/base_layer.py @@ -25,3 +25,11 @@ class ParallelLayer(nn.Module): ParallelMode.PIPELINE) self.pipeline_parallel_size = 1 if not gpc.is_initialized(ParallelMode.PIPELINE) else gpc.get_world_size( ParallelMode.PIPELINE) + + def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, + error_msgs): + super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, + error_msgs) + if gpc.get_local_rank(ParallelMode.TENSOR) != 0: + missing_keys.clear() + unexpected_keys.clear() diff --git a/colossalai/nn/layer/colossalai_layer/_utils.py b/colossalai/nn/layer/colossalai_layer/_utils.py index 6271667cc..6f23def9c 100644 --- a/colossalai/nn/layer/colossalai_layer/_utils.py +++ b/colossalai/nn/layer/colossalai_layer/_utils.py @@ -1,3 +1,4 @@ +import torch.nn as nn from torch import Tensor from ..parallel_2d._operation import split_tensor_2d @@ -17,3 +18,21 @@ def partition_batch(input_) -> Tensor: return _parallel_split_batch[tensor_parallel_mode](input_) else: return input_ + + +class ColossalaiModule(nn.Module): + + def __init__(self, module: nn.Module, **kwargs): + super().__init__() + # copy values + self.__dict__ = module.__dict__.copy() + # copy methods + for name, attr in module.__class__.__dict__.items(): + if name not in ['__init__', 'forward'] and callable(attr): + setattr(self, name, getattr(module, name)) + self._forward_func = module.forward + for k, v in kwargs.items(): + setattr(self, k, v) + + def forward(self, *args): + return self._forward_func(*args) diff --git a/colossalai/nn/layer/colossalai_layer/dropout.py b/colossalai/nn/layer/colossalai_layer/dropout.py index b08a5abb7..0df6698d5 100644 --- a/colossalai/nn/layer/colossalai_layer/dropout.py +++ b/colossalai/nn/layer/colossalai_layer/dropout.py @@ -3,9 +3,10 @@ from colossalai.context import ParallelMode, seed from ..parallel_1d import * from ..utils import get_tensor_parallel_mode +from ._utils import ColossalaiModule -class Dropout(nn.Module): +class Dropout(ColossalaiModule): """Dropout layer of colossalai. Args: @@ -13,16 +14,16 @@ class Dropout(nn.Module): inplace (bool, optional): whether to do dropout in-place, default to be False. """ def __init__(self, p: float = 0.5, inplace: bool = False) -> None: - super().__init__() - self.tensor_parallel = get_tensor_parallel_mode() - if self.tensor_parallel == '1d': - self.drop = Dropout1D(p, inplace) + tensor_parallel = get_tensor_parallel_mode() + if tensor_parallel == "1d": + drop = Dropout1D(p, inplace) else: - self.drop = nn.Dropout(p, inplace) + drop = nn.Dropout(p, inplace) + super().__init__(drop, tensor_parallel=tensor_parallel) def forward(self, *args): if self.tensor_parallel in [None, '1d']: - return self.drop(*args) + return self._forward_func(*args) else: with seed(ParallelMode.TENSOR): - return self.drop(*args) + return self._forward_func(*args) diff --git a/colossalai/nn/layer/colossalai_layer/embedding.py b/colossalai/nn/layer/colossalai_layer/embedding.py index 32fd94d4c..e5c9c46e0 100644 --- a/colossalai/nn/layer/colossalai_layer/embedding.py +++ b/colossalai/nn/layer/colossalai_layer/embedding.py @@ -5,14 +5,16 @@ from colossalai.utils import get_current_device from torch import dtype, nn from ... import init as init -from ..parallel_1d import * -from ..parallel_2d import * -from ..parallel_2p5d import * -from ..parallel_3d import * +from ..parallel_1d import Embedding1D, PatchEmbedding1D, VocabParallelEmbedding1D +from ..parallel_2d import Embedding2D, PatchEmbedding2D, VocabParallelEmbedding2D +from ..parallel_2p5d import Embedding2p5D, PatchEmbedding2p5D, VocabParallelEmbedding2p5D +from ..parallel_3d import Embedding3D, PatchEmbedding3D, VocabParallelEmbedding3D from ..utils import get_tensor_parallel_mode -from ..vanilla import * +from ..vanilla import VanillaPatchEmbedding +from ._utils import ColossalaiModule _parallel_embedding = { + '1d': Embedding1D, '2d': Embedding2D, '2.5d': Embedding2p5D, '3d': Embedding3D, @@ -27,14 +29,14 @@ _vocab_parallel_embedding = { _parallel_patchembedding = { None: VanillaPatchEmbedding, - '1d': VanillaPatchEmbedding, + '1d': PatchEmbedding1D, '2d': PatchEmbedding2D, '2.5d': PatchEmbedding2p5D, '3d': PatchEmbedding3D } -class Embedding(nn.Module): +class Embedding(ColossalaiModule): r"""Embedding for colossalai. Args: @@ -73,14 +75,13 @@ class Embedding(nn.Module): vocab_parallel_limit: int = 2048, *args, **kwargs) -> None: - super().__init__() tensor_parallel = get_tensor_parallel_mode() - if tensor_parallel is None or (tensor_parallel == '1d' and num_embeddings <= vocab_parallel_limit): - self.embed = nn.Embedding(num_embeddings, embedding_dim, padding_idx=padding_idx, *args, - **kwargs).to(dtype).to(get_current_device()) - weight_initializer(self.embed.weight, fan_in=num_embeddings, fan_out=embedding_dim) + if tensor_parallel is None: + embed = nn.Embedding(num_embeddings, embedding_dim, padding_idx=padding_idx, *args, + **kwargs).to(dtype).to(get_current_device()) + weight_initializer(embed.weight, fan_in=num_embeddings, fan_out=embedding_dim) elif num_embeddings <= vocab_parallel_limit: - self.embed = _parallel_embedding[tensor_parallel]( + embed = _parallel_embedding[tensor_parallel]( num_embeddings, embedding_dim, padding_idx=padding_idx, @@ -90,7 +91,7 @@ class Embedding(nn.Module): **kwargs, ) else: - self.embed = _vocab_parallel_embedding[tensor_parallel]( + embed = _vocab_parallel_embedding[tensor_parallel]( num_embeddings, embedding_dim, padding_idx=padding_idx, @@ -99,16 +100,10 @@ class Embedding(nn.Module): *args, **kwargs, ) - - @property - def weight(self): - return self.embed.weight - - def forward(self, *args): - return self.embed(*args) + super().__init__(embed) -class PatchEmbedding(nn.Module): +class PatchEmbedding(ColossalaiModule): """2D Image to Patch Embedding. Args: @@ -141,9 +136,8 @@ class PatchEmbedding(nn.Module): bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1), position_embed_initializer: Callable = init.zeros_() ) -> None: - super().__init__() tensor_parallel = get_tensor_parallel_mode() - self.embed = _parallel_patchembedding[tensor_parallel]( + embed = _parallel_patchembedding[tensor_parallel]( img_size, patch_size, in_chans, @@ -154,22 +148,4 @@ class PatchEmbedding(nn.Module): bias_initializer=bias_initializer, position_embed_initializer=position_embed_initializer, ) - - @property - def weight(self): - return self.embed.weight - - @property - def bias(self): - return self.embed.bias - - @property - def pos_embed(self): - return self.embed.pos_embed - - @property - def cls_token(self): - return self.embed.cls_token - - def forward(self, *args): - return self.embed(*args) + super().__init__(embed) diff --git a/colossalai/nn/layer/colossalai_layer/linear.py b/colossalai/nn/layer/colossalai_layer/linear.py index f98156500..35e6a783c 100644 --- a/colossalai/nn/layer/colossalai_layer/linear.py +++ b/colossalai/nn/layer/colossalai_layer/linear.py @@ -12,6 +12,7 @@ from ..parallel_2p5d import * from ..parallel_3d import * from ..utils import get_tensor_parallel_mode from ..vanilla import * +from ._utils import ColossalaiModule _parallel_linear = {'1d': Linear1D, '2d': Linear2D, '2.5d': Linear2p5D, '3d': Linear3D} @@ -31,7 +32,7 @@ _vocab_parallel_classifier = { } -class Linear(nn.Module): +class Linear(ColossalaiModule): """Linear layer of colossalai. Args: @@ -71,41 +72,30 @@ class Linear(nn.Module): weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1), **kwargs) -> None: - super().__init__() tensor_parallel = get_tensor_parallel_mode() if tensor_parallel is None: - self.layer = nn.Linear(in_features, out_features, bias=bias).to(dtype).to(get_current_device()) - weight_initializer(self.layer.weight, fan_in=in_features, fan_out=out_features) - if self.layer.bias is not None: - bias_initializer(self.layer.bias, fan_in=in_features) + layer = nn.Linear(in_features, out_features, bias=bias).to(dtype).to(get_current_device()) + weight_initializer(layer.weight, fan_in=in_features, fan_out=out_features) + if layer.bias is not None: + bias_initializer(layer.bias, fan_in=in_features) else: linear_cls = _parallel_linear[tensor_parallel] gather_output = kwargs.pop('gather_output', None) if 'gather_output' in inspect.signature(linear_cls.__init__).parameters.keys(): # gather_out arg is available kwargs['gather_output'] = gather_output - self.layer = linear_cls( - in_features, - out_features, - bias=bias, - dtype=dtype, - weight_initializer=weight_initializer, - bias_initializer=bias_initializer, - **kwargs, - ) - - @property - def weight(self): - return self.layer.weight - - @property - def bias(self): - return self.layer.bias - - def forward(self, *args): - return self.layer(*args) + layer = linear_cls( + in_features, + out_features, + bias=bias, + dtype=dtype, + weight_initializer=weight_initializer, + bias_initializer=bias_initializer, + **kwargs, + ) + super().__init__(layer) -class Classifier(nn.Module): +class Classifier(ColossalaiModule): """Classifier layer of colossalai. Args: @@ -132,10 +122,9 @@ class Classifier(nn.Module): weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1), vocab_parallel_limit: int = 2048) -> None: - super().__init__() tensor_parallel = get_tensor_parallel_mode() if num_classes <= vocab_parallel_limit or tensor_parallel is None: - self.layer = _parallel_classifier[tensor_parallel]( + layer = _parallel_classifier[tensor_parallel]( in_features, num_classes, weight=weight, @@ -145,7 +134,7 @@ class Classifier(nn.Module): bias_initializer=bias_initializer, ) else: - self.layer = _vocab_parallel_classifier[tensor_parallel]( + layer = _vocab_parallel_classifier[tensor_parallel]( in_features, num_classes, weight=weight, @@ -154,14 +143,4 @@ class Classifier(nn.Module): weight_initializer=weight_initializer, bias_initializer=bias_initializer, ) - - @property - def weight(self): - return self.layer.weight - - @property - def bias(self): - return self.layer.bias - - def forward(self, *args): - return self.layer(*args) + super().__init__(layer) diff --git a/colossalai/nn/layer/colossalai_layer/normalization.py b/colossalai/nn/layer/colossalai_layer/normalization.py index 2e147d9e2..4c6150bd5 100644 --- a/colossalai/nn/layer/colossalai_layer/normalization.py +++ b/colossalai/nn/layer/colossalai_layer/normalization.py @@ -1,24 +1,17 @@ from colossalai.utils import get_current_device from torch import nn -from colossalai import kernel -from ... import init as init -from ..parallel_1d import * -from ..parallel_2d import * -from ..parallel_2p5d import * -from ..parallel_3d import * +from ..parallel_1d import LayerNorm1D +from ..parallel_2d import LayerNorm2D +from ..parallel_2p5d import LayerNorm2p5D +from ..parallel_3d import LayerNorm3D from ..utils import get_tensor_parallel_mode -from ..vanilla import * +from ._utils import ColossalaiModule -_parallel_layernorm = { - '1d': kernel.LayerNorm, - '2d': LayerNorm2D, - '2.5d': LayerNorm2p5D, - '3d': LayerNorm3D -} +_parallel_layernorm = {'1d': LayerNorm1D, '2d': LayerNorm2D, '2.5d': LayerNorm2p5D, '3d': LayerNorm3D} -class LayerNorm(nn.Module): +class LayerNorm(ColossalaiModule): r"""Layer Normalization for colossalai. Args: @@ -31,20 +24,9 @@ class LayerNorm(nn.Module): """ def __init__(self, normalized_shape: int, eps=1e-05, dtype=None) -> None: - super().__init__() tensor_parallel = get_tensor_parallel_mode() if tensor_parallel is None: - self.norm = nn.LayerNorm(normalized_shape, eps=eps).to(dtype).to(get_current_device()) + norm = nn.LayerNorm(normalized_shape, eps=eps).to(dtype).to(get_current_device()) else: - self.norm = _parallel_layernorm[tensor_parallel](normalized_shape, eps=eps, dtype=dtype) - - @property - def weight(self): - return self.norm.weight - - @property - def bias(self): - return self.norm.bias - - def forward(self, *args): - return self.norm(*args) + norm = _parallel_layernorm[tensor_parallel](normalized_shape, eps=eps, dtype=dtype) + super().__init__(norm)