From dd2bf026797fb94d5120b481145f37a9661c4a6c Mon Sep 17 00:00:00 2001 From: FoolPlayer <45593998+FoolPlayer@users.noreply.github.com> Date: Fri, 14 Jul 2023 15:56:59 +0800 Subject: [PATCH] [shardformer] support SAM (#4231) * 1.support sam 2.add fused qkv for nn.Linear * update utils support set element in list * overtwrite SamVisionAttention foward to use DropoutForParallelInput * remove unused code --- colossalai/shardformer/_utils.py | 44 +++- colossalai/shardformer/layer/__init__.py | 5 +- .../shardformer/layer/qkv_fused_linear.py | 175 ++++++++++++++- colossalai/shardformer/modeling/sam.py | 41 ++++ .../shardformer/policies/auto_policy.py | 4 + colossalai/shardformer/policies/sam.py | 209 ++++++++++++++++++ tests/kit/model_zoo/transformers/__init__.py | 1 + tests/kit/model_zoo/transformers/sam.py | 52 +++++ .../test_gpt2_qkv_fused_linear_1d.py | 120 ++++++++++ .../test_model/test_shard_sam.py | 92 ++++++++ 10 files changed, 733 insertions(+), 10 deletions(-) create mode 100644 colossalai/shardformer/modeling/sam.py create mode 100644 colossalai/shardformer/policies/sam.py create mode 100644 tests/kit/model_zoo/transformers/sam.py create mode 100644 tests/test_shardformer/test_layer/test_gpt2_qkv_fused_linear_1d.py create mode 100644 tests/test_shardformer/test_model/test_shard_sam.py diff --git a/colossalai/shardformer/_utils.py b/colossalai/shardformer/_utils.py index 4ad877e72..c553080de 100644 --- a/colossalai/shardformer/_utils.py +++ b/colossalai/shardformer/_utils.py @@ -1,25 +1,57 @@ import re -def get_obj_list_element(obj, a): +def get_obj_list_element(obj, attr: str): r""" Get the element of the list in the object + + If the attr is a normal attribute, return the attribute of the object. + If the attr is a index type, return the element of the index in the list, like `layers[0]`. + + Args: + obj (Object): The object to get + attr (str): The suffix of the attribute to get + """ re_pattern = r'\[\d+\]' prog = re.compile(re_pattern) - result = prog.search(a) + result = prog.search(attr) if result: matched_brackets = result.group() matched_index = matched_brackets.replace('[', '') matched_index = matched_index.replace(']', '') - a_ = a.replace(matched_brackets, '') - container_obj = getattr(obj, a_) + attr_ = attr.replace(matched_brackets, '') + container_obj = getattr(obj, attr_) obj = container_obj[int(matched_index)] else: - obj = getattr(obj, a) + obj = getattr(obj, attr) return obj +def set_obj_list_element(obj, attr: str, value): + r""" + Set the element to value of a list object + + It used like set_obj_list_element(obj, 'lyaers[0]', new_layer), it will set obj.layers[0] to value + + Args: + obj (object): The object to set + attr (str): the string including a list index like `layers[0]` + """ + re_pattern = r'\[\d+\]' + prog = re.compile(re_pattern) + result = prog.search(attr) + if result: + matched_brackets = result.group() + matched_index = matched_brackets.replace('[', '') + matched_index = matched_index.replace(']', '') + attr_ = attr.replace(matched_brackets, '') + container_obj = getattr(obj, attr_) + container_obj[int(matched_index)] = value + else: + setattr(obj, attr, value) + + def hasattr_(obj, attr: str): r""" Check whether the object has the multi sublevel attr @@ -56,7 +88,7 @@ def setattr_(obj, attr: str, value, ignore: bool = False): if ignore: return raise AttributeError(f"Object {obj.__class__.__name__} has no attribute {attr}") - setattr(obj, attrs[-1], value) + set_obj_list_element(obj, attrs[-1], value) def getattr_(obj, attr: str, ignore: bool = False): diff --git a/colossalai/shardformer/layer/__init__.py b/colossalai/shardformer/layer/__init__.py index 7cdcfc318..0c44e6621 100644 --- a/colossalai/shardformer/layer/__init__.py +++ b/colossalai/shardformer/layer/__init__.py @@ -3,11 +3,10 @@ from .embedding import Embedding1D, VocabParallelEmbedding1D from .linear import Linear1D_Col, Linear1D_Row from .loss import cross_entropy_1d from .normalization import FusedLayerNorm, FusedRMSNorm -from .parallel_module import ParallelModule -from .qkv_fused_linear import GPT2FusedLinearConv1D_Col, GPT2FusedLinearConv1D_Row +from .qkv_fused_linear import FusedLinear1D_Col, GPT2FusedLinearConv1D_Col, GPT2FusedLinearConv1D_Row __all__ = [ "Embedding1D", "VocabParallelEmbedding1D", "Linear1D_Col", "Linear1D_Row", 'GPT2FusedLinearConv1D_Col', 'GPT2FusedLinearConv1D_Row', 'DropoutForParallelInput', 'DropoutForReplicatedInput', "cross_entropy_1d", - 'FusedLayerNorm', 'FusedRMSNorm', 'ParallelModule' + 'FusedLayerNorm', 'FusedRMSNorm', 'FusedLinear1D_Col' ] diff --git a/colossalai/shardformer/layer/qkv_fused_linear.py b/colossalai/shardformer/layer/qkv_fused_linear.py index 3c47c0b11..1e4b6ecb6 100644 --- a/colossalai/shardformer/layer/qkv_fused_linear.py +++ b/colossalai/shardformer/layer/qkv_fused_linear.py @@ -25,6 +25,7 @@ from colossalai.tensor.d_tensor.api import ( from ._operation import ( gather_forward_split_backward, + linear_with_async_comm, matmul_with_async_comm, reduce_backward, reduce_forward, @@ -33,7 +34,7 @@ from ._operation import ( from .parallel_module import ParallelModule from .utils import create_randomizer_with_offset -__all__ = ['FusedLinear1D_Col', 'FusedLinear1D_Row'] +__all__ = ['FusedLinear1D_Col', 'FusedLinear1D_Row', 'GPT2FusedLinearConv1D_Col', 'GPT2FusedLinearConv1D_Row'] # ==================================== # For GPT Only @@ -490,3 +491,175 @@ class GPT2FusedLinearConv1D_Row(ParallelModule): return output else: return output, self.bias + + +# ==================================== +# For Fused torch.nn.Linear +# ==================================== + + +class FusedLinear1D_Col(ParallelModule): + r"""Fused Linear layer with column parallelism. + + The linear layer is defined as :math:`Y = XA + b`. A is parallelized along + its second dimension as :math:`A = [A_1, ..., A_p]`. This layer is used to fit `torch.nn.Linear` layer (Fused QKV) in normal torch layer of huggingface, like SAM. + + Args: + in_features (int): size of each input sample. + out_features (int): size of each output sample. + bias (bool, optional): If set to ``False``, the layer will not learn an additive bias, defaults to ``True``. + dtype (`torch.dtype`): The dtype of parameters, defaults to None. + device (`torch.device`): The device of parameters, defaults to None. + n_fused (int): The number items fused, defaults to 3 (QKV). + process_group (`torch.distributed.ProcessGroup`): The process group to be used for weight sharding and communication, defaults to None. + gather_output (bool, optional): If true, call all-gather on output and make Y available + to all GPUs, otherwise, every GPU will have its output + which is :math:`Y_i = XA_i`, defaults to False + skip_bias_add (bool): If set to ``True``, it will skip bias add for linear layer, + which is preserved for kernel fusion, defaults to False + weight_initializer (`typing.Callable`): + The initializer of weight, defaults to kaiming uniform initializer. + bias_initializer (`typing.Callable`): + The initializer of bias, defaults to xavier uniform initializer. + + More details about ``initializer`` please refer to + `init `_. + """ + + def __init__(self, + in_features: int, + out_features: int, + bias: bool = True, + dtype: torch.dtype = None, + device: torch.device = None, + process_group: ProcessGroup = None, + async_communication: bool = False, + gather_output: bool = False, + skip_bias_add: bool = False, + n_fused: int = 3, + weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), + bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1)): + super().__init__() + + # Keep input parameters + self.in_features = in_features + self.out_features = out_features + self.gather_output = gather_output + self.skip_bias_add = skip_bias_add + self.device = device + self.n_fused = n_fused + self.process_group = process_group + self.async_communication = async_communication + + if skip_bias_add and not bias: + raise ValueError('cannot skip bias addition if bias is None') + + # Parameters. + # Initialize weight. + factory_kwargs = {'device': device, 'dtype': dtype} + weight = torch.empty(self.out_features, self.in_features, **factory_kwargs) + + def shard_fn(tensor): + return split_fused_qkv_in_gpt2_style(tensor, self.n_fused, self.process_group, False) + + def gather_fn(tensor): + return gather_fused_qkv_in_gpt2_style(tensor, 3, self.process_group, False) + + with torch.no_grad(): + sharded_weight = distribute_tensor_with_customization(weight, shard_fn, gather_fn) + self.weight = customized_distributed_tensor_to_param(sharded_weight) + + if bias: + bias = torch.empty(self.out_features, **factory_kwargs) + + with torch.no_grad(): + sharded_bias = distribute_tensor_with_customization(bias, shard_fn, gather_fn) + self.bias = customized_distributed_tensor_to_param(sharded_bias) + else: + self.bias = None + + # offset the seed with randomizer index and rank + seed = torch.random.initial_seed() + self.randomizer = create_randomizer_with_offset(seed, process_group=self.process_group) + + # init weights + self.reset_parameters(weight_initializer, bias_initializer) + + @staticmethod + def from_native_module(module: nn.Module, process_group: Union[ProcessGroup, List[ProcessGroup]], n_fused: int, + *args, **kwargs) -> ParallelModule: + r""" + Convert a fused `torch.nn.linear` layer to a parallelized linear layer. + + Args: + module (`nn.Linear`): The module to be converted. + process_group (`Union[ProcessGroup, List[ProcessGroup]]`): The process group to be used for weight sharding and communication. + n_fused (int): The number of layers to be fused. In common, Q,K,V are fused in one weight. + """ + # get the attributes + in_features = module.in_features + out_features = module.out_features + bias = module.bias is not None + device = module.weight.device + + # ensure only one process group is passed + if isinstance(process_group, (list, tuple)): + assert len(process_group) == 1, \ + f'Expected only one process group, got {len(process_group)}.' + process_group = process_group[0] + + linear_1d = FusedLinear1D_Col(in_features=in_features, + out_features=out_features, + bias=bias, + device=device, + process_group=process_group, + *args, + **kwargs) + + # TODO: copy the sharded weights + with torch.no_grad(): + sharded_weight = split_fused_qkv_in_gpt2_style(module.weight.data, + n_fused=n_fused, + process_group=process_group, + is_transposed=False) + linear_1d.weight.data.copy_(sharded_weight.data) + + if bias: + sharded_bias = split_fused_qkv_in_gpt2_style(module.bias.data, + n_fused=n_fused, + process_group=process_group, + is_transposed=False) + linear_1d.bias.data.copy_(sharded_bias.data) + + return linear_1d + + def reset_parameters(self, weight_initializer, bias_initializer) -> None: + with self.randomizer.fork_rng(enable_cpu=True): + fan_in, fan_out = self.in_features, self.out_features + weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out) + if self.bias is not None: + bias_initializer(self.bias, fan_in=fan_in) + + def forward(self, input_: Tensor) -> Tuple[Tensor, Tensor]: + assert input_.shape[-1] == self.weight.shape[-1], \ + 'Invalid shapes in Linear1D_Col forward: input={}, weight={}. Expected last dim of input {}.'.format( + input_.shape, self.weight.shape, self.weight.shape[-1]) + # Set up backprop all-reduce. + # input_parallel = reduce_backward(input_, self.process_group) + input_parallel = input_ + + # Matrix multiply. + bias = self.bias if not self.skip_bias_add else None + + output_parallel = linear_with_async_comm(input_parallel, self.weight, bias, self.process_group, True) + + if self.gather_output: + # All-gather across the partitions. + output = gather_forward_split_backward(output_parallel, dim=-1, process_group=self.process_group) + else: + output = output_parallel + + if self.skip_bias_add: + return output, self.bias + else: + return output diff --git a/colossalai/shardformer/modeling/sam.py b/colossalai/shardformer/modeling/sam.py new file mode 100644 index 000000000..00e2d744e --- /dev/null +++ b/colossalai/shardformer/modeling/sam.py @@ -0,0 +1,41 @@ +import torch +import torch.distributed as dist +from torch.distributed import ProcessGroup + + +def forward_fn(): + + def forward(self, hidden_states: torch.Tensor, output_attentions=False) -> torch.Tensor: + batch_size, height, width, _ = hidden_states.shape + # qkv with shape (3, batch_size, nHead, height * width, channel) + qkv = (self.qkv(hidden_states).reshape(batch_size, height * width, 3, self.num_attention_heads, + -1).permute(2, 0, 3, 1, 4)) + # q, k, v with shape (batch_size * nHead, height * width, channel) + query, key, value = qkv.reshape(3, batch_size * self.num_attention_heads, height * width, -1).unbind(0) + + attn_weights = (query * self.scale) @ key.transpose(-2, -1) + + if self.use_rel_pos: + attn_weights = self.add_decomposed_rel_pos(attn_weights, query, self.rel_pos_h, self.rel_pos_w, + (height, width), (height, width)) + + attn_weights = torch.nn.functional.softmax(attn_weights, dtype=torch.float32, dim=-1).to(query.dtype) + + # replace dropout process with added DropoutForParallelInput layer + # origin code: + # attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) + attn_probs = self.dropout_layer(attn_weights) + + attn_output = (attn_probs @ value).reshape(batch_size, self.num_attention_heads, height, width, -1) + attn_output = attn_output.permute(0, 2, 3, 1, 4).reshape(batch_size, height, width, -1) + + attn_output = self.proj(attn_output) + + if output_attentions: + outputs = (attn_output, attn_weights) + else: + outputs = (attn_output, None) + + return outputs + + return forward diff --git a/colossalai/shardformer/policies/auto_policy.py b/colossalai/shardformer/policies/auto_policy.py index d00a03c92..63ec8398f 100644 --- a/colossalai/shardformer/policies/auto_policy.py +++ b/colossalai/shardformer/policies/auto_policy.py @@ -104,6 +104,10 @@ _POLICY_LIST = { PolicyLocation(file_name="bloom", class_name="BloomForTokenClassificationPolicy"), "transformers.models.bloom.modeling_bloom.BloomForQuestionAnswering": PolicyLocation(file_name="bloom", class_name="BloomForQuestionAnsweringPolicy"), + + # Sam + "transformers.models.sam.modeling_sam.SamModel": + PolicyLocation(file_name="sam", class_name="SamModelPolicy"), } diff --git a/colossalai/shardformer/policies/sam.py b/colossalai/shardformer/policies/sam.py new file mode 100644 index 000000000..e75d63946 --- /dev/null +++ b/colossalai/shardformer/policies/sam.py @@ -0,0 +1,209 @@ +import torch.nn as nn + +import colossalai.shardformer.layer as col_nn + +from .._utils import getattr_, setattr_ +from ..modeling.sam import forward_fn +from .basepolicy import ModulePolicyDescription, Policy, SubModuleReplacementDescription + +__all__ = ['SamPolicy', 'SamModelPolicy'] + + +class SamPolicy(Policy): + + def config_sanity_check(self): + pass + + def preprocess(self): + return self.model + + def module_policy(self): + from transformers.models.sam.modeling_sam import ( + SamFeedForward, + SamTwoWayAttentionBlock, + SamTwoWayTransformer, + SamVisionAttention, + SamVisionLayer, + ) + + policy = {} + + if self.shard_config.enable_tensor_parallelism: + policy[SamVisionLayer] = ModulePolicyDescription(attribute_replacement={ + "attn.num_attention_heads": + self.model.config.vision_config.num_attention_heads // self.shard_config.tensor_parallel_size, + }, + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="attn.qkv", + target_module=col_nn.FusedLinear1D_Col, + kwargs={ + "n_fused": 3, + }, + ), + SubModuleReplacementDescription( + suffix="attn.proj", + target_module=col_nn.Linear1D_Row, + ), + SubModuleReplacementDescription( + suffix="mlp.lin1", + target_module=col_nn.Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="mlp.lin2", + target_module=col_nn.Linear1D_Row, + ) + ]) + policy[SamTwoWayAttentionBlock] = ModulePolicyDescription( + attribute_replacement={ + "self_attn.num_attention_heads": + self.model.config.mask_decoder_config.num_attention_heads // + self.shard_config.tensor_parallel_size, + }, + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="self_attn.q_proj", + target_module=col_nn.Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="self_attn.k_proj", + target_module=col_nn.Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="self_attn.v_proj", + target_module=col_nn.Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="self_attn.out_proj", + target_module=col_nn.Linear1D_Row, + ), + SubModuleReplacementDescription( + suffix="cross_attn_token_to_image.q_proj", + target_module=col_nn.Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="cross_attn_token_to_image.k_proj", + target_module=col_nn.Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="cross_attn_token_to_image.v_proj", + target_module=col_nn.Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="cross_attn_token_to_image.out_proj", + target_module=col_nn.Linear1D_Row, + ), + SubModuleReplacementDescription( + suffix="mlp.lin1", + target_module=col_nn.Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="mlp.lin2", + target_module=col_nn.Linear1D_Row, + ), + SubModuleReplacementDescription( + suffix="cross_attn_image_to_token.q_proj", + target_module=col_nn.Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="cross_attn_image_to_token.k_proj", + target_module=col_nn.Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="cross_attn_image_to_token.v_proj", + target_module=col_nn.Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="cross_attn_image_to_token.out_proj", + target_module=col_nn.Linear1D_Row, + ), + ]) + policy[SamTwoWayTransformer] = ModulePolicyDescription(attribute_replacement={ + "final_attn_token_to_image.num_attention_heads": + self.model.config.mask_decoder_config.num_attention_heads // self.shard_config.tensor_parallel_size, + }, + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="final_attn_token_to_image.q_proj", + target_module=col_nn.Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="final_attn_token_to_image.k_proj", + target_module=col_nn.Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="final_attn_token_to_image.v_proj", + target_module=col_nn.Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="final_attn_token_to_image.out_proj", + target_module=col_nn.Linear1D_Row, + ) + ]) + + # add `DropoutForParallelInput` layer to replace the useage of `nn.functional.dropout` + policy[SamVisionAttention] = ModulePolicyDescription(attribute_replacement={ + "dropout_layer": col_nn.DropoutForParallelInput(self.model.config.vision_config.attention_dropout) + }, + method_replacement={"forward": forward_fn()}, + sub_module_replacement=[]) + + # optimization configuration + if self.shard_config.enable_fused_normalization: + # Handle SamVisionLayer + self.append_or_create_submodule_replacement(description=[ + SubModuleReplacementDescription( + suffix="layer_norm1", + target_module=col_nn.FusedLayerNorm, + ), + SubModuleReplacementDescription( + suffix="layer_norm2", + target_module=col_nn.FusedLayerNorm, + ) + ], + policy=policy, + target_key=SamVisionLayer) + + # Handle SamTwoWayAttentionBlock + self.append_or_create_submodule_replacement(description=[ + SubModuleReplacementDescription( + suffix="layer_norm1", + target_module=col_nn.FusedLayerNorm, + ), + SubModuleReplacementDescription( + suffix="layer_norm2", + target_module=col_nn.FusedLayerNorm, + ), + SubModuleReplacementDescription( + suffix="layer_norm3", + target_module=col_nn.FusedLayerNorm, + ), + SubModuleReplacementDescription( + suffix="layer_norm4", + target_module=col_nn.FusedLayerNorm, + ) + ], + policy=policy, + target_key=SamTwoWayAttentionBlock) + + # Handle SamTwoWayTransformer + self.append_or_create_submodule_replacement(description=[ + SubModuleReplacementDescription( + suffix="layer_norm_final_attn", + target_module=col_nn.FusedLayerNorm, + ) + ], + policy=policy, + target_key=SamTwoWayTransformer) + + return policy + + def postprocess(self): + return self.model + + +# SamModel +class SamModelPolicy(SamPolicy): + + def __init__(self) -> None: + super().__init__() diff --git a/tests/kit/model_zoo/transformers/__init__.py b/tests/kit/model_zoo/transformers/__init__.py index a298767d1..a1bcb78dd 100644 --- a/tests/kit/model_zoo/transformers/__init__.py +++ b/tests/kit/model_zoo/transformers/__init__.py @@ -4,5 +4,6 @@ from .bloom import * from .gpt import * from .llama import * from .opt import * +from .sam import * from .t5 import * from .vit import * diff --git a/tests/kit/model_zoo/transformers/sam.py b/tests/kit/model_zoo/transformers/sam.py new file mode 100644 index 000000000..d850623f3 --- /dev/null +++ b/tests/kit/model_zoo/transformers/sam.py @@ -0,0 +1,52 @@ +import torch +import transformers + +from ..registry import ModelAttribute, model_zoo + +# =============================== +# Register single-image SAM +# =============================== + + +# define data gen function +def data_gen(): + # Generated from following code snippet + # + # from PIL import Image + # import requests + # from transformers import SamModel, SamProcessor + # + # model = SamModel.from_pretrained("facebook/sam-vit-base") + # processor = SamProcessor.from_pretrained("facebook/sam-vit-base") + # + # img_url = "https://huggingface.co/ybelkada/segment-anything/resolve/main/assets/car.png" + # raw_image = Image.open(requests.get(img_url, stream=True).raw).convert("RGB") + # input_points = [[[450, 600]]] # 2D localization of a window + # inputs = processor(raw_image, input_points=input_points, return_tensors="pt") + + pixel_values = torch.rand(1, 3, 1024, 1024, dtype=torch.float32) + original_sizes = torch.tensor([[1764, 2646]], dtype=torch.int64) + reshaped_input_sizes = torch.tensor([[683, 1024]], dtype=torch.int64) + input_points = torch.tensor([[[[174.1497, 232.3129]]]], dtype=torch.float64) + return dict(pixel_values=pixel_values, + original_sizes=original_sizes, + reshaped_input_sizes=reshaped_input_sizes, + input_points=input_points) + + +# define output transform function +output_transform_fn = lambda x: x + +# define loss funciton +loss_fn = lambda x: x.iou_scores.mean() + +config = transformers.SamConfig() +config.vision_config.num_hidden_layers = 2 + +# register the BERT variants +model_zoo.register(name='transformers_sam', + model_fn=lambda: transformers.SamModel(config), + data_gen_fn=data_gen, + output_transform_fn=output_transform_fn, + loss_fn=loss_fn, + model_attribute=ModelAttribute(has_control_flow=True)) diff --git a/tests/test_shardformer/test_layer/test_gpt2_qkv_fused_linear_1d.py b/tests/test_shardformer/test_layer/test_gpt2_qkv_fused_linear_1d.py new file mode 100644 index 000000000..9eeda93af --- /dev/null +++ b/tests/test_shardformer/test_layer/test_gpt2_qkv_fused_linear_1d.py @@ -0,0 +1,120 @@ +import torch +import torch.distributed as dist +import torch.nn as nn +from torch.testing import assert_close + +import colossalai +from colossalai.shardformer.layer import GPT2FusedLinearConv1D_Col, GPT2FusedLinearConv1D_Row +from colossalai.shardformer.layer.qkv_fused_linear import split_fused_qkv_in_gpt2_style +from colossalai.testing import rerun_if_address_is_in_use, spawn + + +# This code is copied from https://github.com/huggingface/transformers +class Conv1D(nn.Module): + """ + 1D-convolutional layer as defined by Radford et al. for OpenAI GPT (and also used in GPT-2). + + Basically works like a linear layer but the weights are transposed. + + Args: + nf (`int`): The number of output features. + nx (`int`): The number of input features. + """ + + def __init__(self, nf, nx): + super().__init__() + self.nf = nf + self.weight = nn.Parameter(torch.empty(nx, nf)) + self.bias = nn.Parameter(torch.zeros(nf)) + nn.init.normal_(self.weight, std=0.02) + + def forward(self, x): + size_out = x.size()[:-1] + (self.nf,) + x = torch.addmm(self.bias, x.view(-1, x.size(-1)), self.weight) + x = x.view(size_out) + return x + + +def rearrange(tensor: torch.Tensor, dim: int): + tensor = tensor.clone() + world_size = 2 + order = torch.arange(world_size * 3) + new_order = [] + for i in range(world_size): + new_order.append(order[i::world_size]) + new_order = torch.cat(new_order) + + tensor_chunks = torch.chunk(tensor, world_size * 3, dim=dim) + rearanged_tensor_chunks = [tensor_chunks[i] for i in new_order] + rearanged_tensor = torch.cat(rearanged_tensor_chunks, dim=dim) + return rearanged_tensor + + +def check_gpt2_linear_conv_1d_col(): + linear = Conv1D(192, 48).cuda() + linear_conv_col = GPT2FusedLinearConv1D_Col.from_native_module(linear, + process_group=None, + gather_output=True, + n_fused=3) + + assert linear.weight.shape == torch.Size([48, 192]) + assert linear.bias.shape == torch.Size([192]) + assert linear_conv_col.weight.shape == torch.Size([48, 96]) + assert linear_conv_col.bias.shape == torch.Size([96]) + + # ensure weights are reversibly loadable + linear_conv_col.load_state_dict(linear.state_dict()) + linear.load_state_dict(linear_conv_col.state_dict()) + + # check computation correctness + x = torch.rand(4, 48).cuda() + out = linear(x) + gather_out = linear_conv_col(x) + assert_close(rearrange(out, 1), gather_out) + + # check backward correctness + out.sum().backward() + gather_out.sum().backward() + + target_grad = split_fused_qkv_in_gpt2_style(linear.weight.grad, 3, None, True) + assert_close(target_grad, linear_conv_col.weight.grad) + + +def check_gpt2_linear_conv_1d_row(): + linear = Conv1D(192, 48).cuda() + linear_row = GPT2FusedLinearConv1D_Row.from_native_module(linear, process_group=None, parallel_input=False) + + assert linear.weight.shape == torch.Size([48, 192]) + assert linear_row.weight.shape == torch.Size([24, 192]) + assert linear_row.bias.shape == torch.Size([192]) + + # check computation correctness + x = torch.rand(4, 48).cuda() + out = linear(x) + gather_out = linear_row(x) + assert_close(out, gather_out) + + # check backward correctness + out.sum().backward() + gather_out.sum().backward() + + rank = dist.get_rank() + target_grad = torch.chunk(linear.weight.grad, 2, dim=0)[rank] + assert_close(target_grad, linear_row.weight.grad) + + +def run_dist(rank, world_size, port): + colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + + # test for linear conv + check_gpt2_linear_conv_1d_col() + check_gpt2_linear_conv_1d_row() + + +@rerun_if_address_is_in_use() +def test_gpt2_linearconv(): + spawn(run_dist, nprocs=2) + + +if __name__ == '__main__': + test_gpt2_linearconv() diff --git a/tests/test_shardformer/test_model/test_shard_sam.py b/tests/test_shardformer/test_model/test_shard_sam.py new file mode 100644 index 000000000..1d047d8e0 --- /dev/null +++ b/tests/test_shardformer/test_model/test_shard_sam.py @@ -0,0 +1,92 @@ +import pytest +import torch + +import colossalai +from colossalai.logging import disable_existing_loggers +from colossalai.tensor.d_tensor.api import is_customized_distributed_tensor, is_distributed_tensor +from colossalai.testing import ( + assert_hf_output_close, + clear_cache_before_run, + parameterize, + rerun_if_address_is_in_use, + spawn, +) +from tests.kit.model_zoo import model_zoo +from tests.test_shardformer.test_model._utils import build_model, run_forward + + +def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn): + # check forward + org_output, org_loss, shard_output, shard_loss = run_forward(org_model, sharded_model, data_gen_fn, + output_transform_fn, loss_fn) + assert_hf_output_close(org_output, shard_output, ignore_keys=['pred_masks']) + + # do backward + org_loss.backward() + shard_loss.backward() + + assert torch.allclose(org_loss, shard_loss, + atol=1e-5), f"shard model loss is not equal to orgin model loss\n{org_loss}\n{shard_loss}" + + # check grad + + sam = org_model + sharded_sam = sharded_model + + # compare mask decoder grad + + org_grad = sam.mask_decoder.transformer.layers[0].self_attn.q_proj.weight.grad + shard_grad = sharded_sam.mask_decoder.transformer.layers[0].self_attn.q_proj.weight.grad + shard_weight = sharded_sam.mask_decoder.transformer.layers[0].self_attn.q_proj.weight + + if is_distributed_tensor(shard_weight) or is_customized_distributed_tensor(shard_weight): + shard_grad_list = [torch.zeros([*shard_grad.shape]).to('cuda') for _ in range(2)] + shard_grad = torch.distributed.all_gather(shard_grad_list, shard_grad) + all_shard_grad = torch.cat(shard_grad_list, dim=0) + else: + all_shard_grad = shard_grad + assert torch.allclose(org_grad, all_shard_grad, + atol=1e-5), f"shard model grad is not equal to orgin model grad\n{org_grad}\n{all_shard_grad}" + + # compare vision_encoder grad + org_grad = sam.vision_encoder.layers[0].mlp.lin1.weight.grad + shard_grad = sharded_sam.vision_encoder.layers[0].mlp.lin1.weight.grad + shard_weight = sharded_sam.vision_encoder.layers[0].mlp.lin1.weight + + if is_distributed_tensor(shard_weight) or is_customized_distributed_tensor(shard_weight): + shard_grad_list = [torch.zeros([*shard_grad.shape]).to('cuda') for _ in range(2)] + shard_grad = torch.distributed.all_gather(shard_grad_list, shard_grad) + all_shard_grad = torch.cat(shard_grad_list, dim=0) + else: + all_shard_grad = shard_grad + + assert torch.allclose(org_grad, all_shard_grad, + atol=1e-5), f"shard model grad is not equal to orgin model grad\n{org_grad}\n{all_shard_grad}" + + +@parameterize('enable_fused_normalization', [True, False]) +@parameterize('enable_tensor_parallelism', [True, False]) +def run_sam_test(enable_fused_normalization, enable_tensor_parallelism): + sub_model_zoo = model_zoo.get_sub_registry('transformers_sam') + for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): + org_model, sharded_model = build_model(model_fn, enable_fused_normalization, enable_tensor_parallelism) + check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn) + + torch.cuda.empty_cache() + + +def check_sam(rank, world_size, port): + disable_existing_loggers() + colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + run_sam_test() + + +@pytest.mark.dist +@rerun_if_address_is_in_use() +@clear_cache_before_run() +def test_sam(): + spawn(check_sam, 2) + + +if __name__ == "__main__": + test_sam()