ColossalAI/colossalai/shardformer/shard/sharder.py

277 lines
11 KiB
Python

from typing import Any, Callable, Dict, List
import torch
import torch.nn as nn
from transformers.pytorch_utils import Conv1D
from ..policies.autopolicy import get_autopolicy
from ..policies.basepolicy import Col_Layer, Dropout_Layer, Policy, Row_Layer
from ..utils.utils import getattr_, hasattr_, setattr_
from .shard_config import ShardConfig
from .slicer import Slicer
__all__ = ['ModelSharder', 'shard_model']
class ModelSharder(object):
r"""
Shard the original huggingface model according to the policy
Args:
policy (:class:`Policy`): The policy to shard the model
model (:class:`torch.Module`): The model to shard
shard_config: The setting of distributed model
"""
def __init__(
self,
model: nn.Module,
policy: Policy,
shard_config: ShardConfig = None, # TODO
) -> None:
self.model = model
self.policy = get_autopolicy(self.model) if policy is None else policy
self.slicer = Slicer(shard_config)
self.shard_config = shard_config
self.model_config = self.model.config
def shard(self) -> None:
self.reshape_embedding()
self.inject_model(self.model)
self.replace_layer(self.model)
self.bind_layer(self.model)
def reshape_embedding(self,) -> None:
r"""
Reshape the Embedding layer to make the embedding dimension divisible by world_size
"""
vocab_size = self.model_config.vocab_size
world_size = self.shard_config.world_size
if vocab_size % world_size != 0:
new_vocab_size = vocab_size + world_size - vocab_size % world_size
self.model.resize_token_embeddings(new_vocab_size)
self.model_config = self.model.config
def inject_model(
self,
model: nn.Module,
) -> None:
r"""
Replace the model to policy defined model
Mainly modify the forward and backward to fit distributed model
e.g.
::
BertForMaskedLM.forward -> BertForMaskedLM_.forward
"""
inject_policy = self.policy.inject_policy()
if inject_policy is None:
return
if inject_policy is None:
return
org_model_cls = inject_policy[0]
shard_model_cls = inject_policy[1]
if model.__class__ == org_model_cls:
for key in shard_model_cls.__dict__.keys():
if hasattr(model.__class__, key):
setattr(
model.__class__,
key,
getattr(shard_model_cls, key),
)
else:
raise NotImplementedError(f"{model.__class__} is not implemented so far")
def replace_layer(
self,
model: nn.Module,
) -> None:
r"""
Replace the layer according to the policy, and replace the layer one by one
Args:
model (:class:`torch.nn.Module`): The layer to shard
"""
argument_policies = self.policy.argument_policy(self.model_config, self.shard_config.world_size)
for argument_policy in argument_policies.items():
origin_layer_cls = argument_policy[0]
attr_dict = argument_policy[1].attr_dict
param_funcs = argument_policy[1].param_funcs
self.traverse_replace_layer(model, origin_layer_cls, attr_dict, param_funcs)
def traverse_replace_layer(
self,
layer: nn.Module,
origin_cls: nn.Module,
attr_dict: Dict[str, Any],
param_funcs: List[Callable],
) -> None:
r"""
Reverse the replace layer operation
Args:
layer (:class:`torch.nn.Module`): The object of layer to shard
origin_cls (:class:`transformers.model`): The origin layer class
attr_dict (Dict): The attribute dict to modify
policy_cls (:class:`Policy`): The policy class
"""
if layer.__class__ == origin_cls:
for k, v in attr_dict.items():
setattr_(layer, k, v, ignore=True)
self.shard_one_layer(layer, param_funcs)
for name, child in layer.named_children():
self.traverse_replace_layer(child, origin_cls, attr_dict, param_funcs)
return layer
def shard_one_layer(
self,
org_layer: nn.Module,
param_funcs: List[Callable],
) -> None:
r"""
Shard one layer according to the policy, the layer should be the same class as the key in policy's argument_policy return dict
Args:
org_layer (:class:`torch.nn.Module`): The origin layer object to shard
param_funcs (:class:`List[typing.Callable]`): The function list to get shard information in policy class
"""
for func in param_funcs:
policy_layers = func()
for policy_layer in policy_layers:
suffix = policy_layer.suffix
replace_layer_cls = policy_layer.replace_layer
ignore = policy_layer.ignore
reversed = policy_layer.reversed
n_cast = policy_layer.n_cast
assert replace_layer_cls is not None, 'replace_layer should not be None'
# create new object to replace the origin layer
# Linear
suffix_layer = getattr_(org_layer, suffix, ignore=True)
assert suffix_layer is not None or ignore, f"Layer {org_layer.__class__.__qualname__} has no attribute {suffix}"
if suffix_layer is None and ignore:
continue
if isinstance(policy_layer, (Col_Layer, Row_Layer)):
weight = None
bias = None
weight_attr = suffix + '.' + policy_layer.weight if policy_layer.weight is not None else None
bias_attr = suffix + '.' + policy_layer.bias if policy_layer.bias is not None else None
if weight_attr is not None:
if hasattr_(org_layer, weight_attr):
weight = getattr_(org_layer, weight_attr)
else:
raise ValueError(f"Layer {org_layer.__class__.__qualname__} has no attribute {weight_attr}")
if bias_attr is not None:
if hasattr_(org_layer, bias_attr):
bias = getattr_(org_layer, bias_attr)
else:
raise ValueError(f"Layer {org_layer.__class__.__qualname__} has no attribute {bias_attr}")
# set the sliced weight and bias to the new nn_col layer
assert weight is not None or bias is not None
# slice weight and bias
weight, bias = self.slicer.slice_weight_bias(weight, bias, policy_layer.__class__, n_cast, reversed)
if replace_layer_cls.__name__ == "Linear1D_Row":
replace_layer = replace_layer_cls(weight.shape[1],
weight.shape[0],
bias=False if bias is None else True)
elif replace_layer_cls.__name__ == "Linear1D_Col":
gather_output = policy_layer.gather_output and self.shard_config.gather_output
replace_layer = replace_layer_cls(weight.shape[0],
weight.shape[1],
bias=False if bias is None else True,
gather_output=gather_output)
elif replace_layer_cls.__name__ == "VocabParallelEmbedding1D":
replace_layer = replace_layer_cls(weight.shape[0], weight.shape[1],
getattr_(org_layer, f"{suffix}.padding_idx", ignore=True))
# setattr_(org_layer, suffix, replace_layer, ignore=ignore)
# self.set_param(replace_layer, weight, bias)
else:
raise NotImplementedError(
f"Replacing to {replace_layer_cls.__name__} is not implemented so far")
setattr_(org_layer, suffix, replace_layer, ignore=ignore)
self.set_param(replace_layer, weight, bias)
# dropout
elif isinstance(policy_layer, Dropout_Layer):
p_attr = suffix + '.' + policy_layer.p
p = getattr_(org_layer, p_attr, ignore=True)
replace_layer = replace_layer_cls(p)
setattr_(org_layer, suffix, replace_layer, ignore=ignore)
else:
raise NotImplementedError(
f"Replacing {getattr_(org_layer, suffix).__class__} is not implemented so far")
def set_param(self,
layer: Any,
weight: torch.Tensor = None,
bias: torch.Tensor = None,
layer_attr: str = "") -> None:
r"""
Reset the weight and bias of the layer object
Args:
layer (:class:`torch.nn.Module`): The layer object
layer_attr (str): The attribute name of the layer
weight (:class:`torch.Tensor`): The weight of the layer
bias (:class:`torch.Tensor`): The bias of the layer
"""
assert weight is not None or bias is not None
if weight is not None:
setattr_(layer, "weight" if layer_attr == "" else layer_attr + ".weight", nn.Parameter(weight.contiguous()))
self.set_layer_size(layer, layer_attr, weight.shape)
if bias is not None:
setattr_(layer, "bias" if layer_attr == "" else layer_attr + ".bias", nn.Parameter(bias.contiguous()))
def set_layer_size(self, layer: nn.Module, layer_attr: str, size: torch.Size) -> None:
r"""
Set the layer attribute
Args:
layer (:class:`torch.nn.Module`): The layer object
layer_attr (str): The attribute name of the layer
size (:class:`torch.Size`): The size of the tensor
"""
# Tensor.shape[0] -> out_features, Tensor.shape[1] -> in_features
attrs = ["out_features", "in_features"]
for i, attr in enumerate(attrs):
if hasattr_(layer, f"{layer_attr}.{attr}"):
setattr_(layer, f"{layer_attr}.{attr}", size[i])
def bind_layer(self, model: nn.Module) -> None:
r"""
Bind the layer according to the binding policy
Args:
model (:class:`torch.nn.Module`): The shard model
"""
binding_map = self.policy.binding_policy()
if binding_map is None:
return
for k, v in binding_map.items():
param = getattr_(model, k)
param = nn.Parameter(param)
setattr_(model, k, param)
setattr_(model, v, param)
def shard_model(model: nn.Module, shard_config: ShardConfig = None, policy: Policy = None):
r"""
The function is used to shard the PyTorch model.
Args:
model (`torch.nn.Model`): the origin huggingface model
shard_config (`ShardConfig`): the config for distribute information
policy (`Policy`): the custom policy for sharding
"""
sharder = ModelSharder(model=model, shard_config=shard_config, policy=policy)
sharder.shard()
return model