mirror of https://github.com/hpcaitech/ColossalAI
Browse Source
* init shardformer code structure * add implement of sharder (inject and replace) * add implement of replace layer to colossal layer * separate different layer policy, add some notion * implement 1d and 2d slicer, can tell col or row * fix bug when slicing and inject model * fix some bug; add inference test example * add share weight and train example * add train * add docstring and readme * add docstring for other files * pre-commitpull/4157/head
FoolPlayer
2 years ago
committed by
Frank Lee
14 changed files with 612 additions and 339 deletions
@ -0,0 +1,177 @@ |
|||||||
|
## ShardFormer |
||||||
|
|
||||||
|
### Intro |
||||||
|
Make the model in huggingface.co can be paralleled and can be used with colossalai according to custom policy. |
||||||
|
|
||||||
|
### Quick start |
||||||
|
1. Usage |
||||||
|
- Use |
||||||
|
``` python |
||||||
|
from colossalai.shardformer.shard.shardmodel import ShardModel |
||||||
|
from transformers import BertForMaskedLM |
||||||
|
|
||||||
|
# create huggingface model as normal |
||||||
|
model = BertForMaskedLM.from_pretrained("bert-base-uncased") |
||||||
|
|
||||||
|
# make the huggingface model paralleled to ShardModel |
||||||
|
# auto policy: |
||||||
|
shardmodel = ShardModel(model).model |
||||||
|
|
||||||
|
# custom policy: |
||||||
|
from xxx import <POLICYCLASS> |
||||||
|
shardmodel = ShardModel(model, <POLICYCLASS>).model |
||||||
|
|
||||||
|
|
||||||
|
# do angthing as normal |
||||||
|
... |
||||||
|
``` |
||||||
|
- Policy |
||||||
|
|
||||||
|
If you wanna parallel the model in custom way, just overwrite the policy class for the huggingface model. |
||||||
|
|
||||||
|
You should do: |
||||||
|
|
||||||
|
1. Inherit Policy class |
||||||
|
2. Overwrite argument_policy method |
||||||
|
- In this method you need to list which layers class you wanna modify and the attributes and parameters in those layers. |
||||||
|
3. Overwrite inject_policy method [Optional] |
||||||
|
- If you need to modify the forward or backward progress. |
||||||
|
4. Overwrite or add the param recording functions |
||||||
|
- These function use suffix to record the path of weight or bias for the layer. |
||||||
|
5. Overwrite binding |
||||||
|
|
||||||
|
More details can be found in shardformer/policies/basepolicy.py |
||||||
|
``` python |
||||||
|
from colossalai.shardformer.policies.basepolicy import Policy, Layer, Col_Layer, Row_Layer, Argument |
||||||
|
|
||||||
|
CustomPolicy(Policy): |
||||||
|
@staticmethod |
||||||
|
def argument_policy(model_config, shard_config: int) -> Dict[nn.Module,Argument]: |
||||||
|
""" |
||||||
|
Return a dict, the key is layer will be modified and the value is the Argument class with param setting and param functions |
||||||
|
|
||||||
|
Args: |
||||||
|
model_config: The config of transformer model |
||||||
|
shard_setting: The config of distributed model |
||||||
|
|
||||||
|
Return: |
||||||
|
Dict for the modify policy, |
||||||
|
{ |
||||||
|
origin layer class1 (nn.Module): Argument( |
||||||
|
attr_dict = { |
||||||
|
argument1: value1, |
||||||
|
argument2: value2, |
||||||
|
... |
||||||
|
}, |
||||||
|
param_funcs = [ |
||||||
|
staticmethod1, |
||||||
|
staticmethod2, |
||||||
|
... |
||||||
|
] |
||||||
|
), |
||||||
|
origin layer class2 (nn.Module): Argument( |
||||||
|
attr_dict = { |
||||||
|
argument1: value1, |
||||||
|
argument2: value2, |
||||||
|
... |
||||||
|
}, |
||||||
|
param_funcs = [ |
||||||
|
staticmethod1, |
||||||
|
staticmethod2, |
||||||
|
... |
||||||
|
] |
||||||
|
), |
||||||
|
... |
||||||
|
} |
||||||
|
|
||||||
|
""" |
||||||
|
raise NotImplementedError |
||||||
|
|
||||||
|
@staticmethod |
||||||
|
def inject_policy() -> Tuple[nn.Module, nn.Module]: |
||||||
|
""" |
||||||
|
Return the dict for the inject model |
||||||
|
|
||||||
|
Return: |
||||||
|
The injected model, key is the original model and value is the new shardmodel |
||||||
|
""" |
||||||
|
return () |
||||||
|
|
||||||
|
@staticmethod |
||||||
|
def binding_policy() -> Dict: |
||||||
|
""" |
||||||
|
Return the dict for the binding model |
||||||
|
""" |
||||||
|
return NotImplementedError |
||||||
|
|
||||||
|
@staticmethod |
||||||
|
def attn_in() -> List: |
||||||
|
""" |
||||||
|
Attention qkv layer |
||||||
|
|
||||||
|
Returns: |
||||||
|
List[Layer]: List of layer object, each layer is the new |
||||||
|
""" |
||||||
|
return NotImplementedError |
||||||
|
|
||||||
|
@staticmethod |
||||||
|
def attn_out() -> List: |
||||||
|
""" |
||||||
|
Attention output projection layer |
||||||
|
|
||||||
|
Returns: |
||||||
|
List[Layer]: List of layer object |
||||||
|
""" |
||||||
|
return NotImplementedError |
||||||
|
|
||||||
|
@staticmethod |
||||||
|
def mlp_in() -> List: |
||||||
|
""" |
||||||
|
h -> 4h mlp layer |
||||||
|
|
||||||
|
Returns: |
||||||
|
List[Layer]: List of layer object |
||||||
|
""" |
||||||
|
return NotImplementedError |
||||||
|
|
||||||
|
@staticmethod |
||||||
|
def mlp_out() -> List: |
||||||
|
""" |
||||||
|
4h -> h mlp layer |
||||||
|
|
||||||
|
Returns: |
||||||
|
List[Layer]: List of layer object |
||||||
|
""" |
||||||
|
return NotImplementedError |
||||||
|
|
||||||
|
@staticmethod |
||||||
|
def embedding() -> List: |
||||||
|
""" |
||||||
|
Partially slice the embedding layer |
||||||
|
vocab_size->vocab_size//gpu_nums |
||||||
|
|
||||||
|
Return: |
||||||
|
List[Layer]: List of layer object |
||||||
|
""" |
||||||
|
return NotImplementedError |
||||||
|
|
||||||
|
@staticmethod |
||||||
|
def unembedding() -> List: |
||||||
|
""" |
||||||
|
Partially slice the embedding layer |
||||||
|
vocab_size->vocab_size//gpu_nums |
||||||
|
|
||||||
|
Return: |
||||||
|
List[Layer]: List of layer object |
||||||
|
""" |
||||||
|
return NotImplementedError |
||||||
|
|
||||||
|
``` |
||||||
|
|
||||||
|
2. Simple example |
||||||
|
``` shell |
||||||
|
# inference |
||||||
|
colossalai run --nproc_per_node 2 --master_port 29500 test.py --config config.py --mode inference |
||||||
|
# train |
||||||
|
colossalai run --nproc_per_node 2 --master_port 29500 test.py --config config.py --mode train |
||||||
|
``` |
@ -1,5 +1 @@ |
|||||||
parallel = dict( |
parallel = dict(data=1, pipeline=1, tensor=dict(size=2, mode='1d')) |
||||||
data=1, |
|
||||||
pipeline=1, |
|
||||||
tensor=dict(size=2, mode='1d') |
|
||||||
) |
|
||||||
|
Loading…
Reference in new issue