mirror of https://github.com/hpcaitech/ColossalAI
[shardformer] supported bloom model (#4098)
parent
8af29ee47a
commit
b1c2901530
|
@ -83,8 +83,10 @@ We will follow this roadmap to develop Shardformer:
|
||||||
- [x] BERT
|
- [x] BERT
|
||||||
- [x] T5
|
- [x] T5
|
||||||
- [x] LlaMa
|
- [x] LlaMa
|
||||||
- [ ] GPT2
|
- [x] GPT2
|
||||||
- [ ] BLOOM
|
- [x] OPT
|
||||||
|
- [x] BLOOM
|
||||||
|
- [ ] GLM
|
||||||
- [ ] RoBERTa
|
- [ ] RoBERTa
|
||||||
- [ ] ALBERT
|
- [ ] ALBERT
|
||||||
- [ ] ERNIE
|
- [ ] ERNIE
|
||||||
|
@ -96,7 +98,7 @@ We will follow this roadmap to develop Shardformer:
|
||||||
- [ ] SwinTransformer
|
- [ ] SwinTransformer
|
||||||
- [ ] SwinTransformer V2
|
- [ ] SwinTransformer V2
|
||||||
- [ ] Audio
|
- [ ] Audio
|
||||||
- [ ] To be added
|
- [ ] Whisper
|
||||||
- [ ] Multi-modal
|
- [ ] Multi-modal
|
||||||
- [ ] To be added
|
- [ ] To be added
|
||||||
|
|
||||||
|
|
|
@ -1,11 +1,12 @@
|
||||||
from .dropout import Dropout1D
|
from .dropout import DropoutForParallelInput, DropoutForReplicatedInput
|
||||||
from .embedding import Embedding1D, VocabParallelEmbedding1D
|
from .embedding import Embedding1D, VocabParallelEmbedding1D
|
||||||
from .layernorm import FusedLayerNorm
|
from .layernorm import FusedLayerNorm
|
||||||
from .linear import Linear1D_Col, Linear1D_Row
|
from .linear import Linear1D_Col, Linear1D_Row
|
||||||
from .linear_conv import LinearConv1D_Col, LinearConv1D_Row
|
|
||||||
from .loss import cross_entropy_1d
|
from .loss import cross_entropy_1d
|
||||||
|
from .qkv_fused_linear import GPT2FusedLinearConv1D_Col, GPT2FusedLinearConv1D_Row
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"Embedding1D", "VocabParallelEmbedding1D", "Linear1D_Col", "Linear1D_Row", "LinearConv1D_Col", "LinearConv1D_Row",
|
"Embedding1D", "VocabParallelEmbedding1D", "Linear1D_Col", "Linear1D_Row", 'GPT2FusedLinearConv1D_Col',
|
||||||
"Dropout1D", "cross_entropy_1d", 'FusedLayerNorm'
|
'GPT2FusedLinearConv1D_Row', 'DropoutForParallelInput', 'DropoutForReplicatedInput', "cross_entropy_1d",
|
||||||
|
'FusedLayerNorm'
|
||||||
]
|
]
|
||||||
|
|
|
@ -7,10 +7,10 @@ from torch.distributed import ProcessGroup
|
||||||
from .parallel_module import ParallelModule
|
from .parallel_module import ParallelModule
|
||||||
from .utils import create_randomizer_with_offset
|
from .utils import create_randomizer_with_offset
|
||||||
|
|
||||||
__all__ = ['Dropout1D']
|
__all__ = ['DropoutForParallelInput', 'DropoutForReplicatedInput']
|
||||||
|
|
||||||
|
|
||||||
class Dropout1D(ParallelModule, nn.Dropout):
|
class DropoutForParallelInput(ParallelModule, nn.Dropout):
|
||||||
"""
|
"""
|
||||||
The Dropout Layer will apply dropout mask to the input tensor. The dropout mask is generated with
|
The Dropout Layer will apply dropout mask to the input tensor. The dropout mask is generated with
|
||||||
randomness on different ranks of the given process group. This can avoid the same dropout mask is generated
|
randomness on different ranks of the given process group. This can avoid the same dropout mask is generated
|
||||||
|
@ -32,13 +32,50 @@ class Dropout1D(ParallelModule, nn.Dropout):
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def from_native_module(module: nn.Dropout,
|
def from_native_module(module: nn.Dropout,
|
||||||
process_group: Union[ProcessGroup, List[ProcessGroup]] = None) -> "Dropout1D":
|
process_group: Union[ProcessGroup, List[ProcessGroup]] = None) -> "DropoutForParallelInput":
|
||||||
"""
|
"""
|
||||||
Create a Dropout1D layer from a native dropout layer.
|
Create a DropoutForParallelInput layer from a native dropout layer.
|
||||||
"""
|
"""
|
||||||
p = module.p
|
p = module.p
|
||||||
inplace = module.inplace
|
inplace = module.inplace
|
||||||
return Dropout1D(p=p, inplace=inplace, process_group=process_group)
|
return DropoutForParallelInput(p=p, inplace=inplace, process_group=process_group)
|
||||||
|
|
||||||
|
def forward(self, input):
|
||||||
|
with self.randomizer.fork_rng():
|
||||||
|
input = super().forward(input)
|
||||||
|
return input
|
||||||
|
|
||||||
|
|
||||||
|
class DropoutForReplicatedInput(ParallelModule, nn.Dropout):
|
||||||
|
"""
|
||||||
|
The Dropout Layer will apply dropout mask to the input tensor. The dropout mask is generated with
|
||||||
|
randomness on different ranks of the given process group. This can avoid the same dropout mask is generated
|
||||||
|
and applied on the same position of different ranks, leading to poor convergence performance.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
p (float): probability of an element to be zeroed. Defaults to 0.5.
|
||||||
|
inplace (bool): If set to True, will do this operation in-place. Defaults to False.
|
||||||
|
process_group (ProcessGroup): the process group to be used for generating randomness. Defaults to None.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, p: float = 0.5, inplace: bool = False, process_group: ProcessGroup = None):
|
||||||
|
# init with nn.Dropout
|
||||||
|
super(nn.Dropout, self).__init__(p=p, inplace=inplace)
|
||||||
|
|
||||||
|
# offset the seed with randomizer index only
|
||||||
|
seed = torch.random.initial_seed()
|
||||||
|
self.randomizer = create_randomizer_with_offset(seed, process_group=process_group, offset_by_rank=False)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def from_native_module(
|
||||||
|
module: nn.Dropout,
|
||||||
|
process_group: Union[ProcessGroup, List[ProcessGroup]] = None) -> "DropoutForReplicatedInput":
|
||||||
|
"""
|
||||||
|
Create a Dropout1D layer from a native dropout layer.
|
||||||
|
"""
|
||||||
|
p = module.p
|
||||||
|
inplace = module.inplace
|
||||||
|
return DropoutForReplicatedInput(p=p, inplace=inplace, process_group=process_group)
|
||||||
|
|
||||||
def forward(self, input):
|
def forward(self, input):
|
||||||
with self.randomizer.fork_rng():
|
with self.randomizer.fork_rng():
|
||||||
|
|
|
@ -277,6 +277,7 @@ class Linear1D_Row(ParallelModule):
|
||||||
def chunk_weight(self):
|
def chunk_weight(self):
|
||||||
self.weight_list = torch.chunk(self.weight, self.stream_chunk_num, dim=0)
|
self.weight_list = torch.chunk(self.weight, self.stream_chunk_num, dim=0)
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
def reset_parameters(self, weight_initializer, bias_initializer) -> None:
|
def reset_parameters(self, weight_initializer, bias_initializer) -> None:
|
||||||
fan_in, fan_out = self.in_features, self.out_features
|
fan_in, fan_out = self.in_features, self.out_features
|
||||||
weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out)
|
weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out)
|
||||||
|
@ -289,9 +290,10 @@ class Linear1D_Row(ParallelModule):
|
||||||
src_rank = dist.distributed_c10d._get_global_rank(self.process_group, 0)
|
src_rank = dist.distributed_c10d._get_global_rank(self.process_group, 0)
|
||||||
|
|
||||||
origin_device = self.bias.device
|
origin_device = self.bias.device
|
||||||
self.bias = self.bias.cuda()
|
bias = self.bias.cuda()
|
||||||
dist.broadcast(self.bias, src=src_rank, group=self.process_group)
|
dist.broadcast(bias, src=src_rank, group=self.process_group)
|
||||||
self.bias = self.bias.to(origin_device)
|
bias = bias.to(origin_device)
|
||||||
|
self.bias.copy_(bias)
|
||||||
|
|
||||||
def forward(self, input_: Tensor) -> Tensor:
|
def forward(self, input_: Tensor) -> Tensor:
|
||||||
# Set up backprop all-reduce.
|
# Set up backprop all-reduce.
|
||||||
|
|
|
@ -31,12 +31,25 @@ from ._operation import (
|
||||||
from .parallel_module import ParallelModule
|
from .parallel_module import ParallelModule
|
||||||
from .utils import create_randomizer_with_offset
|
from .utils import create_randomizer_with_offset
|
||||||
|
|
||||||
__all__ = ['LinearConv1D_Col', 'LinearConv1D_Row']
|
__all__ = ['FusedLinear1D_Col', 'FusedLinear1D_Row']
|
||||||
|
|
||||||
|
# ====================================
|
||||||
|
# For GPT Only
|
||||||
|
# ====================================
|
||||||
|
|
||||||
|
|
||||||
def split_fused_qkv(qkv: torch.Tensor, n_fused: int, process_group: ProcessGroup):
|
def split_fused_qkv_in_gpt2_style(qkv: torch.Tensor,
|
||||||
|
n_fused: int,
|
||||||
|
process_group: ProcessGroup,
|
||||||
|
is_transposed: bool = False):
|
||||||
"""
|
"""
|
||||||
The fused qkv tensor looks like [Q1, Q2, K1, K2, V1, V2], this function will split them into [Q1, K1, V1] and [Q2, K2, V2].
|
The fused qkv tensor looks like [Q1, Q2, K1, K2, V1, V2], this function will split them into [Q1, K1, V1] and [Q2, K2, V2].
|
||||||
|
|
||||||
|
Args:
|
||||||
|
qkv (torch.Tensor): The fused qkv tensor.
|
||||||
|
n_fused (int): The number items fused together, defaults to 3 (query, key and value).
|
||||||
|
process_group (ProcessGroup): The process group for distributed communication.
|
||||||
|
is_transposed (bool): generally the tensor is the shape of (out_features, in_features). Set this to True if the tensor is in the shape (in_features, out_features).
|
||||||
"""
|
"""
|
||||||
# get the number of slice for the fused qkv
|
# get the number of slice for the fused qkv
|
||||||
rank = dist.get_rank(group=process_group)
|
rank = dist.get_rank(group=process_group)
|
||||||
|
@ -48,7 +61,10 @@ def split_fused_qkv(qkv: torch.Tensor, n_fused: int, process_group: ProcessGroup
|
||||||
# [Q, K, V]
|
# [Q, K, V]
|
||||||
# to
|
# to
|
||||||
# [Q1, Q2, K1, K2, V1, V2]
|
# [Q1, Q2, K1, K2, V1, V2]
|
||||||
|
if is_transposed:
|
||||||
weight_chunks = torch.chunk(qkv, world_size * n_fused, dim=-1)
|
weight_chunks = torch.chunk(qkv, world_size * n_fused, dim=-1)
|
||||||
|
else:
|
||||||
|
weight_chunks = torch.chunk(qkv, world_size * n_fused, dim=0)
|
||||||
|
|
||||||
# rearrange the slice into the final order
|
# rearrange the slice into the final order
|
||||||
# from
|
# from
|
||||||
|
@ -56,13 +72,26 @@ def split_fused_qkv(qkv: torch.Tensor, n_fused: int, process_group: ProcessGroup
|
||||||
# to
|
# to
|
||||||
# [Q1, K1, V1], [Q2, K2, V2]
|
# [Q1, K1, V1], [Q2, K2, V2]
|
||||||
weight_chunks_of_current_rank = [weight_chunks[i] for i in order[rank::world_size]]
|
weight_chunks_of_current_rank = [weight_chunks[i] for i in order[rank::world_size]]
|
||||||
|
|
||||||
|
if is_transposed:
|
||||||
weight_of_current_rank = torch.cat(weight_chunks_of_current_rank, dim=-1)
|
weight_of_current_rank = torch.cat(weight_chunks_of_current_rank, dim=-1)
|
||||||
|
else:
|
||||||
|
weight_of_current_rank = torch.cat(weight_chunks_of_current_rank, dim=0)
|
||||||
return weight_of_current_rank
|
return weight_of_current_rank
|
||||||
|
|
||||||
|
|
||||||
def gather_fused_qkv(qkv: torch.Tensor, n_fused: int, process_group: ProcessGroup):
|
def gather_fused_qkv_in_gpt2_style(qkv: torch.Tensor,
|
||||||
|
n_fused: int,
|
||||||
|
process_group: ProcessGroup,
|
||||||
|
is_transposed: bool = False):
|
||||||
"""
|
"""
|
||||||
The splitted qkv tensor looks like [Q1, K1, V1] and [Q2, K2, V2], this function will gather them into [Q1, Q2, K1, K2, V1, V2].
|
The splitted qkv tensor looks like [Q1, K1, V1] and [Q2, K2, V2], this function will gather them into [Q1, Q2, K1, K2, V1, V2].
|
||||||
|
|
||||||
|
Args:
|
||||||
|
qkv (torch.Tensor): The fused qkv tensor.
|
||||||
|
n_fused (int): The number items fused together, defaults to 3 (query, key and value).
|
||||||
|
process_group (ProcessGroup): The process group for distributed communication.
|
||||||
|
is_transposed (bool): generally the tensor is the shape of (out_features, in_features). Set this to True if the tensor is in the shape (in_features, out_features).
|
||||||
"""
|
"""
|
||||||
world_size = dist.get_world_size(group=process_group)
|
world_size = dist.get_world_size(group=process_group)
|
||||||
|
|
||||||
|
@ -75,7 +104,11 @@ def gather_fused_qkv(qkv: torch.Tensor, n_fused: int, process_group: ProcessGrou
|
||||||
qkv = qkv.cuda()
|
qkv = qkv.cuda()
|
||||||
gather_list = [torch.zeros_like(qkv) for _ in range(world_size)]
|
gather_list = [torch.zeros_like(qkv) for _ in range(world_size)]
|
||||||
dist.all_gather(gather_list, qkv, group=process_group)
|
dist.all_gather(gather_list, qkv, group=process_group)
|
||||||
|
|
||||||
|
if is_transposed:
|
||||||
gather_weight = torch.cat(gather_list, dim=-1)
|
gather_weight = torch.cat(gather_list, dim=-1)
|
||||||
|
else:
|
||||||
|
gather_weight = torch.cat(gather_list, dim=0)
|
||||||
gather_weight = gather_weight.to(origin_device)
|
gather_weight = gather_weight.to(origin_device)
|
||||||
qkv = qkv.to(origin_device)
|
qkv = qkv.to(origin_device)
|
||||||
|
|
||||||
|
@ -84,15 +117,23 @@ def gather_fused_qkv(qkv: torch.Tensor, n_fused: int, process_group: ProcessGrou
|
||||||
# [Q1, K1, V1, Q2, K2, V2]
|
# [Q1, K1, V1, Q2, K2, V2]
|
||||||
# to
|
# to
|
||||||
# [Q1, Q2, K1, K2, V1, V2]
|
# [Q1, Q2, K1, K2, V1, V2]
|
||||||
|
if is_transposed:
|
||||||
weight_chunks = torch.chunk(gather_weight, world_size * n_fused, dim=-1)
|
weight_chunks = torch.chunk(gather_weight, world_size * n_fused, dim=-1)
|
||||||
|
else:
|
||||||
|
weight_chunks = torch.chunk(gather_weight, world_size * n_fused, dim=0)
|
||||||
|
|
||||||
reordered_chunk_list = []
|
reordered_chunk_list = []
|
||||||
for i in range(n_fused):
|
for i in range(n_fused):
|
||||||
reordered_chunk_list.extend(weight_chunks[i::n_fused])
|
reordered_chunk_list.extend(weight_chunks[i::n_fused])
|
||||||
|
|
||||||
|
if is_transposed:
|
||||||
reordered_gather_weight = torch.cat(reordered_chunk_list, dim=-1)
|
reordered_gather_weight = torch.cat(reordered_chunk_list, dim=-1)
|
||||||
|
else:
|
||||||
|
reordered_gather_weight = torch.cat(reordered_chunk_list, dim=0)
|
||||||
return reordered_gather_weight
|
return reordered_gather_weight
|
||||||
|
|
||||||
|
|
||||||
class LinearConv1D_Col(ParallelModule):
|
class GPT2FusedLinearConv1D_Col(ParallelModule):
|
||||||
r"""Linear layer with column parallelism.
|
r"""Linear layer with column parallelism.
|
||||||
|
|
||||||
The linear layer is defined as :math:`Y = XA + b`. A is parallelized along
|
The linear layer is defined as :math:`Y = XA + b`. A is parallelized along
|
||||||
|
@ -154,10 +195,10 @@ class LinearConv1D_Col(ParallelModule):
|
||||||
weight = torch.empty(self.in_features, self.out_features, **factory_kwargs)
|
weight = torch.empty(self.in_features, self.out_features, **factory_kwargs)
|
||||||
|
|
||||||
def shard_fn(tensor):
|
def shard_fn(tensor):
|
||||||
return split_fused_qkv(tensor, self.n_fused, self.process_group)
|
return split_fused_qkv_in_gpt2_style(tensor, self.n_fused, self.process_group, True)
|
||||||
|
|
||||||
def gather_fn(tensor):
|
def gather_fn(tensor):
|
||||||
return gather_fused_qkv(tensor, 3, self.process_group)
|
return gather_fused_qkv_in_gpt2_style(tensor, 3, self.process_group, True)
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
sharded_weight = distribute_tensor_with_customization(weight, shard_fn, gather_fn)
|
sharded_weight = distribute_tensor_with_customization(weight, shard_fn, gather_fn)
|
||||||
|
@ -202,7 +243,7 @@ class LinearConv1D_Col(ParallelModule):
|
||||||
f'Expected only one process group, got {len(process_group)}.'
|
f'Expected only one process group, got {len(process_group)}.'
|
||||||
process_group = process_group[0]
|
process_group = process_group[0]
|
||||||
|
|
||||||
linear_1d = LinearConv1D_Col(in_features=in_features,
|
linear_1d = GPT2FusedLinearConv1D_Col(in_features=in_features,
|
||||||
out_features=out_features,
|
out_features=out_features,
|
||||||
bias=bias,
|
bias=bias,
|
||||||
device=device,
|
device=device,
|
||||||
|
@ -212,11 +253,17 @@ class LinearConv1D_Col(ParallelModule):
|
||||||
|
|
||||||
# TODO: copy the sharded weights
|
# TODO: copy the sharded weights
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
sharded_weight = split_fused_qkv(module.weight.data, n_fused=n_fused, process_group=process_group)
|
sharded_weight = split_fused_qkv_in_gpt2_style(module.weight.data,
|
||||||
|
n_fused=n_fused,
|
||||||
|
process_group=process_group,
|
||||||
|
is_transposed=True)
|
||||||
linear_1d.weight.data.copy_(sharded_weight.data)
|
linear_1d.weight.data.copy_(sharded_weight.data)
|
||||||
|
|
||||||
if bias:
|
if bias:
|
||||||
sharded_bias = split_fused_qkv(module.bias.data, n_fused=n_fused, process_group=process_group)
|
sharded_bias = split_fused_qkv_in_gpt2_style(module.bias.data,
|
||||||
|
n_fused=n_fused,
|
||||||
|
process_group=process_group,
|
||||||
|
is_transposed=True)
|
||||||
linear_1d.bias.data.copy_(sharded_bias.data)
|
linear_1d.bias.data.copy_(sharded_bias.data)
|
||||||
|
|
||||||
return linear_1d
|
return linear_1d
|
||||||
|
@ -254,7 +301,7 @@ class LinearConv1D_Col(ParallelModule):
|
||||||
return output
|
return output
|
||||||
|
|
||||||
|
|
||||||
class LinearConv1D_Row(ParallelModule):
|
class GPT2FusedLinearConv1D_Row(ParallelModule):
|
||||||
r""" Linear layer with row parallelism.
|
r""" Linear layer with row parallelism.
|
||||||
This layer is used to fit `Conv1D` layer (Fused QKV) in gpt2 of huggingface.
|
This layer is used to fit `Conv1D` layer (Fused QKV) in gpt2 of huggingface.
|
||||||
|
|
||||||
|
@ -345,7 +392,7 @@ class LinearConv1D_Row(ParallelModule):
|
||||||
f'Expected only one process group, got {len(process_group)}.'
|
f'Expected only one process group, got {len(process_group)}.'
|
||||||
process_group = process_group[0]
|
process_group = process_group[0]
|
||||||
|
|
||||||
linear_1d = LinearConv1D_Row(in_features=in_features,
|
linear_1d = GPT2FusedLinearConv1D_Row(in_features=in_features,
|
||||||
out_features=out_features,
|
out_features=out_features,
|
||||||
bias=bias,
|
bias=bias,
|
||||||
device=device,
|
device=device,
|
|
@ -3,6 +3,7 @@ from contextlib import contextmanager
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
from torch.distributed import ProcessGroup
|
from torch.distributed import ProcessGroup
|
||||||
|
from torch.distributed.distributed_c10d import _get_global_rank
|
||||||
|
|
||||||
|
|
||||||
class Randomizer:
|
class Randomizer:
|
||||||
|
@ -112,27 +113,90 @@ class Randomizer:
|
||||||
|
|
||||||
"""
|
"""
|
||||||
idx = Randomizer._INDEX
|
idx = Randomizer._INDEX
|
||||||
Randomizer._INDEX += 1
|
|
||||||
return idx
|
return idx
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def increment_index():
|
||||||
|
"""
|
||||||
|
Increment the index of the randomizer by one.
|
||||||
|
"""
|
||||||
|
Randomizer._INDEX += 1
|
||||||
|
|
||||||
def create_randomizer_with_offset(seed: int, process_group: ProcessGroup = None):
|
@staticmethod
|
||||||
|
def is_randomizer_index_synchronized(process_group: ProcessGroup = None):
|
||||||
|
"""
|
||||||
|
Return whether the randomizer index is synchronized across processes.
|
||||||
|
"""
|
||||||
|
index = Randomizer.index()
|
||||||
|
if dist.is_initialized():
|
||||||
|
# convert the index to tensor
|
||||||
|
index_tensor = torch.tensor(index, dtype=torch.int32).cuda()
|
||||||
|
|
||||||
|
# all gather the index
|
||||||
|
gathered_index = [torch.zeros_like(index_tensor) for _ in range(dist.get_world_size(process_group))]
|
||||||
|
dist.all_gather(gathered_index, index_tensor, process_group)
|
||||||
|
|
||||||
|
# make sure all the gathered index are the same
|
||||||
|
for i in range(1, dist.get_world_size(process_group)):
|
||||||
|
if gathered_index[i] != gathered_index[0]:
|
||||||
|
return False
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def synchronize_index(process_group: ProcessGroup = None):
|
||||||
|
"""
|
||||||
|
All gather the index and pick the largest value.
|
||||||
|
"""
|
||||||
|
index = Randomizer.index()
|
||||||
|
|
||||||
|
if dist.is_initialized():
|
||||||
|
# convert the index to tensor
|
||||||
|
index_tensor = torch.tensor(index, dtype=torch.int32).cuda()
|
||||||
|
|
||||||
|
# all gather the index
|
||||||
|
gathered_index = [torch.zeros_like(index_tensor) for _ in range(dist.get_world_size(process_group))]
|
||||||
|
dist.all_gather(gathered_index, index_tensor, process_group)
|
||||||
|
|
||||||
|
# pick the largest index
|
||||||
|
for i in range(1, dist.get_world_size(process_group)):
|
||||||
|
if gathered_index[i] > index_tensor:
|
||||||
|
index_tensor = gathered_index[i]
|
||||||
|
|
||||||
|
# set the index
|
||||||
|
Randomizer._INDEX = index_tensor.item()
|
||||||
|
|
||||||
|
|
||||||
|
def create_randomizer_with_offset(seed: int,
|
||||||
|
process_group: ProcessGroup = None,
|
||||||
|
offset_by_rank: bool = True,
|
||||||
|
offset_by_index: bool = True):
|
||||||
"""
|
"""
|
||||||
Create a randomizer with an offset. The offset is equal to the rank of the process and the index of the randomizer.
|
Create a randomizer with an offset. The offset is equal to the rank of the process and the index of the randomizer.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
seed (int): The base random seed to set.
|
seed (int): The base random seed to set.
|
||||||
enable_cpu (bool): fork the CPU RNG state as well.
|
|
||||||
process_group (ProcessGroup): the process group to get the rank from.
|
process_group (ProcessGroup): the process group to get the rank from.
|
||||||
|
offset_by_rank (bool): whether to offset by the rank of the process, i.e., the rank of the process will be added to the seed. Default: True.
|
||||||
|
offset_by_index (bool): whether to offset by the index of the randomizer, i.e., the index of the randomizer will be added to the seed. Default: True.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Randomizer: the randomizer with offset.
|
Randomizer: the randomizer with offset.
|
||||||
"""
|
"""
|
||||||
offset = Randomizer.index()
|
base_seed = seed
|
||||||
|
|
||||||
if dist.is_initialized():
|
if offset_by_rank and dist.is_initialized():
|
||||||
rank = dist.get_rank(process_group)
|
rank = dist.get_rank(process_group)
|
||||||
offset += rank
|
base_seed += rank
|
||||||
|
|
||||||
seed += offset
|
if offset_by_index:
|
||||||
return Randomizer(seed=seed)
|
# check if the randomizer index is synchronized
|
||||||
|
is_synchronized = Randomizer.is_randomizer_index_synchronized(process_group)
|
||||||
|
assert is_synchronized, ("We detect that the randomizer index is not synchronized across processes."
|
||||||
|
"This is not allowed when we want to create a randomizer with offset by index."
|
||||||
|
"Please call Randomizer.synchronize_index() first.")
|
||||||
|
|
||||||
|
base_seed += Randomizer.index()
|
||||||
|
Randomizer.increment_index()
|
||||||
|
|
||||||
|
return Randomizer(seed=base_seed)
|
||||||
|
|
|
@ -78,6 +78,18 @@ _POLICY_LIST = {
|
||||||
PolicyLocation(file_name="opt", class_name="OPTForSequenceClassificationPolicy"),
|
PolicyLocation(file_name="opt", class_name="OPTForSequenceClassificationPolicy"),
|
||||||
"transformers.models.opt.modeling_opt.OPTForQuestionAnswering":
|
"transformers.models.opt.modeling_opt.OPTForQuestionAnswering":
|
||||||
PolicyLocation(file_name="opt", class_name="OPTForQuestionAnsweringPolicy"),
|
PolicyLocation(file_name="opt", class_name="OPTForQuestionAnsweringPolicy"),
|
||||||
|
|
||||||
|
# Bloom
|
||||||
|
"transformers.models.bloom.modeling_bloom.BloomModel":
|
||||||
|
PolicyLocation(file_name="bloom", class_name="BloomModelPolicy"),
|
||||||
|
"transformers.models.bloom.modeling_bloom.BloomForCausalLM":
|
||||||
|
PolicyLocation(file_name="bloom", class_name="BloomForCausalLMPolicy"),
|
||||||
|
"transformers.models.bloom.modeling_bloom.BloomForSequenceClassification":
|
||||||
|
PolicyLocation(file_name="bloom", class_name="BloomForSequenceClassificationPolicy"),
|
||||||
|
"transformers.models.bloom.modeling_bloom.BloomForTokenClassification":
|
||||||
|
PolicyLocation(file_name="bloom", class_name="BloomForTokenClassificationPolicy"),
|
||||||
|
"transformers.models.bloom.modeling_bloom.BloomForQuestionAnswering":
|
||||||
|
PolicyLocation(file_name="bloom", class_name="BloomForQuestionAnsweringPolicy"),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -52,6 +52,7 @@ class ModulePolicyDescription:
|
||||||
attribute_replacement: Dict[str, Any]
|
attribute_replacement: Dict[str, Any]
|
||||||
param_replacement: List[Callable]
|
param_replacement: List[Callable]
|
||||||
sub_module_replacement: List[SubModuleReplacementDescription]
|
sub_module_replacement: List[SubModuleReplacementDescription]
|
||||||
|
method_replacement: List[Callable] = None
|
||||||
|
|
||||||
|
|
||||||
class Policy(ABC):
|
class Policy(ABC):
|
||||||
|
|
|
@ -61,7 +61,7 @@ class BertPolicy(Policy):
|
||||||
),
|
),
|
||||||
SubModuleReplacementDescription(
|
SubModuleReplacementDescription(
|
||||||
suffix="attention.self.dropout",
|
suffix="attention.self.dropout",
|
||||||
target_module=col_nn.Dropout1D,
|
target_module=col_nn.DropoutForParallelInput,
|
||||||
),
|
),
|
||||||
SubModuleReplacementDescription(
|
SubModuleReplacementDescription(
|
||||||
suffix="attention.output.dense",
|
suffix="attention.output.dense",
|
||||||
|
@ -69,7 +69,7 @@ class BertPolicy(Policy):
|
||||||
),
|
),
|
||||||
SubModuleReplacementDescription(
|
SubModuleReplacementDescription(
|
||||||
suffix="attention.output.dropout",
|
suffix="attention.output.dropout",
|
||||||
target_module=col_nn.Dropout1D,
|
target_module=col_nn.DropoutForParallelInput,
|
||||||
),
|
),
|
||||||
SubModuleReplacementDescription(
|
SubModuleReplacementDescription(
|
||||||
suffix="intermediate.dense",
|
suffix="intermediate.dense",
|
||||||
|
@ -81,7 +81,7 @@ class BertPolicy(Policy):
|
||||||
),
|
),
|
||||||
SubModuleReplacementDescription(
|
SubModuleReplacementDescription(
|
||||||
suffix="output.dropout",
|
suffix="output.dropout",
|
||||||
target_module=col_nn.Dropout1D,
|
target_module=col_nn.DropoutForParallelInput,
|
||||||
)
|
)
|
||||||
]),
|
]),
|
||||||
BertEmbeddings:
|
BertEmbeddings:
|
||||||
|
@ -94,7 +94,7 @@ class BertPolicy(Policy):
|
||||||
),
|
),
|
||||||
SubModuleReplacementDescription(
|
SubModuleReplacementDescription(
|
||||||
suffix="dropout",
|
suffix="dropout",
|
||||||
target_module=col_nn.Dropout1D,
|
target_module=col_nn.DropoutForParallelInput,
|
||||||
)
|
)
|
||||||
])
|
])
|
||||||
}
|
}
|
||||||
|
@ -258,7 +258,7 @@ class BertForSequenceClassificationPolicy(BertPolicy):
|
||||||
sub_module_replacement=[
|
sub_module_replacement=[
|
||||||
SubModuleReplacementDescription(
|
SubModuleReplacementDescription(
|
||||||
suffix="dropout",
|
suffix="dropout",
|
||||||
target_module=col_nn.Dropout1D,
|
target_module=col_nn.DropoutForParallelInput,
|
||||||
)
|
)
|
||||||
])
|
])
|
||||||
}
|
}
|
||||||
|
@ -281,7 +281,7 @@ class BertForTokenClassificationPolicy(BertPolicy):
|
||||||
sub_module_replacement=[
|
sub_module_replacement=[
|
||||||
SubModuleReplacementDescription(
|
SubModuleReplacementDescription(
|
||||||
suffix="dropout",
|
suffix="dropout",
|
||||||
target_module=col_nn.Dropout1D,
|
target_module=col_nn.DropoutForParallelInput,
|
||||||
)
|
)
|
||||||
])
|
])
|
||||||
}
|
}
|
||||||
|
@ -311,7 +311,7 @@ class BertForMultipleChoicePolicy(BertPolicy):
|
||||||
sub_module_replacement=[
|
sub_module_replacement=[
|
||||||
SubModuleReplacementDescription(
|
SubModuleReplacementDescription(
|
||||||
suffix="dropout",
|
suffix="dropout",
|
||||||
target_module=col_nn.Dropout1D,
|
target_module=col_nn.DropoutForParallelInput,
|
||||||
)
|
)
|
||||||
])
|
])
|
||||||
}
|
}
|
||||||
|
|
|
@ -0,0 +1,214 @@
|
||||||
|
import torch
|
||||||
|
import torch.distributed as dist
|
||||||
|
|
||||||
|
import colossalai.shardformer.layer as col_nn
|
||||||
|
|
||||||
|
from .basepolicy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
|
||||||
|
|
||||||
|
|
||||||
|
def build_bloom_alibi_tensor(self, attention_mask: torch.Tensor, num_heads: int, dtype: torch.dtype) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Link to paper: https://arxiv.org/abs/2108.12409 Alibi tensor is not causal as the original paper mentions, it
|
||||||
|
relies on a translation invariance of softmax for quick implementation: with l being a tensor, and a fixed value
|
||||||
|
`softmax(l+a) = softmax(l)`. Based on
|
||||||
|
https://github.com/ofirpress/attention_with_linear_biases/blob/a35aaca144e0eb6b789dfcb46784c4b8e31b7983/fairseq/models/transformer.py#L742
|
||||||
|
TODO @thomasw21 this doesn't work as nicely due to the masking strategy, and so masking varies slightly.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
Returns tensor shaped (batch_size * num_heads, 1, max_seq_len)
|
||||||
|
attention_mask (`torch.Tensor`):
|
||||||
|
Token-wise attention mask, this should be of shape (batch_size, max_seq_len).
|
||||||
|
num_heads (`int`, *required*):
|
||||||
|
number of heads
|
||||||
|
dtype (`torch.dtype`, *optional*, default=`torch.bfloat16`):
|
||||||
|
dtype of the output tensor
|
||||||
|
"""
|
||||||
|
import math
|
||||||
|
|
||||||
|
if dist.is_initialized():
|
||||||
|
world_size = dist.get_world_size()
|
||||||
|
num_heads = num_heads * world_size
|
||||||
|
|
||||||
|
batch_size, seq_length = attention_mask.shape
|
||||||
|
closest_power_of_2 = 2**math.floor(math.log2(num_heads))
|
||||||
|
base = torch.tensor(2**(-(2**-(math.log2(closest_power_of_2) - 3))),
|
||||||
|
device=attention_mask.device,
|
||||||
|
dtype=torch.float32)
|
||||||
|
powers = torch.arange(1, 1 + closest_power_of_2, device=attention_mask.device, dtype=torch.int32)
|
||||||
|
slopes = torch.pow(base, powers)
|
||||||
|
|
||||||
|
if closest_power_of_2 != num_heads:
|
||||||
|
extra_base = torch.tensor(2**(-(2**-(math.log2(2 * closest_power_of_2) - 3))),
|
||||||
|
device=attention_mask.device,
|
||||||
|
dtype=torch.float32)
|
||||||
|
num_remaining_heads = min(closest_power_of_2, num_heads - closest_power_of_2)
|
||||||
|
extra_powers = torch.arange(1, 1 + 2 * num_remaining_heads, 2, device=attention_mask.device, dtype=torch.int32)
|
||||||
|
slopes = torch.cat([slopes, torch.pow(extra_base, extra_powers)], dim=0)
|
||||||
|
|
||||||
|
# Note: alibi will added to the attention bias that will be applied to the query, key product of attention
|
||||||
|
# => therefore alibi will have to be of shape (batch_size, num_heads, query_length, key_length)
|
||||||
|
# => here we set (batch_size=1, num_heads=num_heads, query_length=1, key_length=max_length)
|
||||||
|
# => the query_length dimension will then be broadcasted correctly
|
||||||
|
# This is more or less identical to T5's relative position bias:
|
||||||
|
# https://github.com/huggingface/transformers/blob/f681437203baa7671de3174b0fa583c349d9d5e1/src/transformers/models/t5/modeling_t5.py#L527
|
||||||
|
arange_tensor = ((attention_mask.cumsum(dim=-1) - 1) * attention_mask)[:, None, :]
|
||||||
|
alibi = slopes[..., None] * arange_tensor
|
||||||
|
if dist.is_initialized():
|
||||||
|
num_heads_per_rank = int(num_heads / dist.get_world_size())
|
||||||
|
offset = dist.get_rank() * num_heads_per_rank
|
||||||
|
alibi = alibi.view(batch_size, num_heads, 1, seq_length)
|
||||||
|
alibi = alibi[:, offset:num_heads_per_rank + offset, :, :]
|
||||||
|
return alibi.reshape(batch_size * num_heads_per_rank, 1, seq_length).to(dtype)
|
||||||
|
else:
|
||||||
|
return alibi.reshape(batch_size * num_heads, 1, seq_length).to(dtype)
|
||||||
|
|
||||||
|
|
||||||
|
class BloomPolicy(Policy):
|
||||||
|
|
||||||
|
def preprocess(self):
|
||||||
|
# reshape the embedding layer
|
||||||
|
r"""
|
||||||
|
Reshape the Embedding layer to make the embedding dimension divisible by world_size
|
||||||
|
"""
|
||||||
|
# TODO:
|
||||||
|
vocab_size = self.model.config.vocab_size
|
||||||
|
world_size = self.shard_config.tensor_parallel_size
|
||||||
|
if vocab_size % world_size != 0:
|
||||||
|
new_vocab_size = vocab_size + world_size - vocab_size % world_size
|
||||||
|
self.model.resize_token_embeddings(new_vocab_size)
|
||||||
|
return self.model
|
||||||
|
|
||||||
|
def module_policy(self):
|
||||||
|
from transformers.models.bloom.modeling_bloom import BloomBlock, BloomModel
|
||||||
|
|
||||||
|
return {
|
||||||
|
BloomBlock:
|
||||||
|
ModulePolicyDescription(
|
||||||
|
attribute_replacement={
|
||||||
|
# 1. shard hidden size
|
||||||
|
"self_attention.hidden_size":
|
||||||
|
self.model.config.hidden_size // self.shard_config.tensor_parallel_size,
|
||||||
|
"self_attention.split_size":
|
||||||
|
self.model.config.hidden_size // self.shard_config.tensor_parallel_size,
|
||||||
|
# 2. shard number of heads
|
||||||
|
"self_attention.num_heads":
|
||||||
|
self.model.config.n_head // self.shard_config.tensor_parallel_size,
|
||||||
|
},
|
||||||
|
param_replacement=[],
|
||||||
|
sub_module_replacement=[
|
||||||
|
SubModuleReplacementDescription(
|
||||||
|
suffix="self_attention.query_key_value",
|
||||||
|
target_module=col_nn.Linear1D_Col,
|
||||||
|
# kwargs={'n_fused': 3}
|
||||||
|
),
|
||||||
|
SubModuleReplacementDescription(
|
||||||
|
suffix="self_attention.dense",
|
||||||
|
target_module=col_nn.Linear1D_Row,
|
||||||
|
),
|
||||||
|
SubModuleReplacementDescription(
|
||||||
|
suffix="self_attention.attention_dropout",
|
||||||
|
target_module=col_nn.DropoutForParallelInput,
|
||||||
|
),
|
||||||
|
SubModuleReplacementDescription(
|
||||||
|
suffix="mlp.dense_h_to_4h",
|
||||||
|
target_module=col_nn.Linear1D_Col,
|
||||||
|
),
|
||||||
|
SubModuleReplacementDescription(
|
||||||
|
suffix="mlp.dense_4h_to_h",
|
||||||
|
target_module=col_nn.Linear1D_Row,
|
||||||
|
),
|
||||||
|
]),
|
||||||
|
BloomModel:
|
||||||
|
ModulePolicyDescription(attribute_replacement={
|
||||||
|
"num_heads": self.model.config.n_head // self.shard_config.tensor_parallel_size,
|
||||||
|
},
|
||||||
|
param_replacement=[],
|
||||||
|
method_replacement={"build_alibi_tensor": build_bloom_alibi_tensor},
|
||||||
|
sub_module_replacement=[
|
||||||
|
SubModuleReplacementDescription(
|
||||||
|
suffix="word_embeddings",
|
||||||
|
target_module=col_nn.VocabParallelEmbedding1D,
|
||||||
|
)
|
||||||
|
])
|
||||||
|
}
|
||||||
|
|
||||||
|
def new_model_class(self):
|
||||||
|
# do nothing
|
||||||
|
return self.model
|
||||||
|
|
||||||
|
def postprocess(self):
|
||||||
|
return self.model
|
||||||
|
|
||||||
|
|
||||||
|
# BertModel
|
||||||
|
class BloomModelPolicy(BloomPolicy):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class BloomForCausalLMPolicy(BloomPolicy):
|
||||||
|
|
||||||
|
def module_policy(self):
|
||||||
|
from transformers.models.bloom.modeling_bloom import BloomForCausalLM
|
||||||
|
policy = super().module_policy()
|
||||||
|
# add a new item for casual lm
|
||||||
|
new_item = {
|
||||||
|
BloomForCausalLM:
|
||||||
|
ModulePolicyDescription(attribute_replacement={},
|
||||||
|
param_replacement=[],
|
||||||
|
sub_module_replacement=[
|
||||||
|
SubModuleReplacementDescription(suffix="lm_head",
|
||||||
|
target_module=col_nn.Linear1D_Col,
|
||||||
|
kwargs=dict(gather_output=True))
|
||||||
|
])
|
||||||
|
}
|
||||||
|
policy.update(new_item)
|
||||||
|
return policy
|
||||||
|
|
||||||
|
|
||||||
|
class BloomForSequenceClassificationPolicy(BloomPolicy):
|
||||||
|
|
||||||
|
def module_policy(self):
|
||||||
|
from transformers.models.bloom.modeling_bloom import BloomForSequenceClassification
|
||||||
|
policy = super().module_policy()
|
||||||
|
# add a new item for casual lm
|
||||||
|
new_item = {
|
||||||
|
BloomForSequenceClassification:
|
||||||
|
ModulePolicyDescription(attribute_replacement={},
|
||||||
|
param_replacement=[],
|
||||||
|
sub_module_replacement=[
|
||||||
|
SubModuleReplacementDescription(suffix="score",
|
||||||
|
target_module=col_nn.Linear1D_Col,
|
||||||
|
kwargs=dict(gather_output=True))
|
||||||
|
])
|
||||||
|
}
|
||||||
|
policy.update(new_item)
|
||||||
|
return policy
|
||||||
|
|
||||||
|
|
||||||
|
class BloomForTokenClassificationPolicy(BloomPolicy):
|
||||||
|
|
||||||
|
def module_policy(self):
|
||||||
|
from transformers.models.bloom.modeling_bloom import BloomForTokenClassification
|
||||||
|
policy = super().module_policy()
|
||||||
|
# add a new item for casual lm
|
||||||
|
new_item = {
|
||||||
|
BloomForTokenClassification:
|
||||||
|
ModulePolicyDescription(attribute_replacement={},
|
||||||
|
param_replacement=[],
|
||||||
|
sub_module_replacement=[
|
||||||
|
SubModuleReplacementDescription(suffix="classifier",
|
||||||
|
target_module=col_nn.Linear1D_Col,
|
||||||
|
kwargs=dict(gather_output=True)),
|
||||||
|
SubModuleReplacementDescription(
|
||||||
|
suffix="dropout",
|
||||||
|
target_module=col_nn.DropoutForReplicatedInput,
|
||||||
|
),
|
||||||
|
])
|
||||||
|
}
|
||||||
|
policy.update(new_item)
|
||||||
|
return policy
|
||||||
|
|
||||||
|
|
||||||
|
class BloomForQuestionAnsweringPolicy(BloomPolicy):
|
||||||
|
# No head sharding as the output features is only 2
|
||||||
|
pass
|
|
@ -42,37 +42,37 @@ class GPT2Policy(Policy):
|
||||||
sub_module_replacement=[
|
sub_module_replacement=[
|
||||||
SubModuleReplacementDescription(
|
SubModuleReplacementDescription(
|
||||||
suffix="attn.c_attn",
|
suffix="attn.c_attn",
|
||||||
target_module=col_nn.LinearConv1D_Col,
|
target_module=col_nn.GPT2FusedLinearConv1D_Col,
|
||||||
kwargs={
|
kwargs={
|
||||||
"n_fused": 3,
|
"n_fused": 3,
|
||||||
},
|
},
|
||||||
),
|
),
|
||||||
SubModuleReplacementDescription(
|
SubModuleReplacementDescription(
|
||||||
suffix="attn.c_proj",
|
suffix="attn.c_proj",
|
||||||
target_module=col_nn.LinearConv1D_Row,
|
target_module=col_nn.GPT2FusedLinearConv1D_Row,
|
||||||
),
|
),
|
||||||
SubModuleReplacementDescription(
|
SubModuleReplacementDescription(
|
||||||
suffix="mlp.c_fc",
|
suffix="mlp.c_fc",
|
||||||
target_module=col_nn.LinearConv1D_Col,
|
target_module=col_nn.GPT2FusedLinearConv1D_Col,
|
||||||
kwargs={
|
kwargs={
|
||||||
"n_fused": 1,
|
"n_fused": 1,
|
||||||
},
|
},
|
||||||
),
|
),
|
||||||
SubModuleReplacementDescription(
|
SubModuleReplacementDescription(
|
||||||
suffix="mlp.c_proj",
|
suffix="mlp.c_proj",
|
||||||
target_module=col_nn.LinearConv1D_Row,
|
target_module=col_nn.GPT2FusedLinearConv1D_Row,
|
||||||
),
|
),
|
||||||
SubModuleReplacementDescription(
|
SubModuleReplacementDescription(
|
||||||
suffix="attn.attn_dropout",
|
suffix="attn.attn_dropout",
|
||||||
target_module=col_nn.Dropout1D,
|
target_module=col_nn.DropoutForParallelInput,
|
||||||
),
|
),
|
||||||
SubModuleReplacementDescription(
|
SubModuleReplacementDescription(
|
||||||
suffix="attn.resid_dropout",
|
suffix="attn.resid_dropout",
|
||||||
target_module=col_nn.Dropout1D,
|
target_module=col_nn.DropoutForParallelInput,
|
||||||
),
|
),
|
||||||
SubModuleReplacementDescription(
|
SubModuleReplacementDescription(
|
||||||
suffix="mlp.dropout",
|
suffix="mlp.dropout",
|
||||||
target_module=col_nn.Dropout1D,
|
target_module=col_nn.DropoutForParallelInput,
|
||||||
),
|
),
|
||||||
])
|
])
|
||||||
}
|
}
|
||||||
|
|
|
@ -9,7 +9,7 @@ from transformers.models.t5.modeling_t5 import (
|
||||||
T5Stack,
|
T5Stack,
|
||||||
)
|
)
|
||||||
|
|
||||||
from colossalai.shardformer.layer import Dropout1D, Embedding1D, Linear1D_Col, Linear1D_Row
|
from colossalai.shardformer.layer import DropoutForParallelInput, Embedding1D, Linear1D_Col, Linear1D_Row
|
||||||
|
|
||||||
from .basepolicy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
|
from .basepolicy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
|
||||||
|
|
||||||
|
@ -38,7 +38,7 @@ class T5ModelPolicy(Policy):
|
||||||
sub_module_replacement=[
|
sub_module_replacement=[
|
||||||
SubModuleReplacementDescription(
|
SubModuleReplacementDescription(
|
||||||
suffix="dropout",
|
suffix="dropout",
|
||||||
target_module=Dropout1D,
|
target_module=DropoutForParallelInput,
|
||||||
)
|
)
|
||||||
]),
|
]),
|
||||||
T5LayerSelfAttention:
|
T5LayerSelfAttention:
|
||||||
|
@ -47,7 +47,7 @@ class T5ModelPolicy(Policy):
|
||||||
sub_module_replacement=[
|
sub_module_replacement=[
|
||||||
SubModuleReplacementDescription(
|
SubModuleReplacementDescription(
|
||||||
suffix="dropout",
|
suffix="dropout",
|
||||||
target_module=Dropout1D,
|
target_module=DropoutForParallelInput,
|
||||||
),
|
),
|
||||||
]),
|
]),
|
||||||
T5LayerCrossAttention:
|
T5LayerCrossAttention:
|
||||||
|
@ -56,7 +56,7 @@ class T5ModelPolicy(Policy):
|
||||||
sub_module_replacement=[
|
sub_module_replacement=[
|
||||||
SubModuleReplacementDescription(
|
SubModuleReplacementDescription(
|
||||||
suffix="dropout",
|
suffix="dropout",
|
||||||
target_module=Dropout1D,
|
target_module=DropoutForParallelInput,
|
||||||
)
|
)
|
||||||
]),
|
]),
|
||||||
T5Attention:
|
T5Attention:
|
||||||
|
@ -97,7 +97,7 @@ class T5ModelPolicy(Policy):
|
||||||
sub_module_replacement=[
|
sub_module_replacement=[
|
||||||
SubModuleReplacementDescription(
|
SubModuleReplacementDescription(
|
||||||
suffix="dropout",
|
suffix="dropout",
|
||||||
target_module=Dropout1D,
|
target_module=DropoutForParallelInput,
|
||||||
),
|
),
|
||||||
]),
|
]),
|
||||||
T5DenseGatedActDense:
|
T5DenseGatedActDense:
|
||||||
|
@ -117,7 +117,7 @@ class T5ModelPolicy(Policy):
|
||||||
kwargs=dict(gather_output=True)),
|
kwargs=dict(gather_output=True)),
|
||||||
SubModuleReplacementDescription(
|
SubModuleReplacementDescription(
|
||||||
suffix="dropout",
|
suffix="dropout",
|
||||||
target_module=Dropout1D,
|
target_module=DropoutForParallelInput,
|
||||||
)
|
)
|
||||||
]),
|
]),
|
||||||
T5DenseActDense:
|
T5DenseActDense:
|
||||||
|
@ -134,7 +134,7 @@ class T5ModelPolicy(Policy):
|
||||||
),
|
),
|
||||||
SubModuleReplacementDescription(
|
SubModuleReplacementDescription(
|
||||||
suffix="dropout",
|
suffix="dropout",
|
||||||
target_module=Dropout1D,
|
target_module=DropoutForParallelInput,
|
||||||
)
|
)
|
||||||
])
|
])
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,13 +1,13 @@
|
||||||
from typing import Dict, Union
|
from typing import Dict, Union
|
||||||
|
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
from transformers.models.vit.modeling_vit import ViTAttention, ViTEmbeddings, ViTLayer, ViTModel
|
||||||
|
|
||||||
from transformers.models.vit.modeling_vit import ViTModel, ViTLayer, ViTEmbeddings, ViTAttention
|
from colossalai.shardformer.layer import DropoutForReplicatedInput, Linear1D_Col, Linear1D_Row
|
||||||
|
|
||||||
from colossalai.shardformer.layer import Linear1D_Col, Linear1D_Row, Dropout1D
|
|
||||||
|
|
||||||
from .basepolicy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
|
from .basepolicy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
|
||||||
|
|
||||||
|
|
||||||
class ViTPolicy(Policy):
|
class ViTPolicy(Policy):
|
||||||
|
|
||||||
def preprocess(self):
|
def preprocess(self):
|
||||||
|
@ -24,19 +24,16 @@ class ViTPolicy(Policy):
|
||||||
def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:
|
def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:
|
||||||
return {
|
return {
|
||||||
ViTEmbeddings:
|
ViTEmbeddings:
|
||||||
ModulePolicyDescription(
|
ModulePolicyDescription(attribute_replacement={},
|
||||||
attribute_replacement{},
|
|
||||||
param_replacement=[],
|
param_replacement=[],
|
||||||
sub_module_replacement=[
|
sub_module_replacement=[
|
||||||
SubModuleReplacementDescription(
|
SubModuleReplacementDescription(
|
||||||
suffix="dropout",
|
suffix="dropout",
|
||||||
target_module=Dropout1D,
|
target_module=DropoutForReplicatedInput,
|
||||||
)
|
)
|
||||||
]
|
]),
|
||||||
),
|
|
||||||
ViTLayer:
|
ViTLayer:
|
||||||
ModulePolicyDescription(
|
ModulePolicyDescription(attribute_replacement={
|
||||||
attribute_replacement{
|
|
||||||
"attention.attention.num_attention_heads":
|
"attention.attention.num_attention_heads":
|
||||||
self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size,
|
self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size,
|
||||||
"attention.attention.all_head_size":
|
"attention.attention.all_head_size":
|
||||||
|
@ -58,7 +55,7 @@ class ViTPolicy(Policy):
|
||||||
),
|
),
|
||||||
SubModuleReplacementDescription(
|
SubModuleReplacementDescription(
|
||||||
suffix="attention.attention.dropout",
|
suffix="attention.attention.dropout",
|
||||||
target_module=Dropout1D,
|
target_module=DropoutForParallelInput,
|
||||||
),
|
),
|
||||||
SubModuleReplacementDescription(
|
SubModuleReplacementDescription(
|
||||||
suffix="attention.output.dense",
|
suffix="attention.output.dense",
|
||||||
|
@ -66,7 +63,7 @@ class ViTPolicy(Policy):
|
||||||
),
|
),
|
||||||
SubModuleReplacementDescription(
|
SubModuleReplacementDescription(
|
||||||
suffix="attention.output.dropout",
|
suffix="attention.output.dropout",
|
||||||
target_module=Dropout1D,
|
target_module=DropoutForParallelInput,
|
||||||
),
|
),
|
||||||
SubModuleReplacementDescription(
|
SubModuleReplacementDescription(
|
||||||
suffix="intermediate.dense",
|
suffix="intermediate.dense",
|
||||||
|
@ -78,10 +75,9 @@ class ViTPolicy(Policy):
|
||||||
),
|
),
|
||||||
SubModuleReplacementDescription(
|
SubModuleReplacementDescription(
|
||||||
suffix="output.dropout",
|
suffix="output.dropout",
|
||||||
target_module=Dropout1D,
|
target_module=DropoutForParallelInput,
|
||||||
),
|
|
||||||
]
|
|
||||||
),
|
),
|
||||||
|
]),
|
||||||
}
|
}
|
||||||
|
|
||||||
def new_model_class(self):
|
def new_model_class(self):
|
||||||
|
@ -89,8 +85,3 @@ class ViTPolicy(Policy):
|
||||||
|
|
||||||
def postprocess(self):
|
def postprocess(self):
|
||||||
return self.model
|
return self.model
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -95,8 +95,9 @@ class ModelSharder(object):
|
||||||
attr_replacement = module_description[1].attribute_replacement
|
attr_replacement = module_description[1].attribute_replacement
|
||||||
param_replacement = module_description[1].param_replacement
|
param_replacement = module_description[1].param_replacement
|
||||||
sub_module_replacement = module_description[1].sub_module_replacement
|
sub_module_replacement = module_description[1].sub_module_replacement
|
||||||
|
method_replacement = module_description[1].method_replacement
|
||||||
self._recursive_replace_layer(self.model, origin_layer_cls, attr_replacement, param_replacement,
|
self._recursive_replace_layer(self.model, origin_layer_cls, attr_replacement, param_replacement,
|
||||||
sub_module_replacement)
|
method_replacement, sub_module_replacement)
|
||||||
|
|
||||||
def _recursive_replace_layer(
|
def _recursive_replace_layer(
|
||||||
self,
|
self,
|
||||||
|
@ -104,6 +105,7 @@ class ModelSharder(object):
|
||||||
origin_cls: nn.Module,
|
origin_cls: nn.Module,
|
||||||
attr_replacement: Dict[str, Any],
|
attr_replacement: Dict[str, Any],
|
||||||
param_replacement: List[Callable],
|
param_replacement: List[Callable],
|
||||||
|
method_replacement: Dict[str, Callable],
|
||||||
sub_module_replacement: List[Callable],
|
sub_module_replacement: List[Callable],
|
||||||
) -> None:
|
) -> None:
|
||||||
r"""
|
r"""
|
||||||
|
@ -119,9 +121,11 @@ class ModelSharder(object):
|
||||||
if module.__class__ == origin_cls:
|
if module.__class__ == origin_cls:
|
||||||
self._replace_attr(module, attr_replacement)
|
self._replace_attr(module, attr_replacement)
|
||||||
self._replace_param(module, param_replacement)
|
self._replace_param(module, param_replacement)
|
||||||
|
self._replace_method(module, method_replacement)
|
||||||
self._replace_sub_module(module, sub_module_replacement)
|
self._replace_sub_module(module, sub_module_replacement)
|
||||||
|
|
||||||
for name, child in module.named_children():
|
for name, child in module.named_children():
|
||||||
self._recursive_replace_layer(child, origin_cls, attr_replacement, param_replacement,
|
self._recursive_replace_layer(child, origin_cls, attr_replacement, param_replacement, method_replacement,
|
||||||
sub_module_replacement)
|
sub_module_replacement)
|
||||||
|
|
||||||
def _replace_attr(
|
def _replace_attr(
|
||||||
|
@ -154,6 +158,14 @@ class ModelSharder(object):
|
||||||
# TODO: support parameter shard
|
# TODO: support parameter shard
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
def _replace_method(self, module: nn.Module, method_replacement: Dict[str, Callable]):
|
||||||
|
if method_replacement is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
for method_name, new_method in method_replacement.items():
|
||||||
|
# bind the new method to the module
|
||||||
|
setattr(module, method_name, new_method.__get__(module, module.__class__))
|
||||||
|
|
||||||
def _replace_sub_module(
|
def _replace_sub_module(
|
||||||
self,
|
self,
|
||||||
org_layer: nn.Module,
|
org_layer: nn.Module,
|
||||||
|
|
|
@ -1,5 +1,6 @@
|
||||||
from .albert import *
|
from .albert import *
|
||||||
from .bert import *
|
from .bert import *
|
||||||
|
from .bloom import *
|
||||||
from .gpt import *
|
from .gpt import *
|
||||||
from .llama import *
|
from .llama import *
|
||||||
from .opt import *
|
from .opt import *
|
||||||
|
|
|
@ -0,0 +1,107 @@
|
||||||
|
import torch
|
||||||
|
import transformers
|
||||||
|
|
||||||
|
from ..registry import ModelAttribute, model_zoo
|
||||||
|
|
||||||
|
# ===============================
|
||||||
|
# Register Bloom
|
||||||
|
# ===============================
|
||||||
|
|
||||||
|
|
||||||
|
def data_gen():
|
||||||
|
# Generated from following code snippet
|
||||||
|
#
|
||||||
|
# from transformers import BloomTokenizer
|
||||||
|
# input = 'Hello, my dog is cute'
|
||||||
|
# tokenized_input = tokenizer(input, return_tensors='pt')
|
||||||
|
# input_ids = tokenized_input['input_ids']
|
||||||
|
# attention_mask = tokenized_input['attention_mask']
|
||||||
|
input_ids = torch.tensor([[59414, 15, 2670, 35433, 632, 207595]], dtype=torch.int64)
|
||||||
|
attention_mask = torch.tensor([[1, 1, 1, 1, 1, 1]], dtype=torch.int64)
|
||||||
|
return dict(input_ids=input_ids, attention_mask=attention_mask)
|
||||||
|
|
||||||
|
|
||||||
|
def data_gen_for_lm():
|
||||||
|
# LM data gen
|
||||||
|
# the `labels` of LM is the token of the output, cause no padding, use `input_ids` as `labels`
|
||||||
|
data = data_gen()
|
||||||
|
data['labels'] = data['input_ids'].clone()
|
||||||
|
return data
|
||||||
|
|
||||||
|
|
||||||
|
def data_gen_for_token_classification():
|
||||||
|
# token classification data gen
|
||||||
|
# `labels` is the type not the token id for token classification, 0 or 1
|
||||||
|
data = data_gen()
|
||||||
|
data['labels'] = torch.tensor([[0, 0, 0, 0, 0, 0]], dtype=torch.int64)
|
||||||
|
return data
|
||||||
|
|
||||||
|
|
||||||
|
def data_gen_for_sequence_classification():
|
||||||
|
# sequence classification data gen
|
||||||
|
data = data_gen()
|
||||||
|
data['labels'] = torch.tensor([0], dtype=torch.int64)
|
||||||
|
return data
|
||||||
|
|
||||||
|
|
||||||
|
def data_gen_for_question_answering():
|
||||||
|
# obtained with the following code
|
||||||
|
#
|
||||||
|
# from transformers import AutoTokenizer
|
||||||
|
# tokenizer = AutoTokenizer.from_pretrained("bigscience/bloom-560m")
|
||||||
|
# question, text = "Who was Jim Henson?", "Jim Henson was a nice puppet"
|
||||||
|
# inputs = tokenizer(question, text, return_tensors="pt")
|
||||||
|
|
||||||
|
input_ids = torch.tensor(
|
||||||
|
[[57647, 1620, 23967, 620, 107373, 34, 91514, 620, 107373, 1620, 267, 35378, 48946, 18161]], dtype=torch.int64)
|
||||||
|
attention_mask = torch.tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]], dtype=torch.int64)
|
||||||
|
return dict(input_ids=input_ids, attention_mask=attention_mask)
|
||||||
|
|
||||||
|
|
||||||
|
# define output transform function
|
||||||
|
output_transform_fn = lambda x: x
|
||||||
|
|
||||||
|
# define loss function
|
||||||
|
loss_fn_for_bloom_model = lambda x: x.last_hidden_state.mean()
|
||||||
|
loss_fn_for_causal_lm = lambda x: x.loss
|
||||||
|
loss_fn_for_classification = lambda x: x.logits.mean()
|
||||||
|
loss_fn_for_question_answering = lambda x: x.end_logits.mean()
|
||||||
|
|
||||||
|
config = transformers.BloomConfig(n_layer=1,
|
||||||
|
n_head=4,
|
||||||
|
vocab_size=250880,
|
||||||
|
hidden_dropout=0,
|
||||||
|
attention_dropout=0,
|
||||||
|
hidden_size=64)
|
||||||
|
|
||||||
|
# register the following models
|
||||||
|
model_zoo.register(name='transformers_bloom',
|
||||||
|
model_fn=lambda: transformers.BloomModel(config),
|
||||||
|
data_gen_fn=data_gen,
|
||||||
|
output_transform_fn=output_transform_fn,
|
||||||
|
loss_fn=loss_fn_for_bloom_model,
|
||||||
|
model_attribute=ModelAttribute(has_control_flow=True))
|
||||||
|
model_zoo.register(name='transformers_bloom_for_causal_lm',
|
||||||
|
model_fn=lambda: transformers.BloomForCausalLM(config),
|
||||||
|
data_gen_fn=data_gen_for_lm,
|
||||||
|
output_transform_fn=output_transform_fn,
|
||||||
|
loss_fn=loss_fn_for_causal_lm,
|
||||||
|
model_attribute=ModelAttribute(has_control_flow=True))
|
||||||
|
model_zoo.register(name='transformers_bloom_for_sequence_classification',
|
||||||
|
model_fn=lambda: transformers.BloomForSequenceClassification(config),
|
||||||
|
data_gen_fn=data_gen_for_sequence_classification,
|
||||||
|
output_transform_fn=output_transform_fn,
|
||||||
|
loss_fn=loss_fn_for_classification,
|
||||||
|
model_attribute=ModelAttribute(has_control_flow=True))
|
||||||
|
model_zoo.register(name='transformers_bloom_for_token_classification',
|
||||||
|
model_fn=lambda: transformers.BloomForTokenClassification(config),
|
||||||
|
data_gen_fn=data_gen_for_token_classification,
|
||||||
|
output_transform_fn=output_transform_fn,
|
||||||
|
loss_fn=loss_fn_for_classification,
|
||||||
|
model_attribute=ModelAttribute(has_control_flow=True))
|
||||||
|
model_zoo.register(name='transformers_bloom_for_question_answering',
|
||||||
|
model_fn=lambda: transformers.BloomForQuestionAnswering(config),
|
||||||
|
data_gen_fn=data_gen_for_question_answering,
|
||||||
|
output_transform_fn=output_transform_fn,
|
||||||
|
loss_fn=loss_fn_for_question_answering,
|
||||||
|
model_attribute=ModelAttribute(has_control_flow=True))
|
|
@ -6,8 +6,6 @@ from ..registry import ModelAttribute, model_zoo
|
||||||
# ===============================
|
# ===============================
|
||||||
# Register single-sentence GPT
|
# Register single-sentence GPT
|
||||||
# ===============================
|
# ===============================
|
||||||
BATCH_SIZE = 1 # it can only be 1 as GPT cannot handle batch sizes > 1 if no padding token is defined.
|
|
||||||
SEQ_LENGTH = 16
|
|
||||||
|
|
||||||
|
|
||||||
def data_gen():
|
def data_gen():
|
||||||
|
|
|
@ -3,13 +3,13 @@ import torch.distributed as dist
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
import colossalai
|
import colossalai
|
||||||
from colossalai.shardformer.layer import Dropout1D
|
from colossalai.shardformer.layer import DropoutForParallelInput, DropoutForReplicatedInput
|
||||||
from colossalai.testing import assert_equal, assert_not_equal, rerun_if_address_is_in_use, spawn
|
from colossalai.testing import assert_equal, assert_not_equal, rerun_if_address_is_in_use, spawn
|
||||||
|
|
||||||
|
|
||||||
def check_dropout():
|
def check_dropout_parallel_input():
|
||||||
dropout = nn.Dropout().cuda()
|
dropout = nn.Dropout().cuda()
|
||||||
dropout_1d = Dropout1D.from_native_module(dropout, process_group=None)
|
dropout_1d = DropoutForParallelInput.from_native_module(dropout, process_group=None)
|
||||||
|
|
||||||
# check computation correctness
|
# check computation correctness
|
||||||
x = torch.rand(4, 128).cuda()
|
x = torch.rand(4, 128).cuda()
|
||||||
|
@ -39,9 +39,26 @@ def check_dropout():
|
||||||
assert_not_equal(out_1d_all[i], out_1d_all[0])
|
assert_not_equal(out_1d_all[i], out_1d_all[0])
|
||||||
|
|
||||||
|
|
||||||
|
def check_dropout_replicated_input():
|
||||||
|
dropout = nn.Dropout().cuda()
|
||||||
|
dropout_replica = DropoutForReplicatedInput.from_native_module(dropout, process_group=None)
|
||||||
|
|
||||||
|
# check computation correctness
|
||||||
|
x = torch.rand(4, 128).cuda()
|
||||||
|
out_1d = dropout_replica(x)
|
||||||
|
|
||||||
|
# ensure out_1d is different across ranks
|
||||||
|
world_size = dist.get_world_size()
|
||||||
|
out_1d_all = [torch.zeros_like(out_1d) for _ in range(world_size)]
|
||||||
|
dist.all_gather(out_1d_all, out_1d)
|
||||||
|
for i in range(1, world_size):
|
||||||
|
assert_equal(out_1d_all[i], out_1d_all[0])
|
||||||
|
|
||||||
|
|
||||||
def run_dist(rank, world_size, port):
|
def run_dist(rank, world_size, port):
|
||||||
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||||
check_dropout()
|
check_dropout_parallel_input()
|
||||||
|
check_dropout_replicated_input()
|
||||||
|
|
||||||
|
|
||||||
@rerun_if_address_is_in_use()
|
@rerun_if_address_is_in_use()
|
||||||
|
|
|
@ -4,8 +4,8 @@ import torch.nn as nn
|
||||||
from torch.testing import assert_close
|
from torch.testing import assert_close
|
||||||
|
|
||||||
import colossalai
|
import colossalai
|
||||||
from colossalai.shardformer.layer import LinearConv1D_Col, LinearConv1D_Row
|
from colossalai.shardformer.layer import GPT2FusedLinearConv1D_Col, GPT2FusedLinearConv1D_Row
|
||||||
from colossalai.shardformer.layer.linear_conv import split_fused_qkv
|
from colossalai.shardformer.layer.qkv_fused_linear import split_fused_qkv_in_gpt2_style
|
||||||
from colossalai.testing import rerun_if_address_is_in_use, spawn
|
from colossalai.testing import rerun_if_address_is_in_use, spawn
|
||||||
|
|
||||||
|
|
||||||
|
@ -52,7 +52,10 @@ def rearrange(tensor: torch.Tensor, dim: int):
|
||||||
|
|
||||||
def check_linear_conv_1d_col():
|
def check_linear_conv_1d_col():
|
||||||
linear = Conv1D(192, 48).cuda()
|
linear = Conv1D(192, 48).cuda()
|
||||||
linear_conv_col = LinearConv1D_Col.from_native_module(linear, process_group=None, gather_output=True, n_fused=3)
|
linear_conv_col = GPT2FusedLinearConv1D_Col.from_native_module(linear,
|
||||||
|
process_group=None,
|
||||||
|
gather_output=True,
|
||||||
|
n_fused=3)
|
||||||
|
|
||||||
assert linear.weight.shape == torch.Size([48, 192])
|
assert linear.weight.shape == torch.Size([48, 192])
|
||||||
assert linear.bias.shape == torch.Size([192])
|
assert linear.bias.shape == torch.Size([192])
|
||||||
|
@ -73,13 +76,13 @@ def check_linear_conv_1d_col():
|
||||||
out.sum().backward()
|
out.sum().backward()
|
||||||
gather_out.sum().backward()
|
gather_out.sum().backward()
|
||||||
|
|
||||||
target_grad = split_fused_qkv(linear.weight.grad, 3, None)
|
target_grad = split_fused_qkv_in_gpt2_style(linear.weight.grad, 3, None, True)
|
||||||
assert_close(target_grad, linear_conv_col.weight.grad)
|
assert_close(target_grad, linear_conv_col.weight.grad)
|
||||||
|
|
||||||
|
|
||||||
def check_linear_conv_1d_row():
|
def check_linear_conv_1d_row():
|
||||||
linear = Conv1D(192, 48).cuda()
|
linear = Conv1D(192, 48).cuda()
|
||||||
linear_row = LinearConv1D_Row.from_native_module(linear, process_group=None, parallel_input=False)
|
linear_row = GPT2FusedLinearConv1D_Row.from_native_module(linear, process_group=None, parallel_input=False)
|
||||||
|
|
||||||
assert linear.weight.shape == torch.Size([48, 192])
|
assert linear.weight.shape == torch.Size([48, 192])
|
||||||
assert linear_row.weight.shape == torch.Size([24, 192])
|
assert linear_row.weight.shape == torch.Size([24, 192])
|
||||||
|
@ -102,6 +105,8 @@ def check_linear_conv_1d_row():
|
||||||
|
|
||||||
def run_dist(rank, world_size, port):
|
def run_dist(rank, world_size, port):
|
||||||
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||||
|
|
||||||
|
# test for linear conv
|
||||||
check_linear_conv_1d_col()
|
check_linear_conv_1d_col()
|
||||||
check_linear_conv_1d_row()
|
check_linear_conv_1d_row()
|
||||||
|
|
|
@ -0,0 +1,59 @@
|
||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
|
||||||
|
import colossalai
|
||||||
|
from colossalai.logging import disable_existing_loggers
|
||||||
|
from colossalai.testing import assert_hf_output_close, clear_cache_before_run, rerun_if_address_is_in_use, spawn
|
||||||
|
from tests.kit.model_zoo import model_zoo
|
||||||
|
from tests.test_shardformer.test_model._utils import build_model, run_forward
|
||||||
|
|
||||||
|
|
||||||
|
def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn):
|
||||||
|
# check forward
|
||||||
|
org_output, org_loss, shard_output, shard_loss = run_forward(org_model, sharded_model, data_gen_fn,
|
||||||
|
output_transform_fn, loss_fn)
|
||||||
|
assert_hf_output_close(org_output, shard_output, ignore_keys=['past_key_values'])
|
||||||
|
|
||||||
|
# do backward
|
||||||
|
org_loss.backward()
|
||||||
|
shard_loss.backward()
|
||||||
|
|
||||||
|
# check grad equality
|
||||||
|
if org_model.__class__.__name__ == 'BloomModel':
|
||||||
|
org_grad = org_model.h[0].self_attention.query_key_value.weight.grad
|
||||||
|
shard_grad = sharded_model.h[0].self_attention.query_key_value.weight.grad
|
||||||
|
else:
|
||||||
|
org_grad = org_model.transformer.h[0].self_attention.query_key_value.weight.grad
|
||||||
|
shard_grad = sharded_model.transformer.h[0].self_attention.query_key_value.weight.grad
|
||||||
|
|
||||||
|
shard_grad_list = [torch.zeros([*shard_grad.shape]).to('cuda') for _ in range(2)]
|
||||||
|
torch.distributed.all_gather(shard_grad_list, shard_grad)
|
||||||
|
all_shard_grad = torch.cat(shard_grad_list, dim=0)
|
||||||
|
|
||||||
|
assert torch.allclose(org_loss, shard_loss,
|
||||||
|
atol=1e-5), f"shard model loss is not equal to orgin model loss\n{org_loss}\n{shard_loss}"
|
||||||
|
assert torch.allclose(org_grad, all_shard_grad,
|
||||||
|
atol=1e-5), f"shard model grad is not equal to orgin model grad\n{org_grad}\n{all_shard_grad}"
|
||||||
|
|
||||||
|
|
||||||
|
def check_bloom(rank, world_size, port):
|
||||||
|
disable_existing_loggers()
|
||||||
|
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||||
|
|
||||||
|
sub_model_zoo = model_zoo.get_sub_registry('transformers_bloom')
|
||||||
|
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
|
||||||
|
org_model, sharded_model = build_model(world_size, model_fn)
|
||||||
|
check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn)
|
||||||
|
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.dist
|
||||||
|
@rerun_if_address_is_in_use()
|
||||||
|
@clear_cache_before_run()
|
||||||
|
def test_bloom():
|
||||||
|
spawn(check_bloom, 2)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
test_bloom()
|
Loading…
Reference in New Issue