integrate with dist layer (#4011)

pull/4157/head
FoolPlayer 2023-06-16 11:23:30 +08:00 committed by Frank Lee
parent 015af592f8
commit dfca9678fa
3 changed files with 42 additions and 24 deletions

View File

@ -8,12 +8,6 @@ from ..utils import getattr_, setattr_
from .basepolicy import ModulePolicyDescription, Policy, SubModuleReplacementDescription from .basepolicy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
class ParallelModule():
def __init__(self):
pass
class BertPolicy(Policy): class BertPolicy(Policy):
def preprocess(self, shard_config: ShardConfig = None): def preprocess(self, shard_config: ShardConfig = None):
@ -49,7 +43,27 @@ class BertPolicy(Policy):
sub_module_replacement=[ sub_module_replacement=[
SubModuleReplacementDescription( SubModuleReplacementDescription(
suffix="attention.self.query", suffix="attention.self.query",
target_module=ParallelModule, target_module=col_nn.Linear1D_Col,
),
SubModuleReplacementDescription(
suffix="attention.self.key",
target_module=col_nn.Linear1D_Col,
),
SubModuleReplacementDescription(
suffix="attention.self.value",
target_module=col_nn.Linear1D_Col,
),
SubModuleReplacementDescription(
suffix="attention.output.dense",
target_module=col_nn.Linear1D_Row,
),
SubModuleReplacementDescription(
suffix="intermediate.dense",
target_module=col_nn.Linear1D_Col,
),
SubModuleReplacementDescription(
suffix="output.dense",
target_module=col_nn.Linear1D_Row,
), ),
]) ])
} }

View File

@ -7,8 +7,8 @@ from transformers.pytorch_utils import Conv1D
from colossalai.cluster.process_group_manager import ProcessGroupManager from colossalai.cluster.process_group_manager import ProcessGroupManager
from ..policies.autopolicy import get_autopolicy from ..policies.autopolicy import get_autopolicy
from ..policies.basepolicy import Policy from ..policies.basepolicy import Policy, SubModuleReplacementDescription
from ..utils.utils import setattr_ from ..utils.utils import getattr_, setattr_
from .shard_config import ShardConfig from .shard_config import ShardConfig
__all__ = ['ModelSharder', 'shard_model'] __all__ = ['ModelSharder', 'shard_model']
@ -90,9 +90,7 @@ class ModelSharder(object):
Args: Args:
model (:class:`torch.nn.Module`): The model to shard model (:class:`torch.nn.Module`): The model to shard
""" """
print(self.policy)
module_descriptions = self.policy.module_policy(self.shard_config) module_descriptions = self.policy.module_policy(self.shard_config)
print(f"*******{module_descriptions}")
for module_description in module_descriptions.items(): for module_description in module_descriptions.items():
origin_layer_cls = module_description[0] origin_layer_cls = module_description[0]
attr_replacement = module_description[1].attribute_replacement attr_replacement = module_description[1].attribute_replacement
@ -160,7 +158,7 @@ class ModelSharder(object):
def _replace_sub_module( def _replace_sub_module(
self, self,
org_layer: nn.Module, org_layer: nn.Module,
sub_module_replacement: List[Callable], sub_module_replacement: List[SubModuleReplacementDescription],
) -> None: ) -> None:
r""" r"""
Shard one layer according to the policy, the layer should be the same class as the key in policy's argument_policy return dict Shard one layer according to the policy, the layer should be the same class as the key in policy's argument_policy return dict
@ -177,7 +175,8 @@ class ModelSharder(object):
assert target_module is not None, 'target_module should not be None' assert target_module is not None, 'target_module should not be None'
# TODO: integrate with new layer # TODO: support different parallel mode
# replace_layer = target_module.from_native_layer(org_layer, self.pg_manager) native_sub_module = getattr_(org_layer, suffix)
replace_layer = None replace_layer = target_module.from_native_module(native_sub_module, self.pg_manager.pg_store['tp1d'])
setattr_(org_layer, suffix, replace_layer) setattr_(org_layer, suffix, replace_layer)

View File

@ -17,7 +17,7 @@ from transformers import (
import colossalai import colossalai
from colossalai.logging import disable_existing_loggers from colossalai.logging import disable_existing_loggers
from colossalai.shardformer.shard import ShardConfig, shard_model from colossalai.shardformer import ShardConfig, ShardFormer
from colossalai.testing import rerun_if_address_is_in_use, spawn from colossalai.testing import rerun_if_address_is_in_use, spawn
os.environ['TRANSFORMERS_NO_ADVISORY_WARNINGS'] = 'true' os.environ['TRANSFORMERS_NO_ADVISORY_WARNINGS'] = 'true'
@ -30,16 +30,21 @@ def build_model(rank, world_size, model):
config.hidden_dropout_prob = 0 config.hidden_dropout_prob = 0
config.attention_probs_dropout_prob = 0 config.attention_probs_dropout_prob = 0
org_model = model(config=config) org_model = BertForMaskedLM.from_pretrained('bert-base-uncased', config=config)
org_model_forshard = copy.deepcopy(org_model) org_model_forshard = copy.deepcopy(org_model)
org_model = org_model.to('cuda') org_model.to('cuda')
shardconfig = ShardConfig( # TODO: no need to transfer to cuda
rank=rank, org_model_forshard.to('cuda')
world_size=world_size, shard_config = ShardConfig(tensor_parallel_size=2,
gather_output=True, data_parallel_size=1,
) pipeline_parallel_size=1,
sharded_model = shard_model(org_model_forshard, shardconfig).to('cuda') tensor_parallel_mode='1d',
inference_only=True,
gather_output=True)
shard_former = ShardFormer(shard_config=shard_config)
shard_former.init_distributed()
sharded_model = shard_former.shard_model(org_model_forshard).to('cuda')
return org_model, sharded_model return org_model, sharded_model