[shardformer] init shardformer code structure (#3731)

* init shardformer code structure

* add implement of sharder (inject and replace)

* add implement of replace layer to colossal layer

* separate different layer policy, add some notion

* implement 1d and 2d slicer, can tell col or row

* fix bug when slicing and inject model

* fix some bug; add inference test example
pull/4157/head
FoolPlayer 2023-05-22 15:02:17 +08:00 committed by Frank Lee
parent 3d8d5d0d58
commit 8d68de767d
16 changed files with 1033 additions and 0 deletions

View File

View File

View File

@ -0,0 +1,63 @@
import torch
import torch.nn as nn
from torch.nn import CrossEntropyLoss
from typing import Any, Dict, List, Type
from transformers import BertForMaskedLM
from transformers.models.bert.modeling_bert import MaskedLMOutput
class BertForMaskedLM_(BertForMaskedLM):
def forward(
self,
input_ids=None,
attention_mask=None,
token_type_ids=None,
position_ids=None,
head_mask=None,
inputs_embeds=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
labels=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
**kwargs,
):
print("[Inject OK] Injected forward method")
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
outputs = self.bert(
input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
sequence_output = outputs[0]
prediction_scores = self.cls(sequence_output)
masked_lm_loss = None
# if input_ids is not None:
# masked_lm_loss = applyDistCrossEntropy(prediction_scores, input_ids, self.config.vocab_size)
if labels is not None:
loss_fct = CrossEntropyLoss() # -100 index = padding token
masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
if not return_dict:
output = (prediction_scores,) + outputs[2:]
return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
return MaskedLMOutput(
loss=masked_lm_loss,
logits=prediction_scores,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)

View File

@ -0,0 +1,41 @@
import torch.nn as nn
def build_policies():
"""
Build the policies for the model
Return:
The dict for the policies
"""
auto_policy_dict = {}
from transformers.models.bert.modeling_bert import BertForMaskedLM
from .bert import BertForMaskedLMPolicy
auto_policy_dict[BertForMaskedLM] = BertForMaskedLMPolicy
from transformers.models.bert.modeling_bert import BertForSequenceClassification
from .bert import BertForSequenceClassificationPolicy
auto_policy_dict[BertForSequenceClassification] = BertForSequenceClassificationPolicy
return auto_policy_dict
def get_autopolicy(model:nn.Module):
"""
Return the auto policy for the model
Args:
model: The model to be used
Return:
The auto policy for the model
"""
auto_policy_dict = build_policies()
policy = auto_policy_dict.get(model.__class__, None)
if policy is None:
raise NotImplementedError(f"Auto policy for {model.__class__.__qualname__} is not implemented\n Supported models are {[i.__qualname__ for i in auto_policy_dict.keys()]}")
return policy
# from transformers.models.bert.modeling_bert import BertForMaskedLM, BertForPreTraining
# model = BertForPreTraining
# policy = get_autopolicy(model)
# print(policy)

View File

@ -0,0 +1,182 @@
# part of code modified from https://github.com/tunib-ai/parallelformers
import torch
import torch.nn as nn
import colossalai.nn as col_nn
from typing import Any, Dict, List, Type, Tuple, Callable
from transformers import AutoConfig
from dataclasses import dataclass, field
@dataclass
class Argument:
attr_dict : Dict[str, Any]
param_funcs : List[Callable]
binding_layers : List[nn.Module] = field(default_factory=list)
@dataclass
class Layer:
"""
The layer object for the policy
Args:
weight: The weight name of the layer
bias: The bias name of the layer
replace_layer: The layer to replace the original layer
ignore: Whether to ignore this layer if it is not in the model
"""
weight: str = None
bias: str = None
replace_layer: Any = None
ignore: bool = False
@dataclass
class Col_Layer(Layer):
"""
Class for col shard layer in MegatronLM
"""
gather_output: bool = False
@dataclass
class Row_Layer(Layer):
"""
Class for col shard layer in MegatronLM
"""
pass
class Policy():
"""
The base class for all the policies
For each different model, it should have a different policy class, like BertPolicy for Bert Model
or OPTPolicy for OPT model.
AutoPolicy:
shardformer already defined some policies for huggingface model, just set custom_policy = None
to use the auto policy. In shardformer autopolicy, we define a base policy for one type model,
like BertPolicy, and for each different Bert modle in huggingface like, BertForMaskedLM,
BertForSequenceClassification, etc., for each different Bert model we difine different policy class
and overwrite the method inject_policy
CustomPolicy:
"""
@staticmethod
def argument_policy(model_config, shard_config: int) -> Dict[nn.Module,Argument]:
"""
Return a dict, the key is layer will be modified and the value is the Argument class with param setting and param functions
Args:
model_config: The config of transformer model
shard_setting: The config of distributed model
Return:
Dict for the modify policy,
{
origin layer class1 (nn.Module): Argument(
attr_dict = {
argument1: value1,
argument2: value2,
...
},
param_funcs = [
staticmethod1,
staticmethod2,
...
]
),
origin layer class2 (nn.Module): Argument(
attr_dict = {
argument1: value1,
argument2: value2,
...
},
param_funcs = [
staticmethod1,
staticmethod2,
...
]
),
...
}
"""
raise NotImplementedError
@staticmethod
def inject_policy() -> Tuple[nn.Module, nn.Module]:
"""
Return the dict for the inject model
Return:
The injected model, key is the original model and value is the new shardmodel
"""
return ()
@staticmethod
def attn_in() -> List:
"""
Attention qkv layer
Returns:
List[Layer]: List of layer object, each layer is the new
"""
return NotImplementedError
@staticmethod
def attn_out() -> List:
"""
Attention output projection layer
Returns:
List[Layer]: List of layer object
"""
return NotImplementedError
@staticmethod
def mlp_in() -> List:
"""
h -> 4h mlp layer
Returns:
List[Layer]: List of layer object
"""
return NotImplementedError
@staticmethod
def mlp_out() -> List:
"""
4h -> h mlp layer
Returns:
List[Layer]: List of layer object
"""
return NotImplementedError
@staticmethod
def embedding()->List:
"""
Partially slice the embedding layer
vocab_size->vocab_size//gpu_nums
Return:
List[Layer]: List of layer object
"""
return NotImplementedError
@staticmethod
def unembedding()->List:
"""
Partially slice the embedding layer
vocab_size->vocab_size//gpu_nums
Return:
List[Layer]: List of layer object
"""
return NotImplementedError

View File

@ -0,0 +1,168 @@
from typing import Dict, List, Tuple, Type, Any, Callable
import torch.nn as nn
from .basepolicy import Policy, Layer, Argument, Col_Layer, Row_Layer
import colossalai.nn as col_nn
from transformers.models.bert.modeling_bert import BertLayer, BertEmbeddings, BertLMPredictionHead
from dataclasses import dataclass
class BertPolicy(Policy):
@staticmethod
def argument_policy(config, world_size: int) -> Dict[nn.Module,Argument]:
return {
BertLayer: Argument(
attr_dict = {
# 1. shard hidden size
"attention.self.all_head_size": config.hidden_size // world_size,
"crossattention.self.all_head_size": config.hidden_size // world_size,
# 2. shard number of heads
"attention.self.num_attention_heads": config.num_attention_heads // world_size,
"crossattention.self.num_attention_heads": config.num_attention_heads // world_size,
},
param_funcs = [
BertPolicy.attn_in,
BertPolicy.attn_out,
BertPolicy.mlp_in,
BertPolicy.mlp_out
]
),
BertEmbeddings: Argument(
attr_dict = {
# 1. shard vocab size
# "word_embeddings.num_embeddings": config.vocab_size // world_size,
# 2. add the size of the sliced embedding layer excluding the last slice
"word_embeddings.dim_size": (config.vocab_size+world_size-1) // world_size,
},
param_funcs = [
BertPolicy.embedding,
],
binding_layers = [
BertLMPredictionHead,
]
),
BertLMPredictionHead: Argument(
attr_dict = {
# 1. shard vocab size
# "word_embeddings.num_embeddings": config.vocab_size // world_size,
# 2. add the size of the sliced embedding layer excluding the last slice
},
param_funcs = [
BertPolicy.unembedding,
]
)
}
@staticmethod
def attn_in() -> List:
return [
Col_Layer(
weight="attention.self.query.weight",
bias="attention.self.query.bias",
replace_layer=col_nn.Linear1D_Col,
),
Col_Layer(
weight="attention.self.key.weight",
bias="attention.self.key.bias",
replace_layer=col_nn.Linear1D_Col,
),
Col_Layer(
weight="attention.self.value.weight",
bias="attention.self.value.bias",
replace_layer=col_nn.Linear1D_Col,
),
Col_Layer(
weight="crossattention.self.query.weight",
bias="crossattention.self.query.bias",
replace_layer=col_nn.Linear1D_Col,
ignore=True,
),
Col_Layer(
weight="crossattention.self.key.weight",
bias="crossattention.self.key.bias",
replace_layer=col_nn.Linear1D_Col,
ignore=True,
),
Col_Layer(
weight="crossattention.self.value.weight",
bias="crossattention.self.value.bias",
replace_layer=col_nn.Linear1D_Col,
ignore=True,
),
]
@staticmethod
def attn_out() -> List:
return [
Row_Layer(
weight="attention.output.dense.weight",
bias="attention.output.dense.bias",
replace_layer=col_nn.Linear1D_Row,
),
Row_Layer(
weight="crossattention.output.dense.weight",
bias="crossattention.output.dense.bias",
replace_layer=col_nn.Linear1D_Row,
ignore=True,
),
]
@staticmethod
def mlp_in() -> List:
return [
Col_Layer(
weight="intermediate.dense.weight",
bias="intermediate.dense.bias",
replace_layer=col_nn.Linear1D_Col,
),
]
@staticmethod
def mlp_out() -> List:
return [
Row_Layer(
weight="output.dense.weight",
bias="output.dense.bias",
replace_layer=col_nn.Linear1D_Row,
),
]
@staticmethod
def embedding() -> List:
return [
Col_Layer(
weight="word_embeddings.weight",
replace_layer=col_nn.VocabParallelEmbedding1D,
)
]
@staticmethod
def unembedding() -> List:
return [
Col_Layer(
weight="decoder.weight",
bias="decoder.bias",
replace_layer=col_nn.Linear1D_Col,
gather_output=True,
)
]
from transformers import BertForMaskedLM
from colossalai.shardformer.model.modeling_bert import BertForMaskedLM_
class BertForMaskedLMPolicy(BertPolicy):
@staticmethod
def inject_policy() -> Tuple[nn.Module, nn.Module]:
return (BertForMaskedLM, BertForMaskedLM_)
class BertForSequenceClassificationPolicy(BertPolicy):
@staticmethod
def inject_policy() -> Dict:
return {}
# model = BertForMaskedLM.from_pretrained("bert-base-uncased")
# _ = BertForMaskedLMPolicy(model)
# print(isinstance(model,list(_.inject_policy().keys())[0]))

View File

View File

@ -0,0 +1,18 @@
from dataclasses import dataclass
@dataclass
class ShardConfig:
"""
The config for sharding the huggingface model for test
"""
rank: int
fp16: bool = True
num_gpus: int = 2
world_size: int = 2
backend="nccl"
verbose: str = 'simple'
seed: int = None
require_grad: bool = False
master_addr: str = "127.0.0.1"
master_port: int = 29500

View File

@ -0,0 +1,238 @@
import torch
import torch.nn as nn
from typing import Any, Dict, Iterable, List, Optional, Tuple, Type, Union, Callable
from .shardconfig import ShardConfig
from dataclasses import dataclass
from ..policies.basepolicy import Policy, Layer
from ..policies.autopolicy import get_autopolicy
from .slicer import Slicer
from ..utils.utils import hasattr_, setattr_, getattr_
import colossalai.nn as col_nn
from colossalai.logging import get_dist_logger
import os
logger = get_dist_logger()
class ModelSharder(object):
"""
Shard the original huggingface model according to the policy
Args:
policy: The policy to shard the model
model: The model to shard
dist_setting: The setting of distributed model
"""
def __init__(
self,
model: nn.Module,
policy: Policy,
shard_config: ShardConfig = None, # TODO
) -> None:
self.model = model
self.policy = get_autopolicy(self.model) if policy is None else policy
self.slicer = Slicer(shard_config)
self.shard_config = shard_config
self.model_config = self.model.config
self.binding_map = {}
def shard(self) -> None:
self.inject_model(self.model)
self.replace_layer(self.model)
def inject_model(
self,
model: nn.Module,
) -> None:
"""
Replace the model to policy defined model
Mainly modify the forward and backward to fit distributed model
e.g.
BertForMaskedLM.forward -> BertForMaskedLM_.forward
"""
inject_policy = self.policy.inject_policy()
org_model_cls = inject_policy[0]
shard_model_cls = inject_policy[1]
if model.__class__ == org_model_cls:
for key in shard_model_cls.__dict__.keys():
if hasattr(model.__class__, key):
setattr(
model.__class__,
key,
getattr(shard_model_cls,key),
)
else:
raise NotImplementedError(f"{model.__class__} is not implemented so far")
def replace_layer(
self,
model: nn.Module,
) -> None:
"""
Replace the layer according to the policy, and replace the layer one by one
Args:
layer: The layer to shard
"""
argument_policies = self.policy.argument_policy(self.model_config, self.shard_config.world_size)
for argument_policy in argument_policies.items():
origin_layer_cls = argument_policy[0]
attr_dict = argument_policy[1].attr_dict
param_funcs = argument_policy[1].param_funcs
binding_layers = argument_policy[1].binding_layers
# if binding_layer is not None:
# self.binding_map[origin_layer_cls] = binding_layer
self.reverse_replace_layer(model, origin_layer_cls, attr_dict, param_funcs, binding_layers)
def reverse_replace_layer(
self,
layer: nn.Module,
origin_cls: nn.Module,
attr_dict: Dict[str, Any],
param_funcs: List[Callable],
binding_layers: List[nn.Module]
) -> None:
"""
Reverse the replace layer operation
Args:
layer: The object of layer to shard
origin_cls: The origin layer class
attr_dict: The attribute dict to modify
policy_cls: The policy class
"""
for name, child in layer.named_children():
if child.__class__ == origin_cls:
# replac_layer = child
for k, v in attr_dict.items():
setattr_(child, k, v, ignore=True)
# print(f"Sharding {name} layer", replac_layer.attention.self.__dict__)
# setattr_(layer, name, self.shard_one_layer(child, policy_cls))
self.shard_one_layer(child, param_funcs, binding_layers)
continue
self.reverse_replace_layer(child, origin_cls, attr_dict, param_funcs, binding_layers)
return layer
def shard_one_layer(
self,
org_layer: nn.Module,
param_funcs: List[Callable],
binding_layers: List[nn.Module]
) -> None:
"""
Shard one layer according to the policy, the layer should be the same class as the key in policy's argument_policy return dict
Args:
org_layer: The origin layer object to shard
param_funcs: The function list to get shard information in policy class
"""
# print(org_layer)
for func in param_funcs:
policy_layers = func()
for policy_layer in policy_layers:
weight = None
bias = None
weight_attr = policy_layer.weight
bias_attr = policy_layer.bias
replace_layer_cls = policy_layer.replace_layer
ignore = policy_layer.ignore
if policy_layer.__class__.__name__ == "Col_Layer":
gather_output = policy_layer.gather_output
print(gather_output)
if weight_attr is not None:
if hasattr_(org_layer, weight_attr):
weight = getattr_(org_layer, weight_attr)
elif not ignore:
raise ValueError(f"Layer {org_layer.__class__.__qualname__} has no attribute {weight_attr}")
if bias_attr is not None:
if hasattr_(org_layer, bias_attr):
bias = getattr_(org_layer, bias_attr)
elif not ignore:
raise ValueError(f"Layer {org_layer.__class__.__qualname__} has no attribute {bias_attr}")
# dont have the attribute in policy, and ignore is true
if weight is None and bias is None and ignore:
continue
# set the sliced weight and bias to the new nn_col layer
assert weight is not None or bias is not None
layer_attr = (lambda x: x[:x.rfind(".")])(weight_attr or bias_attr)
# slice weight and bias
weight, bias = self.slicer.slice_weight_bias(weight, bias, policy_layer.__class__)
print(os.environ['RANK'], policy_layer.__class__, weight.shape, bias.shape if bias is not None else None)
# save the binding information
for binding_layer in binding_layers:
self.binding_map[binding_layer] = dict(weight=weight, bias=bias)
# create new object to replace the origin layer
if replace_layer_cls is not None:
# print(f"RANK {os.environ['RANK']}: replace {getattr_(org_layer, layer_attr).__class__} to {replace_layer_cls}, shape is {weight.shape}")
if isinstance(getattr_(org_layer, layer_attr), nn.Linear):
if replace_layer_cls.__name__ == "Linear1D_Row":
replace_layer = replace_layer_cls(weight.shape[1], weight.shape[0], bias=False if bias is None else True)
elif replace_layer_cls.__name__ == "Linear1D_Col":
replace_layer = replace_layer_cls(weight.shape[0], weight.shape[1], bias=False if bias is None else True, gather_output=gather_output)
setattr_(org_layer, layer_attr, replace_layer, ignore=ignore)
self.set_param(replace_layer, weight, bias)
elif isinstance(getattr_(org_layer, layer_attr), nn.Embedding):
replace_layer = replace_layer_cls(weight.shape[0], weight.shape[1], getattr_(org_layer, f"{layer_attr}.padding_idx", ignore=True))
setattr_(org_layer, layer_attr, replace_layer, ignore=ignore)
self.set_param(replace_layer, weight, bias)
else:
raise NotImplementedError(f"Replacing {getattr_(org_layer, layer_attr).__class__} is not implemented so far")
# do not replace the layer object, just replace the weight and bias
else:
self.set_param(org_layer, layer_attr, weight, bias)
def set_param(
self,
layer: Any,
layer_attr: str = "",
weight: torch.Tensor = None,
bias: torch.Tensor = None
) -> None:
"""
Reset the weight and bias of the layer object
Args:
layer: The layer object
layer_attr: The attribute name of the layer
weight: The weight of the layer
bias: The bias of the layer
"""
assert weight is not None or bias is not None
if weight is not None:
setattr_(layer, "weight" if layer_attr == "" else layer_attr+".weight", nn.Parameter(weight))
self.set_layer_size(layer, layer_attr, weight.shape)
if bias is not None:
setattr_(layer, "bias" if layer_attr == "" else layer_attr+".bias", nn.Parameter(bias))
def set_layer_size(self, layer: nn.Module, layer_attr: str, size: torch.Size) -> None:
"""
Set the layer attribute
Args:
layer: The layer object
layer_attr: The attribute name of the layer
size: Torch.size
"""
# Tensor.shape[0] -> out_features, Tensor.shape[1] -> in_features
attrs = ["out_features", "in_features"]
for i, attr in enumerate(attrs):
if hasattr_(layer, f"{layer_attr}.{attr}"):
setattr_(layer, f"{layer_attr}.{attr}", size[i])

View File

@ -0,0 +1,58 @@
import os
import torch
import torch.nn as nn
import transformers
import torch.distributed as dist
from dataclasses import dataclass
from contextlib import suppress
from colossalai.tensor.d_tensor.layout import Layout
from ..policies.basepolicy import Policy
from .sharder import ModelSharder
from .shardconfig import ShardConfig
class ShardModel(object):
"""
The class for sharding the huggingface model, self.model is the sharded model
Just creat a new ShardModel object to shard huggingface model
Args:
model: the origin huggingface model
dist_config: the config for distribute information
custom_policy: the custom policy for sharding
"""
def __init__(
self,
model: nn.Module,
shard_config: ShardConfig = None, # TODO
custom_policy: Policy = None,
) -> None:
self.model = model
self.shard_config = shard_config
self.policy = custom_policy
# self.layout=, # TODO
sharder=ModelSharder(
model=self.model,
policy=self.policy,
shard_config=self.shard_config,
)
sharder.shard()
def set_environ(self) -> None:
os.environ["TOKENIZERS_PARALLELISM"] = "true"
os.environ["MKL_SERVICE_FORCE_INTEL"] = "GNU"
os.environ["MASTER_ADDR"] = str(self.dist_config.master_addr)
os.environ["MASTER_PORT"] = str(self.dist_config.master_port)
os.environ["WORLD_SIZE"] = str(self.dist_config.num_gpus)
os.environ["RANK"] = str(self.dist_config.rank)
os.environ["LOCAL_RANK"] = str(self.dist_config.rank)
if not dist.is_initialized():
dist.init_process_group(backend=self.dist_config.backend)
torch.cuda.set_device(int(os.getenv("LOCAL_RANK", "0")))
def back_to_org() -> None:
pass

View File

@ -0,0 +1,167 @@
import os
from typing import Dict, Tuple
from dataclasses import dataclass
import torch
import torch.distributed as dist
from ..policies.basepolicy import Layer, Col_Layer, Row_Layer
from .shardconfig import ShardConfig
dim_mapping = {Col_Layer: 1, Row_Layer: 0}
class Slicer():
def __init__(
self,
shardconfig: ShardConfig #TODO
) -> None:
self.shardconfig = shardconfig
def slice_weight_bias(
self,
weight: torch.Tensor,
bias: torch.Tensor,
policy_layer_cls: Layer,
):
"""
Slice the weight and bias according to policy layer cls
Layer -> do nothing
Col_Layer -> slice the weight and bias along dim 1
Row_Layer -> slice the weight along dim 0 and do not slice bias
Args:
weight: The weight of the layer
bias: The bias of the layer
policy_layer_class: The class represent how to slice the tensor
"""
if policy_layer_cls == Layer:
return weight, bias
elif policy_layer_cls == Col_Layer:
weight = self.slice_tensor(weight, 1, False)
bias = self.slice_tensor(bias, 0, True)
elif policy_layer_cls == Row_Layer:
weight = self.slice_tensor(weight, 0, False)
else:
raise NotImplementedError(f"The policy layer class {policy_layer_cls} is not supported")
return weight, bias
def slice_weight(
self,
weight: torch.Tensor,
policy_layer_cls: Layer,
) -> torch.Tensor:
"""
Slice the weight and bias according to the shardconfig
Args:
weight: The weight of the layer
bias: The bias of the layer
policy_layer_class: The class represent how to slice the tensor
"""
if weight is not None:
dim = dim_mapping[policy_layer_cls]
weight = self.slice_tensor(weight, dim, False)
return weight
def slice_bias(
self,
bias: torch.Tensor,
) -> torch.Tensor:
"""
Slice the bias according to the shardconfig
Args:
bias: The bias of the layer
"""
assert bias is not None, "The bias is None"
if bias is not None:
bias = self.slice_tensor(bias, 1, True)
return bias
def slice_tensor(
self,
tensor_in: torch.Tensor,
dim: int,
is_bias: bool,
) -> torch.Tensor:
"""
Slice tensor according to the config
"""
if tensor_in is None:
return None
if not is_bias:
return self.slice_2d(tensor_in, dim)
else:
return self.slice_1d(tensor_in)
def slice_2d(
self,
tensor: torch.Tensor,
dim: int,
) -> torch.Tensor:
"""
Slice the 2D tensor
Args:
tensor: The tensor to slice
"""
assert dim in [0,1], f"Only support 2D tensor, but got {dim}D tensor"
if dim == 0:
return self.slice_row(tensor)
elif dim == 1:
return self.slice_col(tensor)
def slice_1d(
self,
tensor: torch.Tensor,
dim: int = None,
) -> torch.Tensor:
"""
Slice the 1D tensor
Args:
tensor: The tensor to slice
"""
delta = (tensor.shape[0] + self.shardconfig.world_size - 1) // self.shardconfig.world_size
down_idx = self.shardconfig.rank * delta
up_idx = down_idx + delta
return tensor[down_idx:up_idx]
def slice_col(
self,
tensor: torch.Tensor,
) -> torch.Tensor:
"""
Slice the tensor in column
Args:
tensor: The tensor to slice
"""
delta = (tensor.shape[0] + self.shardconfig.world_size - 1) // self.shardconfig.world_size
down_idx = self.shardconfig.rank * delta
up_idx = down_idx + delta
return tensor[down_idx:up_idx,:]
def slice_row(
self,
tensor: torch.Tensor,
) -> torch.Tensor:
"""
Slice the tensor in column
Args:
tensor: The tensor to slice
"""
delta = (tensor.shape[1] + self.shardconfig.world_size - 1) // self.shardconfig.world_size
down_idx = self.shardconfig.rank * delta
up_idx = down_idx + delta
return tensor[:,down_idx:up_idx]

View File

@ -0,0 +1,5 @@
parallel = dict(
data=1,
pipeline=1,
tensor=dict(size=2, mode='1d')
)

View File

@ -0,0 +1,37 @@
from transformers import AutoTokenizer
from transformers import BertForMaskedLM
import colossalai
from colossalai.shardformer.shard.shardmodel import ShardModel
from colossalai.utils import get_current_device, print_rank_0
from colossalai.logging import get_dist_logger
from colossalai.shardformer.shard.shardconfig import ShardConfig
import inspect
import argparse
import torch.nn as nn
import os
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
def get_args():
parser = colossalai.get_default_parser()
return parser.parse_args()
def inference(model: nn.Module):
# print(model)
token = "Hello, my dog is cute"
inputs = tokenizer(token, return_tensors="pt")
inputs.to("cuda")
model.to("cuda")
outputs = model(**inputs)
print(outputs)
if __name__ == "__main__":
args = get_args()
colossalai.launch_from_torch(config=args.config)
model = BertForMaskedLM.from_pretrained("bert-base-uncased")
shard_config = ShardConfig(
rank = int(str(get_current_device()).split(':')[-1]),
world_size= int(os.environ['WORLD_SIZE']),
)
shardmodel = ShardModel(model, shard_config)
inference(shardmodel.model)

View File

View File

@ -0,0 +1,56 @@
def hasattr_(obj, attr: str):
"""
Check whether the object has the multi sublevel attr
Args:
obj: The object to check
attr: The multi level attr to check
"""
attrs = attr.split('.')
for a in attrs:
try:
obj = getattr(obj, a)
except AttributeError:
return False
return True
def setattr_(obj, attr: str, value, ignore: bool=False):
"""
Set the object's multi sublevel attr to value, if ignore, ignore when it doesn't exist
Args:
obj: The object to set
attr: The multi level attr to set
value: The value to set
ignore: Whether to ignore when the attr doesn't exist
"""
attrs = attr.split('.')
for a in attrs[:-1]:
try:
obj = getattr(obj, a)
except AttributeError:
if ignore:
return
raise AttributeError(f"Object {obj} has no attribute {attr}")
setattr(obj, attrs[-1], value)
def getattr_(obj, attr: str, ignore: bool=None):
"""
Get the object's multi sublevel attr
Args:
obj: The object to set
attr: The multi level attr to set
ignore: Whether to ignore when the attr doesn't exist
"""
attrs = attr.split('.')
for a in attrs:
try:
obj = getattr(obj, a)
except AttributeError:
if ignore:
return None
raise AttributeError(f"Object {obj} has no attribute {attr}")
return obj