2023-08-24 07:50:02 +00:00
|
|
|
import warnings
|
2023-07-11 03:37:26 +00:00
|
|
|
from functools import partial
|
2023-07-21 02:46:39 +00:00
|
|
|
from typing import Callable, Dict, List, Union
|
2023-06-13 06:44:40 +00:00
|
|
|
|
|
|
|
import torch.nn as nn
|
2023-07-11 03:37:26 +00:00
|
|
|
from torch import Tensor
|
2023-07-21 02:46:39 +00:00
|
|
|
from torch.nn import Module
|
|
|
|
|
2023-11-03 05:32:43 +00:00
|
|
|
from colossalai.shardformer.layer import FusedRMSNorm, Linear1D_Col, Linear1D_Row, RMSNorm, VocabParallelEmbedding1D
|
2023-06-13 06:44:40 +00:00
|
|
|
|
2023-12-22 02:44:00 +00:00
|
|
|
from ..modeling.llama import (
|
|
|
|
LlamaPipelineForwards,
|
|
|
|
get_llama_flash_attention_forward,
|
2024-03-27 03:19:32 +00:00
|
|
|
get_llama_model_forward_for_flash_attn,
|
2023-12-22 02:44:00 +00:00
|
|
|
get_lm_forward_with_dist_cross_entropy,
|
|
|
|
)
|
2023-07-05 07:13:00 +00:00
|
|
|
from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
|
2023-06-13 06:44:40 +00:00
|
|
|
|
2023-09-19 06:20:26 +00:00
|
|
|
__all__ = ["LlamaPolicy", "LlamaForCausalLMPolicy", "LlamaForSequenceClassificationPolicy"]
|
2023-06-30 02:56:29 +00:00
|
|
|
|
2023-06-13 06:44:40 +00:00
|
|
|
|
|
|
|
class LlamaPolicy(Policy):
|
2023-06-30 01:32:37 +00:00
|
|
|
def config_sanity_check(self):
|
|
|
|
pass
|
|
|
|
|
2023-06-19 05:53:17 +00:00
|
|
|
def preprocess(self):
|
2023-07-10 02:48:53 +00:00
|
|
|
if self.shard_config.enable_tensor_parallelism:
|
|
|
|
# Resize embedding
|
|
|
|
vocab_size = self.model.config.vocab_size
|
|
|
|
world_size = self.shard_config.tensor_parallel_size
|
2023-06-19 05:53:17 +00:00
|
|
|
|
2023-07-10 02:48:53 +00:00
|
|
|
if vocab_size % world_size != 0:
|
|
|
|
new_vocab_size = vocab_size + world_size - vocab_size % world_size
|
|
|
|
self.model.resize_token_embeddings(new_vocab_size)
|
2023-06-19 05:53:17 +00:00
|
|
|
|
|
|
|
return self.model
|
|
|
|
|
|
|
|
def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:
|
2023-08-07 08:41:07 +00:00
|
|
|
from transformers.models.llama.modeling_llama import LlamaAttention, LlamaDecoderLayer, LlamaModel
|
2023-06-30 02:56:29 +00:00
|
|
|
|
2023-07-04 01:57:03 +00:00
|
|
|
policy = {}
|
|
|
|
|
2023-11-03 05:32:43 +00:00
|
|
|
if self.shard_config.enable_fused_normalization:
|
|
|
|
norm_cls = FusedRMSNorm
|
|
|
|
else:
|
|
|
|
norm_cls = RMSNorm
|
|
|
|
|
2023-08-24 07:50:02 +00:00
|
|
|
if self.shard_config.enable_sequence_parallelism:
|
|
|
|
self.shard_config.enable_sequence_parallelism = False
|
2024-01-30 01:57:38 +00:00
|
|
|
warnings.warn("Llama doesn't support sequence parallelism now, will ignore the sequence parallelism flag.")
|
2023-08-24 07:50:02 +00:00
|
|
|
|
2023-07-04 01:57:03 +00:00
|
|
|
if self.shard_config.enable_tensor_parallelism:
|
2023-09-07 02:15:13 +00:00
|
|
|
decoder_attribute_replacement = {
|
2023-09-09 14:45:36 +00:00
|
|
|
"self_attn.hidden_size": self.model.config.hidden_size // self.shard_config.tensor_parallel_size,
|
|
|
|
"self_attn.num_heads": self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size,
|
2023-09-07 02:15:13 +00:00
|
|
|
}
|
|
|
|
if getattr(self.model.config, "num_key_value_heads", False):
|
2023-09-19 06:20:26 +00:00
|
|
|
decoder_attribute_replacement["self_attn.num_key_value_heads"] = (
|
2023-09-07 02:15:13 +00:00
|
|
|
self.model.config.num_key_value_heads // self.shard_config.tensor_parallel_size
|
2023-09-19 06:20:26 +00:00
|
|
|
)
|
2023-09-07 02:15:13 +00:00
|
|
|
|
2023-07-04 01:57:03 +00:00
|
|
|
policy[LlamaDecoderLayer] = ModulePolicyDescription(
|
2023-09-07 02:15:13 +00:00
|
|
|
attribute_replacement=decoder_attribute_replacement,
|
2023-07-04 01:57:03 +00:00
|
|
|
sub_module_replacement=[
|
|
|
|
SubModuleReplacementDescription(
|
|
|
|
suffix="self_attn.q_proj",
|
|
|
|
target_module=Linear1D_Col,
|
|
|
|
),
|
|
|
|
SubModuleReplacementDescription(
|
|
|
|
suffix="self_attn.k_proj",
|
|
|
|
target_module=Linear1D_Col,
|
|
|
|
),
|
|
|
|
SubModuleReplacementDescription(
|
|
|
|
suffix="self_attn.v_proj",
|
|
|
|
target_module=Linear1D_Col,
|
|
|
|
),
|
|
|
|
SubModuleReplacementDescription(
|
|
|
|
suffix="self_attn.o_proj",
|
|
|
|
target_module=Linear1D_Row,
|
|
|
|
),
|
2023-07-03 07:29:11 +00:00
|
|
|
SubModuleReplacementDescription(
|
2023-07-04 01:57:03 +00:00
|
|
|
suffix="mlp.gate_proj",
|
|
|
|
target_module=Linear1D_Col,
|
|
|
|
),
|
|
|
|
SubModuleReplacementDescription(
|
|
|
|
suffix="mlp.up_proj",
|
|
|
|
target_module=Linear1D_Col,
|
|
|
|
),
|
|
|
|
SubModuleReplacementDescription(
|
|
|
|
suffix="mlp.down_proj",
|
|
|
|
target_module=Linear1D_Row,
|
2023-09-19 06:20:26 +00:00
|
|
|
),
|
2023-07-04 01:57:03 +00:00
|
|
|
],
|
|
|
|
)
|
|
|
|
|
2023-09-19 06:20:26 +00:00
|
|
|
self.append_or_create_submodule_replacement(
|
|
|
|
description=SubModuleReplacementDescription(
|
|
|
|
suffix="embed_tokens",
|
|
|
|
target_module=VocabParallelEmbedding1D,
|
|
|
|
),
|
|
|
|
policy=policy,
|
|
|
|
target_key=LlamaModel,
|
|
|
|
)
|
2023-06-13 06:44:40 +00:00
|
|
|
|
2023-06-30 01:32:37 +00:00
|
|
|
# optimization configuration
|
2023-11-03 05:32:43 +00:00
|
|
|
self.append_or_create_submodule_replacement(
|
|
|
|
description=[
|
|
|
|
SubModuleReplacementDescription(
|
|
|
|
suffix="input_layernorm",
|
|
|
|
target_module=norm_cls,
|
2023-06-30 01:32:37 +00:00
|
|
|
),
|
2023-11-03 05:32:43 +00:00
|
|
|
SubModuleReplacementDescription(
|
|
|
|
suffix="post_attention_layernorm",
|
|
|
|
target_module=norm_cls,
|
|
|
|
),
|
|
|
|
],
|
|
|
|
policy=policy,
|
|
|
|
target_key=LlamaDecoderLayer,
|
|
|
|
)
|
|
|
|
|
|
|
|
self.append_or_create_submodule_replacement(
|
|
|
|
description=SubModuleReplacementDescription(
|
|
|
|
suffix="norm",
|
|
|
|
target_module=norm_cls,
|
|
|
|
),
|
|
|
|
policy=policy,
|
|
|
|
target_key=LlamaModel,
|
|
|
|
)
|
|
|
|
|
|
|
|
# use flash attention
|
2023-08-07 08:41:07 +00:00
|
|
|
if self.shard_config.enable_flash_attention:
|
2023-09-19 06:20:26 +00:00
|
|
|
self.append_or_create_method_replacement(
|
|
|
|
description={
|
2024-01-03 06:39:53 +00:00
|
|
|
"forward": get_llama_flash_attention_forward(self.shard_config),
|
2023-09-19 06:20:26 +00:00
|
|
|
},
|
|
|
|
policy=policy,
|
|
|
|
target_key=LlamaAttention,
|
|
|
|
)
|
2024-03-27 03:19:32 +00:00
|
|
|
if self.pipeline_stage_manager is None:
|
|
|
|
# replace llama model forward method
|
|
|
|
self.append_or_create_method_replacement(
|
|
|
|
description={
|
|
|
|
"forward": get_llama_model_forward_for_flash_attn(self.shard_config),
|
|
|
|
},
|
|
|
|
policy=policy,
|
|
|
|
target_key=LlamaModel,
|
|
|
|
)
|
2023-08-07 08:41:07 +00:00
|
|
|
|
2023-07-04 01:57:03 +00:00
|
|
|
return policy
|
2023-06-30 01:32:37 +00:00
|
|
|
|
2023-06-19 05:53:17 +00:00
|
|
|
def postprocess(self):
|
|
|
|
return self.model
|
2023-06-13 06:44:40 +00:00
|
|
|
|
2023-07-21 02:46:39 +00:00
|
|
|
def set_pipeline_forward(self, model_cls: nn.Module, new_forward: Callable, policy: Dict) -> None:
|
|
|
|
"""If under pipeline parallel setting, replacing the original forward method of huggingface
|
2023-09-19 06:20:26 +00:00
|
|
|
to customized forward method, and add this changing to policy."""
|
2023-12-22 02:44:00 +00:00
|
|
|
if self.pipeline_stage_manager is None:
|
|
|
|
return
|
|
|
|
|
|
|
|
stage_manager = self.pipeline_stage_manager
|
|
|
|
if self.model.__class__.__name__ == "LlamaModel":
|
|
|
|
module = self.model
|
|
|
|
else:
|
|
|
|
module = self.model.model
|
|
|
|
|
|
|
|
if stage_manager.is_interleave:
|
|
|
|
layers_per_stage = self.distribute_layers(
|
|
|
|
len(module.layers), stage_manager.num_stages * stage_manager.num_model_chunks
|
|
|
|
)
|
2024-03-27 05:57:00 +00:00
|
|
|
stage_manager.stage_indices = self.get_stage_index(
|
2023-12-22 02:44:00 +00:00
|
|
|
layers_per_stage,
|
|
|
|
stage_manager.stage,
|
|
|
|
num_model_chunks=stage_manager.num_model_chunks,
|
|
|
|
num_stages=stage_manager.num_stages,
|
|
|
|
)
|
|
|
|
method_replacement = {
|
|
|
|
"forward": partial(new_forward, stage_manager=stage_manager, shard_config=self.shard_config)
|
|
|
|
}
|
2023-07-21 02:46:39 +00:00
|
|
|
|
2023-12-22 02:44:00 +00:00
|
|
|
else:
|
2024-03-27 05:57:00 +00:00
|
|
|
layers_per_stage = self.distribute_layers(len(module.layers), stage_manager.num_stages)
|
|
|
|
stage_index = self.get_stage_index(layers_per_stage, stage_manager.stage)
|
2023-12-22 02:44:00 +00:00
|
|
|
method_replacement = {
|
|
|
|
"forward": partial(
|
|
|
|
new_forward, stage_manager=stage_manager, stage_index=stage_index, shard_config=self.shard_config
|
|
|
|
)
|
|
|
|
}
|
2023-09-19 06:20:26 +00:00
|
|
|
self.append_or_create_method_replacement(
|
|
|
|
description=method_replacement, policy=policy, target_key=model_cls
|
|
|
|
)
|
2023-07-21 02:46:39 +00:00
|
|
|
|
2023-12-22 02:44:00 +00:00
|
|
|
self.append_or_create_method_replacement(description=method_replacement, policy=policy, target_key=model_cls)
|
2023-07-11 03:37:26 +00:00
|
|
|
|
|
|
|
def get_held_layers(self) -> List[Module]:
|
|
|
|
"""Get pipeline layers for current stage."""
|
2023-07-21 02:46:39 +00:00
|
|
|
assert self.pipeline_stage_manager is not None
|
|
|
|
|
2023-09-19 06:20:26 +00:00
|
|
|
if self.model.__class__.__name__ == "LlamaModel":
|
2023-07-21 02:46:39 +00:00
|
|
|
module = self.model
|
|
|
|
else:
|
|
|
|
module = self.model.model
|
2023-07-11 03:37:26 +00:00
|
|
|
stage_manager = self.pipeline_stage_manager
|
2023-07-21 02:46:39 +00:00
|
|
|
|
2023-07-11 03:37:26 +00:00
|
|
|
held_layers = []
|
2023-12-22 02:44:00 +00:00
|
|
|
if stage_manager.is_interleave:
|
|
|
|
assert stage_manager.num_model_chunks is not None
|
|
|
|
layers_per_stage = self.distribute_layers(
|
|
|
|
len(module.layers), stage_manager.num_stages * stage_manager.num_model_chunks
|
|
|
|
)
|
2024-03-27 05:57:00 +00:00
|
|
|
stage_indices = self.get_stage_index(
|
2023-12-22 02:44:00 +00:00
|
|
|
layers_per_stage,
|
|
|
|
stage_manager.stage,
|
|
|
|
num_model_chunks=stage_manager.num_model_chunks,
|
|
|
|
num_stages=stage_manager.num_stages,
|
|
|
|
)
|
|
|
|
if stage_manager.is_first_stage(ignore_chunk=True):
|
|
|
|
held_layers.append(module.embed_tokens)
|
|
|
|
for start_idx, end_idx in stage_indices:
|
|
|
|
held_layers.extend(module.layers[start_idx:end_idx])
|
|
|
|
if stage_manager.is_last_stage(ignore_chunk=True):
|
|
|
|
held_layers.append(module.norm)
|
|
|
|
|
|
|
|
else:
|
|
|
|
layers_per_stage = self.distribute_layers(len(module.layers), stage_manager.num_stages)
|
|
|
|
if stage_manager.is_first_stage():
|
|
|
|
held_layers.append(module.embed_tokens)
|
|
|
|
start_idx, end_idx = self.get_stage_index(layers_per_stage, stage_manager.stage)
|
|
|
|
held_layers.extend(module.layers[start_idx:end_idx])
|
|
|
|
if stage_manager.is_last_stage():
|
|
|
|
held_layers.append(module.norm)
|
2023-07-21 02:46:39 +00:00
|
|
|
|
|
|
|
return held_layers
|
|
|
|
|
|
|
|
|
|
|
|
class LlamaModelPolicy(LlamaPolicy):
|
|
|
|
def module_policy(self):
|
|
|
|
policy = super().module_policy()
|
|
|
|
from transformers.models.llama.modeling_llama import LlamaModel
|
2023-09-19 06:20:26 +00:00
|
|
|
|
2023-07-21 02:46:39 +00:00
|
|
|
if self.pipeline_stage_manager:
|
|
|
|
# set None as default
|
2023-09-19 06:20:26 +00:00
|
|
|
self.set_pipeline_forward(
|
|
|
|
model_cls=LlamaModel, new_forward=LlamaPipelineForwards.llama_model_forward, policy=policy
|
|
|
|
)
|
2023-07-21 02:46:39 +00:00
|
|
|
return policy
|
|
|
|
|
|
|
|
def get_held_layers(self) -> List[Module]:
|
|
|
|
"""Get pipeline layers for current stage."""
|
|
|
|
held_layers = super().get_held_layers()
|
2023-07-11 03:37:26 +00:00
|
|
|
return held_layers
|
|
|
|
|
|
|
|
def get_shared_params(self) -> List[Dict[int, Tensor]]:
|
2023-07-11 07:23:33 +00:00
|
|
|
"""No shared params in llama model"""
|
2023-07-11 03:37:26 +00:00
|
|
|
return []
|
|
|
|
|
|
|
|
|
2023-06-19 05:53:17 +00:00
|
|
|
class LlamaForCausalLMPolicy(LlamaPolicy):
|
|
|
|
def module_policy(self):
|
2023-06-30 02:56:29 +00:00
|
|
|
from transformers import LlamaForCausalLM
|
|
|
|
|
2023-06-19 05:53:17 +00:00
|
|
|
policy = super().module_policy()
|
2023-07-04 01:57:03 +00:00
|
|
|
|
|
|
|
if self.shard_config.enable_tensor_parallelism:
|
|
|
|
# add a new item for casual lm
|
|
|
|
new_item = {
|
2023-09-19 06:20:26 +00:00
|
|
|
LlamaForCausalLM: ModulePolicyDescription(
|
|
|
|
sub_module_replacement=[
|
2024-03-27 03:21:03 +00:00
|
|
|
SubModuleReplacementDescription(
|
|
|
|
suffix="lm_head",
|
|
|
|
target_module=Linear1D_Col,
|
|
|
|
kwargs={"gather_output": not self.shard_config.parallel_output},
|
|
|
|
)
|
2023-12-12 17:39:14 +00:00
|
|
|
],
|
2023-09-19 06:20:26 +00:00
|
|
|
)
|
2023-07-04 01:57:03 +00:00
|
|
|
}
|
2024-03-25 09:21:51 +00:00
|
|
|
if self.shard_config.parallel_output:
|
2024-03-27 03:21:03 +00:00
|
|
|
new_item[LlamaForCausalLM].method_replacement = {
|
|
|
|
"forward": get_lm_forward_with_dist_cross_entropy(self.shard_config)
|
|
|
|
}
|
2023-07-04 01:57:03 +00:00
|
|
|
policy.update(new_item)
|
2023-07-11 07:23:33 +00:00
|
|
|
|
|
|
|
if self.pipeline_stage_manager:
|
|
|
|
# set None as default
|
2023-09-19 06:20:26 +00:00
|
|
|
self.set_pipeline_forward(
|
|
|
|
model_cls=LlamaForCausalLM, new_forward=LlamaPipelineForwards.llama_for_causal_lm_forward, policy=policy
|
|
|
|
)
|
2023-07-21 02:46:39 +00:00
|
|
|
|
2023-06-19 05:53:17 +00:00
|
|
|
return policy
|
2023-06-13 06:44:40 +00:00
|
|
|
|
2023-07-11 07:23:33 +00:00
|
|
|
def get_held_layers(self) -> List[Module]:
|
|
|
|
"""Get pipeline layers for current stage."""
|
|
|
|
stage_manager = self.pipeline_stage_manager
|
2023-07-21 02:46:39 +00:00
|
|
|
held_layers = super().get_held_layers()
|
2023-12-22 02:44:00 +00:00
|
|
|
if stage_manager.is_last_stage(ignore_chunk=True):
|
2023-07-21 02:46:39 +00:00
|
|
|
held_layers.append(self.model.lm_head)
|
2023-07-11 07:23:33 +00:00
|
|
|
return held_layers
|
|
|
|
|
|
|
|
def get_shared_params(self) -> List[Dict[int, Tensor]]:
|
|
|
|
llama_model = self.model.model
|
2023-07-21 02:46:39 +00:00
|
|
|
if self.pipeline_stage_manager and self.pipeline_stage_manager.num_stages > 1:
|
2023-09-19 06:20:26 +00:00
|
|
|
if (
|
|
|
|
id(llama_model.embed_tokens.weight) == id(self.model.lm_head.weight)
|
|
|
|
and self.pipeline_stage_manager.num_stages > 1
|
|
|
|
):
|
2023-07-21 02:46:39 +00:00
|
|
|
# tie weights
|
2023-09-19 06:20:26 +00:00
|
|
|
return [
|
|
|
|
{
|
|
|
|
0: llama_model.embed_tokens.weight,
|
|
|
|
self.pipeline_stage_manager.num_stages - 1: self.model.lm_head.weight,
|
|
|
|
}
|
|
|
|
]
|
2023-07-11 07:23:33 +00:00
|
|
|
return []
|
|
|
|
|
2023-06-13 06:44:40 +00:00
|
|
|
|
|
|
|
class LlamaForSequenceClassificationPolicy(LlamaPolicy):
|
2023-06-19 05:53:17 +00:00
|
|
|
def module_policy(self):
|
2023-06-30 02:56:29 +00:00
|
|
|
from transformers import LlamaForSequenceClassification
|
|
|
|
|
2023-06-19 05:53:17 +00:00
|
|
|
policy = super().module_policy()
|
|
|
|
|
2023-07-04 01:57:03 +00:00
|
|
|
if self.shard_config.enable_tensor_parallelism:
|
|
|
|
# add a new item for sequence classification
|
|
|
|
new_item = {
|
2023-09-19 06:20:26 +00:00
|
|
|
LlamaForSequenceClassification: ModulePolicyDescription(
|
|
|
|
sub_module_replacement=[
|
2023-07-04 01:57:03 +00:00
|
|
|
SubModuleReplacementDescription(
|
2023-09-19 06:20:26 +00:00
|
|
|
suffix="score", target_module=Linear1D_Col, kwargs=dict(gather_output=True)
|
|
|
|
)
|
|
|
|
]
|
|
|
|
)
|
2023-07-04 01:57:03 +00:00
|
|
|
}
|
|
|
|
policy.update(new_item)
|
2023-07-11 07:23:33 +00:00
|
|
|
# to be confirmed
|
|
|
|
if self.pipeline_stage_manager:
|
|
|
|
# set None as default
|
2023-09-19 06:20:26 +00:00
|
|
|
self.set_pipeline_forward(
|
|
|
|
model_cls=LlamaForSequenceClassification,
|
|
|
|
new_forward=LlamaPipelineForwards.llama_for_sequence_classification_forward,
|
|
|
|
policy=policy,
|
|
|
|
)
|
2023-06-19 05:53:17 +00:00
|
|
|
return policy
|
2023-07-11 03:37:26 +00:00
|
|
|
|
2023-07-11 07:23:33 +00:00
|
|
|
def get_held_layers(self) -> List[Module]:
|
|
|
|
"""Get pipeline layers for current stage."""
|
|
|
|
stage_manager = self.pipeline_stage_manager
|
2023-07-21 02:46:39 +00:00
|
|
|
held_layers = super().get_held_layers()
|
2023-12-22 02:44:00 +00:00
|
|
|
if stage_manager.is_last_stage(ignore_chunk=True):
|
2023-07-21 02:46:39 +00:00
|
|
|
held_layers.append(self.model.score)
|
2023-07-11 07:23:33 +00:00
|
|
|
return held_layers
|
|
|
|
|
|
|
|
def get_shared_params(self) -> List[Dict[int, Tensor]]:
|
|
|
|
"""No shared params in llama for sequence classification model"""
|
|
|
|
return []
|