|
|
|
@ -7,8 +7,8 @@ from transformers.pytorch_utils import Conv1D
|
|
|
|
|
from colossalai.cluster.process_group_manager import ProcessGroupManager |
|
|
|
|
|
|
|
|
|
from ..policies.autopolicy import get_autopolicy |
|
|
|
|
from ..policies.basepolicy import Policy |
|
|
|
|
from ..utils.utils import setattr_ |
|
|
|
|
from ..policies.basepolicy import Policy, SubModuleReplacementDescription |
|
|
|
|
from ..utils.utils import getattr_, setattr_ |
|
|
|
|
from .shard_config import ShardConfig |
|
|
|
|
|
|
|
|
|
__all__ = ['ModelSharder', 'shard_model'] |
|
|
|
@ -90,9 +90,7 @@ class ModelSharder(object):
|
|
|
|
|
Args: |
|
|
|
|
model (:class:`torch.nn.Module`): The model to shard |
|
|
|
|
""" |
|
|
|
|
print(self.policy) |
|
|
|
|
module_descriptions = self.policy.module_policy(self.shard_config) |
|
|
|
|
print(f"*******{module_descriptions}") |
|
|
|
|
for module_description in module_descriptions.items(): |
|
|
|
|
origin_layer_cls = module_description[0] |
|
|
|
|
attr_replacement = module_description[1].attribute_replacement |
|
|
|
@ -160,7 +158,7 @@ class ModelSharder(object):
|
|
|
|
|
def _replace_sub_module( |
|
|
|
|
self, |
|
|
|
|
org_layer: nn.Module, |
|
|
|
|
sub_module_replacement: List[Callable], |
|
|
|
|
sub_module_replacement: List[SubModuleReplacementDescription], |
|
|
|
|
) -> None: |
|
|
|
|
r""" |
|
|
|
|
Shard one layer according to the policy, the layer should be the same class as the key in policy's argument_policy return dict |
|
|
|
@ -177,7 +175,8 @@ class ModelSharder(object):
|
|
|
|
|
|
|
|
|
|
assert target_module is not None, 'target_module should not be None' |
|
|
|
|
|
|
|
|
|
# TODO: integrate with new layer |
|
|
|
|
# replace_layer = target_module.from_native_layer(org_layer, self.pg_manager) |
|
|
|
|
replace_layer = None |
|
|
|
|
# TODO: support different parallel mode |
|
|
|
|
native_sub_module = getattr_(org_layer, suffix) |
|
|
|
|
replace_layer = target_module.from_native_module(native_sub_module, self.pg_manager.pg_store['tp1d']) |
|
|
|
|
|
|
|
|
|
setattr_(org_layer, suffix, replace_layer) |
|
|
|
|