[shardformer] add gpt2 policy and modify shard and slicer to support (#3883)

* add gpt2 policy and modify shard and slicer to support

* remove unused code

* polish code
pull/3943/head
FoolPlayer 2023-06-07 16:09:40 +08:00 committed by FrankLeeeee
parent 6370a935f6
commit ef1537759c
7 changed files with 233 additions and 44 deletions

View File

@ -10,16 +10,26 @@ def build_policies():
"""
auto_policy_dict = {}
from transformers.models.bert.modeling_bert import BertForMaskedLM
from transformers import BertForMaskedLM
from .bert import BertForMaskedLMPolicy
auto_policy_dict[BertForMaskedLM] = BertForMaskedLMPolicy
from transformers.models.bert.modeling_bert import BertForSequenceClassification
from transformers import BertForSequenceClassification
from .bert import BertForSequenceClassificationPolicy
auto_policy_dict[BertForSequenceClassification] = BertForSequenceClassificationPolicy
from transformers import GPT2Model
from .gpt2 import GPT2Policy
auto_policy_dict[GPT2Model] = GPT2Policy
from transformers import GPT2LMHeadModel
from .gpt2 import GPT2LMHeadModelPolicy
auto_policy_dict[GPT2LMHeadModel] = GPT2LMHeadModelPolicy
return auto_policy_dict

View File

@ -1,11 +1,9 @@
# part of code modified from https://github.com/tunib-ai/parallelformers
from dataclasses import dataclass, field
from dataclasses import dataclass
from typing import Any, Callable, Dict, List, Tuple, Type
import torch
import torch.nn as nn
from transformers import AutoConfig
@dataclass
@ -31,11 +29,18 @@ class Layer:
bias (str): The bias suffix of the layer
replace_layer (:class:`colosalai.nn`): The layer to replace the original layer
ignore (bool): Whether to ignore this layer if it is not in the model
reversed (bool): Whether the weight in layer is reversed, commonly the weight in `torch.nn.Linear` is [out, in],
but in GPT2 `Conv1D` layer is [in, out] which is reversed.
n_cast (int): The number of weight will cast to, like q, k, v in attention layer, n_cast should be 3. commonly in TP, we just chunk the weight with the number of devices,
but in multi-head attention, we need to chunk the weight with the number of devices * n_head, and
each device should have a part of Q, K and V weight.
"""
weight: str = None
bias: str = None
replace_layer: Any = None
ignore: bool = False
reversed: bool = False
n_cast: int = None
@dataclass
@ -131,7 +136,7 @@ class Policy():
(OrignModel, CustomModel)
in `CustomModel`, we can overwrite the forward and backward process
"""
return ()
return None
@staticmethod
def binding_policy() -> Dict:
@ -146,7 +151,7 @@ class Policy():
"bert.embeddings.word_embeddings.weight": "cls.predictions.decoder.weight",
}
"""
return NotImplementedError
return None
@staticmethod
def attn_in() -> List:
@ -209,4 +214,4 @@ class Policy():
Return:
List[Layer]: List of layer object
"""
return NotImplementedError
return None

View File

@ -1,4 +1,3 @@
from dataclasses import dataclass
from typing import Any, Callable, Dict, List, Tuple, Type
import torch.nn as nn

View File

@ -0,0 +1,118 @@
from typing import Any, Callable, Dict, List, Tuple, Type
import torch.nn as nn
from transformers.models.gpt2.modeling_gpt2 import GPT2Block, GPT2Model
import colossalai.shardformer.layer.layers as col_nn
from .basepolicy import Argument, Col_Layer, Layer, Policy, Row_Layer
class GPT2Policy(Policy):
@staticmethod
def argument_policy(config, world_size):
return {
GPT2Model:
Argument(attr_dict={}, param_funcs=[
GPT2Policy.embedding,
]),
GPT2Block:
Argument(
attr_dict={
# 1. reduce hidden size
"attn.embed_dim": config.hidden_size // world_size,
"attn.split_size": config.hidden_size // world_size,
"crossattention.embed_dim": config.hidden_size // world_size,
"crossattention.split_size": config.hidden_size // world_size,
# 2. reduce number of heads
"attn.num_heads": config.num_attention_heads // world_size,
"crossattention.num_heads": config.num_attention_heads // world_size,
},
param_funcs=[
GPT2Policy.attn_in,
GPT2Policy.attn_out,
GPT2Policy.mlp_in,
GPT2Policy.mlp_out,
]),
}
@staticmethod
def attn_in() -> List:
return [
Col_Layer(weight="attn.c_attn.weight",
bias="attn.c_attn.bias",
n_cast=3,
reversed=True,
replace_layer=col_nn.Linear1D_Col),
Col_Layer(weight="crossattention.c_attn.weight",
bias="crossattention.c_attn.bias",
n_cast=2,
reversed=True,
ignore=True,
replace_layer=col_nn.Linear1D_Col),
Col_Layer(weight="crossattention.q_attn.weight",
bias="crossattention.q_attn.bias",
reversed=True,
ignore=True,
replace_layer=col_nn.Linear1D_Col)
]
@staticmethod
def attn_out() -> List:
return [
Row_Layer(weight="attn.c_proj.weight",
bias="attn.c_proj.bias",
reversed=True,
replace_layer=col_nn.Linear1D_Row),
Row_Layer(weight="crossattention.c_proj.weight",
bias="crossattention.c_proj.bias",
reversed=True,
ignore=True,
replace_layer=col_nn.Linear1D_Row)
]
@staticmethod
def mlp_in() -> List:
return [
Col_Layer(weight="mlp.c_fc.weight", bias="mlp.c_fc.bias", reversed=True, replace_layer=col_nn.Linear1D_Col),
]
@staticmethod
def mlp_out() -> List:
return [
Row_Layer(weight="mlp.c_proj.weight",
bias="mlp.c_proj.bias",
reversed=True,
replace_layer=col_nn.Linear1D_Row)
]
@staticmethod
def embedding() -> List:
return [Col_Layer(weight="wte.weight", replace_layer=col_nn.VocabParallelEmbedding1D)]
from transformers import GPT2LMHeadModel
class GPT2LMHeadModelPolicy(GPT2Policy):
@staticmethod
def argument_policy(config, world_size):
base_argument = GPT2Policy.argument_policy(config, world_size)
argument = {
GPT2LMHeadModel: Argument(attr_dict={}, param_funcs=[
GPT2LMHeadModelPolicy.unembedding,
]),
}
argument.update(base_argument)
return argument
@staticmethod
def unembedding() -> List:
return [
Col_Layer(weight="lm_head.weight",
bias="lm_head.bias",
replace_layer=col_nn.Linear1D_Col,
gather_output=True)
]

View File

@ -2,6 +2,7 @@ from typing import Any, Callable, Dict, List
import torch
import torch.nn as nn
from transformers.pytorch_utils import Conv1D
from ..policies.autopolicy import get_autopolicy
from ..policies.basepolicy import Policy
@ -35,10 +36,22 @@ class ModelSharder(object):
self.model_config = self.model.config
def shard(self) -> None:
self.reshape_embedding()
self.inject_model(self.model)
self.replace_layer(self.model)
self.bind_layer(self.model)
def reshape_embedding(self,) -> None:
r"""
Reshape the Embedding layer to make the embedding dimension divisible by world_size
"""
vocab_size = self.model_config.vocab_size
world_size = self.shard_config.world_size
if vocab_size % world_size != 0:
new_vocab_size = vocab_size + world_size - vocab_size % world_size
self.model.resize_token_embeddings(new_vocab_size)
self.model_config = self.model.config
def inject_model(
self,
model: nn.Module,
@ -53,6 +66,8 @@ class ModelSharder(object):
"""
inject_policy = self.policy.inject_policy()
if inject_policy is None:
return
org_model_cls = inject_policy[0]
shard_model_cls = inject_policy[1]
@ -82,9 +97,9 @@ class ModelSharder(object):
origin_layer_cls = argument_policy[0]
attr_dict = argument_policy[1].attr_dict
param_funcs = argument_policy[1].param_funcs
self.reverse_replace_layer(model, origin_layer_cls, attr_dict, param_funcs)
self.traverse_replace_layer(model, origin_layer_cls, attr_dict, param_funcs)
def reverse_replace_layer(
def traverse_replace_layer(
self,
layer: nn.Module,
origin_cls: nn.Module,
@ -100,17 +115,12 @@ class ModelSharder(object):
attr_dict (Dict): The attribute dict to modify
policy_cls (:class:`Policy`): The policy class
"""
if layer.__class__ == origin_cls:
for k, v in attr_dict.items():
setattr_(layer, k, v, ignore=True)
self.shard_one_layer(layer, param_funcs)
for name, child in layer.named_children():
if child.__class__ == origin_cls:
# replac_layer = child
for k, v in attr_dict.items():
setattr_(child, k, v, ignore=True)
# print(f"Sharding {name} layer", replac_layer.attention.self.__dict__)
# setattr_(layer, name, self.shard_one_layer(child, policy_cls))
self.shard_one_layer(child, param_funcs)
continue
self.reverse_replace_layer(child, origin_cls, attr_dict, param_funcs)
self.traverse_replace_layer(child, origin_cls, attr_dict, param_funcs)
return layer
def shard_one_layer(
@ -126,7 +136,6 @@ class ModelSharder(object):
param_funcs (:class:`List[typing.Callable]`): The function list to get shard information in policy class
"""
# print(org_layer)
for func in param_funcs:
policy_layers = func()
for policy_layer in policy_layers:
@ -136,9 +145,10 @@ class ModelSharder(object):
bias_attr = policy_layer.bias
replace_layer_cls = policy_layer.replace_layer
ignore = policy_layer.ignore
n_cast = policy_layer.n_cast
reversed = policy_layer.reversed
if policy_layer.__class__.__name__ == "Col_Layer":
gather_output = policy_layer.gather_output
# print(gather_output)
if weight_attr is not None:
if hasattr_(org_layer, weight_attr):
@ -161,13 +171,11 @@ class ModelSharder(object):
layer_attr = (lambda x: x[:x.rfind(".")])(weight_attr or bias_attr)
# slice weight and bias
weight, bias = self.slicer.slice_weight_bias(weight, bias, policy_layer.__class__)
# print(os.environ['RANK'], policy_layer.__class__, weight.shape, bias.shape if bias is not None else None)
weight, bias = self.slicer.slice_weight_bias(weight, bias, policy_layer.__class__, n_cast, reversed)
# create new object to replace the origin layer
if replace_layer_cls is not None:
# print(f"RANK {os.environ['RANK']}: replace {getattr_(org_layer, layer_attr).__class__} to {replace_layer_cls}, shape is {weight.shape}")
if isinstance(getattr_(org_layer, layer_attr), nn.Linear):
if isinstance(getattr_(org_layer, layer_attr), (nn.Linear, Conv1D)):
if replace_layer_cls.__name__ == "Linear1D_Row":
replace_layer = replace_layer_cls(weight.shape[1],
weight.shape[0],
@ -235,6 +243,8 @@ class ModelSharder(object):
model (:class:`torch.nn.Module`): The shard model
"""
binding_map = self.policy.binding_policy()
if binding_map is None:
return
for k, v in binding_map.items():
param = getattr_(model, k)
param = nn.Parameter(param)

View File

@ -19,6 +19,8 @@ class Slicer():
weight: torch.Tensor,
bias: torch.Tensor,
policy_layer_cls: Layer,
n_cast: int = None,
reversed: bool = False,
):
r"""
Slice the weight and bias according to policy layer cls
@ -33,13 +35,18 @@ class Slicer():
"""
if policy_layer_cls == Layer:
return weight, bias
elif policy_layer_cls == Col_Layer:
weight = self.slice_tensor(weight, 1, False)
dim = dim_mapping[policy_layer_cls] if not reversed else (1 - dim_mapping[policy_layer_cls])
# print(weight.shape, dim)
if policy_layer_cls == Col_Layer:
weight = self.slice_tensor(weight, dim, False, n_cast)
bias = self.slice_tensor(bias, 0, True)
elif policy_layer_cls == Row_Layer:
weight = self.slice_tensor(weight, 0, False)
weight = self.slice_tensor(weight, dim, False, n_cast)
else:
raise NotImplementedError(f"The policy layer class {policy_layer_cls} is not supported")
if reversed:
weight = weight.transpose(0, 1).contiguous()
return weight, bias
def slice_tensor(
@ -47,6 +54,7 @@ class Slicer():
tensor_in: torch.Tensor,
dim: int,
is_bias: bool,
n_cast: int = None,
) -> torch.Tensor:
r"""
Slice tensor according to the config
@ -59,14 +67,15 @@ class Slicer():
if tensor_in is None:
return None
if not is_bias:
return self.slice_2d(tensor_in, dim)
return self.slice_2d(tensor_in, dim, n_cast)
else:
return self.slice_1d(tensor_in)
return self.slice_1d(tensor_in, n_cast)
def slice_2d(
self,
tensor: torch.Tensor,
dim: int,
n_cast: int = None,
) -> torch.Tensor:
r"""
Slice the 2D tensor
@ -77,13 +86,14 @@ class Slicer():
"""
assert dim in [0, 1], f"Only support 2D tensor, but got {dim}D tensor"
if dim == 0:
return self.slice_row(tensor)
return self.slice_row(tensor, n_cast)
elif dim == 1:
return self.slice_col(tensor)
return self.slice_col(tensor, n_cast)
def slice_1d(
self,
tensor: torch.Tensor,
n_cast: int = None,
) -> torch.Tensor:
r"""
Slice the 1D tensor
@ -94,11 +104,19 @@ class Slicer():
Returns:
:class:`torch.Tensor`: The sliced tensor
"""
return tensor.chunk(self.shardconfig.world_size, dim=0)[self.shardconfig.rank].contiguous()
if n_cast is None:
return tensor.chunk(self.shardconfig.world_size, dim=0)[self.shardconfig.rank].contiguous()
else:
tensor_chunks = tensor.chunk(self.shardconfig.world_size * n_cast, dim=0)
chunk_list = [
tensor_chunks[i] for i in range(self.shardconfig.rank, len(tensor_chunks), self.shardconfig.world_size)
]
return torch.cat(chunk_list, dim=0).contiguous()
def slice_col(
self,
tensor: torch.Tensor,
n_cast: int = None,
) -> torch.Tensor:
r"""
Slice the tensor in column
@ -110,11 +128,19 @@ class Slicer():
:class:`torch.Tensor`: The sliced tensor
"""
return tensor.chunk(self.shardconfig.world_size, dim=0)[self.shardconfig.rank].contiguous()
if n_cast is None:
return tensor.chunk(self.shardconfig.world_size, dim=0)[self.shardconfig.rank].contiguous()
else:
tensor_chunks = tensor.chunk(self.shardconfig.world_size * n_cast, dim=0)
chunk_list = [
tensor_chunks[i] for i in range(self.shardconfig.rank, len(tensor_chunks), self.shardconfig.world_size)
]
return torch.cat(chunk_list, dim=0).contiguous()
def slice_row(
self,
tensor: torch.Tensor,
n_cast: int = None,
) -> torch.Tensor:
r"""
Slice the tensor in column
@ -125,4 +151,11 @@ class Slicer():
Returns:
:class:`torch.Tensor`: The sliced tensor
"""
return tensor.chunk(self.shardconfig.world_size, dim=1)[self.shardconfig.rank].contiguous()
if n_cast is None:
return tensor.chunk(self.shardconfig.world_size, dim=1)[self.shardconfig.rank].contiguous()
else:
tensor_chunks = tensor.chunk(self.shardconfig.world_size * n_cast, dim=1)
chunk_list = [
tensor_chunks[i] for i in range(self.shardconfig.rank, len(tensor_chunks), self.shardconfig.world_size)
]
return torch.cat(chunk_list, dim=1).contiguous()

View File

@ -6,24 +6,28 @@ import torch.nn as nn
from datasets import load_dataset
from torch.utils.data import DataLoader
from tqdm.auto import tqdm
from transformers import AutoTokenizer, BertForMaskedLM, DataCollatorForLanguageModeling, get_scheduler
from transformers import AutoTokenizer, BertForMaskedLM, DataCollatorForLanguageModeling, GPT2LMHeadModel, get_scheduler
import colossalai
from colossalai.shardformer.shard import ShardConfig, shard_model
from colossalai.utils import get_current_device, print_rank_0
os.environ['TRANSFORMERS_NO_ADVISORY_WARNINGS'] = 'true'
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
def get_args():
parser = colossalai.get_default_parser()
parser.add_argument("--mode", type=str, default='inference')
parser.add_argument("--save_model", action='store_true')
parser.add_argument("--model", type=str, default='bert-base-uncased')
return parser.parse_args()
def load_data():
def load_data(args):
tokenizer = AutoTokenizer.from_pretrained(args.model)
if tokenizer.pad_token is None:
tokenizer.add_special_tokens({"pad_token": "[PAD]"})
# tokenizer.pad_token_id = 0
datasets = load_dataset('wikitext', 'wikitext-2-raw-v1')
# datasets=load_dataset("yelp_review_full")
tokenized_datasets = datasets.map(
@ -42,18 +46,23 @@ def load_data():
def inference(model: nn.Module, args):
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
print(model)
# print(model.wte.weight.shape)
tokenizer = AutoTokenizer.from_pretrained(args.model)
if tokenizer.pad_token is None:
tokenizer.add_special_tokens({"pad_token": "[PAD]"})
tokenizer.pad_token_id = 0
token = "Hello, my dog is cute"
inputs = tokenizer(token, return_tensors="pt")
inputs.to("cuda")
model.eval()
model.to("cuda")
outputs = model(**inputs)
print(outputs)
print(outputs[0])
def train(model: nn.Module, args, num_epoch: int = 3):
train_dataloader, eval_dataloader = load_data()
train_dataloader, eval_dataloader = load_data(args)
optimizer = torch.optim.AdamW(model.parameters(), lr=5e-5)
num_training = num_epoch * len(train_dataloader)
progress_bar = tqdm(range(num_training))
@ -94,8 +103,13 @@ def train(model: nn.Module, args, num_epoch: int = 3):
if __name__ == "__main__":
args = get_args()
model = BertForMaskedLM.from_pretrained("bert-base-uncased")
colossalai.launch_from_torch(config=args.config)
if args.model == 'bert-base-uncased':
model = BertForMaskedLM.from_pretrained("bert-base-uncased")
elif args.model == 'gpt2':
model = GPT2LMHeadModel.from_pretrained("gpt2")
else:
raise AttributeError("model not supported")
shard_config = ShardConfig(
rank=int(str(get_current_device()).split(':')[-1]),
world_size=int(os.environ['WORLD_SIZE']),