mirror of https://github.com/hpcaitech/ColossalAI
[shardformer]: Feature/shardformer, add some docstring and readme (#3816)
* init shardformer code structure * add implement of sharder (inject and replace) * add implement of replace layer to colossal layer * separate different layer policy, add some notion * implement 1d and 2d slicer, can tell col or row * fix bug when slicing and inject model * fix some bug; add inference test example * add share weight and train example * add train * add docstring and readme * add docstring for other files * pre-commitpull/4157/head
parent
8d68de767d
commit
8cc11235c0
|
@ -1,5 +1,6 @@
|
|||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
from colossalai.core import global_context as gpc
|
||||
|
||||
try:
|
||||
|
@ -72,6 +73,7 @@ class LinearWithAsyncCommunication(torch.autograd.Function):
|
|||
total_input = input
|
||||
grad_input = grad_output.matmul(weight)
|
||||
|
||||
grad_output = grad_output.contiguous()
|
||||
# Convert the tensor shapes to 2D for execution compatibility
|
||||
grad_output = grad_output.view(grad_output.shape[0] * grad_output.shape[1], grad_output.shape[2])
|
||||
total_input = total_input.view(total_input.shape[0] * total_input.shape[1], total_input.shape[2])
|
||||
|
|
|
@ -469,7 +469,8 @@ class Linear1D_Col(ParallelLayer):
|
|||
if skip_bias_add and not bias:
|
||||
raise ValueError('cannot skip bias addition if bias is None')
|
||||
|
||||
self.out_features_per_partition = divide(out_features, gpc.tensor_parallel_size)
|
||||
# self.out_features_per_partition = divide(out_features*2, gpc.tensor_parallel_size)
|
||||
self.out_features_per_partition = out_features
|
||||
|
||||
# Parameters.
|
||||
# Initialize weight.
|
||||
|
@ -612,7 +613,8 @@ class Linear1D_Row(ParallelLayer):
|
|||
raise ValueError('cannot skip bias addition if bias is None')
|
||||
|
||||
# Divide the weight matrix along the last dimension.
|
||||
self.input_size_per_partition = divide(in_features, gpc.tensor_parallel_size)
|
||||
# self.input_size_per_partition = divide(in_features*2, gpc.tensor_parallel_size)
|
||||
self.input_size_per_partition = in_features
|
||||
|
||||
# Parameters.
|
||||
# Initialize weight.
|
||||
|
@ -884,7 +886,8 @@ class VocabParallelEmbedding1D(ParallelLayer):
|
|||
|
||||
tensor_parallel_size = gpc.get_world_size(ParallelMode.PARALLEL_1D)
|
||||
tensor_parallel_rank = gpc.get_local_rank(ParallelMode.PARALLEL_1D)
|
||||
self.num_embeddings_per_partition = divide(num_embeddings, tensor_parallel_size)
|
||||
# self.num_embeddings_per_partition = divide(num_embeddings, tensor_parallel_size)
|
||||
self.num_embeddings_per_partition = num_embeddings
|
||||
self.vocab_start_index = tensor_parallel_rank * self.num_embeddings_per_partition
|
||||
self.vocab_end_index = self.vocab_start_index + self.num_embeddings_per_partition
|
||||
|
||||
|
|
|
@ -0,0 +1,177 @@
|
|||
## ShardFormer
|
||||
|
||||
### Intro
|
||||
Make the model in huggingface.co can be paralleled and can be used with colossalai according to custom policy.
|
||||
|
||||
### Quick start
|
||||
1. Usage
|
||||
- Use
|
||||
``` python
|
||||
from colossalai.shardformer.shard.shardmodel import ShardModel
|
||||
from transformers import BertForMaskedLM
|
||||
|
||||
# create huggingface model as normal
|
||||
model = BertForMaskedLM.from_pretrained("bert-base-uncased")
|
||||
|
||||
# make the huggingface model paralleled to ShardModel
|
||||
# auto policy:
|
||||
shardmodel = ShardModel(model).model
|
||||
|
||||
# custom policy:
|
||||
from xxx import <POLICYCLASS>
|
||||
shardmodel = ShardModel(model, <POLICYCLASS>).model
|
||||
|
||||
|
||||
# do angthing as normal
|
||||
...
|
||||
```
|
||||
- Policy
|
||||
|
||||
If you wanna parallel the model in custom way, just overwrite the policy class for the huggingface model.
|
||||
|
||||
You should do:
|
||||
|
||||
1. Inherit Policy class
|
||||
2. Overwrite argument_policy method
|
||||
- In this method you need to list which layers class you wanna modify and the attributes and parameters in those layers.
|
||||
3. Overwrite inject_policy method [Optional]
|
||||
- If you need to modify the forward or backward progress.
|
||||
4. Overwrite or add the param recording functions
|
||||
- These function use suffix to record the path of weight or bias for the layer.
|
||||
5. Overwrite binding
|
||||
|
||||
More details can be found in shardformer/policies/basepolicy.py
|
||||
``` python
|
||||
from colossalai.shardformer.policies.basepolicy import Policy, Layer, Col_Layer, Row_Layer, Argument
|
||||
|
||||
CustomPolicy(Policy):
|
||||
@staticmethod
|
||||
def argument_policy(model_config, shard_config: int) -> Dict[nn.Module,Argument]:
|
||||
"""
|
||||
Return a dict, the key is layer will be modified and the value is the Argument class with param setting and param functions
|
||||
|
||||
Args:
|
||||
model_config: The config of transformer model
|
||||
shard_setting: The config of distributed model
|
||||
|
||||
Return:
|
||||
Dict for the modify policy,
|
||||
{
|
||||
origin layer class1 (nn.Module): Argument(
|
||||
attr_dict = {
|
||||
argument1: value1,
|
||||
argument2: value2,
|
||||
...
|
||||
},
|
||||
param_funcs = [
|
||||
staticmethod1,
|
||||
staticmethod2,
|
||||
...
|
||||
]
|
||||
),
|
||||
origin layer class2 (nn.Module): Argument(
|
||||
attr_dict = {
|
||||
argument1: value1,
|
||||
argument2: value2,
|
||||
...
|
||||
},
|
||||
param_funcs = [
|
||||
staticmethod1,
|
||||
staticmethod2,
|
||||
...
|
||||
]
|
||||
),
|
||||
...
|
||||
}
|
||||
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@staticmethod
|
||||
def inject_policy() -> Tuple[nn.Module, nn.Module]:
|
||||
"""
|
||||
Return the dict for the inject model
|
||||
|
||||
Return:
|
||||
The injected model, key is the original model and value is the new shardmodel
|
||||
"""
|
||||
return ()
|
||||
|
||||
@staticmethod
|
||||
def binding_policy() -> Dict:
|
||||
"""
|
||||
Return the dict for the binding model
|
||||
"""
|
||||
return NotImplementedError
|
||||
|
||||
@staticmethod
|
||||
def attn_in() -> List:
|
||||
"""
|
||||
Attention qkv layer
|
||||
|
||||
Returns:
|
||||
List[Layer]: List of layer object, each layer is the new
|
||||
"""
|
||||
return NotImplementedError
|
||||
|
||||
@staticmethod
|
||||
def attn_out() -> List:
|
||||
"""
|
||||
Attention output projection layer
|
||||
|
||||
Returns:
|
||||
List[Layer]: List of layer object
|
||||
"""
|
||||
return NotImplementedError
|
||||
|
||||
@staticmethod
|
||||
def mlp_in() -> List:
|
||||
"""
|
||||
h -> 4h mlp layer
|
||||
|
||||
Returns:
|
||||
List[Layer]: List of layer object
|
||||
"""
|
||||
return NotImplementedError
|
||||
|
||||
@staticmethod
|
||||
def mlp_out() -> List:
|
||||
"""
|
||||
4h -> h mlp layer
|
||||
|
||||
Returns:
|
||||
List[Layer]: List of layer object
|
||||
"""
|
||||
return NotImplementedError
|
||||
|
||||
@staticmethod
|
||||
def embedding() -> List:
|
||||
"""
|
||||
Partially slice the embedding layer
|
||||
vocab_size->vocab_size//gpu_nums
|
||||
|
||||
Return:
|
||||
List[Layer]: List of layer object
|
||||
"""
|
||||
return NotImplementedError
|
||||
|
||||
@staticmethod
|
||||
def unembedding() -> List:
|
||||
"""
|
||||
Partially slice the embedding layer
|
||||
vocab_size->vocab_size//gpu_nums
|
||||
|
||||
Return:
|
||||
List[Layer]: List of layer object
|
||||
"""
|
||||
return NotImplementedError
|
||||
|
||||
```
|
||||
|
||||
2. Simple example
|
||||
``` shell
|
||||
# inference
|
||||
colossalai run --nproc_per_node 2 --master_port 29500 test.py --config config.py --mode inference
|
||||
# train
|
||||
colossalai run --nproc_per_node 2 --master_port 29500 test.py --config config.py --mode train
|
||||
```
|
|
@ -1,12 +1,14 @@
|
|||
from typing import Any, Dict, List, Type
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.nn import CrossEntropyLoss
|
||||
from typing import Any, Dict, List, Type
|
||||
|
||||
|
||||
from transformers import BertForMaskedLM
|
||||
from transformers.models.bert.modeling_bert import MaskedLMOutput
|
||||
|
||||
|
||||
class BertForMaskedLM_(BertForMaskedLM):
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids=None,
|
||||
|
@ -23,7 +25,7 @@ class BertForMaskedLM_(BertForMaskedLM):
|
|||
return_dict=None,
|
||||
**kwargs,
|
||||
):
|
||||
print("[Inject OK] Injected forward method")
|
||||
# print("[Inject OK] Injected forward method")
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
outputs = self.bert(
|
||||
|
@ -46,9 +48,9 @@ class BertForMaskedLM_(BertForMaskedLM):
|
|||
masked_lm_loss = None
|
||||
|
||||
# if input_ids is not None:
|
||||
# masked_lm_loss = applyDistCrossEntropy(prediction_scores, input_ids, self.config.vocab_size)
|
||||
# masked_lm_loss = applyDistCrossEntropy(prediction_scores, input_ids, self.config.vocab_size)
|
||||
if labels is not None:
|
||||
loss_fct = CrossEntropyLoss() # -100 index = padding token
|
||||
loss_fct = CrossEntropyLoss() # -100 index = padding token
|
||||
masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
|
||||
|
||||
if not return_dict:
|
||||
|
@ -60,4 +62,4 @@ class BertForMaskedLM_(BertForMaskedLM):
|
|||
logits=prediction_scores,
|
||||
hidden_states=outputs.hidden_states,
|
||||
attentions=outputs.attentions,
|
||||
)
|
||||
)
|
||||
|
|
|
@ -1,40 +1,47 @@
|
|||
import torch.nn as nn
|
||||
|
||||
|
||||
def build_policies():
|
||||
"""
|
||||
r"""
|
||||
Build the policies for the model
|
||||
|
||||
|
||||
Return:
|
||||
The dict for the policies
|
||||
"""
|
||||
auto_policy_dict = {}
|
||||
|
||||
from transformers.models.bert.modeling_bert import BertForMaskedLM
|
||||
|
||||
from .bert import BertForMaskedLMPolicy
|
||||
auto_policy_dict[BertForMaskedLM] = BertForMaskedLMPolicy
|
||||
|
||||
from transformers.models.bert.modeling_bert import BertForSequenceClassification
|
||||
|
||||
from .bert import BertForSequenceClassificationPolicy
|
||||
auto_policy_dict[BertForSequenceClassification] = BertForSequenceClassificationPolicy
|
||||
|
||||
|
||||
return auto_policy_dict
|
||||
|
||||
def get_autopolicy(model:nn.Module):
|
||||
"""
|
||||
|
||||
def get_autopolicy(model: nn.Module):
|
||||
r"""
|
||||
Return the auto policy for the model
|
||||
|
||||
Args:
|
||||
model: The model to be used
|
||||
model (:class:`nn.Module`): The model to get the auto policy
|
||||
|
||||
Return:
|
||||
The auto policy for the model
|
||||
:class:`Policy`: The auto policy for the model
|
||||
"""
|
||||
auto_policy_dict = build_policies()
|
||||
policy = auto_policy_dict.get(model.__class__, None)
|
||||
if policy is None:
|
||||
raise NotImplementedError(f"Auto policy for {model.__class__.__qualname__} is not implemented\n Supported models are {[i.__qualname__ for i in auto_policy_dict.keys()]}")
|
||||
if policy is None:
|
||||
raise NotImplementedError(
|
||||
f"Auto policy for {model.__class__.__qualname__} is not implemented\n Supported models are {[i.__qualname__ for i in auto_policy_dict.keys()]}"
|
||||
)
|
||||
return policy
|
||||
|
||||
|
||||
# from transformers.models.bert.modeling_bert import BertForMaskedLM, BertForPreTraining
|
||||
# model = BertForPreTraining
|
||||
# policy = get_autopolicy(model)
|
||||
|
|
|
@ -1,28 +1,38 @@
|
|||
# part of code modified from https://github.com/tunib-ai/parallelformers
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Callable, Dict, List, Tuple, Type
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import colossalai.nn as col_nn
|
||||
from typing import Any, Dict, List, Type, Tuple, Callable
|
||||
from transformers import AutoConfig
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
import colossalai.nn as col_nn
|
||||
|
||||
|
||||
@dataclass
|
||||
class Argument:
|
||||
attr_dict : Dict[str, Any]
|
||||
param_funcs : List[Callable]
|
||||
binding_layers : List[nn.Module] = field(default_factory=list)
|
||||
r"""
|
||||
The argument class for the policy
|
||||
|
||||
Args:
|
||||
attr_dict (Dict[str, Any]): The dict for the param setting
|
||||
param_funcs (:class:`List[Callable]`): The list for the param functions
|
||||
"""
|
||||
attr_dict: Dict[str, Any]
|
||||
param_funcs: List[Callable]
|
||||
|
||||
|
||||
@dataclass
|
||||
class Layer:
|
||||
"""
|
||||
r"""
|
||||
The layer object for the policy
|
||||
|
||||
Args:
|
||||
weight: The weight name of the layer
|
||||
bias: The bias name of the layer
|
||||
replace_layer: The layer to replace the original layer
|
||||
ignore: Whether to ignore this layer if it is not in the model
|
||||
weight (str): The weight suffix of the layer
|
||||
bias (str): The bias suffix of the layer
|
||||
replace_layer (:class:`colosalai.nn`): The layer to replace the original layer
|
||||
ignore (bool): Whether to ignore this layer if it is not in the model
|
||||
"""
|
||||
weight: str = None
|
||||
bias: str = None
|
||||
|
@ -32,45 +42,55 @@ class Layer:
|
|||
|
||||
@dataclass
|
||||
class Col_Layer(Layer):
|
||||
"""
|
||||
r"""
|
||||
Class for col shard layer in MegatronLM
|
||||
|
||||
Args:
|
||||
gather_output (bool): Whether to gather the output of the layer
|
||||
"""
|
||||
gather_output: bool = False
|
||||
|
||||
|
||||
@dataclass
|
||||
class Row_Layer(Layer):
|
||||
"""
|
||||
r"""
|
||||
Class for col shard layer in MegatronLM
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class Policy():
|
||||
"""
|
||||
r"""
|
||||
The base class for all the policies
|
||||
For each different model, it should have a different policy class, like BertPolicy for Bert Model
|
||||
or OPTPolicy for OPT model.
|
||||
For each different model, it should have a different policy class, like BertPolicy for Bert Model
|
||||
or OPTPolicy for OPT model.
|
||||
AutoPolicy:
|
||||
shardformer already defined some policies for huggingface model, just set custom_policy = None
|
||||
Shardformer already defined some policies for huggingface model, just set ``custom_policy`` = None
|
||||
to use the auto policy. In shardformer autopolicy, we define a base policy for one type model,
|
||||
like BertPolicy, and for each different Bert modle in huggingface like, BertForMaskedLM,
|
||||
like BertPolicy, and for each different Bert modle in huggingface like, BertForMaskedLM,
|
||||
BertForSequenceClassification, etc., for each different Bert model we difine different policy class
|
||||
and overwrite the method inject_policy
|
||||
|
||||
and overwrite the method like ``inject_policy`` to modify the forward and backward process.
|
||||
|
||||
CustomPolicy:
|
||||
If you want to define your own policy, you can set ``custom_policy`` = CustomPolicy, and overwrite
|
||||
all the methods in ``Policy`` class. You can refer to any policy we defined like the ``BertPolicy``
|
||||
class for the example.
|
||||
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def argument_policy(model_config, shard_config: int) -> Dict[nn.Module,Argument]:
|
||||
"""
|
||||
Return a dict, the key is layer will be modified and the value is the Argument class with param setting and param functions
|
||||
def argument_policy(model_config, shard_config: int) -> Dict[nn.Module, Argument]:
|
||||
r"""
|
||||
Return the dict for the modify policy, the key is the original layer class and the value is the
|
||||
argument for the modify layer
|
||||
|
||||
Args:
|
||||
model_config: The config of transformer model
|
||||
shard_setting: The config of distributed model
|
||||
|
||||
model_config (:class:`tansformer.Config`): The config of transformer model
|
||||
shard_config (:class:`ShardConfig`): The config for sharding model
|
||||
|
||||
Return:
|
||||
Dict for the modify policy,
|
||||
::
|
||||
{
|
||||
origin layer class1 (nn.Module): Argument(
|
||||
attr_dict = {
|
||||
|
@ -101,33 +121,51 @@ class Policy():
|
|||
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
@staticmethod
|
||||
def inject_policy() -> Tuple[nn.Module, nn.Module]:
|
||||
"""
|
||||
Return the dict for the inject model
|
||||
r"""
|
||||
Return the dict for the inject model
|
||||
|
||||
Return:
|
||||
The injected model, key is the original model and value is the new shardmodel
|
||||
::
|
||||
(OrignModel, CustomModel)
|
||||
in `CustomModel`, we can overwrite the forward and backward process
|
||||
"""
|
||||
return ()
|
||||
|
||||
|
||||
@staticmethod
|
||||
def attn_in() -> List:
|
||||
"""
|
||||
Attention qkv layer
|
||||
def binding_policy() -> Dict:
|
||||
r"""
|
||||
Return the dict for the binding model
|
||||
|
||||
Returns:
|
||||
List[Layer]: List of layer object, each layer is the new
|
||||
Return:
|
||||
This method should return the binding relationship for some layers share the weight or bias,
|
||||
the key and value is the suffix of the weight or bias of the model
|
||||
::
|
||||
return {
|
||||
"bert.embeddings.word_embeddings.weight": "cls.predictions.decoder.weight",
|
||||
}
|
||||
"""
|
||||
return NotImplementedError
|
||||
|
||||
@staticmethod
|
||||
def attn_in() -> List:
|
||||
r"""
|
||||
Attention qkv layer
|
||||
In this kind of method, we should return the list of ``Layer`` object, each ``Layer`` object should be
|
||||
``Layer`` for no slicing, ``Col_Layer`` for col slicing, ``Row_Layer`` for row slicing. And the parameters
|
||||
in ``Layer`` object can refer to the ``Layer`` class.
|
||||
|
||||
Returns:
|
||||
List[Layer]: List of layer object, each layer is the new
|
||||
"""
|
||||
return NotImplementedError
|
||||
|
||||
@staticmethod
|
||||
def attn_out() -> List:
|
||||
"""
|
||||
r"""
|
||||
Attention output projection layer
|
||||
|
||||
Returns:
|
||||
|
@ -135,46 +173,40 @@ class Policy():
|
|||
"""
|
||||
return NotImplementedError
|
||||
|
||||
|
||||
@staticmethod
|
||||
def mlp_in() -> List:
|
||||
"""
|
||||
r"""
|
||||
h -> 4h mlp layer
|
||||
|
||||
Returns:
|
||||
List[Layer]: List of layer object
|
||||
"""
|
||||
return NotImplementedError
|
||||
|
||||
|
||||
@staticmethod
|
||||
def mlp_out() -> List:
|
||||
"""
|
||||
r"""
|
||||
4h -> h mlp layer
|
||||
|
||||
Returns:
|
||||
List[Layer]: List of layer object
|
||||
"""
|
||||
return NotImplementedError
|
||||
|
||||
|
||||
|
||||
@staticmethod
|
||||
def embedding()->List:
|
||||
"""
|
||||
def embedding() -> List:
|
||||
r"""
|
||||
Partially slice the embedding layer
|
||||
vocab_size->vocab_size//gpu_nums
|
||||
|
||||
Return:
|
||||
List[Layer]: List of layer object
|
||||
"""
|
||||
return NotImplementedError
|
||||
|
||||
|
||||
|
||||
@staticmethod
|
||||
def unembedding()->List:
|
||||
"""
|
||||
def unembedding() -> List:
|
||||
r"""
|
||||
Partially slice the embedding layer
|
||||
vocab_size->vocab_size//gpu_nums
|
||||
|
||||
Return:
|
||||
List[Layer]: List of layer object
|
||||
|
|
|
@ -1,56 +1,57 @@
|
|||
from typing import Dict, List, Tuple, Type, Any, Callable
|
||||
import torch.nn as nn
|
||||
from .basepolicy import Policy, Layer, Argument, Col_Layer, Row_Layer
|
||||
import colossalai.nn as col_nn
|
||||
from transformers.models.bert.modeling_bert import BertLayer, BertEmbeddings, BertLMPredictionHead
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Callable, Dict, List, Tuple, Type
|
||||
|
||||
import torch.nn as nn
|
||||
from transformers.models.bert.modeling_bert import BertEmbeddings, BertLayer, BertLMPredictionHead
|
||||
|
||||
import colossalai.nn as col_nn
|
||||
|
||||
from .basepolicy import Argument, Col_Layer, Layer, Policy, Row_Layer
|
||||
|
||||
|
||||
class BertPolicy(Policy):
|
||||
|
||||
@staticmethod
|
||||
def argument_policy(config, world_size: int) -> Dict[nn.Module,Argument]:
|
||||
def argument_policy(config, world_size: int) -> Dict[nn.Module, Argument]:
|
||||
return {
|
||||
BertLayer: Argument(
|
||||
attr_dict = {
|
||||
# 1. shard hidden size
|
||||
"attention.self.all_head_size": config.hidden_size // world_size,
|
||||
"crossattention.self.all_head_size": config.hidden_size // world_size,
|
||||
# 2. shard number of heads
|
||||
"attention.self.num_attention_heads": config.num_attention_heads // world_size,
|
||||
"crossattention.self.num_attention_heads": config.num_attention_heads // world_size,
|
||||
|
||||
},
|
||||
param_funcs = [
|
||||
BertPolicy.attn_in,
|
||||
BertPolicy.attn_out,
|
||||
BertPolicy.mlp_in,
|
||||
BertPolicy.mlp_out
|
||||
]
|
||||
),
|
||||
BertEmbeddings: Argument(
|
||||
attr_dict = {
|
||||
# 1. shard vocab size
|
||||
# "word_embeddings.num_embeddings": config.vocab_size // world_size,
|
||||
# 2. add the size of the sliced embedding layer excluding the last slice
|
||||
"word_embeddings.dim_size": (config.vocab_size+world_size-1) // world_size,
|
||||
},
|
||||
param_funcs = [
|
||||
BertPolicy.embedding,
|
||||
],
|
||||
binding_layers = [
|
||||
BertLMPredictionHead,
|
||||
]
|
||||
),
|
||||
BertLMPredictionHead: Argument(
|
||||
attr_dict = {
|
||||
# 1. shard vocab size
|
||||
# "word_embeddings.num_embeddings": config.vocab_size // world_size,
|
||||
# 2. add the size of the sliced embedding layer excluding the last slice
|
||||
},
|
||||
param_funcs = [
|
||||
BertPolicy.unembedding,
|
||||
]
|
||||
)
|
||||
BertLayer:
|
||||
Argument(
|
||||
attr_dict={
|
||||
# 1. shard hidden size
|
||||
"attention.self.all_head_size": config.hidden_size // world_size,
|
||||
"crossattention.self.all_head_size": config.hidden_size // world_size,
|
||||
# 2. shard number of heads
|
||||
"attention.self.num_attention_heads": config.num_attention_heads // world_size,
|
||||
"crossattention.self.num_attention_heads": config.num_attention_heads // world_size,
|
||||
},
|
||||
param_funcs=[BertPolicy.attn_in, BertPolicy.attn_out, BertPolicy.mlp_in, BertPolicy.mlp_out]),
|
||||
BertEmbeddings:
|
||||
Argument(
|
||||
attr_dict={
|
||||
# 1. shard vocab size
|
||||
# "word_embeddings.num_embeddings": config.vocab_size // world_size,
|
||||
# 2. add the size of the sliced embedding layer excluding the last slice
|
||||
"word_embeddings.dim_size": (config.vocab_size + world_size - 1) // world_size,
|
||||
},
|
||||
param_funcs=[
|
||||
BertPolicy.embedding,
|
||||
]),
|
||||
BertLMPredictionHead:
|
||||
Argument(
|
||||
attr_dict={
|
||||
# 1. shard vocab size
|
||||
# "word_embeddings.num_embeddings": config.vocab_size // world_size,
|
||||
# 2. add the size of the sliced embedding layer excluding the last slice
|
||||
},
|
||||
param_funcs=[
|
||||
BertPolicy.unembedding,
|
||||
])
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def binding_policy() -> Dict:
|
||||
return {
|
||||
"bert.embeddings.word_embeddings.weight": "cls.predictions.decoder.weight",
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
|
@ -89,9 +90,8 @@ class BertPolicy(Policy):
|
|||
replace_layer=col_nn.Linear1D_Col,
|
||||
ignore=True,
|
||||
),
|
||||
|
||||
]
|
||||
|
||||
|
||||
@staticmethod
|
||||
def attn_out() -> List:
|
||||
return [
|
||||
|
@ -107,17 +107,17 @@ class BertPolicy(Policy):
|
|||
ignore=True,
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
@staticmethod
|
||||
def mlp_in() -> List:
|
||||
return [
|
||||
Col_Layer(
|
||||
Col_Layer(
|
||||
weight="intermediate.dense.weight",
|
||||
bias="intermediate.dense.bias",
|
||||
replace_layer=col_nn.Linear1D_Col,
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
@staticmethod
|
||||
def mlp_out() -> List:
|
||||
return [
|
||||
|
@ -130,13 +130,11 @@ class BertPolicy(Policy):
|
|||
|
||||
@staticmethod
|
||||
def embedding() -> List:
|
||||
return [
|
||||
Col_Layer(
|
||||
weight="word_embeddings.weight",
|
||||
replace_layer=col_nn.VocabParallelEmbedding1D,
|
||||
)
|
||||
]
|
||||
|
||||
return [Col_Layer(
|
||||
weight="word_embeddings.weight",
|
||||
replace_layer=col_nn.VocabParallelEmbedding1D,
|
||||
)]
|
||||
|
||||
@staticmethod
|
||||
def unembedding() -> List:
|
||||
return [
|
||||
|
@ -148,16 +146,21 @@ class BertPolicy(Policy):
|
|||
)
|
||||
]
|
||||
|
||||
|
||||
from transformers import BertForMaskedLM
|
||||
|
||||
from colossalai.shardformer.model.modeling_bert import BertForMaskedLM_
|
||||
|
||||
|
||||
class BertForMaskedLMPolicy(BertPolicy):
|
||||
|
||||
@staticmethod
|
||||
def inject_policy() -> Tuple[nn.Module, nn.Module]:
|
||||
return (BertForMaskedLM, BertForMaskedLM_)
|
||||
|
||||
|
||||
|
||||
|
||||
class BertForSequenceClassificationPolicy(BertPolicy):
|
||||
|
||||
@staticmethod
|
||||
def inject_policy() -> Dict:
|
||||
return {}
|
||||
|
@ -165,4 +168,4 @@ class BertForSequenceClassificationPolicy(BertPolicy):
|
|||
|
||||
# model = BertForMaskedLM.from_pretrained("bert-base-uncased")
|
||||
# _ = BertForMaskedLMPolicy(model)
|
||||
# print(isinstance(model,list(_.inject_policy().keys())[0]))
|
||||
# print(isinstance(model,list(_.inject_policy().keys())[0]))
|
||||
|
|
|
@ -10,9 +10,9 @@ class ShardConfig:
|
|||
fp16: bool = True
|
||||
num_gpus: int = 2
|
||||
world_size: int = 2
|
||||
backend="nccl"
|
||||
backend = "nccl"
|
||||
verbose: str = 'simple'
|
||||
seed: int = None
|
||||
require_grad: bool = False
|
||||
master_addr: str = "127.0.0.1"
|
||||
master_port: int = 29500
|
||||
master_port: int = 29500
|
||||
|
|
|
@ -1,56 +1,59 @@
|
|||
import os
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Type, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from typing import Any, Dict, Iterable, List, Optional, Tuple, Type, Union, Callable
|
||||
from .shardconfig import ShardConfig
|
||||
from dataclasses import dataclass
|
||||
from ..policies.basepolicy import Policy, Layer
|
||||
from ..policies.autopolicy import get_autopolicy
|
||||
from .slicer import Slicer
|
||||
from ..utils.utils import hasattr_, setattr_, getattr_
|
||||
|
||||
import colossalai.nn as col_nn
|
||||
from colossalai.logging import get_dist_logger
|
||||
import os
|
||||
|
||||
from ..policies.autopolicy import get_autopolicy
|
||||
from ..policies.basepolicy import Layer, Policy
|
||||
from ..utils.utils import getattr_, hasattr_, setattr_
|
||||
from .shardconfig import ShardConfig
|
||||
from .slicer import Slicer
|
||||
|
||||
logger = get_dist_logger()
|
||||
|
||||
|
||||
class ModelSharder(object):
|
||||
"""
|
||||
r"""
|
||||
Shard the original huggingface model according to the policy
|
||||
|
||||
Args:
|
||||
policy: The policy to shard the model
|
||||
model: The model to shard
|
||||
dist_setting: The setting of distributed model
|
||||
policy (:class:`Policy`): The policy to shard the model
|
||||
model (:class:`torch.Module`): The model to shard
|
||||
shard_config: The setting of distributed model
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: nn.Module,
|
||||
policy: Policy,
|
||||
shard_config: ShardConfig = None, # TODO
|
||||
) -> None:
|
||||
shard_config: ShardConfig = None, # TODO
|
||||
) -> None:
|
||||
self.model = model
|
||||
self.policy = get_autopolicy(self.model) if policy is None else policy
|
||||
self.slicer = Slicer(shard_config)
|
||||
self.shard_config = shard_config
|
||||
self.model_config = self.model.config
|
||||
self.binding_map = {}
|
||||
|
||||
|
||||
def shard(self) -> None:
|
||||
self.inject_model(self.model)
|
||||
self.replace_layer(self.model)
|
||||
|
||||
|
||||
self.bind_layer(self.model)
|
||||
|
||||
def inject_model(
|
||||
self,
|
||||
model: nn.Module,
|
||||
) -> None:
|
||||
"""
|
||||
self,
|
||||
model: nn.Module,
|
||||
) -> None:
|
||||
r"""
|
||||
Replace the model to policy defined model
|
||||
Mainly modify the forward and backward to fit distributed model
|
||||
|
||||
|
||||
e.g.
|
||||
::
|
||||
BertForMaskedLM.forward -> BertForMaskedLM_.forward
|
||||
"""
|
||||
inject_policy = self.policy.inject_policy()
|
||||
|
@ -64,49 +67,43 @@ class ModelSharder(object):
|
|||
setattr(
|
||||
model.__class__,
|
||||
key,
|
||||
getattr(shard_model_cls,key),
|
||||
getattr(shard_model_cls, key),
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError(f"{model.__class__} is not implemented so far")
|
||||
|
||||
|
||||
def replace_layer(
|
||||
self,
|
||||
model: nn.Module,
|
||||
) -> None:
|
||||
"""
|
||||
self,
|
||||
model: nn.Module,
|
||||
) -> None:
|
||||
r"""
|
||||
Replace the layer according to the policy, and replace the layer one by one
|
||||
|
||||
Args:
|
||||
layer: The layer to shard
|
||||
model (:class:`torch.nn.Module`): The layer to shard
|
||||
"""
|
||||
argument_policies = self.policy.argument_policy(self.model_config, self.shard_config.world_size)
|
||||
for argument_policy in argument_policies.items():
|
||||
origin_layer_cls = argument_policy[0]
|
||||
attr_dict = argument_policy[1].attr_dict
|
||||
param_funcs = argument_policy[1].param_funcs
|
||||
binding_layers = argument_policy[1].binding_layers
|
||||
# if binding_layer is not None:
|
||||
# self.binding_map[origin_layer_cls] = binding_layer
|
||||
self.reverse_replace_layer(model, origin_layer_cls, attr_dict, param_funcs, binding_layers)
|
||||
|
||||
self.reverse_replace_layer(model, origin_layer_cls, attr_dict, param_funcs)
|
||||
|
||||
def reverse_replace_layer(
|
||||
self,
|
||||
layer: nn.Module,
|
||||
origin_cls: nn.Module,
|
||||
attr_dict: Dict[str, Any],
|
||||
param_funcs: List[Callable],
|
||||
binding_layers: List[nn.Module]
|
||||
) -> None:
|
||||
"""
|
||||
self,
|
||||
layer: nn.Module,
|
||||
origin_cls: nn.Module,
|
||||
attr_dict: Dict[str, Any],
|
||||
param_funcs: List[Callable],
|
||||
) -> None:
|
||||
r"""
|
||||
Reverse the replace layer operation
|
||||
|
||||
Args:
|
||||
layer: The object of layer to shard
|
||||
origin_cls: The origin layer class
|
||||
attr_dict: The attribute dict to modify
|
||||
policy_cls: The policy class
|
||||
layer (:class:`torch.nn.Module`): The object of layer to shard
|
||||
origin_cls (:class:`transformers.model`): The origin layer class
|
||||
attr_dict (Dict): The attribute dict to modify
|
||||
policy_cls (:class:`Policy`): The policy class
|
||||
"""
|
||||
for name, child in layer.named_children():
|
||||
if child.__class__ == origin_cls:
|
||||
|
@ -115,25 +112,23 @@ class ModelSharder(object):
|
|||
setattr_(child, k, v, ignore=True)
|
||||
# print(f"Sharding {name} layer", replac_layer.attention.self.__dict__)
|
||||
# setattr_(layer, name, self.shard_one_layer(child, policy_cls))
|
||||
self.shard_one_layer(child, param_funcs, binding_layers)
|
||||
self.shard_one_layer(child, param_funcs)
|
||||
continue
|
||||
|
||||
self.reverse_replace_layer(child, origin_cls, attr_dict, param_funcs, binding_layers)
|
||||
self.reverse_replace_layer(child, origin_cls, attr_dict, param_funcs)
|
||||
return layer
|
||||
|
||||
|
||||
def shard_one_layer(
|
||||
self,
|
||||
org_layer: nn.Module,
|
||||
param_funcs: List[Callable],
|
||||
binding_layers: List[nn.Module]
|
||||
) -> None:
|
||||
"""
|
||||
self,
|
||||
org_layer: nn.Module,
|
||||
param_funcs: List[Callable],
|
||||
) -> None:
|
||||
r"""
|
||||
Shard one layer according to the policy, the layer should be the same class as the key in policy's argument_policy return dict
|
||||
|
||||
Args:
|
||||
org_layer: The origin layer object to shard
|
||||
param_funcs: The function list to get shard information in policy class
|
||||
org_layer (:class:`torch.nn.Module`): The origin layer object to shard
|
||||
param_funcs (:class:`List[typing.Callable]`): The function list to get shard information in policy class
|
||||
|
||||
"""
|
||||
# print(org_layer)
|
||||
|
@ -148,7 +143,7 @@ class ModelSharder(object):
|
|||
ignore = policy_layer.ignore
|
||||
if policy_layer.__class__.__name__ == "Col_Layer":
|
||||
gather_output = policy_layer.gather_output
|
||||
print(gather_output)
|
||||
# print(gather_output)
|
||||
|
||||
if weight_attr is not None:
|
||||
if hasattr_(org_layer, weight_attr):
|
||||
|
@ -172,67 +167,81 @@ class ModelSharder(object):
|
|||
|
||||
# slice weight and bias
|
||||
weight, bias = self.slicer.slice_weight_bias(weight, bias, policy_layer.__class__)
|
||||
print(os.environ['RANK'], policy_layer.__class__, weight.shape, bias.shape if bias is not None else None)
|
||||
# save the binding information
|
||||
for binding_layer in binding_layers:
|
||||
self.binding_map[binding_layer] = dict(weight=weight, bias=bias)
|
||||
# print(os.environ['RANK'], policy_layer.__class__, weight.shape, bias.shape if bias is not None else None)
|
||||
|
||||
# create new object to replace the origin layer
|
||||
if replace_layer_cls is not None:
|
||||
# print(f"RANK {os.environ['RANK']}: replace {getattr_(org_layer, layer_attr).__class__} to {replace_layer_cls}, shape is {weight.shape}")
|
||||
if isinstance(getattr_(org_layer, layer_attr), nn.Linear):
|
||||
if replace_layer_cls.__name__ == "Linear1D_Row":
|
||||
replace_layer = replace_layer_cls(weight.shape[1], weight.shape[0], bias=False if bias is None else True)
|
||||
replace_layer = replace_layer_cls(weight.shape[1],
|
||||
weight.shape[0],
|
||||
bias=False if bias is None else True)
|
||||
elif replace_layer_cls.__name__ == "Linear1D_Col":
|
||||
replace_layer = replace_layer_cls(weight.shape[0], weight.shape[1], bias=False if bias is None else True, gather_output=gather_output)
|
||||
replace_layer = replace_layer_cls(weight.shape[0],
|
||||
weight.shape[1],
|
||||
bias=False if bias is None else True,
|
||||
gather_output=gather_output)
|
||||
setattr_(org_layer, layer_attr, replace_layer, ignore=ignore)
|
||||
self.set_param(replace_layer, weight, bias)
|
||||
elif isinstance(getattr_(org_layer, layer_attr), nn.Embedding):
|
||||
replace_layer = replace_layer_cls(weight.shape[0], weight.shape[1], getattr_(org_layer, f"{layer_attr}.padding_idx", ignore=True))
|
||||
elif isinstance(getattr_(org_layer, layer_attr), nn.Embedding):
|
||||
replace_layer = replace_layer_cls(weight.shape[0], weight.shape[1],
|
||||
getattr_(org_layer, f"{layer_attr}.padding_idx", ignore=True))
|
||||
setattr_(org_layer, layer_attr, replace_layer, ignore=ignore)
|
||||
self.set_param(replace_layer, weight, bias)
|
||||
else:
|
||||
raise NotImplementedError(f"Replacing {getattr_(org_layer, layer_attr).__class__} is not implemented so far")
|
||||
raise NotImplementedError(
|
||||
f"Replacing {getattr_(org_layer, layer_attr).__class__} is not implemented so far")
|
||||
# do not replace the layer object, just replace the weight and bias
|
||||
else:
|
||||
self.set_param(org_layer, layer_attr, weight, bias)
|
||||
|
||||
|
||||
def set_param(
|
||||
self,
|
||||
layer: Any,
|
||||
layer_attr: str = "",
|
||||
weight: torch.Tensor = None,
|
||||
bias: torch.Tensor = None
|
||||
) -> None:
|
||||
"""
|
||||
def set_param(self,
|
||||
layer: Any,
|
||||
weight: torch.Tensor = None,
|
||||
bias: torch.Tensor = None,
|
||||
layer_attr: str = "") -> None:
|
||||
r"""
|
||||
Reset the weight and bias of the layer object
|
||||
|
||||
Args:
|
||||
layer: The layer object
|
||||
layer_attr: The attribute name of the layer
|
||||
weight: The weight of the layer
|
||||
bias: The bias of the layer
|
||||
layer (:class:`torch.nn.Module`): The layer object
|
||||
layer_attr (str): The attribute name of the layer
|
||||
weight (:class:`torch.Tensor`): The weight of the layer
|
||||
bias (:class:`torch.Tensor`): The bias of the layer
|
||||
"""
|
||||
assert weight is not None or bias is not None
|
||||
if weight is not None:
|
||||
setattr_(layer, "weight" if layer_attr == "" else layer_attr+".weight", nn.Parameter(weight))
|
||||
setattr_(layer, "weight" if layer_attr == "" else layer_attr + ".weight", nn.Parameter(weight.contiguous()))
|
||||
self.set_layer_size(layer, layer_attr, weight.shape)
|
||||
if bias is not None:
|
||||
setattr_(layer, "bias" if layer_attr == "" else layer_attr+".bias", nn.Parameter(bias))
|
||||
|
||||
setattr_(layer, "bias" if layer_attr == "" else layer_attr + ".bias", nn.Parameter(bias.contiguous()))
|
||||
|
||||
def set_layer_size(self, layer: nn.Module, layer_attr: str, size: torch.Size) -> None:
|
||||
"""
|
||||
r"""
|
||||
Set the layer attribute
|
||||
|
||||
Args:
|
||||
layer: The layer object
|
||||
layer_attr: The attribute name of the layer
|
||||
size: Torch.size
|
||||
layer (:class:`torch.nn.Module`): The layer object
|
||||
layer_attr (str): The attribute name of the layer
|
||||
size (:class:`torch.Size`): The size of the tensor
|
||||
"""
|
||||
# Tensor.shape[0] -> out_features, Tensor.shape[1] -> in_features
|
||||
attrs = ["out_features", "in_features"]
|
||||
for i, attr in enumerate(attrs):
|
||||
if hasattr_(layer, f"{layer_attr}.{attr}"):
|
||||
setattr_(layer, f"{layer_attr}.{attr}", size[i])
|
||||
setattr_(layer, f"{layer_attr}.{attr}", size[i])
|
||||
|
||||
def bind_layer(self, model: nn.Module) -> None:
|
||||
r"""
|
||||
Bind the layer according to the binding policy
|
||||
|
||||
Args:
|
||||
model (:class:`torch.nn.Module`): The shard model
|
||||
"""
|
||||
binding_map = self.policy.binding_policy()
|
||||
for k, v in binding_map.items():
|
||||
param = getattr_(model, k)
|
||||
param = nn.Parameter(param)
|
||||
setattr_(model, k, param)
|
||||
setattr_(model, v, param)
|
||||
|
|
|
@ -1,46 +1,48 @@
|
|||
import os
|
||||
from contextlib import suppress
|
||||
from dataclasses import dataclass
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
import transformers
|
||||
import torch.distributed as dist
|
||||
from dataclasses import dataclass
|
||||
from contextlib import suppress
|
||||
|
||||
from colossalai.tensor.d_tensor.layout import Layout
|
||||
|
||||
from ..policies.basepolicy import Policy
|
||||
from .sharder import ModelSharder
|
||||
from .shardconfig import ShardConfig
|
||||
from .sharder import ModelSharder
|
||||
|
||||
|
||||
class ShardModel(object):
|
||||
"""
|
||||
The class for sharding the huggingface model, self.model is the sharded model
|
||||
r"""
|
||||
The class for sharding the huggingface model, ''self.model'' is the sharded model
|
||||
Just creat a new ShardModel object to shard huggingface model
|
||||
|
||||
Args:
|
||||
model: the origin huggingface model
|
||||
dist_config: the config for distribute information
|
||||
custom_policy: the custom policy for sharding
|
||||
model (:class:`torch.nn.Model`): the origin huggingface model
|
||||
dist_config (:class:`ShardConfig`): the config for distribute information
|
||||
custom_policy (:class:`Policy`): the custom policy for sharding
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: nn.Module,
|
||||
shard_config: ShardConfig = None, # TODO
|
||||
custom_policy: Policy = None,
|
||||
) -> None:
|
||||
self,
|
||||
model: nn.Module,
|
||||
shard_config: ShardConfig = None, # TODO
|
||||
custom_policy: Policy = None,
|
||||
) -> None:
|
||||
self.model = model
|
||||
self.shard_config = shard_config
|
||||
self.policy = custom_policy
|
||||
# self.layout=, # TODO
|
||||
|
||||
sharder=ModelSharder(
|
||||
sharder = ModelSharder(
|
||||
model=self.model,
|
||||
policy=self.policy,
|
||||
shard_config=self.shard_config,
|
||||
)
|
||||
sharder.shard()
|
||||
|
||||
|
||||
def set_environ(self) -> None:
|
||||
os.environ["TOKENIZERS_PARALLELISM"] = "true"
|
||||
os.environ["MKL_SERVICE_FORCE_INTEL"] = "GNU"
|
||||
|
@ -55,4 +57,4 @@ class ShardModel(object):
|
|||
torch.cuda.set_device(int(os.getenv("LOCAL_RANK", "0")))
|
||||
|
||||
def back_to_org() -> None:
|
||||
pass
|
||||
pass
|
||||
|
|
|
@ -1,40 +1,40 @@
|
|||
import os
|
||||
from typing import Dict, Tuple
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, Tuple
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from ..policies.basepolicy import Layer, Col_Layer, Row_Layer
|
||||
|
||||
from ..policies.basepolicy import Col_Layer, Layer, Row_Layer
|
||||
from .shardconfig import ShardConfig
|
||||
|
||||
|
||||
dim_mapping = {Col_Layer: 1, Row_Layer: 0}
|
||||
|
||||
|
||||
class Slicer():
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
shardconfig: ShardConfig #TODO
|
||||
self,
|
||||
shardconfig: ShardConfig #TODO
|
||||
) -> None:
|
||||
self.shardconfig = shardconfig
|
||||
|
||||
|
||||
def slice_weight_bias(
|
||||
self,
|
||||
weight: torch.Tensor,
|
||||
bias: torch.Tensor,
|
||||
policy_layer_cls: Layer,
|
||||
):
|
||||
"""
|
||||
r"""
|
||||
Slice the weight and bias according to policy layer cls
|
||||
Layer -> do nothing
|
||||
Col_Layer -> slice the weight and bias along dim 1
|
||||
Row_Layer -> slice the weight along dim 0 and do not slice bias
|
||||
``Layer`` -> do nothing
|
||||
``Col_Layer`` -> slice the weight and bias along dim 1
|
||||
``Row_Layer`` -> slice the weight along dim 0 and do not slice bias
|
||||
|
||||
Args:
|
||||
weight: The weight of the layer
|
||||
bias: The bias of the layer
|
||||
policy_layer_class: The class represent how to slice the tensor
|
||||
weight (:class:`torch.nn.Module`): The weight of the layer
|
||||
bias: (:class:`torch.nn.Module`): The bias of the layer
|
||||
policy_layer_class (:class:`Policy`): The class represent how to slice the tensor
|
||||
"""
|
||||
if policy_layer_cls == Layer:
|
||||
return weight, bias
|
||||
|
@ -46,42 +46,6 @@ class Slicer():
|
|||
else:
|
||||
raise NotImplementedError(f"The policy layer class {policy_layer_cls} is not supported")
|
||||
return weight, bias
|
||||
|
||||
|
||||
def slice_weight(
|
||||
self,
|
||||
weight: torch.Tensor,
|
||||
policy_layer_cls: Layer,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Slice the weight and bias according to the shardconfig
|
||||
|
||||
Args:
|
||||
weight: The weight of the layer
|
||||
bias: The bias of the layer
|
||||
policy_layer_class: The class represent how to slice the tensor
|
||||
"""
|
||||
if weight is not None:
|
||||
dim = dim_mapping[policy_layer_cls]
|
||||
weight = self.slice_tensor(weight, dim, False)
|
||||
return weight
|
||||
|
||||
|
||||
def slice_bias(
|
||||
self,
|
||||
bias: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Slice the bias according to the shardconfig
|
||||
|
||||
Args:
|
||||
bias: The bias of the layer
|
||||
"""
|
||||
assert bias is not None, "The bias is None"
|
||||
if bias is not None:
|
||||
bias = self.slice_tensor(bias, 1, True)
|
||||
return bias
|
||||
|
||||
|
||||
def slice_tensor(
|
||||
self,
|
||||
|
@ -89,8 +53,13 @@ class Slicer():
|
|||
dim: int,
|
||||
is_bias: bool,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
r"""
|
||||
Slice tensor according to the config
|
||||
|
||||
Args:
|
||||
tensor_in (:class:`torch.Tensor`): The tensor to slice
|
||||
dim (int): The dimension to slice
|
||||
is_bias (bool): Whether the tensor is bias
|
||||
"""
|
||||
if tensor_in is None:
|
||||
return None
|
||||
|
@ -99,69 +68,75 @@ class Slicer():
|
|||
else:
|
||||
return self.slice_1d(tensor_in)
|
||||
|
||||
|
||||
def slice_2d(
|
||||
self,
|
||||
tensor: torch.Tensor,
|
||||
dim: int,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Slice the 2D tensor
|
||||
r"""
|
||||
Slice the 2D tensor
|
||||
|
||||
Args:
|
||||
tensor: The tensor to slice
|
||||
tensor (:class:`torch.Tensor`): The tensor to slice
|
||||
dim (int): The dimension to slice
|
||||
"""
|
||||
assert dim in [0,1], f"Only support 2D tensor, but got {dim}D tensor"
|
||||
assert dim in [0, 1], f"Only support 2D tensor, but got {dim}D tensor"
|
||||
if dim == 0:
|
||||
return self.slice_row(tensor)
|
||||
elif dim == 1:
|
||||
return self.slice_col(tensor)
|
||||
|
||||
|
||||
def slice_1d(
|
||||
self,
|
||||
tensor: torch.Tensor,
|
||||
dim: int = None,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Slice the 1D tensor
|
||||
r"""
|
||||
Slice the 1D tensor
|
||||
|
||||
Args:
|
||||
tensor: The tensor to slice
|
||||
tensor (:class:`torch.Tensor`): The tensor to slice
|
||||
|
||||
Returns:
|
||||
:class:`torch.Tensor`: The sliced tensor
|
||||
"""
|
||||
delta = (tensor.shape[0] + self.shardconfig.world_size - 1) // self.shardconfig.world_size
|
||||
down_idx = self.shardconfig.rank * delta
|
||||
up_idx = down_idx + delta
|
||||
return tensor[down_idx:up_idx]
|
||||
return tensor[down_idx:up_idx].contiguous()
|
||||
|
||||
def slice_col(
|
||||
self,
|
||||
tensor: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
r"""
|
||||
Slice the tensor in column
|
||||
|
||||
Args:
|
||||
tensor: The tensor to slice
|
||||
tensor (:class:`torch.Tensor`): The tensor to slice
|
||||
|
||||
Returns:
|
||||
:class:`torch.Tensor`: The sliced tensor
|
||||
|
||||
"""
|
||||
delta = (tensor.shape[0] + self.shardconfig.world_size - 1) // self.shardconfig.world_size
|
||||
down_idx = self.shardconfig.rank * delta
|
||||
up_idx = down_idx + delta
|
||||
return tensor[down_idx:up_idx,:]
|
||||
|
||||
return tensor[down_idx:up_idx, :].contiguous()
|
||||
|
||||
def slice_row(
|
||||
self,
|
||||
tensor: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
r"""
|
||||
Slice the tensor in column
|
||||
|
||||
Args:
|
||||
tensor: The tensor to slice
|
||||
tensor (:class:`torch.Tensor`): The tensor to slice
|
||||
|
||||
Returns:
|
||||
:class:`torch.Tensor`: The sliced tensor
|
||||
"""
|
||||
delta = (tensor.shape[1] + self.shardconfig.world_size - 1) // self.shardconfig.world_size
|
||||
down_idx = self.shardconfig.rank * delta
|
||||
up_idx = down_idx + delta
|
||||
return tensor[:,down_idx:up_idx]
|
||||
|
||||
return tensor[:, down_idx:up_idx].contiguous()
|
||||
|
|
|
@ -1,5 +1 @@
|
|||
parallel = dict(
|
||||
data=1,
|
||||
pipeline=1,
|
||||
tensor=dict(size=2, mode='1d')
|
||||
)
|
||||
parallel = dict(data=1, pipeline=1, tensor=dict(size=2, mode='1d'))
|
||||
|
|
|
@ -1,23 +1,51 @@
|
|||
from transformers import AutoTokenizer
|
||||
from transformers import BertForMaskedLM
|
||||
import colossalai
|
||||
from colossalai.shardformer.shard.shardmodel import ShardModel
|
||||
from colossalai.utils import get_current_device, print_rank_0
|
||||
from colossalai.logging import get_dist_logger
|
||||
from colossalai.shardformer.shard.shardconfig import ShardConfig
|
||||
import inspect
|
||||
import argparse
|
||||
import torch.nn as nn
|
||||
import inspect
|
||||
import os
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from datasets import load_dataset
|
||||
from torch.utils.data import DataLoader
|
||||
from tqdm.auto import tqdm
|
||||
from transformers import AutoTokenizer, BertForMaskedLM, DataCollatorForLanguageModeling, Trainer, TrainingArguments
|
||||
|
||||
import colossalai
|
||||
from colossalai.logging import get_dist_logger
|
||||
from colossalai.shardformer.shard.shardconfig import ShardConfig
|
||||
from colossalai.shardformer.shard.shardmodel import ShardModel
|
||||
from colossalai.utils import get_current_device, print_rank_0
|
||||
|
||||
os.environ['TRANSFORMERS_NO_ADVISORY_WARNINGS'] = 'true'
|
||||
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
|
||||
|
||||
|
||||
def get_args():
|
||||
parser = colossalai.get_default_parser()
|
||||
parser.add_argument("--mode", type=str, default='inference')
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def load_data():
|
||||
datasets = load_dataset('wikitext', 'wikitext-2-raw-v1')
|
||||
# datasets=load_dataset("yelp_review_full")
|
||||
tokenized_datasets = datasets.map(
|
||||
lambda examples: tokenizer(examples["text"], truncation=True, padding="max_length"), batched=True)
|
||||
tokenized_datasets = tokenized_datasets.remove_columns(["text"])
|
||||
# tokenized_datasets=tokenized_datasets.rename_column("label","labels")
|
||||
tokenized_datasets.set_format("torch")
|
||||
|
||||
train_dataset = tokenized_datasets["train"].select(range(500))
|
||||
test_dataset = tokenized_datasets["test"].select(range(100))
|
||||
|
||||
datacollector = DataCollatorForLanguageModeling(tokenizer, mlm=True, mlm_probability=0.15, return_tensors="pt")
|
||||
train_dataloader = DataLoader(train_dataset, batch_size=8, shuffle=True, collate_fn=datacollector)
|
||||
eval_dataloader = DataLoader(test_dataset, batch_size=8, shuffle=True, collate_fn=datacollector)
|
||||
return train_dataloader, eval_dataloader
|
||||
|
||||
|
||||
def inference(model: nn.Module):
|
||||
# print(model)
|
||||
print(model)
|
||||
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
|
||||
token = "Hello, my dog is cute"
|
||||
inputs = tokenizer(token, return_tensors="pt")
|
||||
inputs.to("cuda")
|
||||
|
@ -25,13 +53,48 @@ def inference(model: nn.Module):
|
|||
outputs = model(**inputs)
|
||||
print(outputs)
|
||||
|
||||
|
||||
def train(model: nn.Module, num_epoch: int = 2):
|
||||
train_dataloader, eval_dataloader = load_data()
|
||||
optimizer = torch.optim.AdamW(model.parameters(), lr=5e-5)
|
||||
progress_bar = tqdm(range((num_epoch) * len(train_dataloader)))
|
||||
criterion = nn.CrossEntropyLoss()
|
||||
model.to("cuda")
|
||||
model.train()
|
||||
for epoch in range(num_epoch):
|
||||
progress_bar.set_description(f"Rank {get_current_device()} epoch {epoch}")
|
||||
|
||||
for batch in train_dataloader:
|
||||
optimizer.zero_grad()
|
||||
batch = {k: v.to('cuda') for k, v in batch.items()}
|
||||
outputs = model(**batch)
|
||||
loss = outputs.loss
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
progress_bar.update(1)
|
||||
train_loss = loss
|
||||
|
||||
loss = 0.0
|
||||
for batch in eval_dataloader:
|
||||
batch = {k: v.to('cuda') for k, v in batch.items()}
|
||||
outputs = model(**batch)
|
||||
# loss = outputs.loss
|
||||
loss += outputs.loss.item()
|
||||
# loss = criterion(outputs.logits, batch["input_ids"])
|
||||
test_loss = loss / len(eval_dataloader)
|
||||
print_rank_0(f"Train Loss: {train_loss:.4f} Test Loss:{test_loss:.4f}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = get_args()
|
||||
colossalai.launch_from_torch(config=args.config)
|
||||
model = BertForMaskedLM.from_pretrained("bert-base-uncased")
|
||||
shard_config = ShardConfig(
|
||||
rank = int(str(get_current_device()).split(':')[-1]),
|
||||
world_size= int(os.environ['WORLD_SIZE']),
|
||||
rank=int(str(get_current_device()).split(':')[-1]),
|
||||
world_size=int(os.environ['WORLD_SIZE']),
|
||||
)
|
||||
shardmodel = ShardModel(model, shard_config)
|
||||
inference(shardmodel.model)
|
||||
if args.mode == "train":
|
||||
train(shardmodel.model)
|
||||
elif args.mode == "inference":
|
||||
inference(shardmodel.model)
|
||||
|
|
|
@ -1,10 +1,10 @@
|
|||
def hasattr_(obj, attr: str):
|
||||
"""
|
||||
r"""
|
||||
Check whether the object has the multi sublevel attr
|
||||
|
||||
Args:
|
||||
obj: The object to check
|
||||
attr: The multi level attr to check
|
||||
obj (object): The object to check
|
||||
attr (str): The multi level attr to check
|
||||
"""
|
||||
attrs = attr.split('.')
|
||||
for a in attrs:
|
||||
|
@ -14,15 +14,16 @@ def hasattr_(obj, attr: str):
|
|||
return False
|
||||
return True
|
||||
|
||||
def setattr_(obj, attr: str, value, ignore: bool=False):
|
||||
"""
|
||||
|
||||
def setattr_(obj, attr: str, value, ignore: bool = False):
|
||||
r"""
|
||||
Set the object's multi sublevel attr to value, if ignore, ignore when it doesn't exist
|
||||
|
||||
Args:
|
||||
obj: The object to set
|
||||
attr: The multi level attr to set
|
||||
value: The value to set
|
||||
ignore: Whether to ignore when the attr doesn't exist
|
||||
obj (object): The object to set
|
||||
attr (str): The multi level attr to set
|
||||
value (Any): The value to set
|
||||
ignore (bool): Whether to ignore when the attr doesn't exist
|
||||
"""
|
||||
|
||||
attrs = attr.split('.')
|
||||
|
@ -31,18 +32,19 @@ def setattr_(obj, attr: str, value, ignore: bool=False):
|
|||
obj = getattr(obj, a)
|
||||
except AttributeError:
|
||||
if ignore:
|
||||
return
|
||||
return
|
||||
raise AttributeError(f"Object {obj} has no attribute {attr}")
|
||||
setattr(obj, attrs[-1], value)
|
||||
|
||||
def getattr_(obj, attr: str, ignore: bool=None):
|
||||
"""
|
||||
|
||||
def getattr_(obj, attr: str, ignore: bool = None):
|
||||
r"""
|
||||
Get the object's multi sublevel attr
|
||||
|
||||
|
||||
Args:
|
||||
obj: The object to set
|
||||
attr: The multi level attr to set
|
||||
ignore: Whether to ignore when the attr doesn't exist
|
||||
obj (object): The object to set
|
||||
attr (str): The multi level attr to set
|
||||
ignore (bool): Whether to ignore when the attr doesn't exist
|
||||
"""
|
||||
|
||||
attrs = attr.split('.')
|
||||
|
@ -53,4 +55,4 @@ def getattr_(obj, attr: str, ignore: bool=None):
|
|||
if ignore:
|
||||
return None
|
||||
raise AttributeError(f"Object {obj} has no attribute {attr}")
|
||||
return obj
|
||||
return obj
|
||||
|
|
Loading…
Reference in New Issue