ColossalAI/colossalai/shardformer/shard/sharder.py

238 lines
9.5 KiB
Python
Raw Normal View History

from types import MethodType
from typing import Any, Callable, Dict, List, Optional, Set, Union
import torch.nn as nn
2023-07-05 06:16:55 +00:00
from torch import Tensor
from colossalai.lazy import LazyInitContext
from .._utils import getattr_, setattr_
2023-07-05 07:13:00 +00:00
from ..policies.auto_policy import get_autopolicy
from ..policies.base_policy import Policy, SubModuleReplacementDescription
from .shard_config import ShardConfig
2023-07-05 06:16:55 +00:00
from .utils import set_tensors_to_none
__all__ = ["ModelSharder", "shard_model"]
class ModelSharder(object):
r"""
Shard the original huggingface model according to the policy
Args:
policy (:class:`Policy`): The policy to shard the model
model (:class:`torch.Module`): The model to shard
shard_config: The setting of distributed model
"""
def __init__(self, model: nn.Module, policy: Policy, shard_config: ShardConfig = None) -> None:
self.model = model
[Feature] The first PR to Add TP inference engine, kv-cache manager and related kernels for our inference system (#4577) * [infer] Infer/llama demo (#4503) * add * add infer example * finish * finish * stash * fix * [Kernels] add inference token attention kernel (#4505) * add token forward * fix tests * fix comments * add try import triton * add adapted license * add tests check * [Kernels] add necessary kernels (llama & bloom) for attention forward and kv-cache manager (#4485) * added _vllm_rms_norm * change place * added tests * added tests * modify * adding kernels * added tests: * adding kernels * modify * added * updating kernels * adding tests * added tests * kernel change * submit * modify * added * edit comments * change name * change commnets and fix import * add * added * combine codes (#4509) * [feature] add KV cache manager for llama & bloom inference (#4495) * add kv cache memory manager * add stateinfo during inference * format * format * rename file * add kv cache test * revise on BatchInferState * file dir change * [Bug FIx] import llama context ops fix (#4524) * added _vllm_rms_norm * change place * added tests * added tests * modify * adding kernels * added tests: * adding kernels * modify * added * updating kernels * adding tests * added tests * kernel change * submit * modify * added * edit comments * change name * change commnets and fix import * add * added * fix * add ops into init.py * add * [Infer] Add TPInferEngine and fix file path (#4532) * add engine for TP inference * move file path * update path * fix TPInferEngine * remove unused file * add engine test demo * revise TPInferEngine * fix TPInferEngine, add test * fix * Add Inference test for llama (#4508) * add kv cache memory manager * add stateinfo during inference * add * add infer example * finish * finish * format * format * rename file * add kv cache test * revise on BatchInferState * add inference test for llama * fix conflict * feature: add some new features for llama engine * adapt colossalai triton interface * Change the parent class of llama policy * add nvtx * move llama inference code to tensor_parallel * fix __init__.py * rm tensor_parallel * fix: fix bugs in auto_policy.py * fix:rm some unused codes * mv colossalai/tpinference to colossalai/inference/tensor_parallel * change __init__.py * save change * fix engine * Bug fix: Fix hang * remove llama_infer_engine.py --------- Co-authored-by: yuanheng-zhao <jonathan.zhaoyh@gmail.com> Co-authored-by: CjhHa1 <cjh18671720497@outlook.com> * [infer] Add Bloom inference policy and replaced methods (#4512) * add bloom inference methods and policy * enable pass BatchInferState from model forward * revise bloom infer layers/policies * add engine for inference (draft) * add test for bloom infer * fix bloom infer policy and flow * revise bloom test * fix bloom file path * remove unused codes * fix bloom modeling * fix dir typo * fix trivial * fix policy * clean pr * trivial fix * Revert "[infer] Add Bloom inference policy and replaced methods (#4512)" (#4552) This reverts commit 17cfa5714083a81a505c097f1c411cd28162d922. * [Doc] Add colossal inference doc (#4549) * create readme * add readme.md * fix typos * [infer] Add Bloom inference policy and replaced methods (#4553) * add bloom inference methods and policy * enable pass BatchInferState from model forward * revise bloom infer layers/policies * add engine for inference (draft) * add test for bloom infer * fix bloom infer policy and flow * revise bloom test * fix bloom file path * remove unused codes * fix bloom modeling * fix dir typo * fix trivial * fix policy * clean pr * trivial fix * trivial * Fix Bugs In Llama Model Forward (#4550) * add kv cache memory manager * add stateinfo during inference * add * add infer example * finish * finish * format * format * rename file * add kv cache test * revise on BatchInferState * add inference test for llama * fix conflict * feature: add some new features for llama engine * adapt colossalai triton interface * Change the parent class of llama policy * add nvtx * move llama inference code to tensor_parallel * fix __init__.py * rm tensor_parallel * fix: fix bugs in auto_policy.py * fix:rm some unused codes * mv colossalai/tpinference to colossalai/inference/tensor_parallel * change __init__.py * save change * fix engine * Bug fix: Fix hang * remove llama_infer_engine.py * bug fix: fix bugs about infer_state.is_context_stage * remove pollcies * fix: delete unused code * fix: delete unused code * remove unused coda * fix conflict --------- Co-authored-by: yuanheng-zhao <jonathan.zhaoyh@gmail.com> Co-authored-by: CjhHa1 <cjh18671720497@outlook.com> * [doc] add colossal inference fig (#4554) * create readme * add readme.md * fix typos * upload fig * [NFC] fix docstring for colossal inference (#4555) Fix docstring and comments in kv cache manager and bloom modeling * fix docstring in llama modeling (#4557) * [Infer] check import vllm (#4559) * change import vllm * import apply_rotary_pos_emb * change import location * [DOC] add installation req (#4561) * add installation req * fix * slight change * remove empty * [Feature] rms-norm transfer into inference llama.py (#4563) * add installation req * fix * slight change * remove empty * add rmsnorm polciy * add * clean codes * [infer] Fix tp inference engine (#4564) * fix engine prepare data * add engine test * use bloom for testing * revise on test * revise on test * reset shardformer llama (#4569) * [infer] Fix engine - tensors on different devices (#4570) * fix diff device in engine * [codefactor] Feature/colossal inference (#4579) * code factors * remove * change coding (#4581) * [doc] complete README of colossal inference (#4585) * complete fig * Update README.md * [doc]update readme (#4586) * update readme * Update README.md * bug fix: fix bus in llama and bloom (#4588) * [BUG FIX]Fix test engine in CI and non-vllm kernels llama forward (#4592) * fix tests * clean * clean * fix bugs * add * fix llama non-vllm kernels bug * modify * clean codes * [Kernel]Rmsnorm fix (#4598) * fix tests * clean * clean * fix bugs * add * fix llama non-vllm kernels bug * modify * clean codes * add triton rmsnorm * delete vllm kernel flag * [Bug Fix]Fix bugs in llama (#4601) * fix tests * clean * clean * fix bugs * add * fix llama non-vllm kernels bug * modify * clean codes * bug fix: remove rotary_positions_ids --------- Co-authored-by: cuiqing.li <lixx3527@gmail.com> * [kernel] Add triton layer norm & replace norm for bloom (#4609) * add layernorm for inference * add test for layernorm kernel * add bloom layernorm replacement policy * trivial: path * [Infer] Bug fix rotary embedding in llama (#4608) * fix rotary embedding * delete print * fix init seq len bug * rename pytest * add benchmark for llama * refactor codes * delete useless code * [bench] Add bloom inference benchmark (#4621) * add bloom benchmark * readme - update benchmark res * trivial - uncomment for testing (#4622) * [Infer] add check triton and cuda version for tests (#4627) * fix rotary embedding * delete print * fix init seq len bug * rename pytest * add benchmark for llama * refactor codes * delete useless code * add check triton and cuda * Update sharder.py (#4629) * [Inference] Hot fix some bugs and typos (#4632) * fix * fix test * fix conflicts * [typo]Comments fix (#4633) * fallback * fix commnets * bug fix: fix some bugs in test_llama and test_bloom (#4635) * [Infer] delete benchmark in tests and fix bug for llama and bloom (#4636) * fix rotary embedding * delete print * fix init seq len bug * rename pytest * add benchmark for llama * refactor codes * delete useless code * add check triton and cuda * delete benchmark and fix infer bugs * delete benchmark for tests * delete useless code * delete bechmark function in utils * [Fix] Revise TPInferEngine, inference tests and benchmarks (#4642) * [Fix] revise TPInferEngine methods and inference tests * fix llama/bloom infer benchmarks * fix infer tests * trivial fix: benchmakrs * trivial * trivial: rm print * modify utils filename for infer ops test (#4657) * [Infer] Fix TPInferEngine init & inference tests, benchmarks (#4670) * fix engine funcs * TPInferEngine: receive shard config in init * benchmarks: revise TPInferEngine init * benchmarks: remove pytest decorator * trivial fix * use small model for tests * [NFC] use args for infer benchmarks (#4674) * revise infer default (#4683) * [Fix] optimize/shard model in TPInferEngine init (#4684) * remove using orig model in engine * revise inference tests * trivial: rename --------- Co-authored-by: Jianghai <72591262+CjhHa1@users.noreply.github.com> Co-authored-by: Xu Kai <xukai16@foxmail.com> Co-authored-by: Yuanheng Zhao <54058983+yuanheng-zhao@users.noreply.github.com> Co-authored-by: yuehuayingxueluo <867460659@qq.com> Co-authored-by: yuanheng-zhao <jonathan.zhaoyh@gmail.com> Co-authored-by: CjhHa1 <cjh18671720497@outlook.com>
2023-09-11 17:22:56 +00:00
self.policy = get_autopolicy(self.model, shard_config.inference_only) if policy is None else policy
self.shard_config = shard_config
2023-07-05 06:16:55 +00:00
def shard(self) -> List[Dict[int, Tensor]]:
r"""
Shard the model according to the policy
"""
self.policy.set_model(self.model)
self.policy.set_shard_config(self.shard_config)
self._preprocess()
# get shared params before release unheld layers, this avoid misjudgement of shared params (None is None)
shared_params = self.policy.get_shared_params()
held_layers = self._release_unheld_layers()
self._replace_module(include=held_layers)
2023-07-05 06:16:55 +00:00
self._materialize()
self._postprocess()
return shared_params
def _preprocess(self) -> None:
self.model = self.policy.preprocess()
def _postprocess(self) -> None:
self.model = self.policy.postprocess()
def _replace_module(self, include: Optional[Set[nn.Module]] = None) -> None:
r"""
Replace the module according to the policy, and replace the module one by one
Args:
model (:class:`torch.nn.Module`): The model to shard
"""
module_descriptions = self.policy.module_policy()
for layer_cls, module_description in module_descriptions.items():
attr_replacement = module_description.attribute_replacement
param_replacement = module_description.param_replacement
sub_module_replacement = module_description.sub_module_replacement
method_replacement = module_description.method_replacement
self._recursive_replace_layer(
self.model,
layer_cls,
attr_replacement,
param_replacement,
method_replacement,
sub_module_replacement,
include=include,
)
def _recursive_replace_layer(
self,
module: nn.Module,
origin_cls: Union[str, nn.Module],
attr_replacement: Dict[str, Any],
param_replacement: List[Callable],
method_replacement: Dict[str, Callable],
sub_module_replacement: List[SubModuleReplacementDescription],
include: Optional[Set[nn.Module]] = None,
) -> None:
r"""
Reverse the replace layer operation
Args:
module (torch.nn.Module): The object of layer to shard
origin_cls (Union[str, torch.nn.Module]): The origin layer class or a string of layer class name
attr_replacement (Dict[str, Any]): The attribute dict to modify
param_replacement (List[Callable]): The function list to get parameter shard information in policy
method_replacement (Dict[str, Callable]): Key is the method name, value is the method for replacement
sub_module_replacement ((List[SubModuleReplacementDescription]): The function list to get sub module shard information in policy
include (Set[nn.Module], optional): The set of modules to keep on current device when pipeline parallel is enabled. Defaults to None
"""
if (isinstance(origin_cls, str) and origin_cls == module.__class__.__name__) or (
module.__class__ == origin_cls
):
if attr_replacement is not None:
self._replace_attr(module, attr_replacement)
if param_replacement is not None and (include is None or module in include):
self._replace_param(module, param_replacement)
if method_replacement is not None:
self._replace_method(module, method_replacement)
if sub_module_replacement is not None:
self._replace_sub_module(module, sub_module_replacement, include)
for name, child in module.named_children():
self._recursive_replace_layer(
child,
origin_cls,
attr_replacement,
param_replacement,
method_replacement,
sub_module_replacement,
include=include,
)
def _replace_attr(
self,
module: nn.Module,
attr_replacement: Dict[str, Any],
) -> None:
r"""
Replace the attribute of the layer
Args:
module (:class:`torch.nn.Module`): The object of layer to shard
attr_replacement (Dict): The attribute dict to modify
"""
for k, v in attr_replacement.items():
setattr_(module, k, v, ignore=True)
def _replace_param(
self,
module: nn.Module,
param_replacement: List[Callable],
) -> None:
r"""
Replace the parameter of the layer
Args:
module (:class:`torch.nn.Module`): The object of layer to shard
param_replacement (List[Callable]): The function list to get parameter shard information in policy
"""
for param_func in param_replacement:
param_func(module)
def _replace_method(self, module: nn.Module, method_replacement: Dict[str, Callable]):
for method_name, new_method in method_replacement.items():
# bind the new method to the module
bound_method = MethodType(new_method, module)
setattr(module, method_name, bound_method)
def _replace_sub_module(
self,
org_layer: nn.Module,
sub_module_replacement: List[SubModuleReplacementDescription],
include: Optional[Set[nn.Module]] = None,
) -> None:
r"""
Shard one layer according to the policy, the layer should be the same class as the key in policy's argument_policy return dict
Args:
org_layer (torch.nn.Module): The origin layer object to shard
sub_module_replacement (List[SubModuleReplacementDescription]): The sub module replacement description list
include (Set[nn.Module], optional): The set of modules to keep on current device when pipeline parallel is enabled. Defaults to None
"""
for description in sub_module_replacement:
suffix = description.suffix
target_module = description.target_module
2023-06-16 08:12:27 +00:00
kwargs = {} if description.kwargs is None else description.kwargs
assert target_module is not None, "target_module should not be None"
native_sub_module = getattr_(org_layer, suffix, ignore=True)
# Skip replacement if submodule is not kept by current device when pipeline parallel is enabled.
if (include is not None) and (native_sub_module is not None) and (native_sub_module not in include):
continue
assert not isinstance(
native_sub_module, target_module
), f"The module with suffix {suffix} has been replaced, please check the policy"
# if it is None and we are allowed to ignore this module
# just skip
if description.ignore_if_not_exist and native_sub_module is None:
continue
try:
replace_layer = target_module.from_native_module(
native_sub_module, self.shard_config.tensor_parallel_process_group, **kwargs
)
except Exception as e:
raise RuntimeError(
f"Failed to replace {suffix} of type {native_sub_module.__class__.__qualname__}"
f" with {target_module.__qualname__} with the exception: {e}. "
"Please check your model configuration or sharding policy, you can set up an issue for us to help you as well."
)
2023-06-16 03:23:30 +00:00
setattr_(org_layer, suffix, replace_layer)
2023-07-05 06:16:55 +00:00
def _get_recursive_held_layers(self, held_layers: Optional[List[nn.Module]]) -> Optional[List[nn.Module]]:
def collect_sub_modules(module: nn.Module):
if module is None:
return
recursive_held_layers.append(module)
for name, child in module.named_children():
collect_sub_modules(child)
recursive_held_layers = []
for module in held_layers:
collect_sub_modules(module)
return recursive_held_layers
def _release_unheld_layers(self) -> Optional[Set[nn.Module]]:
2023-07-05 06:16:55 +00:00
r"""
Release the unheld layers in the model
"""
if self.shard_config and self.shard_config.pipeline_stage_manager:
held_layers = self.policy.get_held_layers()
set_tensors_to_none(self.model, exclude=set(held_layers))
return set(self._get_recursive_held_layers(held_layers))
return None
2023-07-05 06:16:55 +00:00
def _materialize(self) -> None:
r"""
Materialize the model if lazy initialization is used
"""
LazyInitContext.materialize(self.model)