mirror of https://github.com/hpcaitech/ColossalAI
[shardformer] support Blip2 (#4243)
* support base blip2 * add support for downstream blip2 model * update readme * add forward injection * skip not compatible models test * fix test for gemini and low_level_zero_puginpull/4445/head
parent
8120eca0c0
commit
879301d0da
|
@ -104,7 +104,8 @@ We will follow this roadmap to develop Shardformer:
|
|||
- [ ] Audio
|
||||
- [x] Whisper
|
||||
- [ ] Multi-modal
|
||||
- [ ] To be added
|
||||
- [x] SAM
|
||||
- [x] BLIP-2
|
||||
|
||||
## 💡 API Design
|
||||
|
||||
|
|
|
@ -0,0 +1,60 @@
|
|||
from typing import Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
def forward_fn():
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
head_mask: Optional[torch.Tensor] = None,
|
||||
output_attentions: Optional[bool] = False,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||
"""Input shape: Batch x Time x Channel"""
|
||||
|
||||
bsz, tgt_len, embed_dim = hidden_states.size()
|
||||
|
||||
mixed_qkv = self.qkv(hidden_states)
|
||||
|
||||
# modified from original code, which is:
|
||||
# mixed_qkv = mixed_qkv.reshape(bsz, tgt_len, 3, self.num_heads, embed_dim // self.num_heads).permute(
|
||||
# 2, 0, 3, 1, 4
|
||||
# )
|
||||
# to:
|
||||
mixed_qkv = mixed_qkv.reshape(bsz, tgt_len, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
|
||||
query_states, key_states, value_states = (
|
||||
mixed_qkv[0],
|
||||
mixed_qkv[1],
|
||||
mixed_qkv[2],
|
||||
)
|
||||
|
||||
# Take the dot product between "query" and "key" to get the raw attention scores.
|
||||
attention_scores = torch.matmul(query_states, key_states.transpose(-1, -2))
|
||||
|
||||
attention_scores = attention_scores * self.scale
|
||||
|
||||
# Normalize the attention scores to probabilities.
|
||||
attention_probs = nn.functional.softmax(attention_scores, dim=-1)
|
||||
|
||||
# This is actually dropping out entire tokens to attend to, which might
|
||||
# seem a bit unusual, but is taken from the original Transformer paper.
|
||||
attention_probs = self.dropout(attention_probs)
|
||||
|
||||
# Mask heads if we want to
|
||||
if head_mask is not None:
|
||||
attention_probs = attention_probs * head_mask
|
||||
|
||||
context_layer = torch.matmul(attention_probs, value_states).permute(0, 2, 1, 3)
|
||||
|
||||
new_context_layer_shape = context_layer.size()[:-2] + (self.embed_dim,)
|
||||
context_layer = context_layer.reshape(new_context_layer_shape)
|
||||
|
||||
output = self.projection(context_layer)
|
||||
|
||||
outputs = (output, attention_probs) if output_attentions else (output, None)
|
||||
|
||||
return outputs
|
||||
|
||||
return forward
|
|
@ -1,6 +1,4 @@
|
|||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch.distributed import ProcessGroup
|
||||
|
||||
|
||||
def forward_fn():
|
||||
|
|
|
@ -116,6 +116,12 @@ _POLICY_LIST = {
|
|||
# Sam
|
||||
"transformers.models.sam.modeling_sam.SamModel":
|
||||
PolicyLocation(file_name="sam", class_name="SamModelPolicy"),
|
||||
|
||||
# Blip2
|
||||
"transformers.models.blip_2.modeling_blip_2.Blip2Model":
|
||||
PolicyLocation(file_name="blip2", class_name="Blip2ModelPolicy"),
|
||||
"transformers.models.blip_2.modeling_blip_2.Blip2ForConditionalGeneration":
|
||||
PolicyLocation(file_name="blip2", class_name="Blip2ForConditionalGenerationPolicy"),
|
||||
}
|
||||
|
||||
|
||||
|
|
|
@ -0,0 +1,304 @@
|
|||
import torch.nn as nn
|
||||
|
||||
import colossalai.shardformer.layer as col_nn
|
||||
|
||||
from .._utils import getattr_, setattr_
|
||||
from ..modeling.blip2 import forward_fn
|
||||
from .basepolicy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
|
||||
|
||||
__all__ = ['BlipPolicy', 'BlipModelPolicy']
|
||||
|
||||
|
||||
class BlipPolicy(Policy):
|
||||
|
||||
def config_sanity_check(self):
|
||||
pass
|
||||
|
||||
def preprocess(self):
|
||||
# reshape the embedding layer
|
||||
r"""
|
||||
Reshape the Embedding layer to make the embedding dimension divisible by world_size
|
||||
"""
|
||||
# TODO:
|
||||
vocab_size = self.model.config.qformer_config.vocab_size
|
||||
world_size = self.shard_config.tensor_parallel_size
|
||||
if vocab_size % world_size != 0:
|
||||
new_vocab_size = vocab_size + world_size - vocab_size % world_size
|
||||
self.model.resize_token_embeddings(new_vocab_size)
|
||||
return self.model
|
||||
|
||||
def module_policy(self):
|
||||
from transformers.models.blip_2.modeling_blip_2 import (
|
||||
Blip2Attention,
|
||||
Blip2EncoderLayer,
|
||||
Blip2QFormerLayer,
|
||||
Blip2QFormerModel,
|
||||
Blip2VisionModel,
|
||||
)
|
||||
from transformers.models.opt.modeling_opt import OPTDecoderLayer, OPTForCausalLM
|
||||
|
||||
policy = {}
|
||||
|
||||
if self.shard_config.enable_tensor_parallelism:
|
||||
policy[Blip2EncoderLayer] = ModulePolicyDescription(attribute_replacement={
|
||||
"self_attn.num_heads":
|
||||
self.model.config.vision_config.num_attention_heads // self.shard_config.tensor_parallel_size,
|
||||
"self_attn.embed_dim":
|
||||
self.model.config.vision_config.hidden_size // self.shard_config.tensor_parallel_size,
|
||||
},
|
||||
sub_module_replacement=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="self_attn.dropout",
|
||||
target_module=col_nn.DropoutForParallelInput,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="self_attn.qkv",
|
||||
target_module=col_nn.FusedLinear1D_Col,
|
||||
kwargs={
|
||||
"n_fused": 3,
|
||||
}),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="self_attn.projection",
|
||||
target_module=col_nn.Linear1D_Row,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="mlp.fc1",
|
||||
target_module=col_nn.Linear1D_Col,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="mlp.fc2",
|
||||
target_module=col_nn.Linear1D_Row,
|
||||
),
|
||||
])
|
||||
|
||||
policy[Blip2QFormerModel] = ModulePolicyDescription(sub_module_replacement=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="dropout",
|
||||
target_module=col_nn.DropoutForParallelInput,
|
||||
),
|
||||
])
|
||||
|
||||
policy[Blip2QFormerLayer] = ModulePolicyDescription(attribute_replacement={
|
||||
"attention.attention.num_attention_heads":
|
||||
self.model.config.qformer_config.num_attention_heads // self.shard_config.tensor_parallel_size,
|
||||
"attention.attention.all_head_size":
|
||||
self.model.config.qformer_config.hidden_size // self.shard_config.tensor_parallel_size,
|
||||
"crossattention.attention.num_attention_heads":
|
||||
self.model.config.qformer_config.num_attention_heads // self.shard_config.tensor_parallel_size,
|
||||
"crossattention.attention.all_head_size":
|
||||
self.model.config.qformer_config.hidden_size // self.shard_config.tensor_parallel_size,
|
||||
},
|
||||
sub_module_replacement=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="attention.attention.query",
|
||||
target_module=col_nn.Linear1D_Col,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="attention.attention.key",
|
||||
target_module=col_nn.Linear1D_Col,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="attention.attention.value",
|
||||
target_module=col_nn.Linear1D_Col,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="attention.attention.dropout",
|
||||
target_module=col_nn.DropoutForParallelInput,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="attention.output.dense",
|
||||
target_module=col_nn.Linear1D_Row,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="attention.output.dropout",
|
||||
target_module=col_nn.DropoutForParallelInput,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="crossattention.attention.query",
|
||||
target_module=col_nn.Linear1D_Col,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="crossattention.attention.key",
|
||||
target_module=col_nn.Linear1D_Col,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="crossattention.attention.value",
|
||||
target_module=col_nn.Linear1D_Col,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="crossattention.attention.dropout",
|
||||
target_module=col_nn.DropoutForParallelInput,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="crossattention.output.dense",
|
||||
target_module=col_nn.Linear1D_Row,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="crossattention.output.dropout",
|
||||
target_module=col_nn.DropoutForParallelInput,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="intermediate_query.dense",
|
||||
target_module=col_nn.Linear1D_Col,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="output_query.dense",
|
||||
target_module=col_nn.Linear1D_Row,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="output_query.dropout",
|
||||
target_module=col_nn.DropoutForParallelInput,
|
||||
)
|
||||
])
|
||||
|
||||
policy[OPTDecoderLayer] = ModulePolicyDescription(attribute_replacement={
|
||||
"self_attn.embed_dim":
|
||||
self.model.config.text_config.hidden_size // self.shard_config.tensor_parallel_size,
|
||||
"self_attn.num_heads":
|
||||
self.model.config.text_config.num_attention_heads // self.shard_config.tensor_parallel_size
|
||||
},
|
||||
sub_module_replacement=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="self_attn.q_proj",
|
||||
target_module=col_nn.Linear1D_Col,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="self_attn.k_proj",
|
||||
target_module=col_nn.Linear1D_Col,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="self_attn.v_proj",
|
||||
target_module=col_nn.Linear1D_Col,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="self_attn.out_proj",
|
||||
target_module=col_nn.Linear1D_Row,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="fc1",
|
||||
target_module=col_nn.Linear1D_Col,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="fc2",
|
||||
target_module=col_nn.Linear1D_Row,
|
||||
)
|
||||
])
|
||||
|
||||
policy[OPTForCausalLM] = ModulePolicyDescription(sub_module_replacement=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="model.decoder.embed_tokens",
|
||||
target_module=col_nn.VocabParallelEmbedding1D,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="lm_head",
|
||||
target_module=col_nn.Linear1D_Col,
|
||||
kwargs={"gather_output": True},
|
||||
),
|
||||
])
|
||||
|
||||
policy[Blip2Attention] = ModulePolicyDescription(method_replacement={"forward": forward_fn()})
|
||||
|
||||
# optimization configuration
|
||||
if self.shard_config.enable_fused_normalization:
|
||||
# Handle Blip2EncoderLayer layer
|
||||
self.append_or_create_submodule_replacement(description=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="layer_norm1",
|
||||
target_module=col_nn.FusedLayerNorm,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="layer_norm2",
|
||||
target_module=col_nn.FusedLayerNorm,
|
||||
)
|
||||
],
|
||||
policy=policy,
|
||||
target_key=Blip2EncoderLayer)
|
||||
|
||||
# handle Blip2VisionModel layer
|
||||
self.append_or_create_submodule_replacement(description=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="post_layernorm",
|
||||
target_module=col_nn.FusedLayerNorm,
|
||||
)
|
||||
],
|
||||
policy=policy,
|
||||
target_key=Blip2VisionModel)
|
||||
|
||||
# handle Blip2VisionModel layer
|
||||
self.append_or_create_submodule_replacement(
|
||||
description=[SubModuleReplacementDescription(
|
||||
suffix="layernorm",
|
||||
target_module=col_nn.FusedLayerNorm,
|
||||
)],
|
||||
policy=policy,
|
||||
target_key=Blip2QFormerModel)
|
||||
|
||||
# handle Blip2QFormerLayer layer
|
||||
self.append_or_create_submodule_replacement(description=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="attention.output.LayerNorm",
|
||||
target_module=col_nn.FusedLayerNorm,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="crossattention.output.LayerNorm",
|
||||
target_module=col_nn.FusedLayerNorm,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="output_query.LayerNorm",
|
||||
target_module=col_nn.FusedLayerNorm,
|
||||
)
|
||||
],
|
||||
policy=policy,
|
||||
target_key=Blip2QFormerLayer)
|
||||
|
||||
# handle OPTForCausalLM layer
|
||||
self.append_or_create_submodule_replacement(description=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="model.decoder.final_layer_norm",
|
||||
target_module=col_nn.FusedLayerNorm,
|
||||
)
|
||||
],
|
||||
policy=policy,
|
||||
target_key=OPTForCausalLM)
|
||||
|
||||
# handle OPTDecoderLayer layer
|
||||
self.append_or_create_submodule_replacement(description=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="self_attn_layer_norm",
|
||||
target_module=col_nn.FusedLayerNorm,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="final_layer_norm",
|
||||
target_module=col_nn.FusedLayerNorm,
|
||||
)
|
||||
],
|
||||
policy=policy,
|
||||
target_key=OPTDecoderLayer)
|
||||
|
||||
return policy
|
||||
|
||||
def postprocess(self):
|
||||
binding_map = {
|
||||
'language_model.model.decoder.embed_tokens': 'language_model.lm_head',
|
||||
}
|
||||
|
||||
for k, v in binding_map.items():
|
||||
src_mod = getattr_(self.model, k)
|
||||
dst_mod = getattr_(self.model, v)
|
||||
dst_mod.weight = src_mod.weight
|
||||
|
||||
return self.model
|
||||
|
||||
|
||||
# Blip2Model
|
||||
class Blip2ModelPolicy(BlipPolicy):
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
|
||||
# Blip2ForConditionalGeneration
|
||||
class Blip2ForConditionalGenerationPolicy(BlipPolicy):
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
|
@ -1,5 +1,6 @@
|
|||
from .albert import *
|
||||
from .bert import *
|
||||
from .blip2 import *
|
||||
from .bloom import *
|
||||
from .chatglm import *
|
||||
from .gpt import *
|
||||
|
|
|
@ -0,0 +1,61 @@
|
|||
import torch
|
||||
import transformers
|
||||
|
||||
from ..registry import ModelAttribute, model_zoo
|
||||
|
||||
# ===============================
|
||||
# Register single-image SAM
|
||||
# ===============================
|
||||
|
||||
|
||||
# define data gen function
|
||||
def data_gen():
|
||||
# Generated from following code snippet
|
||||
#
|
||||
# from PIL import Image
|
||||
# import requests
|
||||
# from transformers import Blip2Processor, Blip2Model
|
||||
# import torch
|
||||
|
||||
# processor = Blip2Processor.from_pretrained("Salesforce/blip2-opt-2.7b")
|
||||
# url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
||||
# image = Image.open(requests.get(url, stream=True).raw)
|
||||
|
||||
# prompt = "Question: how many cats are there? Answer:"
|
||||
# inputs = processor(images=image, text=prompt, return_tensors="pt").to(device, torch.float16)
|
||||
|
||||
pixel_values = torch.rand(1, 3, 224, 224, dtype=torch.float32)
|
||||
input_ids = torch.tensor([[2, 45641, 35, 141, 171, 10017, 32, 89, 116, 31652, 35]], dtype=torch.int64)
|
||||
attention_mask = torch.tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]], dtype=torch.int64)
|
||||
labels = torch.tensor([[34, 56]], dtype=torch.int64)
|
||||
return dict(pixel_values=pixel_values, input_ids=input_ids, attention_mask=attention_mask, labels=labels)
|
||||
|
||||
|
||||
# define output transform function
|
||||
output_transform_fn = lambda x: x
|
||||
|
||||
# define loss funciton
|
||||
loss_fn_blip2_model = lambda x: x.loss
|
||||
|
||||
config = transformers.Blip2Config()
|
||||
config.text_config.num_hidden_layers = 1
|
||||
config.qformer_config.num_hidden_layers = 1
|
||||
config.vision_config.num_hidden_layers = 1
|
||||
config.qformer_config.attention_probs_dropout_prob = 0
|
||||
config.qformer_config.hidden_dropout_prob = 0
|
||||
config.text_config.dropout = 0
|
||||
|
||||
# register the blip2 variants
|
||||
model_zoo.register(name='transformers_blip2',
|
||||
model_fn=lambda: transformers.Blip2Model(config),
|
||||
data_gen_fn=data_gen,
|
||||
output_transform_fn=output_transform_fn,
|
||||
loss_fn=loss_fn_blip2_model,
|
||||
model_attribute=ModelAttribute(has_control_flow=True))
|
||||
|
||||
model_zoo.register(name='transformers_blip2_conditional_gerneration',
|
||||
model_fn=lambda: transformers.Blip2ForConditionalGeneration(config),
|
||||
data_gen_fn=data_gen,
|
||||
output_transform_fn=output_transform_fn,
|
||||
loss_fn=loss_fn_blip2_model,
|
||||
model_attribute=ModelAttribute(has_control_flow=True))
|
|
@ -0,0 +1,107 @@
|
|||
import pytest
|
||||
import torch
|
||||
|
||||
import colossalai
|
||||
from colossalai.logging import disable_existing_loggers
|
||||
from colossalai.tensor.d_tensor.api import is_customized_distributed_tensor, is_distributed_tensor
|
||||
from colossalai.testing import (
|
||||
assert_hf_output_close,
|
||||
clear_cache_before_run,
|
||||
parameterize,
|
||||
rerun_if_address_is_in_use,
|
||||
spawn,
|
||||
)
|
||||
from tests.kit.model_zoo import model_zoo
|
||||
from tests.test_shardformer.test_model._utils import build_model, run_forward
|
||||
|
||||
|
||||
def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn):
|
||||
# check forward
|
||||
org_output, org_loss, shard_output, shard_loss = run_forward(org_model, sharded_model, data_gen_fn,
|
||||
output_transform_fn, loss_fn)
|
||||
assert_hf_output_close(org_output, shard_output, ignore_keys=['past_key_values'])
|
||||
|
||||
# do backward
|
||||
org_loss.backward()
|
||||
shard_loss.backward()
|
||||
|
||||
assert torch.allclose(org_loss, shard_loss,
|
||||
atol=1e-5), f"shard model loss is not equal to orgin model loss\n{org_loss}\n{shard_loss}"
|
||||
|
||||
# check grad
|
||||
|
||||
blip2 = org_model
|
||||
sharded_blip2 = sharded_model
|
||||
|
||||
# compare vision_model grad
|
||||
|
||||
org_grad = blip2.vision_model.encoder.layers[0].self_attn.qkv.weight.grad
|
||||
shard_grad = sharded_blip2.vision_model.encoder.layers[0].self_attn.qkv.weight.grad
|
||||
shard_weight = sharded_blip2.vision_model.encoder.layers[0].self_attn.qkv.weight
|
||||
|
||||
if is_distributed_tensor(shard_weight) or is_customized_distributed_tensor(shard_weight):
|
||||
shard_grad_list = [torch.zeros([*shard_grad.shape]).to('cuda') for _ in range(2)]
|
||||
shard_grad = torch.distributed.all_gather(shard_grad_list, shard_grad)
|
||||
all_shard_grad = torch.cat(shard_grad_list, dim=0)
|
||||
else:
|
||||
all_shard_grad = shard_grad
|
||||
assert torch.allclose(org_grad, all_shard_grad,
|
||||
atol=1e-5), f"shard model grad is not equal to orgin model grad\n{org_grad}\n{all_shard_grad}"
|
||||
|
||||
# compare qformer grad
|
||||
org_grad = blip2.qformer.encoder.layer[0].attention.attention.query.weight.grad
|
||||
shard_grad = sharded_blip2.qformer.encoder.layer[0].attention.attention.query.weight.grad
|
||||
shard_weight = sharded_blip2.qformer.encoder.layer[0].attention.attention.query.weight
|
||||
|
||||
if is_distributed_tensor(shard_weight) or is_customized_distributed_tensor(shard_weight):
|
||||
shard_grad_list = [torch.zeros([*shard_grad.shape]).to('cuda') for _ in range(2)]
|
||||
shard_grad = torch.distributed.all_gather(shard_grad_list, shard_grad)
|
||||
all_shard_grad = torch.cat(shard_grad_list, dim=0)
|
||||
else:
|
||||
all_shard_grad = shard_grad
|
||||
|
||||
assert torch.allclose(org_grad, all_shard_grad,
|
||||
atol=1e-5), f"shard model grad is not equal to orgin model grad\n{org_grad}\n{all_shard_grad}"
|
||||
|
||||
# compare language_model grad
|
||||
org_grad = blip2.language_model.model.decoder.layers[0].self_attn.k_proj.weight.grad
|
||||
shard_grad = sharded_blip2.language_model.model.decoder.layers[0].self_attn.k_proj.weight.grad
|
||||
shard_weight = sharded_blip2.language_model.model.decoder.layers[0].self_attn.k_proj.weight
|
||||
|
||||
if is_distributed_tensor(shard_weight) or is_customized_distributed_tensor(shard_weight):
|
||||
shard_grad_list = [torch.zeros([*shard_grad.shape]).to('cuda') for _ in range(2)]
|
||||
shard_grad = torch.distributed.all_gather(shard_grad_list, shard_grad)
|
||||
all_shard_grad = torch.cat(shard_grad_list, dim=0)
|
||||
else:
|
||||
all_shard_grad = shard_grad
|
||||
|
||||
assert torch.allclose(org_grad, all_shard_grad,
|
||||
atol=1e-5), f"shard model grad is not equal to orgin model grad\n{org_grad}\n{all_shard_grad}"
|
||||
|
||||
|
||||
@parameterize('enable_fused_normalization', [True, False])
|
||||
@parameterize('enable_tensor_parallelism', [True, False])
|
||||
def run_blip2_test(enable_fused_normalization, enable_tensor_parallelism):
|
||||
sub_model_zoo = model_zoo.get_sub_registry('transformers_blip2')
|
||||
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
|
||||
org_model, sharded_model = build_model(model_fn, enable_fused_normalization, enable_tensor_parallelism)
|
||||
check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn)
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
|
||||
def check_blip2(rank, world_size, port):
|
||||
disable_existing_loggers()
|
||||
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
run_blip2_test()
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
@rerun_if_address_is_in_use()
|
||||
@clear_cache_before_run()
|
||||
def test_blip2():
|
||||
spawn(check_blip2, 2)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_blip2()
|
Loading…
Reference in New Issue