[shardformer] add gpt2 test and layer class refactor (#4041)

* add gpt2 test and layer class refactor

* add dropout in gpt2 policy
pull/4157/head
FoolPlayer 2023-06-20 11:45:16 +08:00 committed by Frank Lee
parent d857f3dbba
commit 4021b9a8a2
14 changed files with 1400 additions and 840 deletions

View File

@ -0,0 +1,17 @@
from .dropout import Dropout1D
from .embedding1d import Embedding1D
from .layernorm1d import LayerNorm1D
from .linear1d import Linear1D_Col, Linear1D_Row
from .linearconv1d import LinearConv1D_Col, LinearConv1D_Row
from .vocabparallelembedding1d import VocabParallelEmbedding1D
__all__ = [
"Embedding1D",
"VocabParallelEmbedding1D",
"Linear1D_Col",
"Linear1D_Row",
"LinearConv1D_Col",
"LinearConv1D_Row",
"LayerNorm1D",
"Dropout1D",
]

View File

@ -4,7 +4,7 @@ import torch
import torch.nn as nn
from torch.distributed import ProcessGroup
from .layers import ParallelModule
from .parallelmodule import ParallelModule
from .utils import create_randomizer_with_offset

View File

@ -0,0 +1,149 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
from typing import Callable, List, Union
import torch
import torch.distributed as dist
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
from torch.distributed import ProcessGroup
from torch.nn.parameter import Parameter
from colossalai.nn import init as init
from colossalai.nn.layer.utils import divide
from colossalai.tensor.d_tensor.api import shard_colwise
from colossalai.utils.cuda import get_current_device
from ._operation import gather_forward_split_backward
from .parallelmodule import ParallelModule
from .utils import create_randomizer_with_offset
Fast_LN = None
try:
from apex.contrib.layer_norm.layer_norm import FastLayerNorm
Fast_LN = FastLayerNorm
except ImportError:
pass
class Embedding1D(ParallelModule):
r"""Embedding for 1D parallelism.
Args:
num_embeddings (int): number of embeddings.
embedding_dim (int): dimension of embedding.
padding_idx (int, optional): If specified, the entries at padding_idx do not contribute to the gradient;
therefore, the embedding vector at padding_idx is not updated during training,
i.e. it remains as a fixed pad, defaults to None.
dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None.
weight_initializer (:class:`typing.Callable`, optional):
he initializer of weight, defaults to normal initializer.
The ``args`` and ``kwargs`` used in :class:`torch.nn.functional.embedding` should contain:
::
max_norm (float, optional): If given, each embedding vector with norm larger than max_norm is
renormalized to have norm max_norm. Note: this will modify weight in-place.
norm_type (float, optional): The p of the p-norm to compute for the max_norm option. Default 2.
scale_grad_by_freq (bool, optional): If given, this will scale gradients by the inverse
of frequency of the words in the mini-batch. Default False.
sparse (bool, optional): If True, gradient w.r.t. weight will be a sparse tensor. Default False.
More details about ``args`` and ``kwargs`` could be found in
`Embedding <https://pytorch.org/docs/stable/generated/torch.nn.functional.embedding.html#torch.nn.functional.embedding>`_.
More details about ``initializer`` please refer to
`init <https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/nn/init.py>`_
"""
def __init__(self,
num_embeddings: int,
embedding_dim: int,
padding_idx: int = None,
dtype: torch.dtype = None,
device: torch.device = None,
process_group: ProcessGroup = None,
weight_initializer: Callable = init.normal_(),
*args,
**kwargs):
super().__init__()
self.num_embeddings = num_embeddings
self.embed_dim = embedding_dim
self.process_group = process_group
self.num_partitions = dist.get_world_size(process_group)
self.embed_dim_per_partition = divide(embedding_dim, self.num_partitions)
self.padding_idx = padding_idx
self.embed_args = args
self.embed_kwargs = kwargs
# self.gather_output = gather_output
if device is None:
device = get_current_device()
self.weight = Parameter(torch.empty((num_embeddings, self.embed_dim_per_partition), device=device, dtype=dtype))
# offset the seed with randomizer index and rank
seed = torch.random.initial_seed()
self.randomizer = create_randomizer_with_offset(seed, process_group=self.process_group)
with self.randomizer.fork_rng(enable_cpu=True):
self.reset_parameters(weight_initializer)
@staticmethod
def from_native_module(module: nn.Embedding,
process_group: Union[ProcessGroup, List[ProcessGroup]] = None) -> "Embedding1D":
r"""
Build a 1D parallelized Embedding from a native nn.Embedding module.
"""
# get the attributes
num_embedding = module.num_embeddings
embedding_dim = module.embedding_dim
padding_idx = module.padding_idx
max_norm = module.max_norm
norm_type = module.norm_type
scale_grad_by_freq = module.scale_grad_by_freq
sparse = module.sparse
dtype = module.weight.dtype
device = module.weight.device
# sparse is not support yet
if sparse:
raise NotImplementedError("The Embedding1D module does not support sparse embedding yet.")
embedding = Embedding1D(num_embeddings=num_embedding,
embedding_dim=embedding_dim,
padding_idx=padding_idx,
process_group=process_group,
dtype=dtype,
device=device,
max_norm=max_norm,
norm_type=norm_type,
scale_grad_by_freq=scale_grad_by_freq,
sparse=sparse)
# copy the weight
with torch.no_grad():
sharded_weight = shard_colwise(module.weight.data, process_group)
embedding.weight.copy_(sharded_weight)
return embedding
def reset_parameters(self, weight_initializer) -> None:
fan_in, fan_out = self.num_embeddings, self.embed_dim
weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out)
self._fill_padding_idx_with_zero()
def _fill_padding_idx_with_zero(self) -> None:
if self.padding_idx is not None:
with torch.no_grad():
self.weight[self.padding_idx].fill_(0)
def forward(self, input_: Tensor) -> Tensor:
output_parallel = F.embedding(input_, self.weight, self.padding_idx, *self.embed_args, **self.embed_kwargs)
output = gather_forward_split_backward(output_parallel, dim=-1, process_group=self.process_group)
return output

View File

@ -0,0 +1,73 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
from collections import OrderedDict
from colossalai.context import ParallelMode, seed
from colossalai.core import global_context as gpc
from colossalai.global_variables import tensor_parallel_env as env
from colossalai.kernel import LayerNorm
from colossalai.nn import init as init
from colossalai.nn.layer.colossalai_layer._utils import ColossalaiModule
from colossalai.utils.checkpointing import broadcast_state_dict
Fast_LN = None
try:
from apex.contrib.layer_norm.layer_norm import FastLayerNorm
Fast_LN = FastLayerNorm
except ImportError:
pass
class LayerNorm1D(ColossalaiModule):
r"""
Layer Normalization for colossalai
Args:
normalized_shape (int): input shape from an expected input of size.
:math:`[* \times \text{normalized_shape}[0] \times \text{normalized_shape}[1]
\times \ldots \times \text{normalized_shape}[-1]]`
If a single integer is used, it is treated as a singleton list, and this module will
normalize over the last dimension which is expected to be of that specific size.
eps (float): a value added to the denominator for numerical stability, defaults to 1e-05.
bias (bool, optional): Whether to add a bias, defaults to ``True``.
dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None.
"""
_fast_ln_supported_sizes = [
1024, 1536, 2048, 2304, 3072, 3840, 4096, 5120, 6144, 8192, 10240, 12288, 12800, 15360, 16384, 18432, 20480,
24576, 25600, 30720, 32768, 40960, 49152, 65536
]
def __init__(self, normalized_shape: int, eps=1e-05, bias=True, dtype=None):
if Fast_LN is not None and normalized_shape in self._fast_ln_supported_sizes:
norm = Fast_LN(normalized_shape, eps=eps).to(dtype)
else:
norm = None
try:
from apex.normalization import FusedLayerNorm
norm = FusedLayerNorm(normalized_shape, eps=eps).to(dtype)
except ImportError:
norm = LayerNorm(normalized_shape, eps=eps).to(dtype)
super().__init__(norm)
def _load_from_state_dict(self, state_dict, prefix, *args):
local_state = OrderedDict()
weight_key = prefix + 'weight'
bias_key = prefix + 'bias'
if gpc.get_local_rank(ParallelMode.TENSOR) == 0:
# weight
weight = state_dict.pop(weight_key, None)
if weight is not None:
local_state[weight_key] = weight
# bias
bias = state_dict.pop(bias_key, None)
if bias is not None:
local_state[bias_key] = bias
local_state = broadcast_state_dict(local_state, ParallelMode.PARALLEL_1D)
super()._load_from_state_dict(local_state, prefix, *args)
def _save_to_state_dict(self, destination, prefix, keep_vars):
if gpc.get_local_rank(ParallelMode.TENSOR) == 0:
super()._save_to_state_dict(destination, prefix, keep_vars)

View File

@ -1,722 +0,0 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
import math
from abc import ABC, abstractmethod
from collections import OrderedDict
from typing import Callable, List, Tuple, Union
import torch
import torch.distributed as dist
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
from torch.distributed import ProcessGroup
from torch.nn.parameter import Parameter
from colossalai.communication import broadcast
from colossalai.context import ParallelMode, seed
from colossalai.core import global_context as gpc
from colossalai.global_variables import tensor_parallel_env as env
from colossalai.kernel import LayerNorm
from colossalai.nn import init as init
from colossalai.nn.layer.base_layer import ParallelLayer
from colossalai.nn.layer.colossalai_layer._utils import ColossalaiModule
from colossalai.nn.layer.parallel_1d._utils import get_parallel_input, reduce_grad, set_parallel_input
from colossalai.nn.layer.utils import divide, set_tensor_parallel_attribute_by_partition
from colossalai.nn.layer.vanilla import VanillaLayerNorm, VanillaPatchEmbedding
from colossalai.tensor.d_tensor.api import shard_colwise, shard_rowwise
from colossalai.utils.checkpointing import (
broadcast_state_dict,
gather_tensor_parallel_state_dict,
partition_tensor_parallel_state_dict,
)
from colossalai.utils.cuda import get_current_device
from ._operation import (
gather_forward_split_backward,
linear_with_async_comm,
reduce_input,
split_forward_gather_backward,
)
from .utils import create_randomizer_with_offset
Fast_LN = None
try:
from apex.contrib.layer_norm.layer_norm import FastLayerNorm
Fast_LN = FastLayerNorm
except ImportError:
pass
class ParallelModule(nn.Module, ABC):
@abstractmethod
def from_native_module(module: nn.Module,
process_group: Union[ProcessGroup, List[ProcessGroup]] = None) -> "ParallelModule":
"""
Convert a native PyTorch module to a parallelized module.
Args:
module (nn.Module): the module to be converted.
process_group (ProcessGroup or list[ProcessGroup]): the process group(s) to be used for communication.
If this is a list, the process group at the ith index of the list will correspond to the process group
in the ith axis of the device mesh. Defaults to None, which means the global process group.
"""
pass
class Linear1D_Col(ParallelModule):
r"""Linear layer with column parallelism.
The linear layer is defined as :math:`Y = XA + b`. A is parallelized along
its second dimension as :math:`A = [A_1, ..., A_p]`.
Args:
in_features (int): size of each input sample.
out_features (int): size of each output sample.
bias (bool, optional): If set to ``False``, the layer will not learn an additive bias, defaults to ``True``.
dtype (`torch.dtype`): The dtype of parameters, defaults to None.
device (`torch.device`): The device of parameters, defaults to None.
process_group (`torch.distributed.ProcessGroup`): The process group to be used for weight sharding and communication, defaults to None.
gather_output (bool, optional): If true, call all-gather on output and make Y available
to all GPUs, otherwise, every GPU will have its output
which is :math:`Y_i = XA_i`, defaults to False
skip_bias_add (bool): If set to ``True``, it will skip bias add for linear layer,
which is preserved for kernel fusion, defaults to False
weight_initializer (`typing.Callable`):
The initializer of weight, defaults to kaiming uniform initializer.
bias_initializer (`typing.Callable`):
The initializer of bias, defaults to xavier uniform initializer.
More details about ``initializer`` please refer to
`init <https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/nn/init.py>`_.
"""
def __init__(self,
in_features: int,
out_features: int,
bias: bool = True,
dtype: torch.dtype = None,
device: torch.device = None,
process_group: ProcessGroup = None,
gather_output: bool = False,
skip_bias_add: bool = False,
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1)):
super().__init__()
# Keep input parameters
self.in_features = in_features
self.out_features = out_features
self.gather_output = gather_output
self.skip_bias_add = skip_bias_add
self.device = device
self.process_group = process_group
self.num_partitions = dist.get_world_size(self.process_group)
if skip_bias_add and not bias:
raise ValueError('cannot skip bias addition if bias is None')
self.out_features_per_partition = divide(out_features, self.num_partitions)
# Parameters.
# Initialize weight.
if device is None:
device = get_current_device()
factory_kwargs = {'device': device, 'dtype': dtype}
self.weight = Parameter(torch.empty(self.out_features_per_partition, self.in_features, **factory_kwargs))
if bias:
self.bias = Parameter(torch.empty(self.out_features_per_partition, **factory_kwargs))
else:
self.bias = None
# offset the seed with randomizer index and rank
seed = torch.random.initial_seed()
self.randomizer = create_randomizer_with_offset(seed, process_group=self.process_group)
with self.randomizer.fork_rng(enable_cpu=True):
self.reset_parameters(weight_initializer, bias_initializer)
@staticmethod
def from_native_module(module: nn.Linear, process_group: Union[ProcessGroup, List[ProcessGroup]], *args,
**kwargs) -> ParallelModule:
r"""
Convert a native PyTorch linear layer to a parallelized linear layer.
"""
# get the attributes
in_features = module.in_features
out_features = module.out_features
bias = module.bias is not None
device = module.weight.device
# ensure only one process group is passed
if isinstance(process_group, (list, tuple)):
assert len(process_group) == 1, \
f'Expected only one process group, got {len(process_group)}.'
process_group = process_group[0]
linear_1d = Linear1D_Col(in_features=in_features,
out_features=out_features,
bias=bias,
device=device,
process_group=process_group,
*args,
**kwargs)
# TODO: copy the sharded weights
with torch.no_grad():
# the weigh to the linear layer is a transpose
# thus shard on row is equal to shard on column
sharded_weight = shard_rowwise(module.weight.data, process_group)
linear_1d.weight.data.copy_(sharded_weight)
if bias:
sharded_bias = shard_colwise(module.bias.data, process_group)
linear_1d.bias.copy_(sharded_bias)
return linear_1d
def reset_parameters(self, weight_initializer, bias_initializer) -> None:
fan_in, fan_out = self.in_features, self.out_features
weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out)
if self.bias is not None:
bias_initializer(self.bias, fan_in=fan_in)
def forward(self, input_: Tensor) -> Tuple[Tensor, Tensor]:
assert input_.shape[-1] == self.weight.shape[-1], \
'Invalid shapes in Linear1D_Col forward: input={}, weight={}. Expected last dim of input {}.'.format(
input_.shape, self.weight.shape, self.weight.shape[-1])
# Set up backprop all-reduce.
# input_parallel = reduce_grad(input_, ParallelMode.PARALLEL_1D)
input_parallel = input_
# Matrix multiply.
bias = self.bias if not self.skip_bias_add else None
output_parallel = linear_with_async_comm(input_parallel, self.weight, bias, self.process_group, True)
if self.gather_output:
# All-gather across the partitions.
output = gather_forward_split_backward(output_parallel, dim=-1, process_group=self.process_group)
else:
output = output_parallel
if self.skip_bias_add:
return output, self.bias
else:
return output
class Linear1D_Row(ParallelModule):
r""" Linear layer with row parallelism
Args:
in_features (int): size of each input sample.
out_features (int): size of each output sample.
bias (bool, optional): If set to ``False``, the layer will not learn an additive bias, defaults to ``True``.
dtype (`torch.dtype`): The dtype of parameters, defaults to None.
parallel_input (bool): If set to ``True``, it's assumed that the input is split, defaults to False.
skip_bias_add (bool): If set to ``True``, it will skip bias add for linear layer,
which is preserved for kernel fusion, defaults to False
weight_initializer (:class:`typing.Callable`, optional):
The initializer of weight, defaults to kaiming uniform initializer.
bias_initializer (:class:`typing.Callable`, optional):
The initializer of bias, defaults to xavier uniform initializer.
More details about ``initializer`` please refer to
`init <https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/nn/init.py>`_.
"""
def __init__(self,
in_features: int,
out_features: int,
bias: bool = True,
dtype: torch.dtype = None,
device: torch.device = None,
process_group: ProcessGroup = None,
parallel_input: bool = True,
skip_bias_add: bool = False,
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1),
stream_chunk_num: int = 1):
super().__init__()
self.stream_chunk_num = stream_chunk_num
# Keep input parameters
self.in_features = in_features
self.out_features = out_features
self.parallel_input = parallel_input
self.skip_bias_add = skip_bias_add
self.process_group = process_group
self.num_partitions = dist.get_world_size(self.process_group)
if skip_bias_add and not bias:
raise ValueError('cannot skip bias addition if bias is None')
# Divide the weight matrix along the last dimension.
self.input_size_per_partition = divide(in_features, self.num_partitions)
# Parameters.
# Initialize weight.
if device is None:
device = get_current_device()
factory_kwargs = {'device': device, 'dtype': dtype}
self.weight = Parameter(torch.empty(self.out_features, self.input_size_per_partition, **factory_kwargs))
if self.stream_chunk_num > 1:
# TODO() work for inference only
self.chunk_weight()
if bias:
self.bias = Parameter(torch.empty(self.out_features, **factory_kwargs))
else:
self.bias = None
# offset the seed with randomizer index and rank
seed = torch.random.initial_seed()
self.randomizer = create_randomizer_with_offset(seed, process_group=self.process_group)
with self.randomizer.fork_rng(enable_cpu=True):
self.reset_parameters(weight_initializer, bias_initializer)
@staticmethod
def from_native_module(module: nn.Linear, process_group: Union[ProcessGroup, List[ProcessGroup]], *args,
**kwargs) -> ParallelModule:
r"""
Convert a native PyTorch linear layer to a parallelized linear layer.
"""
# get the attributes
in_features = module.in_features
out_features = module.out_features
bias = module.bias is not None
device = module.weight.device
# ensure only one process group is passed
if isinstance(process_group, (list, tuple)):
assert len(process_group) == 1, \
f'Expected only one process group, got {len(process_group)}.'
process_group = process_group[0]
linear_1d = Linear1D_Row(in_features=in_features,
out_features=out_features,
bias=bias,
device=device,
process_group=process_group,
*args,
**kwargs)
# TODO: copy the sharded weights
with torch.no_grad():
# the weigh to the linear layer is a transpose
# thus shard on col is equal to shard on row
sharded_weight = shard_colwise(module.weight.data, process_group)
linear_1d.weight.data.copy_(sharded_weight)
if bias:
linear_1d.bias.copy_(module.bias.data)
return linear_1d
def chunk_weight(self):
self.weight_list = torch.chunk(self.weight, self.stream_chunk_num, dim=0)
def reset_parameters(self, weight_initializer, bias_initializer) -> None:
fan_in, fan_out = self.in_features, self.out_features
weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out)
if self.bias is not None:
bias_initializer(self.bias, fan_in=fan_in)
if self.process_group is None:
src_rank = 0
else:
src_rank = dist.distributed_c10d._get_global_rank(self.process_group, 0)
origin_device = self.bias.device
self.bias = self.bias.cuda()
dist.broadcast(self.bias, src=src_rank, group=self.process_group)
self.bias = self.bias.to(origin_device)
def forward(self, input_: Tensor) -> Tensor:
# Set up backprop all-reduce.
if self.parallel_input:
assert input_.shape[-1] == self.weight.shape[-1], \
'Invalid shapes in Linear1D_Row forward: input={}, weight={}. Expected last dim of input {}.'.format(
input_.shape, self.weight.shape, self.weight.shape[-1])
input_ = input_
else:
assert divide(input_.shape[-1], self.num_partitions) == self.weight.shape[-1], \
'Invalid shapes in Linear1D_Row forward: input={}, weight={}. Expected last dim of input {}.'.format(
input_.shape, self.weight.shape, self.weight.shape[-1] * self.num_partitions)
input_ = split_forward_gather_backward(input_, dim=-1, process_group=self.process_group)
if self.stream_chunk_num > 1:
if self.training:
raise RuntimeError("use stream_chunk_num=1 in Linear1D_Row for training!")
with torch.no_grad():
output_parallel_list = [None for i in range(self.stream_chunk_num)]
handle_list = []
for i in range(self.stream_chunk_num):
output_parallel_list[i] = F.linear(input_, self.weight_list[i])
handle = torch.distributed.all_reduce(output_parallel_list[i],
group=self.process_group,
async_op=True)
handle_list.append(handle)
# output_parallel_list[i] = reduce_input(output_parallel_list[i], ParallelMode.PARALLEL_1D)
for handle in handle_list:
handle.wait()
output = torch.cat(output_parallel_list, dim=-1)
else:
output_parallel = F.linear(input_, self.weight)
# output_parallel = linear_with_async_comm(input_, self.weight, None, ParallelMode.PARALLEL_1D, False)
output = reduce_input(output_parallel, self.process_group)
if not self.skip_bias_add:
if self.bias is not None:
output = output + self.bias
return output
else:
return output, self.bias
class LayerNorm1D(ColossalaiModule):
r"""
Layer Normalization for colossalai
Args:
normalized_shape (int): input shape from an expected input of size.
:math:`[* \times \text{normalized_shape}[0] \times \text{normalized_shape}[1]
\times \ldots \times \text{normalized_shape}[-1]]`
If a single integer is used, it is treated as a singleton list, and this module will
normalize over the last dimension which is expected to be of that specific size.
eps (float): a value added to the denominator for numerical stability, defaults to 1e-05.
bias (bool, optional): Whether to add a bias, defaults to ``True``.
dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None.
"""
_fast_ln_supported_sizes = [
1024, 1536, 2048, 2304, 3072, 3840, 4096, 5120, 6144, 8192, 10240, 12288, 12800, 15360, 16384, 18432, 20480,
24576, 25600, 30720, 32768, 40960, 49152, 65536
]
def __init__(self, normalized_shape: int, eps=1e-05, bias=True, dtype=None):
if Fast_LN is not None and normalized_shape in self._fast_ln_supported_sizes:
norm = Fast_LN(normalized_shape, eps=eps).to(dtype)
else:
norm = None
try:
from apex.normalization import FusedLayerNorm
norm = FusedLayerNorm(normalized_shape, eps=eps).to(dtype)
except ImportError:
norm = LayerNorm(normalized_shape, eps=eps).to(dtype)
super().__init__(norm)
def _load_from_state_dict(self, state_dict, prefix, *args):
local_state = OrderedDict()
weight_key = prefix + 'weight'
bias_key = prefix + 'bias'
if gpc.get_local_rank(ParallelMode.TENSOR) == 0:
# weight
weight = state_dict.pop(weight_key, None)
if weight is not None:
local_state[weight_key] = weight
# bias
bias = state_dict.pop(bias_key, None)
if bias is not None:
local_state[bias_key] = bias
local_state = broadcast_state_dict(local_state, ParallelMode.PARALLEL_1D)
super()._load_from_state_dict(local_state, prefix, *args)
def _save_to_state_dict(self, destination, prefix, keep_vars):
if gpc.get_local_rank(ParallelMode.TENSOR) == 0:
super()._save_to_state_dict(destination, prefix, keep_vars)
class Embedding1D(ParallelModule):
r"""Embedding for 1D parallelism.
Args:
num_embeddings (int): number of embeddings.
embedding_dim (int): dimension of embedding.
padding_idx (int, optional): If specified, the entries at padding_idx do not contribute to the gradient;
therefore, the embedding vector at padding_idx is not updated during training,
i.e. it remains as a fixed pad, defaults to None.
dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None.
weight_initializer (:class:`typing.Callable`, optional):
he initializer of weight, defaults to normal initializer.
The ``args`` and ``kwargs`` used in :class:`torch.nn.functional.embedding` should contain:
::
max_norm (float, optional): If given, each embedding vector with norm larger than max_norm is
renormalized to have norm max_norm. Note: this will modify weight in-place.
norm_type (float, optional): The p of the p-norm to compute for the max_norm option. Default 2.
scale_grad_by_freq (bool, optional): If given, this will scale gradients by the inverse
of frequency of the words in the mini-batch. Default False.
sparse (bool, optional): If True, gradient w.r.t. weight will be a sparse tensor. Default False.
More details about ``args`` and ``kwargs`` could be found in
`Embedding <https://pytorch.org/docs/stable/generated/torch.nn.functional.embedding.html#torch.nn.functional.embedding>`_.
More details about ``initializer`` please refer to
`init <https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/nn/init.py>`_
"""
def __init__(self,
num_embeddings: int,
embedding_dim: int,
padding_idx: int = None,
dtype: torch.dtype = None,
device: torch.device = None,
process_group: ProcessGroup = None,
gather_output: bool = True,
weight_initializer: Callable = init.normal_(),
*args,
**kwargs):
super().__init__()
self.num_embeddings = num_embeddings
self.embedding_dim = embedding_dim
self.process_group = process_group
self.num_partitions = dist.get_world_size(process_group)
self.embed_dim_per_partition = divide(embedding_dim, self.num_partitions)
self.padding_idx = padding_idx
self.embed_args = args
self.embed_kwargs = kwargs
self.gather_output = gather_output
if device is None:
device = get_current_device()
self.weight = Parameter(torch.empty((num_embeddings, self.embed_dim_per_partition), device=device, dtype=dtype))
# offset the seed with randomizer index and rank
seed = torch.random.initial_seed()
self.randomizer = create_randomizer_with_offset(seed, process_group=self.process_group)
with self.randomizer.fork_rng(enable_cpu=True):
self.reset_parameters(weight_initializer)
@staticmethod
def from_native_module(module: nn.Embedding,
process_group: Union[ProcessGroup, List[ProcessGroup]] = None,
*args,
**kwargs) -> "Embedding1D":
r"""
Build a 1D parallelized Embedding from a native nn.Embedding module.
"""
# get the attributes
num_embedding = module.num_embeddings
embedding_dim = module.embedding_dim
padding_idx = module.padding_idx
max_norm = module.max_norm
norm_type = module.norm_type
scale_grad_by_freq = module.scale_grad_by_freq
sparse = module.sparse
dtype = module.weight.dtype
device = module.weight.device
# sparse is not support yet
if sparse:
raise NotImplementedError("The Embedding1D module does not support sparse embedding yet.")
embedding = Embedding1D(num_embeddings=num_embedding,
embedding_dim=embedding_dim,
padding_idx=padding_idx,
process_group=process_group,
dtype=dtype,
device=device,
max_norm=max_norm,
norm_type=norm_type,
scale_grad_by_freq=scale_grad_by_freq,
sparse=sparse,
*args,
**kwargs)
# copy the weight
with torch.no_grad():
sharded_weight = shard_colwise(module.weight.data, process_group)
embedding.weight.copy_(sharded_weight)
return embedding
def reset_parameters(self, weight_initializer) -> None:
fan_in, fan_out = self.num_embeddings, self.embedding_dim
weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out)
self._fill_padding_idx_with_zero()
def _fill_padding_idx_with_zero(self) -> None:
if self.padding_idx is not None:
with torch.no_grad():
self.weight[self.padding_idx].fill_(0)
def forward(self, input_: Tensor) -> Tensor:
output_parallel = F.embedding(input_, self.weight, self.padding_idx, *self.embed_args, **self.embed_kwargs)
if self.gather_output:
output = gather_forward_split_backward(output_parallel, dim=-1, process_group=self.process_group)
return output
else:
return output_parallel
class VocabParallelEmbedding1D(ParallelLayer):
r"""Embedding parallelized in the vocabulary dimension.
Args:
num_embeddings (int): number of embeddings.
embedding_dim (int): dimension of embedding.
padding_idx (int, optional): If specified, the entries at padding_idx do not contribute to the gradient;
therefore, the embedding vector at padding_idx is not updated during training,
i.e. it remains as a fixed pad, defaults to None.
dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None.
weight_initializer (:class:`typing.Callable`, optional):
he initializer of weight, defaults to normal initializer.
The ``args`` and ``kwargs`` used in :class:``torch.nn.functional.embedding`` should contain:
::
max_norm (float, optional): If given, each embedding vector with norm larger than max_norm is
renormalized to have norm max_norm. Note: this will modify weight in-place.
norm_type (float, optional): The p of the p-norm to compute for the max_norm option. Default 2.
scale_grad_by_freq (bool, optional): If given, this will scale gradients by the inverse
of frequency of the words in the mini-batch. Default False.
sparse (bool, optional): If True, gradient w.r.t. weight will be a sparse tensor. Default False.
More details about ``args`` and ``kwargs`` could be found in
`Embedding <https://pytorch.org/docs/stable/generated/torch.nn.functional.embedding.html#torch.nn.functional.embedding>`_.
More details about initializer please refer to
`init <https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/nn/init.py>`_.
"""
def __init__(self,
num_embeddings: int,
embedding_dim: int,
padding_idx: int = None,
dtype: torch.dtype = None,
device: torch.device = None,
process_group: ProcessGroup = None,
weight_initializer: Callable = init.normal_(),
*args,
**kwargs):
super().__init__()
self.num_embeddings = num_embeddings
self.embedding_dim = embedding_dim
self.padding_idx = padding_idx
self.embed_args = args
self.embed_kwargs = kwargs
self.process_group = process_group
tensor_parallel_size = dist.get_world_size(group=process_group)
tensor_parallel_rank = dist.get_rank(group=process_group)
self.num_embeddings_per_partition = divide(num_embeddings, tensor_parallel_size)
self.num_embeddings = self.num_embeddings_per_partition
self.vocab_start_index = tensor_parallel_rank * self.num_embeddings_per_partition
self.vocab_end_index = self.vocab_start_index + self.num_embeddings_per_partition
self.weight = Parameter(
torch.empty((self.num_embeddings_per_partition, self.embedding_dim), device=device, dtype=dtype))
# offset the seed with randomizer index and rank
seed = torch.random.initial_seed()
self.randomizer = create_randomizer_with_offset(seed, process_group=self.process_group)
with self.randomizer.fork_rng(enable_cpu=True):
self.reset_parameters(weight_initializer)
# self.reset_parameters(weight_initializer)
# self._set_tensor_parallel_attributes()
# set_parallel_input(False)
# env.vocab_parallel = True
@staticmethod
def from_native_module(module: nn.Embedding, process_group: Union[ProcessGroup, List[ProcessGroup]], *args,
**kwargs) -> ParallelModule:
r"""
Convert a native pytorch embedding module to a parallel module.
"""
# get the origin attributes
num_embeddings = module.num_embeddings
embedding_dim = module.embedding_dim
padding_idx = module.padding_idx
device = module.weight.device
# ensure only one process group is used
if isinstance(process_group, (list, tuple)):
assert len(process_group) == 1, \
f'Expected only one process group, got {len(process_group)}.'
process_group = process_group[0]
# create the parallel module
vocab_embedding_1d = VocabParallelEmbedding1D(num_embeddings=num_embeddings,
embedding_dim=embedding_dim,
padding_idx=padding_idx,
device=device,
process_group=process_group,
*args,
**kwargs)
with torch.no_grad():
# shard and slice the weight along the vocabulary(num_embeddings) dimension
# the shape of the weight is (num_embeddings, embedding_dim)
shard_weight = shard_rowwise(module.weight.data, process_group)
vocab_embedding_1d.weight.data.copy_(shard_weight)
return vocab_embedding_1d
def _set_tensor_parallel_attributes(self):
set_tensor_parallel_attribute_by_partition(self.weight, gpc.tensor_parallel_size)
def reset_parameters(self, weight_initializer) -> None:
with seed(ParallelMode.TENSOR):
fan_in, fan_out = self.num_embeddings, self.embedding_dim
weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out)
self._fill_padding_idx_with_zero()
def _fill_padding_idx_with_zero(self) -> None:
if self.padding_idx is not None and \
self.padding_idx >= self.vocab_start_index and self.padding_idx < self.vocab_end_index:
with torch.no_grad():
self.weight[self.padding_idx - self.vocab_start_index].fill_(0)
def _load_from_global_state_dict(self, state_dict, prefix, *args):
local_state = OrderedDict()
weight_key = prefix + 'weight'
if gpc.get_local_rank(ParallelMode.TENSOR) == 0:
# weight
weight = state_dict.pop(weight_key, None)
if weight is not None:
local_state[weight_key] = weight
local_state = partition_tensor_parallel_state_dict(local_state,
ParallelMode.PARALLEL_1D,
dims={weight_key: 0},
partition_states={weight_key: True})
super()._load_from_global_state_dict(local_state, prefix, *args)
def _save_to_global_state_dict(self, destination, prefix, keep_vars):
weight_key = prefix + 'weight'
local_state = OrderedDict({weight_key: self.weight})
local_state = gather_tensor_parallel_state_dict(local_state,
ParallelMode.PARALLEL_1D,
dims={weight_key: 0},
partition_states={weight_key: True},
keep_vars=keep_vars)
destination.update(local_state)
def forward(self, input_: Tensor) -> Tensor:
# Build the mask.
input_mask = (input_ < self.vocab_start_index) | (input_ >= self.vocab_end_index)
# Mask the input.
masked_input = input_.clone() - self.vocab_start_index
masked_input[input_mask] = 0
output_parallel = F.embedding(masked_input, self.weight, self.padding_idx, *self.embed_args,
**self.embed_kwargs)
# Mask the output embedding.
output_parallel[input_mask, :] = 0.
# Reduce across all the model parallel GPUs.
output = reduce_input(output_parallel, self.process_group)
return output

View File

@ -0,0 +1,346 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
import math
from typing import Callable, List, Tuple, Union
import torch
import torch.distributed as dist
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
from torch.distributed import ProcessGroup
from torch.nn.parameter import Parameter
from colossalai.nn import init as init
from colossalai.nn.layer.utils import divide
from colossalai.tensor.d_tensor.api import shard_colwise, shard_rowwise
from colossalai.utils.cuda import get_current_device
from ._operation import (
gather_forward_split_backward,
linear_with_async_comm,
reduce_input,
split_forward_gather_backward,
)
from .parallelmodule import ParallelModule
from .utils import create_randomizer_with_offset
Fast_LN = None
try:
from apex.contrib.layer_norm.layer_norm import FastLayerNorm
Fast_LN = FastLayerNorm
except ImportError:
pass
class Linear1D_Col(ParallelModule):
r"""Linear layer with column parallelism.
The linear layer is defined as :math:`Y = XA + b`. A is parallelized along
its second dimension as :math:`A = [A_1, ..., A_p]`.
Args:
in_features (int): size of each input sample.
out_features (int): size of each output sample.
bias (bool, optional): If set to ``False``, the layer will not learn an additive bias, defaults to ``True``.
dtype (`torch.dtype`): The dtype of parameters, defaults to None.
device (`torch.device`): The device of parameters, defaults to None.
process_group (`torch.distributed.ProcessGroup`): The process group to be used for weight sharding and communication, defaults to None.
gather_output (bool, optional): If true, call all-gather on output and make Y available
to all GPUs, otherwise, every GPU will have its output
which is :math:`Y_i = XA_i`, defaults to False
skip_bias_add (bool): If set to ``True``, it will skip bias add for linear layer,
which is preserved for kernel fusion, defaults to False
weight_initializer (`typing.Callable`):
The initializer of weight, defaults to kaiming uniform initializer.
bias_initializer (`typing.Callable`):
The initializer of bias, defaults to xavier uniform initializer.
More details about ``initializer`` please refer to
`init <https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/nn/init.py>`_.
"""
def __init__(self,
in_features: int,
out_features: int,
bias: bool = True,
dtype: torch.dtype = None,
device: torch.device = None,
process_group: ProcessGroup = None,
gather_output: bool = False,
skip_bias_add: bool = False,
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1)):
super().__init__()
# Keep input parameters
self.in_features = in_features
self.out_features = out_features
self.gather_output = gather_output
self.skip_bias_add = skip_bias_add
self.device = device
self.process_group = process_group
self.num_partitions = dist.get_world_size(self.process_group)
if skip_bias_add and not bias:
raise ValueError('cannot skip bias addition if bias is None')
self.out_features_per_partition = divide(out_features, self.num_partitions)
# Parameters.
# Initialize weight.
if device is None:
device = get_current_device()
factory_kwargs = {'device': device, 'dtype': dtype}
self.weight = Parameter(torch.empty(self.out_features_per_partition, self.in_features, **factory_kwargs))
if bias:
self.bias = Parameter(torch.empty(self.out_features_per_partition, **factory_kwargs))
else:
self.bias = None
# offset the seed with randomizer index and rank
seed = torch.random.initial_seed()
self.randomizer = create_randomizer_with_offset(seed, process_group=self.process_group)
with self.randomizer.fork_rng(enable_cpu=True):
self.reset_parameters(weight_initializer, bias_initializer)
@staticmethod
def from_native_module(module: nn.Linear, process_group: Union[ProcessGroup, List[ProcessGroup]], *args,
**kwargs) -> ParallelModule:
r"""
Convert a native PyTorch linear layer to a parallelized linear layer.
"""
# get the attributes
in_features = module.in_features
out_features = module.out_features
bias = module.bias is not None
device = module.weight.device
# ensure only one process group is passed
if isinstance(process_group, (list, tuple)):
assert len(process_group) == 1, \
f'Expected only one process group, got {len(process_group)}.'
process_group = process_group[0]
linear_1d = Linear1D_Col(in_features=in_features,
out_features=out_features,
bias=bias,
device=device,
process_group=process_group,
*args,
**kwargs)
# TODO: copy the sharded weights
with torch.no_grad():
# the weigh to the linear layer is a transpose
# thus shard on row is equal to shard on column
sharded_weight = shard_rowwise(module.weight.data, process_group)
linear_1d.weight.data.copy_(sharded_weight)
if bias:
sharded_bias = shard_colwise(module.bias.data, process_group)
linear_1d.bias.copy_(sharded_bias)
return linear_1d
def reset_parameters(self, weight_initializer, bias_initializer) -> None:
fan_in, fan_out = self.in_features, self.out_features
weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out)
if self.bias is not None:
bias_initializer(self.bias, fan_in=fan_in)
def forward(self, input_: Tensor) -> Tuple[Tensor, Tensor]:
assert input_.shape[-1] == self.weight.shape[-1], \
'Invalid shapes in Linear1D_Col forward: input={}, weight={}. Expected last dim of input {}.'.format(
input_.shape, self.weight.shape, self.weight.shape[-1])
# Set up backprop all-reduce.
# input_parallel = reduce_grad(input_, ParallelMode.PARALLEL_1D)
input_parallel = input_
# Matrix multiply.
bias = self.bias if not self.skip_bias_add else None
output_parallel = linear_with_async_comm(input_parallel, self.weight, bias, self.process_group, True)
if self.gather_output:
# All-gather across the partitions.
output = gather_forward_split_backward(output_parallel, dim=-1, process_group=self.process_group)
else:
output = output_parallel
if self.skip_bias_add:
return output, self.bias
else:
return output
class Linear1D_Row(ParallelModule):
r""" Linear layer with row parallelism
Args:
in_features (int): size of each input sample.
out_features (int): size of each output sample.
bias (bool, optional): If set to ``False``, the layer will not learn an additive bias, defaults to ``True``.
dtype (`torch.dtype`): The dtype of parameters, defaults to None.
parallel_input (bool): If set to ``True``, it's assumed that the input is split, defaults to False.
skip_bias_add (bool): If set to ``True``, it will skip bias add for linear layer,
which is preserved for kernel fusion, defaults to False
weight_initializer (:class:`typing.Callable`, optional):
The initializer of weight, defaults to kaiming uniform initializer.
bias_initializer (:class:`typing.Callable`, optional):
The initializer of bias, defaults to xavier uniform initializer.
More details about ``initializer`` please refer to
`init <https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/nn/init.py>`_.
"""
def __init__(self,
in_features: int,
out_features: int,
bias: bool = True,
dtype: torch.dtype = None,
device: torch.device = None,
process_group: ProcessGroup = None,
parallel_input: bool = True,
skip_bias_add: bool = False,
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1),
stream_chunk_num: int = 1):
super().__init__()
self.stream_chunk_num = stream_chunk_num
# Keep input parameters
self.in_features = in_features
self.out_features = out_features
self.parallel_input = parallel_input
self.skip_bias_add = skip_bias_add
self.process_group = process_group
self.num_partitions = dist.get_world_size(self.process_group)
if skip_bias_add and not bias:
raise ValueError('cannot skip bias addition if bias is None')
# Divide the weight matrix along the last dimension.
self.input_size_per_partition = divide(in_features, self.num_partitions)
# Parameters.
# Initialize weight.
if device is None:
device = get_current_device()
factory_kwargs = {'device': device, 'dtype': dtype}
self.weight = Parameter(torch.empty(self.out_features, self.input_size_per_partition, **factory_kwargs))
if self.stream_chunk_num > 1:
# TODO() work for inference only
self.chunk_weight()
if bias:
self.bias = Parameter(torch.empty(self.out_features, **factory_kwargs))
else:
self.bias = None
# offset the seed with randomizer index and rank
seed = torch.random.initial_seed()
self.randomizer = create_randomizer_with_offset(seed, process_group=self.process_group)
with self.randomizer.fork_rng(enable_cpu=True):
self.reset_parameters(weight_initializer, bias_initializer)
@staticmethod
def from_native_module(module: nn.Linear, process_group: Union[ProcessGroup, List[ProcessGroup]], *args,
**kwargs) -> ParallelModule:
r"""
Convert a native PyTorch linear layer to a parallelized linear layer.
"""
# get the attributes
in_features = module.in_features
out_features = module.out_features
bias = module.bias is not None
device = module.weight.device
# ensure only one process group is passed
if isinstance(process_group, (list, tuple)):
assert len(process_group) == 1, \
f'Expected only one process group, got {len(process_group)}.'
process_group = process_group[0]
linear_1d = Linear1D_Row(in_features=in_features,
out_features=out_features,
bias=bias,
device=device,
process_group=process_group,
*args,
**kwargs)
# TODO: copy the sharded weights
with torch.no_grad():
# the weigh to the linear layer is a transpose
# thus shard on col is equal to shard on row
sharded_weight = shard_colwise(module.weight.data, process_group)
linear_1d.weight.data.copy_(sharded_weight)
if bias:
linear_1d.bias.copy_(module.bias.data)
return linear_1d
def chunk_weight(self):
self.weight_list = torch.chunk(self.weight, self.stream_chunk_num, dim=0)
def reset_parameters(self, weight_initializer, bias_initializer) -> None:
fan_in, fan_out = self.in_features, self.out_features
weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out)
if self.bias is not None:
bias_initializer(self.bias, fan_in=fan_in)
if self.process_group is None:
src_rank = 0
else:
src_rank = dist.distributed_c10d._get_global_rank(self.process_group, 0)
origin_device = self.bias.device
self.bias = self.bias.cuda()
dist.broadcast(self.bias, src=src_rank, group=self.process_group)
self.bias = self.bias.to(origin_device)
def forward(self, input_: Tensor) -> Tensor:
# Set up backprop all-reduce.
if self.parallel_input:
assert input_.shape[-1] == self.weight.shape[-1], \
'Invalid shapes in Linear1D_Row forward: input={}, weight={}. Expected last dim of input {}.'.format(
input_.shape, self.weight.shape, self.weight.shape[-1])
input_ = input_
else:
assert divide(input_.shape[-1], self.num_partitions) == self.weight.shape[-1], \
'Invalid shapes in Linear1D_Row forward: input={}, weight={}. Expected last dim of input {}.'.format(
input_.shape, self.weight.shape, self.weight.shape[-1] * self.num_partitions)
input_ = split_forward_gather_backward(input_, dim=-1, process_group=self.process_group)
if self.stream_chunk_num > 1:
if self.training:
raise RuntimeError("use stream_chunk_num=1 in Linear1D_Row for training!")
with torch.no_grad():
output_parallel_list = [None for i in range(self.stream_chunk_num)]
handle_list = []
for i in range(self.stream_chunk_num):
output_parallel_list[i] = F.linear(input_, self.weight_list[i])
handle = torch.distributed.all_reduce(output_parallel_list[i],
group=self.process_group,
async_op=True)
handle_list.append(handle)
# output_parallel_list[i] = reduce_input(output_parallel_list[i], ParallelMode.PARALLEL_1D)
for handle in handle_list:
handle.wait()
output = torch.cat(output_parallel_list, dim=-1)
else:
output_parallel = F.linear(input_, self.weight)
# output_parallel = linear_with_async_comm(input_, self.weight, None, ParallelMode.PARALLEL_1D, False)
output = reduce_input(output_parallel, self.process_group)
if not self.skip_bias_add:
if self.bias is not None:
output = output + self.bias
return output
else:
return output, self.bias

View File

@ -0,0 +1,377 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
import math
from typing import Callable, List, Tuple, Union
import torch
import torch.distributed as dist
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
from torch.distributed import ProcessGroup
from torch.nn.parameter import Parameter
from colossalai.nn import init as init
from colossalai.nn.layer.utils import divide
from colossalai.tensor.d_tensor.api import shard_colwise, shard_rowwise
from colossalai.utils.cuda import get_current_device
from ._operation import (
gather_forward_split_backward,
linear_with_async_comm,
reduce_input,
split_forward_gather_backward,
)
from .parallelmodule import ParallelModule
from .utils import create_randomizer_with_offset
Fast_LN = None
try:
from apex.contrib.layer_norm.layer_norm import FastLayerNorm
Fast_LN = FastLayerNorm
except ImportError:
pass
class LinearConv1D_Col(ParallelModule):
r"""Linear layer with column parallelism.
The linear layer is defined as :math:`Y = XA + b`. A is parallelized along
its second dimension as :math:`A = [A_1, ..., A_p]`. This layer is used to fit `Conv1D` layer in gpt2 of huggingface.
Args:
in_features (int): size of each input sample.
out_features (int): size of each output sample.
bias (bool, optional): If set to ``False``, the layer will not learn an additive bias, defaults to ``True``.
dtype (`torch.dtype`): The dtype of parameters, defaults to None.
device (`torch.device`): The device of parameters, defaults to None.
process_group (`torch.distributed.ProcessGroup`): The process group to be used for weight sharding and communication, defaults to None.
gather_output (bool, optional): If true, call all-gather on output and make Y available
to all GPUs, otherwise, every GPU will have its output
which is :math:`Y_i = XA_i`, defaults to False
skip_bias_add (bool): If set to ``True``, it will skip bias add for linear layer,
which is preserved for kernel fusion, defaults to False
weight_initializer (`typing.Callable`):
The initializer of weight, defaults to kaiming uniform initializer.
bias_initializer (`typing.Callable`):
The initializer of bias, defaults to xavier uniform initializer.
More details about ``initializer`` please refer to
`init <https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/nn/init.py>`_.
"""
def __init__(self,
in_features: int,
out_features: int,
bias: bool = True,
dtype: torch.dtype = None,
device: torch.device = None,
process_group: ProcessGroup = None,
gather_output: bool = False,
skip_bias_add: bool = False,
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1)):
super().__init__()
# Keep input parameters
self.in_features = in_features
self.out_features = out_features
self.gather_output = gather_output
self.skip_bias_add = skip_bias_add
self.device = device
self.process_group = process_group
self.num_partitions = dist.get_world_size(self.process_group)
if skip_bias_add and not bias:
raise ValueError('cannot skip bias addition if bias is None')
self.out_features_per_partition = divide(out_features, self.num_partitions)
# Parameters.
# Initialize weight.
if device is None:
device = get_current_device()
factory_kwargs = {'device': device, 'dtype': dtype}
self.weight = Parameter(torch.empty(self.out_features_per_partition, self.in_features, **factory_kwargs))
if bias:
self.bias = Parameter(torch.empty(self.out_features_per_partition, **factory_kwargs))
else:
self.bias = None
# offset the seed with randomizer index and rank
seed = torch.random.initial_seed()
self.randomizer = create_randomizer_with_offset(seed, process_group=self.process_group)
with self.randomizer.fork_rng(enable_cpu=True):
self.reset_parameters(weight_initializer, bias_initializer)
@staticmethod
def from_native_module(module: nn.Linear, process_group: Union[ProcessGroup, List[ProcessGroup]], n_cast: int,
*args, **kwargs) -> ParallelModule:
r"""
Convert a huggingface layer `Conv1D` in gpt2 to a parallelized linear layer.
"""
# get the attributes
in_features = module.weight.shape[0]
out_features = module.weight.shape[1]
bias = module.bias is not None
device = module.weight.device
# ensure only one process group is passed
if isinstance(process_group, (list, tuple)):
assert len(process_group) == 1, \
f'Expected only one process group, got {len(process_group)}.'
process_group = process_group[0]
linear_1d = LinearConv1D_Col(in_features=in_features,
out_features=out_features,
bias=bias,
device=device,
process_group=process_group,
*args,
**kwargs)
# TODO: copy the sharded weights
with torch.no_grad():
# the weigh to the linear layer is a transpose
# thus shard on row is equal to shard on column
# first rearange the order of weight and bias
world_size = dist.get_world_size(group=process_group)
order = torch.arange(world_size * n_cast)
new_order = []
for i in range(world_size):
new_order.append(order[i::world_size])
new_order = torch.cat(new_order)
weight_chunks = torch.chunk(module.weight.data, world_size * n_cast, dim=1)
rearanged_weight_chunks = [weight_chunks[i] for i in new_order]
rearanged_weight = torch.cat(rearanged_weight_chunks, dim=1)
sharded_weight = shard_colwise(rearanged_weight, process_group)
linear_1d.weight.data.copy_(sharded_weight.T.contiguous())
if bias:
bias_chunks = torch.chunk(module.bias.data, world_size * n_cast, dim=0)
rearanged_bias_chunks = [bias_chunks[i] for i in new_order]
rearanged_bias = torch.cat(rearanged_bias_chunks, dim=0)
sharded_bias = shard_colwise(rearanged_bias, process_group)
linear_1d.bias.copy_(sharded_bias.contiguous())
return linear_1d
def reset_parameters(self, weight_initializer, bias_initializer) -> None:
fan_in, fan_out = self.in_features, self.out_features
weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out)
if self.bias is not None:
bias_initializer(self.bias, fan_in=fan_in)
def forward(self, input_: Tensor) -> Tuple[Tensor, Tensor]:
assert input_.shape[-1] == self.weight.shape[-1], \
'Invalid shapes in Linear1D_Col forward: input={}, weight={}. Expected last dim of input {}.'.format(
input_.shape, self.weight.shape, self.weight.shape[-1])
# Set up backprop all-reduce.
# input_parallel = reduce_grad(input_, ParallelMode.PARALLEL_1D)
input_parallel = input_
# Matrix multiply.
bias = self.bias if not self.skip_bias_add else None
output_parallel = linear_with_async_comm(input_parallel, self.weight, bias, self.process_group, True)
if self.gather_output:
# All-gather across the partitions.
output = gather_forward_split_backward(output_parallel, dim=-1, process_group=self.process_group)
else:
output = output_parallel
if self.skip_bias_add:
return output, self.bias
else:
return output
class LinearConv1D_Row(ParallelModule):
r""" Linear layer with row parallelism
Args:
in_features (int): size of each input sample.
out_features (int): size of each output sample.
bias (bool, optional): If set to ``False``, the layer will not learn an additive bias, defaults to ``True``.
dtype (`torch.dtype`): The dtype of parameters, defaults to None.
parallel_input (bool): If set to ``True``, it's assumed that the input is split, defaults to False.
skip_bias_add (bool): If set to ``True``, it will skip bias add for linear layer,
which is preserved for kernel fusion, defaults to False
weight_initializer (:class:`typing.Callable`, optional):
The initializer of weight, defaults to kaiming uniform initializer.
bias_initializer (:class:`typing.Callable`, optional):
The initializer of bias, defaults to xavier uniform initializer.
More details about ``initializer`` please refer to
`init <https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/nn/init.py>`_.
"""
def __init__(self,
in_features: int,
out_features: int,
bias: bool = True,
dtype: torch.dtype = None,
device: torch.device = None,
process_group: ProcessGroup = None,
parallel_input: bool = True,
skip_bias_add: bool = False,
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1),
stream_chunk_num: int = 1):
super().__init__()
self.stream_chunk_num = stream_chunk_num
# Keep input parameters
self.in_features = in_features
self.out_features = out_features
self.parallel_input = parallel_input
self.skip_bias_add = skip_bias_add
self.process_group = process_group
self.num_partitions = dist.get_world_size(self.process_group)
if skip_bias_add and not bias:
raise ValueError('cannot skip bias addition if bias is None')
# Divide the weight matrix along the last dimension.
self.input_size_per_partition = divide(in_features, self.num_partitions)
# Parameters.
# Initialize weight.
if device is None:
device = get_current_device()
factory_kwargs = {'device': device, 'dtype': dtype}
self.weight = Parameter(torch.empty(self.out_features, self.input_size_per_partition, **factory_kwargs))
if self.stream_chunk_num > 1:
# TODO() work for inference only
self.chunk_weight()
if bias:
self.bias = Parameter(torch.empty(self.out_features, **factory_kwargs))
else:
self.bias = None
# offset the seed with randomizer index and rank
seed = torch.random.initial_seed()
self.randomizer = create_randomizer_with_offset(seed, process_group=self.process_group)
with self.randomizer.fork_rng(enable_cpu=True):
self.reset_parameters(weight_initializer, bias_initializer)
@staticmethod
def from_native_module(module: nn.Linear, process_group: Union[ProcessGroup, List[ProcessGroup]], n_cast: int,
*args, **kwargs) -> ParallelModule:
r"""
Convert a native PyTorch linear layer to a parallelized linear layer.
"""
# get the attributes
in_features = module.weight.shape[0]
out_features = module.weight.shape[1]
bias = module.bias is not None
device = module.weight.device
# ensure only one process group is passed
if isinstance(process_group, (list, tuple)):
assert len(process_group) == 1, \
f'Expected only one process group, got {len(process_group)}.'
process_group = process_group[0]
linear_1d = LinearConv1D_Row(in_features=in_features,
out_features=out_features,
bias=bias,
device=device,
process_group=process_group,
*args,
**kwargs)
# TODO: copy the sharded weights
with torch.no_grad():
# the weigh to the linear layer is a transpose
# thus shard on col is equal to shard on row
# first rearange the order of weight and bias
world_size = dist.get_world_size(group=process_group)
order = torch.arange(world_size * n_cast)
new_order = []
for i in range(world_size):
new_order.append(order[i::world_size])
new_order = torch.cat(new_order)
weight_chunks = torch.chunk(module.weight.data, world_size * n_cast, dim=0)
rearanged_weight_chunks = [weight_chunks[i] for i in new_order]
rearanged_weight = torch.cat(rearanged_weight_chunks, dim=0)
sharded_weight = shard_rowwise(rearanged_weight, process_group)
linear_1d.weight.data.copy_(sharded_weight.T.contiguous())
if bias:
bias_chunks = torch.chunk(module.bias.data, world_size * n_cast, dim=0)
rearanged_bias_chunks = [bias_chunks[i] for i in new_order]
rearanged_bias = torch.cat(rearanged_bias_chunks, dim=0)
linear_1d.bias.copy_(rearanged_bias.contiguous())
return linear_1d
def chunk_weight(self):
self.weight_list = torch.chunk(self.weight, self.stream_chunk_num, dim=0)
def reset_parameters(self, weight_initializer, bias_initializer) -> None:
fan_in, fan_out = self.in_features, self.out_features
weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out)
if self.bias is not None:
bias_initializer(self.bias, fan_in=fan_in)
if self.process_group is None:
src_rank = 0
else:
src_rank = dist.distributed_c10d._get_global_rank(self.process_group, 0)
origin_device = self.bias.device
self.bias = self.bias.cuda()
dist.broadcast(self.bias, src=src_rank, group=self.process_group)
self.bias = self.bias.to(origin_device)
def forward(self, input_: Tensor) -> Tensor:
# Set up backprop all-reduce.
if self.parallel_input:
assert input_.shape[-1] == self.weight.shape[-1], \
'Invalid shapes in Linear1D_Row forward: input={}, weight={}. Expected last dim of input {}.'.format(
input_.shape, self.weight.shape, self.weight.shape[-1])
input_ = input_
else:
assert divide(input_.shape[-1], self.num_partitions) == self.weight.shape[-1], \
'Invalid shapes in Linear1D_Row forward: input={}, weight={}. Expected last dim of input {}.'.format(
input_.shape, self.weight.shape, self.weight.shape[-1] * self.num_partitions)
input_ = split_forward_gather_backward(input_, dim=-1, process_group=self.process_group)
if self.stream_chunk_num > 1:
if self.training:
raise RuntimeError("use stream_chunk_num=1 in Linear1D_Row for training!")
with torch.no_grad():
output_parallel_list = [None for i in range(self.stream_chunk_num)]
handle_list = []
for i in range(self.stream_chunk_num):
output_parallel_list[i] = F.linear(input_, self.weight_list[i])
handle = torch.distributed.all_reduce(output_parallel_list[i],
group=self.process_group,
async_op=True)
handle_list.append(handle)
# output_parallel_list[i] = reduce_input(output_parallel_list[i], ParallelMode.PARALLEL_1D)
for handle in handle_list:
handle.wait()
output = torch.cat(output_parallel_list, dim=-1)
else:
output_parallel = F.linear(input_, self.weight)
# output_parallel = linear_with_async_comm(input_, self.weight, None, ParallelMode.PARALLEL_1D, False)
output = reduce_input(output_parallel, self.process_group)
if not self.skip_bias_add:
if self.bias is not None:
output = output + self.bias
return output
else:
return output, self.bias

View File

@ -0,0 +1,35 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
from abc import ABC, abstractmethod
from typing import List, Union
import torch.nn as nn
from torch.distributed import ProcessGroup
from colossalai.global_variables import tensor_parallel_env as env
from colossalai.nn import init as init
Fast_LN = None
try:
from apex.contrib.layer_norm.layer_norm import FastLayerNorm
Fast_LN = FastLayerNorm
except ImportError:
pass
class ParallelModule(nn.Module, ABC):
@abstractmethod
def from_native_module(module: nn.Module,
process_group: Union[ProcessGroup, List[ProcessGroup]] = None) -> "ParallelModule":
"""
Convert a native PyTorch module to a parallelized module.
Args:
module (nn.Module): the module to be converted.
process_group (ProcessGroup or list[ProcessGroup]): the process group(s) to be used for communication.
If this is a list, the process group at the ith index of the list will correspond to the process group
in the ith axis of the device mesh. Defaults to None, which means the global process group.
"""
pass

View File

@ -0,0 +1,170 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
from collections import OrderedDict
from typing import Callable, List, Union
import torch
import torch.distributed as dist
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
from torch.distributed import ProcessGroup
from torch.nn.parameter import Parameter
from colossalai.context import ParallelMode, seed
from colossalai.nn import init as init
from colossalai.nn.layer.base_layer import ParallelLayer
from colossalai.nn.layer.utils import divide
from colossalai.tensor.d_tensor.api import shard_rowwise
from colossalai.utils.checkpointing import gather_tensor_parallel_state_dict
from ._operation import reduce_input
from .parallelmodule import ParallelModule
from .utils import create_randomizer_with_offset
Fast_LN = None
try:
from apex.contrib.layer_norm.layer_norm import FastLayerNorm
Fast_LN = FastLayerNorm
except ImportError:
pass
class VocabParallelEmbedding1D(ParallelLayer):
r"""Embedding parallelized in the vocabulary dimension.
Args:
num_embeddings (int): number of embeddings.
embedding_dim (int): dimension of embedding.
padding_idx (int, optional): If specified, the entries at padding_idx do not contribute to the gradient;
therefore, the embedding vector at padding_idx is not updated during training,
i.e. it remains as a fixed pad, defaults to None.
dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None.
weight_initializer (:class:`typing.Callable`, optional):
he initializer of weight, defaults to normal initializer.
The ``args`` and ``kwargs`` used in :class:``torch.nn.functional.embedding`` should contain:
::
max_norm (float, optional): If given, each embedding vector with norm larger than max_norm is
renormalized to have norm max_norm. Note: this will modify weight in-place.
norm_type (float, optional): The p of the p-norm to compute for the max_norm option. Default 2.
scale_grad_by_freq (bool, optional): If given, this will scale gradients by the inverse
of frequency of the words in the mini-batch. Default False.
sparse (bool, optional): If True, gradient w.r.t. weight will be a sparse tensor. Default False.
More details about ``args`` and ``kwargs`` could be found in
`Embedding <https://pytorch.org/docs/stable/generated/torch.nn.functional.embedding.html#torch.nn.functional.embedding>`_.
More details about initializer please refer to
`init <https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/nn/init.py>`_.
"""
def __init__(self,
num_embeddings: int,
embedding_dim: int,
padding_idx: int = None,
dtype: torch.dtype = None,
device: torch.device = None,
process_group: ProcessGroup = None,
weight_initializer: Callable = init.normal_(),
*args,
**kwargs):
super().__init__()
self.num_embeddings = num_embeddings
self.embed_dim = embedding_dim
self.padding_idx = padding_idx
self.embed_args = args
self.embed_kwargs = kwargs
self.process_group = process_group
tensor_parallel_size = dist.get_world_size(group=process_group)
tensor_parallel_rank = dist.get_rank(group=process_group)
self.num_embeddings_per_partition = divide(num_embeddings, tensor_parallel_size)
self.num_embeddings = self.num_embeddings_per_partition
self.vocab_start_index = tensor_parallel_rank * self.num_embeddings_per_partition
self.vocab_end_index = self.vocab_start_index + self.num_embeddings_per_partition
self.weight = Parameter(
torch.empty((self.num_embeddings_per_partition, self.embed_dim), device=device, dtype=dtype))
# offset the seed with randomizer index and rank
seed = torch.random.initial_seed()
self.randomizer = create_randomizer_with_offset(seed, process_group=self.process_group)
with self.randomizer.fork_rng(enable_cpu=True):
self.reset_parameters(weight_initializer)
@staticmethod
def from_native_module(module: nn.Embedding, process_group: Union[ProcessGroup, List[ProcessGroup]], *args,
**kwargs) -> ParallelModule:
r"""
Convert a native pytorch embedding module to a parallel module.
"""
# get the origin attributes
num_embeddings = module.num_embeddings
embedding_dim = module.embedding_dim
padding_idx = module.padding_idx
device = module.weight.device
# ensure only one process group is used
if isinstance(process_group, (list, tuple)):
assert len(process_group) == 1, \
f'Expected only one process group, got {len(process_group)}.'
process_group = process_group[0]
# create the parallel module
vocab_embedding_1d = VocabParallelEmbedding1D(num_embeddings=num_embeddings,
embedding_dim=embedding_dim,
padding_idx=padding_idx,
device=device,
process_group=process_group,
*args,
**kwargs)
with torch.no_grad():
# shard and slice the weight along the vocabulary(num_embeddings) dimension
# the shape of the weight is (num_embeddings, embedding_dim)
shard_weight = shard_rowwise(module.weight.data, process_group)
vocab_embedding_1d.weight.data.copy_(shard_weight)
return vocab_embedding_1d
def reset_parameters(self, weight_initializer) -> None:
with seed(ParallelMode.TENSOR):
fan_in, fan_out = self.num_embeddings, self.embed_dim
weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out)
self._fill_padding_idx_with_zero()
def _fill_padding_idx_with_zero(self) -> None:
if self.padding_idx is not None and \
self.padding_idx >= self.vocab_start_index and self.padding_idx < self.vocab_end_index:
with torch.no_grad():
self.weight[self.padding_idx - self.vocab_start_index].fill_(0)
def _save_to_global_state_dict(self, destination, prefix, keep_vars):
weight_key = prefix + 'weight'
local_state = OrderedDict({weight_key: self.weight})
local_state = gather_tensor_parallel_state_dict(local_state,
ParallelMode.PARALLEL_1D,
dims={weight_key: 0},
partition_states={weight_key: True},
keep_vars=keep_vars)
destination.update(local_state)
def forward(self, input_: Tensor) -> Tensor:
# Build the mask.
input_mask = (input_ < self.vocab_start_index) | (input_ >= self.vocab_end_index)
# Mask the input.
masked_input = input_.clone() - self.vocab_start_index
masked_input[input_mask] = 0
output_parallel = F.embedding(masked_input, self.weight, self.padding_idx, *self.embed_args,
**self.embed_kwargs)
# Mask the output embedding.
output_parallel[input_mask, :] = 0.
# Reduce across all the model parallel GPUs.
output = reduce_input(output_parallel, self.process_group)
return output

View File

@ -56,6 +56,8 @@ _POLICY_LIST = {
PolicyLocation(file_name="t5", class_name="T5EncoderPolicy"),
# GPT2
"transformers.models.gpt2.modeling_gpt2.GPT2Model":
PolicyLocation(file_name="gpt2", class_name="GPT2ModelPolicy"),
}
@ -99,4 +101,3 @@ def get_autopolicy(model: nn.Module) -> Policy:
else:
policy = import_policy(policy_location)
return policy()
return policy()

View File

@ -1,7 +1,7 @@
import torch.nn as nn
from transformers.models.bert.modeling_bert import BertEmbeddings, BertLayer, BertLMPredictionHead
import colossalai.shardformer.layer.layers as col_nn
import colossalai.shardformer.layer as col_nn
from colossalai.shardformer.layer.dropout import Dropout1D
from ..utils import getattr_, setattr_
@ -87,15 +87,9 @@ class BertPolicy(Policy):
def new_model_class(self):
# do nothing
return None
return self.model
def postprocess(self):
binding_map = {"bert.embeddings.word_embeddings.weight": "cls.predictions.decoder.weight"}
for k, v in binding_map.items():
param = getattr_(self.model, k)
param = nn.Parameter(param)
setattr_(self.model, k, param)
setattr_(self.model, v, param)
return self.model
@ -127,6 +121,15 @@ class BertForPretrainingPolicy(BertPolicy):
module_policy.update(addon_module)
return module_policy
def postprocess(self):
binding_map = {"bert.embeddings.word_embeddings.weight": "cls.predictions.decoder.weight"}
for k, v in binding_map.items():
param = getattr_(self.model, k)
param = nn.Parameter(param)
setattr_(self.model, k, param)
setattr_(self.model, v, param)
return self.model
# BertForMaskedLM
class BertForMaskedLMPolicy(BertPolicy):
@ -149,6 +152,15 @@ class BertForMaskedLMPolicy(BertPolicy):
module_policy.update(addon_module)
return module_policy
def postprocess(self):
binding_map = {"bert.embeddings.word_embeddings.weight": "cls.predictions.decoder.weight"}
for k, v in binding_map.items():
param = getattr_(self.model, k)
param = nn.Parameter(param)
setattr_(self.model, k, param)
setattr_(self.model, v, param)
return self.model
# BertLMHeadModel
class BertLMHeadModelPolicy(BertPolicy):
@ -171,6 +183,15 @@ class BertLMHeadModelPolicy(BertPolicy):
module_policy.update(addon_module)
return module_policy
def postprocess(self):
binding_map = {"bert.embeddings.word_embeddings.weight": "cls.predictions.decoder.weight"}
for k, v in binding_map.items():
param = getattr_(self.model, k)
param = nn.Parameter(param)
setattr_(self.model, k, param)
setattr_(self.model, v, param)
return self.model
# BertForNextSentencePrediction
class BertForNextSentencePredictionPolicy(BertPolicy):

View File

@ -1,126 +1,101 @@
from typing import Any, Callable, Dict, List, Tuple, Type
from typing import Type, Union
import torch.nn as nn
from transformers.models.gpt2.modeling_gpt2 import GPT2Block, GPT2Model
import colossalai.shardformer.layer.layers as col_nn
import colossalai.shardformer.layer as col_nn
from colossalai.shardformer.layer.dropout import Dropout1D
from .basepolicy import Argument, Col_Layer, Layer, Policy, Row_Layer
from ..utils import getattr_, setattr_
from .basepolicy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
class GPT2Policy(Policy):
@staticmethod
def argument_policy(config, world_size):
def preprocess(self):
# reshape the embedding layer
r"""
Reshape the Embedding layer to make the embedding dimension divisible by world_size
"""
vocab_size = self.model.config.vocab_size
world_size = self.shard_config.tensor_parallel_size
if vocab_size % world_size != 0:
new_vocab_size = vocab_size + world_size - vocab_size % world_size
self.model.resize_token_embeddings(new_vocab_size)
return self.model
def module_policy(self):
return {
GPT2Model:
Argument(attr_dict={}, param_funcs=[
GPT2Policy.embedding,
]),
ModulePolicyDescription(attribute_replacement={},
param_replacement=[],
sub_module_replacement=[
SubModuleReplacementDescription(
suffix="wte",
target_module=col_nn.VocabParallelEmbedding1D,
),
]),
GPT2Block:
Argument(
attr_dict={
# 1. reduce hidden size
"attn.embed_dim": config.hidden_size // world_size,
"attn.split_size": config.hidden_size // world_size,
"crossattention.embed_dim": config.hidden_size // world_size,
"crossattention.split_size": config.hidden_size // world_size,
# 2. reduce number of heads
"attn.num_heads": config.num_attention_heads // world_size,
"crossattention.num_heads": config.num_attention_heads // world_size,
},
param_funcs=[
GPT2Policy.attn_in,
GPT2Policy.attn_out,
GPT2Policy.mlp_in,
GPT2Policy.mlp_out,
]),
ModulePolicyDescription(attribute_replacement={
"attn.embed_dim": self.model.config.hidden_size // self.shard_config.tensor_parallel_size,
"attn.split_size": self.model.config.hidden_size // self.shard_config.tensor_parallel_size,
"attn.num_heads": self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size,
},
param_replacement=[],
sub_module_replacement=[
SubModuleReplacementDescription(
suffix="attn.c_attn",
target_module=col_nn.LinearConv1D_Col,
kwargs={
"n_cast": 3,
},
),
SubModuleReplacementDescription(
suffix="attn.c_proj",
target_module=col_nn.LinearConv1D_Row,
kwargs={
"n_cast": 1,
},
),
SubModuleReplacementDescription(
suffix="mlp.c_fc",
target_module=col_nn.LinearConv1D_Col,
kwargs={
"n_cast": 1,
},
),
SubModuleReplacementDescription(
suffix="mlp.c_proj",
target_module=col_nn.LinearConv1D_Row,
kwargs={
"n_cast": 1,
},
),
SubModuleReplacementDescription(
suffix="attn.attn_dropout",
target_module=col_nn.Dropout1D,
),
SubModuleReplacementDescription(
suffix="attn.resid_dropout",
target_module=col_nn.Dropout1D,
),
SubModuleReplacementDescription(
suffix="mlp.dropout",
target_module=col_nn.Dropout1D,
),
])
}
@staticmethod
def attn_in() -> List:
return [
Col_Layer(suffix="attn.c_attn",
weight="weight",
bias="bias",
n_cast=3,
reversed=True,
replace_layer=col_nn.Linear1D_Col),
Col_Layer(suffix="crossattention.c_attn",
weight="weight",
bias="bias",
n_cast=2,
reversed=True,
ignore=True,
replace_layer=col_nn.Linear1D_Col),
Col_Layer(suffix="crossattention.q_attn",
weight="weight",
bias="bias",
reversed=True,
ignore=True,
replace_layer=col_nn.Linear1D_Col)
]
def new_model_class(self):
@staticmethod
def attn_out() -> List:
return [
Row_Layer(suffix="attn.c_proj",
weight="weight",
bias="bias",
reversed=True,
replace_layer=col_nn.Linear1D_Row),
Row_Layer(suffix="crossattention.c_proj",
weight="weight",
bias="bias",
reversed=True,
ignore=True,
replace_layer=col_nn.Linear1D_Row)
]
return self.model
@staticmethod
def mlp_in() -> List:
return [
Col_Layer(suffix="mlp.c_fc", weight="weight", bias="bias", reversed=True,
replace_layer=col_nn.Linear1D_Col),
]
@staticmethod
def mlp_out() -> List:
return [
Row_Layer(suffix="mlp.c_proj",
weight="weight",
bias="bias",
reversed=True,
replace_layer=col_nn.Linear1D_Row)
]
@staticmethod
def embedding() -> List:
return [Col_Layer(suffix="wte", weight="weight", replace_layer=col_nn.VocabParallelEmbedding1D)]
def postprocess(self):
return self.model
from transformers import GPT2LMHeadModel
# GPT2Model
class GPT2ModelPolicy(GPT2Policy):
class GPT2LMHeadModelPolicy(GPT2Policy):
@staticmethod
def argument_policy(config, world_size):
base_argument = GPT2Policy.argument_policy(config, world_size)
argument = {
GPT2LMHeadModel: Argument(attr_dict={}, param_funcs=[
GPT2LMHeadModelPolicy.unembedding,
]),
}
argument.update(base_argument)
return argument
@staticmethod
def unembedding() -> List:
return [
Col_Layer(suffix="lm_head",
weight="weight",
bias="bias",
replace_layer=col_nn.Linear1D_Col,
gather_output=True)
]
def __init__(self) -> None:
super().__init__()

View File

@ -108,7 +108,7 @@ def check_bert(rank, world_size, port):
backward_lsit = [BertForMaskedLM, BertLMHeadModel]
for model_fn in forward_list:
org_model, sharded_model = build_model(model_fn)
org_model, sharded_model = build_model(world_size, model_fn)
check_forward(org_model, sharded_model)
if model_fn in backward_lsit:

View File

@ -0,0 +1,118 @@
import copy
import os
import pytest
import torch
from transformers import AutoTokenizer, GPT2Config, GPT2Model
import colossalai
from colossalai.logging import disable_existing_loggers
from colossalai.shardformer import ShardConfig, ShardFormer
from colossalai.testing import rerun_if_address_is_in_use, spawn
os.environ['TRANSFORMERS_NO_ADVISORY_WARNINGS'] = 'true'
CONFIG = dict(parallel=dict(data=1, pipeline=1, tensor=dict(size=2, mode='1d')),)
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
def build_model(world_size, model_fn):
config = GPT2Config()
config.attn_pdrop = 0
config.embd_pdrop = 0
config.resid_pdrop = 0
config.summary_first_dropout
org_model = model_fn(config=config)
org_model_forshard = copy.deepcopy(org_model)
org_model.to('cuda')
# TODO: no need to transfer to cuda
org_model_forshard.to('cuda')
shard_config = ShardConfig(tensor_parallel_size=world_size,)
shard_former = ShardFormer(shard_config=shard_config)
shard_former.init_distributed()
sharded_model = shard_former.shard_model(org_model_forshard).to('cuda')
return org_model, sharded_model
def check_forward(org_model, sharded_model):
input = 'Hello, my dog is cute'
tokenized_input = tokenizer(input, return_tensors='pt').to('cuda')
#orgin model
org_model.eval()
org_out = org_model(**tokenized_input)
#shard model
sharded_model.eval()
shard_out = sharded_model(**tokenized_input)
assert torch.allclose(
org_out[0], shard_out[0],
atol=1e-5), f"shard model output is not equal to orgin model output\n{org_out[0]}\n{shard_out[0]}"
def check_backward(org_model, sharded_model):
# prepare input
input = 'Hello, my dog is cute'
tokenized_input = tokenizer(input, return_tensors='pt').to('cuda')
labels = tokenized_input['input_ids'].clone()
labels[labels == tokenizer.pad_token_id] = -100
# tokenized_input['labels'] = labels
#orgin model
org_model.train()
org_out = org_model(**tokenized_input)
org_loss = org_out.loss
org_loss.backward()
org_grad = org_model.h[0].attn.c_attn.weight.grad
#shard model
sharded_model.train()
shard_out = sharded_model(**tokenized_input)
shard_loss = shard_out.loss
shard_loss.backward()
shard_grad = sharded_model.h[0].attn.c_attn.weight.grad
shard_grad_list = [torch.zeros([*shard_grad.shape]).to('cuda') for _ in range(2)]
shard_grad = torch.distributed.all_gather(shard_grad_list, shard_grad)
all_shard_grad = torch.cat(shard_grad_list, dim=0)
assert torch.allclose(org_loss, shard_loss,
atol=1e-5), f"shard model loss is not equal to orgin model loss\n{org_loss}\n{shard_loss}"
assert torch.allclose(org_grad, all_shard_grad,
atol=1e-5), f"shard model grad is not equal to orgin model grad\n{org_grad}\n{shard_grad}"
def check_bert(rank, world_size, port):
disable_existing_loggers()
colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
forward_list = [
GPT2Model,
# TODO: do not work yet
# BertModel,
# BertForSequenceClassification
# BertForNextSentencePrediction,
]
backward_lsit = []
for model_fn in forward_list:
org_model, sharded_model = build_model(world_size, model_fn)
check_forward(org_model, sharded_model)
if model_fn in backward_lsit:
check_backward(org_model, sharded_model)
torch.cuda.empty_cache()
@pytest.mark.dist
@rerun_if_address_is_in_use()
def test_gpt2():
spawn(check_bert, 2)
if __name__ == "__main__":
test_gpt2()