[shardformer] support Blip2 (#4243)

* support base blip2

* add support for downstream blip2 model

* update readme

* add forward injection

* skip not compatible models test

* fix test for gemini and low_level_zero_pugin
pull/4445/head
FoolPlayer 2023-07-25 14:29:10 +08:00 committed by Hongxin Liu
parent 8120eca0c0
commit 879301d0da
8 changed files with 541 additions and 3 deletions

View File

@ -104,7 +104,8 @@ We will follow this roadmap to develop Shardformer:
- [ ] Audio
- [x] Whisper
- [ ] Multi-modal
- [ ] To be added
- [x] SAM
- [x] BLIP-2
## 💡 API Design

View File

@ -0,0 +1,60 @@
from typing import Optional, Tuple, Union
import torch
import torch.nn as nn
def forward_fn():
def forward(
self,
hidden_states: torch.Tensor,
head_mask: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
"""Input shape: Batch x Time x Channel"""
bsz, tgt_len, embed_dim = hidden_states.size()
mixed_qkv = self.qkv(hidden_states)
# modified from original code, which is:
# mixed_qkv = mixed_qkv.reshape(bsz, tgt_len, 3, self.num_heads, embed_dim // self.num_heads).permute(
# 2, 0, 3, 1, 4
# )
# to:
mixed_qkv = mixed_qkv.reshape(bsz, tgt_len, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
query_states, key_states, value_states = (
mixed_qkv[0],
mixed_qkv[1],
mixed_qkv[2],
)
# Take the dot product between "query" and "key" to get the raw attention scores.
attention_scores = torch.matmul(query_states, key_states.transpose(-1, -2))
attention_scores = attention_scores * self.scale
# Normalize the attention scores to probabilities.
attention_probs = nn.functional.softmax(attention_scores, dim=-1)
# This is actually dropping out entire tokens to attend to, which might
# seem a bit unusual, but is taken from the original Transformer paper.
attention_probs = self.dropout(attention_probs)
# Mask heads if we want to
if head_mask is not None:
attention_probs = attention_probs * head_mask
context_layer = torch.matmul(attention_probs, value_states).permute(0, 2, 1, 3)
new_context_layer_shape = context_layer.size()[:-2] + (self.embed_dim,)
context_layer = context_layer.reshape(new_context_layer_shape)
output = self.projection(context_layer)
outputs = (output, attention_probs) if output_attentions else (output, None)
return outputs
return forward

View File

@ -1,6 +1,4 @@
import torch
import torch.distributed as dist
from torch.distributed import ProcessGroup
def forward_fn():

View File

@ -116,6 +116,12 @@ _POLICY_LIST = {
# Sam
"transformers.models.sam.modeling_sam.SamModel":
PolicyLocation(file_name="sam", class_name="SamModelPolicy"),
# Blip2
"transformers.models.blip_2.modeling_blip_2.Blip2Model":
PolicyLocation(file_name="blip2", class_name="Blip2ModelPolicy"),
"transformers.models.blip_2.modeling_blip_2.Blip2ForConditionalGeneration":
PolicyLocation(file_name="blip2", class_name="Blip2ForConditionalGenerationPolicy"),
}

View File

@ -0,0 +1,304 @@
import torch.nn as nn
import colossalai.shardformer.layer as col_nn
from .._utils import getattr_, setattr_
from ..modeling.blip2 import forward_fn
from .basepolicy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
__all__ = ['BlipPolicy', 'BlipModelPolicy']
class BlipPolicy(Policy):
def config_sanity_check(self):
pass
def preprocess(self):
# reshape the embedding layer
r"""
Reshape the Embedding layer to make the embedding dimension divisible by world_size
"""
# TODO:
vocab_size = self.model.config.qformer_config.vocab_size
world_size = self.shard_config.tensor_parallel_size
if vocab_size % world_size != 0:
new_vocab_size = vocab_size + world_size - vocab_size % world_size
self.model.resize_token_embeddings(new_vocab_size)
return self.model
def module_policy(self):
from transformers.models.blip_2.modeling_blip_2 import (
Blip2Attention,
Blip2EncoderLayer,
Blip2QFormerLayer,
Blip2QFormerModel,
Blip2VisionModel,
)
from transformers.models.opt.modeling_opt import OPTDecoderLayer, OPTForCausalLM
policy = {}
if self.shard_config.enable_tensor_parallelism:
policy[Blip2EncoderLayer] = ModulePolicyDescription(attribute_replacement={
"self_attn.num_heads":
self.model.config.vision_config.num_attention_heads // self.shard_config.tensor_parallel_size,
"self_attn.embed_dim":
self.model.config.vision_config.hidden_size // self.shard_config.tensor_parallel_size,
},
sub_module_replacement=[
SubModuleReplacementDescription(
suffix="self_attn.dropout",
target_module=col_nn.DropoutForParallelInput,
),
SubModuleReplacementDescription(
suffix="self_attn.qkv",
target_module=col_nn.FusedLinear1D_Col,
kwargs={
"n_fused": 3,
}),
SubModuleReplacementDescription(
suffix="self_attn.projection",
target_module=col_nn.Linear1D_Row,
),
SubModuleReplacementDescription(
suffix="mlp.fc1",
target_module=col_nn.Linear1D_Col,
),
SubModuleReplacementDescription(
suffix="mlp.fc2",
target_module=col_nn.Linear1D_Row,
),
])
policy[Blip2QFormerModel] = ModulePolicyDescription(sub_module_replacement=[
SubModuleReplacementDescription(
suffix="dropout",
target_module=col_nn.DropoutForParallelInput,
),
])
policy[Blip2QFormerLayer] = ModulePolicyDescription(attribute_replacement={
"attention.attention.num_attention_heads":
self.model.config.qformer_config.num_attention_heads // self.shard_config.tensor_parallel_size,
"attention.attention.all_head_size":
self.model.config.qformer_config.hidden_size // self.shard_config.tensor_parallel_size,
"crossattention.attention.num_attention_heads":
self.model.config.qformer_config.num_attention_heads // self.shard_config.tensor_parallel_size,
"crossattention.attention.all_head_size":
self.model.config.qformer_config.hidden_size // self.shard_config.tensor_parallel_size,
},
sub_module_replacement=[
SubModuleReplacementDescription(
suffix="attention.attention.query",
target_module=col_nn.Linear1D_Col,
),
SubModuleReplacementDescription(
suffix="attention.attention.key",
target_module=col_nn.Linear1D_Col,
),
SubModuleReplacementDescription(
suffix="attention.attention.value",
target_module=col_nn.Linear1D_Col,
),
SubModuleReplacementDescription(
suffix="attention.attention.dropout",
target_module=col_nn.DropoutForParallelInput,
),
SubModuleReplacementDescription(
suffix="attention.output.dense",
target_module=col_nn.Linear1D_Row,
),
SubModuleReplacementDescription(
suffix="attention.output.dropout",
target_module=col_nn.DropoutForParallelInput,
),
SubModuleReplacementDescription(
suffix="crossattention.attention.query",
target_module=col_nn.Linear1D_Col,
),
SubModuleReplacementDescription(
suffix="crossattention.attention.key",
target_module=col_nn.Linear1D_Col,
),
SubModuleReplacementDescription(
suffix="crossattention.attention.value",
target_module=col_nn.Linear1D_Col,
),
SubModuleReplacementDescription(
suffix="crossattention.attention.dropout",
target_module=col_nn.DropoutForParallelInput,
),
SubModuleReplacementDescription(
suffix="crossattention.output.dense",
target_module=col_nn.Linear1D_Row,
),
SubModuleReplacementDescription(
suffix="crossattention.output.dropout",
target_module=col_nn.DropoutForParallelInput,
),
SubModuleReplacementDescription(
suffix="intermediate_query.dense",
target_module=col_nn.Linear1D_Col,
),
SubModuleReplacementDescription(
suffix="output_query.dense",
target_module=col_nn.Linear1D_Row,
),
SubModuleReplacementDescription(
suffix="output_query.dropout",
target_module=col_nn.DropoutForParallelInput,
)
])
policy[OPTDecoderLayer] = ModulePolicyDescription(attribute_replacement={
"self_attn.embed_dim":
self.model.config.text_config.hidden_size // self.shard_config.tensor_parallel_size,
"self_attn.num_heads":
self.model.config.text_config.num_attention_heads // self.shard_config.tensor_parallel_size
},
sub_module_replacement=[
SubModuleReplacementDescription(
suffix="self_attn.q_proj",
target_module=col_nn.Linear1D_Col,
),
SubModuleReplacementDescription(
suffix="self_attn.k_proj",
target_module=col_nn.Linear1D_Col,
),
SubModuleReplacementDescription(
suffix="self_attn.v_proj",
target_module=col_nn.Linear1D_Col,
),
SubModuleReplacementDescription(
suffix="self_attn.out_proj",
target_module=col_nn.Linear1D_Row,
),
SubModuleReplacementDescription(
suffix="fc1",
target_module=col_nn.Linear1D_Col,
),
SubModuleReplacementDescription(
suffix="fc2",
target_module=col_nn.Linear1D_Row,
)
])
policy[OPTForCausalLM] = ModulePolicyDescription(sub_module_replacement=[
SubModuleReplacementDescription(
suffix="model.decoder.embed_tokens",
target_module=col_nn.VocabParallelEmbedding1D,
),
SubModuleReplacementDescription(
suffix="lm_head",
target_module=col_nn.Linear1D_Col,
kwargs={"gather_output": True},
),
])
policy[Blip2Attention] = ModulePolicyDescription(method_replacement={"forward": forward_fn()})
# optimization configuration
if self.shard_config.enable_fused_normalization:
# Handle Blip2EncoderLayer layer
self.append_or_create_submodule_replacement(description=[
SubModuleReplacementDescription(
suffix="layer_norm1",
target_module=col_nn.FusedLayerNorm,
),
SubModuleReplacementDescription(
suffix="layer_norm2",
target_module=col_nn.FusedLayerNorm,
)
],
policy=policy,
target_key=Blip2EncoderLayer)
# handle Blip2VisionModel layer
self.append_or_create_submodule_replacement(description=[
SubModuleReplacementDescription(
suffix="post_layernorm",
target_module=col_nn.FusedLayerNorm,
)
],
policy=policy,
target_key=Blip2VisionModel)
# handle Blip2VisionModel layer
self.append_or_create_submodule_replacement(
description=[SubModuleReplacementDescription(
suffix="layernorm",
target_module=col_nn.FusedLayerNorm,
)],
policy=policy,
target_key=Blip2QFormerModel)
# handle Blip2QFormerLayer layer
self.append_or_create_submodule_replacement(description=[
SubModuleReplacementDescription(
suffix="attention.output.LayerNorm",
target_module=col_nn.FusedLayerNorm,
),
SubModuleReplacementDescription(
suffix="crossattention.output.LayerNorm",
target_module=col_nn.FusedLayerNorm,
),
SubModuleReplacementDescription(
suffix="output_query.LayerNorm",
target_module=col_nn.FusedLayerNorm,
)
],
policy=policy,
target_key=Blip2QFormerLayer)
# handle OPTForCausalLM layer
self.append_or_create_submodule_replacement(description=[
SubModuleReplacementDescription(
suffix="model.decoder.final_layer_norm",
target_module=col_nn.FusedLayerNorm,
)
],
policy=policy,
target_key=OPTForCausalLM)
# handle OPTDecoderLayer layer
self.append_or_create_submodule_replacement(description=[
SubModuleReplacementDescription(
suffix="self_attn_layer_norm",
target_module=col_nn.FusedLayerNorm,
),
SubModuleReplacementDescription(
suffix="final_layer_norm",
target_module=col_nn.FusedLayerNorm,
)
],
policy=policy,
target_key=OPTDecoderLayer)
return policy
def postprocess(self):
binding_map = {
'language_model.model.decoder.embed_tokens': 'language_model.lm_head',
}
for k, v in binding_map.items():
src_mod = getattr_(self.model, k)
dst_mod = getattr_(self.model, v)
dst_mod.weight = src_mod.weight
return self.model
# Blip2Model
class Blip2ModelPolicy(BlipPolicy):
def __init__(self) -> None:
super().__init__()
# Blip2ForConditionalGeneration
class Blip2ForConditionalGenerationPolicy(BlipPolicy):
def __init__(self) -> None:
super().__init__()

View File

@ -1,5 +1,6 @@
from .albert import *
from .bert import *
from .blip2 import *
from .bloom import *
from .chatglm import *
from .gpt import *

View File

@ -0,0 +1,61 @@
import torch
import transformers
from ..registry import ModelAttribute, model_zoo
# ===============================
# Register single-image SAM
# ===============================
# define data gen function
def data_gen():
# Generated from following code snippet
#
# from PIL import Image
# import requests
# from transformers import Blip2Processor, Blip2Model
# import torch
# processor = Blip2Processor.from_pretrained("Salesforce/blip2-opt-2.7b")
# url = "http://images.cocodataset.org/val2017/000000039769.jpg"
# image = Image.open(requests.get(url, stream=True).raw)
# prompt = "Question: how many cats are there? Answer:"
# inputs = processor(images=image, text=prompt, return_tensors="pt").to(device, torch.float16)
pixel_values = torch.rand(1, 3, 224, 224, dtype=torch.float32)
input_ids = torch.tensor([[2, 45641, 35, 141, 171, 10017, 32, 89, 116, 31652, 35]], dtype=torch.int64)
attention_mask = torch.tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]], dtype=torch.int64)
labels = torch.tensor([[34, 56]], dtype=torch.int64)
return dict(pixel_values=pixel_values, input_ids=input_ids, attention_mask=attention_mask, labels=labels)
# define output transform function
output_transform_fn = lambda x: x
# define loss funciton
loss_fn_blip2_model = lambda x: x.loss
config = transformers.Blip2Config()
config.text_config.num_hidden_layers = 1
config.qformer_config.num_hidden_layers = 1
config.vision_config.num_hidden_layers = 1
config.qformer_config.attention_probs_dropout_prob = 0
config.qformer_config.hidden_dropout_prob = 0
config.text_config.dropout = 0
# register the blip2 variants
model_zoo.register(name='transformers_blip2',
model_fn=lambda: transformers.Blip2Model(config),
data_gen_fn=data_gen,
output_transform_fn=output_transform_fn,
loss_fn=loss_fn_blip2_model,
model_attribute=ModelAttribute(has_control_flow=True))
model_zoo.register(name='transformers_blip2_conditional_gerneration',
model_fn=lambda: transformers.Blip2ForConditionalGeneration(config),
data_gen_fn=data_gen,
output_transform_fn=output_transform_fn,
loss_fn=loss_fn_blip2_model,
model_attribute=ModelAttribute(has_control_flow=True))

View File

@ -0,0 +1,107 @@
import pytest
import torch
import colossalai
from colossalai.logging import disable_existing_loggers
from colossalai.tensor.d_tensor.api import is_customized_distributed_tensor, is_distributed_tensor
from colossalai.testing import (
assert_hf_output_close,
clear_cache_before_run,
parameterize,
rerun_if_address_is_in_use,
spawn,
)
from tests.kit.model_zoo import model_zoo
from tests.test_shardformer.test_model._utils import build_model, run_forward
def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn):
# check forward
org_output, org_loss, shard_output, shard_loss = run_forward(org_model, sharded_model, data_gen_fn,
output_transform_fn, loss_fn)
assert_hf_output_close(org_output, shard_output, ignore_keys=['past_key_values'])
# do backward
org_loss.backward()
shard_loss.backward()
assert torch.allclose(org_loss, shard_loss,
atol=1e-5), f"shard model loss is not equal to orgin model loss\n{org_loss}\n{shard_loss}"
# check grad
blip2 = org_model
sharded_blip2 = sharded_model
# compare vision_model grad
org_grad = blip2.vision_model.encoder.layers[0].self_attn.qkv.weight.grad
shard_grad = sharded_blip2.vision_model.encoder.layers[0].self_attn.qkv.weight.grad
shard_weight = sharded_blip2.vision_model.encoder.layers[0].self_attn.qkv.weight
if is_distributed_tensor(shard_weight) or is_customized_distributed_tensor(shard_weight):
shard_grad_list = [torch.zeros([*shard_grad.shape]).to('cuda') for _ in range(2)]
shard_grad = torch.distributed.all_gather(shard_grad_list, shard_grad)
all_shard_grad = torch.cat(shard_grad_list, dim=0)
else:
all_shard_grad = shard_grad
assert torch.allclose(org_grad, all_shard_grad,
atol=1e-5), f"shard model grad is not equal to orgin model grad\n{org_grad}\n{all_shard_grad}"
# compare qformer grad
org_grad = blip2.qformer.encoder.layer[0].attention.attention.query.weight.grad
shard_grad = sharded_blip2.qformer.encoder.layer[0].attention.attention.query.weight.grad
shard_weight = sharded_blip2.qformer.encoder.layer[0].attention.attention.query.weight
if is_distributed_tensor(shard_weight) or is_customized_distributed_tensor(shard_weight):
shard_grad_list = [torch.zeros([*shard_grad.shape]).to('cuda') for _ in range(2)]
shard_grad = torch.distributed.all_gather(shard_grad_list, shard_grad)
all_shard_grad = torch.cat(shard_grad_list, dim=0)
else:
all_shard_grad = shard_grad
assert torch.allclose(org_grad, all_shard_grad,
atol=1e-5), f"shard model grad is not equal to orgin model grad\n{org_grad}\n{all_shard_grad}"
# compare language_model grad
org_grad = blip2.language_model.model.decoder.layers[0].self_attn.k_proj.weight.grad
shard_grad = sharded_blip2.language_model.model.decoder.layers[0].self_attn.k_proj.weight.grad
shard_weight = sharded_blip2.language_model.model.decoder.layers[0].self_attn.k_proj.weight
if is_distributed_tensor(shard_weight) or is_customized_distributed_tensor(shard_weight):
shard_grad_list = [torch.zeros([*shard_grad.shape]).to('cuda') for _ in range(2)]
shard_grad = torch.distributed.all_gather(shard_grad_list, shard_grad)
all_shard_grad = torch.cat(shard_grad_list, dim=0)
else:
all_shard_grad = shard_grad
assert torch.allclose(org_grad, all_shard_grad,
atol=1e-5), f"shard model grad is not equal to orgin model grad\n{org_grad}\n{all_shard_grad}"
@parameterize('enable_fused_normalization', [True, False])
@parameterize('enable_tensor_parallelism', [True, False])
def run_blip2_test(enable_fused_normalization, enable_tensor_parallelism):
sub_model_zoo = model_zoo.get_sub_registry('transformers_blip2')
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
org_model, sharded_model = build_model(model_fn, enable_fused_normalization, enable_tensor_parallelism)
check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn)
torch.cuda.empty_cache()
def check_blip2(rank, world_size, port):
disable_existing_loggers()
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
run_blip2_test()
@pytest.mark.dist
@rerun_if_address_is_in_use()
@clear_cache_before_run()
def test_blip2():
spawn(check_blip2, 2)
if __name__ == "__main__":
test_blip2()