ColossalAI/colossalai/shardformer/policies/mixtral.py

309 lines
12 KiB
Python

import warnings
from functools import partial
from typing import Callable, Dict, List, Union
import torch.nn as nn
from torch import Tensor
from torch.nn import Module
from transformers.models.mixtral.modeling_mixtral import MixtralDecoderLayer, MixtralForCausalLM, MixtralModel
from colossalai.shardformer.layer import FusedRMSNorm, Linear1D_Col
from colossalai.shardformer.modeling.mixtral import EPMixtralSparseMoeBlock, MixtralPipelineForwards
from colossalai.shardformer.policies.base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
__all__ = ["MixtralPolicy", "MixtralForCausalLMPolicy"]
class MixtralPolicy(Policy):
def config_sanity_check(self):
pass
def preprocess(self):
if self.shard_config.enable_tensor_parallelism:
raise NotImplementedError
# # Resize embedding
# vocab_size = self.model.config.vocab_size
# world_size = self.shard_config.tensor_parallel_size
# if vocab_size % world_size != 0:
# new_vocab_size = vocab_size + world_size - vocab_size % world_size
# self.model.resize_token_embeddings(new_vocab_size)
return self.model
def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:
policy = {}
if self.shard_config.enable_sequence_parallelism:
self.shard_config.enable_sequence_parallelism = False
raise NotImplementedError(
"Mixtral dosen't support sequence parallelism now, will ignore the sequence parallelism flag."
)
if self.shard_config.enable_tensor_parallelism:
raise NotImplementedError
# assert (
# self.model.config.num_attention_heads % self.shard_config.tensor_parallel_size == 0
# ), f"The number of attention heads must be divisible by tensor parallel size."
# assert (
# self.model.config.num_key_value_heads % self.shard_config.tensor_parallel_size == 0
# ), f"The number of key_value heads must be divisible by tensor parallel size."
# decoder_attribute_replacement = {
# "self_attn.hidden_size": self.model.config.hidden_size // self.shard_config.tensor_parallel_size,
# "self_attn.num_heads": self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size,
# "self_attn.num_key_value_heads": self.model.config.num_key_value_heads
# // self.shard_config.tensor_parallel_size,
# }
# policy[MixtralDecoderLayer] = ModulePolicyDescription(
# attribute_replacement=decoder_attribute_replacement,
# sub_module_replacement=[
# SubModuleReplacementDescription(
# suffix="self_attn.q_proj",
# target_module=Linear1D_Col,
# kwargs={
# 'process_group': self.shard_config.tensor_parallel_process_group,
# }
# ),
# SubModuleReplacementDescription(
# suffix="self_attn.k_proj",
# target_module=Linear1D_Col,
# kwargs={
# 'process_group': self.shard_config.tensor_parallel_process_group,
# }
# ),
# SubModuleReplacementDescription(
# suffix="self_attn.v_proj",
# target_module=Linear1D_Col,
# kwargs={
# 'process_group': self.shard_config.tensor_parallel_process_group,
# }
# ),
# SubModuleReplacementDescription(
# suffix="self_attn.o_proj",
# target_module=Linear1D_Row,
# kwargs={
# 'process_group': self.shard_config.tensor_parallel_process_group,
# }
# ),
# # SubModuleReplacementDescription(
# # suffix="mlp.gate_proj",
# # target_module=Linear1D_Col,
# # ),
# # SubModuleReplacementDescription(
# # suffix="mlp.up_proj",
# # target_module=Linear1D_Col,
# # ),
# # SubModuleReplacementDescription(
# # suffix="mlp.down_proj",
# # target_module=Linear1D_Row,
# # ),
# ],
# )
if getattr(self.shard_config, "ep_group", None) is None:
# expert parallel
self.append_or_create_submodule_replacement(
description=[
SubModuleReplacementDescription(
suffix="block_sparse_moe",
target_module=EPMixtralSparseMoeBlock,
kwargs={"ep_group": self.shard_config.ep_group},
)
],
policy=policy,
target_key=MixtralDecoderLayer,
)
# optimization configuration
if self.shard_config.enable_fused_normalization:
self.append_or_create_submodule_replacement(
description=[
SubModuleReplacementDescription(
suffix="input_layernorm",
target_module=FusedRMSNorm,
),
SubModuleReplacementDescription(
suffix="post_attention_layernorm",
target_module=FusedRMSNorm,
),
],
policy=policy,
target_key=MixtralDecoderLayer,
)
self.append_or_create_submodule_replacement(
description=SubModuleReplacementDescription(
suffix="norm",
target_module=FusedRMSNorm,
),
policy=policy,
target_key=MixtralModel,
)
if self.shard_config.enable_flash_attention:
warnings.warn("Flash attention is natively supported in transformers, will ignore the flag.")
self.shard_config.enable_flash_attention = False
return policy
def postprocess(self):
return self.model
def set_pipeline_forward(self, model_cls: nn.Module, new_forward: Callable, policy: Dict) -> None:
"""If under pipeline parallel setting, replacing the original forward method of huggingface
to customized forward method, and add this changing to policy."""
if self.pipeline_stage_manager:
stage_manager = self.pipeline_stage_manager
if self.model.__class__.__name__ == "MixtralModel":
module = self.model
else:
module = self.model.model
layers_per_stage = stage_manager.distribute_layers(len(module.layers))
stage_index = stage_manager.get_stage_index(layers_per_stage)
method_replacement = {"forward": partial(new_forward, stage_manager=stage_manager, stage_index=stage_index)}
self.append_or_create_method_replacement(
description=method_replacement, policy=policy, target_key=model_cls
)
return
def get_held_layers(self) -> List[Module]:
"""Get pipeline layers for current stage."""
assert self.pipeline_stage_manager is not None
if self.model.__class__.__name__ == "MixtralModel":
module = self.model
else:
module = self.model.model
stage_manager = self.pipeline_stage_manager
held_layers = []
layers_per_stage = stage_manager.distribute_layers(len(module.layers))
if stage_manager.is_first_stage():
held_layers.append(module.embed_tokens)
start_idx, end_idx = stage_manager.get_stage_index(layers_per_stage)
held_layers.extend(module.layers[start_idx:end_idx])
if stage_manager.is_last_stage():
held_layers.append(module.norm)
return held_layers
class MixtralModelPolicy(MixtralPolicy):
def __init__(self) -> None:
super().__init__()
def module_policy(self):
policy = super().module_policy()
if self.pipeline_stage_manager:
# set None as default
self.set_pipeline_forward(
model_cls=MixtralModel,
new_forward=MixtralPipelineForwards.mixtral_model_forward,
policy=policy,
)
return policy
def get_held_layers(self) -> List[Module]:
"""Get pipeline layers for current stage."""
held_layers = super().get_held_layers()
return held_layers
def get_shared_params(self) -> List[Dict[int, Tensor]]:
"""No shared params in mixtral model"""
return []
class MixtralForCausalLMPolicy(MixtralPolicy):
def module_policy(self):
policy = super().module_policy()
# TODO: assign pg mesh from plugin to all modules
if self.shard_config.enable_tensor_parallelism:
# add a new item for casual lm
new_item = {
MixtralForCausalLM: ModulePolicyDescription(
sub_module_replacement=[
SubModuleReplacementDescription(
suffix="lm_head",
target_module=Linear1D_Col,
kwargs=dict(gather_output=True),
)
]
)
}
policy.update(new_item)
if self.pipeline_stage_manager:
# set None as default
self.set_pipeline_forward(
model_cls=MixtralForCausalLM,
new_forward=MixtralPipelineForwards.mixtral_for_causal_lm_forward,
policy=policy,
)
return policy
def get_held_layers(self) -> List[Module]:
"""Get pipeline layers for current stage."""
stage_manager = self.pipeline_stage_manager
held_layers = super().get_held_layers()
if stage_manager.is_last_stage():
held_layers.append(self.model.lm_head)
return held_layers
def get_shared_params(self) -> List[Dict[int, Tensor]]:
mixtral_model = self.model.model
if self.pipeline_stage_manager and self.pipeline_stage_manager.num_stages > 1:
if (
id(mixtral_model.embed_tokens.weight) == id(self.model.lm_head.weight)
and self.pipeline_stage_manager.num_stages > 1
):
# tie weights
return [
{
0: mixtral_model.embed_tokens.weight,
self.pipeline_stage_manager.num_stages - 1: self.model.lm_head.weight,
}
]
return []
class MixtralForSequenceClassificationPolicy(MixtralPolicy):
def module_policy(self):
from transformers import MixtralForSequenceClassification
policy = super().module_policy()
if self.shard_config.enable_tensor_parallelism:
# add a new item for sequence classification
new_item = {
MixtralForSequenceClassification: ModulePolicyDescription(
sub_module_replacement=[
SubModuleReplacementDescription(
suffix="score", target_module=Linear1D_Col, kwargs=dict(gather_output=True)
)
]
)
}
policy.update(new_item)
if self.pipeline_stage_manager:
raise NotImplementedError
return policy
def get_held_layers(self) -> List[Module]:
"""Get pipeline layers for current stage."""
stage_manager = self.pipeline_stage_manager
held_layers = super().get_held_layers()
if stage_manager.is_last_stage(ignore_chunk=True):
held_layers.append(self.model.score)
return held_layers
def get_shared_params(self) -> List[Dict[int, Tensor]]:
"""No shared params in llama for sequence classification model"""
return []