[shardformer] supported bloom model (#4098)

pull/4157/head
Frank Lee 2023-06-28 15:04:35 +08:00
parent 8af29ee47a
commit b1c2901530
20 changed files with 724 additions and 154 deletions

View File

@ -83,8 +83,10 @@ We will follow this roadmap to develop Shardformer:
- [x] BERT
- [x] T5
- [x] LlaMa
- [ ] GPT2
- [ ] BLOOM
- [x] GPT2
- [x] OPT
- [x] BLOOM
- [ ] GLM
- [ ] RoBERTa
- [ ] ALBERT
- [ ] ERNIE
@ -96,7 +98,7 @@ We will follow this roadmap to develop Shardformer:
- [ ] SwinTransformer
- [ ] SwinTransformer V2
- [ ] Audio
- [ ] To be added
- [ ] Whisper
- [ ] Multi-modal
- [ ] To be added

View File

@ -1,11 +1,12 @@
from .dropout import Dropout1D
from .dropout import DropoutForParallelInput, DropoutForReplicatedInput
from .embedding import Embedding1D, VocabParallelEmbedding1D
from .layernorm import FusedLayerNorm
from .linear import Linear1D_Col, Linear1D_Row
from .linear_conv import LinearConv1D_Col, LinearConv1D_Row
from .loss import cross_entropy_1d
from .qkv_fused_linear import GPT2FusedLinearConv1D_Col, GPT2FusedLinearConv1D_Row
__all__ = [
"Embedding1D", "VocabParallelEmbedding1D", "Linear1D_Col", "Linear1D_Row", "LinearConv1D_Col", "LinearConv1D_Row",
"Dropout1D", "cross_entropy_1d", 'FusedLayerNorm'
"Embedding1D", "VocabParallelEmbedding1D", "Linear1D_Col", "Linear1D_Row", 'GPT2FusedLinearConv1D_Col',
'GPT2FusedLinearConv1D_Row', 'DropoutForParallelInput', 'DropoutForReplicatedInput', "cross_entropy_1d",
'FusedLayerNorm'
]

View File

@ -7,10 +7,10 @@ from torch.distributed import ProcessGroup
from .parallel_module import ParallelModule
from .utils import create_randomizer_with_offset
__all__ = ['Dropout1D']
__all__ = ['DropoutForParallelInput', 'DropoutForReplicatedInput']
class Dropout1D(ParallelModule, nn.Dropout):
class DropoutForParallelInput(ParallelModule, nn.Dropout):
"""
The Dropout Layer will apply dropout mask to the input tensor. The dropout mask is generated with
randomness on different ranks of the given process group. This can avoid the same dropout mask is generated
@ -32,13 +32,50 @@ class Dropout1D(ParallelModule, nn.Dropout):
@staticmethod
def from_native_module(module: nn.Dropout,
process_group: Union[ProcessGroup, List[ProcessGroup]] = None) -> "Dropout1D":
process_group: Union[ProcessGroup, List[ProcessGroup]] = None) -> "DropoutForParallelInput":
"""
Create a Dropout1D layer from a native dropout layer.
Create a DropoutForParallelInput layer from a native dropout layer.
"""
p = module.p
inplace = module.inplace
return Dropout1D(p=p, inplace=inplace, process_group=process_group)
return DropoutForParallelInput(p=p, inplace=inplace, process_group=process_group)
def forward(self, input):
with self.randomizer.fork_rng():
input = super().forward(input)
return input
class DropoutForReplicatedInput(ParallelModule, nn.Dropout):
"""
The Dropout Layer will apply dropout mask to the input tensor. The dropout mask is generated with
randomness on different ranks of the given process group. This can avoid the same dropout mask is generated
and applied on the same position of different ranks, leading to poor convergence performance.
Args:
p (float): probability of an element to be zeroed. Defaults to 0.5.
inplace (bool): If set to True, will do this operation in-place. Defaults to False.
process_group (ProcessGroup): the process group to be used for generating randomness. Defaults to None.
"""
def __init__(self, p: float = 0.5, inplace: bool = False, process_group: ProcessGroup = None):
# init with nn.Dropout
super(nn.Dropout, self).__init__(p=p, inplace=inplace)
# offset the seed with randomizer index only
seed = torch.random.initial_seed()
self.randomizer = create_randomizer_with_offset(seed, process_group=process_group, offset_by_rank=False)
@staticmethod
def from_native_module(
module: nn.Dropout,
process_group: Union[ProcessGroup, List[ProcessGroup]] = None) -> "DropoutForReplicatedInput":
"""
Create a Dropout1D layer from a native dropout layer.
"""
p = module.p
inplace = module.inplace
return DropoutForReplicatedInput(p=p, inplace=inplace, process_group=process_group)
def forward(self, input):
with self.randomizer.fork_rng():

View File

@ -277,6 +277,7 @@ class Linear1D_Row(ParallelModule):
def chunk_weight(self):
self.weight_list = torch.chunk(self.weight, self.stream_chunk_num, dim=0)
@torch.no_grad()
def reset_parameters(self, weight_initializer, bias_initializer) -> None:
fan_in, fan_out = self.in_features, self.out_features
weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out)
@ -289,9 +290,10 @@ class Linear1D_Row(ParallelModule):
src_rank = dist.distributed_c10d._get_global_rank(self.process_group, 0)
origin_device = self.bias.device
self.bias = self.bias.cuda()
dist.broadcast(self.bias, src=src_rank, group=self.process_group)
self.bias = self.bias.to(origin_device)
bias = self.bias.cuda()
dist.broadcast(bias, src=src_rank, group=self.process_group)
bias = bias.to(origin_device)
self.bias.copy_(bias)
def forward(self, input_: Tensor) -> Tensor:
# Set up backprop all-reduce.

View File

@ -31,12 +31,25 @@ from ._operation import (
from .parallel_module import ParallelModule
from .utils import create_randomizer_with_offset
__all__ = ['LinearConv1D_Col', 'LinearConv1D_Row']
__all__ = ['FusedLinear1D_Col', 'FusedLinear1D_Row']
# ====================================
# For GPT Only
# ====================================
def split_fused_qkv(qkv: torch.Tensor, n_fused: int, process_group: ProcessGroup):
def split_fused_qkv_in_gpt2_style(qkv: torch.Tensor,
n_fused: int,
process_group: ProcessGroup,
is_transposed: bool = False):
"""
The fused qkv tensor looks like [Q1, Q2, K1, K2, V1, V2], this function will split them into [Q1, K1, V1] and [Q2, K2, V2].
Args:
qkv (torch.Tensor): The fused qkv tensor.
n_fused (int): The number items fused together, defaults to 3 (query, key and value).
process_group (ProcessGroup): The process group for distributed communication.
is_transposed (bool): generally the tensor is the shape of (out_features, in_features). Set this to True if the tensor is in the shape (in_features, out_features).
"""
# get the number of slice for the fused qkv
rank = dist.get_rank(group=process_group)
@ -48,7 +61,10 @@ def split_fused_qkv(qkv: torch.Tensor, n_fused: int, process_group: ProcessGroup
# [Q, K, V]
# to
# [Q1, Q2, K1, K2, V1, V2]
if is_transposed:
weight_chunks = torch.chunk(qkv, world_size * n_fused, dim=-1)
else:
weight_chunks = torch.chunk(qkv, world_size * n_fused, dim=0)
# rearrange the slice into the final order
# from
@ -56,13 +72,26 @@ def split_fused_qkv(qkv: torch.Tensor, n_fused: int, process_group: ProcessGroup
# to
# [Q1, K1, V1], [Q2, K2, V2]
weight_chunks_of_current_rank = [weight_chunks[i] for i in order[rank::world_size]]
if is_transposed:
weight_of_current_rank = torch.cat(weight_chunks_of_current_rank, dim=-1)
else:
weight_of_current_rank = torch.cat(weight_chunks_of_current_rank, dim=0)
return weight_of_current_rank
def gather_fused_qkv(qkv: torch.Tensor, n_fused: int, process_group: ProcessGroup):
def gather_fused_qkv_in_gpt2_style(qkv: torch.Tensor,
n_fused: int,
process_group: ProcessGroup,
is_transposed: bool = False):
"""
The splitted qkv tensor looks like [Q1, K1, V1] and [Q2, K2, V2], this function will gather them into [Q1, Q2, K1, K2, V1, V2].
Args:
qkv (torch.Tensor): The fused qkv tensor.
n_fused (int): The number items fused together, defaults to 3 (query, key and value).
process_group (ProcessGroup): The process group for distributed communication.
is_transposed (bool): generally the tensor is the shape of (out_features, in_features). Set this to True if the tensor is in the shape (in_features, out_features).
"""
world_size = dist.get_world_size(group=process_group)
@ -75,7 +104,11 @@ def gather_fused_qkv(qkv: torch.Tensor, n_fused: int, process_group: ProcessGrou
qkv = qkv.cuda()
gather_list = [torch.zeros_like(qkv) for _ in range(world_size)]
dist.all_gather(gather_list, qkv, group=process_group)
if is_transposed:
gather_weight = torch.cat(gather_list, dim=-1)
else:
gather_weight = torch.cat(gather_list, dim=0)
gather_weight = gather_weight.to(origin_device)
qkv = qkv.to(origin_device)
@ -84,15 +117,23 @@ def gather_fused_qkv(qkv: torch.Tensor, n_fused: int, process_group: ProcessGrou
# [Q1, K1, V1, Q2, K2, V2]
# to
# [Q1, Q2, K1, K2, V1, V2]
if is_transposed:
weight_chunks = torch.chunk(gather_weight, world_size * n_fused, dim=-1)
else:
weight_chunks = torch.chunk(gather_weight, world_size * n_fused, dim=0)
reordered_chunk_list = []
for i in range(n_fused):
reordered_chunk_list.extend(weight_chunks[i::n_fused])
if is_transposed:
reordered_gather_weight = torch.cat(reordered_chunk_list, dim=-1)
else:
reordered_gather_weight = torch.cat(reordered_chunk_list, dim=0)
return reordered_gather_weight
class LinearConv1D_Col(ParallelModule):
class GPT2FusedLinearConv1D_Col(ParallelModule):
r"""Linear layer with column parallelism.
The linear layer is defined as :math:`Y = XA + b`. A is parallelized along
@ -154,10 +195,10 @@ class LinearConv1D_Col(ParallelModule):
weight = torch.empty(self.in_features, self.out_features, **factory_kwargs)
def shard_fn(tensor):
return split_fused_qkv(tensor, self.n_fused, self.process_group)
return split_fused_qkv_in_gpt2_style(tensor, self.n_fused, self.process_group, True)
def gather_fn(tensor):
return gather_fused_qkv(tensor, 3, self.process_group)
return gather_fused_qkv_in_gpt2_style(tensor, 3, self.process_group, True)
with torch.no_grad():
sharded_weight = distribute_tensor_with_customization(weight, shard_fn, gather_fn)
@ -202,7 +243,7 @@ class LinearConv1D_Col(ParallelModule):
f'Expected only one process group, got {len(process_group)}.'
process_group = process_group[0]
linear_1d = LinearConv1D_Col(in_features=in_features,
linear_1d = GPT2FusedLinearConv1D_Col(in_features=in_features,
out_features=out_features,
bias=bias,
device=device,
@ -212,11 +253,17 @@ class LinearConv1D_Col(ParallelModule):
# TODO: copy the sharded weights
with torch.no_grad():
sharded_weight = split_fused_qkv(module.weight.data, n_fused=n_fused, process_group=process_group)
sharded_weight = split_fused_qkv_in_gpt2_style(module.weight.data,
n_fused=n_fused,
process_group=process_group,
is_transposed=True)
linear_1d.weight.data.copy_(sharded_weight.data)
if bias:
sharded_bias = split_fused_qkv(module.bias.data, n_fused=n_fused, process_group=process_group)
sharded_bias = split_fused_qkv_in_gpt2_style(module.bias.data,
n_fused=n_fused,
process_group=process_group,
is_transposed=True)
linear_1d.bias.data.copy_(sharded_bias.data)
return linear_1d
@ -254,7 +301,7 @@ class LinearConv1D_Col(ParallelModule):
return output
class LinearConv1D_Row(ParallelModule):
class GPT2FusedLinearConv1D_Row(ParallelModule):
r""" Linear layer with row parallelism.
This layer is used to fit `Conv1D` layer (Fused QKV) in gpt2 of huggingface.
@ -345,7 +392,7 @@ class LinearConv1D_Row(ParallelModule):
f'Expected only one process group, got {len(process_group)}.'
process_group = process_group[0]
linear_1d = LinearConv1D_Row(in_features=in_features,
linear_1d = GPT2FusedLinearConv1D_Row(in_features=in_features,
out_features=out_features,
bias=bias,
device=device,

View File

@ -3,6 +3,7 @@ from contextlib import contextmanager
import torch
import torch.distributed as dist
from torch.distributed import ProcessGroup
from torch.distributed.distributed_c10d import _get_global_rank
class Randomizer:
@ -112,27 +113,90 @@ class Randomizer:
"""
idx = Randomizer._INDEX
Randomizer._INDEX += 1
return idx
@staticmethod
def increment_index():
"""
Increment the index of the randomizer by one.
"""
Randomizer._INDEX += 1
def create_randomizer_with_offset(seed: int, process_group: ProcessGroup = None):
@staticmethod
def is_randomizer_index_synchronized(process_group: ProcessGroup = None):
"""
Return whether the randomizer index is synchronized across processes.
"""
index = Randomizer.index()
if dist.is_initialized():
# convert the index to tensor
index_tensor = torch.tensor(index, dtype=torch.int32).cuda()
# all gather the index
gathered_index = [torch.zeros_like(index_tensor) for _ in range(dist.get_world_size(process_group))]
dist.all_gather(gathered_index, index_tensor, process_group)
# make sure all the gathered index are the same
for i in range(1, dist.get_world_size(process_group)):
if gathered_index[i] != gathered_index[0]:
return False
return True
@staticmethod
def synchronize_index(process_group: ProcessGroup = None):
"""
All gather the index and pick the largest value.
"""
index = Randomizer.index()
if dist.is_initialized():
# convert the index to tensor
index_tensor = torch.tensor(index, dtype=torch.int32).cuda()
# all gather the index
gathered_index = [torch.zeros_like(index_tensor) for _ in range(dist.get_world_size(process_group))]
dist.all_gather(gathered_index, index_tensor, process_group)
# pick the largest index
for i in range(1, dist.get_world_size(process_group)):
if gathered_index[i] > index_tensor:
index_tensor = gathered_index[i]
# set the index
Randomizer._INDEX = index_tensor.item()
def create_randomizer_with_offset(seed: int,
process_group: ProcessGroup = None,
offset_by_rank: bool = True,
offset_by_index: bool = True):
"""
Create a randomizer with an offset. The offset is equal to the rank of the process and the index of the randomizer.
Args:
seed (int): The base random seed to set.
enable_cpu (bool): fork the CPU RNG state as well.
process_group (ProcessGroup): the process group to get the rank from.
offset_by_rank (bool): whether to offset by the rank of the process, i.e., the rank of the process will be added to the seed. Default: True.
offset_by_index (bool): whether to offset by the index of the randomizer, i.e., the index of the randomizer will be added to the seed. Default: True.
Returns:
Randomizer: the randomizer with offset.
"""
offset = Randomizer.index()
base_seed = seed
if dist.is_initialized():
if offset_by_rank and dist.is_initialized():
rank = dist.get_rank(process_group)
offset += rank
base_seed += rank
seed += offset
return Randomizer(seed=seed)
if offset_by_index:
# check if the randomizer index is synchronized
is_synchronized = Randomizer.is_randomizer_index_synchronized(process_group)
assert is_synchronized, ("We detect that the randomizer index is not synchronized across processes."
"This is not allowed when we want to create a randomizer with offset by index."
"Please call Randomizer.synchronize_index() first.")
base_seed += Randomizer.index()
Randomizer.increment_index()
return Randomizer(seed=base_seed)

View File

@ -78,6 +78,18 @@ _POLICY_LIST = {
PolicyLocation(file_name="opt", class_name="OPTForSequenceClassificationPolicy"),
"transformers.models.opt.modeling_opt.OPTForQuestionAnswering":
PolicyLocation(file_name="opt", class_name="OPTForQuestionAnsweringPolicy"),
# Bloom
"transformers.models.bloom.modeling_bloom.BloomModel":
PolicyLocation(file_name="bloom", class_name="BloomModelPolicy"),
"transformers.models.bloom.modeling_bloom.BloomForCausalLM":
PolicyLocation(file_name="bloom", class_name="BloomForCausalLMPolicy"),
"transformers.models.bloom.modeling_bloom.BloomForSequenceClassification":
PolicyLocation(file_name="bloom", class_name="BloomForSequenceClassificationPolicy"),
"transformers.models.bloom.modeling_bloom.BloomForTokenClassification":
PolicyLocation(file_name="bloom", class_name="BloomForTokenClassificationPolicy"),
"transformers.models.bloom.modeling_bloom.BloomForQuestionAnswering":
PolicyLocation(file_name="bloom", class_name="BloomForQuestionAnsweringPolicy"),
}

View File

@ -52,6 +52,7 @@ class ModulePolicyDescription:
attribute_replacement: Dict[str, Any]
param_replacement: List[Callable]
sub_module_replacement: List[SubModuleReplacementDescription]
method_replacement: List[Callable] = None
class Policy(ABC):

View File

@ -61,7 +61,7 @@ class BertPolicy(Policy):
),
SubModuleReplacementDescription(
suffix="attention.self.dropout",
target_module=col_nn.Dropout1D,
target_module=col_nn.DropoutForParallelInput,
),
SubModuleReplacementDescription(
suffix="attention.output.dense",
@ -69,7 +69,7 @@ class BertPolicy(Policy):
),
SubModuleReplacementDescription(
suffix="attention.output.dropout",
target_module=col_nn.Dropout1D,
target_module=col_nn.DropoutForParallelInput,
),
SubModuleReplacementDescription(
suffix="intermediate.dense",
@ -81,7 +81,7 @@ class BertPolicy(Policy):
),
SubModuleReplacementDescription(
suffix="output.dropout",
target_module=col_nn.Dropout1D,
target_module=col_nn.DropoutForParallelInput,
)
]),
BertEmbeddings:
@ -94,7 +94,7 @@ class BertPolicy(Policy):
),
SubModuleReplacementDescription(
suffix="dropout",
target_module=col_nn.Dropout1D,
target_module=col_nn.DropoutForParallelInput,
)
])
}
@ -258,7 +258,7 @@ class BertForSequenceClassificationPolicy(BertPolicy):
sub_module_replacement=[
SubModuleReplacementDescription(
suffix="dropout",
target_module=col_nn.Dropout1D,
target_module=col_nn.DropoutForParallelInput,
)
])
}
@ -281,7 +281,7 @@ class BertForTokenClassificationPolicy(BertPolicy):
sub_module_replacement=[
SubModuleReplacementDescription(
suffix="dropout",
target_module=col_nn.Dropout1D,
target_module=col_nn.DropoutForParallelInput,
)
])
}
@ -311,7 +311,7 @@ class BertForMultipleChoicePolicy(BertPolicy):
sub_module_replacement=[
SubModuleReplacementDescription(
suffix="dropout",
target_module=col_nn.Dropout1D,
target_module=col_nn.DropoutForParallelInput,
)
])
}

View File

@ -0,0 +1,214 @@
import torch
import torch.distributed as dist
import colossalai.shardformer.layer as col_nn
from .basepolicy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
def build_bloom_alibi_tensor(self, attention_mask: torch.Tensor, num_heads: int, dtype: torch.dtype) -> torch.Tensor:
"""
Link to paper: https://arxiv.org/abs/2108.12409 Alibi tensor is not causal as the original paper mentions, it
relies on a translation invariance of softmax for quick implementation: with l being a tensor, and a fixed value
`softmax(l+a) = softmax(l)`. Based on
https://github.com/ofirpress/attention_with_linear_biases/blob/a35aaca144e0eb6b789dfcb46784c4b8e31b7983/fairseq/models/transformer.py#L742
TODO @thomasw21 this doesn't work as nicely due to the masking strategy, and so masking varies slightly.
Args:
Returns tensor shaped (batch_size * num_heads, 1, max_seq_len)
attention_mask (`torch.Tensor`):
Token-wise attention mask, this should be of shape (batch_size, max_seq_len).
num_heads (`int`, *required*):
number of heads
dtype (`torch.dtype`, *optional*, default=`torch.bfloat16`):
dtype of the output tensor
"""
import math
if dist.is_initialized():
world_size = dist.get_world_size()
num_heads = num_heads * world_size
batch_size, seq_length = attention_mask.shape
closest_power_of_2 = 2**math.floor(math.log2(num_heads))
base = torch.tensor(2**(-(2**-(math.log2(closest_power_of_2) - 3))),
device=attention_mask.device,
dtype=torch.float32)
powers = torch.arange(1, 1 + closest_power_of_2, device=attention_mask.device, dtype=torch.int32)
slopes = torch.pow(base, powers)
if closest_power_of_2 != num_heads:
extra_base = torch.tensor(2**(-(2**-(math.log2(2 * closest_power_of_2) - 3))),
device=attention_mask.device,
dtype=torch.float32)
num_remaining_heads = min(closest_power_of_2, num_heads - closest_power_of_2)
extra_powers = torch.arange(1, 1 + 2 * num_remaining_heads, 2, device=attention_mask.device, dtype=torch.int32)
slopes = torch.cat([slopes, torch.pow(extra_base, extra_powers)], dim=0)
# Note: alibi will added to the attention bias that will be applied to the query, key product of attention
# => therefore alibi will have to be of shape (batch_size, num_heads, query_length, key_length)
# => here we set (batch_size=1, num_heads=num_heads, query_length=1, key_length=max_length)
# => the query_length dimension will then be broadcasted correctly
# This is more or less identical to T5's relative position bias:
# https://github.com/huggingface/transformers/blob/f681437203baa7671de3174b0fa583c349d9d5e1/src/transformers/models/t5/modeling_t5.py#L527
arange_tensor = ((attention_mask.cumsum(dim=-1) - 1) * attention_mask)[:, None, :]
alibi = slopes[..., None] * arange_tensor
if dist.is_initialized():
num_heads_per_rank = int(num_heads / dist.get_world_size())
offset = dist.get_rank() * num_heads_per_rank
alibi = alibi.view(batch_size, num_heads, 1, seq_length)
alibi = alibi[:, offset:num_heads_per_rank + offset, :, :]
return alibi.reshape(batch_size * num_heads_per_rank, 1, seq_length).to(dtype)
else:
return alibi.reshape(batch_size * num_heads, 1, seq_length).to(dtype)
class BloomPolicy(Policy):
def preprocess(self):
# reshape the embedding layer
r"""
Reshape the Embedding layer to make the embedding dimension divisible by world_size
"""
# TODO:
vocab_size = self.model.config.vocab_size
world_size = self.shard_config.tensor_parallel_size
if vocab_size % world_size != 0:
new_vocab_size = vocab_size + world_size - vocab_size % world_size
self.model.resize_token_embeddings(new_vocab_size)
return self.model
def module_policy(self):
from transformers.models.bloom.modeling_bloom import BloomBlock, BloomModel
return {
BloomBlock:
ModulePolicyDescription(
attribute_replacement={
# 1. shard hidden size
"self_attention.hidden_size":
self.model.config.hidden_size // self.shard_config.tensor_parallel_size,
"self_attention.split_size":
self.model.config.hidden_size // self.shard_config.tensor_parallel_size,
# 2. shard number of heads
"self_attention.num_heads":
self.model.config.n_head // self.shard_config.tensor_parallel_size,
},
param_replacement=[],
sub_module_replacement=[
SubModuleReplacementDescription(
suffix="self_attention.query_key_value",
target_module=col_nn.Linear1D_Col,
# kwargs={'n_fused': 3}
),
SubModuleReplacementDescription(
suffix="self_attention.dense",
target_module=col_nn.Linear1D_Row,
),
SubModuleReplacementDescription(
suffix="self_attention.attention_dropout",
target_module=col_nn.DropoutForParallelInput,
),
SubModuleReplacementDescription(
suffix="mlp.dense_h_to_4h",
target_module=col_nn.Linear1D_Col,
),
SubModuleReplacementDescription(
suffix="mlp.dense_4h_to_h",
target_module=col_nn.Linear1D_Row,
),
]),
BloomModel:
ModulePolicyDescription(attribute_replacement={
"num_heads": self.model.config.n_head // self.shard_config.tensor_parallel_size,
},
param_replacement=[],
method_replacement={"build_alibi_tensor": build_bloom_alibi_tensor},
sub_module_replacement=[
SubModuleReplacementDescription(
suffix="word_embeddings",
target_module=col_nn.VocabParallelEmbedding1D,
)
])
}
def new_model_class(self):
# do nothing
return self.model
def postprocess(self):
return self.model
# BertModel
class BloomModelPolicy(BloomPolicy):
pass
class BloomForCausalLMPolicy(BloomPolicy):
def module_policy(self):
from transformers.models.bloom.modeling_bloom import BloomForCausalLM
policy = super().module_policy()
# add a new item for casual lm
new_item = {
BloomForCausalLM:
ModulePolicyDescription(attribute_replacement={},
param_replacement=[],
sub_module_replacement=[
SubModuleReplacementDescription(suffix="lm_head",
target_module=col_nn.Linear1D_Col,
kwargs=dict(gather_output=True))
])
}
policy.update(new_item)
return policy
class BloomForSequenceClassificationPolicy(BloomPolicy):
def module_policy(self):
from transformers.models.bloom.modeling_bloom import BloomForSequenceClassification
policy = super().module_policy()
# add a new item for casual lm
new_item = {
BloomForSequenceClassification:
ModulePolicyDescription(attribute_replacement={},
param_replacement=[],
sub_module_replacement=[
SubModuleReplacementDescription(suffix="score",
target_module=col_nn.Linear1D_Col,
kwargs=dict(gather_output=True))
])
}
policy.update(new_item)
return policy
class BloomForTokenClassificationPolicy(BloomPolicy):
def module_policy(self):
from transformers.models.bloom.modeling_bloom import BloomForTokenClassification
policy = super().module_policy()
# add a new item for casual lm
new_item = {
BloomForTokenClassification:
ModulePolicyDescription(attribute_replacement={},
param_replacement=[],
sub_module_replacement=[
SubModuleReplacementDescription(suffix="classifier",
target_module=col_nn.Linear1D_Col,
kwargs=dict(gather_output=True)),
SubModuleReplacementDescription(
suffix="dropout",
target_module=col_nn.DropoutForReplicatedInput,
),
])
}
policy.update(new_item)
return policy
class BloomForQuestionAnsweringPolicy(BloomPolicy):
# No head sharding as the output features is only 2
pass

View File

@ -42,37 +42,37 @@ class GPT2Policy(Policy):
sub_module_replacement=[
SubModuleReplacementDescription(
suffix="attn.c_attn",
target_module=col_nn.LinearConv1D_Col,
target_module=col_nn.GPT2FusedLinearConv1D_Col,
kwargs={
"n_fused": 3,
},
),
SubModuleReplacementDescription(
suffix="attn.c_proj",
target_module=col_nn.LinearConv1D_Row,
target_module=col_nn.GPT2FusedLinearConv1D_Row,
),
SubModuleReplacementDescription(
suffix="mlp.c_fc",
target_module=col_nn.LinearConv1D_Col,
target_module=col_nn.GPT2FusedLinearConv1D_Col,
kwargs={
"n_fused": 1,
},
),
SubModuleReplacementDescription(
suffix="mlp.c_proj",
target_module=col_nn.LinearConv1D_Row,
target_module=col_nn.GPT2FusedLinearConv1D_Row,
),
SubModuleReplacementDescription(
suffix="attn.attn_dropout",
target_module=col_nn.Dropout1D,
target_module=col_nn.DropoutForParallelInput,
),
SubModuleReplacementDescription(
suffix="attn.resid_dropout",
target_module=col_nn.Dropout1D,
target_module=col_nn.DropoutForParallelInput,
),
SubModuleReplacementDescription(
suffix="mlp.dropout",
target_module=col_nn.Dropout1D,
target_module=col_nn.DropoutForParallelInput,
),
])
}

View File

@ -9,7 +9,7 @@ from transformers.models.t5.modeling_t5 import (
T5Stack,
)
from colossalai.shardformer.layer import Dropout1D, Embedding1D, Linear1D_Col, Linear1D_Row
from colossalai.shardformer.layer import DropoutForParallelInput, Embedding1D, Linear1D_Col, Linear1D_Row
from .basepolicy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
@ -38,7 +38,7 @@ class T5ModelPolicy(Policy):
sub_module_replacement=[
SubModuleReplacementDescription(
suffix="dropout",
target_module=Dropout1D,
target_module=DropoutForParallelInput,
)
]),
T5LayerSelfAttention:
@ -47,7 +47,7 @@ class T5ModelPolicy(Policy):
sub_module_replacement=[
SubModuleReplacementDescription(
suffix="dropout",
target_module=Dropout1D,
target_module=DropoutForParallelInput,
),
]),
T5LayerCrossAttention:
@ -56,7 +56,7 @@ class T5ModelPolicy(Policy):
sub_module_replacement=[
SubModuleReplacementDescription(
suffix="dropout",
target_module=Dropout1D,
target_module=DropoutForParallelInput,
)
]),
T5Attention:
@ -97,7 +97,7 @@ class T5ModelPolicy(Policy):
sub_module_replacement=[
SubModuleReplacementDescription(
suffix="dropout",
target_module=Dropout1D,
target_module=DropoutForParallelInput,
),
]),
T5DenseGatedActDense:
@ -117,7 +117,7 @@ class T5ModelPolicy(Policy):
kwargs=dict(gather_output=True)),
SubModuleReplacementDescription(
suffix="dropout",
target_module=Dropout1D,
target_module=DropoutForParallelInput,
)
]),
T5DenseActDense:
@ -134,7 +134,7 @@ class T5ModelPolicy(Policy):
),
SubModuleReplacementDescription(
suffix="dropout",
target_module=Dropout1D,
target_module=DropoutForParallelInput,
)
])
}

View File

@ -1,13 +1,13 @@
from typing import Dict, Union
import torch.nn as nn
from transformers.models.vit.modeling_vit import ViTAttention, ViTEmbeddings, ViTLayer, ViTModel
from transformers.models.vit.modeling_vit import ViTModel, ViTLayer, ViTEmbeddings, ViTAttention
from colossalai.shardformer.layer import Linear1D_Col, Linear1D_Row, Dropout1D
from colossalai.shardformer.layer import DropoutForReplicatedInput, Linear1D_Col, Linear1D_Row
from .basepolicy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
class ViTPolicy(Policy):
def preprocess(self):
@ -24,19 +24,16 @@ class ViTPolicy(Policy):
def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:
return {
ViTEmbeddings:
ModulePolicyDescription(
attribute_replacement{},
ModulePolicyDescription(attribute_replacement={},
param_replacement=[],
sub_module_replacement=[
SubModuleReplacementDescription(
suffix="dropout",
target_module=Dropout1D,
target_module=DropoutForReplicatedInput,
)
]
),
]),
ViTLayer:
ModulePolicyDescription(
attribute_replacement{
ModulePolicyDescription(attribute_replacement={
"attention.attention.num_attention_heads":
self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size,
"attention.attention.all_head_size":
@ -58,7 +55,7 @@ class ViTPolicy(Policy):
),
SubModuleReplacementDescription(
suffix="attention.attention.dropout",
target_module=Dropout1D,
target_module=DropoutForParallelInput,
),
SubModuleReplacementDescription(
suffix="attention.output.dense",
@ -66,7 +63,7 @@ class ViTPolicy(Policy):
),
SubModuleReplacementDescription(
suffix="attention.output.dropout",
target_module=Dropout1D,
target_module=DropoutForParallelInput,
),
SubModuleReplacementDescription(
suffix="intermediate.dense",
@ -78,10 +75,9 @@ class ViTPolicy(Policy):
),
SubModuleReplacementDescription(
suffix="output.dropout",
target_module=Dropout1D,
),
]
target_module=DropoutForParallelInput,
),
]),
}
def new_model_class(self):
@ -89,8 +85,3 @@ class ViTPolicy(Policy):
def postprocess(self):
return self.model

View File

@ -95,8 +95,9 @@ class ModelSharder(object):
attr_replacement = module_description[1].attribute_replacement
param_replacement = module_description[1].param_replacement
sub_module_replacement = module_description[1].sub_module_replacement
method_replacement = module_description[1].method_replacement
self._recursive_replace_layer(self.model, origin_layer_cls, attr_replacement, param_replacement,
sub_module_replacement)
method_replacement, sub_module_replacement)
def _recursive_replace_layer(
self,
@ -104,6 +105,7 @@ class ModelSharder(object):
origin_cls: nn.Module,
attr_replacement: Dict[str, Any],
param_replacement: List[Callable],
method_replacement: Dict[str, Callable],
sub_module_replacement: List[Callable],
) -> None:
r"""
@ -119,9 +121,11 @@ class ModelSharder(object):
if module.__class__ == origin_cls:
self._replace_attr(module, attr_replacement)
self._replace_param(module, param_replacement)
self._replace_method(module, method_replacement)
self._replace_sub_module(module, sub_module_replacement)
for name, child in module.named_children():
self._recursive_replace_layer(child, origin_cls, attr_replacement, param_replacement,
self._recursive_replace_layer(child, origin_cls, attr_replacement, param_replacement, method_replacement,
sub_module_replacement)
def _replace_attr(
@ -154,6 +158,14 @@ class ModelSharder(object):
# TODO: support parameter shard
pass
def _replace_method(self, module: nn.Module, method_replacement: Dict[str, Callable]):
if method_replacement is None:
return
for method_name, new_method in method_replacement.items():
# bind the new method to the module
setattr(module, method_name, new_method.__get__(module, module.__class__))
def _replace_sub_module(
self,
org_layer: nn.Module,

View File

@ -1,5 +1,6 @@
from .albert import *
from .bert import *
from .bloom import *
from .gpt import *
from .llama import *
from .opt import *

View File

@ -0,0 +1,107 @@
import torch
import transformers
from ..registry import ModelAttribute, model_zoo
# ===============================
# Register Bloom
# ===============================
def data_gen():
# Generated from following code snippet
#
# from transformers import BloomTokenizer
# input = 'Hello, my dog is cute'
# tokenized_input = tokenizer(input, return_tensors='pt')
# input_ids = tokenized_input['input_ids']
# attention_mask = tokenized_input['attention_mask']
input_ids = torch.tensor([[59414, 15, 2670, 35433, 632, 207595]], dtype=torch.int64)
attention_mask = torch.tensor([[1, 1, 1, 1, 1, 1]], dtype=torch.int64)
return dict(input_ids=input_ids, attention_mask=attention_mask)
def data_gen_for_lm():
# LM data gen
# the `labels` of LM is the token of the output, cause no padding, use `input_ids` as `labels`
data = data_gen()
data['labels'] = data['input_ids'].clone()
return data
def data_gen_for_token_classification():
# token classification data gen
# `labels` is the type not the token id for token classification, 0 or 1
data = data_gen()
data['labels'] = torch.tensor([[0, 0, 0, 0, 0, 0]], dtype=torch.int64)
return data
def data_gen_for_sequence_classification():
# sequence classification data gen
data = data_gen()
data['labels'] = torch.tensor([0], dtype=torch.int64)
return data
def data_gen_for_question_answering():
# obtained with the following code
#
# from transformers import AutoTokenizer
# tokenizer = AutoTokenizer.from_pretrained("bigscience/bloom-560m")
# question, text = "Who was Jim Henson?", "Jim Henson was a nice puppet"
# inputs = tokenizer(question, text, return_tensors="pt")
input_ids = torch.tensor(
[[57647, 1620, 23967, 620, 107373, 34, 91514, 620, 107373, 1620, 267, 35378, 48946, 18161]], dtype=torch.int64)
attention_mask = torch.tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]], dtype=torch.int64)
return dict(input_ids=input_ids, attention_mask=attention_mask)
# define output transform function
output_transform_fn = lambda x: x
# define loss function
loss_fn_for_bloom_model = lambda x: x.last_hidden_state.mean()
loss_fn_for_causal_lm = lambda x: x.loss
loss_fn_for_classification = lambda x: x.logits.mean()
loss_fn_for_question_answering = lambda x: x.end_logits.mean()
config = transformers.BloomConfig(n_layer=1,
n_head=4,
vocab_size=250880,
hidden_dropout=0,
attention_dropout=0,
hidden_size=64)
# register the following models
model_zoo.register(name='transformers_bloom',
model_fn=lambda: transformers.BloomModel(config),
data_gen_fn=data_gen,
output_transform_fn=output_transform_fn,
loss_fn=loss_fn_for_bloom_model,
model_attribute=ModelAttribute(has_control_flow=True))
model_zoo.register(name='transformers_bloom_for_causal_lm',
model_fn=lambda: transformers.BloomForCausalLM(config),
data_gen_fn=data_gen_for_lm,
output_transform_fn=output_transform_fn,
loss_fn=loss_fn_for_causal_lm,
model_attribute=ModelAttribute(has_control_flow=True))
model_zoo.register(name='transformers_bloom_for_sequence_classification',
model_fn=lambda: transformers.BloomForSequenceClassification(config),
data_gen_fn=data_gen_for_sequence_classification,
output_transform_fn=output_transform_fn,
loss_fn=loss_fn_for_classification,
model_attribute=ModelAttribute(has_control_flow=True))
model_zoo.register(name='transformers_bloom_for_token_classification',
model_fn=lambda: transformers.BloomForTokenClassification(config),
data_gen_fn=data_gen_for_token_classification,
output_transform_fn=output_transform_fn,
loss_fn=loss_fn_for_classification,
model_attribute=ModelAttribute(has_control_flow=True))
model_zoo.register(name='transformers_bloom_for_question_answering',
model_fn=lambda: transformers.BloomForQuestionAnswering(config),
data_gen_fn=data_gen_for_question_answering,
output_transform_fn=output_transform_fn,
loss_fn=loss_fn_for_question_answering,
model_attribute=ModelAttribute(has_control_flow=True))

View File

@ -6,8 +6,6 @@ from ..registry import ModelAttribute, model_zoo
# ===============================
# Register single-sentence GPT
# ===============================
BATCH_SIZE = 1 # it can only be 1 as GPT cannot handle batch sizes > 1 if no padding token is defined.
SEQ_LENGTH = 16
def data_gen():

View File

@ -3,13 +3,13 @@ import torch.distributed as dist
import torch.nn as nn
import colossalai
from colossalai.shardformer.layer import Dropout1D
from colossalai.shardformer.layer import DropoutForParallelInput, DropoutForReplicatedInput
from colossalai.testing import assert_equal, assert_not_equal, rerun_if_address_is_in_use, spawn
def check_dropout():
def check_dropout_parallel_input():
dropout = nn.Dropout().cuda()
dropout_1d = Dropout1D.from_native_module(dropout, process_group=None)
dropout_1d = DropoutForParallelInput.from_native_module(dropout, process_group=None)
# check computation correctness
x = torch.rand(4, 128).cuda()
@ -39,9 +39,26 @@ def check_dropout():
assert_not_equal(out_1d_all[i], out_1d_all[0])
def check_dropout_replicated_input():
dropout = nn.Dropout().cuda()
dropout_replica = DropoutForReplicatedInput.from_native_module(dropout, process_group=None)
# check computation correctness
x = torch.rand(4, 128).cuda()
out_1d = dropout_replica(x)
# ensure out_1d is different across ranks
world_size = dist.get_world_size()
out_1d_all = [torch.zeros_like(out_1d) for _ in range(world_size)]
dist.all_gather(out_1d_all, out_1d)
for i in range(1, world_size):
assert_equal(out_1d_all[i], out_1d_all[0])
def run_dist(rank, world_size, port):
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
check_dropout()
check_dropout_parallel_input()
check_dropout_replicated_input()
@rerun_if_address_is_in_use()

View File

@ -4,8 +4,8 @@ import torch.nn as nn
from torch.testing import assert_close
import colossalai
from colossalai.shardformer.layer import LinearConv1D_Col, LinearConv1D_Row
from colossalai.shardformer.layer.linear_conv import split_fused_qkv
from colossalai.shardformer.layer import GPT2FusedLinearConv1D_Col, GPT2FusedLinearConv1D_Row
from colossalai.shardformer.layer.qkv_fused_linear import split_fused_qkv_in_gpt2_style
from colossalai.testing import rerun_if_address_is_in_use, spawn
@ -52,7 +52,10 @@ def rearrange(tensor: torch.Tensor, dim: int):
def check_linear_conv_1d_col():
linear = Conv1D(192, 48).cuda()
linear_conv_col = LinearConv1D_Col.from_native_module(linear, process_group=None, gather_output=True, n_fused=3)
linear_conv_col = GPT2FusedLinearConv1D_Col.from_native_module(linear,
process_group=None,
gather_output=True,
n_fused=3)
assert linear.weight.shape == torch.Size([48, 192])
assert linear.bias.shape == torch.Size([192])
@ -73,13 +76,13 @@ def check_linear_conv_1d_col():
out.sum().backward()
gather_out.sum().backward()
target_grad = split_fused_qkv(linear.weight.grad, 3, None)
target_grad = split_fused_qkv_in_gpt2_style(linear.weight.grad, 3, None, True)
assert_close(target_grad, linear_conv_col.weight.grad)
def check_linear_conv_1d_row():
linear = Conv1D(192, 48).cuda()
linear_row = LinearConv1D_Row.from_native_module(linear, process_group=None, parallel_input=False)
linear_row = GPT2FusedLinearConv1D_Row.from_native_module(linear, process_group=None, parallel_input=False)
assert linear.weight.shape == torch.Size([48, 192])
assert linear_row.weight.shape == torch.Size([24, 192])
@ -102,6 +105,8 @@ def check_linear_conv_1d_row():
def run_dist(rank, world_size, port):
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
# test for linear conv
check_linear_conv_1d_col()
check_linear_conv_1d_row()

View File

@ -0,0 +1,59 @@
import pytest
import torch
import colossalai
from colossalai.logging import disable_existing_loggers
from colossalai.testing import assert_hf_output_close, clear_cache_before_run, rerun_if_address_is_in_use, spawn
from tests.kit.model_zoo import model_zoo
from tests.test_shardformer.test_model._utils import build_model, run_forward
def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn):
# check forward
org_output, org_loss, shard_output, shard_loss = run_forward(org_model, sharded_model, data_gen_fn,
output_transform_fn, loss_fn)
assert_hf_output_close(org_output, shard_output, ignore_keys=['past_key_values'])
# do backward
org_loss.backward()
shard_loss.backward()
# check grad equality
if org_model.__class__.__name__ == 'BloomModel':
org_grad = org_model.h[0].self_attention.query_key_value.weight.grad
shard_grad = sharded_model.h[0].self_attention.query_key_value.weight.grad
else:
org_grad = org_model.transformer.h[0].self_attention.query_key_value.weight.grad
shard_grad = sharded_model.transformer.h[0].self_attention.query_key_value.weight.grad
shard_grad_list = [torch.zeros([*shard_grad.shape]).to('cuda') for _ in range(2)]
torch.distributed.all_gather(shard_grad_list, shard_grad)
all_shard_grad = torch.cat(shard_grad_list, dim=0)
assert torch.allclose(org_loss, shard_loss,
atol=1e-5), f"shard model loss is not equal to orgin model loss\n{org_loss}\n{shard_loss}"
assert torch.allclose(org_grad, all_shard_grad,
atol=1e-5), f"shard model grad is not equal to orgin model grad\n{org_grad}\n{all_shard_grad}"
def check_bloom(rank, world_size, port):
disable_existing_loggers()
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
sub_model_zoo = model_zoo.get_sub_registry('transformers_bloom')
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
org_model, sharded_model = build_model(world_size, model_fn)
check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn)
torch.cuda.empty_cache()
@pytest.mark.dist
@rerun_if_address_is_in_use()
@clear_cache_before_run()
def test_bloom():
spawn(check_bloom, 2)
if __name__ == "__main__":
test_bloom()