[shardformer] support vision transformer (#4096)

* first v of vit shardformer

* keep vit

* update

* vit shard add vitattention vitlayer

* update num head shard para

* finish test for vit

* add new_model_class & postprocess

* add vit readme

* delete old files & fix the conflict

* fix sth
pull/4157/head
Kun Lin 2023-06-28 13:28:18 +08:00 committed by Frank Lee
parent ac80937138
commit 8af29ee47a
10 changed files with 159 additions and 8 deletions

View File

@ -91,7 +91,7 @@ We will follow this roadmap to develop Shardformer:
- [ ] GPT Neo
- [ ] GPT-J
- [ ] CV
- [ ] ViT
- [x] ViT
- [ ] BEiT
- [ ] SwinTransformer
- [ ] SwinTransformer V2

View File

@ -287,4 +287,4 @@ def reduce_forward(input_, process_group):
def reduce_backward(input_, process_group):
return _ReduceBackward.apply(input_, process_group)
return _ReduceBackward.apply(input_, process_group)

View File

@ -61,4 +61,4 @@ class FusedLayerNorm():
# copy weight and bias
layernorm.weight.copy_(module.weight)
layernorm.bias.copy_(module.bias)
return layernorm
return layernorm

View File

@ -316,4 +316,4 @@ class BertForMultipleChoicePolicy(BertPolicy):
])
}
module_policy.update(addon_module)
return module_policy
return module_policy

View File

@ -167,4 +167,4 @@ class T5ForConditionalGenerationPolicy(T5ModelPolicy):
class T5EncoderPolicy(T5ModelPolicy):
pass
pass

View File

@ -0,0 +1,96 @@
from typing import Dict, Union
import torch.nn as nn
from transformers.models.vit.modeling_vit import ViTModel, ViTLayer, ViTEmbeddings, ViTAttention
from colossalai.shardformer.layer import Linear1D_Col, Linear1D_Row, Dropout1D
from .basepolicy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
class ViTPolicy(Policy):
def preprocess(self):
# Resize embedding
vocab_size = self.model.config.vocab_size
world_size = self.shard_config.tensor_parallel_size
if vocab_size % world_size != 0:
new_vocab_size = vocab_size + world_size - vocab_size % world_size
self.model.resize_token_embeddings(new_vocab_size)
return self.model
def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:
return {
ViTEmbeddings:
ModulePolicyDescription(
attribute_replacement{},
param_replacement=[],
sub_module_replacement=[
SubModuleReplacementDescription(
suffix="dropout",
target_module=Dropout1D,
)
]
),
ViTLayer:
ModulePolicyDescription(
attribute_replacement{
"attention.attention.num_attention_heads":
self.model.config.num_attention_heads//self.shard_config.tensor_parallel_size,
"attention.attention.all_head_size":
self.model.config.hidden_size//self.shard_config.tensor_parallel_size,
},
param_replacement=[],
sub_module_replacement=[
SubModuleReplacementDescription(
suffix="attention.attention.query",
target_module=Linear1D_Col,
),
SubModuleReplacementDescription(
suffix="attention.attention.key",
target_module=Linear1D_Col,
),
SubModuleReplacementDescription(
suffix="attention.attention.value",
target_module=Linear1D_Col,
),
SubModuleReplacementDescription(
suffix="attention.attention.dropout",
target_module=Dropout1D,
),
SubModuleReplacementDescription(
suffix="attention.output.dense",
target_module=Linear1D_Row,
),
SubModuleReplacementDescription(
suffix="attention.output.dropout",
target_module=Dropout1D,
),
SubModuleReplacementDescription(
suffix="intermediate.dense",
target_module=Linear1D_Col,
),
SubModuleReplacementDescription(
suffix="output.dense",
target_module=Linear1D_Row,
),
SubModuleReplacementDescription(
suffix="output.dropout",
target_module=Dropout1D,
),
]
),
}
def new_model_class(self):
return None
def postprocess(self):
return self.model

View File

@ -86,4 +86,4 @@ def test_device_mesh_from_process_group():
if __name__ == '__main__':
test_device_mesh()
test_device_mesh_from_process_group()
test_device_mesh_from_process_group()

View File

@ -41,4 +41,4 @@ def test_layernorm():
if __name__ == '__main__':
test_layernorm_1d()
test_layernorm_1d()

View File

@ -56,4 +56,4 @@ def test_t5():
if __name__ == "__main__":
test_t5()
test_t5()

View File

@ -0,0 +1,55 @@
import pytest
import torch
import colossalai
from colossalai.logging import disable_existing_loggers
from colossalai.testing import assert_hf_output_close, clear_cache_before_run, rerun_if_address_is_in_use, spawn
from tests.kit.model_zoo import model_zoo
from tests.test_shardformer.test_model._utils import build_model, run_forward
def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn):
# check forward
org_output, org_loss, shard_output, shard_loss = run_forward(org_model, sharded_model, data_gen_fn,
output_transform_fn, loss_fn)
assert_hf_output_close(org_output, shard_output)
# do backward
org_loss.backward()
shard_loss.backward()
# check grad
org_grad = org_model.encoder.layer[0].attention.attention.query.weight.grad
shard_grad = sharded_model.encoder.layer[0].attention.attention.query.weight.grad
shard_grad_list = [torch.zeros([*shard_grad.shape]).to('cuda') for _ in range(2)]
shard_grad = torch.distributed.all_gather(shard_grad_list, shard_grad)
all_shard_grad = torch.cat(shard_grad_list, dim=0)
assert torch.allclose(org_loss, shard_loss,
atol=1e-5), f"shard model loss is not equal to orgin model loss\n{org_loss}\n{shard_loss}"
assert torch.allclose(org_grad, all_shard_grad,
atol=1e-5), f"shard model grad is not equal to orgin model grad\n{org_grad}\n{all_shard_grad}"
def check_vit(rank, world_size, port):
disable_existing_loggers()
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
sub_model_zoo = model_zoo.get_sub_registry('transformers_vit')
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
org_model, sharded_model = build_model(world_size, model_fn)
check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn)
torch.cuda.empty_cache()
@pytest.mark.dist
@rerun_if_address_is_in_use()
@clear_cache_before_run()
def test_vit():
spawn(check_vit, 4)
if __name__ == "__main__":
test_vit()