[shardformer] add gpt2 policy and modify shard and slicer to support (#3883)

* add gpt2 policy and modify shard and slicer to support

* remove unused code

* polish code
pull/4157/head
FoolPlayer 2023-06-07 16:09:40 +08:00 committed by Frank Lee
parent 70173e3123
commit 79f8d5d54b
7 changed files with 233 additions and 44 deletions

View File

@ -10,16 +10,26 @@ def build_policies():
""" """
auto_policy_dict = {} auto_policy_dict = {}
from transformers.models.bert.modeling_bert import BertForMaskedLM from transformers import BertForMaskedLM
from .bert import BertForMaskedLMPolicy from .bert import BertForMaskedLMPolicy
auto_policy_dict[BertForMaskedLM] = BertForMaskedLMPolicy auto_policy_dict[BertForMaskedLM] = BertForMaskedLMPolicy
from transformers.models.bert.modeling_bert import BertForSequenceClassification from transformers import BertForSequenceClassification
from .bert import BertForSequenceClassificationPolicy from .bert import BertForSequenceClassificationPolicy
auto_policy_dict[BertForSequenceClassification] = BertForSequenceClassificationPolicy auto_policy_dict[BertForSequenceClassification] = BertForSequenceClassificationPolicy
from transformers import GPT2Model
from .gpt2 import GPT2Policy
auto_policy_dict[GPT2Model] = GPT2Policy
from transformers import GPT2LMHeadModel
from .gpt2 import GPT2LMHeadModelPolicy
auto_policy_dict[GPT2LMHeadModel] = GPT2LMHeadModelPolicy
return auto_policy_dict return auto_policy_dict

View File

@ -1,11 +1,9 @@
# part of code modified from https://github.com/tunib-ai/parallelformers # part of code modified from https://github.com/tunib-ai/parallelformers
from dataclasses import dataclass, field from dataclasses import dataclass
from typing import Any, Callable, Dict, List, Tuple, Type from typing import Any, Callable, Dict, List, Tuple, Type
import torch
import torch.nn as nn import torch.nn as nn
from transformers import AutoConfig
@dataclass @dataclass
@ -31,11 +29,18 @@ class Layer:
bias (str): The bias suffix of the layer bias (str): The bias suffix of the layer
replace_layer (:class:`colosalai.nn`): The layer to replace the original layer replace_layer (:class:`colosalai.nn`): The layer to replace the original layer
ignore (bool): Whether to ignore this layer if it is not in the model ignore (bool): Whether to ignore this layer if it is not in the model
reversed (bool): Whether the weight in layer is reversed, commonly the weight in `torch.nn.Linear` is [out, in],
but in GPT2 `Conv1D` layer is [in, out] which is reversed.
n_cast (int): The number of weight will cast to, like q, k, v in attention layer, n_cast should be 3. commonly in TP, we just chunk the weight with the number of devices,
but in multi-head attention, we need to chunk the weight with the number of devices * n_head, and
each device should have a part of Q, K and V weight.
""" """
weight: str = None weight: str = None
bias: str = None bias: str = None
replace_layer: Any = None replace_layer: Any = None
ignore: bool = False ignore: bool = False
reversed: bool = False
n_cast: int = None
@dataclass @dataclass
@ -131,7 +136,7 @@ class Policy():
(OrignModel, CustomModel) (OrignModel, CustomModel)
in `CustomModel`, we can overwrite the forward and backward process in `CustomModel`, we can overwrite the forward and backward process
""" """
return () return None
@staticmethod @staticmethod
def binding_policy() -> Dict: def binding_policy() -> Dict:
@ -146,7 +151,7 @@ class Policy():
"bert.embeddings.word_embeddings.weight": "cls.predictions.decoder.weight", "bert.embeddings.word_embeddings.weight": "cls.predictions.decoder.weight",
} }
""" """
return NotImplementedError return None
@staticmethod @staticmethod
def attn_in() -> List: def attn_in() -> List:
@ -209,4 +214,4 @@ class Policy():
Return: Return:
List[Layer]: List of layer object List[Layer]: List of layer object
""" """
return NotImplementedError return None

View File

@ -1,4 +1,3 @@
from dataclasses import dataclass
from typing import Any, Callable, Dict, List, Tuple, Type from typing import Any, Callable, Dict, List, Tuple, Type
import torch.nn as nn import torch.nn as nn

View File

@ -0,0 +1,118 @@
from typing import Any, Callable, Dict, List, Tuple, Type
import torch.nn as nn
from transformers.models.gpt2.modeling_gpt2 import GPT2Block, GPT2Model
import colossalai.shardformer.layer.layers as col_nn
from .basepolicy import Argument, Col_Layer, Layer, Policy, Row_Layer
class GPT2Policy(Policy):
@staticmethod
def argument_policy(config, world_size):
return {
GPT2Model:
Argument(attr_dict={}, param_funcs=[
GPT2Policy.embedding,
]),
GPT2Block:
Argument(
attr_dict={
# 1. reduce hidden size
"attn.embed_dim": config.hidden_size // world_size,
"attn.split_size": config.hidden_size // world_size,
"crossattention.embed_dim": config.hidden_size // world_size,
"crossattention.split_size": config.hidden_size // world_size,
# 2. reduce number of heads
"attn.num_heads": config.num_attention_heads // world_size,
"crossattention.num_heads": config.num_attention_heads // world_size,
},
param_funcs=[
GPT2Policy.attn_in,
GPT2Policy.attn_out,
GPT2Policy.mlp_in,
GPT2Policy.mlp_out,
]),
}
@staticmethod
def attn_in() -> List:
return [
Col_Layer(weight="attn.c_attn.weight",
bias="attn.c_attn.bias",
n_cast=3,
reversed=True,
replace_layer=col_nn.Linear1D_Col),
Col_Layer(weight="crossattention.c_attn.weight",
bias="crossattention.c_attn.bias",
n_cast=2,
reversed=True,
ignore=True,
replace_layer=col_nn.Linear1D_Col),
Col_Layer(weight="crossattention.q_attn.weight",
bias="crossattention.q_attn.bias",
reversed=True,
ignore=True,
replace_layer=col_nn.Linear1D_Col)
]
@staticmethod
def attn_out() -> List:
return [
Row_Layer(weight="attn.c_proj.weight",
bias="attn.c_proj.bias",
reversed=True,
replace_layer=col_nn.Linear1D_Row),
Row_Layer(weight="crossattention.c_proj.weight",
bias="crossattention.c_proj.bias",
reversed=True,
ignore=True,
replace_layer=col_nn.Linear1D_Row)
]
@staticmethod
def mlp_in() -> List:
return [
Col_Layer(weight="mlp.c_fc.weight", bias="mlp.c_fc.bias", reversed=True, replace_layer=col_nn.Linear1D_Col),
]
@staticmethod
def mlp_out() -> List:
return [
Row_Layer(weight="mlp.c_proj.weight",
bias="mlp.c_proj.bias",
reversed=True,
replace_layer=col_nn.Linear1D_Row)
]
@staticmethod
def embedding() -> List:
return [Col_Layer(weight="wte.weight", replace_layer=col_nn.VocabParallelEmbedding1D)]
from transformers import GPT2LMHeadModel
class GPT2LMHeadModelPolicy(GPT2Policy):
@staticmethod
def argument_policy(config, world_size):
base_argument = GPT2Policy.argument_policy(config, world_size)
argument = {
GPT2LMHeadModel: Argument(attr_dict={}, param_funcs=[
GPT2LMHeadModelPolicy.unembedding,
]),
}
argument.update(base_argument)
return argument
@staticmethod
def unembedding() -> List:
return [
Col_Layer(weight="lm_head.weight",
bias="lm_head.bias",
replace_layer=col_nn.Linear1D_Col,
gather_output=True)
]

View File

@ -2,6 +2,7 @@ from typing import Any, Callable, Dict, List
import torch import torch
import torch.nn as nn import torch.nn as nn
from transformers.pytorch_utils import Conv1D
from ..policies.autopolicy import get_autopolicy from ..policies.autopolicy import get_autopolicy
from ..policies.basepolicy import Policy from ..policies.basepolicy import Policy
@ -35,10 +36,22 @@ class ModelSharder(object):
self.model_config = self.model.config self.model_config = self.model.config
def shard(self) -> None: def shard(self) -> None:
self.reshape_embedding()
self.inject_model(self.model) self.inject_model(self.model)
self.replace_layer(self.model) self.replace_layer(self.model)
self.bind_layer(self.model) self.bind_layer(self.model)
def reshape_embedding(self,) -> None:
r"""
Reshape the Embedding layer to make the embedding dimension divisible by world_size
"""
vocab_size = self.model_config.vocab_size
world_size = self.shard_config.world_size
if vocab_size % world_size != 0:
new_vocab_size = vocab_size + world_size - vocab_size % world_size
self.model.resize_token_embeddings(new_vocab_size)
self.model_config = self.model.config
def inject_model( def inject_model(
self, self,
model: nn.Module, model: nn.Module,
@ -53,6 +66,8 @@ class ModelSharder(object):
""" """
inject_policy = self.policy.inject_policy() inject_policy = self.policy.inject_policy()
if inject_policy is None:
return
org_model_cls = inject_policy[0] org_model_cls = inject_policy[0]
shard_model_cls = inject_policy[1] shard_model_cls = inject_policy[1]
@ -82,9 +97,9 @@ class ModelSharder(object):
origin_layer_cls = argument_policy[0] origin_layer_cls = argument_policy[0]
attr_dict = argument_policy[1].attr_dict attr_dict = argument_policy[1].attr_dict
param_funcs = argument_policy[1].param_funcs param_funcs = argument_policy[1].param_funcs
self.reverse_replace_layer(model, origin_layer_cls, attr_dict, param_funcs) self.traverse_replace_layer(model, origin_layer_cls, attr_dict, param_funcs)
def reverse_replace_layer( def traverse_replace_layer(
self, self,
layer: nn.Module, layer: nn.Module,
origin_cls: nn.Module, origin_cls: nn.Module,
@ -100,17 +115,12 @@ class ModelSharder(object):
attr_dict (Dict): The attribute dict to modify attr_dict (Dict): The attribute dict to modify
policy_cls (:class:`Policy`): The policy class policy_cls (:class:`Policy`): The policy class
""" """
if layer.__class__ == origin_cls:
for k, v in attr_dict.items():
setattr_(layer, k, v, ignore=True)
self.shard_one_layer(layer, param_funcs)
for name, child in layer.named_children(): for name, child in layer.named_children():
if child.__class__ == origin_cls: self.traverse_replace_layer(child, origin_cls, attr_dict, param_funcs)
# replac_layer = child
for k, v in attr_dict.items():
setattr_(child, k, v, ignore=True)
# print(f"Sharding {name} layer", replac_layer.attention.self.__dict__)
# setattr_(layer, name, self.shard_one_layer(child, policy_cls))
self.shard_one_layer(child, param_funcs)
continue
self.reverse_replace_layer(child, origin_cls, attr_dict, param_funcs)
return layer return layer
def shard_one_layer( def shard_one_layer(
@ -126,7 +136,6 @@ class ModelSharder(object):
param_funcs (:class:`List[typing.Callable]`): The function list to get shard information in policy class param_funcs (:class:`List[typing.Callable]`): The function list to get shard information in policy class
""" """
# print(org_layer)
for func in param_funcs: for func in param_funcs:
policy_layers = func() policy_layers = func()
for policy_layer in policy_layers: for policy_layer in policy_layers:
@ -136,9 +145,10 @@ class ModelSharder(object):
bias_attr = policy_layer.bias bias_attr = policy_layer.bias
replace_layer_cls = policy_layer.replace_layer replace_layer_cls = policy_layer.replace_layer
ignore = policy_layer.ignore ignore = policy_layer.ignore
n_cast = policy_layer.n_cast
reversed = policy_layer.reversed
if policy_layer.__class__.__name__ == "Col_Layer": if policy_layer.__class__.__name__ == "Col_Layer":
gather_output = policy_layer.gather_output gather_output = policy_layer.gather_output
# print(gather_output)
if weight_attr is not None: if weight_attr is not None:
if hasattr_(org_layer, weight_attr): if hasattr_(org_layer, weight_attr):
@ -161,13 +171,11 @@ class ModelSharder(object):
layer_attr = (lambda x: x[:x.rfind(".")])(weight_attr or bias_attr) layer_attr = (lambda x: x[:x.rfind(".")])(weight_attr or bias_attr)
# slice weight and bias # slice weight and bias
weight, bias = self.slicer.slice_weight_bias(weight, bias, policy_layer.__class__) weight, bias = self.slicer.slice_weight_bias(weight, bias, policy_layer.__class__, n_cast, reversed)
# print(os.environ['RANK'], policy_layer.__class__, weight.shape, bias.shape if bias is not None else None)
# create new object to replace the origin layer # create new object to replace the origin layer
if replace_layer_cls is not None: if replace_layer_cls is not None:
# print(f"RANK {os.environ['RANK']}: replace {getattr_(org_layer, layer_attr).__class__} to {replace_layer_cls}, shape is {weight.shape}") if isinstance(getattr_(org_layer, layer_attr), (nn.Linear, Conv1D)):
if isinstance(getattr_(org_layer, layer_attr), nn.Linear):
if replace_layer_cls.__name__ == "Linear1D_Row": if replace_layer_cls.__name__ == "Linear1D_Row":
replace_layer = replace_layer_cls(weight.shape[1], replace_layer = replace_layer_cls(weight.shape[1],
weight.shape[0], weight.shape[0],
@ -235,6 +243,8 @@ class ModelSharder(object):
model (:class:`torch.nn.Module`): The shard model model (:class:`torch.nn.Module`): The shard model
""" """
binding_map = self.policy.binding_policy() binding_map = self.policy.binding_policy()
if binding_map is None:
return
for k, v in binding_map.items(): for k, v in binding_map.items():
param = getattr_(model, k) param = getattr_(model, k)
param = nn.Parameter(param) param = nn.Parameter(param)

View File

@ -19,6 +19,8 @@ class Slicer():
weight: torch.Tensor, weight: torch.Tensor,
bias: torch.Tensor, bias: torch.Tensor,
policy_layer_cls: Layer, policy_layer_cls: Layer,
n_cast: int = None,
reversed: bool = False,
): ):
r""" r"""
Slice the weight and bias according to policy layer cls Slice the weight and bias according to policy layer cls
@ -33,13 +35,18 @@ class Slicer():
""" """
if policy_layer_cls == Layer: if policy_layer_cls == Layer:
return weight, bias return weight, bias
elif policy_layer_cls == Col_Layer:
weight = self.slice_tensor(weight, 1, False) dim = dim_mapping[policy_layer_cls] if not reversed else (1 - dim_mapping[policy_layer_cls])
# print(weight.shape, dim)
if policy_layer_cls == Col_Layer:
weight = self.slice_tensor(weight, dim, False, n_cast)
bias = self.slice_tensor(bias, 0, True) bias = self.slice_tensor(bias, 0, True)
elif policy_layer_cls == Row_Layer: elif policy_layer_cls == Row_Layer:
weight = self.slice_tensor(weight, 0, False) weight = self.slice_tensor(weight, dim, False, n_cast)
else: else:
raise NotImplementedError(f"The policy layer class {policy_layer_cls} is not supported") raise NotImplementedError(f"The policy layer class {policy_layer_cls} is not supported")
if reversed:
weight = weight.transpose(0, 1).contiguous()
return weight, bias return weight, bias
def slice_tensor( def slice_tensor(
@ -47,6 +54,7 @@ class Slicer():
tensor_in: torch.Tensor, tensor_in: torch.Tensor,
dim: int, dim: int,
is_bias: bool, is_bias: bool,
n_cast: int = None,
) -> torch.Tensor: ) -> torch.Tensor:
r""" r"""
Slice tensor according to the config Slice tensor according to the config
@ -59,14 +67,15 @@ class Slicer():
if tensor_in is None: if tensor_in is None:
return None return None
if not is_bias: if not is_bias:
return self.slice_2d(tensor_in, dim) return self.slice_2d(tensor_in, dim, n_cast)
else: else:
return self.slice_1d(tensor_in) return self.slice_1d(tensor_in, n_cast)
def slice_2d( def slice_2d(
self, self,
tensor: torch.Tensor, tensor: torch.Tensor,
dim: int, dim: int,
n_cast: int = None,
) -> torch.Tensor: ) -> torch.Tensor:
r""" r"""
Slice the 2D tensor Slice the 2D tensor
@ -77,13 +86,14 @@ class Slicer():
""" """
assert dim in [0, 1], f"Only support 2D tensor, but got {dim}D tensor" assert dim in [0, 1], f"Only support 2D tensor, but got {dim}D tensor"
if dim == 0: if dim == 0:
return self.slice_row(tensor) return self.slice_row(tensor, n_cast)
elif dim == 1: elif dim == 1:
return self.slice_col(tensor) return self.slice_col(tensor, n_cast)
def slice_1d( def slice_1d(
self, self,
tensor: torch.Tensor, tensor: torch.Tensor,
n_cast: int = None,
) -> torch.Tensor: ) -> torch.Tensor:
r""" r"""
Slice the 1D tensor Slice the 1D tensor
@ -94,11 +104,19 @@ class Slicer():
Returns: Returns:
:class:`torch.Tensor`: The sliced tensor :class:`torch.Tensor`: The sliced tensor
""" """
return tensor.chunk(self.shardconfig.world_size, dim=0)[self.shardconfig.rank].contiguous() if n_cast is None:
return tensor.chunk(self.shardconfig.world_size, dim=0)[self.shardconfig.rank].contiguous()
else:
tensor_chunks = tensor.chunk(self.shardconfig.world_size * n_cast, dim=0)
chunk_list = [
tensor_chunks[i] for i in range(self.shardconfig.rank, len(tensor_chunks), self.shardconfig.world_size)
]
return torch.cat(chunk_list, dim=0).contiguous()
def slice_col( def slice_col(
self, self,
tensor: torch.Tensor, tensor: torch.Tensor,
n_cast: int = None,
) -> torch.Tensor: ) -> torch.Tensor:
r""" r"""
Slice the tensor in column Slice the tensor in column
@ -110,11 +128,19 @@ class Slicer():
:class:`torch.Tensor`: The sliced tensor :class:`torch.Tensor`: The sliced tensor
""" """
return tensor.chunk(self.shardconfig.world_size, dim=0)[self.shardconfig.rank].contiguous() if n_cast is None:
return tensor.chunk(self.shardconfig.world_size, dim=0)[self.shardconfig.rank].contiguous()
else:
tensor_chunks = tensor.chunk(self.shardconfig.world_size * n_cast, dim=0)
chunk_list = [
tensor_chunks[i] for i in range(self.shardconfig.rank, len(tensor_chunks), self.shardconfig.world_size)
]
return torch.cat(chunk_list, dim=0).contiguous()
def slice_row( def slice_row(
self, self,
tensor: torch.Tensor, tensor: torch.Tensor,
n_cast: int = None,
) -> torch.Tensor: ) -> torch.Tensor:
r""" r"""
Slice the tensor in column Slice the tensor in column
@ -125,4 +151,11 @@ class Slicer():
Returns: Returns:
:class:`torch.Tensor`: The sliced tensor :class:`torch.Tensor`: The sliced tensor
""" """
return tensor.chunk(self.shardconfig.world_size, dim=1)[self.shardconfig.rank].contiguous() if n_cast is None:
return tensor.chunk(self.shardconfig.world_size, dim=1)[self.shardconfig.rank].contiguous()
else:
tensor_chunks = tensor.chunk(self.shardconfig.world_size * n_cast, dim=1)
chunk_list = [
tensor_chunks[i] for i in range(self.shardconfig.rank, len(tensor_chunks), self.shardconfig.world_size)
]
return torch.cat(chunk_list, dim=1).contiguous()

View File

@ -6,24 +6,28 @@ import torch.nn as nn
from datasets import load_dataset from datasets import load_dataset
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from tqdm.auto import tqdm from tqdm.auto import tqdm
from transformers import AutoTokenizer, BertForMaskedLM, DataCollatorForLanguageModeling, get_scheduler from transformers import AutoTokenizer, BertForMaskedLM, DataCollatorForLanguageModeling, GPT2LMHeadModel, get_scheduler
import colossalai import colossalai
from colossalai.shardformer.shard import ShardConfig, shard_model from colossalai.shardformer.shard import ShardConfig, shard_model
from colossalai.utils import get_current_device, print_rank_0 from colossalai.utils import get_current_device, print_rank_0
os.environ['TRANSFORMERS_NO_ADVISORY_WARNINGS'] = 'true' os.environ['TRANSFORMERS_NO_ADVISORY_WARNINGS'] = 'true'
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
def get_args(): def get_args():
parser = colossalai.get_default_parser() parser = colossalai.get_default_parser()
parser.add_argument("--mode", type=str, default='inference') parser.add_argument("--mode", type=str, default='inference')
parser.add_argument("--save_model", action='store_true') parser.add_argument("--save_model", action='store_true')
parser.add_argument("--model", type=str, default='bert-base-uncased')
return parser.parse_args() return parser.parse_args()
def load_data(): def load_data(args):
tokenizer = AutoTokenizer.from_pretrained(args.model)
if tokenizer.pad_token is None:
tokenizer.add_special_tokens({"pad_token": "[PAD]"})
# tokenizer.pad_token_id = 0
datasets = load_dataset('wikitext', 'wikitext-2-raw-v1') datasets = load_dataset('wikitext', 'wikitext-2-raw-v1')
# datasets=load_dataset("yelp_review_full") # datasets=load_dataset("yelp_review_full")
tokenized_datasets = datasets.map( tokenized_datasets = datasets.map(
@ -42,18 +46,23 @@ def load_data():
def inference(model: nn.Module, args): def inference(model: nn.Module, args):
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased") print(model)
# print(model.wte.weight.shape)
tokenizer = AutoTokenizer.from_pretrained(args.model)
if tokenizer.pad_token is None:
tokenizer.add_special_tokens({"pad_token": "[PAD]"})
tokenizer.pad_token_id = 0
token = "Hello, my dog is cute" token = "Hello, my dog is cute"
inputs = tokenizer(token, return_tensors="pt") inputs = tokenizer(token, return_tensors="pt")
inputs.to("cuda") inputs.to("cuda")
model.eval() model.eval()
model.to("cuda") model.to("cuda")
outputs = model(**inputs) outputs = model(**inputs)
print(outputs) print(outputs[0])
def train(model: nn.Module, args, num_epoch: int = 3): def train(model: nn.Module, args, num_epoch: int = 3):
train_dataloader, eval_dataloader = load_data() train_dataloader, eval_dataloader = load_data(args)
optimizer = torch.optim.AdamW(model.parameters(), lr=5e-5) optimizer = torch.optim.AdamW(model.parameters(), lr=5e-5)
num_training = num_epoch * len(train_dataloader) num_training = num_epoch * len(train_dataloader)
progress_bar = tqdm(range(num_training)) progress_bar = tqdm(range(num_training))
@ -94,8 +103,13 @@ def train(model: nn.Module, args, num_epoch: int = 3):
if __name__ == "__main__": if __name__ == "__main__":
args = get_args() args = get_args()
model = BertForMaskedLM.from_pretrained("bert-base-uncased")
colossalai.launch_from_torch(config=args.config) colossalai.launch_from_torch(config=args.config)
if args.model == 'bert-base-uncased':
model = BertForMaskedLM.from_pretrained("bert-base-uncased")
elif args.model == 'gpt2':
model = GPT2LMHeadModel.from_pretrained("gpt2")
else:
raise AttributeError("model not supported")
shard_config = ShardConfig( shard_config = ShardConfig(
rank=int(str(get_current_device()).split(':')[-1]), rank=int(str(get_current_device()).split(':')[-1]),
world_size=int(os.environ['WORLD_SIZE']), world_size=int(os.environ['WORLD_SIZE']),