diff --git a/colossalai/shardformer/README.md b/colossalai/shardformer/README.md index fee4cce7a..da80a7276 100644 --- a/colossalai/shardformer/README.md +++ b/colossalai/shardformer/README.md @@ -91,7 +91,7 @@ We will follow this roadmap to develop Shardformer: - [ ] GPT Neo - [ ] GPT-J - [ ] CV - - [ ] ViT + - [x] ViT - [ ] BEiT - [ ] SwinTransformer - [ ] SwinTransformer V2 diff --git a/colossalai/shardformer/layer/_operation.py b/colossalai/shardformer/layer/_operation.py index 7e97bee01..c025daaec 100644 --- a/colossalai/shardformer/layer/_operation.py +++ b/colossalai/shardformer/layer/_operation.py @@ -287,4 +287,4 @@ def reduce_forward(input_, process_group): def reduce_backward(input_, process_group): - return _ReduceBackward.apply(input_, process_group) + return _ReduceBackward.apply(input_, process_group) \ No newline at end of file diff --git a/colossalai/shardformer/layer/layernorm.py b/colossalai/shardformer/layer/layernorm.py index 83854239c..6103380fe 100644 --- a/colossalai/shardformer/layer/layernorm.py +++ b/colossalai/shardformer/layer/layernorm.py @@ -61,4 +61,4 @@ class FusedLayerNorm(): # copy weight and bias layernorm.weight.copy_(module.weight) layernorm.bias.copy_(module.bias) - return layernorm + return layernorm \ No newline at end of file diff --git a/colossalai/shardformer/policies/bert.py b/colossalai/shardformer/policies/bert.py index 7b0eaa5d8..fb70cdff8 100644 --- a/colossalai/shardformer/policies/bert.py +++ b/colossalai/shardformer/policies/bert.py @@ -316,4 +316,4 @@ class BertForMultipleChoicePolicy(BertPolicy): ]) } module_policy.update(addon_module) - return module_policy + return module_policy \ No newline at end of file diff --git a/colossalai/shardformer/policies/t5.py b/colossalai/shardformer/policies/t5.py index 30433f751..9a1b63e46 100644 --- a/colossalai/shardformer/policies/t5.py +++ b/colossalai/shardformer/policies/t5.py @@ -167,4 +167,4 @@ class T5ForConditionalGenerationPolicy(T5ModelPolicy): class T5EncoderPolicy(T5ModelPolicy): - pass + pass \ No newline at end of file diff --git a/colossalai/shardformer/policies/vit.py b/colossalai/shardformer/policies/vit.py new file mode 100644 index 000000000..4a2b72057 --- /dev/null +++ b/colossalai/shardformer/policies/vit.py @@ -0,0 +1,96 @@ +from typing import Dict, Union + +import torch.nn as nn + +from transformers.models.vit.modeling_vit import ViTModel, ViTLayer, ViTEmbeddings, ViTAttention + +from colossalai.shardformer.layer import Linear1D_Col, Linear1D_Row, Dropout1D + +from .basepolicy import ModulePolicyDescription, Policy, SubModuleReplacementDescription + +class ViTPolicy(Policy): + + def preprocess(self): + # Resize embedding + vocab_size = self.model.config.vocab_size + world_size = self.shard_config.tensor_parallel_size + + if vocab_size % world_size != 0: + new_vocab_size = vocab_size + world_size - vocab_size % world_size + self.model.resize_token_embeddings(new_vocab_size) + + return self.model + + def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: + return { + ViTEmbeddings: + ModulePolicyDescription( + attribute_replacement{}, + param_replacement=[], + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="dropout", + target_module=Dropout1D, + ) + ] + ), + ViTLayer: + ModulePolicyDescription( + attribute_replacement{ + "attention.attention.num_attention_heads": + self.model.config.num_attention_heads//self.shard_config.tensor_parallel_size, + "attention.attention.all_head_size": + self.model.config.hidden_size//self.shard_config.tensor_parallel_size, + }, + param_replacement=[], + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="attention.attention.query", + target_module=Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="attention.attention.key", + target_module=Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="attention.attention.value", + target_module=Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="attention.attention.dropout", + target_module=Dropout1D, + ), + SubModuleReplacementDescription( + suffix="attention.output.dense", + target_module=Linear1D_Row, + ), + SubModuleReplacementDescription( + suffix="attention.output.dropout", + target_module=Dropout1D, + ), + SubModuleReplacementDescription( + suffix="intermediate.dense", + target_module=Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="output.dense", + target_module=Linear1D_Row, + ), + SubModuleReplacementDescription( + suffix="output.dropout", + target_module=Dropout1D, + ), + ] + ), + } + + def new_model_class(self): + return None + + def postprocess(self): + return self.model + + + + + diff --git a/tests/test_device/test_device_mesh.py b/tests/test_device/test_device_mesh.py index 590d6966b..1f8db99c9 100644 --- a/tests/test_device/test_device_mesh.py +++ b/tests/test_device/test_device_mesh.py @@ -86,4 +86,4 @@ def test_device_mesh_from_process_group(): if __name__ == '__main__': test_device_mesh() - test_device_mesh_from_process_group() + test_device_mesh_from_process_group() \ No newline at end of file diff --git a/tests/test_shardformer/test_layer/test_layernorm.py b/tests/test_shardformer/test_layer/test_layernorm.py index a11784554..080fae034 100644 --- a/tests/test_shardformer/test_layer/test_layernorm.py +++ b/tests/test_shardformer/test_layer/test_layernorm.py @@ -41,4 +41,4 @@ def test_layernorm(): if __name__ == '__main__': - test_layernorm_1d() + test_layernorm_1d() \ No newline at end of file diff --git a/tests/test_shardformer/test_model/test_shard_t5.py b/tests/test_shardformer/test_model/test_shard_t5.py index 2698d7675..6074a902e 100644 --- a/tests/test_shardformer/test_model/test_shard_t5.py +++ b/tests/test_shardformer/test_model/test_shard_t5.py @@ -56,4 +56,4 @@ def test_t5(): if __name__ == "__main__": - test_t5() + test_t5() \ No newline at end of file diff --git a/tests/test_shardformer/test_model/test_shard_vit.py b/tests/test_shardformer/test_model/test_shard_vit.py new file mode 100644 index 000000000..d5d71d9e2 --- /dev/null +++ b/tests/test_shardformer/test_model/test_shard_vit.py @@ -0,0 +1,55 @@ +import pytest +import torch + +import colossalai +from colossalai.logging import disable_existing_loggers +from colossalai.testing import assert_hf_output_close, clear_cache_before_run, rerun_if_address_is_in_use, spawn +from tests.kit.model_zoo import model_zoo +from tests.test_shardformer.test_model._utils import build_model, run_forward + + +def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn): + # check forward + org_output, org_loss, shard_output, shard_loss = run_forward(org_model, sharded_model, data_gen_fn, + output_transform_fn, loss_fn) + assert_hf_output_close(org_output, shard_output) + + # do backward + org_loss.backward() + shard_loss.backward() + + # check grad + org_grad = org_model.encoder.layer[0].attention.attention.query.weight.grad + shard_grad = sharded_model.encoder.layer[0].attention.attention.query.weight.grad + + shard_grad_list = [torch.zeros([*shard_grad.shape]).to('cuda') for _ in range(2)] + shard_grad = torch.distributed.all_gather(shard_grad_list, shard_grad) + all_shard_grad = torch.cat(shard_grad_list, dim=0) + + assert torch.allclose(org_loss, shard_loss, + atol=1e-5), f"shard model loss is not equal to orgin model loss\n{org_loss}\n{shard_loss}" + assert torch.allclose(org_grad, all_shard_grad, + atol=1e-5), f"shard model grad is not equal to orgin model grad\n{org_grad}\n{all_shard_grad}" + + +def check_vit(rank, world_size, port): + disable_existing_loggers() + colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + + sub_model_zoo = model_zoo.get_sub_registry('transformers_vit') + for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): + org_model, sharded_model = build_model(world_size, model_fn) + check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn) + + torch.cuda.empty_cache() + + +@pytest.mark.dist +@rerun_if_address_is_in_use() +@clear_cache_before_run() +def test_vit(): + spawn(check_vit, 4) + + +if __name__ == "__main__": + test_vit()