Browse Source

[shardformer]: Feature/shardformer, add some docstring and readme (#3816)

* init shardformer code structure

* add implement of sharder (inject and replace)

* add implement of replace layer to colossal layer

* separate different layer policy, add some notion

* implement 1d and 2d slicer, can tell col or row

* fix bug when slicing and inject model

* fix some bug; add inference test example

* add share weight and train example

* add train

* add docstring and readme

* add docstring for other files

* pre-commit
pull/4157/head
FoolPlayer 2 years ago committed by Frank Lee
parent
commit
8cc11235c0
  1. 2
      colossalai/nn/layer/parallel_1d/_operation.py
  2. 9
      colossalai/nn/layer/parallel_1d/layers.py
  3. 177
      colossalai/shardformer/README.md
  4. 16
      colossalai/shardformer/model/modeling_bert.py
  5. 25
      colossalai/shardformer/policies/autopolicy.py
  6. 128
      colossalai/shardformer/policies/basepolicy.py
  7. 125
      colossalai/shardformer/policies/bert.py
  8. 4
      colossalai/shardformer/shard/shardconfig.py
  9. 187
      colossalai/shardformer/shard/sharder.py
  10. 36
      colossalai/shardformer/shard/shardmodel.py
  11. 113
      colossalai/shardformer/shard/slicer.py
  12. 6
      colossalai/shardformer/test/config.py
  13. 87
      colossalai/shardformer/test/test.py
  14. 36
      colossalai/shardformer/utils/utils.py

2
colossalai/nn/layer/parallel_1d/_operation.py

@ -1,5 +1,6 @@
import torch import torch
import torch.distributed as dist import torch.distributed as dist
from colossalai.core import global_context as gpc from colossalai.core import global_context as gpc
try: try:
@ -72,6 +73,7 @@ class LinearWithAsyncCommunication(torch.autograd.Function):
total_input = input total_input = input
grad_input = grad_output.matmul(weight) grad_input = grad_output.matmul(weight)
grad_output = grad_output.contiguous()
# Convert the tensor shapes to 2D for execution compatibility # Convert the tensor shapes to 2D for execution compatibility
grad_output = grad_output.view(grad_output.shape[0] * grad_output.shape[1], grad_output.shape[2]) grad_output = grad_output.view(grad_output.shape[0] * grad_output.shape[1], grad_output.shape[2])
total_input = total_input.view(total_input.shape[0] * total_input.shape[1], total_input.shape[2]) total_input = total_input.view(total_input.shape[0] * total_input.shape[1], total_input.shape[2])

9
colossalai/nn/layer/parallel_1d/layers.py

@ -469,7 +469,8 @@ class Linear1D_Col(ParallelLayer):
if skip_bias_add and not bias: if skip_bias_add and not bias:
raise ValueError('cannot skip bias addition if bias is None') raise ValueError('cannot skip bias addition if bias is None')
self.out_features_per_partition = divide(out_features, gpc.tensor_parallel_size) # self.out_features_per_partition = divide(out_features*2, gpc.tensor_parallel_size)
self.out_features_per_partition = out_features
# Parameters. # Parameters.
# Initialize weight. # Initialize weight.
@ -612,7 +613,8 @@ class Linear1D_Row(ParallelLayer):
raise ValueError('cannot skip bias addition if bias is None') raise ValueError('cannot skip bias addition if bias is None')
# Divide the weight matrix along the last dimension. # Divide the weight matrix along the last dimension.
self.input_size_per_partition = divide(in_features, gpc.tensor_parallel_size) # self.input_size_per_partition = divide(in_features*2, gpc.tensor_parallel_size)
self.input_size_per_partition = in_features
# Parameters. # Parameters.
# Initialize weight. # Initialize weight.
@ -884,7 +886,8 @@ class VocabParallelEmbedding1D(ParallelLayer):
tensor_parallel_size = gpc.get_world_size(ParallelMode.PARALLEL_1D) tensor_parallel_size = gpc.get_world_size(ParallelMode.PARALLEL_1D)
tensor_parallel_rank = gpc.get_local_rank(ParallelMode.PARALLEL_1D) tensor_parallel_rank = gpc.get_local_rank(ParallelMode.PARALLEL_1D)
self.num_embeddings_per_partition = divide(num_embeddings, tensor_parallel_size) # self.num_embeddings_per_partition = divide(num_embeddings, tensor_parallel_size)
self.num_embeddings_per_partition = num_embeddings
self.vocab_start_index = tensor_parallel_rank * self.num_embeddings_per_partition self.vocab_start_index = tensor_parallel_rank * self.num_embeddings_per_partition
self.vocab_end_index = self.vocab_start_index + self.num_embeddings_per_partition self.vocab_end_index = self.vocab_start_index + self.num_embeddings_per_partition

177
colossalai/shardformer/README.md

@ -0,0 +1,177 @@
## ShardFormer
### Intro
Make the model in huggingface.co can be paralleled and can be used with colossalai according to custom policy.
### Quick start
1. Usage
- Use
``` python
from colossalai.shardformer.shard.shardmodel import ShardModel
from transformers import BertForMaskedLM
# create huggingface model as normal
model = BertForMaskedLM.from_pretrained("bert-base-uncased")
# make the huggingface model paralleled to ShardModel
# auto policy:
shardmodel = ShardModel(model).model
# custom policy:
from xxx import <POLICYCLASS>
shardmodel = ShardModel(model, <POLICYCLASS>).model
# do angthing as normal
...
```
- Policy
If you wanna parallel the model in custom way, just overwrite the policy class for the huggingface model.
You should do:
1. Inherit Policy class
2. Overwrite argument_policy method
- In this method you need to list which layers class you wanna modify and the attributes and parameters in those layers.
3. Overwrite inject_policy method [Optional]
- If you need to modify the forward or backward progress.
4. Overwrite or add the param recording functions
- These function use suffix to record the path of weight or bias for the layer.
5. Overwrite binding
More details can be found in shardformer/policies/basepolicy.py
``` python
from colossalai.shardformer.policies.basepolicy import Policy, Layer, Col_Layer, Row_Layer, Argument
CustomPolicy(Policy):
@staticmethod
def argument_policy(model_config, shard_config: int) -> Dict[nn.Module,Argument]:
"""
Return a dict, the key is layer will be modified and the value is the Argument class with param setting and param functions
Args:
model_config: The config of transformer model
shard_setting: The config of distributed model
Return:
Dict for the modify policy,
{
origin layer class1 (nn.Module): Argument(
attr_dict = {
argument1: value1,
argument2: value2,
...
},
param_funcs = [
staticmethod1,
staticmethod2,
...
]
),
origin layer class2 (nn.Module): Argument(
attr_dict = {
argument1: value1,
argument2: value2,
...
},
param_funcs = [
staticmethod1,
staticmethod2,
...
]
),
...
}
"""
raise NotImplementedError
@staticmethod
def inject_policy() -> Tuple[nn.Module, nn.Module]:
"""
Return the dict for the inject model
Return:
The injected model, key is the original model and value is the new shardmodel
"""
return ()
@staticmethod
def binding_policy() -> Dict:
"""
Return the dict for the binding model
"""
return NotImplementedError
@staticmethod
def attn_in() -> List:
"""
Attention qkv layer
Returns:
List[Layer]: List of layer object, each layer is the new
"""
return NotImplementedError
@staticmethod
def attn_out() -> List:
"""
Attention output projection layer
Returns:
List[Layer]: List of layer object
"""
return NotImplementedError
@staticmethod
def mlp_in() -> List:
"""
h -> 4h mlp layer
Returns:
List[Layer]: List of layer object
"""
return NotImplementedError
@staticmethod
def mlp_out() -> List:
"""
4h -> h mlp layer
Returns:
List[Layer]: List of layer object
"""
return NotImplementedError
@staticmethod
def embedding() -> List:
"""
Partially slice the embedding layer
vocab_size->vocab_size//gpu_nums
Return:
List[Layer]: List of layer object
"""
return NotImplementedError
@staticmethod
def unembedding() -> List:
"""
Partially slice the embedding layer
vocab_size->vocab_size//gpu_nums
Return:
List[Layer]: List of layer object
"""
return NotImplementedError
```
2. Simple example
``` shell
# inference
colossalai run --nproc_per_node 2 --master_port 29500 test.py --config config.py --mode inference
# train
colossalai run --nproc_per_node 2 --master_port 29500 test.py --config config.py --mode train
```

16
colossalai/shardformer/model/modeling_bert.py

@ -1,12 +1,14 @@
from typing import Any, Dict, List, Type
import torch import torch
import torch.nn as nn import torch.nn as nn
from torch.nn import CrossEntropyLoss from torch.nn import CrossEntropyLoss
from typing import Any, Dict, List, Type
from transformers import BertForMaskedLM from transformers import BertForMaskedLM
from transformers.models.bert.modeling_bert import MaskedLMOutput from transformers.models.bert.modeling_bert import MaskedLMOutput
class BertForMaskedLM_(BertForMaskedLM): class BertForMaskedLM_(BertForMaskedLM):
def forward( def forward(
self, self,
input_ids=None, input_ids=None,
@ -23,7 +25,7 @@ class BertForMaskedLM_(BertForMaskedLM):
return_dict=None, return_dict=None,
**kwargs, **kwargs,
): ):
print("[Inject OK] Injected forward method") # print("[Inject OK] Injected forward method")
return_dict = return_dict if return_dict is not None else self.config.use_return_dict return_dict = return_dict if return_dict is not None else self.config.use_return_dict
outputs = self.bert( outputs = self.bert(
@ -46,9 +48,9 @@ class BertForMaskedLM_(BertForMaskedLM):
masked_lm_loss = None masked_lm_loss = None
# if input_ids is not None: # if input_ids is not None:
# masked_lm_loss = applyDistCrossEntropy(prediction_scores, input_ids, self.config.vocab_size) # masked_lm_loss = applyDistCrossEntropy(prediction_scores, input_ids, self.config.vocab_size)
if labels is not None: if labels is not None:
loss_fct = CrossEntropyLoss() # -100 index = padding token loss_fct = CrossEntropyLoss() # -100 index = padding token
masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1)) masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
if not return_dict: if not return_dict:
@ -60,4 +62,4 @@ class BertForMaskedLM_(BertForMaskedLM):
logits=prediction_scores, logits=prediction_scores,
hidden_states=outputs.hidden_states, hidden_states=outputs.hidden_states,
attentions=outputs.attentions, attentions=outputs.attentions,
) )

25
colossalai/shardformer/policies/autopolicy.py

@ -1,40 +1,47 @@
import torch.nn as nn import torch.nn as nn
def build_policies(): def build_policies():
""" r"""
Build the policies for the model Build the policies for the model
Return: Return:
The dict for the policies The dict for the policies
""" """
auto_policy_dict = {} auto_policy_dict = {}
from transformers.models.bert.modeling_bert import BertForMaskedLM from transformers.models.bert.modeling_bert import BertForMaskedLM
from .bert import BertForMaskedLMPolicy from .bert import BertForMaskedLMPolicy
auto_policy_dict[BertForMaskedLM] = BertForMaskedLMPolicy auto_policy_dict[BertForMaskedLM] = BertForMaskedLMPolicy
from transformers.models.bert.modeling_bert import BertForSequenceClassification from transformers.models.bert.modeling_bert import BertForSequenceClassification
from .bert import BertForSequenceClassificationPolicy from .bert import BertForSequenceClassificationPolicy
auto_policy_dict[BertForSequenceClassification] = BertForSequenceClassificationPolicy auto_policy_dict[BertForSequenceClassification] = BertForSequenceClassificationPolicy
return auto_policy_dict return auto_policy_dict
def get_autopolicy(model:nn.Module):
""" def get_autopolicy(model: nn.Module):
r"""
Return the auto policy for the model Return the auto policy for the model
Args: Args:
model: The model to be used model (:class:`nn.Module`): The model to get the auto policy
Return: Return:
The auto policy for the model :class:`Policy`: The auto policy for the model
""" """
auto_policy_dict = build_policies() auto_policy_dict = build_policies()
policy = auto_policy_dict.get(model.__class__, None) policy = auto_policy_dict.get(model.__class__, None)
if policy is None: if policy is None:
raise NotImplementedError(f"Auto policy for {model.__class__.__qualname__} is not implemented\n Supported models are {[i.__qualname__ for i in auto_policy_dict.keys()]}") raise NotImplementedError(
f"Auto policy for {model.__class__.__qualname__} is not implemented\n Supported models are {[i.__qualname__ for i in auto_policy_dict.keys()]}"
)
return policy return policy
# from transformers.models.bert.modeling_bert import BertForMaskedLM, BertForPreTraining # from transformers.models.bert.modeling_bert import BertForMaskedLM, BertForPreTraining
# model = BertForPreTraining # model = BertForPreTraining
# policy = get_autopolicy(model) # policy = get_autopolicy(model)

128
colossalai/shardformer/policies/basepolicy.py

@ -1,28 +1,38 @@
# part of code modified from https://github.com/tunib-ai/parallelformers # part of code modified from https://github.com/tunib-ai/parallelformers
from dataclasses import dataclass, field
from typing import Any, Callable, Dict, List, Tuple, Type
import torch import torch
import torch.nn as nn import torch.nn as nn
import colossalai.nn as col_nn
from typing import Any, Dict, List, Type, Tuple, Callable
from transformers import AutoConfig from transformers import AutoConfig
from dataclasses import dataclass, field
import colossalai.nn as col_nn
@dataclass @dataclass
class Argument: class Argument:
attr_dict : Dict[str, Any] r"""
param_funcs : List[Callable] The argument class for the policy
binding_layers : List[nn.Module] = field(default_factory=list)
Args:
attr_dict (Dict[str, Any]): The dict for the param setting
param_funcs (:class:`List[Callable]`): The list for the param functions
"""
attr_dict: Dict[str, Any]
param_funcs: List[Callable]
@dataclass @dataclass
class Layer: class Layer:
""" r"""
The layer object for the policy The layer object for the policy
Args: Args:
weight: The weight name of the layer weight (str): The weight suffix of the layer
bias: The bias name of the layer bias (str): The bias suffix of the layer
replace_layer: The layer to replace the original layer replace_layer (:class:`colosalai.nn`): The layer to replace the original layer
ignore: Whether to ignore this layer if it is not in the model ignore (bool): Whether to ignore this layer if it is not in the model
""" """
weight: str = None weight: str = None
bias: str = None bias: str = None
@ -32,45 +42,55 @@ class Layer:
@dataclass @dataclass
class Col_Layer(Layer): class Col_Layer(Layer):
""" r"""
Class for col shard layer in MegatronLM Class for col shard layer in MegatronLM
Args:
gather_output (bool): Whether to gather the output of the layer
""" """
gather_output: bool = False gather_output: bool = False
@dataclass @dataclass
class Row_Layer(Layer): class Row_Layer(Layer):
""" r"""
Class for col shard layer in MegatronLM Class for col shard layer in MegatronLM
""" """
pass pass
class Policy(): class Policy():
""" r"""
The base class for all the policies The base class for all the policies
For each different model, it should have a different policy class, like BertPolicy for Bert Model For each different model, it should have a different policy class, like BertPolicy for Bert Model
or OPTPolicy for OPT model. or OPTPolicy for OPT model.
AutoPolicy: AutoPolicy:
shardformer already defined some policies for huggingface model, just set custom_policy = None Shardformer already defined some policies for huggingface model, just set ``custom_policy`` = None
to use the auto policy. In shardformer autopolicy, we define a base policy for one type model, to use the auto policy. In shardformer autopolicy, we define a base policy for one type model,
like BertPolicy, and for each different Bert modle in huggingface like, BertForMaskedLM, like BertPolicy, and for each different Bert modle in huggingface like, BertForMaskedLM,
BertForSequenceClassification, etc., for each different Bert model we difine different policy class BertForSequenceClassification, etc., for each different Bert model we difine different policy class
and overwrite the method inject_policy and overwrite the method like ``inject_policy`` to modify the forward and backward process.
CustomPolicy: CustomPolicy:
If you want to define your own policy, you can set ``custom_policy`` = CustomPolicy, and overwrite
all the methods in ``Policy`` class. You can refer to any policy we defined like the ``BertPolicy``
class for the example.
""" """
@staticmethod @staticmethod
def argument_policy(model_config, shard_config: int) -> Dict[nn.Module,Argument]: def argument_policy(model_config, shard_config: int) -> Dict[nn.Module, Argument]:
""" r"""
Return a dict, the key is layer will be modified and the value is the Argument class with param setting and param functions Return the dict for the modify policy, the key is the original layer class and the value is the
argument for the modify layer
Args: Args:
model_config: The config of transformer model model_config (:class:`tansformer.Config`): The config of transformer model
shard_setting: The config of distributed model shard_config (:class:`ShardConfig`): The config for sharding model
Return: Return:
Dict for the modify policy, Dict for the modify policy,
::
{ {
origin layer class1 (nn.Module): Argument( origin layer class1 (nn.Module): Argument(
attr_dict = { attr_dict = {
@ -101,33 +121,51 @@ class Policy():
""" """
raise NotImplementedError raise NotImplementedError
@staticmethod @staticmethod
def inject_policy() -> Tuple[nn.Module, nn.Module]: def inject_policy() -> Tuple[nn.Module, nn.Module]:
""" r"""
Return the dict for the inject model Return the dict for the inject model
Return: Return:
The injected model, key is the original model and value is the new shardmodel The injected model, key is the original model and value is the new shardmodel
::
(OrignModel, CustomModel)
in `CustomModel`, we can overwrite the forward and backward process
""" """
return () return ()
@staticmethod @staticmethod
def attn_in() -> List: def binding_policy() -> Dict:
r"""
Return the dict for the binding model
Return:
This method should return the binding relationship for some layers share the weight or bias,
the key and value is the suffix of the weight or bias of the model
::
return {
"bert.embeddings.word_embeddings.weight": "cls.predictions.decoder.weight",
}
""" """
return NotImplementedError
@staticmethod
def attn_in() -> List:
r"""
Attention qkv layer Attention qkv layer
In this kind of method, we should return the list of ``Layer`` object, each ``Layer`` object should be
``Layer`` for no slicing, ``Col_Layer`` for col slicing, ``Row_Layer`` for row slicing. And the parameters
in ``Layer`` object can refer to the ``Layer`` class.
Returns: Returns:
List[Layer]: List of layer object, each layer is the new List[Layer]: List of layer object, each layer is the new
""" """
return NotImplementedError return NotImplementedError
@staticmethod @staticmethod
def attn_out() -> List: def attn_out() -> List:
""" r"""
Attention output projection layer Attention output projection layer
Returns: Returns:
@ -135,46 +173,40 @@ class Policy():
""" """
return NotImplementedError return NotImplementedError
@staticmethod @staticmethod
def mlp_in() -> List: def mlp_in() -> List:
""" r"""
h -> 4h mlp layer h -> 4h mlp layer
Returns: Returns:
List[Layer]: List of layer object List[Layer]: List of layer object
""" """
return NotImplementedError return NotImplementedError
@staticmethod @staticmethod
def mlp_out() -> List: def mlp_out() -> List:
""" r"""
4h -> h mlp layer 4h -> h mlp layer
Returns: Returns:
List[Layer]: List of layer object List[Layer]: List of layer object
""" """
return NotImplementedError return NotImplementedError
@staticmethod @staticmethod
def embedding()->List: def embedding() -> List:
""" r"""
Partially slice the embedding layer Partially slice the embedding layer
vocab_size->vocab_size//gpu_nums
Return: Return:
List[Layer]: List of layer object List[Layer]: List of layer object
""" """
return NotImplementedError return NotImplementedError
@staticmethod @staticmethod
def unembedding()->List: def unembedding() -> List:
""" r"""
Partially slice the embedding layer Partially slice the embedding layer
vocab_size->vocab_size//gpu_nums
Return: Return:
List[Layer]: List of layer object List[Layer]: List of layer object

125
colossalai/shardformer/policies/bert.py

@ -1,56 +1,57 @@
from typing import Dict, List, Tuple, Type, Any, Callable from dataclasses import dataclass
from typing import Any, Callable, Dict, List, Tuple, Type
import torch.nn as nn import torch.nn as nn
from .basepolicy import Policy, Layer, Argument, Col_Layer, Row_Layer from transformers.models.bert.modeling_bert import BertEmbeddings, BertLayer, BertLMPredictionHead
import colossalai.nn as col_nn import colossalai.nn as col_nn
from transformers.models.bert.modeling_bert import BertLayer, BertEmbeddings, BertLMPredictionHead
from dataclasses import dataclass from .basepolicy import Argument, Col_Layer, Layer, Policy, Row_Layer
class BertPolicy(Policy): class BertPolicy(Policy):
@staticmethod @staticmethod
def argument_policy(config, world_size: int) -> Dict[nn.Module,Argument]: def argument_policy(config, world_size: int) -> Dict[nn.Module, Argument]:
return { return {
BertLayer: Argument( BertLayer:
attr_dict = { Argument(
# 1. shard hidden size attr_dict={
"attention.self.all_head_size": config.hidden_size // world_size, # 1. shard hidden size
"crossattention.self.all_head_size": config.hidden_size // world_size, "attention.self.all_head_size": config.hidden_size // world_size,
# 2. shard number of heads "crossattention.self.all_head_size": config.hidden_size // world_size,
"attention.self.num_attention_heads": config.num_attention_heads // world_size, # 2. shard number of heads
"crossattention.self.num_attention_heads": config.num_attention_heads // world_size, "attention.self.num_attention_heads": config.num_attention_heads // world_size,
"crossattention.self.num_attention_heads": config.num_attention_heads // world_size,
}, },
param_funcs = [ param_funcs=[BertPolicy.attn_in, BertPolicy.attn_out, BertPolicy.mlp_in, BertPolicy.mlp_out]),
BertPolicy.attn_in, BertEmbeddings:
BertPolicy.attn_out, Argument(
BertPolicy.mlp_in, attr_dict={
BertPolicy.mlp_out # 1. shard vocab size
] # "word_embeddings.num_embeddings": config.vocab_size // world_size,
), # 2. add the size of the sliced embedding layer excluding the last slice
BertEmbeddings: Argument( "word_embeddings.dim_size": (config.vocab_size + world_size - 1) // world_size,
attr_dict = { },
# 1. shard vocab size param_funcs=[
# "word_embeddings.num_embeddings": config.vocab_size // world_size, BertPolicy.embedding,
# 2. add the size of the sliced embedding layer excluding the last slice ]),
"word_embeddings.dim_size": (config.vocab_size+world_size-1) // world_size, BertLMPredictionHead:
}, Argument(
param_funcs = [ attr_dict={
BertPolicy.embedding, # 1. shard vocab size
], # "word_embeddings.num_embeddings": config.vocab_size // world_size,
binding_layers = [ # 2. add the size of the sliced embedding layer excluding the last slice
BertLMPredictionHead, },
] param_funcs=[
), BertPolicy.unembedding,
BertLMPredictionHead: Argument( ])
attr_dict = { }
# 1. shard vocab size
# "word_embeddings.num_embeddings": config.vocab_size // world_size, @staticmethod
# 2. add the size of the sliced embedding layer excluding the last slice def binding_policy() -> Dict:
}, return {
param_funcs = [ "bert.embeddings.word_embeddings.weight": "cls.predictions.decoder.weight",
BertPolicy.unembedding,
]
)
} }
@staticmethod @staticmethod
@ -89,9 +90,8 @@ class BertPolicy(Policy):
replace_layer=col_nn.Linear1D_Col, replace_layer=col_nn.Linear1D_Col,
ignore=True, ignore=True,
), ),
] ]
@staticmethod @staticmethod
def attn_out() -> List: def attn_out() -> List:
return [ return [
@ -107,17 +107,17 @@ class BertPolicy(Policy):
ignore=True, ignore=True,
), ),
] ]
@staticmethod @staticmethod
def mlp_in() -> List: def mlp_in() -> List:
return [ return [
Col_Layer( Col_Layer(
weight="intermediate.dense.weight", weight="intermediate.dense.weight",
bias="intermediate.dense.bias", bias="intermediate.dense.bias",
replace_layer=col_nn.Linear1D_Col, replace_layer=col_nn.Linear1D_Col,
), ),
] ]
@staticmethod @staticmethod
def mlp_out() -> List: def mlp_out() -> List:
return [ return [
@ -130,13 +130,11 @@ class BertPolicy(Policy):
@staticmethod @staticmethod
def embedding() -> List: def embedding() -> List:
return [ return [Col_Layer(
Col_Layer( weight="word_embeddings.weight",
weight="word_embeddings.weight", replace_layer=col_nn.VocabParallelEmbedding1D,
replace_layer=col_nn.VocabParallelEmbedding1D, )]
)
]
@staticmethod @staticmethod
def unembedding() -> List: def unembedding() -> List:
return [ return [
@ -148,16 +146,21 @@ class BertPolicy(Policy):
) )
] ]
from transformers import BertForMaskedLM from transformers import BertForMaskedLM
from colossalai.shardformer.model.modeling_bert import BertForMaskedLM_ from colossalai.shardformer.model.modeling_bert import BertForMaskedLM_
class BertForMaskedLMPolicy(BertPolicy): class BertForMaskedLMPolicy(BertPolicy):
@staticmethod @staticmethod
def inject_policy() -> Tuple[nn.Module, nn.Module]: def inject_policy() -> Tuple[nn.Module, nn.Module]:
return (BertForMaskedLM, BertForMaskedLM_) return (BertForMaskedLM, BertForMaskedLM_)
class BertForSequenceClassificationPolicy(BertPolicy): class BertForSequenceClassificationPolicy(BertPolicy):
@staticmethod @staticmethod
def inject_policy() -> Dict: def inject_policy() -> Dict:
return {} return {}
@ -165,4 +168,4 @@ class BertForSequenceClassificationPolicy(BertPolicy):
# model = BertForMaskedLM.from_pretrained("bert-base-uncased") # model = BertForMaskedLM.from_pretrained("bert-base-uncased")
# _ = BertForMaskedLMPolicy(model) # _ = BertForMaskedLMPolicy(model)
# print(isinstance(model,list(_.inject_policy().keys())[0])) # print(isinstance(model,list(_.inject_policy().keys())[0]))

4
colossalai/shardformer/shard/shardconfig.py

@ -10,9 +10,9 @@ class ShardConfig:
fp16: bool = True fp16: bool = True
num_gpus: int = 2 num_gpus: int = 2
world_size: int = 2 world_size: int = 2
backend="nccl" backend = "nccl"
verbose: str = 'simple' verbose: str = 'simple'
seed: int = None seed: int = None
require_grad: bool = False require_grad: bool = False
master_addr: str = "127.0.0.1" master_addr: str = "127.0.0.1"
master_port: int = 29500 master_port: int = 29500

187
colossalai/shardformer/shard/sharder.py

@ -1,56 +1,59 @@
import os
from dataclasses import dataclass
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Type, Union
import torch import torch
import torch.nn as nn import torch.nn as nn
from typing import Any, Dict, Iterable, List, Optional, Tuple, Type, Union, Callable
from .shardconfig import ShardConfig
from dataclasses import dataclass
from ..policies.basepolicy import Policy, Layer
from ..policies.autopolicy import get_autopolicy
from .slicer import Slicer
from ..utils.utils import hasattr_, setattr_, getattr_
import colossalai.nn as col_nn import colossalai.nn as col_nn
from colossalai.logging import get_dist_logger from colossalai.logging import get_dist_logger
import os
from ..policies.autopolicy import get_autopolicy
from ..policies.basepolicy import Layer, Policy
from ..utils.utils import getattr_, hasattr_, setattr_
from .shardconfig import ShardConfig
from .slicer import Slicer
logger = get_dist_logger() logger = get_dist_logger()
class ModelSharder(object): class ModelSharder(object):
""" r"""
Shard the original huggingface model according to the policy Shard the original huggingface model according to the policy
Args: Args:
policy: The policy to shard the model policy (:class:`Policy`): The policy to shard the model
model: The model to shard model (:class:`torch.Module`): The model to shard
dist_setting: The setting of distributed model shard_config: The setting of distributed model
""" """
def __init__( def __init__(
self, self,
model: nn.Module, model: nn.Module,
policy: Policy, policy: Policy,
shard_config: ShardConfig = None, # TODO shard_config: ShardConfig = None, # TODO
) -> None: ) -> None:
self.model = model self.model = model
self.policy = get_autopolicy(self.model) if policy is None else policy self.policy = get_autopolicy(self.model) if policy is None else policy
self.slicer = Slicer(shard_config) self.slicer = Slicer(shard_config)
self.shard_config = shard_config self.shard_config = shard_config
self.model_config = self.model.config self.model_config = self.model.config
self.binding_map = {}
def shard(self) -> None: def shard(self) -> None:
self.inject_model(self.model) self.inject_model(self.model)
self.replace_layer(self.model) self.replace_layer(self.model)
self.bind_layer(self.model)
def inject_model( def inject_model(
self, self,
model: nn.Module, model: nn.Module,
) -> None: ) -> None:
""" r"""
Replace the model to policy defined model Replace the model to policy defined model
Mainly modify the forward and backward to fit distributed model Mainly modify the forward and backward to fit distributed model
e.g. e.g.
::
BertForMaskedLM.forward -> BertForMaskedLM_.forward BertForMaskedLM.forward -> BertForMaskedLM_.forward
""" """
inject_policy = self.policy.inject_policy() inject_policy = self.policy.inject_policy()
@ -64,49 +67,43 @@ class ModelSharder(object):
setattr( setattr(
model.__class__, model.__class__,
key, key,
getattr(shard_model_cls,key), getattr(shard_model_cls, key),
) )
else: else:
raise NotImplementedError(f"{model.__class__} is not implemented so far") raise NotImplementedError(f"{model.__class__} is not implemented so far")
def replace_layer( def replace_layer(
self, self,
model: nn.Module, model: nn.Module,
) -> None: ) -> None:
""" r"""
Replace the layer according to the policy, and replace the layer one by one Replace the layer according to the policy, and replace the layer one by one
Args: Args:
layer: The layer to shard model (:class:`torch.nn.Module`): The layer to shard
""" """
argument_policies = self.policy.argument_policy(self.model_config, self.shard_config.world_size) argument_policies = self.policy.argument_policy(self.model_config, self.shard_config.world_size)
for argument_policy in argument_policies.items(): for argument_policy in argument_policies.items():
origin_layer_cls = argument_policy[0] origin_layer_cls = argument_policy[0]
attr_dict = argument_policy[1].attr_dict attr_dict = argument_policy[1].attr_dict
param_funcs = argument_policy[1].param_funcs param_funcs = argument_policy[1].param_funcs
binding_layers = argument_policy[1].binding_layers self.reverse_replace_layer(model, origin_layer_cls, attr_dict, param_funcs)
# if binding_layer is not None:
# self.binding_map[origin_layer_cls] = binding_layer
self.reverse_replace_layer(model, origin_layer_cls, attr_dict, param_funcs, binding_layers)
def reverse_replace_layer( def reverse_replace_layer(
self, self,
layer: nn.Module, layer: nn.Module,
origin_cls: nn.Module, origin_cls: nn.Module,
attr_dict: Dict[str, Any], attr_dict: Dict[str, Any],
param_funcs: List[Callable], param_funcs: List[Callable],
binding_layers: List[nn.Module] ) -> None:
) -> None: r"""
"""
Reverse the replace layer operation Reverse the replace layer operation
Args: Args:
layer: The object of layer to shard layer (:class:`torch.nn.Module`): The object of layer to shard
origin_cls: The origin layer class origin_cls (:class:`transformers.model`): The origin layer class
attr_dict: The attribute dict to modify attr_dict (Dict): The attribute dict to modify
policy_cls: The policy class policy_cls (:class:`Policy`): The policy class
""" """
for name, child in layer.named_children(): for name, child in layer.named_children():
if child.__class__ == origin_cls: if child.__class__ == origin_cls:
@ -115,25 +112,23 @@ class ModelSharder(object):
setattr_(child, k, v, ignore=True) setattr_(child, k, v, ignore=True)
# print(f"Sharding {name} layer", replac_layer.attention.self.__dict__) # print(f"Sharding {name} layer", replac_layer.attention.self.__dict__)
# setattr_(layer, name, self.shard_one_layer(child, policy_cls)) # setattr_(layer, name, self.shard_one_layer(child, policy_cls))
self.shard_one_layer(child, param_funcs, binding_layers) self.shard_one_layer(child, param_funcs)
continue continue
self.reverse_replace_layer(child, origin_cls, attr_dict, param_funcs, binding_layers) self.reverse_replace_layer(child, origin_cls, attr_dict, param_funcs)
return layer return layer
def shard_one_layer( def shard_one_layer(
self, self,
org_layer: nn.Module, org_layer: nn.Module,
param_funcs: List[Callable], param_funcs: List[Callable],
binding_layers: List[nn.Module] ) -> None:
) -> None: r"""
"""
Shard one layer according to the policy, the layer should be the same class as the key in policy's argument_policy return dict Shard one layer according to the policy, the layer should be the same class as the key in policy's argument_policy return dict
Args: Args:
org_layer: The origin layer object to shard org_layer (:class:`torch.nn.Module`): The origin layer object to shard
param_funcs: The function list to get shard information in policy class param_funcs (:class:`List[typing.Callable]`): The function list to get shard information in policy class
""" """
# print(org_layer) # print(org_layer)
@ -148,7 +143,7 @@ class ModelSharder(object):
ignore = policy_layer.ignore ignore = policy_layer.ignore
if policy_layer.__class__.__name__ == "Col_Layer": if policy_layer.__class__.__name__ == "Col_Layer":
gather_output = policy_layer.gather_output gather_output = policy_layer.gather_output
print(gather_output) # print(gather_output)
if weight_attr is not None: if weight_attr is not None:
if hasattr_(org_layer, weight_attr): if hasattr_(org_layer, weight_attr):
@ -172,67 +167,81 @@ class ModelSharder(object):
# slice weight and bias # slice weight and bias
weight, bias = self.slicer.slice_weight_bias(weight, bias, policy_layer.__class__) weight, bias = self.slicer.slice_weight_bias(weight, bias, policy_layer.__class__)
print(os.environ['RANK'], policy_layer.__class__, weight.shape, bias.shape if bias is not None else None) # print(os.environ['RANK'], policy_layer.__class__, weight.shape, bias.shape if bias is not None else None)
# save the binding information
for binding_layer in binding_layers:
self.binding_map[binding_layer] = dict(weight=weight, bias=bias)
# create new object to replace the origin layer # create new object to replace the origin layer
if replace_layer_cls is not None: if replace_layer_cls is not None:
# print(f"RANK {os.environ['RANK']}: replace {getattr_(org_layer, layer_attr).__class__} to {replace_layer_cls}, shape is {weight.shape}") # print(f"RANK {os.environ['RANK']}: replace {getattr_(org_layer, layer_attr).__class__} to {replace_layer_cls}, shape is {weight.shape}")
if isinstance(getattr_(org_layer, layer_attr), nn.Linear): if isinstance(getattr_(org_layer, layer_attr), nn.Linear):
if replace_layer_cls.__name__ == "Linear1D_Row": if replace_layer_cls.__name__ == "Linear1D_Row":
replace_layer = replace_layer_cls(weight.shape[1], weight.shape[0], bias=False if bias is None else True) replace_layer = replace_layer_cls(weight.shape[1],
weight.shape[0],
bias=False if bias is None else True)
elif replace_layer_cls.__name__ == "Linear1D_Col": elif replace_layer_cls.__name__ == "Linear1D_Col":
replace_layer = replace_layer_cls(weight.shape[0], weight.shape[1], bias=False if bias is None else True, gather_output=gather_output) replace_layer = replace_layer_cls(weight.shape[0],
weight.shape[1],
bias=False if bias is None else True,
gather_output=gather_output)
setattr_(org_layer, layer_attr, replace_layer, ignore=ignore) setattr_(org_layer, layer_attr, replace_layer, ignore=ignore)
self.set_param(replace_layer, weight, bias) self.set_param(replace_layer, weight, bias)
elif isinstance(getattr_(org_layer, layer_attr), nn.Embedding): elif isinstance(getattr_(org_layer, layer_attr), nn.Embedding):
replace_layer = replace_layer_cls(weight.shape[0], weight.shape[1], getattr_(org_layer, f"{layer_attr}.padding_idx", ignore=True)) replace_layer = replace_layer_cls(weight.shape[0], weight.shape[1],
getattr_(org_layer, f"{layer_attr}.padding_idx", ignore=True))
setattr_(org_layer, layer_attr, replace_layer, ignore=ignore) setattr_(org_layer, layer_attr, replace_layer, ignore=ignore)
self.set_param(replace_layer, weight, bias) self.set_param(replace_layer, weight, bias)
else: else:
raise NotImplementedError(f"Replacing {getattr_(org_layer, layer_attr).__class__} is not implemented so far") raise NotImplementedError(
f"Replacing {getattr_(org_layer, layer_attr).__class__} is not implemented so far")
# do not replace the layer object, just replace the weight and bias # do not replace the layer object, just replace the weight and bias
else: else:
self.set_param(org_layer, layer_attr, weight, bias) self.set_param(org_layer, layer_attr, weight, bias)
def set_param(self,
def set_param( layer: Any,
self, weight: torch.Tensor = None,
layer: Any, bias: torch.Tensor = None,
layer_attr: str = "", layer_attr: str = "") -> None:
weight: torch.Tensor = None, r"""
bias: torch.Tensor = None
) -> None:
"""
Reset the weight and bias of the layer object Reset the weight and bias of the layer object
Args: Args:
layer: The layer object layer (:class:`torch.nn.Module`): The layer object
layer_attr: The attribute name of the layer layer_attr (str): The attribute name of the layer
weight: The weight of the layer weight (:class:`torch.Tensor`): The weight of the layer
bias: The bias of the layer bias (:class:`torch.Tensor`): The bias of the layer
""" """
assert weight is not None or bias is not None assert weight is not None or bias is not None
if weight is not None: if weight is not None:
setattr_(layer, "weight" if layer_attr == "" else layer_attr+".weight", nn.Parameter(weight)) setattr_(layer, "weight" if layer_attr == "" else layer_attr + ".weight", nn.Parameter(weight.contiguous()))
self.set_layer_size(layer, layer_attr, weight.shape) self.set_layer_size(layer, layer_attr, weight.shape)
if bias is not None: if bias is not None:
setattr_(layer, "bias" if layer_attr == "" else layer_attr+".bias", nn.Parameter(bias)) setattr_(layer, "bias" if layer_attr == "" else layer_attr + ".bias", nn.Parameter(bias.contiguous()))
def set_layer_size(self, layer: nn.Module, layer_attr: str, size: torch.Size) -> None: def set_layer_size(self, layer: nn.Module, layer_attr: str, size: torch.Size) -> None:
""" r"""
Set the layer attribute Set the layer attribute
Args: Args:
layer: The layer object layer (:class:`torch.nn.Module`): The layer object
layer_attr: The attribute name of the layer layer_attr (str): The attribute name of the layer
size: Torch.size size (:class:`torch.Size`): The size of the tensor
""" """
# Tensor.shape[0] -> out_features, Tensor.shape[1] -> in_features # Tensor.shape[0] -> out_features, Tensor.shape[1] -> in_features
attrs = ["out_features", "in_features"] attrs = ["out_features", "in_features"]
for i, attr in enumerate(attrs): for i, attr in enumerate(attrs):
if hasattr_(layer, f"{layer_attr}.{attr}"): if hasattr_(layer, f"{layer_attr}.{attr}"):
setattr_(layer, f"{layer_attr}.{attr}", size[i]) setattr_(layer, f"{layer_attr}.{attr}", size[i])
def bind_layer(self, model: nn.Module) -> None:
r"""
Bind the layer according to the binding policy
Args:
model (:class:`torch.nn.Module`): The shard model
"""
binding_map = self.policy.binding_policy()
for k, v in binding_map.items():
param = getattr_(model, k)
param = nn.Parameter(param)
setattr_(model, k, param)
setattr_(model, v, param)

36
colossalai/shardformer/shard/shardmodel.py

@ -1,46 +1,48 @@
import os import os
from contextlib import suppress
from dataclasses import dataclass
import torch import torch
import torch.distributed as dist
import torch.nn as nn import torch.nn as nn
import transformers import transformers
import torch.distributed as dist
from dataclasses import dataclass
from contextlib import suppress
from colossalai.tensor.d_tensor.layout import Layout from colossalai.tensor.d_tensor.layout import Layout
from ..policies.basepolicy import Policy from ..policies.basepolicy import Policy
from .sharder import ModelSharder
from .shardconfig import ShardConfig from .shardconfig import ShardConfig
from .sharder import ModelSharder
class ShardModel(object): class ShardModel(object):
""" r"""
The class for sharding the huggingface model, self.model is the sharded model The class for sharding the huggingface model, ''self.model'' is the sharded model
Just creat a new ShardModel object to shard huggingface model Just creat a new ShardModel object to shard huggingface model
Args: Args:
model: the origin huggingface model model (:class:`torch.nn.Model`): the origin huggingface model
dist_config: the config for distribute information dist_config (:class:`ShardConfig`): the config for distribute information
custom_policy: the custom policy for sharding custom_policy (:class:`Policy`): the custom policy for sharding
""" """
def __init__( def __init__(
self, self,
model: nn.Module, model: nn.Module,
shard_config: ShardConfig = None, # TODO shard_config: ShardConfig = None, # TODO
custom_policy: Policy = None, custom_policy: Policy = None,
) -> None: ) -> None:
self.model = model self.model = model
self.shard_config = shard_config self.shard_config = shard_config
self.policy = custom_policy self.policy = custom_policy
# self.layout=, # TODO # self.layout=, # TODO
sharder=ModelSharder( sharder = ModelSharder(
model=self.model, model=self.model,
policy=self.policy, policy=self.policy,
shard_config=self.shard_config, shard_config=self.shard_config,
) )
sharder.shard() sharder.shard()
def set_environ(self) -> None: def set_environ(self) -> None:
os.environ["TOKENIZERS_PARALLELISM"] = "true" os.environ["TOKENIZERS_PARALLELISM"] = "true"
os.environ["MKL_SERVICE_FORCE_INTEL"] = "GNU" os.environ["MKL_SERVICE_FORCE_INTEL"] = "GNU"
@ -55,4 +57,4 @@ class ShardModel(object):
torch.cuda.set_device(int(os.getenv("LOCAL_RANK", "0"))) torch.cuda.set_device(int(os.getenv("LOCAL_RANK", "0")))
def back_to_org() -> None: def back_to_org() -> None:
pass pass

113
colossalai/shardformer/shard/slicer.py

@ -1,40 +1,40 @@
import os import os
from typing import Dict, Tuple
from dataclasses import dataclass from dataclasses import dataclass
from typing import Dict, Tuple
import torch import torch
import torch.distributed as dist import torch.distributed as dist
from ..policies.basepolicy import Layer, Col_Layer, Row_Layer
from .shardconfig import ShardConfig
from ..policies.basepolicy import Col_Layer, Layer, Row_Layer
from .shardconfig import ShardConfig
dim_mapping = {Col_Layer: 1, Row_Layer: 0} dim_mapping = {Col_Layer: 1, Row_Layer: 0}
class Slicer(): class Slicer():
def __init__( def __init__(
self, self,
shardconfig: ShardConfig #TODO shardconfig: ShardConfig #TODO
) -> None: ) -> None:
self.shardconfig = shardconfig self.shardconfig = shardconfig
def slice_weight_bias( def slice_weight_bias(
self, self,
weight: torch.Tensor, weight: torch.Tensor,
bias: torch.Tensor, bias: torch.Tensor,
policy_layer_cls: Layer, policy_layer_cls: Layer,
): ):
""" r"""
Slice the weight and bias according to policy layer cls Slice the weight and bias according to policy layer cls
Layer -> do nothing ``Layer`` -> do nothing
Col_Layer -> slice the weight and bias along dim 1 ``Col_Layer`` -> slice the weight and bias along dim 1
Row_Layer -> slice the weight along dim 0 and do not slice bias ``Row_Layer`` -> slice the weight along dim 0 and do not slice bias
Args: Args:
weight: The weight of the layer weight (:class:`torch.nn.Module`): The weight of the layer
bias: The bias of the layer bias: (:class:`torch.nn.Module`): The bias of the layer
policy_layer_class: The class represent how to slice the tensor policy_layer_class (:class:`Policy`): The class represent how to slice the tensor
""" """
if policy_layer_cls == Layer: if policy_layer_cls == Layer:
return weight, bias return weight, bias
@ -46,42 +46,6 @@ class Slicer():
else: else:
raise NotImplementedError(f"The policy layer class {policy_layer_cls} is not supported") raise NotImplementedError(f"The policy layer class {policy_layer_cls} is not supported")
return weight, bias return weight, bias
def slice_weight(
self,
weight: torch.Tensor,
policy_layer_cls: Layer,
) -> torch.Tensor:
"""
Slice the weight and bias according to the shardconfig
Args:
weight: The weight of the layer
bias: The bias of the layer
policy_layer_class: The class represent how to slice the tensor
"""
if weight is not None:
dim = dim_mapping[policy_layer_cls]
weight = self.slice_tensor(weight, dim, False)
return weight
def slice_bias(
self,
bias: torch.Tensor,
) -> torch.Tensor:
"""
Slice the bias according to the shardconfig
Args:
bias: The bias of the layer
"""
assert bias is not None, "The bias is None"
if bias is not None:
bias = self.slice_tensor(bias, 1, True)
return bias
def slice_tensor( def slice_tensor(
self, self,
@ -89,8 +53,13 @@ class Slicer():
dim: int, dim: int,
is_bias: bool, is_bias: bool,
) -> torch.Tensor: ) -> torch.Tensor:
""" r"""
Slice tensor according to the config Slice tensor according to the config
Args:
tensor_in (:class:`torch.Tensor`): The tensor to slice
dim (int): The dimension to slice
is_bias (bool): Whether the tensor is bias
""" """
if tensor_in is None: if tensor_in is None:
return None return None
@ -99,69 +68,75 @@ class Slicer():
else: else:
return self.slice_1d(tensor_in) return self.slice_1d(tensor_in)
def slice_2d( def slice_2d(
self, self,
tensor: torch.Tensor, tensor: torch.Tensor,
dim: int, dim: int,
) -> torch.Tensor: ) -> torch.Tensor:
""" r"""
Slice the 2D tensor Slice the 2D tensor
Args: Args:
tensor: The tensor to slice tensor (:class:`torch.Tensor`): The tensor to slice
dim (int): The dimension to slice
""" """
assert dim in [0,1], f"Only support 2D tensor, but got {dim}D tensor" assert dim in [0, 1], f"Only support 2D tensor, but got {dim}D tensor"
if dim == 0: if dim == 0:
return self.slice_row(tensor) return self.slice_row(tensor)
elif dim == 1: elif dim == 1:
return self.slice_col(tensor) return self.slice_col(tensor)
def slice_1d( def slice_1d(
self, self,
tensor: torch.Tensor, tensor: torch.Tensor,
dim: int = None,
) -> torch.Tensor: ) -> torch.Tensor:
""" r"""
Slice the 1D tensor Slice the 1D tensor
Args: Args:
tensor: The tensor to slice tensor (:class:`torch.Tensor`): The tensor to slice
Returns:
:class:`torch.Tensor`: The sliced tensor
""" """
delta = (tensor.shape[0] + self.shardconfig.world_size - 1) // self.shardconfig.world_size delta = (tensor.shape[0] + self.shardconfig.world_size - 1) // self.shardconfig.world_size
down_idx = self.shardconfig.rank * delta down_idx = self.shardconfig.rank * delta
up_idx = down_idx + delta up_idx = down_idx + delta
return tensor[down_idx:up_idx] return tensor[down_idx:up_idx].contiguous()
def slice_col( def slice_col(
self, self,
tensor: torch.Tensor, tensor: torch.Tensor,
) -> torch.Tensor: ) -> torch.Tensor:
""" r"""
Slice the tensor in column Slice the tensor in column
Args: Args:
tensor: The tensor to slice tensor (:class:`torch.Tensor`): The tensor to slice
Returns:
:class:`torch.Tensor`: The sliced tensor
""" """
delta = (tensor.shape[0] + self.shardconfig.world_size - 1) // self.shardconfig.world_size delta = (tensor.shape[0] + self.shardconfig.world_size - 1) // self.shardconfig.world_size
down_idx = self.shardconfig.rank * delta down_idx = self.shardconfig.rank * delta
up_idx = down_idx + delta up_idx = down_idx + delta
return tensor[down_idx:up_idx,:] return tensor[down_idx:up_idx, :].contiguous()
def slice_row( def slice_row(
self, self,
tensor: torch.Tensor, tensor: torch.Tensor,
) -> torch.Tensor: ) -> torch.Tensor:
""" r"""
Slice the tensor in column Slice the tensor in column
Args: Args:
tensor: The tensor to slice tensor (:class:`torch.Tensor`): The tensor to slice
Returns:
:class:`torch.Tensor`: The sliced tensor
""" """
delta = (tensor.shape[1] + self.shardconfig.world_size - 1) // self.shardconfig.world_size delta = (tensor.shape[1] + self.shardconfig.world_size - 1) // self.shardconfig.world_size
down_idx = self.shardconfig.rank * delta down_idx = self.shardconfig.rank * delta
up_idx = down_idx + delta up_idx = down_idx + delta
return tensor[:,down_idx:up_idx] return tensor[:, down_idx:up_idx].contiguous()

6
colossalai/shardformer/test/config.py

@ -1,5 +1 @@
parallel = dict( parallel = dict(data=1, pipeline=1, tensor=dict(size=2, mode='1d'))
data=1,
pipeline=1,
tensor=dict(size=2, mode='1d')
)

87
colossalai/shardformer/test/test.py

@ -1,23 +1,51 @@
from transformers import AutoTokenizer import argparse
from transformers import BertForMaskedLM import inspect
import os
import torch
import torch.nn as nn
from datasets import load_dataset
from torch.utils.data import DataLoader
from tqdm.auto import tqdm
from transformers import AutoTokenizer, BertForMaskedLM, DataCollatorForLanguageModeling, Trainer, TrainingArguments
import colossalai import colossalai
from colossalai.shardformer.shard.shardmodel import ShardModel
from colossalai.utils import get_current_device, print_rank_0
from colossalai.logging import get_dist_logger from colossalai.logging import get_dist_logger
from colossalai.shardformer.shard.shardconfig import ShardConfig from colossalai.shardformer.shard.shardconfig import ShardConfig
import inspect from colossalai.shardformer.shard.shardmodel import ShardModel
import argparse from colossalai.utils import get_current_device, print_rank_0
import torch.nn as nn
import os
os.environ['TRANSFORMERS_NO_ADVISORY_WARNINGS'] = 'true'
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased") tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
def get_args(): def get_args():
parser = colossalai.get_default_parser() parser = colossalai.get_default_parser()
parser.add_argument("--mode", type=str, default='inference')
return parser.parse_args() return parser.parse_args()
def load_data():
datasets = load_dataset('wikitext', 'wikitext-2-raw-v1')
# datasets=load_dataset("yelp_review_full")
tokenized_datasets = datasets.map(
lambda examples: tokenizer(examples["text"], truncation=True, padding="max_length"), batched=True)
tokenized_datasets = tokenized_datasets.remove_columns(["text"])
# tokenized_datasets=tokenized_datasets.rename_column("label","labels")
tokenized_datasets.set_format("torch")
train_dataset = tokenized_datasets["train"].select(range(500))
test_dataset = tokenized_datasets["test"].select(range(100))
datacollector = DataCollatorForLanguageModeling(tokenizer, mlm=True, mlm_probability=0.15, return_tensors="pt")
train_dataloader = DataLoader(train_dataset, batch_size=8, shuffle=True, collate_fn=datacollector)
eval_dataloader = DataLoader(test_dataset, batch_size=8, shuffle=True, collate_fn=datacollector)
return train_dataloader, eval_dataloader
def inference(model: nn.Module): def inference(model: nn.Module):
# print(model) print(model)
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
token = "Hello, my dog is cute" token = "Hello, my dog is cute"
inputs = tokenizer(token, return_tensors="pt") inputs = tokenizer(token, return_tensors="pt")
inputs.to("cuda") inputs.to("cuda")
@ -25,13 +53,48 @@ def inference(model: nn.Module):
outputs = model(**inputs) outputs = model(**inputs)
print(outputs) print(outputs)
def train(model: nn.Module, num_epoch: int = 2):
train_dataloader, eval_dataloader = load_data()
optimizer = torch.optim.AdamW(model.parameters(), lr=5e-5)
progress_bar = tqdm(range((num_epoch) * len(train_dataloader)))
criterion = nn.CrossEntropyLoss()
model.to("cuda")
model.train()
for epoch in range(num_epoch):
progress_bar.set_description(f"Rank {get_current_device()} epoch {epoch}")
for batch in train_dataloader:
optimizer.zero_grad()
batch = {k: v.to('cuda') for k, v in batch.items()}
outputs = model(**batch)
loss = outputs.loss
loss.backward()
optimizer.step()
progress_bar.update(1)
train_loss = loss
loss = 0.0
for batch in eval_dataloader:
batch = {k: v.to('cuda') for k, v in batch.items()}
outputs = model(**batch)
# loss = outputs.loss
loss += outputs.loss.item()
# loss = criterion(outputs.logits, batch["input_ids"])
test_loss = loss / len(eval_dataloader)
print_rank_0(f"Train Loss: {train_loss:.4f} Test Loss:{test_loss:.4f}")
if __name__ == "__main__": if __name__ == "__main__":
args = get_args() args = get_args()
colossalai.launch_from_torch(config=args.config) colossalai.launch_from_torch(config=args.config)
model = BertForMaskedLM.from_pretrained("bert-base-uncased") model = BertForMaskedLM.from_pretrained("bert-base-uncased")
shard_config = ShardConfig( shard_config = ShardConfig(
rank = int(str(get_current_device()).split(':')[-1]), rank=int(str(get_current_device()).split(':')[-1]),
world_size= int(os.environ['WORLD_SIZE']), world_size=int(os.environ['WORLD_SIZE']),
) )
shardmodel = ShardModel(model, shard_config) shardmodel = ShardModel(model, shard_config)
inference(shardmodel.model) if args.mode == "train":
train(shardmodel.model)
elif args.mode == "inference":
inference(shardmodel.model)

36
colossalai/shardformer/utils/utils.py

@ -1,10 +1,10 @@
def hasattr_(obj, attr: str): def hasattr_(obj, attr: str):
""" r"""
Check whether the object has the multi sublevel attr Check whether the object has the multi sublevel attr
Args: Args:
obj: The object to check obj (object): The object to check
attr: The multi level attr to check attr (str): The multi level attr to check
""" """
attrs = attr.split('.') attrs = attr.split('.')
for a in attrs: for a in attrs:
@ -14,15 +14,16 @@ def hasattr_(obj, attr: str):
return False return False
return True return True
def setattr_(obj, attr: str, value, ignore: bool=False):
""" def setattr_(obj, attr: str, value, ignore: bool = False):
r"""
Set the object's multi sublevel attr to value, if ignore, ignore when it doesn't exist Set the object's multi sublevel attr to value, if ignore, ignore when it doesn't exist
Args: Args:
obj: The object to set obj (object): The object to set
attr: The multi level attr to set attr (str): The multi level attr to set
value: The value to set value (Any): The value to set
ignore: Whether to ignore when the attr doesn't exist ignore (bool): Whether to ignore when the attr doesn't exist
""" """
attrs = attr.split('.') attrs = attr.split('.')
@ -31,18 +32,19 @@ def setattr_(obj, attr: str, value, ignore: bool=False):
obj = getattr(obj, a) obj = getattr(obj, a)
except AttributeError: except AttributeError:
if ignore: if ignore:
return return
raise AttributeError(f"Object {obj} has no attribute {attr}") raise AttributeError(f"Object {obj} has no attribute {attr}")
setattr(obj, attrs[-1], value) setattr(obj, attrs[-1], value)
def getattr_(obj, attr: str, ignore: bool=None):
""" def getattr_(obj, attr: str, ignore: bool = None):
r"""
Get the object's multi sublevel attr Get the object's multi sublevel attr
Args: Args:
obj: The object to set obj (object): The object to set
attr: The multi level attr to set attr (str): The multi level attr to set
ignore: Whether to ignore when the attr doesn't exist ignore (bool): Whether to ignore when the attr doesn't exist
""" """
attrs = attr.split('.') attrs = attr.split('.')
@ -53,4 +55,4 @@ def getattr_(obj, attr: str, ignore: bool=None):
if ignore: if ignore:
return None return None
raise AttributeError(f"Object {obj} has no attribute {attr}") raise AttributeError(f"Object {obj} has no attribute {attr}")
return obj return obj

Loading…
Cancel
Save