@ -55,30 +55,37 @@ If you wanna parallel the model in a custom way, just overwrite the policy class
You should do:
You should do:
1. Inherit Policy class
1. Inherit Policy class
2. Overwrite argument_policy method
2. Overwrite `argument_policy` method
- In this method, you need to list which layers class you wanna modify and the attributes and parameters in those layers.
- In this method, you need to list which layers class you wanna modify and the attributes and parameters in those layers. Shardformer will replace all the layer belonging to the class you specified.
3. Overwrite inject_policy method (Optional)
- `attr_dict` is dict contains all the attributes need to be modified in this layer.
- If you need to modify the forward or backward progress.
- `param_funcs` is a list contains some functions which will return the path of the weight and bias from the layer.
4. Overwrite or add the param recording functions
3. Overwrite `inject_policy` method (Optional)
- Shardformer will inject the model according to this method. If you need to modify the forward or backward progress (like distributed corssentropy loss in Bert) you need to overwrite this method.
4. Overwrite or add the param functions
- These functions use a suffix to record the path of weight or bias for the layer.
- These functions use a suffix to record the path of weight or bias for the layer.
5. Overwrite binding
- The return is a list contains some `Col_Layer` or `Row_Layer` objects, which means slice along col and row respectively.
5. Overwrite `binding_policy` (Optional)
- Overwrite to specify Shardformer will bind some weight between layers, like embedding and unembedding layers.
- This function will return a dict, the key and value are the suffix of weight need to be binded.
More details can be found in shardformer/policies/basepolicy.py
More details can be found in shardformer/policies/basepolicy.py
``` python
``` python
from colossalai.shardformer.policies.basepolicy import Policy, Layer, Col_Layer, Row_Layer, Argument
from colossalai.shardformer.policies.basepolicy import Policy, Layer, Col_Layer, Row_Layer, Argument