revise shardformer readme (#4246)

pull/4257/head
Jianghai 2023-07-17 17:30:57 +08:00 committed by GitHub
parent 4e9b09c222
commit 9a4842c571
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 10 additions and 15 deletions

View File

@ -22,7 +22,6 @@
- [System Performance](#system-performance)
- [Convergence](#convergence)
## 🔗 Introduction
**Shardformer** is a module that automatically parallelizes the mainstream models in libraries such as HuggingFace and TIMM. This module aims to make parallelization hassle-free for users who are not from the system background.
@ -33,7 +32,7 @@
The sample API usage is given below:
``` python
```python
from colossalai.shardformer import ShardConfig, Shard
from transformers import BertForMaskedLM
@ -74,6 +73,7 @@ shard_former.optimize(model, my_policy)
```
## 🗺 Roadmap
We will follow this roadmap to develop Shardformer:
@ -117,15 +117,13 @@ Please refer to the code for more details.
<br/>
</p>
### Distributed Modules
`ShardFormer` replaces the original PyTorch module with a distributed module.
The distributed module keeps the same attributes as the original module but replaces the original parameters with distributed parameters and defines a new `forward` function to execute distributed computation.
Each distributed module implements its `from_native_module` static method to convert the PyTorch module to its corresponding distributed module.
```python
````python
class ParallelModule(torch.nn.Module):
@abstractmethod
@ -140,7 +138,7 @@ class ParallelModule(torch.nn.Module):
my_linear = Linear1D_Col.from_native_module(my_linear, process_group)
```
"""
```
````
### Shard Config
@ -169,7 +167,7 @@ We abstract the policy into four stages:
2. Providing `ModulePolicyDescription`: call `Policy.module_policy` to get a bunch of `ModulePolicyDescription` to tell `ModelSharder` how the submodules's attributes, child parameters, and deeper submodules will be substituted.
3. Postprocessing: call `Policy.postprocess` to perform some postprocessing work, for example, binding the embedding and classifier head weights of the BERT model.
``` python
```python
@dataclass
class ModulePolicyDescription:
r"""
@ -238,7 +236,6 @@ class Policy(ABC):
...
```
### Model Sharder
`ModelSharder` is the class in charge of sharding the model based on the given policy.
@ -324,21 +321,20 @@ You can create a new file in the `colossalai/shardformer/policies` folder and na
Please follow the following protocols when writing your policy:
- You have to make a clear decision what you want to replace exactly in the original PyTorch module
- Use `ModulePolicyDescription.attribute_replacement` to replace the module attributes
- Use `ModulePolicyDescription.param_replacement` to replace the module parameters
- Use `ModulePolicyDescription.sub_module_replacement` to replace the submodules completely. The target module should implement the `from_native_module` for the .
- Use `ModulePolicyDescription.method_replacement` to replace the module methods. **These replacement methods should be put in the `shardformer/modeling/<model-name>.py`**.
- Use `ModulePolicyDescription.attribute_replacement` to replace the module attributes
- Use `ModulePolicyDescription.param_replacement` to replace the module parameters
- Use `ModulePolicyDescription.sub_module_replacement` to replace the submodules completely. The target module should implement the `from_native_module` for the replacement.
- Use `ModulePolicyDescription.method_replacement` to replace the module methods. **These replacement methods should be put in the `shardformer/modeling/<model-name>.py`**.
- You can implement the `ParallelModule` for primitive modules in the `shardformer/layer/<model-name>.py` file. Primitive modules refer to modules which are not composed of other modules. For example, the `torch.nn.Linear` module is a primitive module while modules such as `BertEncoder` module in the `transformers` library is a composite module. Primitive modules do not nested inner `nn.Module` members. For composite modules, you should consider using `ModulePolicyDescription` to implement your replacement.
- `ParallelModule` is meant to be used in two ways: `ParallelModule.from_native_module` to convert native PyTorch module to the `ParallelModule` and `ParallelModule(...)` to instantiate the module directly just like a normal PyTorch module. `ParallelModule` should be only implemented for modules whose weights are sharded. If you want to make your module compatible with the `ModulePolicyDescription.sub_module_replacement` and there is no weight sharding in your module, you can just implement the `from_native_module` method without inheriting the `ParallelModule` like `colossalai/shardformer/layer/normalization.py`.
- **Do not import any file in the `colossalai/shardformer/policies` and `colossalai/shardformer/modeling` to avoid unwanted import error**. For example, a file in these folders accidentally imports `transformers` library at the top of the file, then the user will have to install `transformers` library even if they do not use this file. Any file in the `modeling` folder should be only imported by the policy file. A policy implementation should be only imported dynamically via the autopolicy or manually via the `ShardFormer` module.
- Try to keep your import statement on third-party libraries such as `transformers` within the function body instead of the header section of the file. This is because we do not want to import the library when we do not use the policy.
- Step 2. Register your policy to the autopolicy
Next, you need to register your policy in the `colossalai/shardformer/policies/autopolicy.py` file.
For example, if we register the policy for the BERT model, we just add a key-value in the `_POLICY_LIST` dictionary. The key if the `qualname` of the model object (you can get it by model.__class__.__qualname__). The value is a `PolicyLocation` object, which contains the file name and the class name of the policy. We do not import the policy directly because the policy file may contain libraries (such as `transformers`) which we do not want to import when we do not use the policy.
For example, if we register the policy for the BERT model, we just add a key-value in the `_POLICY_LIST` dictionary. The key if the `qualname` of the model object (you can get it by model.\_\_class\_\_.\_\_qualname\_\_). The value is a `PolicyLocation` object, which contains the file name and the class name of the policy. We do not import the policy directly because the policy file may contain libraries (such as `transformers`) which we do not want to import when we do not use the policy.
```python
_POLICY_LIST = {
@ -360,7 +356,6 @@ Add your model to the `tests/kit/model_zoo` file. This allows you to define test
Next, implement your unit test in the `tests/test_shardformer` folder. Please refer to other similar tests for style consistency.
- Step 3. Execute your test
When you run tests locally, you should run tests for both your newly-added test file and the whole `shardformer` module tests.