Browse Source

integrate with dist layer (#4011)

pull/4157/head
FoolPlayer 1 year ago committed by Frank Lee
parent
commit
dfca9678fa
  1. 28
      colossalai/shardformer/policies/bert.py
  2. 15
      colossalai/shardformer/shard/sharder.py
  3. 23
      tests/test_shardformer/test_model/test_shard_bert.py

28
colossalai/shardformer/policies/bert.py

@ -8,12 +8,6 @@ from ..utils import getattr_, setattr_
from .basepolicy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
class ParallelModule():
def __init__(self):
pass
class BertPolicy(Policy):
def preprocess(self, shard_config: ShardConfig = None):
@ -49,7 +43,27 @@ class BertPolicy(Policy):
sub_module_replacement=[
SubModuleReplacementDescription(
suffix="attention.self.query",
target_module=ParallelModule,
target_module=col_nn.Linear1D_Col,
),
SubModuleReplacementDescription(
suffix="attention.self.key",
target_module=col_nn.Linear1D_Col,
),
SubModuleReplacementDescription(
suffix="attention.self.value",
target_module=col_nn.Linear1D_Col,
),
SubModuleReplacementDescription(
suffix="attention.output.dense",
target_module=col_nn.Linear1D_Row,
),
SubModuleReplacementDescription(
suffix="intermediate.dense",
target_module=col_nn.Linear1D_Col,
),
SubModuleReplacementDescription(
suffix="output.dense",
target_module=col_nn.Linear1D_Row,
),
])
}

15
colossalai/shardformer/shard/sharder.py

@ -7,8 +7,8 @@ from transformers.pytorch_utils import Conv1D
from colossalai.cluster.process_group_manager import ProcessGroupManager
from ..policies.autopolicy import get_autopolicy
from ..policies.basepolicy import Policy
from ..utils.utils import setattr_
from ..policies.basepolicy import Policy, SubModuleReplacementDescription
from ..utils.utils import getattr_, setattr_
from .shard_config import ShardConfig
__all__ = ['ModelSharder', 'shard_model']
@ -90,9 +90,7 @@ class ModelSharder(object):
Args:
model (:class:`torch.nn.Module`): The model to shard
"""
print(self.policy)
module_descriptions = self.policy.module_policy(self.shard_config)
print(f"*******{module_descriptions}")
for module_description in module_descriptions.items():
origin_layer_cls = module_description[0]
attr_replacement = module_description[1].attribute_replacement
@ -160,7 +158,7 @@ class ModelSharder(object):
def _replace_sub_module(
self,
org_layer: nn.Module,
sub_module_replacement: List[Callable],
sub_module_replacement: List[SubModuleReplacementDescription],
) -> None:
r"""
Shard one layer according to the policy, the layer should be the same class as the key in policy's argument_policy return dict
@ -177,7 +175,8 @@ class ModelSharder(object):
assert target_module is not None, 'target_module should not be None'
# TODO: integrate with new layer
# replace_layer = target_module.from_native_layer(org_layer, self.pg_manager)
replace_layer = None
# TODO: support different parallel mode
native_sub_module = getattr_(org_layer, suffix)
replace_layer = target_module.from_native_module(native_sub_module, self.pg_manager.pg_store['tp1d'])
setattr_(org_layer, suffix, replace_layer)

23
tests/test_shardformer/test_model/test_shard_bert.py

@ -17,7 +17,7 @@ from transformers import (
import colossalai
from colossalai.logging import disable_existing_loggers
from colossalai.shardformer.shard import ShardConfig, shard_model
from colossalai.shardformer import ShardConfig, ShardFormer
from colossalai.testing import rerun_if_address_is_in_use, spawn
os.environ['TRANSFORMERS_NO_ADVISORY_WARNINGS'] = 'true'
@ -30,16 +30,21 @@ def build_model(rank, world_size, model):
config.hidden_dropout_prob = 0
config.attention_probs_dropout_prob = 0
org_model = model(config=config)
org_model = BertForMaskedLM.from_pretrained('bert-base-uncased', config=config)
org_model_forshard = copy.deepcopy(org_model)
org_model = org_model.to('cuda')
shardconfig = ShardConfig(
rank=rank,
world_size=world_size,
gather_output=True,
)
sharded_model = shard_model(org_model_forshard, shardconfig).to('cuda')
org_model.to('cuda')
# TODO: no need to transfer to cuda
org_model_forshard.to('cuda')
shard_config = ShardConfig(tensor_parallel_size=2,
data_parallel_size=1,
pipeline_parallel_size=1,
tensor_parallel_mode='1d',
inference_only=True,
gather_output=True)
shard_former = ShardFormer(shard_config=shard_config)
shard_former.init_distributed()
sharded_model = shard_former.shard_model(org_model_forshard).to('cuda')
return org_model, sharded_model

Loading…
Cancel
Save