ColossalAI/colossalai/shardformer/shard/sharder.py

201 lines
8.1 KiB
Python

from types import MethodType
from typing import Any, Callable, Dict, List, Union
import torch.nn as nn
from torch import Tensor
from colossalai.lazy import LazyInitContext
from .._utils import getattr_, setattr_
from ..policies.auto_policy import get_autopolicy
from ..policies.base_policy import Policy, SubModuleReplacementDescription
from .shard_config import ShardConfig
from .utils import set_tensors_to_none
__all__ = ['ModelSharder', 'shard_model']
class ModelSharder(object):
r"""
Shard the original huggingface model according to the policy
Args:
policy (:class:`Policy`): The policy to shard the model
model (:class:`torch.Module`): The model to shard
shard_config: The setting of distributed model
"""
def __init__(self, model: nn.Module, policy: Policy, shard_config: ShardConfig = None) -> None:
self.model = model
self.policy = get_autopolicy(self.model) if policy is None else policy
self.shard_config = shard_config
def shard(self) -> List[Dict[int, Tensor]]:
r"""
Shard the model according to the policy
"""
self.policy.set_model(self.model)
self.policy.set_shard_config(self.shard_config)
self._preprocess()
# get shared params before release unheld layers, this avoid misjudgement of shared params (None is None)
shared_params = self.policy.get_shared_params()
self._release_unheld_layers()
self._replace_module()
self._materialize()
self._postprocess()
return shared_params
def _preprocess(self) -> None:
self.model = self.policy.preprocess()
def _postprocess(self) -> None:
self.model = self.policy.postprocess()
def _replace_module(self,) -> None:
r"""
Replace the module according to the policy, and replace the module one by one
Args:
model (:class:`torch.nn.Module`): The model to shard
"""
module_descriptions = self.policy.module_policy()
for layer_cls, module_description in module_descriptions.items():
attr_replacement = module_description.attribute_replacement
param_replacement = module_description.param_replacement
sub_module_replacement = module_description.sub_module_replacement
method_replacement = module_description.method_replacement
self._recursive_replace_layer(self.model, layer_cls, attr_replacement, param_replacement,
method_replacement, sub_module_replacement)
def _recursive_replace_layer(
self,
module: nn.Module,
origin_cls: Union[str, nn.Module],
attr_replacement: Dict[str, Any],
param_replacement: List[Callable],
method_replacement: Dict[str, Callable],
sub_module_replacement: List[SubModuleReplacementDescription],
) -> None:
r"""
Reverse the replace layer operation
Args:
module (torch.nn.Module): The object of layer to shard
origin_cls (Union[str, torch.nn.Module]): The origin layer class or a string of layer class name
attr_replacement (Dict[str, Any]): The attribute dict to modify
param_replacement (List[Callable]): The function list to get parameter shard information in policy
method_replacement (Dict[str, Callable]): Key is the method name, value is the method for replacement
sub_module_replacement ((List[SubModuleReplacementDescription]): The function list to get sub module shard information in policy
"""
if (isinstance(origin_cls, str) and origin_cls == module.__class__.__name__) or \
(module.__class__ == origin_cls):
if attr_replacement is not None:
self._replace_attr(module, attr_replacement)
if param_replacement is not None:
self._replace_param(module, param_replacement)
if method_replacement is not None:
self._replace_method(module, method_replacement)
if sub_module_replacement is not None:
self._replace_sub_module(module, sub_module_replacement)
for name, child in module.named_children():
self._recursive_replace_layer(child, origin_cls, attr_replacement, param_replacement, method_replacement,
sub_module_replacement)
def _replace_attr(
self,
module: nn.Module,
attr_replacement: Dict[str, Any],
) -> None:
r"""
Replace the attribute of the layer
Args:
module (:class:`torch.nn.Module`): The object of layer to shard
attr_replacement (Dict): The attribute dict to modify
"""
for k, v in attr_replacement.items():
setattr_(module, k, v, ignore=True)
def _replace_param(
self,
module: nn.Module,
param_replacement: List[Callable],
) -> None:
r"""
Replace the parameter of the layer
Args:
module (:class:`torch.nn.Module`): The object of layer to shard
param_replacement (List[Callable]): The function list to get parameter shard information in policy
"""
for param_func in param_replacement:
param_func(module)
def _replace_method(self, module: nn.Module, method_replacement: Dict[str, Callable]):
for method_name, new_method in method_replacement.items():
# bind the new method to the module
bound_method = MethodType(new_method, module)
setattr(module, method_name, bound_method)
def _replace_sub_module(
self,
org_layer: nn.Module,
sub_module_replacement: List[SubModuleReplacementDescription],
) -> None:
r"""
Shard one layer according to the policy, the layer should be the same class as the key in policy's argument_policy return dict
Args:
org_layer (torch.nn.Module): The origin layer object to shard
sub_module_replacement (List[SubModuleReplacementDescription]): The sub module replacement description list
"""
for description in sub_module_replacement:
suffix = description.suffix
target_module = description.target_module
kwargs = {} if description.kwargs is None else description.kwargs
assert target_module is not None, 'target_module should not be None'
# TODO: support different parallel mode
native_sub_module = getattr_(org_layer, suffix, ignore=True)
assert not isinstance(native_sub_module, target_module), \
f"The module with suffix {suffix} has been replaced, please check the policy"
# if it is None and we are allowed to ignore this module
# just skip
if description.ignore_if_not_exist and native_sub_module is None:
continue
try:
replace_layer = target_module.from_native_module(native_sub_module,
self.shard_config.tensor_parallel_process_group,
**kwargs)
except Exception as e:
raise RuntimeError(
f"Failed to replace {suffix} of type {native_sub_module.__class__.__qualname__}"
f" with {target_module.__qualname__} with the exception: {e}. "
"Please check your model configuration or sharding policy, you can set up an issue for us to help you as well."
)
setattr_(org_layer, suffix, replace_layer)
def _release_unheld_layers(self) -> None:
r"""
Release the unheld layers in the model
"""
if self.shard_config and self.shard_config.pipeline_stage_manager:
held_layers = self.policy.get_held_layers()
set_tensors_to_none(self.model, exclude=set(held_layers))
def _materialize(self) -> None:
r"""
Materialize the model if lazy initialization is used
"""
LazyInitContext.materialize(self.model)