diff --git a/colossalai/shardformer/README.md b/colossalai/shardformer/README.md index da80a7276..8a8ed0f79 100644 --- a/colossalai/shardformer/README.md +++ b/colossalai/shardformer/README.md @@ -83,8 +83,10 @@ We will follow this roadmap to develop Shardformer: - [x] BERT - [x] T5 - [x] LlaMa - - [ ] GPT2 - - [ ] BLOOM + - [x] GPT2 + - [x] OPT + - [x] BLOOM + - [ ] GLM - [ ] RoBERTa - [ ] ALBERT - [ ] ERNIE @@ -96,7 +98,7 @@ We will follow this roadmap to develop Shardformer: - [ ] SwinTransformer - [ ] SwinTransformer V2 - [ ] Audio - - [ ] To be added + - [ ] Whisper - [ ] Multi-modal - [ ] To be added diff --git a/colossalai/shardformer/layer/__init__.py b/colossalai/shardformer/layer/__init__.py index 3ece25831..2826a8429 100644 --- a/colossalai/shardformer/layer/__init__.py +++ b/colossalai/shardformer/layer/__init__.py @@ -1,11 +1,12 @@ -from .dropout import Dropout1D +from .dropout import DropoutForParallelInput, DropoutForReplicatedInput from .embedding import Embedding1D, VocabParallelEmbedding1D from .layernorm import FusedLayerNorm from .linear import Linear1D_Col, Linear1D_Row -from .linear_conv import LinearConv1D_Col, LinearConv1D_Row from .loss import cross_entropy_1d +from .qkv_fused_linear import GPT2FusedLinearConv1D_Col, GPT2FusedLinearConv1D_Row __all__ = [ - "Embedding1D", "VocabParallelEmbedding1D", "Linear1D_Col", "Linear1D_Row", "LinearConv1D_Col", "LinearConv1D_Row", - "Dropout1D", "cross_entropy_1d", 'FusedLayerNorm' + "Embedding1D", "VocabParallelEmbedding1D", "Linear1D_Col", "Linear1D_Row", 'GPT2FusedLinearConv1D_Col', + 'GPT2FusedLinearConv1D_Row', 'DropoutForParallelInput', 'DropoutForReplicatedInput', "cross_entropy_1d", + 'FusedLayerNorm' ] diff --git a/colossalai/shardformer/layer/dropout.py b/colossalai/shardformer/layer/dropout.py index 2c49b49fa..2625fe978 100644 --- a/colossalai/shardformer/layer/dropout.py +++ b/colossalai/shardformer/layer/dropout.py @@ -7,10 +7,10 @@ from torch.distributed import ProcessGroup from .parallel_module import ParallelModule from .utils import create_randomizer_with_offset -__all__ = ['Dropout1D'] +__all__ = ['DropoutForParallelInput', 'DropoutForReplicatedInput'] -class Dropout1D(ParallelModule, nn.Dropout): +class DropoutForParallelInput(ParallelModule, nn.Dropout): """ The Dropout Layer will apply dropout mask to the input tensor. The dropout mask is generated with randomness on different ranks of the given process group. This can avoid the same dropout mask is generated @@ -32,13 +32,50 @@ class Dropout1D(ParallelModule, nn.Dropout): @staticmethod def from_native_module(module: nn.Dropout, - process_group: Union[ProcessGroup, List[ProcessGroup]] = None) -> "Dropout1D": + process_group: Union[ProcessGroup, List[ProcessGroup]] = None) -> "DropoutForParallelInput": """ - Create a Dropout1D layer from a native dropout layer. + Create a DropoutForParallelInput layer from a native dropout layer. """ p = module.p inplace = module.inplace - return Dropout1D(p=p, inplace=inplace, process_group=process_group) + return DropoutForParallelInput(p=p, inplace=inplace, process_group=process_group) + + def forward(self, input): + with self.randomizer.fork_rng(): + input = super().forward(input) + return input + + +class DropoutForReplicatedInput(ParallelModule, nn.Dropout): + """ + The Dropout Layer will apply dropout mask to the input tensor. The dropout mask is generated with + randomness on different ranks of the given process group. This can avoid the same dropout mask is generated + and applied on the same position of different ranks, leading to poor convergence performance. + + Args: + p (float): probability of an element to be zeroed. Defaults to 0.5. + inplace (bool): If set to True, will do this operation in-place. Defaults to False. + process_group (ProcessGroup): the process group to be used for generating randomness. Defaults to None. + """ + + def __init__(self, p: float = 0.5, inplace: bool = False, process_group: ProcessGroup = None): + # init with nn.Dropout + super(nn.Dropout, self).__init__(p=p, inplace=inplace) + + # offset the seed with randomizer index only + seed = torch.random.initial_seed() + self.randomizer = create_randomizer_with_offset(seed, process_group=process_group, offset_by_rank=False) + + @staticmethod + def from_native_module( + module: nn.Dropout, + process_group: Union[ProcessGroup, List[ProcessGroup]] = None) -> "DropoutForReplicatedInput": + """ + Create a Dropout1D layer from a native dropout layer. + """ + p = module.p + inplace = module.inplace + return DropoutForReplicatedInput(p=p, inplace=inplace, process_group=process_group) def forward(self, input): with self.randomizer.fork_rng(): diff --git a/colossalai/shardformer/layer/linear.py b/colossalai/shardformer/layer/linear.py index d952d5eec..26ba5883c 100644 --- a/colossalai/shardformer/layer/linear.py +++ b/colossalai/shardformer/layer/linear.py @@ -277,6 +277,7 @@ class Linear1D_Row(ParallelModule): def chunk_weight(self): self.weight_list = torch.chunk(self.weight, self.stream_chunk_num, dim=0) + @torch.no_grad() def reset_parameters(self, weight_initializer, bias_initializer) -> None: fan_in, fan_out = self.in_features, self.out_features weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out) @@ -289,9 +290,10 @@ class Linear1D_Row(ParallelModule): src_rank = dist.distributed_c10d._get_global_rank(self.process_group, 0) origin_device = self.bias.device - self.bias = self.bias.cuda() - dist.broadcast(self.bias, src=src_rank, group=self.process_group) - self.bias = self.bias.to(origin_device) + bias = self.bias.cuda() + dist.broadcast(bias, src=src_rank, group=self.process_group) + bias = bias.to(origin_device) + self.bias.copy_(bias) def forward(self, input_: Tensor) -> Tensor: # Set up backprop all-reduce. diff --git a/colossalai/shardformer/layer/linear_conv.py b/colossalai/shardformer/layer/qkv_fused_linear.py similarity index 79% rename from colossalai/shardformer/layer/linear_conv.py rename to colossalai/shardformer/layer/qkv_fused_linear.py index e856abc14..9d51670c6 100644 --- a/colossalai/shardformer/layer/linear_conv.py +++ b/colossalai/shardformer/layer/qkv_fused_linear.py @@ -31,12 +31,25 @@ from ._operation import ( from .parallel_module import ParallelModule from .utils import create_randomizer_with_offset -__all__ = ['LinearConv1D_Col', 'LinearConv1D_Row'] +__all__ = ['FusedLinear1D_Col', 'FusedLinear1D_Row'] + +# ==================================== +# For GPT Only +# ==================================== -def split_fused_qkv(qkv: torch.Tensor, n_fused: int, process_group: ProcessGroup): +def split_fused_qkv_in_gpt2_style(qkv: torch.Tensor, + n_fused: int, + process_group: ProcessGroup, + is_transposed: bool = False): """ The fused qkv tensor looks like [Q1, Q2, K1, K2, V1, V2], this function will split them into [Q1, K1, V1] and [Q2, K2, V2]. + + Args: + qkv (torch.Tensor): The fused qkv tensor. + n_fused (int): The number items fused together, defaults to 3 (query, key and value). + process_group (ProcessGroup): The process group for distributed communication. + is_transposed (bool): generally the tensor is the shape of (out_features, in_features). Set this to True if the tensor is in the shape (in_features, out_features). """ # get the number of slice for the fused qkv rank = dist.get_rank(group=process_group) @@ -48,7 +61,10 @@ def split_fused_qkv(qkv: torch.Tensor, n_fused: int, process_group: ProcessGroup # [Q, K, V] # to # [Q1, Q2, K1, K2, V1, V2] - weight_chunks = torch.chunk(qkv, world_size * n_fused, dim=-1) + if is_transposed: + weight_chunks = torch.chunk(qkv, world_size * n_fused, dim=-1) + else: + weight_chunks = torch.chunk(qkv, world_size * n_fused, dim=0) # rearrange the slice into the final order # from @@ -56,13 +72,26 @@ def split_fused_qkv(qkv: torch.Tensor, n_fused: int, process_group: ProcessGroup # to # [Q1, K1, V1], [Q2, K2, V2] weight_chunks_of_current_rank = [weight_chunks[i] for i in order[rank::world_size]] - weight_of_current_rank = torch.cat(weight_chunks_of_current_rank, dim=-1) + + if is_transposed: + weight_of_current_rank = torch.cat(weight_chunks_of_current_rank, dim=-1) + else: + weight_of_current_rank = torch.cat(weight_chunks_of_current_rank, dim=0) return weight_of_current_rank -def gather_fused_qkv(qkv: torch.Tensor, n_fused: int, process_group: ProcessGroup): +def gather_fused_qkv_in_gpt2_style(qkv: torch.Tensor, + n_fused: int, + process_group: ProcessGroup, + is_transposed: bool = False): """ The splitted qkv tensor looks like [Q1, K1, V1] and [Q2, K2, V2], this function will gather them into [Q1, Q2, K1, K2, V1, V2]. + + Args: + qkv (torch.Tensor): The fused qkv tensor. + n_fused (int): The number items fused together, defaults to 3 (query, key and value). + process_group (ProcessGroup): The process group for distributed communication. + is_transposed (bool): generally the tensor is the shape of (out_features, in_features). Set this to True if the tensor is in the shape (in_features, out_features). """ world_size = dist.get_world_size(group=process_group) @@ -75,7 +104,11 @@ def gather_fused_qkv(qkv: torch.Tensor, n_fused: int, process_group: ProcessGrou qkv = qkv.cuda() gather_list = [torch.zeros_like(qkv) for _ in range(world_size)] dist.all_gather(gather_list, qkv, group=process_group) - gather_weight = torch.cat(gather_list, dim=-1) + + if is_transposed: + gather_weight = torch.cat(gather_list, dim=-1) + else: + gather_weight = torch.cat(gather_list, dim=0) gather_weight = gather_weight.to(origin_device) qkv = qkv.to(origin_device) @@ -84,15 +117,23 @@ def gather_fused_qkv(qkv: torch.Tensor, n_fused: int, process_group: ProcessGrou # [Q1, K1, V1, Q2, K2, V2] # to # [Q1, Q2, K1, K2, V1, V2] - weight_chunks = torch.chunk(gather_weight, world_size * n_fused, dim=-1) + if is_transposed: + weight_chunks = torch.chunk(gather_weight, world_size * n_fused, dim=-1) + else: + weight_chunks = torch.chunk(gather_weight, world_size * n_fused, dim=0) + reordered_chunk_list = [] for i in range(n_fused): reordered_chunk_list.extend(weight_chunks[i::n_fused]) - reordered_gather_weight = torch.cat(reordered_chunk_list, dim=-1) + + if is_transposed: + reordered_gather_weight = torch.cat(reordered_chunk_list, dim=-1) + else: + reordered_gather_weight = torch.cat(reordered_chunk_list, dim=0) return reordered_gather_weight -class LinearConv1D_Col(ParallelModule): +class GPT2FusedLinearConv1D_Col(ParallelModule): r"""Linear layer with column parallelism. The linear layer is defined as :math:`Y = XA + b`. A is parallelized along @@ -154,10 +195,10 @@ class LinearConv1D_Col(ParallelModule): weight = torch.empty(self.in_features, self.out_features, **factory_kwargs) def shard_fn(tensor): - return split_fused_qkv(tensor, self.n_fused, self.process_group) + return split_fused_qkv_in_gpt2_style(tensor, self.n_fused, self.process_group, True) def gather_fn(tensor): - return gather_fused_qkv(tensor, 3, self.process_group) + return gather_fused_qkv_in_gpt2_style(tensor, 3, self.process_group, True) with torch.no_grad(): sharded_weight = distribute_tensor_with_customization(weight, shard_fn, gather_fn) @@ -202,21 +243,27 @@ class LinearConv1D_Col(ParallelModule): f'Expected only one process group, got {len(process_group)}.' process_group = process_group[0] - linear_1d = LinearConv1D_Col(in_features=in_features, - out_features=out_features, - bias=bias, - device=device, - process_group=process_group, - *args, - **kwargs) + linear_1d = GPT2FusedLinearConv1D_Col(in_features=in_features, + out_features=out_features, + bias=bias, + device=device, + process_group=process_group, + *args, + **kwargs) # TODO: copy the sharded weights with torch.no_grad(): - sharded_weight = split_fused_qkv(module.weight.data, n_fused=n_fused, process_group=process_group) + sharded_weight = split_fused_qkv_in_gpt2_style(module.weight.data, + n_fused=n_fused, + process_group=process_group, + is_transposed=True) linear_1d.weight.data.copy_(sharded_weight.data) if bias: - sharded_bias = split_fused_qkv(module.bias.data, n_fused=n_fused, process_group=process_group) + sharded_bias = split_fused_qkv_in_gpt2_style(module.bias.data, + n_fused=n_fused, + process_group=process_group, + is_transposed=True) linear_1d.bias.data.copy_(sharded_bias.data) return linear_1d @@ -254,7 +301,7 @@ class LinearConv1D_Col(ParallelModule): return output -class LinearConv1D_Row(ParallelModule): +class GPT2FusedLinearConv1D_Row(ParallelModule): r""" Linear layer with row parallelism. This layer is used to fit `Conv1D` layer (Fused QKV) in gpt2 of huggingface. @@ -345,13 +392,13 @@ class LinearConv1D_Row(ParallelModule): f'Expected only one process group, got {len(process_group)}.' process_group = process_group[0] - linear_1d = LinearConv1D_Row(in_features=in_features, - out_features=out_features, - bias=bias, - device=device, - process_group=process_group, - *args, - **kwargs) + linear_1d = GPT2FusedLinearConv1D_Row(in_features=in_features, + out_features=out_features, + bias=bias, + device=device, + process_group=process_group, + *args, + **kwargs) # TODO: copy the sharded weights with torch.no_grad(): diff --git a/colossalai/shardformer/layer/utils.py b/colossalai/shardformer/layer/utils.py index c3d6ab57e..f2ac6563c 100644 --- a/colossalai/shardformer/layer/utils.py +++ b/colossalai/shardformer/layer/utils.py @@ -3,6 +3,7 @@ from contextlib import contextmanager import torch import torch.distributed as dist from torch.distributed import ProcessGroup +from torch.distributed.distributed_c10d import _get_global_rank class Randomizer: @@ -112,27 +113,90 @@ class Randomizer: """ idx = Randomizer._INDEX - Randomizer._INDEX += 1 return idx + @staticmethod + def increment_index(): + """ + Increment the index of the randomizer by one. + """ + Randomizer._INDEX += 1 -def create_randomizer_with_offset(seed: int, process_group: ProcessGroup = None): + @staticmethod + def is_randomizer_index_synchronized(process_group: ProcessGroup = None): + """ + Return whether the randomizer index is synchronized across processes. + """ + index = Randomizer.index() + if dist.is_initialized(): + # convert the index to tensor + index_tensor = torch.tensor(index, dtype=torch.int32).cuda() + + # all gather the index + gathered_index = [torch.zeros_like(index_tensor) for _ in range(dist.get_world_size(process_group))] + dist.all_gather(gathered_index, index_tensor, process_group) + + # make sure all the gathered index are the same + for i in range(1, dist.get_world_size(process_group)): + if gathered_index[i] != gathered_index[0]: + return False + + return True + + @staticmethod + def synchronize_index(process_group: ProcessGroup = None): + """ + All gather the index and pick the largest value. + """ + index = Randomizer.index() + + if dist.is_initialized(): + # convert the index to tensor + index_tensor = torch.tensor(index, dtype=torch.int32).cuda() + + # all gather the index + gathered_index = [torch.zeros_like(index_tensor) for _ in range(dist.get_world_size(process_group))] + dist.all_gather(gathered_index, index_tensor, process_group) + + # pick the largest index + for i in range(1, dist.get_world_size(process_group)): + if gathered_index[i] > index_tensor: + index_tensor = gathered_index[i] + + # set the index + Randomizer._INDEX = index_tensor.item() + + +def create_randomizer_with_offset(seed: int, + process_group: ProcessGroup = None, + offset_by_rank: bool = True, + offset_by_index: bool = True): """ Create a randomizer with an offset. The offset is equal to the rank of the process and the index of the randomizer. Args: seed (int): The base random seed to set. - enable_cpu (bool): fork the CPU RNG state as well. process_group (ProcessGroup): the process group to get the rank from. + offset_by_rank (bool): whether to offset by the rank of the process, i.e., the rank of the process will be added to the seed. Default: True. + offset_by_index (bool): whether to offset by the index of the randomizer, i.e., the index of the randomizer will be added to the seed. Default: True. Returns: Randomizer: the randomizer with offset. """ - offset = Randomizer.index() + base_seed = seed - if dist.is_initialized(): + if offset_by_rank and dist.is_initialized(): rank = dist.get_rank(process_group) - offset += rank + base_seed += rank - seed += offset - return Randomizer(seed=seed) + if offset_by_index: + # check if the randomizer index is synchronized + is_synchronized = Randomizer.is_randomizer_index_synchronized(process_group) + assert is_synchronized, ("We detect that the randomizer index is not synchronized across processes." + "This is not allowed when we want to create a randomizer with offset by index." + "Please call Randomizer.synchronize_index() first.") + + base_seed += Randomizer.index() + Randomizer.increment_index() + + return Randomizer(seed=base_seed) diff --git a/colossalai/shardformer/policies/autopolicy.py b/colossalai/shardformer/policies/autopolicy.py index 9cc583d58..17c063c8d 100644 --- a/colossalai/shardformer/policies/autopolicy.py +++ b/colossalai/shardformer/policies/autopolicy.py @@ -78,6 +78,18 @@ _POLICY_LIST = { PolicyLocation(file_name="opt", class_name="OPTForSequenceClassificationPolicy"), "transformers.models.opt.modeling_opt.OPTForQuestionAnswering": PolicyLocation(file_name="opt", class_name="OPTForQuestionAnsweringPolicy"), + + # Bloom + "transformers.models.bloom.modeling_bloom.BloomModel": + PolicyLocation(file_name="bloom", class_name="BloomModelPolicy"), + "transformers.models.bloom.modeling_bloom.BloomForCausalLM": + PolicyLocation(file_name="bloom", class_name="BloomForCausalLMPolicy"), + "transformers.models.bloom.modeling_bloom.BloomForSequenceClassification": + PolicyLocation(file_name="bloom", class_name="BloomForSequenceClassificationPolicy"), + "transformers.models.bloom.modeling_bloom.BloomForTokenClassification": + PolicyLocation(file_name="bloom", class_name="BloomForTokenClassificationPolicy"), + "transformers.models.bloom.modeling_bloom.BloomForQuestionAnswering": + PolicyLocation(file_name="bloom", class_name="BloomForQuestionAnsweringPolicy"), } diff --git a/colossalai/shardformer/policies/basepolicy.py b/colossalai/shardformer/policies/basepolicy.py index b5d9cdbd7..7e9bcf209 100644 --- a/colossalai/shardformer/policies/basepolicy.py +++ b/colossalai/shardformer/policies/basepolicy.py @@ -52,6 +52,7 @@ class ModulePolicyDescription: attribute_replacement: Dict[str, Any] param_replacement: List[Callable] sub_module_replacement: List[SubModuleReplacementDescription] + method_replacement: List[Callable] = None class Policy(ABC): diff --git a/colossalai/shardformer/policies/bert.py b/colossalai/shardformer/policies/bert.py index fb70cdff8..49ef53259 100644 --- a/colossalai/shardformer/policies/bert.py +++ b/colossalai/shardformer/policies/bert.py @@ -61,7 +61,7 @@ class BertPolicy(Policy): ), SubModuleReplacementDescription( suffix="attention.self.dropout", - target_module=col_nn.Dropout1D, + target_module=col_nn.DropoutForParallelInput, ), SubModuleReplacementDescription( suffix="attention.output.dense", @@ -69,7 +69,7 @@ class BertPolicy(Policy): ), SubModuleReplacementDescription( suffix="attention.output.dropout", - target_module=col_nn.Dropout1D, + target_module=col_nn.DropoutForParallelInput, ), SubModuleReplacementDescription( suffix="intermediate.dense", @@ -81,7 +81,7 @@ class BertPolicy(Policy): ), SubModuleReplacementDescription( suffix="output.dropout", - target_module=col_nn.Dropout1D, + target_module=col_nn.DropoutForParallelInput, ) ]), BertEmbeddings: @@ -94,7 +94,7 @@ class BertPolicy(Policy): ), SubModuleReplacementDescription( suffix="dropout", - target_module=col_nn.Dropout1D, + target_module=col_nn.DropoutForParallelInput, ) ]) } @@ -258,7 +258,7 @@ class BertForSequenceClassificationPolicy(BertPolicy): sub_module_replacement=[ SubModuleReplacementDescription( suffix="dropout", - target_module=col_nn.Dropout1D, + target_module=col_nn.DropoutForParallelInput, ) ]) } @@ -281,7 +281,7 @@ class BertForTokenClassificationPolicy(BertPolicy): sub_module_replacement=[ SubModuleReplacementDescription( suffix="dropout", - target_module=col_nn.Dropout1D, + target_module=col_nn.DropoutForParallelInput, ) ]) } @@ -311,7 +311,7 @@ class BertForMultipleChoicePolicy(BertPolicy): sub_module_replacement=[ SubModuleReplacementDescription( suffix="dropout", - target_module=col_nn.Dropout1D, + target_module=col_nn.DropoutForParallelInput, ) ]) } diff --git a/colossalai/shardformer/policies/bloom.py b/colossalai/shardformer/policies/bloom.py new file mode 100644 index 000000000..d196bdbd6 --- /dev/null +++ b/colossalai/shardformer/policies/bloom.py @@ -0,0 +1,214 @@ +import torch +import torch.distributed as dist + +import colossalai.shardformer.layer as col_nn + +from .basepolicy import ModulePolicyDescription, Policy, SubModuleReplacementDescription + + +def build_bloom_alibi_tensor(self, attention_mask: torch.Tensor, num_heads: int, dtype: torch.dtype) -> torch.Tensor: + """ + Link to paper: https://arxiv.org/abs/2108.12409 Alibi tensor is not causal as the original paper mentions, it + relies on a translation invariance of softmax for quick implementation: with l being a tensor, and a fixed value + `softmax(l+a) = softmax(l)`. Based on + https://github.com/ofirpress/attention_with_linear_biases/blob/a35aaca144e0eb6b789dfcb46784c4b8e31b7983/fairseq/models/transformer.py#L742 + TODO @thomasw21 this doesn't work as nicely due to the masking strategy, and so masking varies slightly. + + Args: + Returns tensor shaped (batch_size * num_heads, 1, max_seq_len) + attention_mask (`torch.Tensor`): + Token-wise attention mask, this should be of shape (batch_size, max_seq_len). + num_heads (`int`, *required*): + number of heads + dtype (`torch.dtype`, *optional*, default=`torch.bfloat16`): + dtype of the output tensor + """ + import math + + if dist.is_initialized(): + world_size = dist.get_world_size() + num_heads = num_heads * world_size + + batch_size, seq_length = attention_mask.shape + closest_power_of_2 = 2**math.floor(math.log2(num_heads)) + base = torch.tensor(2**(-(2**-(math.log2(closest_power_of_2) - 3))), + device=attention_mask.device, + dtype=torch.float32) + powers = torch.arange(1, 1 + closest_power_of_2, device=attention_mask.device, dtype=torch.int32) + slopes = torch.pow(base, powers) + + if closest_power_of_2 != num_heads: + extra_base = torch.tensor(2**(-(2**-(math.log2(2 * closest_power_of_2) - 3))), + device=attention_mask.device, + dtype=torch.float32) + num_remaining_heads = min(closest_power_of_2, num_heads - closest_power_of_2) + extra_powers = torch.arange(1, 1 + 2 * num_remaining_heads, 2, device=attention_mask.device, dtype=torch.int32) + slopes = torch.cat([slopes, torch.pow(extra_base, extra_powers)], dim=0) + + # Note: alibi will added to the attention bias that will be applied to the query, key product of attention + # => therefore alibi will have to be of shape (batch_size, num_heads, query_length, key_length) + # => here we set (batch_size=1, num_heads=num_heads, query_length=1, key_length=max_length) + # => the query_length dimension will then be broadcasted correctly + # This is more or less identical to T5's relative position bias: + # https://github.com/huggingface/transformers/blob/f681437203baa7671de3174b0fa583c349d9d5e1/src/transformers/models/t5/modeling_t5.py#L527 + arange_tensor = ((attention_mask.cumsum(dim=-1) - 1) * attention_mask)[:, None, :] + alibi = slopes[..., None] * arange_tensor + if dist.is_initialized(): + num_heads_per_rank = int(num_heads / dist.get_world_size()) + offset = dist.get_rank() * num_heads_per_rank + alibi = alibi.view(batch_size, num_heads, 1, seq_length) + alibi = alibi[:, offset:num_heads_per_rank + offset, :, :] + return alibi.reshape(batch_size * num_heads_per_rank, 1, seq_length).to(dtype) + else: + return alibi.reshape(batch_size * num_heads, 1, seq_length).to(dtype) + + +class BloomPolicy(Policy): + + def preprocess(self): + # reshape the embedding layer + r""" + Reshape the Embedding layer to make the embedding dimension divisible by world_size + """ + # TODO: + vocab_size = self.model.config.vocab_size + world_size = self.shard_config.tensor_parallel_size + if vocab_size % world_size != 0: + new_vocab_size = vocab_size + world_size - vocab_size % world_size + self.model.resize_token_embeddings(new_vocab_size) + return self.model + + def module_policy(self): + from transformers.models.bloom.modeling_bloom import BloomBlock, BloomModel + + return { + BloomBlock: + ModulePolicyDescription( + attribute_replacement={ + # 1. shard hidden size + "self_attention.hidden_size": + self.model.config.hidden_size // self.shard_config.tensor_parallel_size, + "self_attention.split_size": + self.model.config.hidden_size // self.shard_config.tensor_parallel_size, + # 2. shard number of heads + "self_attention.num_heads": + self.model.config.n_head // self.shard_config.tensor_parallel_size, + }, + param_replacement=[], + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="self_attention.query_key_value", + target_module=col_nn.Linear1D_Col, + # kwargs={'n_fused': 3} + ), + SubModuleReplacementDescription( + suffix="self_attention.dense", + target_module=col_nn.Linear1D_Row, + ), + SubModuleReplacementDescription( + suffix="self_attention.attention_dropout", + target_module=col_nn.DropoutForParallelInput, + ), + SubModuleReplacementDescription( + suffix="mlp.dense_h_to_4h", + target_module=col_nn.Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="mlp.dense_4h_to_h", + target_module=col_nn.Linear1D_Row, + ), + ]), + BloomModel: + ModulePolicyDescription(attribute_replacement={ + "num_heads": self.model.config.n_head // self.shard_config.tensor_parallel_size, + }, + param_replacement=[], + method_replacement={"build_alibi_tensor": build_bloom_alibi_tensor}, + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="word_embeddings", + target_module=col_nn.VocabParallelEmbedding1D, + ) + ]) + } + + def new_model_class(self): + # do nothing + return self.model + + def postprocess(self): + return self.model + + +# BertModel +class BloomModelPolicy(BloomPolicy): + pass + + +class BloomForCausalLMPolicy(BloomPolicy): + + def module_policy(self): + from transformers.models.bloom.modeling_bloom import BloomForCausalLM + policy = super().module_policy() + # add a new item for casual lm + new_item = { + BloomForCausalLM: + ModulePolicyDescription(attribute_replacement={}, + param_replacement=[], + sub_module_replacement=[ + SubModuleReplacementDescription(suffix="lm_head", + target_module=col_nn.Linear1D_Col, + kwargs=dict(gather_output=True)) + ]) + } + policy.update(new_item) + return policy + + +class BloomForSequenceClassificationPolicy(BloomPolicy): + + def module_policy(self): + from transformers.models.bloom.modeling_bloom import BloomForSequenceClassification + policy = super().module_policy() + # add a new item for casual lm + new_item = { + BloomForSequenceClassification: + ModulePolicyDescription(attribute_replacement={}, + param_replacement=[], + sub_module_replacement=[ + SubModuleReplacementDescription(suffix="score", + target_module=col_nn.Linear1D_Col, + kwargs=dict(gather_output=True)) + ]) + } + policy.update(new_item) + return policy + + +class BloomForTokenClassificationPolicy(BloomPolicy): + + def module_policy(self): + from transformers.models.bloom.modeling_bloom import BloomForTokenClassification + policy = super().module_policy() + # add a new item for casual lm + new_item = { + BloomForTokenClassification: + ModulePolicyDescription(attribute_replacement={}, + param_replacement=[], + sub_module_replacement=[ + SubModuleReplacementDescription(suffix="classifier", + target_module=col_nn.Linear1D_Col, + kwargs=dict(gather_output=True)), + SubModuleReplacementDescription( + suffix="dropout", + target_module=col_nn.DropoutForReplicatedInput, + ), + ]) + } + policy.update(new_item) + return policy + + +class BloomForQuestionAnsweringPolicy(BloomPolicy): + # No head sharding as the output features is only 2 + pass diff --git a/colossalai/shardformer/policies/gpt2.py b/colossalai/shardformer/policies/gpt2.py index 9d5d7d36a..ebfaf8a8e 100644 --- a/colossalai/shardformer/policies/gpt2.py +++ b/colossalai/shardformer/policies/gpt2.py @@ -42,37 +42,37 @@ class GPT2Policy(Policy): sub_module_replacement=[ SubModuleReplacementDescription( suffix="attn.c_attn", - target_module=col_nn.LinearConv1D_Col, + target_module=col_nn.GPT2FusedLinearConv1D_Col, kwargs={ "n_fused": 3, }, ), SubModuleReplacementDescription( suffix="attn.c_proj", - target_module=col_nn.LinearConv1D_Row, + target_module=col_nn.GPT2FusedLinearConv1D_Row, ), SubModuleReplacementDescription( suffix="mlp.c_fc", - target_module=col_nn.LinearConv1D_Col, + target_module=col_nn.GPT2FusedLinearConv1D_Col, kwargs={ "n_fused": 1, }, ), SubModuleReplacementDescription( suffix="mlp.c_proj", - target_module=col_nn.LinearConv1D_Row, + target_module=col_nn.GPT2FusedLinearConv1D_Row, ), SubModuleReplacementDescription( suffix="attn.attn_dropout", - target_module=col_nn.Dropout1D, + target_module=col_nn.DropoutForParallelInput, ), SubModuleReplacementDescription( suffix="attn.resid_dropout", - target_module=col_nn.Dropout1D, + target_module=col_nn.DropoutForParallelInput, ), SubModuleReplacementDescription( suffix="mlp.dropout", - target_module=col_nn.Dropout1D, + target_module=col_nn.DropoutForParallelInput, ), ]) } diff --git a/colossalai/shardformer/policies/t5.py b/colossalai/shardformer/policies/t5.py index 9a1b63e46..8d8abc9f7 100644 --- a/colossalai/shardformer/policies/t5.py +++ b/colossalai/shardformer/policies/t5.py @@ -9,7 +9,7 @@ from transformers.models.t5.modeling_t5 import ( T5Stack, ) -from colossalai.shardformer.layer import Dropout1D, Embedding1D, Linear1D_Col, Linear1D_Row +from colossalai.shardformer.layer import DropoutForParallelInput, Embedding1D, Linear1D_Col, Linear1D_Row from .basepolicy import ModulePolicyDescription, Policy, SubModuleReplacementDescription @@ -38,7 +38,7 @@ class T5ModelPolicy(Policy): sub_module_replacement=[ SubModuleReplacementDescription( suffix="dropout", - target_module=Dropout1D, + target_module=DropoutForParallelInput, ) ]), T5LayerSelfAttention: @@ -47,7 +47,7 @@ class T5ModelPolicy(Policy): sub_module_replacement=[ SubModuleReplacementDescription( suffix="dropout", - target_module=Dropout1D, + target_module=DropoutForParallelInput, ), ]), T5LayerCrossAttention: @@ -56,7 +56,7 @@ class T5ModelPolicy(Policy): sub_module_replacement=[ SubModuleReplacementDescription( suffix="dropout", - target_module=Dropout1D, + target_module=DropoutForParallelInput, ) ]), T5Attention: @@ -97,7 +97,7 @@ class T5ModelPolicy(Policy): sub_module_replacement=[ SubModuleReplacementDescription( suffix="dropout", - target_module=Dropout1D, + target_module=DropoutForParallelInput, ), ]), T5DenseGatedActDense: @@ -117,7 +117,7 @@ class T5ModelPolicy(Policy): kwargs=dict(gather_output=True)), SubModuleReplacementDescription( suffix="dropout", - target_module=Dropout1D, + target_module=DropoutForParallelInput, ) ]), T5DenseActDense: @@ -134,7 +134,7 @@ class T5ModelPolicy(Policy): ), SubModuleReplacementDescription( suffix="dropout", - target_module=Dropout1D, + target_module=DropoutForParallelInput, ) ]) } diff --git a/colossalai/shardformer/policies/vit.py b/colossalai/shardformer/policies/vit.py index 4a2b72057..550f8f997 100644 --- a/colossalai/shardformer/policies/vit.py +++ b/colossalai/shardformer/policies/vit.py @@ -1,15 +1,15 @@ from typing import Dict, Union import torch.nn as nn +from transformers.models.vit.modeling_vit import ViTAttention, ViTEmbeddings, ViTLayer, ViTModel -from transformers.models.vit.modeling_vit import ViTModel, ViTLayer, ViTEmbeddings, ViTAttention - -from colossalai.shardformer.layer import Linear1D_Col, Linear1D_Row, Dropout1D +from colossalai.shardformer.layer import DropoutForReplicatedInput, Linear1D_Col, Linear1D_Row from .basepolicy import ModulePolicyDescription, Policy, SubModuleReplacementDescription + class ViTPolicy(Policy): - + def preprocess(self): # Resize embedding vocab_size = self.model.config.vocab_size @@ -20,77 +20,68 @@ class ViTPolicy(Policy): self.model.resize_token_embeddings(new_vocab_size) return self.model - + def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: - return { + return { ViTEmbeddings: - ModulePolicyDescription( - attribute_replacement{}, - param_replacement=[], - sub_module_replacement=[ - SubModuleReplacementDescription( - suffix="dropout", - target_module=Dropout1D, - ) - ] - ), + ModulePolicyDescription(attribute_replacement={}, + param_replacement=[], + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="dropout", + target_module=DropoutForReplicatedInput, + ) + ]), ViTLayer: - ModulePolicyDescription( - attribute_replacement{ - "attention.attention.num_attention_heads": - self.model.config.num_attention_heads//self.shard_config.tensor_parallel_size, - "attention.attention.all_head_size": - self.model.config.hidden_size//self.shard_config.tensor_parallel_size, - }, - param_replacement=[], - sub_module_replacement=[ - SubModuleReplacementDescription( - suffix="attention.attention.query", - target_module=Linear1D_Col, - ), - SubModuleReplacementDescription( - suffix="attention.attention.key", - target_module=Linear1D_Col, - ), - SubModuleReplacementDescription( - suffix="attention.attention.value", - target_module=Linear1D_Col, - ), - SubModuleReplacementDescription( - suffix="attention.attention.dropout", - target_module=Dropout1D, - ), - SubModuleReplacementDescription( - suffix="attention.output.dense", - target_module=Linear1D_Row, - ), - SubModuleReplacementDescription( - suffix="attention.output.dropout", - target_module=Dropout1D, - ), - SubModuleReplacementDescription( - suffix="intermediate.dense", - target_module=Linear1D_Col, - ), - SubModuleReplacementDescription( - suffix="output.dense", - target_module=Linear1D_Row, - ), - SubModuleReplacementDescription( - suffix="output.dropout", - target_module=Dropout1D, - ), - ] - ), + ModulePolicyDescription(attribute_replacement={ + "attention.attention.num_attention_heads": + self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size, + "attention.attention.all_head_size": + self.model.config.hidden_size // self.shard_config.tensor_parallel_size, + }, + param_replacement=[], + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="attention.attention.query", + target_module=Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="attention.attention.key", + target_module=Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="attention.attention.value", + target_module=Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="attention.attention.dropout", + target_module=DropoutForParallelInput, + ), + SubModuleReplacementDescription( + suffix="attention.output.dense", + target_module=Linear1D_Row, + ), + SubModuleReplacementDescription( + suffix="attention.output.dropout", + target_module=DropoutForParallelInput, + ), + SubModuleReplacementDescription( + suffix="intermediate.dense", + target_module=Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="output.dense", + target_module=Linear1D_Row, + ), + SubModuleReplacementDescription( + suffix="output.dropout", + target_module=DropoutForParallelInput, + ), + ]), } - + def new_model_class(self): return None def postprocess(self): return self.model - - - - - diff --git a/colossalai/shardformer/shard/sharder.py b/colossalai/shardformer/shard/sharder.py index 22f5f1c12..c2444e1f7 100644 --- a/colossalai/shardformer/shard/sharder.py +++ b/colossalai/shardformer/shard/sharder.py @@ -95,8 +95,9 @@ class ModelSharder(object): attr_replacement = module_description[1].attribute_replacement param_replacement = module_description[1].param_replacement sub_module_replacement = module_description[1].sub_module_replacement + method_replacement = module_description[1].method_replacement self._recursive_replace_layer(self.model, origin_layer_cls, attr_replacement, param_replacement, - sub_module_replacement) + method_replacement, sub_module_replacement) def _recursive_replace_layer( self, @@ -104,6 +105,7 @@ class ModelSharder(object): origin_cls: nn.Module, attr_replacement: Dict[str, Any], param_replacement: List[Callable], + method_replacement: Dict[str, Callable], sub_module_replacement: List[Callable], ) -> None: r""" @@ -119,9 +121,11 @@ class ModelSharder(object): if module.__class__ == origin_cls: self._replace_attr(module, attr_replacement) self._replace_param(module, param_replacement) + self._replace_method(module, method_replacement) self._replace_sub_module(module, sub_module_replacement) + for name, child in module.named_children(): - self._recursive_replace_layer(child, origin_cls, attr_replacement, param_replacement, + self._recursive_replace_layer(child, origin_cls, attr_replacement, param_replacement, method_replacement, sub_module_replacement) def _replace_attr( @@ -154,6 +158,14 @@ class ModelSharder(object): # TODO: support parameter shard pass + def _replace_method(self, module: nn.Module, method_replacement: Dict[str, Callable]): + if method_replacement is None: + return + + for method_name, new_method in method_replacement.items(): + # bind the new method to the module + setattr(module, method_name, new_method.__get__(module, module.__class__)) + def _replace_sub_module( self, org_layer: nn.Module, diff --git a/tests/kit/model_zoo/transformers/__init__.py b/tests/kit/model_zoo/transformers/__init__.py index ffaf4c566..4aa01abe1 100644 --- a/tests/kit/model_zoo/transformers/__init__.py +++ b/tests/kit/model_zoo/transformers/__init__.py @@ -1,5 +1,6 @@ from .albert import * from .bert import * +from .bloom import * from .gpt import * from .llama import * from .opt import * diff --git a/tests/kit/model_zoo/transformers/bloom.py b/tests/kit/model_zoo/transformers/bloom.py new file mode 100644 index 000000000..71146c0b9 --- /dev/null +++ b/tests/kit/model_zoo/transformers/bloom.py @@ -0,0 +1,107 @@ +import torch +import transformers + +from ..registry import ModelAttribute, model_zoo + +# =============================== +# Register Bloom +# =============================== + + +def data_gen(): + # Generated from following code snippet + # + # from transformers import BloomTokenizer + # input = 'Hello, my dog is cute' + # tokenized_input = tokenizer(input, return_tensors='pt') + # input_ids = tokenized_input['input_ids'] + # attention_mask = tokenized_input['attention_mask'] + input_ids = torch.tensor([[59414, 15, 2670, 35433, 632, 207595]], dtype=torch.int64) + attention_mask = torch.tensor([[1, 1, 1, 1, 1, 1]], dtype=torch.int64) + return dict(input_ids=input_ids, attention_mask=attention_mask) + + +def data_gen_for_lm(): + # LM data gen + # the `labels` of LM is the token of the output, cause no padding, use `input_ids` as `labels` + data = data_gen() + data['labels'] = data['input_ids'].clone() + return data + + +def data_gen_for_token_classification(): + # token classification data gen + # `labels` is the type not the token id for token classification, 0 or 1 + data = data_gen() + data['labels'] = torch.tensor([[0, 0, 0, 0, 0, 0]], dtype=torch.int64) + return data + + +def data_gen_for_sequence_classification(): + # sequence classification data gen + data = data_gen() + data['labels'] = torch.tensor([0], dtype=torch.int64) + return data + + +def data_gen_for_question_answering(): + # obtained with the following code + # + # from transformers import AutoTokenizer + # tokenizer = AutoTokenizer.from_pretrained("bigscience/bloom-560m") + # question, text = "Who was Jim Henson?", "Jim Henson was a nice puppet" + # inputs = tokenizer(question, text, return_tensors="pt") + + input_ids = torch.tensor( + [[57647, 1620, 23967, 620, 107373, 34, 91514, 620, 107373, 1620, 267, 35378, 48946, 18161]], dtype=torch.int64) + attention_mask = torch.tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]], dtype=torch.int64) + return dict(input_ids=input_ids, attention_mask=attention_mask) + + +# define output transform function +output_transform_fn = lambda x: x + +# define loss function +loss_fn_for_bloom_model = lambda x: x.last_hidden_state.mean() +loss_fn_for_causal_lm = lambda x: x.loss +loss_fn_for_classification = lambda x: x.logits.mean() +loss_fn_for_question_answering = lambda x: x.end_logits.mean() + +config = transformers.BloomConfig(n_layer=1, + n_head=4, + vocab_size=250880, + hidden_dropout=0, + attention_dropout=0, + hidden_size=64) + +# register the following models +model_zoo.register(name='transformers_bloom', + model_fn=lambda: transformers.BloomModel(config), + data_gen_fn=data_gen, + output_transform_fn=output_transform_fn, + loss_fn=loss_fn_for_bloom_model, + model_attribute=ModelAttribute(has_control_flow=True)) +model_zoo.register(name='transformers_bloom_for_causal_lm', + model_fn=lambda: transformers.BloomForCausalLM(config), + data_gen_fn=data_gen_for_lm, + output_transform_fn=output_transform_fn, + loss_fn=loss_fn_for_causal_lm, + model_attribute=ModelAttribute(has_control_flow=True)) +model_zoo.register(name='transformers_bloom_for_sequence_classification', + model_fn=lambda: transformers.BloomForSequenceClassification(config), + data_gen_fn=data_gen_for_sequence_classification, + output_transform_fn=output_transform_fn, + loss_fn=loss_fn_for_classification, + model_attribute=ModelAttribute(has_control_flow=True)) +model_zoo.register(name='transformers_bloom_for_token_classification', + model_fn=lambda: transformers.BloomForTokenClassification(config), + data_gen_fn=data_gen_for_token_classification, + output_transform_fn=output_transform_fn, + loss_fn=loss_fn_for_classification, + model_attribute=ModelAttribute(has_control_flow=True)) +model_zoo.register(name='transformers_bloom_for_question_answering', + model_fn=lambda: transformers.BloomForQuestionAnswering(config), + data_gen_fn=data_gen_for_question_answering, + output_transform_fn=output_transform_fn, + loss_fn=loss_fn_for_question_answering, + model_attribute=ModelAttribute(has_control_flow=True)) diff --git a/tests/kit/model_zoo/transformers/gpt.py b/tests/kit/model_zoo/transformers/gpt.py index c598fa8f4..b9e031078 100644 --- a/tests/kit/model_zoo/transformers/gpt.py +++ b/tests/kit/model_zoo/transformers/gpt.py @@ -6,8 +6,6 @@ from ..registry import ModelAttribute, model_zoo # =============================== # Register single-sentence GPT # =============================== -BATCH_SIZE = 1 # it can only be 1 as GPT cannot handle batch sizes > 1 if no padding token is defined. -SEQ_LENGTH = 16 def data_gen(): diff --git a/tests/test_shardformer/test_layer/test_dropout.py b/tests/test_shardformer/test_layer/test_dropout.py index c62d25d94..332e37711 100644 --- a/tests/test_shardformer/test_layer/test_dropout.py +++ b/tests/test_shardformer/test_layer/test_dropout.py @@ -3,13 +3,13 @@ import torch.distributed as dist import torch.nn as nn import colossalai -from colossalai.shardformer.layer import Dropout1D +from colossalai.shardformer.layer import DropoutForParallelInput, DropoutForReplicatedInput from colossalai.testing import assert_equal, assert_not_equal, rerun_if_address_is_in_use, spawn -def check_dropout(): +def check_dropout_parallel_input(): dropout = nn.Dropout().cuda() - dropout_1d = Dropout1D.from_native_module(dropout, process_group=None) + dropout_1d = DropoutForParallelInput.from_native_module(dropout, process_group=None) # check computation correctness x = torch.rand(4, 128).cuda() @@ -39,9 +39,26 @@ def check_dropout(): assert_not_equal(out_1d_all[i], out_1d_all[0]) +def check_dropout_replicated_input(): + dropout = nn.Dropout().cuda() + dropout_replica = DropoutForReplicatedInput.from_native_module(dropout, process_group=None) + + # check computation correctness + x = torch.rand(4, 128).cuda() + out_1d = dropout_replica(x) + + # ensure out_1d is different across ranks + world_size = dist.get_world_size() + out_1d_all = [torch.zeros_like(out_1d) for _ in range(world_size)] + dist.all_gather(out_1d_all, out_1d) + for i in range(1, world_size): + assert_equal(out_1d_all[i], out_1d_all[0]) + + def run_dist(rank, world_size, port): colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') - check_dropout() + check_dropout_parallel_input() + check_dropout_replicated_input() @rerun_if_address_is_in_use() diff --git a/tests/test_shardformer/test_layer/test_linearconv_1d.py b/tests/test_shardformer/test_layer/test_qkv_fused_linear_1d.py similarity index 81% rename from tests/test_shardformer/test_layer/test_linearconv_1d.py rename to tests/test_shardformer/test_layer/test_qkv_fused_linear_1d.py index 774e6340e..681c4f6dd 100644 --- a/tests/test_shardformer/test_layer/test_linearconv_1d.py +++ b/tests/test_shardformer/test_layer/test_qkv_fused_linear_1d.py @@ -4,8 +4,8 @@ import torch.nn as nn from torch.testing import assert_close import colossalai -from colossalai.shardformer.layer import LinearConv1D_Col, LinearConv1D_Row -from colossalai.shardformer.layer.linear_conv import split_fused_qkv +from colossalai.shardformer.layer import GPT2FusedLinearConv1D_Col, GPT2FusedLinearConv1D_Row +from colossalai.shardformer.layer.qkv_fused_linear import split_fused_qkv_in_gpt2_style from colossalai.testing import rerun_if_address_is_in_use, spawn @@ -52,7 +52,10 @@ def rearrange(tensor: torch.Tensor, dim: int): def check_linear_conv_1d_col(): linear = Conv1D(192, 48).cuda() - linear_conv_col = LinearConv1D_Col.from_native_module(linear, process_group=None, gather_output=True, n_fused=3) + linear_conv_col = GPT2FusedLinearConv1D_Col.from_native_module(linear, + process_group=None, + gather_output=True, + n_fused=3) assert linear.weight.shape == torch.Size([48, 192]) assert linear.bias.shape == torch.Size([192]) @@ -73,13 +76,13 @@ def check_linear_conv_1d_col(): out.sum().backward() gather_out.sum().backward() - target_grad = split_fused_qkv(linear.weight.grad, 3, None) + target_grad = split_fused_qkv_in_gpt2_style(linear.weight.grad, 3, None, True) assert_close(target_grad, linear_conv_col.weight.grad) def check_linear_conv_1d_row(): linear = Conv1D(192, 48).cuda() - linear_row = LinearConv1D_Row.from_native_module(linear, process_group=None, parallel_input=False) + linear_row = GPT2FusedLinearConv1D_Row.from_native_module(linear, process_group=None, parallel_input=False) assert linear.weight.shape == torch.Size([48, 192]) assert linear_row.weight.shape == torch.Size([24, 192]) @@ -102,6 +105,8 @@ def check_linear_conv_1d_row(): def run_dist(rank, world_size, port): colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + + # test for linear conv check_linear_conv_1d_col() check_linear_conv_1d_row() diff --git a/tests/test_shardformer/test_model/test_shard_bloom.py b/tests/test_shardformer/test_model/test_shard_bloom.py new file mode 100644 index 000000000..7e2e3dfa8 --- /dev/null +++ b/tests/test_shardformer/test_model/test_shard_bloom.py @@ -0,0 +1,59 @@ +import pytest +import torch + +import colossalai +from colossalai.logging import disable_existing_loggers +from colossalai.testing import assert_hf_output_close, clear_cache_before_run, rerun_if_address_is_in_use, spawn +from tests.kit.model_zoo import model_zoo +from tests.test_shardformer.test_model._utils import build_model, run_forward + + +def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn): + # check forward + org_output, org_loss, shard_output, shard_loss = run_forward(org_model, sharded_model, data_gen_fn, + output_transform_fn, loss_fn) + assert_hf_output_close(org_output, shard_output, ignore_keys=['past_key_values']) + + # do backward + org_loss.backward() + shard_loss.backward() + + # check grad equality + if org_model.__class__.__name__ == 'BloomModel': + org_grad = org_model.h[0].self_attention.query_key_value.weight.grad + shard_grad = sharded_model.h[0].self_attention.query_key_value.weight.grad + else: + org_grad = org_model.transformer.h[0].self_attention.query_key_value.weight.grad + shard_grad = sharded_model.transformer.h[0].self_attention.query_key_value.weight.grad + + shard_grad_list = [torch.zeros([*shard_grad.shape]).to('cuda') for _ in range(2)] + torch.distributed.all_gather(shard_grad_list, shard_grad) + all_shard_grad = torch.cat(shard_grad_list, dim=0) + + assert torch.allclose(org_loss, shard_loss, + atol=1e-5), f"shard model loss is not equal to orgin model loss\n{org_loss}\n{shard_loss}" + assert torch.allclose(org_grad, all_shard_grad, + atol=1e-5), f"shard model grad is not equal to orgin model grad\n{org_grad}\n{all_shard_grad}" + + +def check_bloom(rank, world_size, port): + disable_existing_loggers() + colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + + sub_model_zoo = model_zoo.get_sub_registry('transformers_bloom') + for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): + org_model, sharded_model = build_model(world_size, model_fn) + check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn) + + torch.cuda.empty_cache() + + +@pytest.mark.dist +@rerun_if_address_is_in_use() +@clear_cache_before_run() +def test_bloom(): + spawn(check_bloom, 2) + + +if __name__ == "__main__": + test_bloom()