mirror of https://github.com/hpcaitech/ColossalAI
update README (#3909)
parent
21a3915c98
commit
6370a935f6
|
@ -55,30 +55,37 @@ If you wanna parallel the model in a custom way, just overwrite the policy class
|
|||
You should do:
|
||||
|
||||
1. Inherit Policy class
|
||||
2. Overwrite argument_policy method
|
||||
- In this method, you need to list which layers class you wanna modify and the attributes and parameters in those layers.
|
||||
3. Overwrite inject_policy method (Optional)
|
||||
- If you need to modify the forward or backward progress.
|
||||
4. Overwrite or add the param recording functions
|
||||
2. Overwrite `argument_policy` method
|
||||
- In this method, you need to list which layers class you wanna modify and the attributes and parameters in those layers. Shardformer will replace all the layer belonging to the class you specified.
|
||||
- `attr_dict` is dict contains all the attributes need to be modified in this layer.
|
||||
- `param_funcs` is a list contains some functions which will return the path of the weight and bias from the layer.
|
||||
3. Overwrite `inject_policy` method (Optional)
|
||||
- Shardformer will inject the model according to this method. If you need to modify the forward or backward progress (like distributed corssentropy loss in Bert) you need to overwrite this method.
|
||||
4. Overwrite or add the param functions
|
||||
- These functions use a suffix to record the path of weight or bias for the layer.
|
||||
5. Overwrite binding
|
||||
- The return is a list contains some `Col_Layer` or `Row_Layer` objects, which means slice along col and row respectively.
|
||||
5. Overwrite `binding_policy` (Optional)
|
||||
- Overwrite to specify Shardformer will bind some weight between layers, like embedding and unembedding layers.
|
||||
- This function will return a dict, the key and value are the suffix of weight need to be binded.
|
||||
|
||||
More details can be found in shardformer/policies/basepolicy.py
|
||||
``` python
|
||||
from colossalai.shardformer.policies.basepolicy import Policy, Layer, Col_Layer, Row_Layer, Argument
|
||||
|
||||
CustomPolicy(Policy):
|
||||
@staticmethod
|
||||
def argument_policy(model_config, shard_config: int) -> Dict[nn.Module,Argument]:
|
||||
"""
|
||||
Return a dict, the key is layer will be modified and the value is the Argument class with param setting and param functions
|
||||
@staticmethod
|
||||
def argument_policy(model_config, shard_config: int) -> Dict[nn.Module, Argument]:
|
||||
r"""
|
||||
Return the dict for the modify policy, the key is the original layer class and the value is the
|
||||
argument for the modify layer
|
||||
|
||||
Args:
|
||||
model_config: The config of transformer model
|
||||
shard_setting: The config of distributed model
|
||||
model_config (:class:`tansformer.Config`): The config of transformer model
|
||||
shard_config (:class:`ShardConfig`): The config for sharding model
|
||||
|
||||
Return:
|
||||
Dict for the modify policy,
|
||||
::
|
||||
{
|
||||
origin layer class1 (nn.Module): Argument(
|
||||
attr_dict = {
|
||||
|
@ -112,18 +119,29 @@ CustomPolicy(Policy):
|
|||
|
||||
@staticmethod
|
||||
def inject_policy() -> Tuple[nn.Module, nn.Module]:
|
||||
"""
|
||||
r"""
|
||||
Return the dict for the inject model
|
||||
|
||||
Return:
|
||||
The injected model, key is the original model and value is the new shardmodel
|
||||
::
|
||||
(OrignModel, CustomModel)
|
||||
in `CustomModel`, we can overwrite the forward and backward process
|
||||
"""
|
||||
return ()
|
||||
|
||||
@staticmethod
|
||||
def binding_policy() -> Dict:
|
||||
"""
|
||||
r"""
|
||||
Return the dict for the binding model
|
||||
|
||||
Return:
|
||||
This method should return the binding relationship for some layers share the weight or bias,
|
||||
the key and value is the suffix of the weight or bias of the model
|
||||
::
|
||||
return {
|
||||
"bert.embeddings.word_embeddings.weight": "cls.predictions.decoder.weight",
|
||||
}
|
||||
"""
|
||||
return NotImplementedError
|
||||
|
||||
|
|
Loading…
Reference in New Issue