mirror of https://github.com/hpcaitech/ColossalAI
revise shardformer readme (#4246)
parent
4e9b09c222
commit
9a4842c571
|
@ -22,7 +22,6 @@
|
|||
- [System Performance](#system-performance)
|
||||
- [Convergence](#convergence)
|
||||
|
||||
|
||||
## 🔗 Introduction
|
||||
|
||||
**Shardformer** is a module that automatically parallelizes the mainstream models in libraries such as HuggingFace and TIMM. This module aims to make parallelization hassle-free for users who are not from the system background.
|
||||
|
@ -33,7 +32,7 @@
|
|||
|
||||
The sample API usage is given below:
|
||||
|
||||
``` python
|
||||
```python
|
||||
from colossalai.shardformer import ShardConfig, Shard
|
||||
from transformers import BertForMaskedLM
|
||||
|
||||
|
@ -74,6 +73,7 @@ shard_former.optimize(model, my_policy)
|
|||
|
||||
|
||||
```
|
||||
|
||||
## 🗺 Roadmap
|
||||
|
||||
We will follow this roadmap to develop Shardformer:
|
||||
|
@ -117,15 +117,13 @@ Please refer to the code for more details.
|
|||
<br/>
|
||||
</p>
|
||||
|
||||
|
||||
|
||||
### Distributed Modules
|
||||
|
||||
`ShardFormer` replaces the original PyTorch module with a distributed module.
|
||||
The distributed module keeps the same attributes as the original module but replaces the original parameters with distributed parameters and defines a new `forward` function to execute distributed computation.
|
||||
Each distributed module implements its `from_native_module` static method to convert the PyTorch module to its corresponding distributed module.
|
||||
|
||||
```python
|
||||
````python
|
||||
class ParallelModule(torch.nn.Module):
|
||||
|
||||
@abstractmethod
|
||||
|
@ -140,7 +138,7 @@ class ParallelModule(torch.nn.Module):
|
|||
my_linear = Linear1D_Col.from_native_module(my_linear, process_group)
|
||||
```
|
||||
"""
|
||||
```
|
||||
````
|
||||
|
||||
### Shard Config
|
||||
|
||||
|
@ -169,7 +167,7 @@ We abstract the policy into four stages:
|
|||
2. Providing `ModulePolicyDescription`: call `Policy.module_policy` to get a bunch of `ModulePolicyDescription` to tell `ModelSharder` how the submodules's attributes, child parameters, and deeper submodules will be substituted.
|
||||
3. Postprocessing: call `Policy.postprocess` to perform some postprocessing work, for example, binding the embedding and classifier head weights of the BERT model.
|
||||
|
||||
``` python
|
||||
```python
|
||||
@dataclass
|
||||
class ModulePolicyDescription:
|
||||
r"""
|
||||
|
@ -238,7 +236,6 @@ class Policy(ABC):
|
|||
...
|
||||
```
|
||||
|
||||
|
||||
### Model Sharder
|
||||
|
||||
`ModelSharder` is the class in charge of sharding the model based on the given policy.
|
||||
|
@ -324,21 +321,20 @@ You can create a new file in the `colossalai/shardformer/policies` folder and na
|
|||
Please follow the following protocols when writing your policy:
|
||||
|
||||
- You have to make a clear decision what you want to replace exactly in the original PyTorch module
|
||||
- Use `ModulePolicyDescription.attribute_replacement` to replace the module attributes
|
||||
- Use `ModulePolicyDescription.param_replacement` to replace the module parameters
|
||||
- Use `ModulePolicyDescription.sub_module_replacement` to replace the submodules completely. The target module should implement the `from_native_module` for the .
|
||||
- Use `ModulePolicyDescription.method_replacement` to replace the module methods. **These replacement methods should be put in the `shardformer/modeling/<model-name>.py`**.
|
||||
- Use `ModulePolicyDescription.attribute_replacement` to replace the module attributes
|
||||
- Use `ModulePolicyDescription.param_replacement` to replace the module parameters
|
||||
- Use `ModulePolicyDescription.sub_module_replacement` to replace the submodules completely. The target module should implement the `from_native_module` for the replacement.
|
||||
- Use `ModulePolicyDescription.method_replacement` to replace the module methods. **These replacement methods should be put in the `shardformer/modeling/<model-name>.py`**.
|
||||
- You can implement the `ParallelModule` for primitive modules in the `shardformer/layer/<model-name>.py` file. Primitive modules refer to modules which are not composed of other modules. For example, the `torch.nn.Linear` module is a primitive module while modules such as `BertEncoder` module in the `transformers` library is a composite module. Primitive modules do not nested inner `nn.Module` members. For composite modules, you should consider using `ModulePolicyDescription` to implement your replacement.
|
||||
- `ParallelModule` is meant to be used in two ways: `ParallelModule.from_native_module` to convert native PyTorch module to the `ParallelModule` and `ParallelModule(...)` to instantiate the module directly just like a normal PyTorch module. `ParallelModule` should be only implemented for modules whose weights are sharded. If you want to make your module compatible with the `ModulePolicyDescription.sub_module_replacement` and there is no weight sharding in your module, you can just implement the `from_native_module` method without inheriting the `ParallelModule` like `colossalai/shardformer/layer/normalization.py`.
|
||||
- **Do not import any file in the `colossalai/shardformer/policies` and `colossalai/shardformer/modeling` to avoid unwanted import error**. For example, a file in these folders accidentally imports `transformers` library at the top of the file, then the user will have to install `transformers` library even if they do not use this file. Any file in the `modeling` folder should be only imported by the policy file. A policy implementation should be only imported dynamically via the autopolicy or manually via the `ShardFormer` module.
|
||||
- Try to keep your import statement on third-party libraries such as `transformers` within the function body instead of the header section of the file. This is because we do not want to import the library when we do not use the policy.
|
||||
|
||||
|
||||
- Step 2. Register your policy to the autopolicy
|
||||
|
||||
Next, you need to register your policy in the `colossalai/shardformer/policies/autopolicy.py` file.
|
||||
|
||||
For example, if we register the policy for the BERT model, we just add a key-value in the `_POLICY_LIST` dictionary. The key if the `qualname` of the model object (you can get it by model.__class__.__qualname__). The value is a `PolicyLocation` object, which contains the file name and the class name of the policy. We do not import the policy directly because the policy file may contain libraries (such as `transformers`) which we do not want to import when we do not use the policy.
|
||||
For example, if we register the policy for the BERT model, we just add a key-value in the `_POLICY_LIST` dictionary. The key if the `qualname` of the model object (you can get it by model.\_\_class\_\_.\_\_qualname\_\_). The value is a `PolicyLocation` object, which contains the file name and the class name of the policy. We do not import the policy directly because the policy file may contain libraries (such as `transformers`) which we do not want to import when we do not use the policy.
|
||||
|
||||
```python
|
||||
_POLICY_LIST = {
|
||||
|
@ -360,7 +356,6 @@ Add your model to the `tests/kit/model_zoo` file. This allows you to define test
|
|||
|
||||
Next, implement your unit test in the `tests/test_shardformer` folder. Please refer to other similar tests for style consistency.
|
||||
|
||||
|
||||
- Step 3. Execute your test
|
||||
|
||||
When you run tests locally, you should run tests for both your newly-added test file and the whole `shardformer` module tests.
|
||||
|
|
Loading…
Reference in New Issue