mirror of https://github.com/hpcaitech/ColossalAI
[shardformer] init shardformer code structure (#3731)
* init shardformer code structure * add implement of sharder (inject and replace) * add implement of replace layer to colossal layer * separate different layer policy, add some notion * implement 1d and 2d slicer, can tell col or row * fix bug when slicing and inject model * fix some bug; add inference test examplepull/3943/head
parent
a98e16ed07
commit
6a69b44dfc
|
@ -0,0 +1,63 @@
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from torch.nn import CrossEntropyLoss
|
||||||
|
from typing import Any, Dict, List, Type
|
||||||
|
|
||||||
|
|
||||||
|
from transformers import BertForMaskedLM
|
||||||
|
from transformers.models.bert.modeling_bert import MaskedLMOutput
|
||||||
|
class BertForMaskedLM_(BertForMaskedLM):
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input_ids=None,
|
||||||
|
attention_mask=None,
|
||||||
|
token_type_ids=None,
|
||||||
|
position_ids=None,
|
||||||
|
head_mask=None,
|
||||||
|
inputs_embeds=None,
|
||||||
|
encoder_hidden_states=None,
|
||||||
|
encoder_attention_mask=None,
|
||||||
|
labels=None,
|
||||||
|
output_attentions=None,
|
||||||
|
output_hidden_states=None,
|
||||||
|
return_dict=None,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
print("[Inject OK] Injected forward method")
|
||||||
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||||
|
|
||||||
|
outputs = self.bert(
|
||||||
|
input_ids,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
token_type_ids=token_type_ids,
|
||||||
|
position_ids=position_ids,
|
||||||
|
head_mask=head_mask,
|
||||||
|
inputs_embeds=inputs_embeds,
|
||||||
|
encoder_hidden_states=encoder_hidden_states,
|
||||||
|
encoder_attention_mask=encoder_attention_mask,
|
||||||
|
output_attentions=output_attentions,
|
||||||
|
output_hidden_states=output_hidden_states,
|
||||||
|
return_dict=return_dict,
|
||||||
|
)
|
||||||
|
|
||||||
|
sequence_output = outputs[0]
|
||||||
|
prediction_scores = self.cls(sequence_output)
|
||||||
|
|
||||||
|
masked_lm_loss = None
|
||||||
|
|
||||||
|
# if input_ids is not None:
|
||||||
|
# masked_lm_loss = applyDistCrossEntropy(prediction_scores, input_ids, self.config.vocab_size)
|
||||||
|
if labels is not None:
|
||||||
|
loss_fct = CrossEntropyLoss() # -100 index = padding token
|
||||||
|
masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
|
||||||
|
|
||||||
|
if not return_dict:
|
||||||
|
output = (prediction_scores,) + outputs[2:]
|
||||||
|
return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
|
||||||
|
|
||||||
|
return MaskedLMOutput(
|
||||||
|
loss=masked_lm_loss,
|
||||||
|
logits=prediction_scores,
|
||||||
|
hidden_states=outputs.hidden_states,
|
||||||
|
attentions=outputs.attentions,
|
||||||
|
)
|
|
@ -0,0 +1,41 @@
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
def build_policies():
|
||||||
|
"""
|
||||||
|
Build the policies for the model
|
||||||
|
|
||||||
|
Return:
|
||||||
|
The dict for the policies
|
||||||
|
"""
|
||||||
|
auto_policy_dict = {}
|
||||||
|
|
||||||
|
from transformers.models.bert.modeling_bert import BertForMaskedLM
|
||||||
|
from .bert import BertForMaskedLMPolicy
|
||||||
|
auto_policy_dict[BertForMaskedLM] = BertForMaskedLMPolicy
|
||||||
|
|
||||||
|
from transformers.models.bert.modeling_bert import BertForSequenceClassification
|
||||||
|
from .bert import BertForSequenceClassificationPolicy
|
||||||
|
auto_policy_dict[BertForSequenceClassification] = BertForSequenceClassificationPolicy
|
||||||
|
|
||||||
|
return auto_policy_dict
|
||||||
|
|
||||||
|
def get_autopolicy(model:nn.Module):
|
||||||
|
"""
|
||||||
|
Return the auto policy for the model
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model: The model to be used
|
||||||
|
|
||||||
|
Return:
|
||||||
|
The auto policy for the model
|
||||||
|
"""
|
||||||
|
auto_policy_dict = build_policies()
|
||||||
|
policy = auto_policy_dict.get(model.__class__, None)
|
||||||
|
if policy is None:
|
||||||
|
raise NotImplementedError(f"Auto policy for {model.__class__.__qualname__} is not implemented\n Supported models are {[i.__qualname__ for i in auto_policy_dict.keys()]}")
|
||||||
|
return policy
|
||||||
|
|
||||||
|
# from transformers.models.bert.modeling_bert import BertForMaskedLM, BertForPreTraining
|
||||||
|
# model = BertForPreTraining
|
||||||
|
# policy = get_autopolicy(model)
|
||||||
|
# print(policy)
|
|
@ -0,0 +1,182 @@
|
||||||
|
# part of code modified from https://github.com/tunib-ai/parallelformers
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import colossalai.nn as col_nn
|
||||||
|
from typing import Any, Dict, List, Type, Tuple, Callable
|
||||||
|
from transformers import AutoConfig
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Argument:
|
||||||
|
attr_dict : Dict[str, Any]
|
||||||
|
param_funcs : List[Callable]
|
||||||
|
binding_layers : List[nn.Module] = field(default_factory=list)
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Layer:
|
||||||
|
"""
|
||||||
|
The layer object for the policy
|
||||||
|
|
||||||
|
Args:
|
||||||
|
weight: The weight name of the layer
|
||||||
|
bias: The bias name of the layer
|
||||||
|
replace_layer: The layer to replace the original layer
|
||||||
|
ignore: Whether to ignore this layer if it is not in the model
|
||||||
|
"""
|
||||||
|
weight: str = None
|
||||||
|
bias: str = None
|
||||||
|
replace_layer: Any = None
|
||||||
|
ignore: bool = False
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Col_Layer(Layer):
|
||||||
|
"""
|
||||||
|
Class for col shard layer in MegatronLM
|
||||||
|
"""
|
||||||
|
gather_output: bool = False
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Row_Layer(Layer):
|
||||||
|
"""
|
||||||
|
Class for col shard layer in MegatronLM
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class Policy():
|
||||||
|
"""
|
||||||
|
The base class for all the policies
|
||||||
|
For each different model, it should have a different policy class, like BertPolicy for Bert Model
|
||||||
|
or OPTPolicy for OPT model.
|
||||||
|
AutoPolicy:
|
||||||
|
shardformer already defined some policies for huggingface model, just set custom_policy = None
|
||||||
|
to use the auto policy. In shardformer autopolicy, we define a base policy for one type model,
|
||||||
|
like BertPolicy, and for each different Bert modle in huggingface like, BertForMaskedLM,
|
||||||
|
BertForSequenceClassification, etc., for each different Bert model we difine different policy class
|
||||||
|
and overwrite the method inject_policy
|
||||||
|
|
||||||
|
CustomPolicy:
|
||||||
|
"""
|
||||||
|
@staticmethod
|
||||||
|
def argument_policy(model_config, shard_config: int) -> Dict[nn.Module,Argument]:
|
||||||
|
"""
|
||||||
|
Return a dict, the key is layer will be modified and the value is the Argument class with param setting and param functions
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_config: The config of transformer model
|
||||||
|
shard_setting: The config of distributed model
|
||||||
|
|
||||||
|
Return:
|
||||||
|
Dict for the modify policy,
|
||||||
|
{
|
||||||
|
origin layer class1 (nn.Module): Argument(
|
||||||
|
attr_dict = {
|
||||||
|
argument1: value1,
|
||||||
|
argument2: value2,
|
||||||
|
...
|
||||||
|
},
|
||||||
|
param_funcs = [
|
||||||
|
staticmethod1,
|
||||||
|
staticmethod2,
|
||||||
|
...
|
||||||
|
]
|
||||||
|
),
|
||||||
|
origin layer class2 (nn.Module): Argument(
|
||||||
|
attr_dict = {
|
||||||
|
argument1: value1,
|
||||||
|
argument2: value2,
|
||||||
|
...
|
||||||
|
},
|
||||||
|
param_funcs = [
|
||||||
|
staticmethod1,
|
||||||
|
staticmethod2,
|
||||||
|
...
|
||||||
|
]
|
||||||
|
),
|
||||||
|
...
|
||||||
|
}
|
||||||
|
|
||||||
|
"""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def inject_policy() -> Tuple[nn.Module, nn.Module]:
|
||||||
|
"""
|
||||||
|
Return the dict for the inject model
|
||||||
|
|
||||||
|
Return:
|
||||||
|
The injected model, key is the original model and value is the new shardmodel
|
||||||
|
"""
|
||||||
|
return ()
|
||||||
|
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def attn_in() -> List:
|
||||||
|
"""
|
||||||
|
Attention qkv layer
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[Layer]: List of layer object, each layer is the new
|
||||||
|
"""
|
||||||
|
return NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def attn_out() -> List:
|
||||||
|
"""
|
||||||
|
Attention output projection layer
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[Layer]: List of layer object
|
||||||
|
"""
|
||||||
|
return NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def mlp_in() -> List:
|
||||||
|
"""
|
||||||
|
h -> 4h mlp layer
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[Layer]: List of layer object
|
||||||
|
"""
|
||||||
|
return NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def mlp_out() -> List:
|
||||||
|
"""
|
||||||
|
4h -> h mlp layer
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[Layer]: List of layer object
|
||||||
|
"""
|
||||||
|
return NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def embedding()->List:
|
||||||
|
"""
|
||||||
|
Partially slice the embedding layer
|
||||||
|
vocab_size->vocab_size//gpu_nums
|
||||||
|
|
||||||
|
Return:
|
||||||
|
List[Layer]: List of layer object
|
||||||
|
"""
|
||||||
|
return NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def unembedding()->List:
|
||||||
|
"""
|
||||||
|
Partially slice the embedding layer
|
||||||
|
vocab_size->vocab_size//gpu_nums
|
||||||
|
|
||||||
|
Return:
|
||||||
|
List[Layer]: List of layer object
|
||||||
|
"""
|
||||||
|
return NotImplementedError
|
|
@ -0,0 +1,168 @@
|
||||||
|
from typing import Dict, List, Tuple, Type, Any, Callable
|
||||||
|
import torch.nn as nn
|
||||||
|
from .basepolicy import Policy, Layer, Argument, Col_Layer, Row_Layer
|
||||||
|
import colossalai.nn as col_nn
|
||||||
|
from transformers.models.bert.modeling_bert import BertLayer, BertEmbeddings, BertLMPredictionHead
|
||||||
|
from dataclasses import dataclass
|
||||||
|
|
||||||
|
|
||||||
|
class BertPolicy(Policy):
|
||||||
|
@staticmethod
|
||||||
|
def argument_policy(config, world_size: int) -> Dict[nn.Module,Argument]:
|
||||||
|
return {
|
||||||
|
BertLayer: Argument(
|
||||||
|
attr_dict = {
|
||||||
|
# 1. shard hidden size
|
||||||
|
"attention.self.all_head_size": config.hidden_size // world_size,
|
||||||
|
"crossattention.self.all_head_size": config.hidden_size // world_size,
|
||||||
|
# 2. shard number of heads
|
||||||
|
"attention.self.num_attention_heads": config.num_attention_heads // world_size,
|
||||||
|
"crossattention.self.num_attention_heads": config.num_attention_heads // world_size,
|
||||||
|
|
||||||
|
},
|
||||||
|
param_funcs = [
|
||||||
|
BertPolicy.attn_in,
|
||||||
|
BertPolicy.attn_out,
|
||||||
|
BertPolicy.mlp_in,
|
||||||
|
BertPolicy.mlp_out
|
||||||
|
]
|
||||||
|
),
|
||||||
|
BertEmbeddings: Argument(
|
||||||
|
attr_dict = {
|
||||||
|
# 1. shard vocab size
|
||||||
|
# "word_embeddings.num_embeddings": config.vocab_size // world_size,
|
||||||
|
# 2. add the size of the sliced embedding layer excluding the last slice
|
||||||
|
"word_embeddings.dim_size": (config.vocab_size+world_size-1) // world_size,
|
||||||
|
},
|
||||||
|
param_funcs = [
|
||||||
|
BertPolicy.embedding,
|
||||||
|
],
|
||||||
|
binding_layers = [
|
||||||
|
BertLMPredictionHead,
|
||||||
|
]
|
||||||
|
),
|
||||||
|
BertLMPredictionHead: Argument(
|
||||||
|
attr_dict = {
|
||||||
|
# 1. shard vocab size
|
||||||
|
# "word_embeddings.num_embeddings": config.vocab_size // world_size,
|
||||||
|
# 2. add the size of the sliced embedding layer excluding the last slice
|
||||||
|
},
|
||||||
|
param_funcs = [
|
||||||
|
BertPolicy.unembedding,
|
||||||
|
]
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def attn_in() -> List:
|
||||||
|
return [
|
||||||
|
Col_Layer(
|
||||||
|
weight="attention.self.query.weight",
|
||||||
|
bias="attention.self.query.bias",
|
||||||
|
replace_layer=col_nn.Linear1D_Col,
|
||||||
|
),
|
||||||
|
Col_Layer(
|
||||||
|
weight="attention.self.key.weight",
|
||||||
|
bias="attention.self.key.bias",
|
||||||
|
replace_layer=col_nn.Linear1D_Col,
|
||||||
|
),
|
||||||
|
Col_Layer(
|
||||||
|
weight="attention.self.value.weight",
|
||||||
|
bias="attention.self.value.bias",
|
||||||
|
replace_layer=col_nn.Linear1D_Col,
|
||||||
|
),
|
||||||
|
Col_Layer(
|
||||||
|
weight="crossattention.self.query.weight",
|
||||||
|
bias="crossattention.self.query.bias",
|
||||||
|
replace_layer=col_nn.Linear1D_Col,
|
||||||
|
ignore=True,
|
||||||
|
),
|
||||||
|
Col_Layer(
|
||||||
|
weight="crossattention.self.key.weight",
|
||||||
|
bias="crossattention.self.key.bias",
|
||||||
|
replace_layer=col_nn.Linear1D_Col,
|
||||||
|
ignore=True,
|
||||||
|
),
|
||||||
|
Col_Layer(
|
||||||
|
weight="crossattention.self.value.weight",
|
||||||
|
bias="crossattention.self.value.bias",
|
||||||
|
replace_layer=col_nn.Linear1D_Col,
|
||||||
|
ignore=True,
|
||||||
|
),
|
||||||
|
|
||||||
|
]
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def attn_out() -> List:
|
||||||
|
return [
|
||||||
|
Row_Layer(
|
||||||
|
weight="attention.output.dense.weight",
|
||||||
|
bias="attention.output.dense.bias",
|
||||||
|
replace_layer=col_nn.Linear1D_Row,
|
||||||
|
),
|
||||||
|
Row_Layer(
|
||||||
|
weight="crossattention.output.dense.weight",
|
||||||
|
bias="crossattention.output.dense.bias",
|
||||||
|
replace_layer=col_nn.Linear1D_Row,
|
||||||
|
ignore=True,
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def mlp_in() -> List:
|
||||||
|
return [
|
||||||
|
Col_Layer(
|
||||||
|
weight="intermediate.dense.weight",
|
||||||
|
bias="intermediate.dense.bias",
|
||||||
|
replace_layer=col_nn.Linear1D_Col,
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def mlp_out() -> List:
|
||||||
|
return [
|
||||||
|
Row_Layer(
|
||||||
|
weight="output.dense.weight",
|
||||||
|
bias="output.dense.bias",
|
||||||
|
replace_layer=col_nn.Linear1D_Row,
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def embedding() -> List:
|
||||||
|
return [
|
||||||
|
Col_Layer(
|
||||||
|
weight="word_embeddings.weight",
|
||||||
|
replace_layer=col_nn.VocabParallelEmbedding1D,
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def unembedding() -> List:
|
||||||
|
return [
|
||||||
|
Col_Layer(
|
||||||
|
weight="decoder.weight",
|
||||||
|
bias="decoder.bias",
|
||||||
|
replace_layer=col_nn.Linear1D_Col,
|
||||||
|
gather_output=True,
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
from transformers import BertForMaskedLM
|
||||||
|
from colossalai.shardformer.model.modeling_bert import BertForMaskedLM_
|
||||||
|
class BertForMaskedLMPolicy(BertPolicy):
|
||||||
|
@staticmethod
|
||||||
|
def inject_policy() -> Tuple[nn.Module, nn.Module]:
|
||||||
|
return (BertForMaskedLM, BertForMaskedLM_)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class BertForSequenceClassificationPolicy(BertPolicy):
|
||||||
|
@staticmethod
|
||||||
|
def inject_policy() -> Dict:
|
||||||
|
return {}
|
||||||
|
|
||||||
|
|
||||||
|
# model = BertForMaskedLM.from_pretrained("bert-base-uncased")
|
||||||
|
# _ = BertForMaskedLMPolicy(model)
|
||||||
|
# print(isinstance(model,list(_.inject_policy().keys())[0]))
|
|
@ -0,0 +1,18 @@
|
||||||
|
from dataclasses import dataclass
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ShardConfig:
|
||||||
|
"""
|
||||||
|
The config for sharding the huggingface model for test
|
||||||
|
"""
|
||||||
|
rank: int
|
||||||
|
fp16: bool = True
|
||||||
|
num_gpus: int = 2
|
||||||
|
world_size: int = 2
|
||||||
|
backend="nccl"
|
||||||
|
verbose: str = 'simple'
|
||||||
|
seed: int = None
|
||||||
|
require_grad: bool = False
|
||||||
|
master_addr: str = "127.0.0.1"
|
||||||
|
master_port: int = 29500
|
|
@ -0,0 +1,238 @@
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from typing import Any, Dict, Iterable, List, Optional, Tuple, Type, Union, Callable
|
||||||
|
from .shardconfig import ShardConfig
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from ..policies.basepolicy import Policy, Layer
|
||||||
|
from ..policies.autopolicy import get_autopolicy
|
||||||
|
from .slicer import Slicer
|
||||||
|
from ..utils.utils import hasattr_, setattr_, getattr_
|
||||||
|
import colossalai.nn as col_nn
|
||||||
|
from colossalai.logging import get_dist_logger
|
||||||
|
import os
|
||||||
|
|
||||||
|
|
||||||
|
logger = get_dist_logger()
|
||||||
|
|
||||||
|
class ModelSharder(object):
|
||||||
|
"""
|
||||||
|
Shard the original huggingface model according to the policy
|
||||||
|
|
||||||
|
Args:
|
||||||
|
policy: The policy to shard the model
|
||||||
|
model: The model to shard
|
||||||
|
dist_setting: The setting of distributed model
|
||||||
|
"""
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model: nn.Module,
|
||||||
|
policy: Policy,
|
||||||
|
shard_config: ShardConfig = None, # TODO
|
||||||
|
) -> None:
|
||||||
|
self.model = model
|
||||||
|
self.policy = get_autopolicy(self.model) if policy is None else policy
|
||||||
|
self.slicer = Slicer(shard_config)
|
||||||
|
self.shard_config = shard_config
|
||||||
|
self.model_config = self.model.config
|
||||||
|
self.binding_map = {}
|
||||||
|
|
||||||
|
|
||||||
|
def shard(self) -> None:
|
||||||
|
self.inject_model(self.model)
|
||||||
|
self.replace_layer(self.model)
|
||||||
|
|
||||||
|
|
||||||
|
def inject_model(
|
||||||
|
self,
|
||||||
|
model: nn.Module,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Replace the model to policy defined model
|
||||||
|
Mainly modify the forward and backward to fit distributed model
|
||||||
|
|
||||||
|
e.g.
|
||||||
|
BertForMaskedLM.forward -> BertForMaskedLM_.forward
|
||||||
|
"""
|
||||||
|
inject_policy = self.policy.inject_policy()
|
||||||
|
|
||||||
|
org_model_cls = inject_policy[0]
|
||||||
|
shard_model_cls = inject_policy[1]
|
||||||
|
|
||||||
|
if model.__class__ == org_model_cls:
|
||||||
|
for key in shard_model_cls.__dict__.keys():
|
||||||
|
if hasattr(model.__class__, key):
|
||||||
|
setattr(
|
||||||
|
model.__class__,
|
||||||
|
key,
|
||||||
|
getattr(shard_model_cls,key),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise NotImplementedError(f"{model.__class__} is not implemented so far")
|
||||||
|
|
||||||
|
|
||||||
|
def replace_layer(
|
||||||
|
self,
|
||||||
|
model: nn.Module,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Replace the layer according to the policy, and replace the layer one by one
|
||||||
|
|
||||||
|
Args:
|
||||||
|
layer: The layer to shard
|
||||||
|
"""
|
||||||
|
argument_policies = self.policy.argument_policy(self.model_config, self.shard_config.world_size)
|
||||||
|
for argument_policy in argument_policies.items():
|
||||||
|
origin_layer_cls = argument_policy[0]
|
||||||
|
attr_dict = argument_policy[1].attr_dict
|
||||||
|
param_funcs = argument_policy[1].param_funcs
|
||||||
|
binding_layers = argument_policy[1].binding_layers
|
||||||
|
# if binding_layer is not None:
|
||||||
|
# self.binding_map[origin_layer_cls] = binding_layer
|
||||||
|
self.reverse_replace_layer(model, origin_layer_cls, attr_dict, param_funcs, binding_layers)
|
||||||
|
|
||||||
|
|
||||||
|
def reverse_replace_layer(
|
||||||
|
self,
|
||||||
|
layer: nn.Module,
|
||||||
|
origin_cls: nn.Module,
|
||||||
|
attr_dict: Dict[str, Any],
|
||||||
|
param_funcs: List[Callable],
|
||||||
|
binding_layers: List[nn.Module]
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Reverse the replace layer operation
|
||||||
|
|
||||||
|
Args:
|
||||||
|
layer: The object of layer to shard
|
||||||
|
origin_cls: The origin layer class
|
||||||
|
attr_dict: The attribute dict to modify
|
||||||
|
policy_cls: The policy class
|
||||||
|
"""
|
||||||
|
for name, child in layer.named_children():
|
||||||
|
if child.__class__ == origin_cls:
|
||||||
|
# replac_layer = child
|
||||||
|
for k, v in attr_dict.items():
|
||||||
|
setattr_(child, k, v, ignore=True)
|
||||||
|
# print(f"Sharding {name} layer", replac_layer.attention.self.__dict__)
|
||||||
|
# setattr_(layer, name, self.shard_one_layer(child, policy_cls))
|
||||||
|
self.shard_one_layer(child, param_funcs, binding_layers)
|
||||||
|
continue
|
||||||
|
|
||||||
|
self.reverse_replace_layer(child, origin_cls, attr_dict, param_funcs, binding_layers)
|
||||||
|
return layer
|
||||||
|
|
||||||
|
|
||||||
|
def shard_one_layer(
|
||||||
|
self,
|
||||||
|
org_layer: nn.Module,
|
||||||
|
param_funcs: List[Callable],
|
||||||
|
binding_layers: List[nn.Module]
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Shard one layer according to the policy, the layer should be the same class as the key in policy's argument_policy return dict
|
||||||
|
|
||||||
|
Args:
|
||||||
|
org_layer: The origin layer object to shard
|
||||||
|
param_funcs: The function list to get shard information in policy class
|
||||||
|
|
||||||
|
"""
|
||||||
|
# print(org_layer)
|
||||||
|
for func in param_funcs:
|
||||||
|
policy_layers = func()
|
||||||
|
for policy_layer in policy_layers:
|
||||||
|
weight = None
|
||||||
|
bias = None
|
||||||
|
weight_attr = policy_layer.weight
|
||||||
|
bias_attr = policy_layer.bias
|
||||||
|
replace_layer_cls = policy_layer.replace_layer
|
||||||
|
ignore = policy_layer.ignore
|
||||||
|
if policy_layer.__class__.__name__ == "Col_Layer":
|
||||||
|
gather_output = policy_layer.gather_output
|
||||||
|
print(gather_output)
|
||||||
|
|
||||||
|
if weight_attr is not None:
|
||||||
|
if hasattr_(org_layer, weight_attr):
|
||||||
|
weight = getattr_(org_layer, weight_attr)
|
||||||
|
elif not ignore:
|
||||||
|
raise ValueError(f"Layer {org_layer.__class__.__qualname__} has no attribute {weight_attr}")
|
||||||
|
|
||||||
|
if bias_attr is not None:
|
||||||
|
if hasattr_(org_layer, bias_attr):
|
||||||
|
bias = getattr_(org_layer, bias_attr)
|
||||||
|
elif not ignore:
|
||||||
|
raise ValueError(f"Layer {org_layer.__class__.__qualname__} has no attribute {bias_attr}")
|
||||||
|
|
||||||
|
# dont have the attribute in policy, and ignore is true
|
||||||
|
if weight is None and bias is None and ignore:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# set the sliced weight and bias to the new nn_col layer
|
||||||
|
assert weight is not None or bias is not None
|
||||||
|
layer_attr = (lambda x: x[:x.rfind(".")])(weight_attr or bias_attr)
|
||||||
|
|
||||||
|
# slice weight and bias
|
||||||
|
weight, bias = self.slicer.slice_weight_bias(weight, bias, policy_layer.__class__)
|
||||||
|
print(os.environ['RANK'], policy_layer.__class__, weight.shape, bias.shape if bias is not None else None)
|
||||||
|
# save the binding information
|
||||||
|
for binding_layer in binding_layers:
|
||||||
|
self.binding_map[binding_layer] = dict(weight=weight, bias=bias)
|
||||||
|
|
||||||
|
# create new object to replace the origin layer
|
||||||
|
if replace_layer_cls is not None:
|
||||||
|
# print(f"RANK {os.environ['RANK']}: replace {getattr_(org_layer, layer_attr).__class__} to {replace_layer_cls}, shape is {weight.shape}")
|
||||||
|
if isinstance(getattr_(org_layer, layer_attr), nn.Linear):
|
||||||
|
if replace_layer_cls.__name__ == "Linear1D_Row":
|
||||||
|
replace_layer = replace_layer_cls(weight.shape[1], weight.shape[0], bias=False if bias is None else True)
|
||||||
|
elif replace_layer_cls.__name__ == "Linear1D_Col":
|
||||||
|
replace_layer = replace_layer_cls(weight.shape[0], weight.shape[1], bias=False if bias is None else True, gather_output=gather_output)
|
||||||
|
setattr_(org_layer, layer_attr, replace_layer, ignore=ignore)
|
||||||
|
self.set_param(replace_layer, weight, bias)
|
||||||
|
elif isinstance(getattr_(org_layer, layer_attr), nn.Embedding):
|
||||||
|
replace_layer = replace_layer_cls(weight.shape[0], weight.shape[1], getattr_(org_layer, f"{layer_attr}.padding_idx", ignore=True))
|
||||||
|
setattr_(org_layer, layer_attr, replace_layer, ignore=ignore)
|
||||||
|
self.set_param(replace_layer, weight, bias)
|
||||||
|
else:
|
||||||
|
raise NotImplementedError(f"Replacing {getattr_(org_layer, layer_attr).__class__} is not implemented so far")
|
||||||
|
# do not replace the layer object, just replace the weight and bias
|
||||||
|
else:
|
||||||
|
self.set_param(org_layer, layer_attr, weight, bias)
|
||||||
|
|
||||||
|
|
||||||
|
def set_param(
|
||||||
|
self,
|
||||||
|
layer: Any,
|
||||||
|
layer_attr: str = "",
|
||||||
|
weight: torch.Tensor = None,
|
||||||
|
bias: torch.Tensor = None
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Reset the weight and bias of the layer object
|
||||||
|
|
||||||
|
Args:
|
||||||
|
layer: The layer object
|
||||||
|
layer_attr: The attribute name of the layer
|
||||||
|
weight: The weight of the layer
|
||||||
|
bias: The bias of the layer
|
||||||
|
"""
|
||||||
|
assert weight is not None or bias is not None
|
||||||
|
if weight is not None:
|
||||||
|
setattr_(layer, "weight" if layer_attr == "" else layer_attr+".weight", nn.Parameter(weight))
|
||||||
|
self.set_layer_size(layer, layer_attr, weight.shape)
|
||||||
|
if bias is not None:
|
||||||
|
setattr_(layer, "bias" if layer_attr == "" else layer_attr+".bias", nn.Parameter(bias))
|
||||||
|
|
||||||
|
|
||||||
|
def set_layer_size(self, layer: nn.Module, layer_attr: str, size: torch.Size) -> None:
|
||||||
|
"""
|
||||||
|
Set the layer attribute
|
||||||
|
|
||||||
|
Args:
|
||||||
|
layer: The layer object
|
||||||
|
layer_attr: The attribute name of the layer
|
||||||
|
size: Torch.size
|
||||||
|
"""
|
||||||
|
# Tensor.shape[0] -> out_features, Tensor.shape[1] -> in_features
|
||||||
|
attrs = ["out_features", "in_features"]
|
||||||
|
for i, attr in enumerate(attrs):
|
||||||
|
if hasattr_(layer, f"{layer_attr}.{attr}"):
|
||||||
|
setattr_(layer, f"{layer_attr}.{attr}", size[i])
|
|
@ -0,0 +1,58 @@
|
||||||
|
import os
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import transformers
|
||||||
|
import torch.distributed as dist
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from contextlib import suppress
|
||||||
|
|
||||||
|
from colossalai.tensor.d_tensor.layout import Layout
|
||||||
|
from ..policies.basepolicy import Policy
|
||||||
|
from .sharder import ModelSharder
|
||||||
|
from .shardconfig import ShardConfig
|
||||||
|
|
||||||
|
|
||||||
|
class ShardModel(object):
|
||||||
|
"""
|
||||||
|
The class for sharding the huggingface model, self.model is the sharded model
|
||||||
|
Just creat a new ShardModel object to shard huggingface model
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model: the origin huggingface model
|
||||||
|
dist_config: the config for distribute information
|
||||||
|
custom_policy: the custom policy for sharding
|
||||||
|
"""
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model: nn.Module,
|
||||||
|
shard_config: ShardConfig = None, # TODO
|
||||||
|
custom_policy: Policy = None,
|
||||||
|
) -> None:
|
||||||
|
self.model = model
|
||||||
|
self.shard_config = shard_config
|
||||||
|
self.policy = custom_policy
|
||||||
|
# self.layout=, # TODO
|
||||||
|
|
||||||
|
sharder=ModelSharder(
|
||||||
|
model=self.model,
|
||||||
|
policy=self.policy,
|
||||||
|
shard_config=self.shard_config,
|
||||||
|
)
|
||||||
|
sharder.shard()
|
||||||
|
|
||||||
|
|
||||||
|
def set_environ(self) -> None:
|
||||||
|
os.environ["TOKENIZERS_PARALLELISM"] = "true"
|
||||||
|
os.environ["MKL_SERVICE_FORCE_INTEL"] = "GNU"
|
||||||
|
os.environ["MASTER_ADDR"] = str(self.dist_config.master_addr)
|
||||||
|
os.environ["MASTER_PORT"] = str(self.dist_config.master_port)
|
||||||
|
os.environ["WORLD_SIZE"] = str(self.dist_config.num_gpus)
|
||||||
|
os.environ["RANK"] = str(self.dist_config.rank)
|
||||||
|
os.environ["LOCAL_RANK"] = str(self.dist_config.rank)
|
||||||
|
if not dist.is_initialized():
|
||||||
|
dist.init_process_group(backend=self.dist_config.backend)
|
||||||
|
|
||||||
|
torch.cuda.set_device(int(os.getenv("LOCAL_RANK", "0")))
|
||||||
|
|
||||||
|
def back_to_org() -> None:
|
||||||
|
pass
|
|
@ -0,0 +1,167 @@
|
||||||
|
import os
|
||||||
|
from typing import Dict, Tuple
|
||||||
|
from dataclasses import dataclass
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.distributed as dist
|
||||||
|
from ..policies.basepolicy import Layer, Col_Layer, Row_Layer
|
||||||
|
from .shardconfig import ShardConfig
|
||||||
|
|
||||||
|
|
||||||
|
dim_mapping = {Col_Layer: 1, Row_Layer: 0}
|
||||||
|
|
||||||
|
class Slicer():
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
shardconfig: ShardConfig #TODO
|
||||||
|
) -> None:
|
||||||
|
self.shardconfig = shardconfig
|
||||||
|
|
||||||
|
|
||||||
|
def slice_weight_bias(
|
||||||
|
self,
|
||||||
|
weight: torch.Tensor,
|
||||||
|
bias: torch.Tensor,
|
||||||
|
policy_layer_cls: Layer,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Slice the weight and bias according to policy layer cls
|
||||||
|
Layer -> do nothing
|
||||||
|
Col_Layer -> slice the weight and bias along dim 1
|
||||||
|
Row_Layer -> slice the weight along dim 0 and do not slice bias
|
||||||
|
|
||||||
|
Args:
|
||||||
|
weight: The weight of the layer
|
||||||
|
bias: The bias of the layer
|
||||||
|
policy_layer_class: The class represent how to slice the tensor
|
||||||
|
"""
|
||||||
|
if policy_layer_cls == Layer:
|
||||||
|
return weight, bias
|
||||||
|
elif policy_layer_cls == Col_Layer:
|
||||||
|
weight = self.slice_tensor(weight, 1, False)
|
||||||
|
bias = self.slice_tensor(bias, 0, True)
|
||||||
|
elif policy_layer_cls == Row_Layer:
|
||||||
|
weight = self.slice_tensor(weight, 0, False)
|
||||||
|
else:
|
||||||
|
raise NotImplementedError(f"The policy layer class {policy_layer_cls} is not supported")
|
||||||
|
return weight, bias
|
||||||
|
|
||||||
|
|
||||||
|
def slice_weight(
|
||||||
|
self,
|
||||||
|
weight: torch.Tensor,
|
||||||
|
policy_layer_cls: Layer,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Slice the weight and bias according to the shardconfig
|
||||||
|
|
||||||
|
Args:
|
||||||
|
weight: The weight of the layer
|
||||||
|
bias: The bias of the layer
|
||||||
|
policy_layer_class: The class represent how to slice the tensor
|
||||||
|
"""
|
||||||
|
if weight is not None:
|
||||||
|
dim = dim_mapping[policy_layer_cls]
|
||||||
|
weight = self.slice_tensor(weight, dim, False)
|
||||||
|
return weight
|
||||||
|
|
||||||
|
|
||||||
|
def slice_bias(
|
||||||
|
self,
|
||||||
|
bias: torch.Tensor,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Slice the bias according to the shardconfig
|
||||||
|
|
||||||
|
Args:
|
||||||
|
bias: The bias of the layer
|
||||||
|
"""
|
||||||
|
assert bias is not None, "The bias is None"
|
||||||
|
if bias is not None:
|
||||||
|
bias = self.slice_tensor(bias, 1, True)
|
||||||
|
return bias
|
||||||
|
|
||||||
|
|
||||||
|
def slice_tensor(
|
||||||
|
self,
|
||||||
|
tensor_in: torch.Tensor,
|
||||||
|
dim: int,
|
||||||
|
is_bias: bool,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Slice tensor according to the config
|
||||||
|
"""
|
||||||
|
if tensor_in is None:
|
||||||
|
return None
|
||||||
|
if not is_bias:
|
||||||
|
return self.slice_2d(tensor_in, dim)
|
||||||
|
else:
|
||||||
|
return self.slice_1d(tensor_in)
|
||||||
|
|
||||||
|
|
||||||
|
def slice_2d(
|
||||||
|
self,
|
||||||
|
tensor: torch.Tensor,
|
||||||
|
dim: int,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Slice the 2D tensor
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tensor: The tensor to slice
|
||||||
|
"""
|
||||||
|
assert dim in [0,1], f"Only support 2D tensor, but got {dim}D tensor"
|
||||||
|
if dim == 0:
|
||||||
|
return self.slice_row(tensor)
|
||||||
|
elif dim == 1:
|
||||||
|
return self.slice_col(tensor)
|
||||||
|
|
||||||
|
|
||||||
|
def slice_1d(
|
||||||
|
self,
|
||||||
|
tensor: torch.Tensor,
|
||||||
|
dim: int = None,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Slice the 1D tensor
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tensor: The tensor to slice
|
||||||
|
"""
|
||||||
|
delta = (tensor.shape[0] + self.shardconfig.world_size - 1) // self.shardconfig.world_size
|
||||||
|
down_idx = self.shardconfig.rank * delta
|
||||||
|
up_idx = down_idx + delta
|
||||||
|
return tensor[down_idx:up_idx]
|
||||||
|
|
||||||
|
def slice_col(
|
||||||
|
self,
|
||||||
|
tensor: torch.Tensor,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Slice the tensor in column
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tensor: The tensor to slice
|
||||||
|
"""
|
||||||
|
delta = (tensor.shape[0] + self.shardconfig.world_size - 1) // self.shardconfig.world_size
|
||||||
|
down_idx = self.shardconfig.rank * delta
|
||||||
|
up_idx = down_idx + delta
|
||||||
|
return tensor[down_idx:up_idx,:]
|
||||||
|
|
||||||
|
|
||||||
|
def slice_row(
|
||||||
|
self,
|
||||||
|
tensor: torch.Tensor,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Slice the tensor in column
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tensor: The tensor to slice
|
||||||
|
"""
|
||||||
|
delta = (tensor.shape[1] + self.shardconfig.world_size - 1) // self.shardconfig.world_size
|
||||||
|
down_idx = self.shardconfig.rank * delta
|
||||||
|
up_idx = down_idx + delta
|
||||||
|
return tensor[:,down_idx:up_idx]
|
||||||
|
|
|
@ -0,0 +1,5 @@
|
||||||
|
parallel = dict(
|
||||||
|
data=1,
|
||||||
|
pipeline=1,
|
||||||
|
tensor=dict(size=2, mode='1d')
|
||||||
|
)
|
|
@ -0,0 +1,37 @@
|
||||||
|
from transformers import AutoTokenizer
|
||||||
|
from transformers import BertForMaskedLM
|
||||||
|
import colossalai
|
||||||
|
from colossalai.shardformer.shard.shardmodel import ShardModel
|
||||||
|
from colossalai.utils import get_current_device, print_rank_0
|
||||||
|
from colossalai.logging import get_dist_logger
|
||||||
|
from colossalai.shardformer.shard.shardconfig import ShardConfig
|
||||||
|
import inspect
|
||||||
|
import argparse
|
||||||
|
import torch.nn as nn
|
||||||
|
import os
|
||||||
|
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
|
||||||
|
|
||||||
|
def get_args():
|
||||||
|
parser = colossalai.get_default_parser()
|
||||||
|
return parser.parse_args()
|
||||||
|
|
||||||
|
def inference(model: nn.Module):
|
||||||
|
# print(model)
|
||||||
|
token = "Hello, my dog is cute"
|
||||||
|
inputs = tokenizer(token, return_tensors="pt")
|
||||||
|
inputs.to("cuda")
|
||||||
|
model.to("cuda")
|
||||||
|
outputs = model(**inputs)
|
||||||
|
print(outputs)
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
args = get_args()
|
||||||
|
colossalai.launch_from_torch(config=args.config)
|
||||||
|
model = BertForMaskedLM.from_pretrained("bert-base-uncased")
|
||||||
|
shard_config = ShardConfig(
|
||||||
|
rank = int(str(get_current_device()).split(':')[-1]),
|
||||||
|
world_size= int(os.environ['WORLD_SIZE']),
|
||||||
|
)
|
||||||
|
shardmodel = ShardModel(model, shard_config)
|
||||||
|
inference(shardmodel.model)
|
|
@ -0,0 +1,56 @@
|
||||||
|
def hasattr_(obj, attr: str):
|
||||||
|
"""
|
||||||
|
Check whether the object has the multi sublevel attr
|
||||||
|
|
||||||
|
Args:
|
||||||
|
obj: The object to check
|
||||||
|
attr: The multi level attr to check
|
||||||
|
"""
|
||||||
|
attrs = attr.split('.')
|
||||||
|
for a in attrs:
|
||||||
|
try:
|
||||||
|
obj = getattr(obj, a)
|
||||||
|
except AttributeError:
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
|
def setattr_(obj, attr: str, value, ignore: bool=False):
|
||||||
|
"""
|
||||||
|
Set the object's multi sublevel attr to value, if ignore, ignore when it doesn't exist
|
||||||
|
|
||||||
|
Args:
|
||||||
|
obj: The object to set
|
||||||
|
attr: The multi level attr to set
|
||||||
|
value: The value to set
|
||||||
|
ignore: Whether to ignore when the attr doesn't exist
|
||||||
|
"""
|
||||||
|
|
||||||
|
attrs = attr.split('.')
|
||||||
|
for a in attrs[:-1]:
|
||||||
|
try:
|
||||||
|
obj = getattr(obj, a)
|
||||||
|
except AttributeError:
|
||||||
|
if ignore:
|
||||||
|
return
|
||||||
|
raise AttributeError(f"Object {obj} has no attribute {attr}")
|
||||||
|
setattr(obj, attrs[-1], value)
|
||||||
|
|
||||||
|
def getattr_(obj, attr: str, ignore: bool=None):
|
||||||
|
"""
|
||||||
|
Get the object's multi sublevel attr
|
||||||
|
|
||||||
|
Args:
|
||||||
|
obj: The object to set
|
||||||
|
attr: The multi level attr to set
|
||||||
|
ignore: Whether to ignore when the attr doesn't exist
|
||||||
|
"""
|
||||||
|
|
||||||
|
attrs = attr.split('.')
|
||||||
|
for a in attrs:
|
||||||
|
try:
|
||||||
|
obj = getattr(obj, a)
|
||||||
|
except AttributeError:
|
||||||
|
if ignore:
|
||||||
|
return None
|
||||||
|
raise AttributeError(f"Object {obj} has no attribute {attr}")
|
||||||
|
return obj
|
Loading…
Reference in New Issue