mirror of https://github.com/hpcaitech/ColossalAI
[shardformer] refactored the shardformer layer structure (#4053)
parent
58df720570
commit
f22ddacef0
|
@ -2,6 +2,9 @@ import re
|
|||
|
||||
|
||||
def get_obj_list_element(obj, a):
|
||||
r"""
|
||||
Get the element of the list in the object
|
||||
"""
|
||||
re_pattern = r'\[\d+\]'
|
||||
prog = re.compile(re_pattern)
|
||||
result = prog.search(a)
|
|
@ -1,17 +1,10 @@
|
|||
from .dropout import Dropout1D
|
||||
from .embedding1d import Embedding1D
|
||||
from .layernorm1d import LayerNorm1D
|
||||
from .linear1d import Linear1D_Col, Linear1D_Row
|
||||
from .linearconv1d import LinearConv1D_Col, LinearConv1D_Row
|
||||
from .vocabparallelembedding1d import VocabParallelEmbedding1D
|
||||
from .embedding import Embedding1D, VocabParallelEmbedding1D
|
||||
from .linear import Linear1D_Col, Linear1D_Row
|
||||
from .linear_conv import LinearConv1D_Col, LinearConv1D_Row
|
||||
from .loss import cross_entropy_1d
|
||||
|
||||
__all__ = [
|
||||
"Embedding1D",
|
||||
"VocabParallelEmbedding1D",
|
||||
"Linear1D_Col",
|
||||
"Linear1D_Row",
|
||||
"LinearConv1D_Col",
|
||||
"LinearConv1D_Row",
|
||||
"LayerNorm1D",
|
||||
"Dropout1D",
|
||||
"Embedding1D", "VocabParallelEmbedding1D", "Linear1D_Col", "Linear1D_Row", "LinearConv1D_Col", "LinearConv1D_Row",
|
||||
"Dropout1D", "cross_entropy_1d"
|
||||
]
|
||||
|
|
|
@ -1,8 +1,6 @@
|
|||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
from colossalai.core import global_context as gpc
|
||||
|
||||
try:
|
||||
import fused_mix_prec_layer_norm_cuda
|
||||
except:
|
||||
|
|
|
@ -4,9 +4,11 @@ import torch
|
|||
import torch.nn as nn
|
||||
from torch.distributed import ProcessGroup
|
||||
|
||||
from .parallelmodule import ParallelModule
|
||||
from .parallel_module import ParallelModule
|
||||
from .utils import create_randomizer_with_offset
|
||||
|
||||
__all__ = ['Dropout1D']
|
||||
|
||||
|
||||
class Dropout1D(ParallelModule, nn.Dropout):
|
||||
"""
|
||||
|
|
|
@ -1,7 +1,6 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- encoding: utf-8 -*-
|
||||
|
||||
from collections import OrderedDict
|
||||
from typing import Callable, List, Union
|
||||
|
||||
import torch
|
||||
|
@ -12,26 +11,148 @@ from torch import Tensor
|
|||
from torch.distributed import ProcessGroup
|
||||
from torch.nn.parameter import Parameter
|
||||
|
||||
from colossalai.context import ParallelMode, seed
|
||||
from colossalai.nn import init as init
|
||||
from colossalai.nn.layer.base_layer import ParallelLayer
|
||||
from colossalai.nn.layer.utils import divide
|
||||
from colossalai.tensor.d_tensor.api import shard_rowwise
|
||||
from colossalai.utils.checkpointing import gather_tensor_parallel_state_dict
|
||||
from colossalai.tensor.d_tensor.api import shard_colwise, shard_rowwise
|
||||
from colossalai.utils.cuda import get_current_device
|
||||
|
||||
from ._operation import reduce_input
|
||||
from .parallelmodule import ParallelModule
|
||||
from ._operation import gather_forward_split_backward, reduce_input
|
||||
from .parallel_module import ParallelModule
|
||||
from .utils import create_randomizer_with_offset
|
||||
|
||||
Fast_LN = None
|
||||
try:
|
||||
from apex.contrib.layer_norm.layer_norm import FastLayerNorm
|
||||
Fast_LN = FastLayerNorm
|
||||
except ImportError:
|
||||
pass
|
||||
__all__ = ['Embedding1D', 'VocabParallelEmbedding1D']
|
||||
|
||||
|
||||
class VocabParallelEmbedding1D(ParallelLayer):
|
||||
class Embedding1D(ParallelModule):
|
||||
r"""Embedding for 1D parallelism.
|
||||
|
||||
Args:
|
||||
num_embeddings (int): number of embeddings.
|
||||
embedding_dim (int): dimension of embedding.
|
||||
padding_idx (int, optional): If specified, the entries at padding_idx do not contribute to the gradient;
|
||||
therefore, the embedding vector at padding_idx is not updated during training,
|
||||
i.e. it remains as a fixed “pad”, defaults to None.
|
||||
dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None.
|
||||
weight_initializer (:class:`typing.Callable`, optional):
|
||||
he initializer of weight, defaults to normal initializer.
|
||||
|
||||
The ``args`` and ``kwargs`` used in :class:`torch.nn.functional.embedding` should contain:
|
||||
::
|
||||
|
||||
max_norm (float, optional): If given, each embedding vector with norm larger than max_norm is
|
||||
renormalized to have norm max_norm. Note: this will modify weight in-place.
|
||||
norm_type (float, optional): The p of the p-norm to compute for the max_norm option. Default 2.
|
||||
scale_grad_by_freq (bool, optional): If given, this will scale gradients by the inverse
|
||||
of frequency of the words in the mini-batch. Default False.
|
||||
sparse (bool, optional): If True, gradient w.r.t. weight will be a sparse tensor. Default False.
|
||||
|
||||
More details about ``args`` and ``kwargs`` could be found in
|
||||
`Embedding <https://pytorch.org/docs/stable/generated/torch.nn.functional.embedding.html#torch.nn.functional.embedding>`_.
|
||||
|
||||
More details about ``initializer`` please refer to
|
||||
`init <https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/nn/init.py>`_
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
num_embeddings: int,
|
||||
embedding_dim: int,
|
||||
padding_idx: int = None,
|
||||
dtype: torch.dtype = None,
|
||||
device: torch.device = None,
|
||||
process_group: ProcessGroup = None,
|
||||
gather_output: bool = True,
|
||||
weight_initializer: Callable = init.normal_(),
|
||||
*args,
|
||||
**kwargs):
|
||||
super().__init__()
|
||||
|
||||
self.num_embeddings = num_embeddings
|
||||
self.embedding_dim = embedding_dim
|
||||
self.process_group = process_group
|
||||
self.num_partitions = dist.get_world_size(process_group)
|
||||
self.embed_dim_per_partition = divide(embedding_dim, self.num_partitions)
|
||||
|
||||
self.padding_idx = padding_idx
|
||||
self.embed_args = args
|
||||
self.embed_kwargs = kwargs
|
||||
self.gather_output = gather_output
|
||||
|
||||
if device is None:
|
||||
device = get_current_device()
|
||||
|
||||
self.weight = Parameter(torch.empty((num_embeddings, self.embed_dim_per_partition), device=device, dtype=dtype))
|
||||
|
||||
# offset the seed with randomizer index and rank
|
||||
seed = torch.random.initial_seed()
|
||||
self.randomizer = create_randomizer_with_offset(seed, process_group=self.process_group)
|
||||
|
||||
with self.randomizer.fork_rng(enable_cpu=True):
|
||||
self.reset_parameters(weight_initializer)
|
||||
|
||||
@staticmethod
|
||||
def from_native_module(module: nn.Embedding,
|
||||
process_group: Union[ProcessGroup, List[ProcessGroup]] = None,
|
||||
*args,
|
||||
**kwargs) -> "Embedding1D":
|
||||
r"""
|
||||
Build a 1D parallelized Embedding from a native nn.Embedding module.
|
||||
"""
|
||||
# get the attributes
|
||||
num_embedding = module.num_embeddings
|
||||
embedding_dim = module.embedding_dim
|
||||
padding_idx = module.padding_idx
|
||||
max_norm = module.max_norm
|
||||
norm_type = module.norm_type
|
||||
scale_grad_by_freq = module.scale_grad_by_freq
|
||||
sparse = module.sparse
|
||||
dtype = module.weight.dtype
|
||||
device = module.weight.device
|
||||
|
||||
# sparse is not support yet
|
||||
if sparse:
|
||||
raise NotImplementedError("The Embedding1D module does not support sparse embedding yet.")
|
||||
|
||||
embedding = Embedding1D(num_embeddings=num_embedding,
|
||||
embedding_dim=embedding_dim,
|
||||
padding_idx=padding_idx,
|
||||
process_group=process_group,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
max_norm=max_norm,
|
||||
norm_type=norm_type,
|
||||
scale_grad_by_freq=scale_grad_by_freq,
|
||||
sparse=sparse,
|
||||
*args,
|
||||
**kwargs)
|
||||
|
||||
# copy the weight
|
||||
with torch.no_grad():
|
||||
sharded_weight = shard_colwise(module.weight.data, process_group)
|
||||
embedding.weight.copy_(sharded_weight)
|
||||
|
||||
return embedding
|
||||
|
||||
def reset_parameters(self, weight_initializer) -> None:
|
||||
fan_in, fan_out = self.num_embeddings, self.embedding_dim
|
||||
weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out)
|
||||
self._fill_padding_idx_with_zero()
|
||||
|
||||
def _fill_padding_idx_with_zero(self) -> None:
|
||||
if self.padding_idx is not None:
|
||||
with torch.no_grad():
|
||||
self.weight[self.padding_idx].fill_(0)
|
||||
|
||||
def forward(self, input_: Tensor) -> Tensor:
|
||||
output_parallel = F.embedding(input_, self.weight, self.padding_idx, *self.embed_args, **self.embed_kwargs)
|
||||
|
||||
if self.gather_output:
|
||||
output = gather_forward_split_backward(output_parallel, dim=-1, process_group=self.process_group)
|
||||
return output
|
||||
else:
|
||||
return output_parallel
|
||||
|
||||
|
||||
class VocabParallelEmbedding1D(ParallelModule):
|
||||
r"""Embedding parallelized in the vocabulary dimension.
|
||||
|
||||
Args:
|
||||
|
@ -93,9 +214,7 @@ class VocabParallelEmbedding1D(ParallelLayer):
|
|||
# offset the seed with randomizer index and rank
|
||||
seed = torch.random.initial_seed()
|
||||
self.randomizer = create_randomizer_with_offset(seed, process_group=self.process_group)
|
||||
|
||||
with self.randomizer.fork_rng(enable_cpu=True):
|
||||
self.reset_parameters(weight_initializer)
|
||||
self.reset_parameters(weight_initializer)
|
||||
|
||||
@staticmethod
|
||||
def from_native_module(module: nn.Embedding, process_group: Union[ProcessGroup, List[ProcessGroup]], *args,
|
||||
|
@ -132,7 +251,7 @@ class VocabParallelEmbedding1D(ParallelLayer):
|
|||
return vocab_embedding_1d
|
||||
|
||||
def reset_parameters(self, weight_initializer) -> None:
|
||||
with seed(ParallelMode.TENSOR):
|
||||
with self.randomizer.fork_rng(enable_cpu=True):
|
||||
fan_in, fan_out = self.num_embeddings, self.embed_dim
|
||||
weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out)
|
||||
self._fill_padding_idx_with_zero()
|
||||
|
@ -143,16 +262,6 @@ class VocabParallelEmbedding1D(ParallelLayer):
|
|||
with torch.no_grad():
|
||||
self.weight[self.padding_idx - self.vocab_start_index].fill_(0)
|
||||
|
||||
def _save_to_global_state_dict(self, destination, prefix, keep_vars):
|
||||
weight_key = prefix + 'weight'
|
||||
local_state = OrderedDict({weight_key: self.weight})
|
||||
local_state = gather_tensor_parallel_state_dict(local_state,
|
||||
ParallelMode.PARALLEL_1D,
|
||||
dims={weight_key: 0},
|
||||
partition_states={weight_key: True},
|
||||
keep_vars=keep_vars)
|
||||
destination.update(local_state)
|
||||
|
||||
def forward(self, input_: Tensor) -> Tensor:
|
||||
# Build the mask.
|
||||
input_mask = (input_ < self.vocab_start_index) | (input_ >= self.vocab_end_index)
|
|
@ -1,157 +0,0 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- encoding: utf-8 -*-
|
||||
|
||||
from typing import Callable, List, Union
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch import Tensor
|
||||
from torch.distributed import ProcessGroup
|
||||
from torch.nn.parameter import Parameter
|
||||
|
||||
from colossalai.nn import init as init
|
||||
from colossalai.nn.layer.utils import divide
|
||||
from colossalai.tensor.d_tensor.api import shard_colwise
|
||||
from colossalai.utils.cuda import get_current_device
|
||||
|
||||
from ._operation import gather_forward_split_backward
|
||||
from .parallelmodule import ParallelModule
|
||||
from .utils import create_randomizer_with_offset
|
||||
|
||||
Fast_LN = None
|
||||
try:
|
||||
from apex.contrib.layer_norm.layer_norm import FastLayerNorm
|
||||
Fast_LN = FastLayerNorm
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
|
||||
class Embedding1D(ParallelModule):
|
||||
r"""Embedding for 1D parallelism.
|
||||
|
||||
Args:
|
||||
num_embeddings (int): number of embeddings.
|
||||
embedding_dim (int): dimension of embedding.
|
||||
padding_idx (int, optional): If specified, the entries at padding_idx do not contribute to the gradient;
|
||||
therefore, the embedding vector at padding_idx is not updated during training,
|
||||
i.e. it remains as a fixed “pad”, defaults to None.
|
||||
dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None.
|
||||
weight_initializer (:class:`typing.Callable`, optional):
|
||||
he initializer of weight, defaults to normal initializer.
|
||||
|
||||
The ``args`` and ``kwargs`` used in :class:`torch.nn.functional.embedding` should contain:
|
||||
::
|
||||
|
||||
max_norm (float, optional): If given, each embedding vector with norm larger than max_norm is
|
||||
renormalized to have norm max_norm. Note: this will modify weight in-place.
|
||||
norm_type (float, optional): The p of the p-norm to compute for the max_norm option. Default 2.
|
||||
scale_grad_by_freq (bool, optional): If given, this will scale gradients by the inverse
|
||||
of frequency of the words in the mini-batch. Default False.
|
||||
sparse (bool, optional): If True, gradient w.r.t. weight will be a sparse tensor. Default False.
|
||||
|
||||
More details about ``args`` and ``kwargs`` could be found in
|
||||
`Embedding <https://pytorch.org/docs/stable/generated/torch.nn.functional.embedding.html#torch.nn.functional.embedding>`_.
|
||||
|
||||
More details about ``initializer`` please refer to
|
||||
`init <https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/nn/init.py>`_
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
num_embeddings: int,
|
||||
embedding_dim: int,
|
||||
padding_idx: int = None,
|
||||
dtype: torch.dtype = None,
|
||||
device: torch.device = None,
|
||||
process_group: ProcessGroup = None,
|
||||
gather_output: bool = True,
|
||||
weight_initializer: Callable = init.normal_(),
|
||||
*args,
|
||||
**kwargs):
|
||||
super().__init__()
|
||||
|
||||
self.num_embeddings = num_embeddings
|
||||
self.embedding_dim = embedding_dim
|
||||
self.process_group = process_group
|
||||
self.num_partitions = dist.get_world_size(process_group)
|
||||
self.embed_dim_per_partition = divide(embedding_dim, self.num_partitions)
|
||||
|
||||
self.padding_idx = padding_idx
|
||||
self.embed_args = args
|
||||
self.embed_kwargs = kwargs
|
||||
self.gather_output = gather_output
|
||||
|
||||
if device is None:
|
||||
device = get_current_device()
|
||||
|
||||
self.weight = Parameter(torch.empty((num_embeddings, self.embed_dim_per_partition), device=device, dtype=dtype))
|
||||
|
||||
# offset the seed with randomizer index and rank
|
||||
seed = torch.random.initial_seed()
|
||||
self.randomizer = create_randomizer_with_offset(seed, process_group=self.process_group)
|
||||
|
||||
with self.randomizer.fork_rng(enable_cpu=True):
|
||||
self.reset_parameters(weight_initializer)
|
||||
|
||||
@staticmethod
|
||||
def from_native_module(module: nn.Embedding,
|
||||
process_group: Union[ProcessGroup, List[ProcessGroup]] = None,
|
||||
*args,
|
||||
**kwargs) -> "Embedding1D":
|
||||
r"""
|
||||
Build a 1D parallelized Embedding from a native nn.Embedding module.
|
||||
"""
|
||||
# get the attributes
|
||||
num_embedding = module.num_embeddings
|
||||
embedding_dim = module.embedding_dim
|
||||
padding_idx = module.padding_idx
|
||||
max_norm = module.max_norm
|
||||
norm_type = module.norm_type
|
||||
scale_grad_by_freq = module.scale_grad_by_freq
|
||||
sparse = module.sparse
|
||||
dtype = module.weight.dtype
|
||||
device = module.weight.device
|
||||
|
||||
# sparse is not support yet
|
||||
if sparse:
|
||||
raise NotImplementedError("The Embedding1D module does not support sparse embedding yet.")
|
||||
|
||||
embedding = Embedding1D(num_embeddings=num_embedding,
|
||||
embedding_dim=embedding_dim,
|
||||
padding_idx=padding_idx,
|
||||
process_group=process_group,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
max_norm=max_norm,
|
||||
norm_type=norm_type,
|
||||
scale_grad_by_freq=scale_grad_by_freq,
|
||||
sparse=sparse,
|
||||
*args,
|
||||
**kwargs)
|
||||
|
||||
# copy the weight
|
||||
with torch.no_grad():
|
||||
sharded_weight = shard_colwise(module.weight.data, process_group)
|
||||
embedding.weight.copy_(sharded_weight)
|
||||
|
||||
return embedding
|
||||
|
||||
def reset_parameters(self, weight_initializer) -> None:
|
||||
fan_in, fan_out = self.num_embeddings, self.embedding_dim
|
||||
weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out)
|
||||
self._fill_padding_idx_with_zero()
|
||||
|
||||
def _fill_padding_idx_with_zero(self) -> None:
|
||||
if self.padding_idx is not None:
|
||||
with torch.no_grad():
|
||||
self.weight[self.padding_idx].fill_(0)
|
||||
|
||||
def forward(self, input_: Tensor) -> Tensor:
|
||||
output_parallel = F.embedding(input_, self.weight, self.padding_idx, *self.embed_args, **self.embed_kwargs)
|
||||
|
||||
if self.gather_output:
|
||||
output = gather_forward_split_backward(output_parallel, dim=-1, process_group=self.process_group)
|
||||
return output
|
||||
else:
|
||||
return output_parallel
|
|
@ -1,73 +0,0 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- encoding: utf-8 -*-
|
||||
|
||||
from collections import OrderedDict
|
||||
|
||||
from colossalai.context import ParallelMode, seed
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.global_variables import tensor_parallel_env as env
|
||||
from colossalai.kernel import LayerNorm
|
||||
from colossalai.nn import init as init
|
||||
from colossalai.nn.layer.colossalai_layer._utils import ColossalaiModule
|
||||
from colossalai.utils.checkpointing import broadcast_state_dict
|
||||
|
||||
Fast_LN = None
|
||||
try:
|
||||
from apex.contrib.layer_norm.layer_norm import FastLayerNorm
|
||||
Fast_LN = FastLayerNorm
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
|
||||
class LayerNorm1D(ColossalaiModule):
|
||||
r"""
|
||||
Layer Normalization for colossalai
|
||||
|
||||
Args:
|
||||
normalized_shape (int): input shape from an expected input of size.
|
||||
:math:`[* \times \text{normalized_shape}[0] \times \text{normalized_shape}[1]
|
||||
\times \ldots \times \text{normalized_shape}[-1]]`
|
||||
If a single integer is used, it is treated as a singleton list, and this module will
|
||||
normalize over the last dimension which is expected to be of that specific size.
|
||||
eps (float): a value added to the denominator for numerical stability, defaults to 1e-05.
|
||||
bias (bool, optional): Whether to add a bias, defaults to ``True``.
|
||||
dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None.
|
||||
"""
|
||||
|
||||
_fast_ln_supported_sizes = [
|
||||
1024, 1536, 2048, 2304, 3072, 3840, 4096, 5120, 6144, 8192, 10240, 12288, 12800, 15360, 16384, 18432, 20480,
|
||||
24576, 25600, 30720, 32768, 40960, 49152, 65536
|
||||
]
|
||||
|
||||
def __init__(self, normalized_shape: int, eps=1e-05, bias=True, dtype=None):
|
||||
if Fast_LN is not None and normalized_shape in self._fast_ln_supported_sizes:
|
||||
norm = Fast_LN(normalized_shape, eps=eps).to(dtype)
|
||||
else:
|
||||
norm = None
|
||||
try:
|
||||
from apex.normalization import FusedLayerNorm
|
||||
norm = FusedLayerNorm(normalized_shape, eps=eps).to(dtype)
|
||||
except ImportError:
|
||||
norm = LayerNorm(normalized_shape, eps=eps).to(dtype)
|
||||
super().__init__(norm)
|
||||
|
||||
def _load_from_state_dict(self, state_dict, prefix, *args):
|
||||
local_state = OrderedDict()
|
||||
weight_key = prefix + 'weight'
|
||||
bias_key = prefix + 'bias'
|
||||
if gpc.get_local_rank(ParallelMode.TENSOR) == 0:
|
||||
# weight
|
||||
weight = state_dict.pop(weight_key, None)
|
||||
if weight is not None:
|
||||
local_state[weight_key] = weight
|
||||
# bias
|
||||
bias = state_dict.pop(bias_key, None)
|
||||
if bias is not None:
|
||||
local_state[bias_key] = bias
|
||||
|
||||
local_state = broadcast_state_dict(local_state, ParallelMode.PARALLEL_1D)
|
||||
super()._load_from_state_dict(local_state, prefix, *args)
|
||||
|
||||
def _save_to_state_dict(self, destination, prefix, keep_vars):
|
||||
if gpc.get_local_rank(ParallelMode.TENSOR) == 0:
|
||||
super()._save_to_state_dict(destination, prefix, keep_vars)
|
|
@ -23,15 +23,10 @@ from ._operation import (
|
|||
reduce_input,
|
||||
split_forward_gather_backward,
|
||||
)
|
||||
from .parallelmodule import ParallelModule
|
||||
from .parallel_module import ParallelModule
|
||||
from .utils import create_randomizer_with_offset
|
||||
|
||||
Fast_LN = None
|
||||
try:
|
||||
from apex.contrib.layer_norm.layer_norm import FastLayerNorm
|
||||
Fast_LN = FastLayerNorm
|
||||
except ImportError:
|
||||
pass
|
||||
__all__ = ['Linear1D_Col', 'Linear1D_Row']
|
||||
|
||||
|
||||
class Linear1D_Col(ParallelModule):
|
||||
|
@ -104,8 +99,8 @@ class Linear1D_Col(ParallelModule):
|
|||
seed = torch.random.initial_seed()
|
||||
self.randomizer = create_randomizer_with_offset(seed, process_group=self.process_group)
|
||||
|
||||
with self.randomizer.fork_rng(enable_cpu=True):
|
||||
self.reset_parameters(weight_initializer, bias_initializer)
|
||||
# init weights
|
||||
self.reset_parameters(weight_initializer, bias_initializer)
|
||||
|
||||
@staticmethod
|
||||
def from_native_module(module: nn.Linear, process_group: Union[ProcessGroup, List[ProcessGroup]], *args,
|
||||
|
@ -146,10 +141,11 @@ class Linear1D_Col(ParallelModule):
|
|||
return linear_1d
|
||||
|
||||
def reset_parameters(self, weight_initializer, bias_initializer) -> None:
|
||||
fan_in, fan_out = self.in_features, self.out_features
|
||||
weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out)
|
||||
if self.bias is not None:
|
||||
bias_initializer(self.bias, fan_in=fan_in)
|
||||
with self.randomizer.fork_rng(enable_cpu=True):
|
||||
fan_in, fan_out = self.in_features, self.out_features
|
||||
weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out)
|
||||
if self.bias is not None:
|
||||
bias_initializer(self.bias, fan_in=fan_in)
|
||||
|
||||
def forward(self, input_: Tensor) -> Tuple[Tensor, Tensor]:
|
||||
assert input_.shape[-1] == self.weight.shape[-1], \
|
|
@ -23,19 +23,15 @@ from ._operation import (
|
|||
reduce_input,
|
||||
split_forward_gather_backward,
|
||||
)
|
||||
from .parallelmodule import ParallelModule
|
||||
from .parallel_module import ParallelModule
|
||||
from .utils import create_randomizer_with_offset
|
||||
|
||||
Fast_LN = None
|
||||
try:
|
||||
from apex.contrib.layer_norm.layer_norm import FastLayerNorm
|
||||
Fast_LN = FastLayerNorm
|
||||
except ImportError:
|
||||
pass
|
||||
__all__ = ['LinearConv1D_Col', 'LinearConv1D_Row']
|
||||
|
||||
|
||||
class LinearConv1D_Col(ParallelModule):
|
||||
r"""Linear layer with column parallelism.
|
||||
Specially created for HuggingFace's GPT2 model.
|
||||
|
||||
The linear layer is defined as :math:`Y = XA + b`. A is parallelized along
|
||||
its second dimension as :math:`A = [A_1, ..., A_p]`. This layer is used to fit `Conv1D` layer in gpt2 of huggingface.
|
||||
|
@ -104,8 +100,8 @@ class LinearConv1D_Col(ParallelModule):
|
|||
seed = torch.random.initial_seed()
|
||||
self.randomizer = create_randomizer_with_offset(seed, process_group=self.process_group)
|
||||
|
||||
with self.randomizer.fork_rng(enable_cpu=True):
|
||||
self.reset_parameters(weight_initializer, bias_initializer)
|
||||
# init weights
|
||||
self.reset_parameters(weight_initializer, bias_initializer)
|
||||
|
||||
@staticmethod
|
||||
def from_native_module(module: nn.Linear, process_group: Union[ProcessGroup, List[ProcessGroup]], n_cast: int,
|
||||
|
@ -162,10 +158,11 @@ class LinearConv1D_Col(ParallelModule):
|
|||
return linear_1d
|
||||
|
||||
def reset_parameters(self, weight_initializer, bias_initializer) -> None:
|
||||
fan_in, fan_out = self.in_features, self.out_features
|
||||
weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out)
|
||||
if self.bias is not None:
|
||||
bias_initializer(self.bias, fan_in=fan_in)
|
||||
with self.randomizer.fork_rng(enable_cpu=True):
|
||||
fan_in, fan_out = self.in_features, self.out_features
|
||||
weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out)
|
||||
if self.bias is not None:
|
||||
bias_initializer(self.bias, fan_in=fan_in)
|
||||
|
||||
def forward(self, input_: Tensor) -> Tuple[Tensor, Tensor]:
|
||||
assert input_.shape[-1] == self.weight.shape[-1], \
|
||||
|
@ -192,6 +189,7 @@ class LinearConv1D_Col(ParallelModule):
|
|||
|
||||
class LinearConv1D_Row(ParallelModule):
|
||||
r""" Linear layer with row parallelism
|
||||
Specially created for HuggingFace's GPT2 model.
|
||||
|
||||
Args:
|
||||
in_features (int): size of each input sample.
|
||||
|
@ -260,8 +258,8 @@ class LinearConv1D_Row(ParallelModule):
|
|||
seed = torch.random.initial_seed()
|
||||
self.randomizer = create_randomizer_with_offset(seed, process_group=self.process_group)
|
||||
|
||||
with self.randomizer.fork_rng(enable_cpu=True):
|
||||
self.reset_parameters(weight_initializer, bias_initializer)
|
||||
# init weights
|
||||
self.reset_parameters(weight_initializer, bias_initializer)
|
||||
|
||||
@staticmethod
|
||||
def from_native_module(module: nn.Linear, process_group: Union[ProcessGroup, List[ProcessGroup]], n_cast: int,
|
||||
|
@ -320,20 +318,21 @@ class LinearConv1D_Row(ParallelModule):
|
|||
self.weight_list = torch.chunk(self.weight, self.stream_chunk_num, dim=0)
|
||||
|
||||
def reset_parameters(self, weight_initializer, bias_initializer) -> None:
|
||||
fan_in, fan_out = self.in_features, self.out_features
|
||||
weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out)
|
||||
with self.randomizer.fork_rng(enable_cpu=True):
|
||||
fan_in, fan_out = self.in_features, self.out_features
|
||||
weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out)
|
||||
|
||||
if self.bias is not None:
|
||||
bias_initializer(self.bias, fan_in=fan_in)
|
||||
if self.process_group is None:
|
||||
src_rank = 0
|
||||
else:
|
||||
src_rank = dist.distributed_c10d._get_global_rank(self.process_group, 0)
|
||||
if self.bias is not None:
|
||||
bias_initializer(self.bias, fan_in=fan_in)
|
||||
if self.process_group is None:
|
||||
src_rank = 0
|
||||
else:
|
||||
src_rank = dist.distributed_c10d._get_global_rank(self.process_group, 0)
|
||||
|
||||
origin_device = self.bias.device
|
||||
self.bias = self.bias.cuda()
|
||||
dist.broadcast(self.bias, src=src_rank, group=self.process_group)
|
||||
self.bias = self.bias.to(origin_device)
|
||||
origin_device = self.bias.device
|
||||
self.bias = self.bias.cuda()
|
||||
dist.broadcast(self.bias, src=src_rank, group=self.process_group)
|
||||
self.bias = self.bias.to(origin_device)
|
||||
|
||||
def forward(self, input_: Tensor) -> Tensor:
|
||||
# Set up backprop all-reduce.
|
|
@ -1,10 +1,10 @@
|
|||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch.autograd import Function
|
||||
from torch.distributed import ProcessGroup
|
||||
|
||||
__all__ = ['DistCrossEntropy', 'cross_entropy_1d']
|
||||
|
||||
|
||||
class DistCrossEntropy(Function):
|
||||
r"""
|
|
@ -7,15 +7,7 @@ from typing import List, Union
|
|||
import torch.nn as nn
|
||||
from torch.distributed import ProcessGroup
|
||||
|
||||
from colossalai.global_variables import tensor_parallel_env as env
|
||||
from colossalai.nn import init as init
|
||||
|
||||
Fast_LN = None
|
||||
try:
|
||||
from apex.contrib.layer_norm.layer_norm import FastLayerNorm
|
||||
Fast_LN = FastLayerNorm
|
||||
except ImportError:
|
||||
pass
|
||||
__all__ = ['ParallelModule']
|
||||
|
||||
|
||||
class ParallelModule(nn.Module, ABC):
|
|
@ -2,7 +2,7 @@
|
|||
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Callable, Dict, List, Tuple, Type, Union
|
||||
from typing import Any, Callable, Dict, List, Type, Union
|
||||
|
||||
import torch.nn as nn
|
||||
|
||||
|
|
|
@ -4,7 +4,7 @@ from transformers.models.bert.modeling_bert import BertEmbeddings, BertLayer, Be
|
|||
import colossalai.shardformer.layer as col_nn
|
||||
from colossalai.shardformer.layer.dropout import Dropout1D
|
||||
|
||||
from ..utils import getattr_, setattr_
|
||||
from .._utils import getattr_, setattr_
|
||||
from .basepolicy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
|
||||
|
||||
|
||||
|
|
|
@ -1,12 +1,7 @@
|
|||
from typing import Type, Union
|
||||
|
||||
import torch.nn as nn
|
||||
from transformers.models.gpt2.modeling_gpt2 import GPT2Block, GPT2Model
|
||||
|
||||
import colossalai.shardformer.layer as col_nn
|
||||
from colossalai.shardformer.layer.dropout import Dropout1D
|
||||
|
||||
from ..utils import getattr_, setattr_
|
||||
from .basepolicy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
|
||||
|
||||
|
||||
|
|
|
@ -1,5 +1,3 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
from transformers import T5ForConditionalGeneration
|
||||
from transformers.models.t5.modeling_t5 import (
|
||||
T5Attention,
|
||||
|
|
|
@ -4,9 +4,9 @@ import torch.nn as nn
|
|||
|
||||
from colossalai.cluster.process_group_manager import ProcessGroupManager
|
||||
|
||||
from .._utils import getattr_, setattr_
|
||||
from ..policies.autopolicy import get_autopolicy
|
||||
from ..policies.basepolicy import Policy, SubModuleReplacementDescription
|
||||
from ..utils.utils import getattr_, setattr_
|
||||
from .shard_config import ShardConfig
|
||||
|
||||
__all__ = ['ModelSharder', 'shard_model']
|
||||
|
|
|
@ -1 +0,0 @@
|
|||
from .utils import getattr_, hasattr_, setattr_
|
|
@ -4,7 +4,7 @@ import torch.nn.functional as F
|
|||
|
||||
import colossalai
|
||||
from colossalai.logging import disable_existing_loggers
|
||||
from colossalai.shardformer.layer.dist_crossentropy import applyDistCrossEntropy
|
||||
from colossalai.shardformer.layer import cross_entropy_1d
|
||||
from colossalai.testing import rerun_if_address_is_in_use, spawn
|
||||
|
||||
CONFIG = dict(parallel=dict(data=1, pipeline=1, tensor=dict(size=2, mode='1d')),)
|
||||
|
@ -25,7 +25,7 @@ def check_dist_crossentropy(rank, world_size, port, ignore_index):
|
|||
org_loss = F.cross_entropy(org_pred, org_labels)
|
||||
|
||||
dist_pred = pred.chunk(world_size, -1)[rank]
|
||||
dist_loss = applyDistCrossEntropy(dist_pred.to('cuda'), labels.to('cuda'), ignore_index=ignore_index)
|
||||
dist_loss = cross_entropy_1d(dist_pred.to('cuda'), labels.to('cuda'), ignore_index=ignore_index)
|
||||
|
||||
assert torch.allclose(org_loss, dist_loss,
|
||||
atol=1e-5), f"dist cross entropy loss is not equal to orgin loss\n{org_loss}\n{dist_loss}"
|
|
@ -3,7 +3,7 @@ import torch.distributed as dist
|
|||
import torch.nn as nn
|
||||
|
||||
import colossalai
|
||||
from colossalai.shardformer.layer.dropout import Dropout1D
|
||||
from colossalai.shardformer.layer import Dropout1D
|
||||
from colossalai.testing import assert_equal, assert_not_equal, rerun_if_address_is_in_use, spawn
|
||||
|
||||
|
||||
|
|
|
@ -4,7 +4,7 @@ import torch.nn as nn
|
|||
from torch.testing import assert_close
|
||||
|
||||
import colossalai
|
||||
from colossalai.shardformer.layer.layers import Embedding1D
|
||||
from colossalai.shardformer.layer import Embedding1D
|
||||
from colossalai.testing import rerun_if_address_is_in_use, spawn
|
||||
|
||||
|
||||
|
|
|
@ -4,7 +4,7 @@ import torch.nn as nn
|
|||
from torch.testing import assert_close
|
||||
|
||||
import colossalai
|
||||
from colossalai.shardformer.layer.layers import Linear1D_Col, Linear1D_Row
|
||||
from colossalai.shardformer.layer import Linear1D_Col, Linear1D_Row
|
||||
from colossalai.testing import rerun_if_address_is_in_use, spawn
|
||||
|
||||
|
||||
|
|
|
@ -4,7 +4,7 @@ import torch.nn as nn
|
|||
from torch.testing import assert_close
|
||||
|
||||
import colossalai
|
||||
from colossalai.shardformer.layer.layers import VocabParallelEmbedding1D
|
||||
from colossalai.shardformer.layer import VocabParallelEmbedding1D
|
||||
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
|
||||
|
||||
|
||||
|
|
|
@ -1,51 +0,0 @@
|
|||
import pytest
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
import colossalai
|
||||
from colossalai.logging import disable_existing_loggers
|
||||
from colossalai.shardformer.layer.dropout import Dropout1D
|
||||
from colossalai.testing import rerun_if_address_is_in_use, spawn
|
||||
|
||||
CONFIG = dict(parallel=dict(data=1, pipeline=1, tensor=dict(size=2, mode='1d')),)
|
||||
|
||||
|
||||
def check_dropout(rank, world_size, port):
|
||||
disable_existing_loggers()
|
||||
colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, port=port, host='localhost', backend='nccl')
|
||||
|
||||
# prepare data
|
||||
input = torch.randn(5, 4).to('cuda')
|
||||
dropout = Dropout1D(p=0.4).to('cuda')
|
||||
output_list = []
|
||||
# compare the dropout pattern in each device
|
||||
for i in range(2):
|
||||
output = dropout(input)
|
||||
output_list.append(output)
|
||||
dist_output_list = [torch.zeros(*output.shape).to('cuda') for _ in range(world_size)]
|
||||
torch.distributed.all_gather(dist_output_list, output)
|
||||
for j in range(world_size):
|
||||
for k in range(world_size):
|
||||
if j != k:
|
||||
mask = torch.eq(dist_output_list[j], 0.0) == torch.eq(dist_output_list[k], 0.0)
|
||||
assert torch.all(
|
||||
mask
|
||||
) == False, f"The dropout pattern in each device is not unique\n{dist_output_list[j]}\n{dist_output_list[k]}"
|
||||
# compare the dropout pattern in loacl device
|
||||
for i in range(len(output_list)):
|
||||
for j in range(len(output_list)):
|
||||
if i != j:
|
||||
mask = torch.eq(output_list[i], 0.0) == torch.eq(output_list[j], 0.0)
|
||||
assert torch.all(
|
||||
mask
|
||||
) == False, f"The dropout pattern in one device is not unique\n{output_list[i]}\n{output_list[j]}"
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_dropout():
|
||||
spawn(check_dropout, 2)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_dropout()
|
|
@ -1,78 +0,0 @@
|
|||
import pytest
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
import colossalai
|
||||
from colossalai.logging import disable_existing_loggers
|
||||
from colossalai.shardformer.policies.basepolicy import Col_Layer, Layer, Row_Layer
|
||||
from colossalai.shardformer.shard.shard_config import ShardConfig
|
||||
from colossalai.shardformer.shard.slicer import Slicer
|
||||
from colossalai.testing import rerun_if_address_is_in_use, spawn
|
||||
|
||||
CONFIG = dict(parallel=dict(data=1, pipeline=1, tensor=dict(size=2, mode='1d')),)
|
||||
|
||||
|
||||
def check_slicer(rank, world_size, port, in_feature, out_feature):
|
||||
disable_existing_loggers()
|
||||
colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, port=port, host='localhost', backend='nccl')
|
||||
# initialize slicer
|
||||
shardconfig = ShardConfig(rank=rank, world_size=world_size)
|
||||
slicer = Slicer(shardconfig)
|
||||
# initialize test data
|
||||
weight = torch.randn(in_feature, out_feature)
|
||||
bias = torch.randn(out_feature)
|
||||
policy_layer_cls_list = [Layer, Col_Layer, Row_Layer]
|
||||
n_cast_list = [None, 2, 3, 4]
|
||||
# weight and bias
|
||||
for n_cast in n_cast_list:
|
||||
sliced_weight, sliced_bias = slicer.slice_weight_bias(weight, bias, policy_layer_cls=Layer, n_cast=n_cast)
|
||||
expected_sliced_weight = weight
|
||||
expected_sliced_bias = bias
|
||||
assert torch.equal(
|
||||
sliced_weight, expected_sliced_weight
|
||||
), f"In Layer case, weight: sliced_weight is not equal to expected_sliced_weight\norg:{weight}\nsliced:{sliced_weight}\nexpected:{expected_sliced_weight}"
|
||||
assert torch.equal(
|
||||
sliced_bias, expected_sliced_bias
|
||||
), f"In Layer case, bias: sliced_bias is not equal to expected_sliced_bias\norg:{bias}\nsliced:{sliced_weight}\nexpected:{expected_sliced_weight}"
|
||||
|
||||
sliced_weight, sliced_bias = slicer.slice_weight_bias(weight, bias, policy_layer_cls=Col_Layer, n_cast=n_cast)
|
||||
if (n_cast is None):
|
||||
expected_sliced_weight = weight.chunk(world_size, dim=0)[rank]
|
||||
expected_sliced_bias = bias.chunk(world_size)[rank]
|
||||
else:
|
||||
chunks = weight.chunk(world_size * n_cast, dim=0)
|
||||
expected_sliced_weight = torch.cat([chunks[i] for i in range(rank, n_cast * world_size, world_size)], dim=0)
|
||||
chunks = bias.chunk(world_size * n_cast, dim=0)
|
||||
expected_sliced_bias = torch.cat([chunks[i] for i in range(rank, n_cast * world_size, world_size)])
|
||||
assert torch.equal(
|
||||
sliced_weight, expected_sliced_weight
|
||||
), f"In Col_Layer {n_cast} cast case, weight: sliced_weight is not equal to expected_sliced_weight\norg:{weight}\nsliced:{sliced_weight}\nexpected:{expected_sliced_weight}"
|
||||
assert torch.equal(
|
||||
sliced_bias, expected_sliced_bias
|
||||
), f"In Col_Layer {n_cast} cast case, bias: sliced_bias is not equal to expected_sliced_bias\norg:{bias}\nsliced:{sliced_bias}\nexpected:{expected_sliced_bias}"
|
||||
|
||||
sliced_weight, sliced_bias = slicer.slice_weight_bias(weight, bias, policy_layer_cls=Row_Layer, n_cast=n_cast)
|
||||
if (n_cast is None):
|
||||
expected_sliced_weight = weight.chunk(world_size, dim=1)[rank]
|
||||
expected_sliced_bias = bias
|
||||
else:
|
||||
chunks = weight.chunk(world_size * n_cast, dim=1)
|
||||
expected_sliced_weight = torch.cat([chunks[i] for i in range(rank, n_cast * world_size, world_size)], dim=1)
|
||||
expected_sliced_bias = bias
|
||||
assert torch.equal(
|
||||
sliced_weight, expected_sliced_weight
|
||||
), f"In Row_Layer {n_cast} cast case, weight: sliced_weight is not equal to expected_sliced_weight\norg:{weight}\nsliced:{sliced_weight}\nexpected:{expected_sliced_weight}"
|
||||
assert torch.equal(
|
||||
sliced_bias, expected_sliced_bias
|
||||
), f"In Row_Layer {n_cast} cast case, bias: sliced_bias is not equal to expected_sliced_bias\norg:{bias}\nsliced:{sliced_weight}\nexpected:{expected_sliced_weight}"
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_slicer():
|
||||
args = dict(in_feature=24, out_feature=48)
|
||||
spawn(check_slicer, nprocs=2, in_feature=args['in_feature'], out_feature=args['out_feature'])
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_slicer()
|
Loading…
Reference in New Issue