From 879301d0dad26a056e0e90d8c9c9d6cc4a662c9a Mon Sep 17 00:00:00 2001 From: FoolPlayer <45593998+FoolPlayer@users.noreply.github.com> Date: Tue, 25 Jul 2023 14:29:10 +0800 Subject: [PATCH] [shardformer] support Blip2 (#4243) * support base blip2 * add support for downstream blip2 model * update readme * add forward injection * skip not compatible models test * fix test for gemini and low_level_zero_pugin --- colossalai/shardformer/README.md | 3 +- colossalai/shardformer/modeling/blip2.py | 60 ++++ colossalai/shardformer/modeling/sam.py | 2 - .../shardformer/policies/auto_policy.py | 6 + colossalai/shardformer/policies/blip2.py | 304 ++++++++++++++++++ tests/kit/model_zoo/transformers/__init__.py | 1 + tests/kit/model_zoo/transformers/blip2.py | 61 ++++ .../test_model/test_shard_blip2.py | 107 ++++++ 8 files changed, 541 insertions(+), 3 deletions(-) create mode 100644 colossalai/shardformer/modeling/blip2.py create mode 100644 colossalai/shardformer/policies/blip2.py create mode 100644 tests/kit/model_zoo/transformers/blip2.py create mode 100644 tests/test_shardformer/test_model/test_shard_blip2.py diff --git a/colossalai/shardformer/README.md b/colossalai/shardformer/README.md index 3c322aabf..5489f97e4 100644 --- a/colossalai/shardformer/README.md +++ b/colossalai/shardformer/README.md @@ -104,7 +104,8 @@ We will follow this roadmap to develop Shardformer: - [ ] Audio - [x] Whisper - [ ] Multi-modal - - [ ] To be added + - [x] SAM + - [x] BLIP-2 ## 💡 API Design diff --git a/colossalai/shardformer/modeling/blip2.py b/colossalai/shardformer/modeling/blip2.py new file mode 100644 index 000000000..b7945423a --- /dev/null +++ b/colossalai/shardformer/modeling/blip2.py @@ -0,0 +1,60 @@ +from typing import Optional, Tuple, Union + +import torch +import torch.nn as nn + + +def forward_fn(): + + def forward( + self, + hidden_states: torch.Tensor, + head_mask: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = False, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + """Input shape: Batch x Time x Channel""" + + bsz, tgt_len, embed_dim = hidden_states.size() + + mixed_qkv = self.qkv(hidden_states) + + # modified from original code, which is: + # mixed_qkv = mixed_qkv.reshape(bsz, tgt_len, 3, self.num_heads, embed_dim // self.num_heads).permute( + # 2, 0, 3, 1, 4 + # ) + # to: + mixed_qkv = mixed_qkv.reshape(bsz, tgt_len, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) + query_states, key_states, value_states = ( + mixed_qkv[0], + mixed_qkv[1], + mixed_qkv[2], + ) + + # Take the dot product between "query" and "key" to get the raw attention scores. + attention_scores = torch.matmul(query_states, key_states.transpose(-1, -2)) + + attention_scores = attention_scores * self.scale + + # Normalize the attention scores to probabilities. + attention_probs = nn.functional.softmax(attention_scores, dim=-1) + + # This is actually dropping out entire tokens to attend to, which might + # seem a bit unusual, but is taken from the original Transformer paper. + attention_probs = self.dropout(attention_probs) + + # Mask heads if we want to + if head_mask is not None: + attention_probs = attention_probs * head_mask + + context_layer = torch.matmul(attention_probs, value_states).permute(0, 2, 1, 3) + + new_context_layer_shape = context_layer.size()[:-2] + (self.embed_dim,) + context_layer = context_layer.reshape(new_context_layer_shape) + + output = self.projection(context_layer) + + outputs = (output, attention_probs) if output_attentions else (output, None) + + return outputs + + return forward diff --git a/colossalai/shardformer/modeling/sam.py b/colossalai/shardformer/modeling/sam.py index 00e2d744e..63ebfe89d 100644 --- a/colossalai/shardformer/modeling/sam.py +++ b/colossalai/shardformer/modeling/sam.py @@ -1,6 +1,4 @@ import torch -import torch.distributed as dist -from torch.distributed import ProcessGroup def forward_fn(): diff --git a/colossalai/shardformer/policies/auto_policy.py b/colossalai/shardformer/policies/auto_policy.py index 90347a984..2a041af19 100644 --- a/colossalai/shardformer/policies/auto_policy.py +++ b/colossalai/shardformer/policies/auto_policy.py @@ -116,6 +116,12 @@ _POLICY_LIST = { # Sam "transformers.models.sam.modeling_sam.SamModel": PolicyLocation(file_name="sam", class_name="SamModelPolicy"), + + # Blip2 + "transformers.models.blip_2.modeling_blip_2.Blip2Model": + PolicyLocation(file_name="blip2", class_name="Blip2ModelPolicy"), + "transformers.models.blip_2.modeling_blip_2.Blip2ForConditionalGeneration": + PolicyLocation(file_name="blip2", class_name="Blip2ForConditionalGenerationPolicy"), } diff --git a/colossalai/shardformer/policies/blip2.py b/colossalai/shardformer/policies/blip2.py new file mode 100644 index 000000000..43aa1adc1 --- /dev/null +++ b/colossalai/shardformer/policies/blip2.py @@ -0,0 +1,304 @@ +import torch.nn as nn + +import colossalai.shardformer.layer as col_nn + +from .._utils import getattr_, setattr_ +from ..modeling.blip2 import forward_fn +from .basepolicy import ModulePolicyDescription, Policy, SubModuleReplacementDescription + +__all__ = ['BlipPolicy', 'BlipModelPolicy'] + + +class BlipPolicy(Policy): + + def config_sanity_check(self): + pass + + def preprocess(self): + # reshape the embedding layer + r""" + Reshape the Embedding layer to make the embedding dimension divisible by world_size + """ + # TODO: + vocab_size = self.model.config.qformer_config.vocab_size + world_size = self.shard_config.tensor_parallel_size + if vocab_size % world_size != 0: + new_vocab_size = vocab_size + world_size - vocab_size % world_size + self.model.resize_token_embeddings(new_vocab_size) + return self.model + + def module_policy(self): + from transformers.models.blip_2.modeling_blip_2 import ( + Blip2Attention, + Blip2EncoderLayer, + Blip2QFormerLayer, + Blip2QFormerModel, + Blip2VisionModel, + ) + from transformers.models.opt.modeling_opt import OPTDecoderLayer, OPTForCausalLM + + policy = {} + + if self.shard_config.enable_tensor_parallelism: + policy[Blip2EncoderLayer] = ModulePolicyDescription(attribute_replacement={ + "self_attn.num_heads": + self.model.config.vision_config.num_attention_heads // self.shard_config.tensor_parallel_size, + "self_attn.embed_dim": + self.model.config.vision_config.hidden_size // self.shard_config.tensor_parallel_size, + }, + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="self_attn.dropout", + target_module=col_nn.DropoutForParallelInput, + ), + SubModuleReplacementDescription( + suffix="self_attn.qkv", + target_module=col_nn.FusedLinear1D_Col, + kwargs={ + "n_fused": 3, + }), + SubModuleReplacementDescription( + suffix="self_attn.projection", + target_module=col_nn.Linear1D_Row, + ), + SubModuleReplacementDescription( + suffix="mlp.fc1", + target_module=col_nn.Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="mlp.fc2", + target_module=col_nn.Linear1D_Row, + ), + ]) + + policy[Blip2QFormerModel] = ModulePolicyDescription(sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="dropout", + target_module=col_nn.DropoutForParallelInput, + ), + ]) + + policy[Blip2QFormerLayer] = ModulePolicyDescription(attribute_replacement={ + "attention.attention.num_attention_heads": + self.model.config.qformer_config.num_attention_heads // self.shard_config.tensor_parallel_size, + "attention.attention.all_head_size": + self.model.config.qformer_config.hidden_size // self.shard_config.tensor_parallel_size, + "crossattention.attention.num_attention_heads": + self.model.config.qformer_config.num_attention_heads // self.shard_config.tensor_parallel_size, + "crossattention.attention.all_head_size": + self.model.config.qformer_config.hidden_size // self.shard_config.tensor_parallel_size, + }, + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="attention.attention.query", + target_module=col_nn.Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="attention.attention.key", + target_module=col_nn.Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="attention.attention.value", + target_module=col_nn.Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="attention.attention.dropout", + target_module=col_nn.DropoutForParallelInput, + ), + SubModuleReplacementDescription( + suffix="attention.output.dense", + target_module=col_nn.Linear1D_Row, + ), + SubModuleReplacementDescription( + suffix="attention.output.dropout", + target_module=col_nn.DropoutForParallelInput, + ), + SubModuleReplacementDescription( + suffix="crossattention.attention.query", + target_module=col_nn.Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="crossattention.attention.key", + target_module=col_nn.Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="crossattention.attention.value", + target_module=col_nn.Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="crossattention.attention.dropout", + target_module=col_nn.DropoutForParallelInput, + ), + SubModuleReplacementDescription( + suffix="crossattention.output.dense", + target_module=col_nn.Linear1D_Row, + ), + SubModuleReplacementDescription( + suffix="crossattention.output.dropout", + target_module=col_nn.DropoutForParallelInput, + ), + SubModuleReplacementDescription( + suffix="intermediate_query.dense", + target_module=col_nn.Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="output_query.dense", + target_module=col_nn.Linear1D_Row, + ), + SubModuleReplacementDescription( + suffix="output_query.dropout", + target_module=col_nn.DropoutForParallelInput, + ) + ]) + + policy[OPTDecoderLayer] = ModulePolicyDescription(attribute_replacement={ + "self_attn.embed_dim": + self.model.config.text_config.hidden_size // self.shard_config.tensor_parallel_size, + "self_attn.num_heads": + self.model.config.text_config.num_attention_heads // self.shard_config.tensor_parallel_size + }, + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="self_attn.q_proj", + target_module=col_nn.Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="self_attn.k_proj", + target_module=col_nn.Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="self_attn.v_proj", + target_module=col_nn.Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="self_attn.out_proj", + target_module=col_nn.Linear1D_Row, + ), + SubModuleReplacementDescription( + suffix="fc1", + target_module=col_nn.Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="fc2", + target_module=col_nn.Linear1D_Row, + ) + ]) + + policy[OPTForCausalLM] = ModulePolicyDescription(sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="model.decoder.embed_tokens", + target_module=col_nn.VocabParallelEmbedding1D, + ), + SubModuleReplacementDescription( + suffix="lm_head", + target_module=col_nn.Linear1D_Col, + kwargs={"gather_output": True}, + ), + ]) + + policy[Blip2Attention] = ModulePolicyDescription(method_replacement={"forward": forward_fn()}) + + # optimization configuration + if self.shard_config.enable_fused_normalization: + # Handle Blip2EncoderLayer layer + self.append_or_create_submodule_replacement(description=[ + SubModuleReplacementDescription( + suffix="layer_norm1", + target_module=col_nn.FusedLayerNorm, + ), + SubModuleReplacementDescription( + suffix="layer_norm2", + target_module=col_nn.FusedLayerNorm, + ) + ], + policy=policy, + target_key=Blip2EncoderLayer) + + # handle Blip2VisionModel layer + self.append_or_create_submodule_replacement(description=[ + SubModuleReplacementDescription( + suffix="post_layernorm", + target_module=col_nn.FusedLayerNorm, + ) + ], + policy=policy, + target_key=Blip2VisionModel) + + # handle Blip2VisionModel layer + self.append_or_create_submodule_replacement( + description=[SubModuleReplacementDescription( + suffix="layernorm", + target_module=col_nn.FusedLayerNorm, + )], + policy=policy, + target_key=Blip2QFormerModel) + + # handle Blip2QFormerLayer layer + self.append_or_create_submodule_replacement(description=[ + SubModuleReplacementDescription( + suffix="attention.output.LayerNorm", + target_module=col_nn.FusedLayerNorm, + ), + SubModuleReplacementDescription( + suffix="crossattention.output.LayerNorm", + target_module=col_nn.FusedLayerNorm, + ), + SubModuleReplacementDescription( + suffix="output_query.LayerNorm", + target_module=col_nn.FusedLayerNorm, + ) + ], + policy=policy, + target_key=Blip2QFormerLayer) + + # handle OPTForCausalLM layer + self.append_or_create_submodule_replacement(description=[ + SubModuleReplacementDescription( + suffix="model.decoder.final_layer_norm", + target_module=col_nn.FusedLayerNorm, + ) + ], + policy=policy, + target_key=OPTForCausalLM) + + # handle OPTDecoderLayer layer + self.append_or_create_submodule_replacement(description=[ + SubModuleReplacementDescription( + suffix="self_attn_layer_norm", + target_module=col_nn.FusedLayerNorm, + ), + SubModuleReplacementDescription( + suffix="final_layer_norm", + target_module=col_nn.FusedLayerNorm, + ) + ], + policy=policy, + target_key=OPTDecoderLayer) + + return policy + + def postprocess(self): + binding_map = { + 'language_model.model.decoder.embed_tokens': 'language_model.lm_head', + } + + for k, v in binding_map.items(): + src_mod = getattr_(self.model, k) + dst_mod = getattr_(self.model, v) + dst_mod.weight = src_mod.weight + + return self.model + + +# Blip2Model +class Blip2ModelPolicy(BlipPolicy): + + def __init__(self) -> None: + super().__init__() + + +# Blip2ForConditionalGeneration +class Blip2ForConditionalGenerationPolicy(BlipPolicy): + + def __init__(self) -> None: + super().__init__() diff --git a/tests/kit/model_zoo/transformers/__init__.py b/tests/kit/model_zoo/transformers/__init__.py index 08a118e57..823ca032f 100644 --- a/tests/kit/model_zoo/transformers/__init__.py +++ b/tests/kit/model_zoo/transformers/__init__.py @@ -1,5 +1,6 @@ from .albert import * from .bert import * +from .blip2 import * from .bloom import * from .chatglm import * from .gpt import * diff --git a/tests/kit/model_zoo/transformers/blip2.py b/tests/kit/model_zoo/transformers/blip2.py new file mode 100644 index 000000000..7338f740b --- /dev/null +++ b/tests/kit/model_zoo/transformers/blip2.py @@ -0,0 +1,61 @@ +import torch +import transformers + +from ..registry import ModelAttribute, model_zoo + +# =============================== +# Register single-image SAM +# =============================== + + +# define data gen function +def data_gen(): + # Generated from following code snippet + # + # from PIL import Image + # import requests + # from transformers import Blip2Processor, Blip2Model + # import torch + + # processor = Blip2Processor.from_pretrained("Salesforce/blip2-opt-2.7b") + # url = "http://images.cocodataset.org/val2017/000000039769.jpg" + # image = Image.open(requests.get(url, stream=True).raw) + + # prompt = "Question: how many cats are there? Answer:" + # inputs = processor(images=image, text=prompt, return_tensors="pt").to(device, torch.float16) + + pixel_values = torch.rand(1, 3, 224, 224, dtype=torch.float32) + input_ids = torch.tensor([[2, 45641, 35, 141, 171, 10017, 32, 89, 116, 31652, 35]], dtype=torch.int64) + attention_mask = torch.tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]], dtype=torch.int64) + labels = torch.tensor([[34, 56]], dtype=torch.int64) + return dict(pixel_values=pixel_values, input_ids=input_ids, attention_mask=attention_mask, labels=labels) + + +# define output transform function +output_transform_fn = lambda x: x + +# define loss funciton +loss_fn_blip2_model = lambda x: x.loss + +config = transformers.Blip2Config() +config.text_config.num_hidden_layers = 1 +config.qformer_config.num_hidden_layers = 1 +config.vision_config.num_hidden_layers = 1 +config.qformer_config.attention_probs_dropout_prob = 0 +config.qformer_config.hidden_dropout_prob = 0 +config.text_config.dropout = 0 + +# register the blip2 variants +model_zoo.register(name='transformers_blip2', + model_fn=lambda: transformers.Blip2Model(config), + data_gen_fn=data_gen, + output_transform_fn=output_transform_fn, + loss_fn=loss_fn_blip2_model, + model_attribute=ModelAttribute(has_control_flow=True)) + +model_zoo.register(name='transformers_blip2_conditional_gerneration', + model_fn=lambda: transformers.Blip2ForConditionalGeneration(config), + data_gen_fn=data_gen, + output_transform_fn=output_transform_fn, + loss_fn=loss_fn_blip2_model, + model_attribute=ModelAttribute(has_control_flow=True)) diff --git a/tests/test_shardformer/test_model/test_shard_blip2.py b/tests/test_shardformer/test_model/test_shard_blip2.py new file mode 100644 index 000000000..f96299e55 --- /dev/null +++ b/tests/test_shardformer/test_model/test_shard_blip2.py @@ -0,0 +1,107 @@ +import pytest +import torch + +import colossalai +from colossalai.logging import disable_existing_loggers +from colossalai.tensor.d_tensor.api import is_customized_distributed_tensor, is_distributed_tensor +from colossalai.testing import ( + assert_hf_output_close, + clear_cache_before_run, + parameterize, + rerun_if_address_is_in_use, + spawn, +) +from tests.kit.model_zoo import model_zoo +from tests.test_shardformer.test_model._utils import build_model, run_forward + + +def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn): + # check forward + org_output, org_loss, shard_output, shard_loss = run_forward(org_model, sharded_model, data_gen_fn, + output_transform_fn, loss_fn) + assert_hf_output_close(org_output, shard_output, ignore_keys=['past_key_values']) + + # do backward + org_loss.backward() + shard_loss.backward() + + assert torch.allclose(org_loss, shard_loss, + atol=1e-5), f"shard model loss is not equal to orgin model loss\n{org_loss}\n{shard_loss}" + + # check grad + + blip2 = org_model + sharded_blip2 = sharded_model + + # compare vision_model grad + + org_grad = blip2.vision_model.encoder.layers[0].self_attn.qkv.weight.grad + shard_grad = sharded_blip2.vision_model.encoder.layers[0].self_attn.qkv.weight.grad + shard_weight = sharded_blip2.vision_model.encoder.layers[0].self_attn.qkv.weight + + if is_distributed_tensor(shard_weight) or is_customized_distributed_tensor(shard_weight): + shard_grad_list = [torch.zeros([*shard_grad.shape]).to('cuda') for _ in range(2)] + shard_grad = torch.distributed.all_gather(shard_grad_list, shard_grad) + all_shard_grad = torch.cat(shard_grad_list, dim=0) + else: + all_shard_grad = shard_grad + assert torch.allclose(org_grad, all_shard_grad, + atol=1e-5), f"shard model grad is not equal to orgin model grad\n{org_grad}\n{all_shard_grad}" + + # compare qformer grad + org_grad = blip2.qformer.encoder.layer[0].attention.attention.query.weight.grad + shard_grad = sharded_blip2.qformer.encoder.layer[0].attention.attention.query.weight.grad + shard_weight = sharded_blip2.qformer.encoder.layer[0].attention.attention.query.weight + + if is_distributed_tensor(shard_weight) or is_customized_distributed_tensor(shard_weight): + shard_grad_list = [torch.zeros([*shard_grad.shape]).to('cuda') for _ in range(2)] + shard_grad = torch.distributed.all_gather(shard_grad_list, shard_grad) + all_shard_grad = torch.cat(shard_grad_list, dim=0) + else: + all_shard_grad = shard_grad + + assert torch.allclose(org_grad, all_shard_grad, + atol=1e-5), f"shard model grad is not equal to orgin model grad\n{org_grad}\n{all_shard_grad}" + + # compare language_model grad + org_grad = blip2.language_model.model.decoder.layers[0].self_attn.k_proj.weight.grad + shard_grad = sharded_blip2.language_model.model.decoder.layers[0].self_attn.k_proj.weight.grad + shard_weight = sharded_blip2.language_model.model.decoder.layers[0].self_attn.k_proj.weight + + if is_distributed_tensor(shard_weight) or is_customized_distributed_tensor(shard_weight): + shard_grad_list = [torch.zeros([*shard_grad.shape]).to('cuda') for _ in range(2)] + shard_grad = torch.distributed.all_gather(shard_grad_list, shard_grad) + all_shard_grad = torch.cat(shard_grad_list, dim=0) + else: + all_shard_grad = shard_grad + + assert torch.allclose(org_grad, all_shard_grad, + atol=1e-5), f"shard model grad is not equal to orgin model grad\n{org_grad}\n{all_shard_grad}" + + +@parameterize('enable_fused_normalization', [True, False]) +@parameterize('enable_tensor_parallelism', [True, False]) +def run_blip2_test(enable_fused_normalization, enable_tensor_parallelism): + sub_model_zoo = model_zoo.get_sub_registry('transformers_blip2') + for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): + org_model, sharded_model = build_model(model_fn, enable_fused_normalization, enable_tensor_parallelism) + check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn) + + torch.cuda.empty_cache() + + +def check_blip2(rank, world_size, port): + disable_existing_loggers() + colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + run_blip2_test() + + +@pytest.mark.dist +@rerun_if_address_is_in_use() +@clear_cache_before_run() +def test_blip2(): + spawn(check_blip2, 2) + + +if __name__ == "__main__": + test_blip2()