from functools import partial from typing import Callable, Dict, List, Optional, Tuple, Union import torch.nn as nn from torch import Tensor from transformers.modeling_outputs import BaseModelOutputWithPast import colossalai.shardformer.layer as col_nn from colossalai.pipeline.stage_manager import PipelineStageManager from colossalai.shardformer.modeling.chatglm2 import ChatGLMPipelineForwards from colossalai.shardformer.modeling.chatglm2_6b.configuration_chatglm import ChatGLMConfig from colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm import ( ChatGLMForConditionalGeneration, ChatGLMModel, GLMBlock, ) from ..modeling.chatglm2 import ( get_chatglm_sequence_parallel_forward_fn, get_flash_core_attention_forward, get_jit_fused_glm_block_forward, ) from ..modeling.jit import get_jit_fused_dropout_add_func from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription __all__ = ['ChatGLMPolicy', 'ChatGLMModelPolicy', 'ChatGLMForConditionalGenerationPolicy'] class ChatGLMPolicy(Policy): def config_sanity_check(self): pass def preprocess(self): # Resize embedding if self.shard_config.enable_tensor_parallelism: vocab_size = self.model.config.padded_vocab_size world_size = self.shard_config.tensor_parallel_size if vocab_size % world_size != 0: new_vocab_size = vocab_size + world_size - vocab_size % world_size self.model.resize_token_embeddings(new_vocab_size) return self.model def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: from colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm import ChatGLMModel, CoreAttention, GLMBlock policy = {} use_sequence_parallel = self.shard_config.enable_sequence_parallelism if self.shard_config.enable_tensor_parallelism: policy[ChatGLMModel] = ModulePolicyDescription(attribute_replacement={}, sub_module_replacement=[ SubModuleReplacementDescription( suffix="embedding.word_embeddings", target_module=col_nn.VocabParallelEmbedding1D, ) ]) policy[GLMBlock] = ModulePolicyDescription( attribute_replacement={ "self_attention.num_attention_heads_per_partition": self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size, "self_attention.projection_size": (self.model.config.kv_channels * self.model.config.num_attention_heads) // self.shard_config.tensor_parallel_size, "self_attention.qkv_hidden_size": (self.model.config.kv_channels * self.model.config.num_attention_heads * 3) // self.shard_config.tensor_parallel_size, "self_attention.core_attention.num_attention_heads_per_partition": self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size, "self_attention.core_attention.hidden_size_per_partition": self.model.config.kv_channels * self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size, }, param_replacement=[], sub_module_replacement=[ SubModuleReplacementDescription(suffix="self_attention.query_key_value", target_module=col_nn.Linear1D_Col, kwargs={ 'seq_parallel': use_sequence_parallel, 'seq_parallel_dim': 0 }), SubModuleReplacementDescription(suffix="self_attention.dense", target_module=col_nn.Linear1D_Row, kwargs={ 'seq_parallel': use_sequence_parallel, 'seq_parallel_dim': 0 }), SubModuleReplacementDescription( suffix="self_attention.core_attention.attention_dropout", target_module=col_nn.DropoutForParallelInput, ), ]) # optimization configuration if self.shard_config.enable_fused_normalization: if not self.model.config.rmsnorm: self.append_or_create_submodule_replacement(description=[ SubModuleReplacementDescription(suffix="input_layernorm", target_module=col_nn.FusedLayerNorm), SubModuleReplacementDescription(suffix="post_attention_layernorm", target_module=col_nn.FusedLayerNorm) ], policy=policy, target_key=GLMBlock) if self.model.config.post_layer_norm: self.append_or_create_submodule_replacement(description=[ SubModuleReplacementDescription(suffix="encoder.final_layernorm", target_module=col_nn.FusedLayerNorm) ], policy=policy, target_key=ChatGLMModel) else: self.append_or_create_submodule_replacement(description=[ SubModuleReplacementDescription(suffix="input_layernorm", target_module=col_nn.FusedRMSNorm), SubModuleReplacementDescription(suffix="post_attention_layernorm", target_module=col_nn.FusedRMSNorm) ], policy=policy, target_key=GLMBlock) if self.model.config.post_layer_norm: self.append_or_create_submodule_replacement(description=[ SubModuleReplacementDescription(suffix="encoder.final_layernorm", target_module=col_nn.FusedRMSNorm) ], policy=policy, target_key=ChatGLMModel) # use flash attention if self.shard_config.enable_flash_attention: self.append_or_create_method_replacement(description={ 'forward': get_flash_core_attention_forward(), }, policy=policy, target_key=CoreAttention) # use sequence parallel if use_sequence_parallel: self.append_or_create_method_replacement( description={'forward': get_chatglm_sequence_parallel_forward_fn(self.shard_config)}, policy=policy, target_key=ChatGLMModel) # use jit fused operator if self.shard_config.enable_jit_fused: self.append_or_create_method_replacement(description={ 'forward': get_jit_fused_glm_block_forward(), 'dropout_add': get_jit_fused_dropout_add_func(), }, policy=policy, target_key=GLMBlock) return policy def postprocess(self): return self.model def get_held_layers(self) -> List[nn.Module]: """Get pipeline layers for current stage.""" assert self.pipeline_stage_manager is not None if self.model.__class__.__name__ == 'ChatGLMModel': module = self.model else: module = self.model.transformer stage_manager = self.pipeline_stage_manager held_layers = [] layers_per_stage = self.distribute_layers(module.num_layers, stage_manager.num_stages) if stage_manager.is_first_stage(): held_layers.append(module.embedding) start_idx, end_idx = self.get_stage_index(layers_per_stage, stage_manager.stage) held_layers.extend(module.encoder.layers[start_idx:end_idx]) if stage_manager.is_last_stage(): if module.encoder.post_layer_norm: held_layers.append(module.encoder.final_layernorm) # rotary_pos_emb is needed for all stages held_layers.append(module.rotary_pos_emb) return held_layers def set_pipeline_forward(self, model_cls: nn.Module, new_forward: Callable, policy: Dict) -> None: """If under pipeline parallel setting, replacing the original forward method of huggingface to customized forward method, and add this changing to policy.""" if not self.pipeline_stage_manager: raise ValueError("set_pipeline_forward method can only be called when pipeline parallel is enabled.") stage_manager = self.pipeline_stage_manager if self.model.__class__.__name__ == 'ChatGLMModel': module = self.model else: module = self.model.transformer layers_per_stage = Policy.distribute_layers(module.num_layers, stage_manager.num_stages) stage_index = Policy.get_stage_index(layers_per_stage, stage_manager.stage) method_replacement = { 'forward': partial(new_forward, stage_manager=stage_manager, stage_index=stage_index, shard_config=self.shard_config) } self.append_or_create_method_replacement(description=method_replacement, policy=policy, target_key=model_cls) class ChatGLMModelPolicy(ChatGLMPolicy): def __init__(self) -> None: super().__init__() def module_policy(self): from transformers.models.gpt2.modeling_gpt2 import GPT2Model policy = super().module_policy() if self.pipeline_stage_manager is not None: self.set_pipeline_forward(model_cls=ChatGLMModel, new_forward=ChatGLMPipelineForwards.chatglm_model_forward, policy=policy) return policy def get_held_layers(self) -> List[nn.Module]: return super().get_held_layers() def get_shared_params(self) -> List[Dict[int, Tensor]]: """No shared params in ChatGLMModel.""" return [] class ChatGLMForConditionalGenerationPolicy(ChatGLMModelPolicy): def module_policy(self): policy = super().module_policy() if self.pipeline_stage_manager is not None: self.set_pipeline_forward(model_cls=ChatGLMForConditionalGeneration, new_forward=ChatGLMPipelineForwards.chatglm_for_conditional_generation_forward, policy=policy) return policy def get_held_layers(self) -> List[nn.Module]: held_layers = super().get_held_layers() if self.pipeline_stage_manager.is_last_stage(): held_layers.append(self.model.transformer.output_layer) return held_layers def get_shared_params(self) -> List[Dict[int, Tensor]]: """No shared params in ChatGLMForConditionalGenerationModel.""" return []