mirror of https://github.com/hpcaitech/ColossalAI
[shardformer] support vision transformer (#4096)
* first v of vit shardformer * keep vit * update * vit shard add vitattention vitlayer * update num head shard para * finish test for vit * add new_model_class & postprocess * add vit readme * delete old files & fix the conflict * fix sthpull/4157/head
parent
ac80937138
commit
8af29ee47a
|
@ -91,7 +91,7 @@ We will follow this roadmap to develop Shardformer:
|
||||||
- [ ] GPT Neo
|
- [ ] GPT Neo
|
||||||
- [ ] GPT-J
|
- [ ] GPT-J
|
||||||
- [ ] CV
|
- [ ] CV
|
||||||
- [ ] ViT
|
- [x] ViT
|
||||||
- [ ] BEiT
|
- [ ] BEiT
|
||||||
- [ ] SwinTransformer
|
- [ ] SwinTransformer
|
||||||
- [ ] SwinTransformer V2
|
- [ ] SwinTransformer V2
|
||||||
|
|
|
@ -287,4 +287,4 @@ def reduce_forward(input_, process_group):
|
||||||
|
|
||||||
|
|
||||||
def reduce_backward(input_, process_group):
|
def reduce_backward(input_, process_group):
|
||||||
return _ReduceBackward.apply(input_, process_group)
|
return _ReduceBackward.apply(input_, process_group)
|
|
@ -61,4 +61,4 @@ class FusedLayerNorm():
|
||||||
# copy weight and bias
|
# copy weight and bias
|
||||||
layernorm.weight.copy_(module.weight)
|
layernorm.weight.copy_(module.weight)
|
||||||
layernorm.bias.copy_(module.bias)
|
layernorm.bias.copy_(module.bias)
|
||||||
return layernorm
|
return layernorm
|
|
@ -316,4 +316,4 @@ class BertForMultipleChoicePolicy(BertPolicy):
|
||||||
])
|
])
|
||||||
}
|
}
|
||||||
module_policy.update(addon_module)
|
module_policy.update(addon_module)
|
||||||
return module_policy
|
return module_policy
|
|
@ -167,4 +167,4 @@ class T5ForConditionalGenerationPolicy(T5ModelPolicy):
|
||||||
|
|
||||||
|
|
||||||
class T5EncoderPolicy(T5ModelPolicy):
|
class T5EncoderPolicy(T5ModelPolicy):
|
||||||
pass
|
pass
|
|
@ -0,0 +1,96 @@
|
||||||
|
from typing import Dict, Union
|
||||||
|
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
from transformers.models.vit.modeling_vit import ViTModel, ViTLayer, ViTEmbeddings, ViTAttention
|
||||||
|
|
||||||
|
from colossalai.shardformer.layer import Linear1D_Col, Linear1D_Row, Dropout1D
|
||||||
|
|
||||||
|
from .basepolicy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
|
||||||
|
|
||||||
|
class ViTPolicy(Policy):
|
||||||
|
|
||||||
|
def preprocess(self):
|
||||||
|
# Resize embedding
|
||||||
|
vocab_size = self.model.config.vocab_size
|
||||||
|
world_size = self.shard_config.tensor_parallel_size
|
||||||
|
|
||||||
|
if vocab_size % world_size != 0:
|
||||||
|
new_vocab_size = vocab_size + world_size - vocab_size % world_size
|
||||||
|
self.model.resize_token_embeddings(new_vocab_size)
|
||||||
|
|
||||||
|
return self.model
|
||||||
|
|
||||||
|
def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:
|
||||||
|
return {
|
||||||
|
ViTEmbeddings:
|
||||||
|
ModulePolicyDescription(
|
||||||
|
attribute_replacement{},
|
||||||
|
param_replacement=[],
|
||||||
|
sub_module_replacement=[
|
||||||
|
SubModuleReplacementDescription(
|
||||||
|
suffix="dropout",
|
||||||
|
target_module=Dropout1D,
|
||||||
|
)
|
||||||
|
]
|
||||||
|
),
|
||||||
|
ViTLayer:
|
||||||
|
ModulePolicyDescription(
|
||||||
|
attribute_replacement{
|
||||||
|
"attention.attention.num_attention_heads":
|
||||||
|
self.model.config.num_attention_heads//self.shard_config.tensor_parallel_size,
|
||||||
|
"attention.attention.all_head_size":
|
||||||
|
self.model.config.hidden_size//self.shard_config.tensor_parallel_size,
|
||||||
|
},
|
||||||
|
param_replacement=[],
|
||||||
|
sub_module_replacement=[
|
||||||
|
SubModuleReplacementDescription(
|
||||||
|
suffix="attention.attention.query",
|
||||||
|
target_module=Linear1D_Col,
|
||||||
|
),
|
||||||
|
SubModuleReplacementDescription(
|
||||||
|
suffix="attention.attention.key",
|
||||||
|
target_module=Linear1D_Col,
|
||||||
|
),
|
||||||
|
SubModuleReplacementDescription(
|
||||||
|
suffix="attention.attention.value",
|
||||||
|
target_module=Linear1D_Col,
|
||||||
|
),
|
||||||
|
SubModuleReplacementDescription(
|
||||||
|
suffix="attention.attention.dropout",
|
||||||
|
target_module=Dropout1D,
|
||||||
|
),
|
||||||
|
SubModuleReplacementDescription(
|
||||||
|
suffix="attention.output.dense",
|
||||||
|
target_module=Linear1D_Row,
|
||||||
|
),
|
||||||
|
SubModuleReplacementDescription(
|
||||||
|
suffix="attention.output.dropout",
|
||||||
|
target_module=Dropout1D,
|
||||||
|
),
|
||||||
|
SubModuleReplacementDescription(
|
||||||
|
suffix="intermediate.dense",
|
||||||
|
target_module=Linear1D_Col,
|
||||||
|
),
|
||||||
|
SubModuleReplacementDescription(
|
||||||
|
suffix="output.dense",
|
||||||
|
target_module=Linear1D_Row,
|
||||||
|
),
|
||||||
|
SubModuleReplacementDescription(
|
||||||
|
suffix="output.dropout",
|
||||||
|
target_module=Dropout1D,
|
||||||
|
),
|
||||||
|
]
|
||||||
|
),
|
||||||
|
}
|
||||||
|
|
||||||
|
def new_model_class(self):
|
||||||
|
return None
|
||||||
|
|
||||||
|
def postprocess(self):
|
||||||
|
return self.model
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -86,4 +86,4 @@ def test_device_mesh_from_process_group():
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
test_device_mesh()
|
test_device_mesh()
|
||||||
test_device_mesh_from_process_group()
|
test_device_mesh_from_process_group()
|
|
@ -41,4 +41,4 @@ def test_layernorm():
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
test_layernorm_1d()
|
test_layernorm_1d()
|
|
@ -56,4 +56,4 @@ def test_t5():
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
test_t5()
|
test_t5()
|
|
@ -0,0 +1,55 @@
|
||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
|
||||||
|
import colossalai
|
||||||
|
from colossalai.logging import disable_existing_loggers
|
||||||
|
from colossalai.testing import assert_hf_output_close, clear_cache_before_run, rerun_if_address_is_in_use, spawn
|
||||||
|
from tests.kit.model_zoo import model_zoo
|
||||||
|
from tests.test_shardformer.test_model._utils import build_model, run_forward
|
||||||
|
|
||||||
|
|
||||||
|
def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn):
|
||||||
|
# check forward
|
||||||
|
org_output, org_loss, shard_output, shard_loss = run_forward(org_model, sharded_model, data_gen_fn,
|
||||||
|
output_transform_fn, loss_fn)
|
||||||
|
assert_hf_output_close(org_output, shard_output)
|
||||||
|
|
||||||
|
# do backward
|
||||||
|
org_loss.backward()
|
||||||
|
shard_loss.backward()
|
||||||
|
|
||||||
|
# check grad
|
||||||
|
org_grad = org_model.encoder.layer[0].attention.attention.query.weight.grad
|
||||||
|
shard_grad = sharded_model.encoder.layer[0].attention.attention.query.weight.grad
|
||||||
|
|
||||||
|
shard_grad_list = [torch.zeros([*shard_grad.shape]).to('cuda') for _ in range(2)]
|
||||||
|
shard_grad = torch.distributed.all_gather(shard_grad_list, shard_grad)
|
||||||
|
all_shard_grad = torch.cat(shard_grad_list, dim=0)
|
||||||
|
|
||||||
|
assert torch.allclose(org_loss, shard_loss,
|
||||||
|
atol=1e-5), f"shard model loss is not equal to orgin model loss\n{org_loss}\n{shard_loss}"
|
||||||
|
assert torch.allclose(org_grad, all_shard_grad,
|
||||||
|
atol=1e-5), f"shard model grad is not equal to orgin model grad\n{org_grad}\n{all_shard_grad}"
|
||||||
|
|
||||||
|
|
||||||
|
def check_vit(rank, world_size, port):
|
||||||
|
disable_existing_loggers()
|
||||||
|
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||||
|
|
||||||
|
sub_model_zoo = model_zoo.get_sub_registry('transformers_vit')
|
||||||
|
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
|
||||||
|
org_model, sharded_model = build_model(world_size, model_fn)
|
||||||
|
check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn)
|
||||||
|
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.dist
|
||||||
|
@rerun_if_address_is_in_use()
|
||||||
|
@clear_cache_before_run()
|
||||||
|
def test_vit():
|
||||||
|
spawn(check_vit, 4)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
test_vit()
|
Loading…
Reference in New Issue