mirror of https://github.com/hpcaitech/ColossalAI
[shardformer] add gpt2 policy and modify shard and slicer to support (#3883)
* add gpt2 policy and modify shard and slicer to support * remove unused code * polish codepull/3943/head
parent
6370a935f6
commit
ef1537759c
|
@ -10,16 +10,26 @@ def build_policies():
|
|||
"""
|
||||
auto_policy_dict = {}
|
||||
|
||||
from transformers.models.bert.modeling_bert import BertForMaskedLM
|
||||
from transformers import BertForMaskedLM
|
||||
|
||||
from .bert import BertForMaskedLMPolicy
|
||||
auto_policy_dict[BertForMaskedLM] = BertForMaskedLMPolicy
|
||||
|
||||
from transformers.models.bert.modeling_bert import BertForSequenceClassification
|
||||
from transformers import BertForSequenceClassification
|
||||
|
||||
from .bert import BertForSequenceClassificationPolicy
|
||||
auto_policy_dict[BertForSequenceClassification] = BertForSequenceClassificationPolicy
|
||||
|
||||
from transformers import GPT2Model
|
||||
|
||||
from .gpt2 import GPT2Policy
|
||||
auto_policy_dict[GPT2Model] = GPT2Policy
|
||||
|
||||
from transformers import GPT2LMHeadModel
|
||||
|
||||
from .gpt2 import GPT2LMHeadModelPolicy
|
||||
auto_policy_dict[GPT2LMHeadModel] = GPT2LMHeadModelPolicy
|
||||
|
||||
return auto_policy_dict
|
||||
|
||||
|
||||
|
|
|
@ -1,11 +1,9 @@
|
|||
# part of code modified from https://github.com/tunib-ai/parallelformers
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Callable, Dict, List, Tuple, Type
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from transformers import AutoConfig
|
||||
|
||||
|
||||
@dataclass
|
||||
|
@ -31,11 +29,18 @@ class Layer:
|
|||
bias (str): The bias suffix of the layer
|
||||
replace_layer (:class:`colosalai.nn`): The layer to replace the original layer
|
||||
ignore (bool): Whether to ignore this layer if it is not in the model
|
||||
reversed (bool): Whether the weight in layer is reversed, commonly the weight in `torch.nn.Linear` is [out, in],
|
||||
but in GPT2 `Conv1D` layer is [in, out] which is reversed.
|
||||
n_cast (int): The number of weight will cast to, like q, k, v in attention layer, n_cast should be 3. commonly in TP, we just chunk the weight with the number of devices,
|
||||
but in multi-head attention, we need to chunk the weight with the number of devices * n_head, and
|
||||
each device should have a part of Q, K and V weight.
|
||||
"""
|
||||
weight: str = None
|
||||
bias: str = None
|
||||
replace_layer: Any = None
|
||||
ignore: bool = False
|
||||
reversed: bool = False
|
||||
n_cast: int = None
|
||||
|
||||
|
||||
@dataclass
|
||||
|
@ -131,7 +136,7 @@ class Policy():
|
|||
(OrignModel, CustomModel)
|
||||
in `CustomModel`, we can overwrite the forward and backward process
|
||||
"""
|
||||
return ()
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def binding_policy() -> Dict:
|
||||
|
@ -146,7 +151,7 @@ class Policy():
|
|||
"bert.embeddings.word_embeddings.weight": "cls.predictions.decoder.weight",
|
||||
}
|
||||
"""
|
||||
return NotImplementedError
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def attn_in() -> List:
|
||||
|
@ -209,4 +214,4 @@ class Policy():
|
|||
Return:
|
||||
List[Layer]: List of layer object
|
||||
"""
|
||||
return NotImplementedError
|
||||
return None
|
||||
|
|
|
@ -1,4 +1,3 @@
|
|||
from dataclasses import dataclass
|
||||
from typing import Any, Callable, Dict, List, Tuple, Type
|
||||
|
||||
import torch.nn as nn
|
||||
|
|
|
@ -0,0 +1,118 @@
|
|||
from typing import Any, Callable, Dict, List, Tuple, Type
|
||||
|
||||
import torch.nn as nn
|
||||
from transformers.models.gpt2.modeling_gpt2 import GPT2Block, GPT2Model
|
||||
|
||||
import colossalai.shardformer.layer.layers as col_nn
|
||||
|
||||
from .basepolicy import Argument, Col_Layer, Layer, Policy, Row_Layer
|
||||
|
||||
|
||||
class GPT2Policy(Policy):
|
||||
|
||||
@staticmethod
|
||||
def argument_policy(config, world_size):
|
||||
return {
|
||||
GPT2Model:
|
||||
Argument(attr_dict={}, param_funcs=[
|
||||
GPT2Policy.embedding,
|
||||
]),
|
||||
GPT2Block:
|
||||
Argument(
|
||||
attr_dict={
|
||||
# 1. reduce hidden size
|
||||
"attn.embed_dim": config.hidden_size // world_size,
|
||||
"attn.split_size": config.hidden_size // world_size,
|
||||
"crossattention.embed_dim": config.hidden_size // world_size,
|
||||
"crossattention.split_size": config.hidden_size // world_size,
|
||||
# 2. reduce number of heads
|
||||
"attn.num_heads": config.num_attention_heads // world_size,
|
||||
"crossattention.num_heads": config.num_attention_heads // world_size,
|
||||
},
|
||||
param_funcs=[
|
||||
GPT2Policy.attn_in,
|
||||
GPT2Policy.attn_out,
|
||||
GPT2Policy.mlp_in,
|
||||
GPT2Policy.mlp_out,
|
||||
]),
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def attn_in() -> List:
|
||||
return [
|
||||
Col_Layer(weight="attn.c_attn.weight",
|
||||
bias="attn.c_attn.bias",
|
||||
n_cast=3,
|
||||
reversed=True,
|
||||
replace_layer=col_nn.Linear1D_Col),
|
||||
Col_Layer(weight="crossattention.c_attn.weight",
|
||||
bias="crossattention.c_attn.bias",
|
||||
n_cast=2,
|
||||
reversed=True,
|
||||
ignore=True,
|
||||
replace_layer=col_nn.Linear1D_Col),
|
||||
Col_Layer(weight="crossattention.q_attn.weight",
|
||||
bias="crossattention.q_attn.bias",
|
||||
reversed=True,
|
||||
ignore=True,
|
||||
replace_layer=col_nn.Linear1D_Col)
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def attn_out() -> List:
|
||||
return [
|
||||
Row_Layer(weight="attn.c_proj.weight",
|
||||
bias="attn.c_proj.bias",
|
||||
reversed=True,
|
||||
replace_layer=col_nn.Linear1D_Row),
|
||||
Row_Layer(weight="crossattention.c_proj.weight",
|
||||
bias="crossattention.c_proj.bias",
|
||||
reversed=True,
|
||||
ignore=True,
|
||||
replace_layer=col_nn.Linear1D_Row)
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def mlp_in() -> List:
|
||||
return [
|
||||
Col_Layer(weight="mlp.c_fc.weight", bias="mlp.c_fc.bias", reversed=True, replace_layer=col_nn.Linear1D_Col),
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def mlp_out() -> List:
|
||||
return [
|
||||
Row_Layer(weight="mlp.c_proj.weight",
|
||||
bias="mlp.c_proj.bias",
|
||||
reversed=True,
|
||||
replace_layer=col_nn.Linear1D_Row)
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def embedding() -> List:
|
||||
return [Col_Layer(weight="wte.weight", replace_layer=col_nn.VocabParallelEmbedding1D)]
|
||||
|
||||
|
||||
from transformers import GPT2LMHeadModel
|
||||
|
||||
|
||||
class GPT2LMHeadModelPolicy(GPT2Policy):
|
||||
|
||||
@staticmethod
|
||||
def argument_policy(config, world_size):
|
||||
base_argument = GPT2Policy.argument_policy(config, world_size)
|
||||
argument = {
|
||||
GPT2LMHeadModel: Argument(attr_dict={}, param_funcs=[
|
||||
GPT2LMHeadModelPolicy.unembedding,
|
||||
]),
|
||||
}
|
||||
argument.update(base_argument)
|
||||
return argument
|
||||
|
||||
@staticmethod
|
||||
def unembedding() -> List:
|
||||
return [
|
||||
Col_Layer(weight="lm_head.weight",
|
||||
bias="lm_head.bias",
|
||||
replace_layer=col_nn.Linear1D_Col,
|
||||
gather_output=True)
|
||||
]
|
|
@ -2,6 +2,7 @@ from typing import Any, Callable, Dict, List
|
|||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from transformers.pytorch_utils import Conv1D
|
||||
|
||||
from ..policies.autopolicy import get_autopolicy
|
||||
from ..policies.basepolicy import Policy
|
||||
|
@ -35,10 +36,22 @@ class ModelSharder(object):
|
|||
self.model_config = self.model.config
|
||||
|
||||
def shard(self) -> None:
|
||||
self.reshape_embedding()
|
||||
self.inject_model(self.model)
|
||||
self.replace_layer(self.model)
|
||||
self.bind_layer(self.model)
|
||||
|
||||
def reshape_embedding(self,) -> None:
|
||||
r"""
|
||||
Reshape the Embedding layer to make the embedding dimension divisible by world_size
|
||||
"""
|
||||
vocab_size = self.model_config.vocab_size
|
||||
world_size = self.shard_config.world_size
|
||||
if vocab_size % world_size != 0:
|
||||
new_vocab_size = vocab_size + world_size - vocab_size % world_size
|
||||
self.model.resize_token_embeddings(new_vocab_size)
|
||||
self.model_config = self.model.config
|
||||
|
||||
def inject_model(
|
||||
self,
|
||||
model: nn.Module,
|
||||
|
@ -53,6 +66,8 @@ class ModelSharder(object):
|
|||
"""
|
||||
inject_policy = self.policy.inject_policy()
|
||||
|
||||
if inject_policy is None:
|
||||
return
|
||||
org_model_cls = inject_policy[0]
|
||||
shard_model_cls = inject_policy[1]
|
||||
|
||||
|
@ -82,9 +97,9 @@ class ModelSharder(object):
|
|||
origin_layer_cls = argument_policy[0]
|
||||
attr_dict = argument_policy[1].attr_dict
|
||||
param_funcs = argument_policy[1].param_funcs
|
||||
self.reverse_replace_layer(model, origin_layer_cls, attr_dict, param_funcs)
|
||||
self.traverse_replace_layer(model, origin_layer_cls, attr_dict, param_funcs)
|
||||
|
||||
def reverse_replace_layer(
|
||||
def traverse_replace_layer(
|
||||
self,
|
||||
layer: nn.Module,
|
||||
origin_cls: nn.Module,
|
||||
|
@ -100,17 +115,12 @@ class ModelSharder(object):
|
|||
attr_dict (Dict): The attribute dict to modify
|
||||
policy_cls (:class:`Policy`): The policy class
|
||||
"""
|
||||
if layer.__class__ == origin_cls:
|
||||
for k, v in attr_dict.items():
|
||||
setattr_(layer, k, v, ignore=True)
|
||||
self.shard_one_layer(layer, param_funcs)
|
||||
for name, child in layer.named_children():
|
||||
if child.__class__ == origin_cls:
|
||||
# replac_layer = child
|
||||
for k, v in attr_dict.items():
|
||||
setattr_(child, k, v, ignore=True)
|
||||
# print(f"Sharding {name} layer", replac_layer.attention.self.__dict__)
|
||||
# setattr_(layer, name, self.shard_one_layer(child, policy_cls))
|
||||
self.shard_one_layer(child, param_funcs)
|
||||
continue
|
||||
|
||||
self.reverse_replace_layer(child, origin_cls, attr_dict, param_funcs)
|
||||
self.traverse_replace_layer(child, origin_cls, attr_dict, param_funcs)
|
||||
return layer
|
||||
|
||||
def shard_one_layer(
|
||||
|
@ -126,7 +136,6 @@ class ModelSharder(object):
|
|||
param_funcs (:class:`List[typing.Callable]`): The function list to get shard information in policy class
|
||||
|
||||
"""
|
||||
# print(org_layer)
|
||||
for func in param_funcs:
|
||||
policy_layers = func()
|
||||
for policy_layer in policy_layers:
|
||||
|
@ -136,9 +145,10 @@ class ModelSharder(object):
|
|||
bias_attr = policy_layer.bias
|
||||
replace_layer_cls = policy_layer.replace_layer
|
||||
ignore = policy_layer.ignore
|
||||
n_cast = policy_layer.n_cast
|
||||
reversed = policy_layer.reversed
|
||||
if policy_layer.__class__.__name__ == "Col_Layer":
|
||||
gather_output = policy_layer.gather_output
|
||||
# print(gather_output)
|
||||
|
||||
if weight_attr is not None:
|
||||
if hasattr_(org_layer, weight_attr):
|
||||
|
@ -161,13 +171,11 @@ class ModelSharder(object):
|
|||
layer_attr = (lambda x: x[:x.rfind(".")])(weight_attr or bias_attr)
|
||||
|
||||
# slice weight and bias
|
||||
weight, bias = self.slicer.slice_weight_bias(weight, bias, policy_layer.__class__)
|
||||
# print(os.environ['RANK'], policy_layer.__class__, weight.shape, bias.shape if bias is not None else None)
|
||||
weight, bias = self.slicer.slice_weight_bias(weight, bias, policy_layer.__class__, n_cast, reversed)
|
||||
|
||||
# create new object to replace the origin layer
|
||||
if replace_layer_cls is not None:
|
||||
# print(f"RANK {os.environ['RANK']}: replace {getattr_(org_layer, layer_attr).__class__} to {replace_layer_cls}, shape is {weight.shape}")
|
||||
if isinstance(getattr_(org_layer, layer_attr), nn.Linear):
|
||||
if isinstance(getattr_(org_layer, layer_attr), (nn.Linear, Conv1D)):
|
||||
if replace_layer_cls.__name__ == "Linear1D_Row":
|
||||
replace_layer = replace_layer_cls(weight.shape[1],
|
||||
weight.shape[0],
|
||||
|
@ -235,6 +243,8 @@ class ModelSharder(object):
|
|||
model (:class:`torch.nn.Module`): The shard model
|
||||
"""
|
||||
binding_map = self.policy.binding_policy()
|
||||
if binding_map is None:
|
||||
return
|
||||
for k, v in binding_map.items():
|
||||
param = getattr_(model, k)
|
||||
param = nn.Parameter(param)
|
||||
|
|
|
@ -19,6 +19,8 @@ class Slicer():
|
|||
weight: torch.Tensor,
|
||||
bias: torch.Tensor,
|
||||
policy_layer_cls: Layer,
|
||||
n_cast: int = None,
|
||||
reversed: bool = False,
|
||||
):
|
||||
r"""
|
||||
Slice the weight and bias according to policy layer cls
|
||||
|
@ -33,13 +35,18 @@ class Slicer():
|
|||
"""
|
||||
if policy_layer_cls == Layer:
|
||||
return weight, bias
|
||||
elif policy_layer_cls == Col_Layer:
|
||||
weight = self.slice_tensor(weight, 1, False)
|
||||
|
||||
dim = dim_mapping[policy_layer_cls] if not reversed else (1 - dim_mapping[policy_layer_cls])
|
||||
# print(weight.shape, dim)
|
||||
if policy_layer_cls == Col_Layer:
|
||||
weight = self.slice_tensor(weight, dim, False, n_cast)
|
||||
bias = self.slice_tensor(bias, 0, True)
|
||||
elif policy_layer_cls == Row_Layer:
|
||||
weight = self.slice_tensor(weight, 0, False)
|
||||
weight = self.slice_tensor(weight, dim, False, n_cast)
|
||||
else:
|
||||
raise NotImplementedError(f"The policy layer class {policy_layer_cls} is not supported")
|
||||
if reversed:
|
||||
weight = weight.transpose(0, 1).contiguous()
|
||||
return weight, bias
|
||||
|
||||
def slice_tensor(
|
||||
|
@ -47,6 +54,7 @@ class Slicer():
|
|||
tensor_in: torch.Tensor,
|
||||
dim: int,
|
||||
is_bias: bool,
|
||||
n_cast: int = None,
|
||||
) -> torch.Tensor:
|
||||
r"""
|
||||
Slice tensor according to the config
|
||||
|
@ -59,14 +67,15 @@ class Slicer():
|
|||
if tensor_in is None:
|
||||
return None
|
||||
if not is_bias:
|
||||
return self.slice_2d(tensor_in, dim)
|
||||
return self.slice_2d(tensor_in, dim, n_cast)
|
||||
else:
|
||||
return self.slice_1d(tensor_in)
|
||||
return self.slice_1d(tensor_in, n_cast)
|
||||
|
||||
def slice_2d(
|
||||
self,
|
||||
tensor: torch.Tensor,
|
||||
dim: int,
|
||||
n_cast: int = None,
|
||||
) -> torch.Tensor:
|
||||
r"""
|
||||
Slice the 2D tensor
|
||||
|
@ -77,13 +86,14 @@ class Slicer():
|
|||
"""
|
||||
assert dim in [0, 1], f"Only support 2D tensor, but got {dim}D tensor"
|
||||
if dim == 0:
|
||||
return self.slice_row(tensor)
|
||||
return self.slice_row(tensor, n_cast)
|
||||
elif dim == 1:
|
||||
return self.slice_col(tensor)
|
||||
return self.slice_col(tensor, n_cast)
|
||||
|
||||
def slice_1d(
|
||||
self,
|
||||
tensor: torch.Tensor,
|
||||
n_cast: int = None,
|
||||
) -> torch.Tensor:
|
||||
r"""
|
||||
Slice the 1D tensor
|
||||
|
@ -94,11 +104,19 @@ class Slicer():
|
|||
Returns:
|
||||
:class:`torch.Tensor`: The sliced tensor
|
||||
"""
|
||||
return tensor.chunk(self.shardconfig.world_size, dim=0)[self.shardconfig.rank].contiguous()
|
||||
if n_cast is None:
|
||||
return tensor.chunk(self.shardconfig.world_size, dim=0)[self.shardconfig.rank].contiguous()
|
||||
else:
|
||||
tensor_chunks = tensor.chunk(self.shardconfig.world_size * n_cast, dim=0)
|
||||
chunk_list = [
|
||||
tensor_chunks[i] for i in range(self.shardconfig.rank, len(tensor_chunks), self.shardconfig.world_size)
|
||||
]
|
||||
return torch.cat(chunk_list, dim=0).contiguous()
|
||||
|
||||
def slice_col(
|
||||
self,
|
||||
tensor: torch.Tensor,
|
||||
n_cast: int = None,
|
||||
) -> torch.Tensor:
|
||||
r"""
|
||||
Slice the tensor in column
|
||||
|
@ -110,11 +128,19 @@ class Slicer():
|
|||
:class:`torch.Tensor`: The sliced tensor
|
||||
|
||||
"""
|
||||
return tensor.chunk(self.shardconfig.world_size, dim=0)[self.shardconfig.rank].contiguous()
|
||||
if n_cast is None:
|
||||
return tensor.chunk(self.shardconfig.world_size, dim=0)[self.shardconfig.rank].contiguous()
|
||||
else:
|
||||
tensor_chunks = tensor.chunk(self.shardconfig.world_size * n_cast, dim=0)
|
||||
chunk_list = [
|
||||
tensor_chunks[i] for i in range(self.shardconfig.rank, len(tensor_chunks), self.shardconfig.world_size)
|
||||
]
|
||||
return torch.cat(chunk_list, dim=0).contiguous()
|
||||
|
||||
def slice_row(
|
||||
self,
|
||||
tensor: torch.Tensor,
|
||||
n_cast: int = None,
|
||||
) -> torch.Tensor:
|
||||
r"""
|
||||
Slice the tensor in column
|
||||
|
@ -125,4 +151,11 @@ class Slicer():
|
|||
Returns:
|
||||
:class:`torch.Tensor`: The sliced tensor
|
||||
"""
|
||||
return tensor.chunk(self.shardconfig.world_size, dim=1)[self.shardconfig.rank].contiguous()
|
||||
if n_cast is None:
|
||||
return tensor.chunk(self.shardconfig.world_size, dim=1)[self.shardconfig.rank].contiguous()
|
||||
else:
|
||||
tensor_chunks = tensor.chunk(self.shardconfig.world_size * n_cast, dim=1)
|
||||
chunk_list = [
|
||||
tensor_chunks[i] for i in range(self.shardconfig.rank, len(tensor_chunks), self.shardconfig.world_size)
|
||||
]
|
||||
return torch.cat(chunk_list, dim=1).contiguous()
|
||||
|
|
|
@ -6,24 +6,28 @@ import torch.nn as nn
|
|||
from datasets import load_dataset
|
||||
from torch.utils.data import DataLoader
|
||||
from tqdm.auto import tqdm
|
||||
from transformers import AutoTokenizer, BertForMaskedLM, DataCollatorForLanguageModeling, get_scheduler
|
||||
from transformers import AutoTokenizer, BertForMaskedLM, DataCollatorForLanguageModeling, GPT2LMHeadModel, get_scheduler
|
||||
|
||||
import colossalai
|
||||
from colossalai.shardformer.shard import ShardConfig, shard_model
|
||||
from colossalai.utils import get_current_device, print_rank_0
|
||||
|
||||
os.environ['TRANSFORMERS_NO_ADVISORY_WARNINGS'] = 'true'
|
||||
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
|
||||
|
||||
|
||||
def get_args():
|
||||
parser = colossalai.get_default_parser()
|
||||
parser.add_argument("--mode", type=str, default='inference')
|
||||
parser.add_argument("--save_model", action='store_true')
|
||||
parser.add_argument("--model", type=str, default='bert-base-uncased')
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def load_data():
|
||||
def load_data(args):
|
||||
tokenizer = AutoTokenizer.from_pretrained(args.model)
|
||||
if tokenizer.pad_token is None:
|
||||
tokenizer.add_special_tokens({"pad_token": "[PAD]"})
|
||||
# tokenizer.pad_token_id = 0
|
||||
datasets = load_dataset('wikitext', 'wikitext-2-raw-v1')
|
||||
# datasets=load_dataset("yelp_review_full")
|
||||
tokenized_datasets = datasets.map(
|
||||
|
@ -42,18 +46,23 @@ def load_data():
|
|||
|
||||
|
||||
def inference(model: nn.Module, args):
|
||||
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
|
||||
print(model)
|
||||
# print(model.wte.weight.shape)
|
||||
tokenizer = AutoTokenizer.from_pretrained(args.model)
|
||||
if tokenizer.pad_token is None:
|
||||
tokenizer.add_special_tokens({"pad_token": "[PAD]"})
|
||||
tokenizer.pad_token_id = 0
|
||||
token = "Hello, my dog is cute"
|
||||
inputs = tokenizer(token, return_tensors="pt")
|
||||
inputs.to("cuda")
|
||||
model.eval()
|
||||
model.to("cuda")
|
||||
outputs = model(**inputs)
|
||||
print(outputs)
|
||||
print(outputs[0])
|
||||
|
||||
|
||||
def train(model: nn.Module, args, num_epoch: int = 3):
|
||||
train_dataloader, eval_dataloader = load_data()
|
||||
train_dataloader, eval_dataloader = load_data(args)
|
||||
optimizer = torch.optim.AdamW(model.parameters(), lr=5e-5)
|
||||
num_training = num_epoch * len(train_dataloader)
|
||||
progress_bar = tqdm(range(num_training))
|
||||
|
@ -94,8 +103,13 @@ def train(model: nn.Module, args, num_epoch: int = 3):
|
|||
|
||||
if __name__ == "__main__":
|
||||
args = get_args()
|
||||
model = BertForMaskedLM.from_pretrained("bert-base-uncased")
|
||||
colossalai.launch_from_torch(config=args.config)
|
||||
if args.model == 'bert-base-uncased':
|
||||
model = BertForMaskedLM.from_pretrained("bert-base-uncased")
|
||||
elif args.model == 'gpt2':
|
||||
model = GPT2LMHeadModel.from_pretrained("gpt2")
|
||||
else:
|
||||
raise AttributeError("model not supported")
|
||||
shard_config = ShardConfig(
|
||||
rank=int(str(get_current_device()).split(':')[-1]),
|
||||
world_size=int(os.environ['WORLD_SIZE']),
|
||||
|
|
Loading…
Reference in New Issue