From 9a4842c571cd63e6a660182a234bc6ff60991ba0 Mon Sep 17 00:00:00 2001 From: Jianghai <72591262+CjhHa1@users.noreply.github.com> Date: Mon, 17 Jul 2023 17:30:57 +0800 Subject: [PATCH] revise shardformer readme (#4246) --- colossalai/shardformer/README.md | 25 ++++++++++--------------- 1 file changed, 10 insertions(+), 15 deletions(-) diff --git a/colossalai/shardformer/README.md b/colossalai/shardformer/README.md index 6ae32e4fb..bf4215c52 100644 --- a/colossalai/shardformer/README.md +++ b/colossalai/shardformer/README.md @@ -22,7 +22,6 @@ - [System Performance](#system-performance) - [Convergence](#convergence) - ## 🔗 Introduction **Shardformer** is a module that automatically parallelizes the mainstream models in libraries such as HuggingFace and TIMM. This module aims to make parallelization hassle-free for users who are not from the system background. @@ -33,7 +32,7 @@ The sample API usage is given below: -``` python +```python from colossalai.shardformer import ShardConfig, Shard from transformers import BertForMaskedLM @@ -74,6 +73,7 @@ shard_former.optimize(model, my_policy) ``` + ## 🗺 Roadmap We will follow this roadmap to develop Shardformer: @@ -117,15 +117,13 @@ Please refer to the code for more details.

- - ### Distributed Modules `ShardFormer` replaces the original PyTorch module with a distributed module. The distributed module keeps the same attributes as the original module but replaces the original parameters with distributed parameters and defines a new `forward` function to execute distributed computation. Each distributed module implements its `from_native_module` static method to convert the PyTorch module to its corresponding distributed module. -```python +````python class ParallelModule(torch.nn.Module): @abstractmethod @@ -140,7 +138,7 @@ class ParallelModule(torch.nn.Module): my_linear = Linear1D_Col.from_native_module(my_linear, process_group) ``` """ -``` +```` ### Shard Config @@ -169,7 +167,7 @@ We abstract the policy into four stages: 2. Providing `ModulePolicyDescription`: call `Policy.module_policy` to get a bunch of `ModulePolicyDescription` to tell `ModelSharder` how the submodules's attributes, child parameters, and deeper submodules will be substituted. 3. Postprocessing: call `Policy.postprocess` to perform some postprocessing work, for example, binding the embedding and classifier head weights of the BERT model. -``` python +```python @dataclass class ModulePolicyDescription: r""" @@ -238,7 +236,6 @@ class Policy(ABC): ... ``` - ### Model Sharder `ModelSharder` is the class in charge of sharding the model based on the given policy. @@ -324,21 +321,20 @@ You can create a new file in the `colossalai/shardformer/policies` folder and na Please follow the following protocols when writing your policy: - You have to make a clear decision what you want to replace exactly in the original PyTorch module - - Use `ModulePolicyDescription.attribute_replacement` to replace the module attributes - - Use `ModulePolicyDescription.param_replacement` to replace the module parameters - - Use `ModulePolicyDescription.sub_module_replacement` to replace the submodules completely. The target module should implement the `from_native_module` for the . - - Use `ModulePolicyDescription.method_replacement` to replace the module methods. **These replacement methods should be put in the `shardformer/modeling/.py`**. + - Use `ModulePolicyDescription.attribute_replacement` to replace the module attributes + - Use `ModulePolicyDescription.param_replacement` to replace the module parameters + - Use `ModulePolicyDescription.sub_module_replacement` to replace the submodules completely. The target module should implement the `from_native_module` for the replacement. + - Use `ModulePolicyDescription.method_replacement` to replace the module methods. **These replacement methods should be put in the `shardformer/modeling/.py`**. - You can implement the `ParallelModule` for primitive modules in the `shardformer/layer/.py` file. Primitive modules refer to modules which are not composed of other modules. For example, the `torch.nn.Linear` module is a primitive module while modules such as `BertEncoder` module in the `transformers` library is a composite module. Primitive modules do not nested inner `nn.Module` members. For composite modules, you should consider using `ModulePolicyDescription` to implement your replacement. - `ParallelModule` is meant to be used in two ways: `ParallelModule.from_native_module` to convert native PyTorch module to the `ParallelModule` and `ParallelModule(...)` to instantiate the module directly just like a normal PyTorch module. `ParallelModule` should be only implemented for modules whose weights are sharded. If you want to make your module compatible with the `ModulePolicyDescription.sub_module_replacement` and there is no weight sharding in your module, you can just implement the `from_native_module` method without inheriting the `ParallelModule` like `colossalai/shardformer/layer/normalization.py`. - **Do not import any file in the `colossalai/shardformer/policies` and `colossalai/shardformer/modeling` to avoid unwanted import error**. For example, a file in these folders accidentally imports `transformers` library at the top of the file, then the user will have to install `transformers` library even if they do not use this file. Any file in the `modeling` folder should be only imported by the policy file. A policy implementation should be only imported dynamically via the autopolicy or manually via the `ShardFormer` module. - Try to keep your import statement on third-party libraries such as `transformers` within the function body instead of the header section of the file. This is because we do not want to import the library when we do not use the policy. - - Step 2. Register your policy to the autopolicy Next, you need to register your policy in the `colossalai/shardformer/policies/autopolicy.py` file. -For example, if we register the policy for the BERT model, we just add a key-value in the `_POLICY_LIST` dictionary. The key if the `qualname` of the model object (you can get it by model.__class__.__qualname__). The value is a `PolicyLocation` object, which contains the file name and the class name of the policy. We do not import the policy directly because the policy file may contain libraries (such as `transformers`) which we do not want to import when we do not use the policy. +For example, if we register the policy for the BERT model, we just add a key-value in the `_POLICY_LIST` dictionary. The key if the `qualname` of the model object (you can get it by model.\_\_class\_\_.\_\_qualname\_\_). The value is a `PolicyLocation` object, which contains the file name and the class name of the policy. We do not import the policy directly because the policy file may contain libraries (such as `transformers`) which we do not want to import when we do not use the policy. ```python _POLICY_LIST = { @@ -360,7 +356,6 @@ Add your model to the `tests/kit/model_zoo` file. This allows you to define test Next, implement your unit test in the `tests/test_shardformer` folder. Please refer to other similar tests for style consistency. - - Step 3. Execute your test When you run tests locally, you should run tests for both your newly-added test file and the whole `shardformer` module tests.