mirror of https://github.com/hpcaitech/ColossalAI
[shardformer] add gpt2 policy and modify shard and slicer to support (#3883)
* add gpt2 policy and modify shard and slicer to support * remove unused code * polish codepull/4157/head
parent
70173e3123
commit
79f8d5d54b
|
@ -10,16 +10,26 @@ def build_policies():
|
||||||
"""
|
"""
|
||||||
auto_policy_dict = {}
|
auto_policy_dict = {}
|
||||||
|
|
||||||
from transformers.models.bert.modeling_bert import BertForMaskedLM
|
from transformers import BertForMaskedLM
|
||||||
|
|
||||||
from .bert import BertForMaskedLMPolicy
|
from .bert import BertForMaskedLMPolicy
|
||||||
auto_policy_dict[BertForMaskedLM] = BertForMaskedLMPolicy
|
auto_policy_dict[BertForMaskedLM] = BertForMaskedLMPolicy
|
||||||
|
|
||||||
from transformers.models.bert.modeling_bert import BertForSequenceClassification
|
from transformers import BertForSequenceClassification
|
||||||
|
|
||||||
from .bert import BertForSequenceClassificationPolicy
|
from .bert import BertForSequenceClassificationPolicy
|
||||||
auto_policy_dict[BertForSequenceClassification] = BertForSequenceClassificationPolicy
|
auto_policy_dict[BertForSequenceClassification] = BertForSequenceClassificationPolicy
|
||||||
|
|
||||||
|
from transformers import GPT2Model
|
||||||
|
|
||||||
|
from .gpt2 import GPT2Policy
|
||||||
|
auto_policy_dict[GPT2Model] = GPT2Policy
|
||||||
|
|
||||||
|
from transformers import GPT2LMHeadModel
|
||||||
|
|
||||||
|
from .gpt2 import GPT2LMHeadModelPolicy
|
||||||
|
auto_policy_dict[GPT2LMHeadModel] = GPT2LMHeadModelPolicy
|
||||||
|
|
||||||
return auto_policy_dict
|
return auto_policy_dict
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -1,11 +1,9 @@
|
||||||
# part of code modified from https://github.com/tunib-ai/parallelformers
|
# part of code modified from https://github.com/tunib-ai/parallelformers
|
||||||
|
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass
|
||||||
from typing import Any, Callable, Dict, List, Tuple, Type
|
from typing import Any, Callable, Dict, List, Tuple, Type
|
||||||
|
|
||||||
import torch
|
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from transformers import AutoConfig
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
|
@ -31,11 +29,18 @@ class Layer:
|
||||||
bias (str): The bias suffix of the layer
|
bias (str): The bias suffix of the layer
|
||||||
replace_layer (:class:`colosalai.nn`): The layer to replace the original layer
|
replace_layer (:class:`colosalai.nn`): The layer to replace the original layer
|
||||||
ignore (bool): Whether to ignore this layer if it is not in the model
|
ignore (bool): Whether to ignore this layer if it is not in the model
|
||||||
|
reversed (bool): Whether the weight in layer is reversed, commonly the weight in `torch.nn.Linear` is [out, in],
|
||||||
|
but in GPT2 `Conv1D` layer is [in, out] which is reversed.
|
||||||
|
n_cast (int): The number of weight will cast to, like q, k, v in attention layer, n_cast should be 3. commonly in TP, we just chunk the weight with the number of devices,
|
||||||
|
but in multi-head attention, we need to chunk the weight with the number of devices * n_head, and
|
||||||
|
each device should have a part of Q, K and V weight.
|
||||||
"""
|
"""
|
||||||
weight: str = None
|
weight: str = None
|
||||||
bias: str = None
|
bias: str = None
|
||||||
replace_layer: Any = None
|
replace_layer: Any = None
|
||||||
ignore: bool = False
|
ignore: bool = False
|
||||||
|
reversed: bool = False
|
||||||
|
n_cast: int = None
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
|
@ -131,7 +136,7 @@ class Policy():
|
||||||
(OrignModel, CustomModel)
|
(OrignModel, CustomModel)
|
||||||
in `CustomModel`, we can overwrite the forward and backward process
|
in `CustomModel`, we can overwrite the forward and backward process
|
||||||
"""
|
"""
|
||||||
return ()
|
return None
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def binding_policy() -> Dict:
|
def binding_policy() -> Dict:
|
||||||
|
@ -146,7 +151,7 @@ class Policy():
|
||||||
"bert.embeddings.word_embeddings.weight": "cls.predictions.decoder.weight",
|
"bert.embeddings.word_embeddings.weight": "cls.predictions.decoder.weight",
|
||||||
}
|
}
|
||||||
"""
|
"""
|
||||||
return NotImplementedError
|
return None
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def attn_in() -> List:
|
def attn_in() -> List:
|
||||||
|
@ -209,4 +214,4 @@ class Policy():
|
||||||
Return:
|
Return:
|
||||||
List[Layer]: List of layer object
|
List[Layer]: List of layer object
|
||||||
"""
|
"""
|
||||||
return NotImplementedError
|
return None
|
||||||
|
|
|
@ -1,4 +1,3 @@
|
||||||
from dataclasses import dataclass
|
|
||||||
from typing import Any, Callable, Dict, List, Tuple, Type
|
from typing import Any, Callable, Dict, List, Tuple, Type
|
||||||
|
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
|
@ -0,0 +1,118 @@
|
||||||
|
from typing import Any, Callable, Dict, List, Tuple, Type
|
||||||
|
|
||||||
|
import torch.nn as nn
|
||||||
|
from transformers.models.gpt2.modeling_gpt2 import GPT2Block, GPT2Model
|
||||||
|
|
||||||
|
import colossalai.shardformer.layer.layers as col_nn
|
||||||
|
|
||||||
|
from .basepolicy import Argument, Col_Layer, Layer, Policy, Row_Layer
|
||||||
|
|
||||||
|
|
||||||
|
class GPT2Policy(Policy):
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def argument_policy(config, world_size):
|
||||||
|
return {
|
||||||
|
GPT2Model:
|
||||||
|
Argument(attr_dict={}, param_funcs=[
|
||||||
|
GPT2Policy.embedding,
|
||||||
|
]),
|
||||||
|
GPT2Block:
|
||||||
|
Argument(
|
||||||
|
attr_dict={
|
||||||
|
# 1. reduce hidden size
|
||||||
|
"attn.embed_dim": config.hidden_size // world_size,
|
||||||
|
"attn.split_size": config.hidden_size // world_size,
|
||||||
|
"crossattention.embed_dim": config.hidden_size // world_size,
|
||||||
|
"crossattention.split_size": config.hidden_size // world_size,
|
||||||
|
# 2. reduce number of heads
|
||||||
|
"attn.num_heads": config.num_attention_heads // world_size,
|
||||||
|
"crossattention.num_heads": config.num_attention_heads // world_size,
|
||||||
|
},
|
||||||
|
param_funcs=[
|
||||||
|
GPT2Policy.attn_in,
|
||||||
|
GPT2Policy.attn_out,
|
||||||
|
GPT2Policy.mlp_in,
|
||||||
|
GPT2Policy.mlp_out,
|
||||||
|
]),
|
||||||
|
}
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def attn_in() -> List:
|
||||||
|
return [
|
||||||
|
Col_Layer(weight="attn.c_attn.weight",
|
||||||
|
bias="attn.c_attn.bias",
|
||||||
|
n_cast=3,
|
||||||
|
reversed=True,
|
||||||
|
replace_layer=col_nn.Linear1D_Col),
|
||||||
|
Col_Layer(weight="crossattention.c_attn.weight",
|
||||||
|
bias="crossattention.c_attn.bias",
|
||||||
|
n_cast=2,
|
||||||
|
reversed=True,
|
||||||
|
ignore=True,
|
||||||
|
replace_layer=col_nn.Linear1D_Col),
|
||||||
|
Col_Layer(weight="crossattention.q_attn.weight",
|
||||||
|
bias="crossattention.q_attn.bias",
|
||||||
|
reversed=True,
|
||||||
|
ignore=True,
|
||||||
|
replace_layer=col_nn.Linear1D_Col)
|
||||||
|
]
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def attn_out() -> List:
|
||||||
|
return [
|
||||||
|
Row_Layer(weight="attn.c_proj.weight",
|
||||||
|
bias="attn.c_proj.bias",
|
||||||
|
reversed=True,
|
||||||
|
replace_layer=col_nn.Linear1D_Row),
|
||||||
|
Row_Layer(weight="crossattention.c_proj.weight",
|
||||||
|
bias="crossattention.c_proj.bias",
|
||||||
|
reversed=True,
|
||||||
|
ignore=True,
|
||||||
|
replace_layer=col_nn.Linear1D_Row)
|
||||||
|
]
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def mlp_in() -> List:
|
||||||
|
return [
|
||||||
|
Col_Layer(weight="mlp.c_fc.weight", bias="mlp.c_fc.bias", reversed=True, replace_layer=col_nn.Linear1D_Col),
|
||||||
|
]
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def mlp_out() -> List:
|
||||||
|
return [
|
||||||
|
Row_Layer(weight="mlp.c_proj.weight",
|
||||||
|
bias="mlp.c_proj.bias",
|
||||||
|
reversed=True,
|
||||||
|
replace_layer=col_nn.Linear1D_Row)
|
||||||
|
]
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def embedding() -> List:
|
||||||
|
return [Col_Layer(weight="wte.weight", replace_layer=col_nn.VocabParallelEmbedding1D)]
|
||||||
|
|
||||||
|
|
||||||
|
from transformers import GPT2LMHeadModel
|
||||||
|
|
||||||
|
|
||||||
|
class GPT2LMHeadModelPolicy(GPT2Policy):
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def argument_policy(config, world_size):
|
||||||
|
base_argument = GPT2Policy.argument_policy(config, world_size)
|
||||||
|
argument = {
|
||||||
|
GPT2LMHeadModel: Argument(attr_dict={}, param_funcs=[
|
||||||
|
GPT2LMHeadModelPolicy.unembedding,
|
||||||
|
]),
|
||||||
|
}
|
||||||
|
argument.update(base_argument)
|
||||||
|
return argument
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def unembedding() -> List:
|
||||||
|
return [
|
||||||
|
Col_Layer(weight="lm_head.weight",
|
||||||
|
bias="lm_head.bias",
|
||||||
|
replace_layer=col_nn.Linear1D_Col,
|
||||||
|
gather_output=True)
|
||||||
|
]
|
|
@ -2,6 +2,7 @@ from typing import Any, Callable, Dict, List
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
from transformers.pytorch_utils import Conv1D
|
||||||
|
|
||||||
from ..policies.autopolicy import get_autopolicy
|
from ..policies.autopolicy import get_autopolicy
|
||||||
from ..policies.basepolicy import Policy
|
from ..policies.basepolicy import Policy
|
||||||
|
@ -35,10 +36,22 @@ class ModelSharder(object):
|
||||||
self.model_config = self.model.config
|
self.model_config = self.model.config
|
||||||
|
|
||||||
def shard(self) -> None:
|
def shard(self) -> None:
|
||||||
|
self.reshape_embedding()
|
||||||
self.inject_model(self.model)
|
self.inject_model(self.model)
|
||||||
self.replace_layer(self.model)
|
self.replace_layer(self.model)
|
||||||
self.bind_layer(self.model)
|
self.bind_layer(self.model)
|
||||||
|
|
||||||
|
def reshape_embedding(self,) -> None:
|
||||||
|
r"""
|
||||||
|
Reshape the Embedding layer to make the embedding dimension divisible by world_size
|
||||||
|
"""
|
||||||
|
vocab_size = self.model_config.vocab_size
|
||||||
|
world_size = self.shard_config.world_size
|
||||||
|
if vocab_size % world_size != 0:
|
||||||
|
new_vocab_size = vocab_size + world_size - vocab_size % world_size
|
||||||
|
self.model.resize_token_embeddings(new_vocab_size)
|
||||||
|
self.model_config = self.model.config
|
||||||
|
|
||||||
def inject_model(
|
def inject_model(
|
||||||
self,
|
self,
|
||||||
model: nn.Module,
|
model: nn.Module,
|
||||||
|
@ -53,6 +66,8 @@ class ModelSharder(object):
|
||||||
"""
|
"""
|
||||||
inject_policy = self.policy.inject_policy()
|
inject_policy = self.policy.inject_policy()
|
||||||
|
|
||||||
|
if inject_policy is None:
|
||||||
|
return
|
||||||
org_model_cls = inject_policy[0]
|
org_model_cls = inject_policy[0]
|
||||||
shard_model_cls = inject_policy[1]
|
shard_model_cls = inject_policy[1]
|
||||||
|
|
||||||
|
@ -82,9 +97,9 @@ class ModelSharder(object):
|
||||||
origin_layer_cls = argument_policy[0]
|
origin_layer_cls = argument_policy[0]
|
||||||
attr_dict = argument_policy[1].attr_dict
|
attr_dict = argument_policy[1].attr_dict
|
||||||
param_funcs = argument_policy[1].param_funcs
|
param_funcs = argument_policy[1].param_funcs
|
||||||
self.reverse_replace_layer(model, origin_layer_cls, attr_dict, param_funcs)
|
self.traverse_replace_layer(model, origin_layer_cls, attr_dict, param_funcs)
|
||||||
|
|
||||||
def reverse_replace_layer(
|
def traverse_replace_layer(
|
||||||
self,
|
self,
|
||||||
layer: nn.Module,
|
layer: nn.Module,
|
||||||
origin_cls: nn.Module,
|
origin_cls: nn.Module,
|
||||||
|
@ -100,17 +115,12 @@ class ModelSharder(object):
|
||||||
attr_dict (Dict): The attribute dict to modify
|
attr_dict (Dict): The attribute dict to modify
|
||||||
policy_cls (:class:`Policy`): The policy class
|
policy_cls (:class:`Policy`): The policy class
|
||||||
"""
|
"""
|
||||||
|
if layer.__class__ == origin_cls:
|
||||||
|
for k, v in attr_dict.items():
|
||||||
|
setattr_(layer, k, v, ignore=True)
|
||||||
|
self.shard_one_layer(layer, param_funcs)
|
||||||
for name, child in layer.named_children():
|
for name, child in layer.named_children():
|
||||||
if child.__class__ == origin_cls:
|
self.traverse_replace_layer(child, origin_cls, attr_dict, param_funcs)
|
||||||
# replac_layer = child
|
|
||||||
for k, v in attr_dict.items():
|
|
||||||
setattr_(child, k, v, ignore=True)
|
|
||||||
# print(f"Sharding {name} layer", replac_layer.attention.self.__dict__)
|
|
||||||
# setattr_(layer, name, self.shard_one_layer(child, policy_cls))
|
|
||||||
self.shard_one_layer(child, param_funcs)
|
|
||||||
continue
|
|
||||||
|
|
||||||
self.reverse_replace_layer(child, origin_cls, attr_dict, param_funcs)
|
|
||||||
return layer
|
return layer
|
||||||
|
|
||||||
def shard_one_layer(
|
def shard_one_layer(
|
||||||
|
@ -126,7 +136,6 @@ class ModelSharder(object):
|
||||||
param_funcs (:class:`List[typing.Callable]`): The function list to get shard information in policy class
|
param_funcs (:class:`List[typing.Callable]`): The function list to get shard information in policy class
|
||||||
|
|
||||||
"""
|
"""
|
||||||
# print(org_layer)
|
|
||||||
for func in param_funcs:
|
for func in param_funcs:
|
||||||
policy_layers = func()
|
policy_layers = func()
|
||||||
for policy_layer in policy_layers:
|
for policy_layer in policy_layers:
|
||||||
|
@ -136,9 +145,10 @@ class ModelSharder(object):
|
||||||
bias_attr = policy_layer.bias
|
bias_attr = policy_layer.bias
|
||||||
replace_layer_cls = policy_layer.replace_layer
|
replace_layer_cls = policy_layer.replace_layer
|
||||||
ignore = policy_layer.ignore
|
ignore = policy_layer.ignore
|
||||||
|
n_cast = policy_layer.n_cast
|
||||||
|
reversed = policy_layer.reversed
|
||||||
if policy_layer.__class__.__name__ == "Col_Layer":
|
if policy_layer.__class__.__name__ == "Col_Layer":
|
||||||
gather_output = policy_layer.gather_output
|
gather_output = policy_layer.gather_output
|
||||||
# print(gather_output)
|
|
||||||
|
|
||||||
if weight_attr is not None:
|
if weight_attr is not None:
|
||||||
if hasattr_(org_layer, weight_attr):
|
if hasattr_(org_layer, weight_attr):
|
||||||
|
@ -161,13 +171,11 @@ class ModelSharder(object):
|
||||||
layer_attr = (lambda x: x[:x.rfind(".")])(weight_attr or bias_attr)
|
layer_attr = (lambda x: x[:x.rfind(".")])(weight_attr or bias_attr)
|
||||||
|
|
||||||
# slice weight and bias
|
# slice weight and bias
|
||||||
weight, bias = self.slicer.slice_weight_bias(weight, bias, policy_layer.__class__)
|
weight, bias = self.slicer.slice_weight_bias(weight, bias, policy_layer.__class__, n_cast, reversed)
|
||||||
# print(os.environ['RANK'], policy_layer.__class__, weight.shape, bias.shape if bias is not None else None)
|
|
||||||
|
|
||||||
# create new object to replace the origin layer
|
# create new object to replace the origin layer
|
||||||
if replace_layer_cls is not None:
|
if replace_layer_cls is not None:
|
||||||
# print(f"RANK {os.environ['RANK']}: replace {getattr_(org_layer, layer_attr).__class__} to {replace_layer_cls}, shape is {weight.shape}")
|
if isinstance(getattr_(org_layer, layer_attr), (nn.Linear, Conv1D)):
|
||||||
if isinstance(getattr_(org_layer, layer_attr), nn.Linear):
|
|
||||||
if replace_layer_cls.__name__ == "Linear1D_Row":
|
if replace_layer_cls.__name__ == "Linear1D_Row":
|
||||||
replace_layer = replace_layer_cls(weight.shape[1],
|
replace_layer = replace_layer_cls(weight.shape[1],
|
||||||
weight.shape[0],
|
weight.shape[0],
|
||||||
|
@ -235,6 +243,8 @@ class ModelSharder(object):
|
||||||
model (:class:`torch.nn.Module`): The shard model
|
model (:class:`torch.nn.Module`): The shard model
|
||||||
"""
|
"""
|
||||||
binding_map = self.policy.binding_policy()
|
binding_map = self.policy.binding_policy()
|
||||||
|
if binding_map is None:
|
||||||
|
return
|
||||||
for k, v in binding_map.items():
|
for k, v in binding_map.items():
|
||||||
param = getattr_(model, k)
|
param = getattr_(model, k)
|
||||||
param = nn.Parameter(param)
|
param = nn.Parameter(param)
|
||||||
|
|
|
@ -19,6 +19,8 @@ class Slicer():
|
||||||
weight: torch.Tensor,
|
weight: torch.Tensor,
|
||||||
bias: torch.Tensor,
|
bias: torch.Tensor,
|
||||||
policy_layer_cls: Layer,
|
policy_layer_cls: Layer,
|
||||||
|
n_cast: int = None,
|
||||||
|
reversed: bool = False,
|
||||||
):
|
):
|
||||||
r"""
|
r"""
|
||||||
Slice the weight and bias according to policy layer cls
|
Slice the weight and bias according to policy layer cls
|
||||||
|
@ -33,13 +35,18 @@ class Slicer():
|
||||||
"""
|
"""
|
||||||
if policy_layer_cls == Layer:
|
if policy_layer_cls == Layer:
|
||||||
return weight, bias
|
return weight, bias
|
||||||
elif policy_layer_cls == Col_Layer:
|
|
||||||
weight = self.slice_tensor(weight, 1, False)
|
dim = dim_mapping[policy_layer_cls] if not reversed else (1 - dim_mapping[policy_layer_cls])
|
||||||
|
# print(weight.shape, dim)
|
||||||
|
if policy_layer_cls == Col_Layer:
|
||||||
|
weight = self.slice_tensor(weight, dim, False, n_cast)
|
||||||
bias = self.slice_tensor(bias, 0, True)
|
bias = self.slice_tensor(bias, 0, True)
|
||||||
elif policy_layer_cls == Row_Layer:
|
elif policy_layer_cls == Row_Layer:
|
||||||
weight = self.slice_tensor(weight, 0, False)
|
weight = self.slice_tensor(weight, dim, False, n_cast)
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError(f"The policy layer class {policy_layer_cls} is not supported")
|
raise NotImplementedError(f"The policy layer class {policy_layer_cls} is not supported")
|
||||||
|
if reversed:
|
||||||
|
weight = weight.transpose(0, 1).contiguous()
|
||||||
return weight, bias
|
return weight, bias
|
||||||
|
|
||||||
def slice_tensor(
|
def slice_tensor(
|
||||||
|
@ -47,6 +54,7 @@ class Slicer():
|
||||||
tensor_in: torch.Tensor,
|
tensor_in: torch.Tensor,
|
||||||
dim: int,
|
dim: int,
|
||||||
is_bias: bool,
|
is_bias: bool,
|
||||||
|
n_cast: int = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
r"""
|
r"""
|
||||||
Slice tensor according to the config
|
Slice tensor according to the config
|
||||||
|
@ -59,14 +67,15 @@ class Slicer():
|
||||||
if tensor_in is None:
|
if tensor_in is None:
|
||||||
return None
|
return None
|
||||||
if not is_bias:
|
if not is_bias:
|
||||||
return self.slice_2d(tensor_in, dim)
|
return self.slice_2d(tensor_in, dim, n_cast)
|
||||||
else:
|
else:
|
||||||
return self.slice_1d(tensor_in)
|
return self.slice_1d(tensor_in, n_cast)
|
||||||
|
|
||||||
def slice_2d(
|
def slice_2d(
|
||||||
self,
|
self,
|
||||||
tensor: torch.Tensor,
|
tensor: torch.Tensor,
|
||||||
dim: int,
|
dim: int,
|
||||||
|
n_cast: int = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
r"""
|
r"""
|
||||||
Slice the 2D tensor
|
Slice the 2D tensor
|
||||||
|
@ -77,13 +86,14 @@ class Slicer():
|
||||||
"""
|
"""
|
||||||
assert dim in [0, 1], f"Only support 2D tensor, but got {dim}D tensor"
|
assert dim in [0, 1], f"Only support 2D tensor, but got {dim}D tensor"
|
||||||
if dim == 0:
|
if dim == 0:
|
||||||
return self.slice_row(tensor)
|
return self.slice_row(tensor, n_cast)
|
||||||
elif dim == 1:
|
elif dim == 1:
|
||||||
return self.slice_col(tensor)
|
return self.slice_col(tensor, n_cast)
|
||||||
|
|
||||||
def slice_1d(
|
def slice_1d(
|
||||||
self,
|
self,
|
||||||
tensor: torch.Tensor,
|
tensor: torch.Tensor,
|
||||||
|
n_cast: int = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
r"""
|
r"""
|
||||||
Slice the 1D tensor
|
Slice the 1D tensor
|
||||||
|
@ -94,11 +104,19 @@ class Slicer():
|
||||||
Returns:
|
Returns:
|
||||||
:class:`torch.Tensor`: The sliced tensor
|
:class:`torch.Tensor`: The sliced tensor
|
||||||
"""
|
"""
|
||||||
return tensor.chunk(self.shardconfig.world_size, dim=0)[self.shardconfig.rank].contiguous()
|
if n_cast is None:
|
||||||
|
return tensor.chunk(self.shardconfig.world_size, dim=0)[self.shardconfig.rank].contiguous()
|
||||||
|
else:
|
||||||
|
tensor_chunks = tensor.chunk(self.shardconfig.world_size * n_cast, dim=0)
|
||||||
|
chunk_list = [
|
||||||
|
tensor_chunks[i] for i in range(self.shardconfig.rank, len(tensor_chunks), self.shardconfig.world_size)
|
||||||
|
]
|
||||||
|
return torch.cat(chunk_list, dim=0).contiguous()
|
||||||
|
|
||||||
def slice_col(
|
def slice_col(
|
||||||
self,
|
self,
|
||||||
tensor: torch.Tensor,
|
tensor: torch.Tensor,
|
||||||
|
n_cast: int = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
r"""
|
r"""
|
||||||
Slice the tensor in column
|
Slice the tensor in column
|
||||||
|
@ -110,11 +128,19 @@ class Slicer():
|
||||||
:class:`torch.Tensor`: The sliced tensor
|
:class:`torch.Tensor`: The sliced tensor
|
||||||
|
|
||||||
"""
|
"""
|
||||||
return tensor.chunk(self.shardconfig.world_size, dim=0)[self.shardconfig.rank].contiguous()
|
if n_cast is None:
|
||||||
|
return tensor.chunk(self.shardconfig.world_size, dim=0)[self.shardconfig.rank].contiguous()
|
||||||
|
else:
|
||||||
|
tensor_chunks = tensor.chunk(self.shardconfig.world_size * n_cast, dim=0)
|
||||||
|
chunk_list = [
|
||||||
|
tensor_chunks[i] for i in range(self.shardconfig.rank, len(tensor_chunks), self.shardconfig.world_size)
|
||||||
|
]
|
||||||
|
return torch.cat(chunk_list, dim=0).contiguous()
|
||||||
|
|
||||||
def slice_row(
|
def slice_row(
|
||||||
self,
|
self,
|
||||||
tensor: torch.Tensor,
|
tensor: torch.Tensor,
|
||||||
|
n_cast: int = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
r"""
|
r"""
|
||||||
Slice the tensor in column
|
Slice the tensor in column
|
||||||
|
@ -125,4 +151,11 @@ class Slicer():
|
||||||
Returns:
|
Returns:
|
||||||
:class:`torch.Tensor`: The sliced tensor
|
:class:`torch.Tensor`: The sliced tensor
|
||||||
"""
|
"""
|
||||||
return tensor.chunk(self.shardconfig.world_size, dim=1)[self.shardconfig.rank].contiguous()
|
if n_cast is None:
|
||||||
|
return tensor.chunk(self.shardconfig.world_size, dim=1)[self.shardconfig.rank].contiguous()
|
||||||
|
else:
|
||||||
|
tensor_chunks = tensor.chunk(self.shardconfig.world_size * n_cast, dim=1)
|
||||||
|
chunk_list = [
|
||||||
|
tensor_chunks[i] for i in range(self.shardconfig.rank, len(tensor_chunks), self.shardconfig.world_size)
|
||||||
|
]
|
||||||
|
return torch.cat(chunk_list, dim=1).contiguous()
|
||||||
|
|
|
@ -6,24 +6,28 @@ import torch.nn as nn
|
||||||
from datasets import load_dataset
|
from datasets import load_dataset
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
from tqdm.auto import tqdm
|
from tqdm.auto import tqdm
|
||||||
from transformers import AutoTokenizer, BertForMaskedLM, DataCollatorForLanguageModeling, get_scheduler
|
from transformers import AutoTokenizer, BertForMaskedLM, DataCollatorForLanguageModeling, GPT2LMHeadModel, get_scheduler
|
||||||
|
|
||||||
import colossalai
|
import colossalai
|
||||||
from colossalai.shardformer.shard import ShardConfig, shard_model
|
from colossalai.shardformer.shard import ShardConfig, shard_model
|
||||||
from colossalai.utils import get_current_device, print_rank_0
|
from colossalai.utils import get_current_device, print_rank_0
|
||||||
|
|
||||||
os.environ['TRANSFORMERS_NO_ADVISORY_WARNINGS'] = 'true'
|
os.environ['TRANSFORMERS_NO_ADVISORY_WARNINGS'] = 'true'
|
||||||
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
|
|
||||||
|
|
||||||
|
|
||||||
def get_args():
|
def get_args():
|
||||||
parser = colossalai.get_default_parser()
|
parser = colossalai.get_default_parser()
|
||||||
parser.add_argument("--mode", type=str, default='inference')
|
parser.add_argument("--mode", type=str, default='inference')
|
||||||
parser.add_argument("--save_model", action='store_true')
|
parser.add_argument("--save_model", action='store_true')
|
||||||
|
parser.add_argument("--model", type=str, default='bert-base-uncased')
|
||||||
return parser.parse_args()
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
def load_data():
|
def load_data(args):
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(args.model)
|
||||||
|
if tokenizer.pad_token is None:
|
||||||
|
tokenizer.add_special_tokens({"pad_token": "[PAD]"})
|
||||||
|
# tokenizer.pad_token_id = 0
|
||||||
datasets = load_dataset('wikitext', 'wikitext-2-raw-v1')
|
datasets = load_dataset('wikitext', 'wikitext-2-raw-v1')
|
||||||
# datasets=load_dataset("yelp_review_full")
|
# datasets=load_dataset("yelp_review_full")
|
||||||
tokenized_datasets = datasets.map(
|
tokenized_datasets = datasets.map(
|
||||||
|
@ -42,18 +46,23 @@ def load_data():
|
||||||
|
|
||||||
|
|
||||||
def inference(model: nn.Module, args):
|
def inference(model: nn.Module, args):
|
||||||
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
|
print(model)
|
||||||
|
# print(model.wte.weight.shape)
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(args.model)
|
||||||
|
if tokenizer.pad_token is None:
|
||||||
|
tokenizer.add_special_tokens({"pad_token": "[PAD]"})
|
||||||
|
tokenizer.pad_token_id = 0
|
||||||
token = "Hello, my dog is cute"
|
token = "Hello, my dog is cute"
|
||||||
inputs = tokenizer(token, return_tensors="pt")
|
inputs = tokenizer(token, return_tensors="pt")
|
||||||
inputs.to("cuda")
|
inputs.to("cuda")
|
||||||
model.eval()
|
model.eval()
|
||||||
model.to("cuda")
|
model.to("cuda")
|
||||||
outputs = model(**inputs)
|
outputs = model(**inputs)
|
||||||
print(outputs)
|
print(outputs[0])
|
||||||
|
|
||||||
|
|
||||||
def train(model: nn.Module, args, num_epoch: int = 3):
|
def train(model: nn.Module, args, num_epoch: int = 3):
|
||||||
train_dataloader, eval_dataloader = load_data()
|
train_dataloader, eval_dataloader = load_data(args)
|
||||||
optimizer = torch.optim.AdamW(model.parameters(), lr=5e-5)
|
optimizer = torch.optim.AdamW(model.parameters(), lr=5e-5)
|
||||||
num_training = num_epoch * len(train_dataloader)
|
num_training = num_epoch * len(train_dataloader)
|
||||||
progress_bar = tqdm(range(num_training))
|
progress_bar = tqdm(range(num_training))
|
||||||
|
@ -94,8 +103,13 @@ def train(model: nn.Module, args, num_epoch: int = 3):
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
args = get_args()
|
args = get_args()
|
||||||
model = BertForMaskedLM.from_pretrained("bert-base-uncased")
|
|
||||||
colossalai.launch_from_torch(config=args.config)
|
colossalai.launch_from_torch(config=args.config)
|
||||||
|
if args.model == 'bert-base-uncased':
|
||||||
|
model = BertForMaskedLM.from_pretrained("bert-base-uncased")
|
||||||
|
elif args.model == 'gpt2':
|
||||||
|
model = GPT2LMHeadModel.from_pretrained("gpt2")
|
||||||
|
else:
|
||||||
|
raise AttributeError("model not supported")
|
||||||
shard_config = ShardConfig(
|
shard_config = ShardConfig(
|
||||||
rank=int(str(get_current_device()).split(':')[-1]),
|
rank=int(str(get_current_device()).split(':')[-1]),
|
||||||
world_size=int(os.environ['WORLD_SIZE']),
|
world_size=int(os.environ['WORLD_SIZE']),
|
||||||
|
|
Loading…
Reference in New Issue