mirror of https://github.com/hpcaitech/ColossalAI
integrate with dist layer (#4011)
parent
015af592f8
commit
dfca9678fa
|
@ -8,12 +8,6 @@ from ..utils import getattr_, setattr_
|
||||||
from .basepolicy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
|
from .basepolicy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
|
||||||
|
|
||||||
|
|
||||||
class ParallelModule():
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class BertPolicy(Policy):
|
class BertPolicy(Policy):
|
||||||
|
|
||||||
def preprocess(self, shard_config: ShardConfig = None):
|
def preprocess(self, shard_config: ShardConfig = None):
|
||||||
|
@ -49,7 +43,27 @@ class BertPolicy(Policy):
|
||||||
sub_module_replacement=[
|
sub_module_replacement=[
|
||||||
SubModuleReplacementDescription(
|
SubModuleReplacementDescription(
|
||||||
suffix="attention.self.query",
|
suffix="attention.self.query",
|
||||||
target_module=ParallelModule,
|
target_module=col_nn.Linear1D_Col,
|
||||||
|
),
|
||||||
|
SubModuleReplacementDescription(
|
||||||
|
suffix="attention.self.key",
|
||||||
|
target_module=col_nn.Linear1D_Col,
|
||||||
|
),
|
||||||
|
SubModuleReplacementDescription(
|
||||||
|
suffix="attention.self.value",
|
||||||
|
target_module=col_nn.Linear1D_Col,
|
||||||
|
),
|
||||||
|
SubModuleReplacementDescription(
|
||||||
|
suffix="attention.output.dense",
|
||||||
|
target_module=col_nn.Linear1D_Row,
|
||||||
|
),
|
||||||
|
SubModuleReplacementDescription(
|
||||||
|
suffix="intermediate.dense",
|
||||||
|
target_module=col_nn.Linear1D_Col,
|
||||||
|
),
|
||||||
|
SubModuleReplacementDescription(
|
||||||
|
suffix="output.dense",
|
||||||
|
target_module=col_nn.Linear1D_Row,
|
||||||
),
|
),
|
||||||
])
|
])
|
||||||
}
|
}
|
||||||
|
|
|
@ -7,8 +7,8 @@ from transformers.pytorch_utils import Conv1D
|
||||||
from colossalai.cluster.process_group_manager import ProcessGroupManager
|
from colossalai.cluster.process_group_manager import ProcessGroupManager
|
||||||
|
|
||||||
from ..policies.autopolicy import get_autopolicy
|
from ..policies.autopolicy import get_autopolicy
|
||||||
from ..policies.basepolicy import Policy
|
from ..policies.basepolicy import Policy, SubModuleReplacementDescription
|
||||||
from ..utils.utils import setattr_
|
from ..utils.utils import getattr_, setattr_
|
||||||
from .shard_config import ShardConfig
|
from .shard_config import ShardConfig
|
||||||
|
|
||||||
__all__ = ['ModelSharder', 'shard_model']
|
__all__ = ['ModelSharder', 'shard_model']
|
||||||
|
@ -90,9 +90,7 @@ class ModelSharder(object):
|
||||||
Args:
|
Args:
|
||||||
model (:class:`torch.nn.Module`): The model to shard
|
model (:class:`torch.nn.Module`): The model to shard
|
||||||
"""
|
"""
|
||||||
print(self.policy)
|
|
||||||
module_descriptions = self.policy.module_policy(self.shard_config)
|
module_descriptions = self.policy.module_policy(self.shard_config)
|
||||||
print(f"*******{module_descriptions}")
|
|
||||||
for module_description in module_descriptions.items():
|
for module_description in module_descriptions.items():
|
||||||
origin_layer_cls = module_description[0]
|
origin_layer_cls = module_description[0]
|
||||||
attr_replacement = module_description[1].attribute_replacement
|
attr_replacement = module_description[1].attribute_replacement
|
||||||
|
@ -160,7 +158,7 @@ class ModelSharder(object):
|
||||||
def _replace_sub_module(
|
def _replace_sub_module(
|
||||||
self,
|
self,
|
||||||
org_layer: nn.Module,
|
org_layer: nn.Module,
|
||||||
sub_module_replacement: List[Callable],
|
sub_module_replacement: List[SubModuleReplacementDescription],
|
||||||
) -> None:
|
) -> None:
|
||||||
r"""
|
r"""
|
||||||
Shard one layer according to the policy, the layer should be the same class as the key in policy's argument_policy return dict
|
Shard one layer according to the policy, the layer should be the same class as the key in policy's argument_policy return dict
|
||||||
|
@ -177,7 +175,8 @@ class ModelSharder(object):
|
||||||
|
|
||||||
assert target_module is not None, 'target_module should not be None'
|
assert target_module is not None, 'target_module should not be None'
|
||||||
|
|
||||||
# TODO: integrate with new layer
|
# TODO: support different parallel mode
|
||||||
# replace_layer = target_module.from_native_layer(org_layer, self.pg_manager)
|
native_sub_module = getattr_(org_layer, suffix)
|
||||||
replace_layer = None
|
replace_layer = target_module.from_native_module(native_sub_module, self.pg_manager.pg_store['tp1d'])
|
||||||
|
|
||||||
setattr_(org_layer, suffix, replace_layer)
|
setattr_(org_layer, suffix, replace_layer)
|
||||||
|
|
|
@ -17,7 +17,7 @@ from transformers import (
|
||||||
|
|
||||||
import colossalai
|
import colossalai
|
||||||
from colossalai.logging import disable_existing_loggers
|
from colossalai.logging import disable_existing_loggers
|
||||||
from colossalai.shardformer.shard import ShardConfig, shard_model
|
from colossalai.shardformer import ShardConfig, ShardFormer
|
||||||
from colossalai.testing import rerun_if_address_is_in_use, spawn
|
from colossalai.testing import rerun_if_address_is_in_use, spawn
|
||||||
|
|
||||||
os.environ['TRANSFORMERS_NO_ADVISORY_WARNINGS'] = 'true'
|
os.environ['TRANSFORMERS_NO_ADVISORY_WARNINGS'] = 'true'
|
||||||
|
@ -30,16 +30,21 @@ def build_model(rank, world_size, model):
|
||||||
config.hidden_dropout_prob = 0
|
config.hidden_dropout_prob = 0
|
||||||
config.attention_probs_dropout_prob = 0
|
config.attention_probs_dropout_prob = 0
|
||||||
|
|
||||||
org_model = model(config=config)
|
org_model = BertForMaskedLM.from_pretrained('bert-base-uncased', config=config)
|
||||||
org_model_forshard = copy.deepcopy(org_model)
|
org_model_forshard = copy.deepcopy(org_model)
|
||||||
|
|
||||||
org_model = org_model.to('cuda')
|
org_model.to('cuda')
|
||||||
shardconfig = ShardConfig(
|
# TODO: no need to transfer to cuda
|
||||||
rank=rank,
|
org_model_forshard.to('cuda')
|
||||||
world_size=world_size,
|
shard_config = ShardConfig(tensor_parallel_size=2,
|
||||||
gather_output=True,
|
data_parallel_size=1,
|
||||||
)
|
pipeline_parallel_size=1,
|
||||||
sharded_model = shard_model(org_model_forshard, shardconfig).to('cuda')
|
tensor_parallel_mode='1d',
|
||||||
|
inference_only=True,
|
||||||
|
gather_output=True)
|
||||||
|
shard_former = ShardFormer(shard_config=shard_config)
|
||||||
|
shard_former.init_distributed()
|
||||||
|
sharded_model = shard_former.shard_model(org_model_forshard).to('cuda')
|
||||||
|
|
||||||
return org_model, sharded_model
|
return org_model, sharded_model
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue