mirror of https://github.com/hpcaitech/ColossalAI
[shardformer] supported bloom model (#4098)
parent
8af29ee47a
commit
b1c2901530
|
@ -83,8 +83,10 @@ We will follow this roadmap to develop Shardformer:
|
|||
- [x] BERT
|
||||
- [x] T5
|
||||
- [x] LlaMa
|
||||
- [ ] GPT2
|
||||
- [ ] BLOOM
|
||||
- [x] GPT2
|
||||
- [x] OPT
|
||||
- [x] BLOOM
|
||||
- [ ] GLM
|
||||
- [ ] RoBERTa
|
||||
- [ ] ALBERT
|
||||
- [ ] ERNIE
|
||||
|
@ -96,7 +98,7 @@ We will follow this roadmap to develop Shardformer:
|
|||
- [ ] SwinTransformer
|
||||
- [ ] SwinTransformer V2
|
||||
- [ ] Audio
|
||||
- [ ] To be added
|
||||
- [ ] Whisper
|
||||
- [ ] Multi-modal
|
||||
- [ ] To be added
|
||||
|
||||
|
|
|
@ -1,11 +1,12 @@
|
|||
from .dropout import Dropout1D
|
||||
from .dropout import DropoutForParallelInput, DropoutForReplicatedInput
|
||||
from .embedding import Embedding1D, VocabParallelEmbedding1D
|
||||
from .layernorm import FusedLayerNorm
|
||||
from .linear import Linear1D_Col, Linear1D_Row
|
||||
from .linear_conv import LinearConv1D_Col, LinearConv1D_Row
|
||||
from .loss import cross_entropy_1d
|
||||
from .qkv_fused_linear import GPT2FusedLinearConv1D_Col, GPT2FusedLinearConv1D_Row
|
||||
|
||||
__all__ = [
|
||||
"Embedding1D", "VocabParallelEmbedding1D", "Linear1D_Col", "Linear1D_Row", "LinearConv1D_Col", "LinearConv1D_Row",
|
||||
"Dropout1D", "cross_entropy_1d", 'FusedLayerNorm'
|
||||
"Embedding1D", "VocabParallelEmbedding1D", "Linear1D_Col", "Linear1D_Row", 'GPT2FusedLinearConv1D_Col',
|
||||
'GPT2FusedLinearConv1D_Row', 'DropoutForParallelInput', 'DropoutForReplicatedInput', "cross_entropy_1d",
|
||||
'FusedLayerNorm'
|
||||
]
|
||||
|
|
|
@ -7,10 +7,10 @@ from torch.distributed import ProcessGroup
|
|||
from .parallel_module import ParallelModule
|
||||
from .utils import create_randomizer_with_offset
|
||||
|
||||
__all__ = ['Dropout1D']
|
||||
__all__ = ['DropoutForParallelInput', 'DropoutForReplicatedInput']
|
||||
|
||||
|
||||
class Dropout1D(ParallelModule, nn.Dropout):
|
||||
class DropoutForParallelInput(ParallelModule, nn.Dropout):
|
||||
"""
|
||||
The Dropout Layer will apply dropout mask to the input tensor. The dropout mask is generated with
|
||||
randomness on different ranks of the given process group. This can avoid the same dropout mask is generated
|
||||
|
@ -32,13 +32,50 @@ class Dropout1D(ParallelModule, nn.Dropout):
|
|||
|
||||
@staticmethod
|
||||
def from_native_module(module: nn.Dropout,
|
||||
process_group: Union[ProcessGroup, List[ProcessGroup]] = None) -> "Dropout1D":
|
||||
process_group: Union[ProcessGroup, List[ProcessGroup]] = None) -> "DropoutForParallelInput":
|
||||
"""
|
||||
Create a Dropout1D layer from a native dropout layer.
|
||||
Create a DropoutForParallelInput layer from a native dropout layer.
|
||||
"""
|
||||
p = module.p
|
||||
inplace = module.inplace
|
||||
return Dropout1D(p=p, inplace=inplace, process_group=process_group)
|
||||
return DropoutForParallelInput(p=p, inplace=inplace, process_group=process_group)
|
||||
|
||||
def forward(self, input):
|
||||
with self.randomizer.fork_rng():
|
||||
input = super().forward(input)
|
||||
return input
|
||||
|
||||
|
||||
class DropoutForReplicatedInput(ParallelModule, nn.Dropout):
|
||||
"""
|
||||
The Dropout Layer will apply dropout mask to the input tensor. The dropout mask is generated with
|
||||
randomness on different ranks of the given process group. This can avoid the same dropout mask is generated
|
||||
and applied on the same position of different ranks, leading to poor convergence performance.
|
||||
|
||||
Args:
|
||||
p (float): probability of an element to be zeroed. Defaults to 0.5.
|
||||
inplace (bool): If set to True, will do this operation in-place. Defaults to False.
|
||||
process_group (ProcessGroup): the process group to be used for generating randomness. Defaults to None.
|
||||
"""
|
||||
|
||||
def __init__(self, p: float = 0.5, inplace: bool = False, process_group: ProcessGroup = None):
|
||||
# init with nn.Dropout
|
||||
super(nn.Dropout, self).__init__(p=p, inplace=inplace)
|
||||
|
||||
# offset the seed with randomizer index only
|
||||
seed = torch.random.initial_seed()
|
||||
self.randomizer = create_randomizer_with_offset(seed, process_group=process_group, offset_by_rank=False)
|
||||
|
||||
@staticmethod
|
||||
def from_native_module(
|
||||
module: nn.Dropout,
|
||||
process_group: Union[ProcessGroup, List[ProcessGroup]] = None) -> "DropoutForReplicatedInput":
|
||||
"""
|
||||
Create a Dropout1D layer from a native dropout layer.
|
||||
"""
|
||||
p = module.p
|
||||
inplace = module.inplace
|
||||
return DropoutForReplicatedInput(p=p, inplace=inplace, process_group=process_group)
|
||||
|
||||
def forward(self, input):
|
||||
with self.randomizer.fork_rng():
|
||||
|
|
|
@ -277,6 +277,7 @@ class Linear1D_Row(ParallelModule):
|
|||
def chunk_weight(self):
|
||||
self.weight_list = torch.chunk(self.weight, self.stream_chunk_num, dim=0)
|
||||
|
||||
@torch.no_grad()
|
||||
def reset_parameters(self, weight_initializer, bias_initializer) -> None:
|
||||
fan_in, fan_out = self.in_features, self.out_features
|
||||
weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out)
|
||||
|
@ -289,9 +290,10 @@ class Linear1D_Row(ParallelModule):
|
|||
src_rank = dist.distributed_c10d._get_global_rank(self.process_group, 0)
|
||||
|
||||
origin_device = self.bias.device
|
||||
self.bias = self.bias.cuda()
|
||||
dist.broadcast(self.bias, src=src_rank, group=self.process_group)
|
||||
self.bias = self.bias.to(origin_device)
|
||||
bias = self.bias.cuda()
|
||||
dist.broadcast(bias, src=src_rank, group=self.process_group)
|
||||
bias = bias.to(origin_device)
|
||||
self.bias.copy_(bias)
|
||||
|
||||
def forward(self, input_: Tensor) -> Tensor:
|
||||
# Set up backprop all-reduce.
|
||||
|
|
|
@ -31,12 +31,25 @@ from ._operation import (
|
|||
from .parallel_module import ParallelModule
|
||||
from .utils import create_randomizer_with_offset
|
||||
|
||||
__all__ = ['LinearConv1D_Col', 'LinearConv1D_Row']
|
||||
__all__ = ['FusedLinear1D_Col', 'FusedLinear1D_Row']
|
||||
|
||||
# ====================================
|
||||
# For GPT Only
|
||||
# ====================================
|
||||
|
||||
|
||||
def split_fused_qkv(qkv: torch.Tensor, n_fused: int, process_group: ProcessGroup):
|
||||
def split_fused_qkv_in_gpt2_style(qkv: torch.Tensor,
|
||||
n_fused: int,
|
||||
process_group: ProcessGroup,
|
||||
is_transposed: bool = False):
|
||||
"""
|
||||
The fused qkv tensor looks like [Q1, Q2, K1, K2, V1, V2], this function will split them into [Q1, K1, V1] and [Q2, K2, V2].
|
||||
|
||||
Args:
|
||||
qkv (torch.Tensor): The fused qkv tensor.
|
||||
n_fused (int): The number items fused together, defaults to 3 (query, key and value).
|
||||
process_group (ProcessGroup): The process group for distributed communication.
|
||||
is_transposed (bool): generally the tensor is the shape of (out_features, in_features). Set this to True if the tensor is in the shape (in_features, out_features).
|
||||
"""
|
||||
# get the number of slice for the fused qkv
|
||||
rank = dist.get_rank(group=process_group)
|
||||
|
@ -48,7 +61,10 @@ def split_fused_qkv(qkv: torch.Tensor, n_fused: int, process_group: ProcessGroup
|
|||
# [Q, K, V]
|
||||
# to
|
||||
# [Q1, Q2, K1, K2, V1, V2]
|
||||
weight_chunks = torch.chunk(qkv, world_size * n_fused, dim=-1)
|
||||
if is_transposed:
|
||||
weight_chunks = torch.chunk(qkv, world_size * n_fused, dim=-1)
|
||||
else:
|
||||
weight_chunks = torch.chunk(qkv, world_size * n_fused, dim=0)
|
||||
|
||||
# rearrange the slice into the final order
|
||||
# from
|
||||
|
@ -56,13 +72,26 @@ def split_fused_qkv(qkv: torch.Tensor, n_fused: int, process_group: ProcessGroup
|
|||
# to
|
||||
# [Q1, K1, V1], [Q2, K2, V2]
|
||||
weight_chunks_of_current_rank = [weight_chunks[i] for i in order[rank::world_size]]
|
||||
weight_of_current_rank = torch.cat(weight_chunks_of_current_rank, dim=-1)
|
||||
|
||||
if is_transposed:
|
||||
weight_of_current_rank = torch.cat(weight_chunks_of_current_rank, dim=-1)
|
||||
else:
|
||||
weight_of_current_rank = torch.cat(weight_chunks_of_current_rank, dim=0)
|
||||
return weight_of_current_rank
|
||||
|
||||
|
||||
def gather_fused_qkv(qkv: torch.Tensor, n_fused: int, process_group: ProcessGroup):
|
||||
def gather_fused_qkv_in_gpt2_style(qkv: torch.Tensor,
|
||||
n_fused: int,
|
||||
process_group: ProcessGroup,
|
||||
is_transposed: bool = False):
|
||||
"""
|
||||
The splitted qkv tensor looks like [Q1, K1, V1] and [Q2, K2, V2], this function will gather them into [Q1, Q2, K1, K2, V1, V2].
|
||||
|
||||
Args:
|
||||
qkv (torch.Tensor): The fused qkv tensor.
|
||||
n_fused (int): The number items fused together, defaults to 3 (query, key and value).
|
||||
process_group (ProcessGroup): The process group for distributed communication.
|
||||
is_transposed (bool): generally the tensor is the shape of (out_features, in_features). Set this to True if the tensor is in the shape (in_features, out_features).
|
||||
"""
|
||||
world_size = dist.get_world_size(group=process_group)
|
||||
|
||||
|
@ -75,7 +104,11 @@ def gather_fused_qkv(qkv: torch.Tensor, n_fused: int, process_group: ProcessGrou
|
|||
qkv = qkv.cuda()
|
||||
gather_list = [torch.zeros_like(qkv) for _ in range(world_size)]
|
||||
dist.all_gather(gather_list, qkv, group=process_group)
|
||||
gather_weight = torch.cat(gather_list, dim=-1)
|
||||
|
||||
if is_transposed:
|
||||
gather_weight = torch.cat(gather_list, dim=-1)
|
||||
else:
|
||||
gather_weight = torch.cat(gather_list, dim=0)
|
||||
gather_weight = gather_weight.to(origin_device)
|
||||
qkv = qkv.to(origin_device)
|
||||
|
||||
|
@ -84,15 +117,23 @@ def gather_fused_qkv(qkv: torch.Tensor, n_fused: int, process_group: ProcessGrou
|
|||
# [Q1, K1, V1, Q2, K2, V2]
|
||||
# to
|
||||
# [Q1, Q2, K1, K2, V1, V2]
|
||||
weight_chunks = torch.chunk(gather_weight, world_size * n_fused, dim=-1)
|
||||
if is_transposed:
|
||||
weight_chunks = torch.chunk(gather_weight, world_size * n_fused, dim=-1)
|
||||
else:
|
||||
weight_chunks = torch.chunk(gather_weight, world_size * n_fused, dim=0)
|
||||
|
||||
reordered_chunk_list = []
|
||||
for i in range(n_fused):
|
||||
reordered_chunk_list.extend(weight_chunks[i::n_fused])
|
||||
reordered_gather_weight = torch.cat(reordered_chunk_list, dim=-1)
|
||||
|
||||
if is_transposed:
|
||||
reordered_gather_weight = torch.cat(reordered_chunk_list, dim=-1)
|
||||
else:
|
||||
reordered_gather_weight = torch.cat(reordered_chunk_list, dim=0)
|
||||
return reordered_gather_weight
|
||||
|
||||
|
||||
class LinearConv1D_Col(ParallelModule):
|
||||
class GPT2FusedLinearConv1D_Col(ParallelModule):
|
||||
r"""Linear layer with column parallelism.
|
||||
|
||||
The linear layer is defined as :math:`Y = XA + b`. A is parallelized along
|
||||
|
@ -154,10 +195,10 @@ class LinearConv1D_Col(ParallelModule):
|
|||
weight = torch.empty(self.in_features, self.out_features, **factory_kwargs)
|
||||
|
||||
def shard_fn(tensor):
|
||||
return split_fused_qkv(tensor, self.n_fused, self.process_group)
|
||||
return split_fused_qkv_in_gpt2_style(tensor, self.n_fused, self.process_group, True)
|
||||
|
||||
def gather_fn(tensor):
|
||||
return gather_fused_qkv(tensor, 3, self.process_group)
|
||||
return gather_fused_qkv_in_gpt2_style(tensor, 3, self.process_group, True)
|
||||
|
||||
with torch.no_grad():
|
||||
sharded_weight = distribute_tensor_with_customization(weight, shard_fn, gather_fn)
|
||||
|
@ -202,21 +243,27 @@ class LinearConv1D_Col(ParallelModule):
|
|||
f'Expected only one process group, got {len(process_group)}.'
|
||||
process_group = process_group[0]
|
||||
|
||||
linear_1d = LinearConv1D_Col(in_features=in_features,
|
||||
out_features=out_features,
|
||||
bias=bias,
|
||||
device=device,
|
||||
process_group=process_group,
|
||||
*args,
|
||||
**kwargs)
|
||||
linear_1d = GPT2FusedLinearConv1D_Col(in_features=in_features,
|
||||
out_features=out_features,
|
||||
bias=bias,
|
||||
device=device,
|
||||
process_group=process_group,
|
||||
*args,
|
||||
**kwargs)
|
||||
|
||||
# TODO: copy the sharded weights
|
||||
with torch.no_grad():
|
||||
sharded_weight = split_fused_qkv(module.weight.data, n_fused=n_fused, process_group=process_group)
|
||||
sharded_weight = split_fused_qkv_in_gpt2_style(module.weight.data,
|
||||
n_fused=n_fused,
|
||||
process_group=process_group,
|
||||
is_transposed=True)
|
||||
linear_1d.weight.data.copy_(sharded_weight.data)
|
||||
|
||||
if bias:
|
||||
sharded_bias = split_fused_qkv(module.bias.data, n_fused=n_fused, process_group=process_group)
|
||||
sharded_bias = split_fused_qkv_in_gpt2_style(module.bias.data,
|
||||
n_fused=n_fused,
|
||||
process_group=process_group,
|
||||
is_transposed=True)
|
||||
linear_1d.bias.data.copy_(sharded_bias.data)
|
||||
|
||||
return linear_1d
|
||||
|
@ -254,7 +301,7 @@ class LinearConv1D_Col(ParallelModule):
|
|||
return output
|
||||
|
||||
|
||||
class LinearConv1D_Row(ParallelModule):
|
||||
class GPT2FusedLinearConv1D_Row(ParallelModule):
|
||||
r""" Linear layer with row parallelism.
|
||||
This layer is used to fit `Conv1D` layer (Fused QKV) in gpt2 of huggingface.
|
||||
|
||||
|
@ -345,13 +392,13 @@ class LinearConv1D_Row(ParallelModule):
|
|||
f'Expected only one process group, got {len(process_group)}.'
|
||||
process_group = process_group[0]
|
||||
|
||||
linear_1d = LinearConv1D_Row(in_features=in_features,
|
||||
out_features=out_features,
|
||||
bias=bias,
|
||||
device=device,
|
||||
process_group=process_group,
|
||||
*args,
|
||||
**kwargs)
|
||||
linear_1d = GPT2FusedLinearConv1D_Row(in_features=in_features,
|
||||
out_features=out_features,
|
||||
bias=bias,
|
||||
device=device,
|
||||
process_group=process_group,
|
||||
*args,
|
||||
**kwargs)
|
||||
|
||||
# TODO: copy the sharded weights
|
||||
with torch.no_grad():
|
|
@ -3,6 +3,7 @@ from contextlib import contextmanager
|
|||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch.distributed import ProcessGroup
|
||||
from torch.distributed.distributed_c10d import _get_global_rank
|
||||
|
||||
|
||||
class Randomizer:
|
||||
|
@ -112,27 +113,90 @@ class Randomizer:
|
|||
|
||||
"""
|
||||
idx = Randomizer._INDEX
|
||||
Randomizer._INDEX += 1
|
||||
return idx
|
||||
|
||||
@staticmethod
|
||||
def increment_index():
|
||||
"""
|
||||
Increment the index of the randomizer by one.
|
||||
"""
|
||||
Randomizer._INDEX += 1
|
||||
|
||||
def create_randomizer_with_offset(seed: int, process_group: ProcessGroup = None):
|
||||
@staticmethod
|
||||
def is_randomizer_index_synchronized(process_group: ProcessGroup = None):
|
||||
"""
|
||||
Return whether the randomizer index is synchronized across processes.
|
||||
"""
|
||||
index = Randomizer.index()
|
||||
if dist.is_initialized():
|
||||
# convert the index to tensor
|
||||
index_tensor = torch.tensor(index, dtype=torch.int32).cuda()
|
||||
|
||||
# all gather the index
|
||||
gathered_index = [torch.zeros_like(index_tensor) for _ in range(dist.get_world_size(process_group))]
|
||||
dist.all_gather(gathered_index, index_tensor, process_group)
|
||||
|
||||
# make sure all the gathered index are the same
|
||||
for i in range(1, dist.get_world_size(process_group)):
|
||||
if gathered_index[i] != gathered_index[0]:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
def synchronize_index(process_group: ProcessGroup = None):
|
||||
"""
|
||||
All gather the index and pick the largest value.
|
||||
"""
|
||||
index = Randomizer.index()
|
||||
|
||||
if dist.is_initialized():
|
||||
# convert the index to tensor
|
||||
index_tensor = torch.tensor(index, dtype=torch.int32).cuda()
|
||||
|
||||
# all gather the index
|
||||
gathered_index = [torch.zeros_like(index_tensor) for _ in range(dist.get_world_size(process_group))]
|
||||
dist.all_gather(gathered_index, index_tensor, process_group)
|
||||
|
||||
# pick the largest index
|
||||
for i in range(1, dist.get_world_size(process_group)):
|
||||
if gathered_index[i] > index_tensor:
|
||||
index_tensor = gathered_index[i]
|
||||
|
||||
# set the index
|
||||
Randomizer._INDEX = index_tensor.item()
|
||||
|
||||
|
||||
def create_randomizer_with_offset(seed: int,
|
||||
process_group: ProcessGroup = None,
|
||||
offset_by_rank: bool = True,
|
||||
offset_by_index: bool = True):
|
||||
"""
|
||||
Create a randomizer with an offset. The offset is equal to the rank of the process and the index of the randomizer.
|
||||
|
||||
Args:
|
||||
seed (int): The base random seed to set.
|
||||
enable_cpu (bool): fork the CPU RNG state as well.
|
||||
process_group (ProcessGroup): the process group to get the rank from.
|
||||
offset_by_rank (bool): whether to offset by the rank of the process, i.e., the rank of the process will be added to the seed. Default: True.
|
||||
offset_by_index (bool): whether to offset by the index of the randomizer, i.e., the index of the randomizer will be added to the seed. Default: True.
|
||||
|
||||
Returns:
|
||||
Randomizer: the randomizer with offset.
|
||||
"""
|
||||
offset = Randomizer.index()
|
||||
base_seed = seed
|
||||
|
||||
if dist.is_initialized():
|
||||
if offset_by_rank and dist.is_initialized():
|
||||
rank = dist.get_rank(process_group)
|
||||
offset += rank
|
||||
base_seed += rank
|
||||
|
||||
seed += offset
|
||||
return Randomizer(seed=seed)
|
||||
if offset_by_index:
|
||||
# check if the randomizer index is synchronized
|
||||
is_synchronized = Randomizer.is_randomizer_index_synchronized(process_group)
|
||||
assert is_synchronized, ("We detect that the randomizer index is not synchronized across processes."
|
||||
"This is not allowed when we want to create a randomizer with offset by index."
|
||||
"Please call Randomizer.synchronize_index() first.")
|
||||
|
||||
base_seed += Randomizer.index()
|
||||
Randomizer.increment_index()
|
||||
|
||||
return Randomizer(seed=base_seed)
|
||||
|
|
|
@ -78,6 +78,18 @@ _POLICY_LIST = {
|
|||
PolicyLocation(file_name="opt", class_name="OPTForSequenceClassificationPolicy"),
|
||||
"transformers.models.opt.modeling_opt.OPTForQuestionAnswering":
|
||||
PolicyLocation(file_name="opt", class_name="OPTForQuestionAnsweringPolicy"),
|
||||
|
||||
# Bloom
|
||||
"transformers.models.bloom.modeling_bloom.BloomModel":
|
||||
PolicyLocation(file_name="bloom", class_name="BloomModelPolicy"),
|
||||
"transformers.models.bloom.modeling_bloom.BloomForCausalLM":
|
||||
PolicyLocation(file_name="bloom", class_name="BloomForCausalLMPolicy"),
|
||||
"transformers.models.bloom.modeling_bloom.BloomForSequenceClassification":
|
||||
PolicyLocation(file_name="bloom", class_name="BloomForSequenceClassificationPolicy"),
|
||||
"transformers.models.bloom.modeling_bloom.BloomForTokenClassification":
|
||||
PolicyLocation(file_name="bloom", class_name="BloomForTokenClassificationPolicy"),
|
||||
"transformers.models.bloom.modeling_bloom.BloomForQuestionAnswering":
|
||||
PolicyLocation(file_name="bloom", class_name="BloomForQuestionAnsweringPolicy"),
|
||||
}
|
||||
|
||||
|
||||
|
|
|
@ -52,6 +52,7 @@ class ModulePolicyDescription:
|
|||
attribute_replacement: Dict[str, Any]
|
||||
param_replacement: List[Callable]
|
||||
sub_module_replacement: List[SubModuleReplacementDescription]
|
||||
method_replacement: List[Callable] = None
|
||||
|
||||
|
||||
class Policy(ABC):
|
||||
|
|
|
@ -61,7 +61,7 @@ class BertPolicy(Policy):
|
|||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="attention.self.dropout",
|
||||
target_module=col_nn.Dropout1D,
|
||||
target_module=col_nn.DropoutForParallelInput,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="attention.output.dense",
|
||||
|
@ -69,7 +69,7 @@ class BertPolicy(Policy):
|
|||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="attention.output.dropout",
|
||||
target_module=col_nn.Dropout1D,
|
||||
target_module=col_nn.DropoutForParallelInput,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="intermediate.dense",
|
||||
|
@ -81,7 +81,7 @@ class BertPolicy(Policy):
|
|||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="output.dropout",
|
||||
target_module=col_nn.Dropout1D,
|
||||
target_module=col_nn.DropoutForParallelInput,
|
||||
)
|
||||
]),
|
||||
BertEmbeddings:
|
||||
|
@ -94,7 +94,7 @@ class BertPolicy(Policy):
|
|||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="dropout",
|
||||
target_module=col_nn.Dropout1D,
|
||||
target_module=col_nn.DropoutForParallelInput,
|
||||
)
|
||||
])
|
||||
}
|
||||
|
@ -258,7 +258,7 @@ class BertForSequenceClassificationPolicy(BertPolicy):
|
|||
sub_module_replacement=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="dropout",
|
||||
target_module=col_nn.Dropout1D,
|
||||
target_module=col_nn.DropoutForParallelInput,
|
||||
)
|
||||
])
|
||||
}
|
||||
|
@ -281,7 +281,7 @@ class BertForTokenClassificationPolicy(BertPolicy):
|
|||
sub_module_replacement=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="dropout",
|
||||
target_module=col_nn.Dropout1D,
|
||||
target_module=col_nn.DropoutForParallelInput,
|
||||
)
|
||||
])
|
||||
}
|
||||
|
@ -311,7 +311,7 @@ class BertForMultipleChoicePolicy(BertPolicy):
|
|||
sub_module_replacement=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="dropout",
|
||||
target_module=col_nn.Dropout1D,
|
||||
target_module=col_nn.DropoutForParallelInput,
|
||||
)
|
||||
])
|
||||
}
|
||||
|
|
|
@ -0,0 +1,214 @@
|
|||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
import colossalai.shardformer.layer as col_nn
|
||||
|
||||
from .basepolicy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
|
||||
|
||||
|
||||
def build_bloom_alibi_tensor(self, attention_mask: torch.Tensor, num_heads: int, dtype: torch.dtype) -> torch.Tensor:
|
||||
"""
|
||||
Link to paper: https://arxiv.org/abs/2108.12409 Alibi tensor is not causal as the original paper mentions, it
|
||||
relies on a translation invariance of softmax for quick implementation: with l being a tensor, and a fixed value
|
||||
`softmax(l+a) = softmax(l)`. Based on
|
||||
https://github.com/ofirpress/attention_with_linear_biases/blob/a35aaca144e0eb6b789dfcb46784c4b8e31b7983/fairseq/models/transformer.py#L742
|
||||
TODO @thomasw21 this doesn't work as nicely due to the masking strategy, and so masking varies slightly.
|
||||
|
||||
Args:
|
||||
Returns tensor shaped (batch_size * num_heads, 1, max_seq_len)
|
||||
attention_mask (`torch.Tensor`):
|
||||
Token-wise attention mask, this should be of shape (batch_size, max_seq_len).
|
||||
num_heads (`int`, *required*):
|
||||
number of heads
|
||||
dtype (`torch.dtype`, *optional*, default=`torch.bfloat16`):
|
||||
dtype of the output tensor
|
||||
"""
|
||||
import math
|
||||
|
||||
if dist.is_initialized():
|
||||
world_size = dist.get_world_size()
|
||||
num_heads = num_heads * world_size
|
||||
|
||||
batch_size, seq_length = attention_mask.shape
|
||||
closest_power_of_2 = 2**math.floor(math.log2(num_heads))
|
||||
base = torch.tensor(2**(-(2**-(math.log2(closest_power_of_2) - 3))),
|
||||
device=attention_mask.device,
|
||||
dtype=torch.float32)
|
||||
powers = torch.arange(1, 1 + closest_power_of_2, device=attention_mask.device, dtype=torch.int32)
|
||||
slopes = torch.pow(base, powers)
|
||||
|
||||
if closest_power_of_2 != num_heads:
|
||||
extra_base = torch.tensor(2**(-(2**-(math.log2(2 * closest_power_of_2) - 3))),
|
||||
device=attention_mask.device,
|
||||
dtype=torch.float32)
|
||||
num_remaining_heads = min(closest_power_of_2, num_heads - closest_power_of_2)
|
||||
extra_powers = torch.arange(1, 1 + 2 * num_remaining_heads, 2, device=attention_mask.device, dtype=torch.int32)
|
||||
slopes = torch.cat([slopes, torch.pow(extra_base, extra_powers)], dim=0)
|
||||
|
||||
# Note: alibi will added to the attention bias that will be applied to the query, key product of attention
|
||||
# => therefore alibi will have to be of shape (batch_size, num_heads, query_length, key_length)
|
||||
# => here we set (batch_size=1, num_heads=num_heads, query_length=1, key_length=max_length)
|
||||
# => the query_length dimension will then be broadcasted correctly
|
||||
# This is more or less identical to T5's relative position bias:
|
||||
# https://github.com/huggingface/transformers/blob/f681437203baa7671de3174b0fa583c349d9d5e1/src/transformers/models/t5/modeling_t5.py#L527
|
||||
arange_tensor = ((attention_mask.cumsum(dim=-1) - 1) * attention_mask)[:, None, :]
|
||||
alibi = slopes[..., None] * arange_tensor
|
||||
if dist.is_initialized():
|
||||
num_heads_per_rank = int(num_heads / dist.get_world_size())
|
||||
offset = dist.get_rank() * num_heads_per_rank
|
||||
alibi = alibi.view(batch_size, num_heads, 1, seq_length)
|
||||
alibi = alibi[:, offset:num_heads_per_rank + offset, :, :]
|
||||
return alibi.reshape(batch_size * num_heads_per_rank, 1, seq_length).to(dtype)
|
||||
else:
|
||||
return alibi.reshape(batch_size * num_heads, 1, seq_length).to(dtype)
|
||||
|
||||
|
||||
class BloomPolicy(Policy):
|
||||
|
||||
def preprocess(self):
|
||||
# reshape the embedding layer
|
||||
r"""
|
||||
Reshape the Embedding layer to make the embedding dimension divisible by world_size
|
||||
"""
|
||||
# TODO:
|
||||
vocab_size = self.model.config.vocab_size
|
||||
world_size = self.shard_config.tensor_parallel_size
|
||||
if vocab_size % world_size != 0:
|
||||
new_vocab_size = vocab_size + world_size - vocab_size % world_size
|
||||
self.model.resize_token_embeddings(new_vocab_size)
|
||||
return self.model
|
||||
|
||||
def module_policy(self):
|
||||
from transformers.models.bloom.modeling_bloom import BloomBlock, BloomModel
|
||||
|
||||
return {
|
||||
BloomBlock:
|
||||
ModulePolicyDescription(
|
||||
attribute_replacement={
|
||||
# 1. shard hidden size
|
||||
"self_attention.hidden_size":
|
||||
self.model.config.hidden_size // self.shard_config.tensor_parallel_size,
|
||||
"self_attention.split_size":
|
||||
self.model.config.hidden_size // self.shard_config.tensor_parallel_size,
|
||||
# 2. shard number of heads
|
||||
"self_attention.num_heads":
|
||||
self.model.config.n_head // self.shard_config.tensor_parallel_size,
|
||||
},
|
||||
param_replacement=[],
|
||||
sub_module_replacement=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="self_attention.query_key_value",
|
||||
target_module=col_nn.Linear1D_Col,
|
||||
# kwargs={'n_fused': 3}
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="self_attention.dense",
|
||||
target_module=col_nn.Linear1D_Row,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="self_attention.attention_dropout",
|
||||
target_module=col_nn.DropoutForParallelInput,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="mlp.dense_h_to_4h",
|
||||
target_module=col_nn.Linear1D_Col,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="mlp.dense_4h_to_h",
|
||||
target_module=col_nn.Linear1D_Row,
|
||||
),
|
||||
]),
|
||||
BloomModel:
|
||||
ModulePolicyDescription(attribute_replacement={
|
||||
"num_heads": self.model.config.n_head // self.shard_config.tensor_parallel_size,
|
||||
},
|
||||
param_replacement=[],
|
||||
method_replacement={"build_alibi_tensor": build_bloom_alibi_tensor},
|
||||
sub_module_replacement=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="word_embeddings",
|
||||
target_module=col_nn.VocabParallelEmbedding1D,
|
||||
)
|
||||
])
|
||||
}
|
||||
|
||||
def new_model_class(self):
|
||||
# do nothing
|
||||
return self.model
|
||||
|
||||
def postprocess(self):
|
||||
return self.model
|
||||
|
||||
|
||||
# BertModel
|
||||
class BloomModelPolicy(BloomPolicy):
|
||||
pass
|
||||
|
||||
|
||||
class BloomForCausalLMPolicy(BloomPolicy):
|
||||
|
||||
def module_policy(self):
|
||||
from transformers.models.bloom.modeling_bloom import BloomForCausalLM
|
||||
policy = super().module_policy()
|
||||
# add a new item for casual lm
|
||||
new_item = {
|
||||
BloomForCausalLM:
|
||||
ModulePolicyDescription(attribute_replacement={},
|
||||
param_replacement=[],
|
||||
sub_module_replacement=[
|
||||
SubModuleReplacementDescription(suffix="lm_head",
|
||||
target_module=col_nn.Linear1D_Col,
|
||||
kwargs=dict(gather_output=True))
|
||||
])
|
||||
}
|
||||
policy.update(new_item)
|
||||
return policy
|
||||
|
||||
|
||||
class BloomForSequenceClassificationPolicy(BloomPolicy):
|
||||
|
||||
def module_policy(self):
|
||||
from transformers.models.bloom.modeling_bloom import BloomForSequenceClassification
|
||||
policy = super().module_policy()
|
||||
# add a new item for casual lm
|
||||
new_item = {
|
||||
BloomForSequenceClassification:
|
||||
ModulePolicyDescription(attribute_replacement={},
|
||||
param_replacement=[],
|
||||
sub_module_replacement=[
|
||||
SubModuleReplacementDescription(suffix="score",
|
||||
target_module=col_nn.Linear1D_Col,
|
||||
kwargs=dict(gather_output=True))
|
||||
])
|
||||
}
|
||||
policy.update(new_item)
|
||||
return policy
|
||||
|
||||
|
||||
class BloomForTokenClassificationPolicy(BloomPolicy):
|
||||
|
||||
def module_policy(self):
|
||||
from transformers.models.bloom.modeling_bloom import BloomForTokenClassification
|
||||
policy = super().module_policy()
|
||||
# add a new item for casual lm
|
||||
new_item = {
|
||||
BloomForTokenClassification:
|
||||
ModulePolicyDescription(attribute_replacement={},
|
||||
param_replacement=[],
|
||||
sub_module_replacement=[
|
||||
SubModuleReplacementDescription(suffix="classifier",
|
||||
target_module=col_nn.Linear1D_Col,
|
||||
kwargs=dict(gather_output=True)),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="dropout",
|
||||
target_module=col_nn.DropoutForReplicatedInput,
|
||||
),
|
||||
])
|
||||
}
|
||||
policy.update(new_item)
|
||||
return policy
|
||||
|
||||
|
||||
class BloomForQuestionAnsweringPolicy(BloomPolicy):
|
||||
# No head sharding as the output features is only 2
|
||||
pass
|
|
@ -42,37 +42,37 @@ class GPT2Policy(Policy):
|
|||
sub_module_replacement=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="attn.c_attn",
|
||||
target_module=col_nn.LinearConv1D_Col,
|
||||
target_module=col_nn.GPT2FusedLinearConv1D_Col,
|
||||
kwargs={
|
||||
"n_fused": 3,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="attn.c_proj",
|
||||
target_module=col_nn.LinearConv1D_Row,
|
||||
target_module=col_nn.GPT2FusedLinearConv1D_Row,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="mlp.c_fc",
|
||||
target_module=col_nn.LinearConv1D_Col,
|
||||
target_module=col_nn.GPT2FusedLinearConv1D_Col,
|
||||
kwargs={
|
||||
"n_fused": 1,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="mlp.c_proj",
|
||||
target_module=col_nn.LinearConv1D_Row,
|
||||
target_module=col_nn.GPT2FusedLinearConv1D_Row,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="attn.attn_dropout",
|
||||
target_module=col_nn.Dropout1D,
|
||||
target_module=col_nn.DropoutForParallelInput,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="attn.resid_dropout",
|
||||
target_module=col_nn.Dropout1D,
|
||||
target_module=col_nn.DropoutForParallelInput,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="mlp.dropout",
|
||||
target_module=col_nn.Dropout1D,
|
||||
target_module=col_nn.DropoutForParallelInput,
|
||||
),
|
||||
])
|
||||
}
|
||||
|
|
|
@ -9,7 +9,7 @@ from transformers.models.t5.modeling_t5 import (
|
|||
T5Stack,
|
||||
)
|
||||
|
||||
from colossalai.shardformer.layer import Dropout1D, Embedding1D, Linear1D_Col, Linear1D_Row
|
||||
from colossalai.shardformer.layer import DropoutForParallelInput, Embedding1D, Linear1D_Col, Linear1D_Row
|
||||
|
||||
from .basepolicy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
|
||||
|
||||
|
@ -38,7 +38,7 @@ class T5ModelPolicy(Policy):
|
|||
sub_module_replacement=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="dropout",
|
||||
target_module=Dropout1D,
|
||||
target_module=DropoutForParallelInput,
|
||||
)
|
||||
]),
|
||||
T5LayerSelfAttention:
|
||||
|
@ -47,7 +47,7 @@ class T5ModelPolicy(Policy):
|
|||
sub_module_replacement=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="dropout",
|
||||
target_module=Dropout1D,
|
||||
target_module=DropoutForParallelInput,
|
||||
),
|
||||
]),
|
||||
T5LayerCrossAttention:
|
||||
|
@ -56,7 +56,7 @@ class T5ModelPolicy(Policy):
|
|||
sub_module_replacement=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="dropout",
|
||||
target_module=Dropout1D,
|
||||
target_module=DropoutForParallelInput,
|
||||
)
|
||||
]),
|
||||
T5Attention:
|
||||
|
@ -97,7 +97,7 @@ class T5ModelPolicy(Policy):
|
|||
sub_module_replacement=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="dropout",
|
||||
target_module=Dropout1D,
|
||||
target_module=DropoutForParallelInput,
|
||||
),
|
||||
]),
|
||||
T5DenseGatedActDense:
|
||||
|
@ -117,7 +117,7 @@ class T5ModelPolicy(Policy):
|
|||
kwargs=dict(gather_output=True)),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="dropout",
|
||||
target_module=Dropout1D,
|
||||
target_module=DropoutForParallelInput,
|
||||
)
|
||||
]),
|
||||
T5DenseActDense:
|
||||
|
@ -134,7 +134,7 @@ class T5ModelPolicy(Policy):
|
|||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="dropout",
|
||||
target_module=Dropout1D,
|
||||
target_module=DropoutForParallelInput,
|
||||
)
|
||||
])
|
||||
}
|
||||
|
|
|
@ -1,15 +1,15 @@
|
|||
from typing import Dict, Union
|
||||
|
||||
import torch.nn as nn
|
||||
from transformers.models.vit.modeling_vit import ViTAttention, ViTEmbeddings, ViTLayer, ViTModel
|
||||
|
||||
from transformers.models.vit.modeling_vit import ViTModel, ViTLayer, ViTEmbeddings, ViTAttention
|
||||
|
||||
from colossalai.shardformer.layer import Linear1D_Col, Linear1D_Row, Dropout1D
|
||||
from colossalai.shardformer.layer import DropoutForReplicatedInput, Linear1D_Col, Linear1D_Row
|
||||
|
||||
from .basepolicy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
|
||||
|
||||
|
||||
class ViTPolicy(Policy):
|
||||
|
||||
|
||||
def preprocess(self):
|
||||
# Resize embedding
|
||||
vocab_size = self.model.config.vocab_size
|
||||
|
@ -20,77 +20,68 @@ class ViTPolicy(Policy):
|
|||
self.model.resize_token_embeddings(new_vocab_size)
|
||||
|
||||
return self.model
|
||||
|
||||
|
||||
def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:
|
||||
return {
|
||||
return {
|
||||
ViTEmbeddings:
|
||||
ModulePolicyDescription(
|
||||
attribute_replacement{},
|
||||
param_replacement=[],
|
||||
sub_module_replacement=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="dropout",
|
||||
target_module=Dropout1D,
|
||||
)
|
||||
]
|
||||
),
|
||||
ModulePolicyDescription(attribute_replacement={},
|
||||
param_replacement=[],
|
||||
sub_module_replacement=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="dropout",
|
||||
target_module=DropoutForReplicatedInput,
|
||||
)
|
||||
]),
|
||||
ViTLayer:
|
||||
ModulePolicyDescription(
|
||||
attribute_replacement{
|
||||
"attention.attention.num_attention_heads":
|
||||
self.model.config.num_attention_heads//self.shard_config.tensor_parallel_size,
|
||||
"attention.attention.all_head_size":
|
||||
self.model.config.hidden_size//self.shard_config.tensor_parallel_size,
|
||||
},
|
||||
param_replacement=[],
|
||||
sub_module_replacement=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="attention.attention.query",
|
||||
target_module=Linear1D_Col,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="attention.attention.key",
|
||||
target_module=Linear1D_Col,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="attention.attention.value",
|
||||
target_module=Linear1D_Col,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="attention.attention.dropout",
|
||||
target_module=Dropout1D,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="attention.output.dense",
|
||||
target_module=Linear1D_Row,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="attention.output.dropout",
|
||||
target_module=Dropout1D,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="intermediate.dense",
|
||||
target_module=Linear1D_Col,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="output.dense",
|
||||
target_module=Linear1D_Row,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="output.dropout",
|
||||
target_module=Dropout1D,
|
||||
),
|
||||
]
|
||||
),
|
||||
ModulePolicyDescription(attribute_replacement={
|
||||
"attention.attention.num_attention_heads":
|
||||
self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size,
|
||||
"attention.attention.all_head_size":
|
||||
self.model.config.hidden_size // self.shard_config.tensor_parallel_size,
|
||||
},
|
||||
param_replacement=[],
|
||||
sub_module_replacement=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="attention.attention.query",
|
||||
target_module=Linear1D_Col,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="attention.attention.key",
|
||||
target_module=Linear1D_Col,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="attention.attention.value",
|
||||
target_module=Linear1D_Col,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="attention.attention.dropout",
|
||||
target_module=DropoutForParallelInput,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="attention.output.dense",
|
||||
target_module=Linear1D_Row,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="attention.output.dropout",
|
||||
target_module=DropoutForParallelInput,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="intermediate.dense",
|
||||
target_module=Linear1D_Col,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="output.dense",
|
||||
target_module=Linear1D_Row,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="output.dropout",
|
||||
target_module=DropoutForParallelInput,
|
||||
),
|
||||
]),
|
||||
}
|
||||
|
||||
|
||||
def new_model_class(self):
|
||||
return None
|
||||
|
||||
def postprocess(self):
|
||||
return self.model
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
|
|
@ -95,8 +95,9 @@ class ModelSharder(object):
|
|||
attr_replacement = module_description[1].attribute_replacement
|
||||
param_replacement = module_description[1].param_replacement
|
||||
sub_module_replacement = module_description[1].sub_module_replacement
|
||||
method_replacement = module_description[1].method_replacement
|
||||
self._recursive_replace_layer(self.model, origin_layer_cls, attr_replacement, param_replacement,
|
||||
sub_module_replacement)
|
||||
method_replacement, sub_module_replacement)
|
||||
|
||||
def _recursive_replace_layer(
|
||||
self,
|
||||
|
@ -104,6 +105,7 @@ class ModelSharder(object):
|
|||
origin_cls: nn.Module,
|
||||
attr_replacement: Dict[str, Any],
|
||||
param_replacement: List[Callable],
|
||||
method_replacement: Dict[str, Callable],
|
||||
sub_module_replacement: List[Callable],
|
||||
) -> None:
|
||||
r"""
|
||||
|
@ -119,9 +121,11 @@ class ModelSharder(object):
|
|||
if module.__class__ == origin_cls:
|
||||
self._replace_attr(module, attr_replacement)
|
||||
self._replace_param(module, param_replacement)
|
||||
self._replace_method(module, method_replacement)
|
||||
self._replace_sub_module(module, sub_module_replacement)
|
||||
|
||||
for name, child in module.named_children():
|
||||
self._recursive_replace_layer(child, origin_cls, attr_replacement, param_replacement,
|
||||
self._recursive_replace_layer(child, origin_cls, attr_replacement, param_replacement, method_replacement,
|
||||
sub_module_replacement)
|
||||
|
||||
def _replace_attr(
|
||||
|
@ -154,6 +158,14 @@ class ModelSharder(object):
|
|||
# TODO: support parameter shard
|
||||
pass
|
||||
|
||||
def _replace_method(self, module: nn.Module, method_replacement: Dict[str, Callable]):
|
||||
if method_replacement is None:
|
||||
return
|
||||
|
||||
for method_name, new_method in method_replacement.items():
|
||||
# bind the new method to the module
|
||||
setattr(module, method_name, new_method.__get__(module, module.__class__))
|
||||
|
||||
def _replace_sub_module(
|
||||
self,
|
||||
org_layer: nn.Module,
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
from .albert import *
|
||||
from .bert import *
|
||||
from .bloom import *
|
||||
from .gpt import *
|
||||
from .llama import *
|
||||
from .opt import *
|
||||
|
|
|
@ -0,0 +1,107 @@
|
|||
import torch
|
||||
import transformers
|
||||
|
||||
from ..registry import ModelAttribute, model_zoo
|
||||
|
||||
# ===============================
|
||||
# Register Bloom
|
||||
# ===============================
|
||||
|
||||
|
||||
def data_gen():
|
||||
# Generated from following code snippet
|
||||
#
|
||||
# from transformers import BloomTokenizer
|
||||
# input = 'Hello, my dog is cute'
|
||||
# tokenized_input = tokenizer(input, return_tensors='pt')
|
||||
# input_ids = tokenized_input['input_ids']
|
||||
# attention_mask = tokenized_input['attention_mask']
|
||||
input_ids = torch.tensor([[59414, 15, 2670, 35433, 632, 207595]], dtype=torch.int64)
|
||||
attention_mask = torch.tensor([[1, 1, 1, 1, 1, 1]], dtype=torch.int64)
|
||||
return dict(input_ids=input_ids, attention_mask=attention_mask)
|
||||
|
||||
|
||||
def data_gen_for_lm():
|
||||
# LM data gen
|
||||
# the `labels` of LM is the token of the output, cause no padding, use `input_ids` as `labels`
|
||||
data = data_gen()
|
||||
data['labels'] = data['input_ids'].clone()
|
||||
return data
|
||||
|
||||
|
||||
def data_gen_for_token_classification():
|
||||
# token classification data gen
|
||||
# `labels` is the type not the token id for token classification, 0 or 1
|
||||
data = data_gen()
|
||||
data['labels'] = torch.tensor([[0, 0, 0, 0, 0, 0]], dtype=torch.int64)
|
||||
return data
|
||||
|
||||
|
||||
def data_gen_for_sequence_classification():
|
||||
# sequence classification data gen
|
||||
data = data_gen()
|
||||
data['labels'] = torch.tensor([0], dtype=torch.int64)
|
||||
return data
|
||||
|
||||
|
||||
def data_gen_for_question_answering():
|
||||
# obtained with the following code
|
||||
#
|
||||
# from transformers import AutoTokenizer
|
||||
# tokenizer = AutoTokenizer.from_pretrained("bigscience/bloom-560m")
|
||||
# question, text = "Who was Jim Henson?", "Jim Henson was a nice puppet"
|
||||
# inputs = tokenizer(question, text, return_tensors="pt")
|
||||
|
||||
input_ids = torch.tensor(
|
||||
[[57647, 1620, 23967, 620, 107373, 34, 91514, 620, 107373, 1620, 267, 35378, 48946, 18161]], dtype=torch.int64)
|
||||
attention_mask = torch.tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]], dtype=torch.int64)
|
||||
return dict(input_ids=input_ids, attention_mask=attention_mask)
|
||||
|
||||
|
||||
# define output transform function
|
||||
output_transform_fn = lambda x: x
|
||||
|
||||
# define loss function
|
||||
loss_fn_for_bloom_model = lambda x: x.last_hidden_state.mean()
|
||||
loss_fn_for_causal_lm = lambda x: x.loss
|
||||
loss_fn_for_classification = lambda x: x.logits.mean()
|
||||
loss_fn_for_question_answering = lambda x: x.end_logits.mean()
|
||||
|
||||
config = transformers.BloomConfig(n_layer=1,
|
||||
n_head=4,
|
||||
vocab_size=250880,
|
||||
hidden_dropout=0,
|
||||
attention_dropout=0,
|
||||
hidden_size=64)
|
||||
|
||||
# register the following models
|
||||
model_zoo.register(name='transformers_bloom',
|
||||
model_fn=lambda: transformers.BloomModel(config),
|
||||
data_gen_fn=data_gen,
|
||||
output_transform_fn=output_transform_fn,
|
||||
loss_fn=loss_fn_for_bloom_model,
|
||||
model_attribute=ModelAttribute(has_control_flow=True))
|
||||
model_zoo.register(name='transformers_bloom_for_causal_lm',
|
||||
model_fn=lambda: transformers.BloomForCausalLM(config),
|
||||
data_gen_fn=data_gen_for_lm,
|
||||
output_transform_fn=output_transform_fn,
|
||||
loss_fn=loss_fn_for_causal_lm,
|
||||
model_attribute=ModelAttribute(has_control_flow=True))
|
||||
model_zoo.register(name='transformers_bloom_for_sequence_classification',
|
||||
model_fn=lambda: transformers.BloomForSequenceClassification(config),
|
||||
data_gen_fn=data_gen_for_sequence_classification,
|
||||
output_transform_fn=output_transform_fn,
|
||||
loss_fn=loss_fn_for_classification,
|
||||
model_attribute=ModelAttribute(has_control_flow=True))
|
||||
model_zoo.register(name='transformers_bloom_for_token_classification',
|
||||
model_fn=lambda: transformers.BloomForTokenClassification(config),
|
||||
data_gen_fn=data_gen_for_token_classification,
|
||||
output_transform_fn=output_transform_fn,
|
||||
loss_fn=loss_fn_for_classification,
|
||||
model_attribute=ModelAttribute(has_control_flow=True))
|
||||
model_zoo.register(name='transformers_bloom_for_question_answering',
|
||||
model_fn=lambda: transformers.BloomForQuestionAnswering(config),
|
||||
data_gen_fn=data_gen_for_question_answering,
|
||||
output_transform_fn=output_transform_fn,
|
||||
loss_fn=loss_fn_for_question_answering,
|
||||
model_attribute=ModelAttribute(has_control_flow=True))
|
|
@ -6,8 +6,6 @@ from ..registry import ModelAttribute, model_zoo
|
|||
# ===============================
|
||||
# Register single-sentence GPT
|
||||
# ===============================
|
||||
BATCH_SIZE = 1 # it can only be 1 as GPT cannot handle batch sizes > 1 if no padding token is defined.
|
||||
SEQ_LENGTH = 16
|
||||
|
||||
|
||||
def data_gen():
|
||||
|
|
|
@ -3,13 +3,13 @@ import torch.distributed as dist
|
|||
import torch.nn as nn
|
||||
|
||||
import colossalai
|
||||
from colossalai.shardformer.layer import Dropout1D
|
||||
from colossalai.shardformer.layer import DropoutForParallelInput, DropoutForReplicatedInput
|
||||
from colossalai.testing import assert_equal, assert_not_equal, rerun_if_address_is_in_use, spawn
|
||||
|
||||
|
||||
def check_dropout():
|
||||
def check_dropout_parallel_input():
|
||||
dropout = nn.Dropout().cuda()
|
||||
dropout_1d = Dropout1D.from_native_module(dropout, process_group=None)
|
||||
dropout_1d = DropoutForParallelInput.from_native_module(dropout, process_group=None)
|
||||
|
||||
# check computation correctness
|
||||
x = torch.rand(4, 128).cuda()
|
||||
|
@ -39,9 +39,26 @@ def check_dropout():
|
|||
assert_not_equal(out_1d_all[i], out_1d_all[0])
|
||||
|
||||
|
||||
def check_dropout_replicated_input():
|
||||
dropout = nn.Dropout().cuda()
|
||||
dropout_replica = DropoutForReplicatedInput.from_native_module(dropout, process_group=None)
|
||||
|
||||
# check computation correctness
|
||||
x = torch.rand(4, 128).cuda()
|
||||
out_1d = dropout_replica(x)
|
||||
|
||||
# ensure out_1d is different across ranks
|
||||
world_size = dist.get_world_size()
|
||||
out_1d_all = [torch.zeros_like(out_1d) for _ in range(world_size)]
|
||||
dist.all_gather(out_1d_all, out_1d)
|
||||
for i in range(1, world_size):
|
||||
assert_equal(out_1d_all[i], out_1d_all[0])
|
||||
|
||||
|
||||
def run_dist(rank, world_size, port):
|
||||
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
check_dropout()
|
||||
check_dropout_parallel_input()
|
||||
check_dropout_replicated_input()
|
||||
|
||||
|
||||
@rerun_if_address_is_in_use()
|
||||
|
|
|
@ -4,8 +4,8 @@ import torch.nn as nn
|
|||
from torch.testing import assert_close
|
||||
|
||||
import colossalai
|
||||
from colossalai.shardformer.layer import LinearConv1D_Col, LinearConv1D_Row
|
||||
from colossalai.shardformer.layer.linear_conv import split_fused_qkv
|
||||
from colossalai.shardformer.layer import GPT2FusedLinearConv1D_Col, GPT2FusedLinearConv1D_Row
|
||||
from colossalai.shardformer.layer.qkv_fused_linear import split_fused_qkv_in_gpt2_style
|
||||
from colossalai.testing import rerun_if_address_is_in_use, spawn
|
||||
|
||||
|
||||
|
@ -52,7 +52,10 @@ def rearrange(tensor: torch.Tensor, dim: int):
|
|||
|
||||
def check_linear_conv_1d_col():
|
||||
linear = Conv1D(192, 48).cuda()
|
||||
linear_conv_col = LinearConv1D_Col.from_native_module(linear, process_group=None, gather_output=True, n_fused=3)
|
||||
linear_conv_col = GPT2FusedLinearConv1D_Col.from_native_module(linear,
|
||||
process_group=None,
|
||||
gather_output=True,
|
||||
n_fused=3)
|
||||
|
||||
assert linear.weight.shape == torch.Size([48, 192])
|
||||
assert linear.bias.shape == torch.Size([192])
|
||||
|
@ -73,13 +76,13 @@ def check_linear_conv_1d_col():
|
|||
out.sum().backward()
|
||||
gather_out.sum().backward()
|
||||
|
||||
target_grad = split_fused_qkv(linear.weight.grad, 3, None)
|
||||
target_grad = split_fused_qkv_in_gpt2_style(linear.weight.grad, 3, None, True)
|
||||
assert_close(target_grad, linear_conv_col.weight.grad)
|
||||
|
||||
|
||||
def check_linear_conv_1d_row():
|
||||
linear = Conv1D(192, 48).cuda()
|
||||
linear_row = LinearConv1D_Row.from_native_module(linear, process_group=None, parallel_input=False)
|
||||
linear_row = GPT2FusedLinearConv1D_Row.from_native_module(linear, process_group=None, parallel_input=False)
|
||||
|
||||
assert linear.weight.shape == torch.Size([48, 192])
|
||||
assert linear_row.weight.shape == torch.Size([24, 192])
|
||||
|
@ -102,6 +105,8 @@ def check_linear_conv_1d_row():
|
|||
|
||||
def run_dist(rank, world_size, port):
|
||||
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
|
||||
# test for linear conv
|
||||
check_linear_conv_1d_col()
|
||||
check_linear_conv_1d_row()
|
||||
|
|
@ -0,0 +1,59 @@
|
|||
import pytest
|
||||
import torch
|
||||
|
||||
import colossalai
|
||||
from colossalai.logging import disable_existing_loggers
|
||||
from colossalai.testing import assert_hf_output_close, clear_cache_before_run, rerun_if_address_is_in_use, spawn
|
||||
from tests.kit.model_zoo import model_zoo
|
||||
from tests.test_shardformer.test_model._utils import build_model, run_forward
|
||||
|
||||
|
||||
def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn):
|
||||
# check forward
|
||||
org_output, org_loss, shard_output, shard_loss = run_forward(org_model, sharded_model, data_gen_fn,
|
||||
output_transform_fn, loss_fn)
|
||||
assert_hf_output_close(org_output, shard_output, ignore_keys=['past_key_values'])
|
||||
|
||||
# do backward
|
||||
org_loss.backward()
|
||||
shard_loss.backward()
|
||||
|
||||
# check grad equality
|
||||
if org_model.__class__.__name__ == 'BloomModel':
|
||||
org_grad = org_model.h[0].self_attention.query_key_value.weight.grad
|
||||
shard_grad = sharded_model.h[0].self_attention.query_key_value.weight.grad
|
||||
else:
|
||||
org_grad = org_model.transformer.h[0].self_attention.query_key_value.weight.grad
|
||||
shard_grad = sharded_model.transformer.h[0].self_attention.query_key_value.weight.grad
|
||||
|
||||
shard_grad_list = [torch.zeros([*shard_grad.shape]).to('cuda') for _ in range(2)]
|
||||
torch.distributed.all_gather(shard_grad_list, shard_grad)
|
||||
all_shard_grad = torch.cat(shard_grad_list, dim=0)
|
||||
|
||||
assert torch.allclose(org_loss, shard_loss,
|
||||
atol=1e-5), f"shard model loss is not equal to orgin model loss\n{org_loss}\n{shard_loss}"
|
||||
assert torch.allclose(org_grad, all_shard_grad,
|
||||
atol=1e-5), f"shard model grad is not equal to orgin model grad\n{org_grad}\n{all_shard_grad}"
|
||||
|
||||
|
||||
def check_bloom(rank, world_size, port):
|
||||
disable_existing_loggers()
|
||||
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
|
||||
sub_model_zoo = model_zoo.get_sub_registry('transformers_bloom')
|
||||
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
|
||||
org_model, sharded_model = build_model(world_size, model_fn)
|
||||
check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn)
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
@rerun_if_address_is_in_use()
|
||||
@clear_cache_before_run()
|
||||
def test_bloom():
|
||||
spawn(check_bloom, 2)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_bloom()
|
Loading…
Reference in New Issue