2023-07-10 05:58:58 +00:00
|
|
|
from types import MethodType
|
2023-07-03 07:29:11 +00:00
|
|
|
from typing import Any, Callable, Dict, List, Union
|
2023-05-24 02:26:46 +00:00
|
|
|
|
2023-05-22 07:02:17 +00:00
|
|
|
import torch.nn as nn
|
2023-07-05 06:16:55 +00:00
|
|
|
from torch import Tensor
|
|
|
|
|
2023-07-10 02:48:53 +00:00
|
|
|
from colossalai.lazy import LazyInitContext
|
2023-05-24 02:26:46 +00:00
|
|
|
|
2023-06-21 06:30:06 +00:00
|
|
|
from .._utils import getattr_, setattr_
|
2023-07-05 07:13:00 +00:00
|
|
|
from ..policies.auto_policy import get_autopolicy
|
|
|
|
from ..policies.base_policy import Policy, SubModuleReplacementDescription
|
2023-05-24 08:01:26 +00:00
|
|
|
from .shard_config import ShardConfig
|
2023-07-05 06:16:55 +00:00
|
|
|
from .utils import set_tensors_to_none
|
2023-05-22 07:02:17 +00:00
|
|
|
|
2023-05-24 08:01:26 +00:00
|
|
|
__all__ = ['ModelSharder', 'shard_model']
|
2023-05-22 07:02:17 +00:00
|
|
|
|
2023-05-24 02:26:46 +00:00
|
|
|
|
2023-05-22 07:02:17 +00:00
|
|
|
class ModelSharder(object):
|
2023-05-24 02:26:46 +00:00
|
|
|
r"""
|
2023-05-22 07:02:17 +00:00
|
|
|
Shard the original huggingface model according to the policy
|
|
|
|
|
|
|
|
Args:
|
2023-05-24 02:26:46 +00:00
|
|
|
policy (:class:`Policy`): The policy to shard the model
|
|
|
|
model (:class:`torch.Module`): The model to shard
|
|
|
|
shard_config: The setting of distributed model
|
2023-05-22 07:02:17 +00:00
|
|
|
"""
|
2023-05-24 02:26:46 +00:00
|
|
|
|
2023-06-30 01:58:08 +00:00
|
|
|
def __init__(self, model: nn.Module, policy: Policy, shard_config: ShardConfig = None) -> None:
|
2023-05-22 07:02:17 +00:00
|
|
|
self.model = model
|
|
|
|
self.policy = get_autopolicy(self.model) if policy is None else policy
|
|
|
|
self.shard_config = shard_config
|
|
|
|
|
2023-07-05 06:16:55 +00:00
|
|
|
def shard(self) -> List[Dict[int, Tensor]]:
|
2023-06-15 09:55:42 +00:00
|
|
|
r"""
|
|
|
|
Shard the model according to the policy
|
|
|
|
"""
|
|
|
|
self.policy.set_model(self.model)
|
2023-06-19 02:47:16 +00:00
|
|
|
self.policy.set_shard_config(self.shard_config)
|
2023-06-19 05:53:17 +00:00
|
|
|
self._preprocess()
|
2023-07-05 06:16:55 +00:00
|
|
|
self._release_unheld_layers()
|
2023-06-19 05:53:17 +00:00
|
|
|
self._replace_module()
|
2023-07-05 06:16:55 +00:00
|
|
|
self._materialize()
|
2023-06-19 05:53:17 +00:00
|
|
|
self._postprocess()
|
2023-07-05 06:16:55 +00:00
|
|
|
return self.policy.get_shared_params()
|
2023-05-24 02:26:46 +00:00
|
|
|
|
2023-06-19 05:53:17 +00:00
|
|
|
def _preprocess(self) -> None:
|
2023-06-19 02:47:16 +00:00
|
|
|
self.model = self.policy.preprocess()
|
2023-06-15 09:55:42 +00:00
|
|
|
|
2023-06-19 05:53:17 +00:00
|
|
|
def _postprocess(self) -> None:
|
2023-06-15 09:55:42 +00:00
|
|
|
self.model = self.policy.postprocess()
|
|
|
|
|
2023-06-19 05:53:17 +00:00
|
|
|
def _replace_module(self,) -> None:
|
2023-05-24 02:26:46 +00:00
|
|
|
r"""
|
2023-06-15 09:55:42 +00:00
|
|
|
Replace the module according to the policy, and replace the module one by one
|
2023-05-22 07:02:17 +00:00
|
|
|
|
|
|
|
Args:
|
2023-06-15 09:55:42 +00:00
|
|
|
model (:class:`torch.nn.Module`): The model to shard
|
2023-05-22 07:02:17 +00:00
|
|
|
"""
|
2023-06-19 02:47:16 +00:00
|
|
|
module_descriptions = self.policy.module_policy()
|
2023-07-03 07:29:11 +00:00
|
|
|
for layer_cls, module_description in module_descriptions.items():
|
|
|
|
attr_replacement = module_description.attribute_replacement
|
|
|
|
param_replacement = module_description.param_replacement
|
|
|
|
sub_module_replacement = module_description.sub_module_replacement
|
|
|
|
method_replacement = module_description.method_replacement
|
|
|
|
self._recursive_replace_layer(self.model, layer_cls, attr_replacement, param_replacement,
|
2023-06-28 07:04:35 +00:00
|
|
|
method_replacement, sub_module_replacement)
|
2023-06-15 09:55:42 +00:00
|
|
|
|
|
|
|
def _recursive_replace_layer(
|
2023-05-24 02:26:46 +00:00
|
|
|
self,
|
2023-06-15 09:55:42 +00:00
|
|
|
module: nn.Module,
|
2023-07-03 07:29:11 +00:00
|
|
|
origin_cls: Union[str, nn.Module],
|
2023-06-15 09:55:42 +00:00
|
|
|
attr_replacement: Dict[str, Any],
|
|
|
|
param_replacement: List[Callable],
|
2023-06-28 07:04:35 +00:00
|
|
|
method_replacement: Dict[str, Callable],
|
2023-07-13 07:34:06 +00:00
|
|
|
sub_module_replacement: List[SubModuleReplacementDescription],
|
2023-05-24 02:26:46 +00:00
|
|
|
) -> None:
|
|
|
|
r"""
|
2023-05-22 07:02:17 +00:00
|
|
|
Reverse the replace layer operation
|
|
|
|
|
|
|
|
Args:
|
2023-07-13 07:34:06 +00:00
|
|
|
module (torch.nn.Module): The object of layer to shard
|
|
|
|
origin_cls (Union[str, torch.nn.Module]): The origin layer class or a string of layer class name
|
|
|
|
attr_replacement (Dict[str, Any]): The attribute dict to modify
|
2023-07-04 09:53:39 +00:00
|
|
|
param_replacement (List[Callable]): The function list to get parameter shard information in policy
|
2023-07-13 07:34:06 +00:00
|
|
|
method_replacement (Dict[str, Callable]): Key is the method name, value is the method for replacement
|
|
|
|
sub_module_replacement ((List[SubModuleReplacementDescription]): The function list to get sub module shard information in policy
|
2023-05-22 07:02:17 +00:00
|
|
|
"""
|
2023-07-03 07:29:11 +00:00
|
|
|
if (isinstance(origin_cls, str) and origin_cls == module.__class__.__name__) or \
|
|
|
|
(module.__class__ == origin_cls):
|
|
|
|
if attr_replacement is not None:
|
|
|
|
self._replace_attr(module, attr_replacement)
|
|
|
|
|
|
|
|
if param_replacement is not None:
|
|
|
|
self._replace_param(module, param_replacement)
|
|
|
|
|
|
|
|
if method_replacement is not None:
|
|
|
|
self._replace_method(module, method_replacement)
|
|
|
|
|
|
|
|
if sub_module_replacement is not None:
|
|
|
|
self._replace_sub_module(module, sub_module_replacement)
|
2023-06-28 07:04:35 +00:00
|
|
|
|
2023-06-15 09:55:42 +00:00
|
|
|
for name, child in module.named_children():
|
2023-06-28 07:04:35 +00:00
|
|
|
self._recursive_replace_layer(child, origin_cls, attr_replacement, param_replacement, method_replacement,
|
2023-06-15 09:55:42 +00:00
|
|
|
sub_module_replacement)
|
|
|
|
|
|
|
|
def _replace_attr(
|
2023-05-24 02:26:46 +00:00
|
|
|
self,
|
2023-06-15 09:55:42 +00:00
|
|
|
module: nn.Module,
|
|
|
|
attr_replacement: Dict[str, Any],
|
2023-05-24 02:26:46 +00:00
|
|
|
) -> None:
|
|
|
|
r"""
|
2023-06-15 09:55:42 +00:00
|
|
|
Replace the attribute of the layer
|
2023-05-22 07:02:17 +00:00
|
|
|
|
|
|
|
Args:
|
2023-07-13 07:34:06 +00:00
|
|
|
module (:class:`torch.nn.Module`): The object of layer to shard
|
2023-06-15 09:55:42 +00:00
|
|
|
attr_replacement (Dict): The attribute dict to modify
|
2023-05-22 07:02:17 +00:00
|
|
|
"""
|
2023-06-15 09:55:42 +00:00
|
|
|
for k, v in attr_replacement.items():
|
|
|
|
setattr_(module, k, v, ignore=True)
|
2023-06-12 08:52:18 +00:00
|
|
|
|
2023-06-15 09:55:42 +00:00
|
|
|
def _replace_param(
|
|
|
|
self,
|
|
|
|
module: nn.Module,
|
|
|
|
param_replacement: List[Callable],
|
|
|
|
) -> None:
|
2023-05-24 02:26:46 +00:00
|
|
|
r"""
|
2023-06-15 09:55:42 +00:00
|
|
|
Replace the parameter of the layer
|
2023-05-22 07:02:17 +00:00
|
|
|
|
|
|
|
Args:
|
2023-07-13 07:34:06 +00:00
|
|
|
module (:class:`torch.nn.Module`): The object of layer to shard
|
2023-06-15 09:55:42 +00:00
|
|
|
param_replacement (List[Callable]): The function list to get parameter shard information in policy
|
2023-05-22 07:02:17 +00:00
|
|
|
"""
|
2023-07-03 07:29:11 +00:00
|
|
|
for param_func in param_replacement:
|
|
|
|
param_func(module)
|
2023-05-22 07:02:17 +00:00
|
|
|
|
2023-06-28 07:04:35 +00:00
|
|
|
def _replace_method(self, module: nn.Module, method_replacement: Dict[str, Callable]):
|
|
|
|
for method_name, new_method in method_replacement.items():
|
|
|
|
# bind the new method to the module
|
2023-07-10 05:58:58 +00:00
|
|
|
bound_method = MethodType(new_method, module)
|
|
|
|
setattr(module, method_name, bound_method)
|
2023-06-28 07:04:35 +00:00
|
|
|
|
2023-06-15 09:55:42 +00:00
|
|
|
def _replace_sub_module(
|
|
|
|
self,
|
|
|
|
org_layer: nn.Module,
|
2023-06-16 03:23:30 +00:00
|
|
|
sub_module_replacement: List[SubModuleReplacementDescription],
|
2023-06-15 09:55:42 +00:00
|
|
|
) -> None:
|
2023-05-24 02:26:46 +00:00
|
|
|
r"""
|
2023-06-15 09:55:42 +00:00
|
|
|
Shard one layer according to the policy, the layer should be the same class as the key in policy's argument_policy return dict
|
2023-05-22 07:02:17 +00:00
|
|
|
|
|
|
|
Args:
|
2023-07-03 07:29:11 +00:00
|
|
|
org_layer (torch.nn.Module): The origin layer object to shard
|
|
|
|
sub_module_replacement (List[SubModuleReplacementDescription]): The sub module replacement description list
|
2023-05-24 02:26:46 +00:00
|
|
|
|
|
|
|
"""
|
2023-06-15 09:55:42 +00:00
|
|
|
for description in sub_module_replacement:
|
|
|
|
suffix = description.suffix
|
|
|
|
target_module = description.target_module
|
2023-06-16 08:12:27 +00:00
|
|
|
kwargs = {} if description.kwargs is None else description.kwargs
|
2023-05-24 08:01:26 +00:00
|
|
|
|
2023-06-15 09:55:42 +00:00
|
|
|
assert target_module is not None, 'target_module should not be None'
|
2023-05-24 08:01:26 +00:00
|
|
|
|
2023-06-16 03:23:30 +00:00
|
|
|
# TODO: support different parallel mode
|
2023-06-19 09:57:37 +00:00
|
|
|
native_sub_module = getattr_(org_layer, suffix, ignore=True)
|
|
|
|
|
|
|
|
assert not isinstance(native_sub_module, target_module), \
|
|
|
|
f"The module with suffix {suffix} has been replaced, please check the policy"
|
|
|
|
|
|
|
|
# if it is None and we are allowed to ignore this module
|
|
|
|
# just skip
|
|
|
|
if description.ignore_if_not_exist and native_sub_module is None:
|
|
|
|
continue
|
|
|
|
|
2023-06-21 01:32:46 +00:00
|
|
|
try:
|
2023-06-30 01:58:08 +00:00
|
|
|
replace_layer = target_module.from_native_module(native_sub_module,
|
|
|
|
self.shard_config.tensor_parallel_process_group,
|
2023-06-21 01:32:46 +00:00
|
|
|
**kwargs)
|
|
|
|
except Exception as e:
|
|
|
|
raise RuntimeError(
|
|
|
|
f"Failed to replace {suffix} of type {native_sub_module.__class__.__qualname__}"
|
|
|
|
f" with {target_module.__qualname__} with the exception: {e}. "
|
|
|
|
"Please check your model configuration or sharding policy, you can set up an issue for us to help you as well."
|
|
|
|
)
|
2023-06-16 03:23:30 +00:00
|
|
|
|
2023-06-15 09:55:42 +00:00
|
|
|
setattr_(org_layer, suffix, replace_layer)
|
2023-07-05 06:16:55 +00:00
|
|
|
|
|
|
|
def _release_unheld_layers(self) -> None:
|
|
|
|
r"""
|
|
|
|
Release the unheld layers in the model
|
|
|
|
"""
|
|
|
|
if self.shard_config and self.shard_config.pipeline_stage_manager:
|
|
|
|
held_layers = self.policy.get_held_layers()
|
|
|
|
set_tensors_to_none(self.model, exclude=set(held_layers))
|
|
|
|
|
|
|
|
def _materialize(self) -> None:
|
|
|
|
r"""
|
|
|
|
Materialize the model if lazy initialization is used
|
|
|
|
"""
|
2023-07-10 02:48:53 +00:00
|
|
|
LazyInitContext.materialize(self.model)
|