mirror of https://github.com/hpcaitech/ColossalAI
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
207 lines
8.3 KiB
207 lines
8.3 KiB
from functools import partial
|
|
from typing import List
|
|
|
|
import torch
|
|
from torch.nn import Module
|
|
from transformers.models.llama.modeling_llama import (
|
|
LlamaAttention,
|
|
LlamaDecoderLayer,
|
|
LlamaForCausalLM,
|
|
LlamaModel,
|
|
LlamaRMSNorm,
|
|
)
|
|
|
|
from colossalai.shardformer.policies.base_policy import ModulePolicyDescription, SubModuleReplacementDescription
|
|
|
|
# import colossalai
|
|
from colossalai.shardformer.policies.llama import LlamaForCausalLMPolicy
|
|
|
|
from ..modeling._utils import init_to_get_rotary
|
|
from ..modeling.llama import LlamaInferenceForwards
|
|
|
|
try:
|
|
from colossalai.kernel.triton import rmsnorm_forward
|
|
|
|
HAS_TRITON_RMSNORM = True
|
|
except:
|
|
print("you should install triton from https://github.com/openai/triton")
|
|
HAS_TRITON_RMSNORM = False
|
|
|
|
|
|
def get_triton_rmsnorm_forward():
|
|
if HAS_TRITON_RMSNORM:
|
|
|
|
def _triton_rmsnorm_forward(self: LlamaRMSNorm, hidden_states: torch.Tensor):
|
|
return rmsnorm_forward(hidden_states, self.weight.data, self.variance_epsilon)
|
|
|
|
return _triton_rmsnorm_forward
|
|
else:
|
|
return None
|
|
|
|
|
|
class LlamaModelInferPolicy(LlamaForCausalLMPolicy):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
|
|
def module_policy(self):
|
|
policy = super().module_policy()
|
|
decoder_attribute_replacement = {
|
|
"self_attn.hidden_size": self.model.config.hidden_size // self.shard_config.tensor_parallel_size,
|
|
"self_attn.num_heads": self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size,
|
|
"self_attn.num_key_value_heads": self.model.config.num_key_value_heads
|
|
// self.shard_config.tensor_parallel_size,
|
|
}
|
|
if self.shard_config.extra_kwargs.get("quant", None) == "gptq":
|
|
from colossalai.inference.quant.gptq.cai_gptq import ColCaiQuantLinear, RowCaiQuantLinear
|
|
|
|
policy[LlamaDecoderLayer] = ModulePolicyDescription(
|
|
attribute_replacement=decoder_attribute_replacement,
|
|
sub_module_replacement=[
|
|
SubModuleReplacementDescription(
|
|
suffix="self_attn.q_proj",
|
|
target_module=ColCaiQuantLinear,
|
|
kwargs={"split_num": 1},
|
|
),
|
|
SubModuleReplacementDescription(
|
|
suffix="self_attn.k_proj",
|
|
target_module=ColCaiQuantLinear,
|
|
kwargs={"split_num": 1},
|
|
),
|
|
SubModuleReplacementDescription(
|
|
suffix="self_attn.v_proj",
|
|
target_module=ColCaiQuantLinear,
|
|
kwargs={"split_num": 1},
|
|
),
|
|
SubModuleReplacementDescription(
|
|
suffix="self_attn.o_proj",
|
|
target_module=RowCaiQuantLinear,
|
|
kwargs={"split_num": 1},
|
|
),
|
|
SubModuleReplacementDescription(
|
|
suffix="mlp.gate_proj",
|
|
target_module=ColCaiQuantLinear,
|
|
kwargs={"split_num": 1},
|
|
),
|
|
SubModuleReplacementDescription(
|
|
suffix="mlp.up_proj",
|
|
target_module=ColCaiQuantLinear,
|
|
kwargs={"split_num": 1},
|
|
),
|
|
SubModuleReplacementDescription(
|
|
suffix="mlp.down_proj",
|
|
target_module=RowCaiQuantLinear,
|
|
kwargs={"split_num": 1},
|
|
),
|
|
],
|
|
)
|
|
|
|
elif self.shard_config.extra_kwargs.get("quant", None) == "smoothquant":
|
|
from colossalai.inference.quant.smoothquant.models.llama import LlamaSmoothquantDecoderLayer
|
|
from colossalai.inference.quant.smoothquant.models.parallel_linear import (
|
|
ColW8A8BFP32OFP32Linear,
|
|
RowW8A8B8O8Linear,
|
|
RowW8A8BFP32O32LinearSiLU,
|
|
RowW8A8BFP32OFP32Linear,
|
|
)
|
|
|
|
policy[LlamaSmoothquantDecoderLayer] = ModulePolicyDescription(
|
|
attribute_replacement=decoder_attribute_replacement,
|
|
sub_module_replacement=[
|
|
SubModuleReplacementDescription(
|
|
suffix="self_attn.q_proj",
|
|
target_module=RowW8A8B8O8Linear,
|
|
kwargs={"split_num": 1},
|
|
),
|
|
SubModuleReplacementDescription(
|
|
suffix="self_attn.k_proj",
|
|
target_module=RowW8A8B8O8Linear,
|
|
kwargs={"split_num": 1},
|
|
),
|
|
SubModuleReplacementDescription(
|
|
suffix="self_attn.v_proj",
|
|
target_module=RowW8A8B8O8Linear,
|
|
kwargs={"split_num": 1},
|
|
),
|
|
SubModuleReplacementDescription(
|
|
suffix="self_attn.o_proj",
|
|
target_module=ColW8A8BFP32OFP32Linear,
|
|
kwargs={"split_num": 1},
|
|
),
|
|
SubModuleReplacementDescription(
|
|
suffix="mlp.gate_proj",
|
|
target_module=RowW8A8BFP32O32LinearSiLU,
|
|
kwargs={"split_num": 1},
|
|
),
|
|
SubModuleReplacementDescription(
|
|
suffix="mlp.up_proj",
|
|
target_module=RowW8A8BFP32OFP32Linear,
|
|
kwargs={"split_num": 1},
|
|
),
|
|
SubModuleReplacementDescription(
|
|
suffix="mlp.down_proj",
|
|
target_module=ColW8A8BFP32OFP32Linear,
|
|
kwargs={"split_num": 1},
|
|
),
|
|
],
|
|
)
|
|
self.shard_config._infer()
|
|
|
|
infer_forward = LlamaInferenceForwards.llama_model_forward
|
|
method_replacement = {"forward": partial(infer_forward)}
|
|
self.append_or_create_method_replacement(description=method_replacement, policy=policy, target_key=LlamaModel)
|
|
|
|
infer_forward = LlamaInferenceForwards.llama_decoder_layer_forward
|
|
method_replacement = {"forward": partial(infer_forward)}
|
|
self.append_or_create_method_replacement(
|
|
description=method_replacement, policy=policy, target_key=LlamaDecoderLayer
|
|
)
|
|
|
|
infer_forward = LlamaInferenceForwards.llama_flash_attn_kvcache_forward
|
|
method_replacement = {"forward": partial(infer_forward)}
|
|
self.append_or_create_method_replacement(
|
|
description=method_replacement, policy=policy, target_key=LlamaAttention
|
|
)
|
|
|
|
# set as default, in inference we also use pipeline style forward, just setting stage as 1
|
|
self.set_pipeline_forward(
|
|
model_cls=LlamaForCausalLM, new_forward=LlamaInferenceForwards.llama_causal_lm_forward, policy=policy
|
|
)
|
|
|
|
infer_forward = None
|
|
if HAS_TRITON_RMSNORM:
|
|
infer_forward = get_triton_rmsnorm_forward()
|
|
|
|
if infer_forward is not None:
|
|
method_replacement = {"forward": partial(infer_forward)}
|
|
self.append_or_create_method_replacement(
|
|
description=method_replacement, policy=policy, target_key=LlamaRMSNorm
|
|
)
|
|
|
|
return policy
|
|
|
|
def postprocess(self):
|
|
init_to_get_rotary(self.model.model)
|
|
return self.model
|
|
|
|
def get_held_layers(self) -> List[Module]:
|
|
"""Get pipeline layers for current stage."""
|
|
assert self.pipeline_stage_manager is not None
|
|
|
|
if self.model.__class__.__name__ == "LlamaModel":
|
|
module = self.model
|
|
else:
|
|
module = self.model.model
|
|
stage_manager = self.pipeline_stage_manager
|
|
|
|
held_layers = []
|
|
layers_per_stage = self.distribute_layers(len(module.layers), stage_manager.num_stages)
|
|
if stage_manager.is_first_stage():
|
|
held_layers.append(module.embed_tokens)
|
|
held_layers.append(self.model.lm_head)
|
|
start_idx, end_idx = self.get_stage_index(layers_per_stage, stage_manager.stage)
|
|
held_layers.extend(module.layers[start_idx:end_idx])
|
|
if stage_manager.is_last_stage():
|
|
held_layers.append(module.norm)
|
|
|
|
return held_layers
|