[shardformer] adapted llama to the new API (#4036)

pull/4157/head
Frank Lee 1 year ago
parent 74d176c8d8
commit c1d5453e9f

@ -1,64 +1,76 @@
import importlib
from dataclasses import dataclass
import torch.nn as nn
from .basepolicy import Policy
def build_policies():
r"""
Build the policies for the model
Return:
The dict for the policies
@dataclass
class PolicyLocation:
"""
auto_policy_dict = {}
from transformers import BertModel
from .bert import BertModelPolicy
auto_policy_dict[BertModel] = BertModelPolicy
from transformers import BertForPreTraining
from .bert import BertForPretrainingPolicy
auto_policy_dict[BertForPreTraining] = BertForPretrainingPolicy
from transformers import BertLMHeadModel
from .bert import BertLMHeadModelPolicy
auto_policy_dict[BertLMHeadModel] = BertLMHeadModelPolicy
from transformers import BertForMaskedLM
from .bert import BertForMaskedLMPolicy
auto_policy_dict[BertForMaskedLM] = BertForMaskedLMPolicy
from transformers import BertForNextSentencePrediction
PolicyLocation describes the location of a policy class.
from .bert import BertForNextSentencePredictionPolicy
auto_policy_dict[BertForNextSentencePrediction] = BertForNextSentencePredictionPolicy
from transformers import BertForSequenceClassification
from .bert import BertForSequenceClassificationPolicy
auto_policy_dict[BertForSequenceClassification] = BertForSequenceClassificationPolicy
from transformers.models.llama.modeling_llama import LlamaModel
Args:
file_name (str): The file name of the policy under colossalai.shardformer.policies
class_name (str): The class name of the policy class
"""
file_name: str
class_name: str
# we don't want to import all policies here
# as each policy file imports its own model zoo library
# we will allow the user to only import the policy file needed
_POLICY_LIST = {
# BERT
"transformers.models.bert.modeling_bert.BertModel":
PolicyLocation(file_name="bert", class_name="BertPolicy"),
"transformers.models.bert.modeling_bert.BertForPreTraining":
PolicyLocation(file_name="bert", class_name="BertForPretrainingPolicy"),
"transformers.models.bert.modeling_bert.BertForMaskedLM":
PolicyLocation(file_name="bert", class_name="BertForMaskedLMPolicy"),
"transformers.models.bert.modeling_bert.BertLMHeadModel":
PolicyLocation(file_name="bert", class_name="BertLMHeadModelPolicy"),
"transformers.models.bert.modeling_bert.BertForNextSentencePrediction":
PolicyLocation(file_name="bert", class_name="BertForNextSentencePredictionPolicy"),
"transformers.models.bert.modeling_bert.BertForSequenceClassification":
PolicyLocation(file_name="bert", class_name="BertForSequenceClassificationPolicy"),
"transformers.models.bert.modeling_bert.BertForMultipleChoice":
PolicyLocation(file_name="bert", class_name="BertForMultipleChoicePolicy"),
# LLaMA
"transformers.models.llama.modeling_llama.LlamaModel":
PolicyLocation(file_name="llama", class_name="LlamaPolicy"),
"transformers.models.llama.modeling_llama.LlamaForCausalLM":
PolicyLocation(file_name="llama", class_name="LlamaForCausalLMPolicy"),
"transformers.models.llama.modeling_llama.LlamaForSequenceClassification":
PolicyLocation(file_name="llama", class_name="LlamaForSequenceClassificationPolicy"),
# T5
# GPT2
}
def import_policy(policy_location: PolicyLocation) -> Policy:
"""
Dynamically import a Policy class based on the policy location.
"""
module_name = f"colossalai.shardformer.policies.{policy_location.file_name}"
module = importlib.import_module(module_name)
return getattr(module, policy_location.class_name)
# from .llama import LlamaPolicy
# auto_policy_dict[LlamaModel] = LlamaPolicy
# from transformers import LlamaForSequenceClassification
# from .llama import LlamaForSequenceClassificationPolicy
# auto_policy_dict[LlamaForSequenceClassification] = LlamaForSequenceClassificationPolicy
# from transformers import LlamaForCausalLM
# from .llama import LlamaForCausalLMPolicy
# auto_policy_dict[LlamaForCausalLM] = LlamaForCausalLMPolicy
# from transformers import GPT2Model
# from .gpt2 import GPT2Policy
# auto_policy_dict[GPT2Model] = GPT2Policy
# from transformers import GPT2LMHeadModel
# from .gpt2 import GPT2LMHeadModelPolicy
# auto_policy_dict[GPT2LMHeadModel] = GPT2LMHeadModelPolicy
return auto_policy_dict
def _fullname(obj):
"""
Return the full name of an object, including the module name.
"""
klass = obj.__class__
module = klass.__module__
if module == 'builtins':
return klass.__qualname__ # avoid outputs like 'builtins.str'
return module + '.' + klass.__qualname__
def get_autopolicy(model: nn.Module) -> Policy:
@ -71,16 +83,14 @@ def get_autopolicy(model: nn.Module) -> Policy:
Return:
:class:`Policy`: The auto policy for the model
"""
auto_policy_dict = build_policies()
policy = auto_policy_dict.get(model.__class__, None)
if policy is None:
full_name = _fullname(model)
policy_location = _POLICY_LIST.get(full_name, None)
if policy_location is None:
raise NotImplementedError(
f"Auto policy for {model.__class__.__qualname__} is not implemented\n Supported models are {[i.__qualname__ for i in auto_policy_dict.keys()]}"
f"Auto policy for {model.__class__.__qualname__} is not implemented\n. Supported models are {list(_POLICY_LIST.keys())}"
)
else:
policy = import_policy(policy_location)
return policy()
return policy()
# from transformers.models.bert.modeling_bert import BertForMaskedLM, BertForPreTraining
# model = BertForPreTraining
# policy = get_autopolicy(model)
# print(policy)

@ -75,6 +75,7 @@ class Policy(ABC):
"""
def __init__(self) -> None:
self.shard_config = None
self.model = None
self.shard_config = None
@ -101,6 +102,7 @@ class Policy(ABC):
r"""
Perform some preprocessing of the model, like reshaping the embedding layer
"""
pass
@abstractmethod
def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:
@ -135,6 +137,7 @@ class Policy(ABC):
...
}
"""
pass
@abstractmethod
def new_model_class(self) -> Union[Type[nn.Module], None]:
@ -149,6 +152,7 @@ class Policy(ABC):
return BertModel_
```
"""
pass
@abstractmethod
def postprocess(self) -> nn.Module:
@ -156,3 +160,4 @@ class Policy(ABC):
Perform some postprocessing of the model, like binding the weight of embedding layer with
the classifier layer
"""
pass

@ -1,122 +1,121 @@
from dataclasses import dataclass
from typing import Any, Callable, Dict, List, Tuple, Type
from typing import Dict, Union
import torch.nn as nn
from transformers import LlamaForCausalLM, LlamaForSequenceClassification
from transformers.models.llama.modeling_llama import LlamaDecoderLayer, LlamaModel
import colossalai.shardformer.layer.layers as col_nn
from colossalai.shardformer.layer.layers import Linear1D_Col, Linear1D_Row, VocabParallelEmbedding1D
from .basepolicy import Argument, Col_Layer, Policy, Row_Layer
from .basepolicy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
class LlamaPolicy(Policy):
@staticmethod
def argument_policy(config, world_size: int) -> Dict[nn.Module, Argument]:
def preprocess(self):
# Resize embedding
vocab_size = self.model.config.vocab_size
world_size = self.shard_config.tensor_parallel_size
if vocab_size % world_size != 0:
new_vocab_size = vocab_size + world_size - vocab_size % world_size
self.model.resize_token_embeddings(new_vocab_size)
return self.model
def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:
return {
LlamaDecoderLayer:
Argument(attr_dict={
"self_attn.hidden_size": config.hidden_size // world_size,
"self_attn.num_heads": config.num_attention_heads // world_size,
},
param_funcs=[LlamaPolicy.attn_layer, LlamaPolicy.mlp_layer]),
ModulePolicyDescription(
attribute_replacement={
"self_attn.hidden_size":
self.model.config.hidden_size // self.shard_config.tensor_parallel_size,
"self_attn.num_heads":
self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size,
},
param_replacement=[],
sub_module_replacement=[
SubModuleReplacementDescription(
suffix="self_attn.q_proj",
target_module=Linear1D_Col,
),
SubModuleReplacementDescription(
suffix="self_attn.k_proj",
target_module=Linear1D_Col,
),
SubModuleReplacementDescription(
suffix="self_attn.v_proj",
target_module=Linear1D_Col,
),
SubModuleReplacementDescription(
suffix="self_attn.o_proj",
target_module=Linear1D_Row,
),
SubModuleReplacementDescription(
suffix="mlp.gate_proj",
target_module=Linear1D_Col,
),
SubModuleReplacementDescription(
suffix="mlp.up_proj",
target_module=Linear1D_Col,
),
SubModuleReplacementDescription(
suffix="mlp.down_proj",
target_module=Linear1D_Row,
)
],
),
LlamaModel:
Argument(attr_dict={}, param_funcs=[LlamaPolicy.embeddings])
ModulePolicyDescription(attribute_replacement={},
param_replacement=[],
sub_module_replacement=[
SubModuleReplacementDescription(
suffix="embed_tokens",
target_module=VocabParallelEmbedding1D,
)
])
}
@staticmethod
def attn_layer() -> List:
return [
Col_Layer(
suffix="self_attn.q_proj",
weight="weight",
bias="bias",
replace_layer=col_nn.Linear1D_Col,
),
Col_Layer(
suffix="self_attn.k_proj",
weight="weight",
bias="bias",
replace_layer=col_nn.Linear1D_Col,
),
Col_Layer(
suffix="self_attn.v_proj",
weight="weight",
bias="bias",
replace_layer=col_nn.Linear1D_Col,
),
Row_Layer(
suffix="self_attn.o_proj",
weight="weight",
bias="bias",
replace_layer=col_nn.Linear1D_Row,
)
]
@staticmethod
def mlp_layer() -> List:
return [
Col_Layer(
suffix="mlp.gate_proj",
weight="weight",
bias="bias",
replace_layer=col_nn.Linear1D_Col,
gather_output=True,
),
Col_Layer(
suffix="mlp.up_proj",
weight="weight",
bias="bias",
replace_layer=col_nn.Linear1D_Row,
gather_output=True,
),
Col_Layer(
suffix="mlp.down_proj",
weight="weight",
bias="bias",
replace_layer=col_nn.Linear1D_Col,
gather_output=True,
),
]
@staticmethod
def embeddings() -> List:
return [Col_Layer(
suffix="embed_tokens",
weight="weight",
replace_layer=col_nn.VocabParallelEmbedding1D,
)]
from transformers import LlamaForCausalLM
class LlamaForCausalLMPolicy(LlamaPolicy):
def new_model_class(self):
return None
@staticmethod
def argument(config, world_size):
llamapolicy = LlamaPolicy.argument_policy(config, world_size)
argument = {LlamaForCausalLM: Argument(attr_dict={}, param_funcs=[LlamaForCausalLMPolicy.lm_head])}
argument.update(llamapolicy)
def postprocess(self):
return self.model
@staticmethod
def lm_head() -> List:
return [Col_Layer(suffix="lm_head", weight="weight", replace_layer=col_nn.Linear1D_Col, gather_output=True)]
class LlamaForCausalLMPolicy(LlamaPolicy):
from transformers import LlamaForSequenceClassification
def module_policy(self):
policy = super().module_policy()
# add a new item for casual lm
new_item = {
LlamaForCausalLM:
ModulePolicyDescription(attribute_replacement={},
param_replacement=[],
sub_module_replacement=[
SubModuleReplacementDescription(suffix="lm_head",
target_module=Linear1D_Col,
kwargs=dict(gather_output=True))
])
}
policy.update(new_item)
return policy
class LlamaForSequenceClassificationPolicy(LlamaPolicy):
@staticmethod
def argument(config, world_size):
llamapolicy = LlamaPolicy.argument_policy(config, world_size)
argument = {
def module_policy(self):
policy = super().module_policy()
# add a new item for sequence classification
new_item = {
LlamaForSequenceClassification:
Argument(attr_dict={}, param_funcs=[LlamaForSequenceClassificationPolicy.score])
ModulePolicyDescription(attribute_replacement={},
param_replacement=[],
sub_module_replacement=[
SubModuleReplacementDescription(suffix="score",
target_module=Linear1D_Col,
kwargs=dict(gather_output=True))
])
}
argument.update(llamapolicy)
@staticmethod
def score() -> List:
return [Col_Layer(suffix="score", weight="weight", replace_layer=col_nn.Linear1D_Col, gather_output=True)]
policy.update(new_item)
return policy

@ -1,5 +1,6 @@
from dataclasses import dataclass
from typing import List, Literal
from colossalai.cluster.dist_coordinator import DistCoordinator
__all__ = ['ShardConfig']
@ -19,9 +20,19 @@ class ShardConfig:
gather_output (bool): Whether to gather the output of the model of the last layer
"""
tensor_parallel_size: int
# TODO: add support for tensor parallel
# pipeline_parallel_size: int
# data_parallel_size: int
tensor_parallel_mode: Literal['1d', '2d', '2.5d', '3d']
inference_only: bool = True
gather_output: bool = True
# tensor_parallel_mode: Literal['1d', '2d', '2.5d', '3d']
# inference_only: bool = True
# gather_output: bool = True
def __post_init__(self):
coordinator = DistCoordinator()
# ensure the parallel size can match the world size
world_size = coordinator.world_size
self.data_parallel_size = world_size // self.tensor_parallel_size
assert world_size == self.data_parallel_size * self.tensor_parallel_size, \
f"The world size ({world_size}) should be divisible by the data parallel size {self.data_parallel_size} and tensor parallel size {self.tensor_parallel_size}"

@ -1,8 +1,6 @@
from typing import Any, Callable, Dict, List
import torch
import torch.nn as nn
from transformers.pytorch_utils import Conv1D
from colossalai.cluster.process_group_manager import ProcessGroupManager
@ -41,10 +39,10 @@ class ModelSharder(object):
"""
self.policy.set_model(self.model)
self.policy.set_shard_config(self.shard_config)
self.preprocess()
self.replace_model_class()
self.replace_module()
self.postprocess()
self._preprocess()
self._replace_model_class()
self._replace_module()
self._postprocess()
def reshape_embedding(self) -> None:
r"""
@ -57,13 +55,13 @@ class ModelSharder(object):
self.model.resize_token_embeddings(new_vocab_size)
self.model_config = self.model.config
def preprocess(self) -> None:
def _preprocess(self) -> None:
self.model = self.policy.preprocess()
def postprocess(self) -> None:
def _postprocess(self) -> None:
self.model = self.policy.postprocess()
def replace_model_class(self) -> None:
def _replace_model_class(self,) -> None:
r"""
Replace the model to policy defined model
Mainly modify the forward and backward to fit distributed model
@ -84,7 +82,7 @@ class ModelSharder(object):
getattr(new_model_class, key),
)
def replace_module(self) -> None:
def _replace_module(self,) -> None:
r"""
Replace the module according to the policy, and replace the module one by one

@ -47,10 +47,12 @@ class ShardFormer:
"""
Initialize the distributed process group according to the
"""
# create process group manager and 1d process group
# TODO: may need to support other parallel mode when the config has such as field
pg_manager = ProcessGroupManager()
if (self.shard_config.tensor_parallel_mode == '1d'):
pg_manager.create_process_group(name='tp1d', ranks=range(self.coordinator.world_size))
pg_manager.create_process_group(name='tp1d', ranks=range(self.coordinator.world_size))
self.pg_manager = pg_manager
return pg_manager
def shard_model(self, model: nn.Module, policy: Policy = None):

@ -24,21 +24,18 @@ CONFIG = dict(parallel=dict(data=1, pipeline=1, tensor=dict(size=2, mode='1d')),
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
def build_model(rank, world_size, model):
config = BertConfig.from_pretrained('bert-base-uncased')
def build_model(world_size, model_fn):
config = BertConfig()
config.hidden_dropout_prob = 0
config.attention_probs_dropout_prob = 0
org_model = BertForMaskedLM.from_pretrained('bert-base-uncased', config=config)
org_model = model_fn(config=config)
org_model_forshard = copy.deepcopy(org_model)
org_model.to('cuda')
# TODO: no need to transfer to cuda
org_model_forshard.to('cuda')
shard_config = ShardConfig(
tensor_parallel_size=2,
tensor_parallel_mode='1d',
)
shard_config = ShardConfig(tensor_parallel_size=world_size,)
shard_former = ShardFormer(shard_config=shard_config)
shard_former.init_distributed()
sharded_model = shard_former.shard_model(org_model_forshard).to('cuda')
@ -99,15 +96,22 @@ def check_bert(rank, world_size, port):
disable_existing_loggers()
colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
forward_list = [
BertModel, BertForPreTraining, BertForMaskedLM, BertLMHeadModel, BertForNextSentencePrediction,
BertForSequenceClassification
BertForMaskedLM,
BertForPreTraining,
BertLMHeadModel,
# TODO: do not work yet
# BertModel,
# BertForSequenceClassification
# BertForNextSentencePrediction,
]
backward_lsit = [BertForMaskedLM, BertLMHeadModel]
for model in forward_list:
org_model, sharded_model = build_model(rank, world_size, model)
for model_fn in forward_list:
org_model, sharded_model = build_model(model_fn)
check_forward(org_model, sharded_model)
if model in backward_lsit:
if model_fn in backward_lsit:
check_backward(org_model, sharded_model)
torch.cuda.empty_cache()

@ -4,31 +4,28 @@ import random
import pytest
import torch
from transformers import AutoTokenizer, LlamaConfig, LlamaForCausalLM, LlamaModel, LlamaTokenizerFast
from transformers import LlamaConfig, LlamaForCausalLM, LlamaForSequenceClassification, LlamaModel, LlamaTokenizerFast
import colossalai
from colossalai.logging import disable_existing_loggers
from colossalai.shardformer.shard import ShardConfig, shard_model
from colossalai.shardformer import ShardConfig, ShardFormer
from colossalai.testing import rerun_if_address_is_in_use, spawn
os.environ['TRANSFORMERS_NO_ADVISORY_WARNINGS'] = 'true'
CONFIG = dict(parallel=dict(data=1, pipeline=1, tensor=dict(size=4, mode='1d')),)
tokenizer = LlamaTokenizerFast.from_pretrained("hf-internal-testing/llama-tokenizer")
def build_model(rank, world_size):
cfg = LlamaConfig(num_hidden_layers=16)
org_model = LlamaForCausalLM(cfg)
def build_model(world_size, model_fn):
# create new model
config = LlamaConfig(num_hidden_layers=8)
org_model = model_fn(config).cuda()
shardconfig = ShardConfig(
rank=rank,
world_size=world_size,
gather_output=True,
)
org_model = org_model.to('cuda')
org_model_forshard = copy.deepcopy(org_model)
sharded_model = shard_model(org_model_forshard, shardconfig).to('cuda')
# shard model
shard_config = ShardConfig(tensor_parallel_size=world_size)
model_copy = copy.deepcopy(org_model)
shard_former = ShardFormer(shard_config=shard_config)
shard_former.init_distributed()
sharded_model = shard_former.shard_model(model_copy)
return org_model, sharded_model
@ -38,6 +35,7 @@ def check_forward(org_model, sharded_model):
inputs = tokenizer(input, return_tensors='pt').to('cuda')
del inputs["token_type_ids"]
del inputs["attention_mask"]
#orgin model
org_model.eval()
org_out = org_model(**inputs)
@ -87,11 +85,20 @@ def check_backward(org_model, sharded_model):
def check_llama(rank, world_size, port):
disable_existing_loggers()
colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
model_list = [
LlamaForCausalLM,
# TODO: do not work yet
# LlamaModel,
# LlamaForSequenceClassification
]
org_model, sharded_model = build_model(rank, world_size)
check_forward(org_model, sharded_model)
check_backward(org_model, sharded_model)
for model_fn in model_list:
org_model, sharded_model = build_model(world_size, model_fn)
check_forward(org_model, sharded_model)
check_backward(org_model, sharded_model)
torch.cuda.empty_cache()

@ -8,7 +8,7 @@ from transformers import AutoTokenizer, BertConfig, BertForMaskedLM, T5Config, T
import colossalai
from colossalai.logging import disable_existing_loggers
from colossalai.shardformer.shard import ShardConfig, shard_model
from colossalai.shardformer.shard import ShardConfig, ShardFormer
from colossalai.testing import rerun_if_address_is_in_use, spawn
os.environ['TRANSFORMERS_NO_ADVISORY_WARNINGS'] = 'true'
@ -90,6 +90,7 @@ def check_t5(rank, world_size, port):
@pytest.mark.dist
@pytest.mark.skip
@rerun_if_address_is_in_use()
def test_t5():
spawn(check_t5, 2)

Loading…
Cancel
Save