mirror of https://github.com/hpcaitech/ColossalAI
integrate with dist layer (#4011)
parent
015af592f8
commit
dfca9678fa
|
@ -8,12 +8,6 @@ from ..utils import getattr_, setattr_
|
|||
from .basepolicy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
|
||||
|
||||
|
||||
class ParallelModule():
|
||||
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
|
||||
class BertPolicy(Policy):
|
||||
|
||||
def preprocess(self, shard_config: ShardConfig = None):
|
||||
|
@ -49,7 +43,27 @@ class BertPolicy(Policy):
|
|||
sub_module_replacement=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="attention.self.query",
|
||||
target_module=ParallelModule,
|
||||
target_module=col_nn.Linear1D_Col,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="attention.self.key",
|
||||
target_module=col_nn.Linear1D_Col,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="attention.self.value",
|
||||
target_module=col_nn.Linear1D_Col,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="attention.output.dense",
|
||||
target_module=col_nn.Linear1D_Row,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="intermediate.dense",
|
||||
target_module=col_nn.Linear1D_Col,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="output.dense",
|
||||
target_module=col_nn.Linear1D_Row,
|
||||
),
|
||||
])
|
||||
}
|
||||
|
|
|
@ -7,8 +7,8 @@ from transformers.pytorch_utils import Conv1D
|
|||
from colossalai.cluster.process_group_manager import ProcessGroupManager
|
||||
|
||||
from ..policies.autopolicy import get_autopolicy
|
||||
from ..policies.basepolicy import Policy
|
||||
from ..utils.utils import setattr_
|
||||
from ..policies.basepolicy import Policy, SubModuleReplacementDescription
|
||||
from ..utils.utils import getattr_, setattr_
|
||||
from .shard_config import ShardConfig
|
||||
|
||||
__all__ = ['ModelSharder', 'shard_model']
|
||||
|
@ -90,9 +90,7 @@ class ModelSharder(object):
|
|||
Args:
|
||||
model (:class:`torch.nn.Module`): The model to shard
|
||||
"""
|
||||
print(self.policy)
|
||||
module_descriptions = self.policy.module_policy(self.shard_config)
|
||||
print(f"*******{module_descriptions}")
|
||||
for module_description in module_descriptions.items():
|
||||
origin_layer_cls = module_description[0]
|
||||
attr_replacement = module_description[1].attribute_replacement
|
||||
|
@ -160,7 +158,7 @@ class ModelSharder(object):
|
|||
def _replace_sub_module(
|
||||
self,
|
||||
org_layer: nn.Module,
|
||||
sub_module_replacement: List[Callable],
|
||||
sub_module_replacement: List[SubModuleReplacementDescription],
|
||||
) -> None:
|
||||
r"""
|
||||
Shard one layer according to the policy, the layer should be the same class as the key in policy's argument_policy return dict
|
||||
|
@ -177,7 +175,8 @@ class ModelSharder(object):
|
|||
|
||||
assert target_module is not None, 'target_module should not be None'
|
||||
|
||||
# TODO: integrate with new layer
|
||||
# replace_layer = target_module.from_native_layer(org_layer, self.pg_manager)
|
||||
replace_layer = None
|
||||
# TODO: support different parallel mode
|
||||
native_sub_module = getattr_(org_layer, suffix)
|
||||
replace_layer = target_module.from_native_module(native_sub_module, self.pg_manager.pg_store['tp1d'])
|
||||
|
||||
setattr_(org_layer, suffix, replace_layer)
|
||||
|
|
|
@ -17,7 +17,7 @@ from transformers import (
|
|||
|
||||
import colossalai
|
||||
from colossalai.logging import disable_existing_loggers
|
||||
from colossalai.shardformer.shard import ShardConfig, shard_model
|
||||
from colossalai.shardformer import ShardConfig, ShardFormer
|
||||
from colossalai.testing import rerun_if_address_is_in_use, spawn
|
||||
|
||||
os.environ['TRANSFORMERS_NO_ADVISORY_WARNINGS'] = 'true'
|
||||
|
@ -30,16 +30,21 @@ def build_model(rank, world_size, model):
|
|||
config.hidden_dropout_prob = 0
|
||||
config.attention_probs_dropout_prob = 0
|
||||
|
||||
org_model = model(config=config)
|
||||
org_model = BertForMaskedLM.from_pretrained('bert-base-uncased', config=config)
|
||||
org_model_forshard = copy.deepcopy(org_model)
|
||||
|
||||
org_model = org_model.to('cuda')
|
||||
shardconfig = ShardConfig(
|
||||
rank=rank,
|
||||
world_size=world_size,
|
||||
gather_output=True,
|
||||
)
|
||||
sharded_model = shard_model(org_model_forshard, shardconfig).to('cuda')
|
||||
org_model.to('cuda')
|
||||
# TODO: no need to transfer to cuda
|
||||
org_model_forshard.to('cuda')
|
||||
shard_config = ShardConfig(tensor_parallel_size=2,
|
||||
data_parallel_size=1,
|
||||
pipeline_parallel_size=1,
|
||||
tensor_parallel_mode='1d',
|
||||
inference_only=True,
|
||||
gather_output=True)
|
||||
shard_former = ShardFormer(shard_config=shard_config)
|
||||
shard_former.init_distributed()
|
||||
sharded_model = shard_former.shard_model(org_model_forshard).to('cuda')
|
||||
|
||||
return org_model, sharded_model
|
||||
|
||||
|
|
Loading…
Reference in New Issue