ColossalAI/colossalai/shardformer/shard/sharder.py

239 lines
9.6 KiB
Python
Raw Normal View History

import torch
import torch.nn as nn
from typing import Any, Dict, Iterable, List, Optional, Tuple, Type, Union, Callable
from .shardconfig import ShardConfig
from dataclasses import dataclass
from ..policies.basepolicy import Policy, Layer
from ..policies.autopolicy import get_autopolicy
from .slicer import Slicer
from ..utils.utils import hasattr_, setattr_, getattr_
import colossalai.nn as col_nn
from colossalai.logging import get_dist_logger
import os
logger = get_dist_logger()
class ModelSharder(object):
"""
Shard the original huggingface model according to the policy
Args:
policy: The policy to shard the model
model: The model to shard
dist_setting: The setting of distributed model
"""
def __init__(
self,
model: nn.Module,
policy: Policy,
shard_config: ShardConfig = None, # TODO
) -> None:
self.model = model
self.policy = get_autopolicy(self.model) if policy is None else policy
self.slicer = Slicer(shard_config)
self.shard_config = shard_config
self.model_config = self.model.config
self.binding_map = {}
def shard(self) -> None:
self.inject_model(self.model)
self.replace_layer(self.model)
def inject_model(
self,
model: nn.Module,
) -> None:
"""
Replace the model to policy defined model
Mainly modify the forward and backward to fit distributed model
e.g.
BertForMaskedLM.forward -> BertForMaskedLM_.forward
"""
inject_policy = self.policy.inject_policy()
org_model_cls = inject_policy[0]
shard_model_cls = inject_policy[1]
if model.__class__ == org_model_cls:
for key in shard_model_cls.__dict__.keys():
if hasattr(model.__class__, key):
setattr(
model.__class__,
key,
getattr(shard_model_cls,key),
)
else:
raise NotImplementedError(f"{model.__class__} is not implemented so far")
def replace_layer(
self,
model: nn.Module,
) -> None:
"""
Replace the layer according to the policy, and replace the layer one by one
Args:
layer: The layer to shard
"""
argument_policies = self.policy.argument_policy(self.model_config, self.shard_config.world_size)
for argument_policy in argument_policies.items():
origin_layer_cls = argument_policy[0]
attr_dict = argument_policy[1].attr_dict
param_funcs = argument_policy[1].param_funcs
binding_layers = argument_policy[1].binding_layers
# if binding_layer is not None:
# self.binding_map[origin_layer_cls] = binding_layer
self.reverse_replace_layer(model, origin_layer_cls, attr_dict, param_funcs, binding_layers)
def reverse_replace_layer(
self,
layer: nn.Module,
origin_cls: nn.Module,
attr_dict: Dict[str, Any],
param_funcs: List[Callable],
binding_layers: List[nn.Module]
) -> None:
"""
Reverse the replace layer operation
Args:
layer: The object of layer to shard
origin_cls: The origin layer class
attr_dict: The attribute dict to modify
policy_cls: The policy class
"""
for name, child in layer.named_children():
if child.__class__ == origin_cls:
# replac_layer = child
for k, v in attr_dict.items():
setattr_(child, k, v, ignore=True)
# print(f"Sharding {name} layer", replac_layer.attention.self.__dict__)
# setattr_(layer, name, self.shard_one_layer(child, policy_cls))
self.shard_one_layer(child, param_funcs, binding_layers)
continue
self.reverse_replace_layer(child, origin_cls, attr_dict, param_funcs, binding_layers)
return layer
def shard_one_layer(
self,
org_layer: nn.Module,
param_funcs: List[Callable],
binding_layers: List[nn.Module]
) -> None:
"""
Shard one layer according to the policy, the layer should be the same class as the key in policy's argument_policy return dict
Args:
org_layer: The origin layer object to shard
param_funcs: The function list to get shard information in policy class
"""
# print(org_layer)
for func in param_funcs:
policy_layers = func()
for policy_layer in policy_layers:
weight = None
bias = None
weight_attr = policy_layer.weight
bias_attr = policy_layer.bias
replace_layer_cls = policy_layer.replace_layer
ignore = policy_layer.ignore
if policy_layer.__class__.__name__ == "Col_Layer":
gather_output = policy_layer.gather_output
print(gather_output)
if weight_attr is not None:
if hasattr_(org_layer, weight_attr):
weight = getattr_(org_layer, weight_attr)
elif not ignore:
raise ValueError(f"Layer {org_layer.__class__.__qualname__} has no attribute {weight_attr}")
if bias_attr is not None:
if hasattr_(org_layer, bias_attr):
bias = getattr_(org_layer, bias_attr)
elif not ignore:
raise ValueError(f"Layer {org_layer.__class__.__qualname__} has no attribute {bias_attr}")
# dont have the attribute in policy, and ignore is true
if weight is None and bias is None and ignore:
continue
# set the sliced weight and bias to the new nn_col layer
assert weight is not None or bias is not None
layer_attr = (lambda x: x[:x.rfind(".")])(weight_attr or bias_attr)
# slice weight and bias
weight, bias = self.slicer.slice_weight_bias(weight, bias, policy_layer.__class__)
print(os.environ['RANK'], policy_layer.__class__, weight.shape, bias.shape if bias is not None else None)
# save the binding information
for binding_layer in binding_layers:
self.binding_map[binding_layer] = dict(weight=weight, bias=bias)
# create new object to replace the origin layer
if replace_layer_cls is not None:
# print(f"RANK {os.environ['RANK']}: replace {getattr_(org_layer, layer_attr).__class__} to {replace_layer_cls}, shape is {weight.shape}")
if isinstance(getattr_(org_layer, layer_attr), nn.Linear):
if replace_layer_cls.__name__ == "Linear1D_Row":
replace_layer = replace_layer_cls(weight.shape[1], weight.shape[0], bias=False if bias is None else True)
elif replace_layer_cls.__name__ == "Linear1D_Col":
replace_layer = replace_layer_cls(weight.shape[0], weight.shape[1], bias=False if bias is None else True, gather_output=gather_output)
setattr_(org_layer, layer_attr, replace_layer, ignore=ignore)
self.set_param(replace_layer, weight, bias)
elif isinstance(getattr_(org_layer, layer_attr), nn.Embedding):
replace_layer = replace_layer_cls(weight.shape[0], weight.shape[1], getattr_(org_layer, f"{layer_attr}.padding_idx", ignore=True))
setattr_(org_layer, layer_attr, replace_layer, ignore=ignore)
self.set_param(replace_layer, weight, bias)
else:
raise NotImplementedError(f"Replacing {getattr_(org_layer, layer_attr).__class__} is not implemented so far")
# do not replace the layer object, just replace the weight and bias
else:
self.set_param(org_layer, layer_attr, weight, bias)
def set_param(
self,
layer: Any,
layer_attr: str = "",
weight: torch.Tensor = None,
bias: torch.Tensor = None
) -> None:
"""
Reset the weight and bias of the layer object
Args:
layer: The layer object
layer_attr: The attribute name of the layer
weight: The weight of the layer
bias: The bias of the layer
"""
assert weight is not None or bias is not None
if weight is not None:
setattr_(layer, "weight" if layer_attr == "" else layer_attr+".weight", nn.Parameter(weight))
self.set_layer_size(layer, layer_attr, weight.shape)
if bias is not None:
setattr_(layer, "bias" if layer_attr == "" else layer_attr+".bias", nn.Parameter(bias))
def set_layer_size(self, layer: nn.Module, layer_attr: str, size: torch.Size) -> None:
"""
Set the layer attribute
Args:
layer: The layer object
layer_attr: The attribute name of the layer
size: Torch.size
"""
# Tensor.shape[0] -> out_features, Tensor.shape[1] -> in_features
attrs = ["out_features", "in_features"]
for i, attr in enumerate(attrs):
if hasattr_(layer, f"{layer_attr}.{attr}"):
setattr_(layer, f"{layer_attr}.{attr}", size[i])