from functools import partial from typing import Callable, Dict, List, Optional, Union import torch import torch.nn as nn import torch.nn.functional as F from torch import Tensor from torch.nn import Module from transformers.modeling_outputs import CausalLMOutputWithPast from transformers.utils import logging from colossalai.moe.manager import MOE_MANAGER from colossalai.pipeline.stage_manager import PipelineStageManager from colossalai.shardformer.layer import FusedRMSNorm, Linear1D_Col from colossalai.shardformer.policies.base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription from .modeling_openmoe import OpenMoeDecoderLayer, OpenMoeForCausalLM, OpenMoeModel __all__ = ["OpenMoePolicy", "OpenMoeForCausalLMPolicy"] class OpenMoePolicy(Policy): def config_sanity_check(self): pass def preprocess(self): if self.shard_config.enable_tensor_parallelism: # Resize embedding vocab_size = self.model.config.vocab_size world_size = self.shard_config.tensor_parallel_size if vocab_size % world_size != 0: new_vocab_size = vocab_size + world_size - vocab_size % world_size self.model.resize_token_embeddings(new_vocab_size) return self.model def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: policy = {} if self.shard_config.enable_sequence_parallelism: self.shard_config.enable_sequence_parallelism = False raise NotImplementedError( "openmoe doesn't support sequence parallelism now, will ignore the sequence parallelism flag." ) if self.shard_config.enable_tensor_parallelism: raise NotImplementedError("Tensor parallelism is not supported for openmoe model now.") # optimization configuration if self.shard_config.enable_fused_normalization: self.append_or_create_submodule_replacement( description=[ SubModuleReplacementDescription( suffix="input_layernorm", target_module=FusedRMSNorm, ), SubModuleReplacementDescription( suffix="post_attention_layernorm", target_module=FusedRMSNorm, ), SubModuleReplacementDescription( suffix="pre_extra_mlp_layernorm", target_module=FusedRMSNorm, ignore_if_not_exist=True, ), ], policy=policy, target_key=OpenMoeDecoderLayer, ) self.append_or_create_submodule_replacement( description=SubModuleReplacementDescription( suffix="norm", target_module=FusedRMSNorm, ), policy=policy, target_key=OpenMoeModel, ) if self.shard_config.enable_flash_attention: raise NotImplementedError("Flash attention has already been replaced in openmoe.") return policy def postprocess(self): return self.model def set_pipeline_forward(self, model_cls: nn.Module, new_forward: Callable, policy: Dict) -> None: """If under pipeline parallel setting, replacing the original forward method of huggingface to customized forward method, and add this changing to policy.""" if self.pipeline_stage_manager: stage_manager = self.pipeline_stage_manager if self.model.__class__.__name__ == "OpenMoeModel": module = self.model else: module = self.model.model layers_per_stage = stage_manager.distribute_layers(len(module.layers)) stage_index = stage_manager.get_stage_index(layers_per_stage) method_replacement = {"forward": partial(new_forward, stage_manager=stage_manager, stage_index=stage_index)} self.append_or_create_method_replacement( description=method_replacement, policy=policy, target_key=model_cls ) return def get_held_layers(self) -> List[Module]: """Get pipeline layers for current stage.""" assert self.pipeline_stage_manager is not None if self.model.__class__.__name__ == "OpenMoeModel": module = self.model else: module = self.model.model stage_manager = self.pipeline_stage_manager held_layers = [] layers_per_stage = stage_manager.distribute_layers(len(module.layers)) if stage_manager.is_first_stage(): held_layers.append(module.embed_tokens) start_idx, end_idx = stage_manager.get_stage_index(layers_per_stage) held_layers.extend(module.layers[start_idx:end_idx]) if stage_manager.is_last_stage(): held_layers.append(module.norm) return held_layers def distribute_layers(self, num_layers: int, num_stages: int) -> List[int]: """Divide layers into stages""" if num_layers == 24 and num_stages == 4: return [7, 7, 7, 3] elif num_layers == 24 and num_stages == 2: return [15, 9] elif num_layers == 12 and num_stages == 4: return [5, 5, 5, 1] elif num_layers == 12 and num_stages == 2: return [8, 4] else: print(f"num_layers: {num_layers}, num_stages: {num_stages} not optimized, use origin pp policy") return super().distribute_layers(num_layers, num_stages) class OpenMoeModelPolicy(OpenMoePolicy): def __init__(self) -> None: super().__init__() def module_policy(self): policy = super().module_policy() if self.pipeline_stage_manager: # set None as default self.set_pipeline_forward( model_cls=OpenMoeModel, new_forward=OpenMoePipelineForwards.openmoe_model_forward, policy=policy, ) return policy def get_held_layers(self) -> List[Module]: """Get pipeline layers for current stage.""" held_layers = super().get_held_layers() return held_layers def get_shared_params(self) -> List[Dict[int, Tensor]]: """No shared params in llama model""" return [] class OpenMoeForCausalLMPolicy(OpenMoePolicy): def module_policy(self): policy = super().module_policy() if self.shard_config.enable_tensor_parallelism: # add a new item for casual lm new_item = { OpenMoeForCausalLM: ModulePolicyDescription( sub_module_replacement=[ SubModuleReplacementDescription( suffix="lm_head", target_module=Linear1D_Col, kwargs=dict(gather_output=True), ) ] ) } policy.update(new_item) if self.pipeline_stage_manager: # set None as default self.set_pipeline_forward( model_cls=OpenMoeForCausalLM, new_forward=OpenMoePipelineForwards.llama_for_causal_lm_forward, policy=policy, ) return policy def get_held_layers(self) -> List[Module]: """Get pipeline layers for current stage.""" stage_manager = self.pipeline_stage_manager held_layers = super().get_held_layers() if stage_manager.is_last_stage(): held_layers.append(self.model.lm_head) return held_layers def get_shared_params(self) -> List[Dict[int, Tensor]]: llama_model = self.model.model if self.pipeline_stage_manager and self.pipeline_stage_manager.num_stages > 1: if ( id(llama_model.embed_tokens.weight) == id(self.model.lm_head.weight) and self.pipeline_stage_manager.num_stages > 1 ): # tie weights return [ { 0: llama_model.embed_tokens.weight, self.pipeline_stage_manager.num_stages - 1: self.model.lm_head.weight, } ] return [] class OpenMoePipelineForwards: """ This class serves as a micro library for forward function substitution of Llama models under pipeline setting. """ @staticmethod def openmoe_model_forward( self: OpenMoeModel, input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[List[torch.FloatTensor]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, stage_manager: Optional[PipelineStageManager] = None, hidden_states: Optional[torch.FloatTensor] = None, stage_index: Optional[List[int]] = None, past_router_aux_loss: Optional[torch.FloatTensor] = None, past_router_z_loss: Optional[torch.FloatTensor] = None, ): # reset moe loss for different data MOE_MANAGER.reset_loss() logger = logging.get_logger(__name__) output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) use_cache = use_cache if use_cache is not None else self.config.use_cache return_dict = return_dict if return_dict is not None else self.config.use_return_dict # retrieve input_ids and inputs_embeds if stage_manager.is_first_stage(): if input_ids is not None and inputs_embeds is not None: raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") elif input_ids is not None: batch_size, seq_length = input_ids.shape elif inputs_embeds is not None: batch_size, seq_length, _ = inputs_embeds.shape else: raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") device = input_ids.device if input_ids is not None else inputs_embeds.device if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) hidden_states = inputs_embeds else: input_shape = hidden_states.shape[:-1] batch_size, seq_length = input_shape device = hidden_states.device seq_length_with_past = seq_length past_key_values_length = 0 # TODO(jianghai): left the recording kv-value tensors as () or None type, this feature may be added in the future. if output_attentions: logger.warning_once("output_attentions=True is not supported for pipeline models at the moment.") output_attentions = False if output_hidden_states: logger.warning_once("output_hidden_states=True is not supported for pipeline models at the moment.") output_hidden_states = False if use_cache: logger.warning_once("use_cache=True is not supported for pipeline models at the moment.") use_cache = False if past_key_values is not None: past_key_values_length = past_key_values[0][0].shape[2] seq_length_with_past = seq_length_with_past + past_key_values_length if position_ids is None: position_ids = torch.arange( past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device, ) position_ids = position_ids.unsqueeze(0).view(-1, seq_length) else: position_ids = position_ids.view(-1, seq_length).long() # embed positions, for the first stage, hidden_states is the input embeddings, # for the other stages, hidden_states is the output of the previous stage if attention_mask is None: attention_mask = torch.ones( (batch_size, seq_length_with_past), dtype=torch.bool, device=hidden_states.device, ) attention_mask = self._prepare_decoder_attention_mask( attention_mask, (batch_size, seq_length), hidden_states, past_key_values_length, ) if self.gradient_checkpointing and self.training: if use_cache: logger.warning_once( "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." ) use_cache = False # decoder layers all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None next_decoder_cache = () if use_cache else None start_idx, end_idx = stage_index[0], stage_index[1] for idx, decoder_layer in enumerate(self.layers[start_idx:end_idx], start=start_idx): if output_hidden_states: all_hidden_states += (hidden_states,) past_key_value = past_key_values[idx] if past_key_values is not None else None if self.gradient_checkpointing and self.training: def create_custom_forward(module): def custom_forward(*inputs): # None for past_key_value return module(*inputs, output_attentions, None) return custom_forward layer_outputs = torch.utils.checkpoint.checkpoint( create_custom_forward(decoder_layer), hidden_states, attention_mask, position_ids, None, ) else: layer_outputs = decoder_layer( hidden_states, attention_mask=attention_mask, position_ids=position_ids, past_key_value=past_key_value, output_attentions=output_attentions, use_cache=use_cache, ) hidden_states = layer_outputs[0] if use_cache: next_decoder_cache += (layer_outputs[2 if output_attentions else 1],) if output_attentions: all_self_attns += (layer_outputs[1],) if stage_manager.is_last_stage(): hidden_states = self.norm(hidden_states) # add hidden states from the last decoder layer if output_hidden_states: all_hidden_states += (hidden_states,) next_cache = next_decoder_cache if use_cache else None # concat past losses with current ones router_aux_loss, router_z_loss = MOE_MANAGER.get_loss() if past_router_aux_loss is not None and past_router_z_loss is not None: router_aux_loss = past_router_aux_loss + router_aux_loss router_z_loss = past_router_z_loss + router_z_loss if stage_manager.is_last_stage(): return tuple( [ hidden_states, next_cache, all_hidden_states, all_self_attns, router_aux_loss, router_z_loss, ] ) # always return dict for imediate stage return { "hidden_states": hidden_states, "router_aux_loss": router_aux_loss, "router_z_loss": router_z_loss, } @staticmethod def llama_for_causal_lm_forward( self: OpenMoeForCausalLM, input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[List[torch.FloatTensor]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, stage_manager: Optional[PipelineStageManager] = None, hidden_states: Optional[torch.FloatTensor] = None, stage_index: Optional[List[int]] = None, chunk_head: Optional[bool] = True, past_router_aux_loss: Optional[torch.FloatTensor] = None, past_router_z_loss: Optional[torch.FloatTensor] = None, ): r""" Args: labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. Returns: Example: ```python >>> from transformers import AutoTokenizer, LlamaForCausalLM >>> model = LlamaForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS) >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER) >>> prompt = "Hey, are you consciours? Can you talk to me?" >>> inputs = tokenizer(prompt, return_tensors="pt") >>> # Generate >>> generate_ids = model.generate(inputs.input_ids, max_length=30) >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] "Hey, are you consciours? Can you talk to me?\nI'm not consciours, but I can talk to you." ```""" logger = logging.get_logger(__name__) output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) return_dict = return_dict if return_dict is not None else self.config.use_return_dict # TODO(jianghai): left the recording kv-value tensors as () or None type, this feature may be added in the future. if output_attentions: logger.warning_once("output_attentions=True is not supported for pipeline models at the moment.") output_attentions = False if output_hidden_states: logger.warning_once("output_hidden_states=True is not supported for pipeline models at the moment.") output_hidden_states = False # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) outputs = OpenMoePipelineForwards.openmoe_model_forward( self.model, input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, stage_manager=stage_manager, hidden_states=hidden_states, stage_index=stage_index, past_router_aux_loss=past_router_aux_loss, past_router_z_loss=past_router_z_loss, ) if stage_manager.is_last_stage(): ( hidden_states, past_key_values, all_hidden_states, attentions, router_aux_loss, router_z_loss, ) = outputs if self.pretraining_tp > 1: lm_head_slices = self.lm_head.weight.split(self.vocab_size // self.pretraining_tp, dim=0) logits = [F.linear(hidden_states, lm_head_slices[i]) for i in range(self.pretraining_tp)] logits = torch.cat(logits, dim=-1) loss = None # if no training, just do forward if labels is None: logits = self.lm_head(hidden_states) logits = logits.float() # the vocab size for openmoe is 30w+ # which causes great activation memory in training, up to 20G for one sequence # so we use chunk and checkpoint to reduce memory else: if chunk_head == True: def create_custom_forward(module): def custom_forward(*inputs): logits = module(inputs[0]) logits = logits.float() # Shift so that tokens < n predict n shift_logits = logits[..., :-1, :].contiguous().float() shift_labels = inputs[1][..., 1:].contiguous() # Flatten the tokens loss = self._calculate_loss(shift_logits, shift_labels) return loss return custom_forward aux_loss, z_loss = self._calculate_router_loss(router_aux_loss, router_z_loss) loss = aux_loss + z_loss for batch_idx in range(hidden_states.shape[0]): loss = loss + torch.utils.checkpoint.checkpoint( create_custom_forward(self.lm_head), hidden_states[batch_idx : batch_idx + 1, :], labels[batch_idx : batch_idx + 1, :], ) logits = None else: logits = self.lm_head(hidden_states) logits = logits.float() # Shift so that tokens < n predict n shift_logits = logits[..., :-1, :].contiguous() shift_labels = labels[..., 1:].contiguous() # Flatten the tokens aux_loss, z_loss = self._calculate_router_loss(router_aux_loss, router_z_loss) loss = aux_loss + z_loss loss = loss + self._calculate_loss(shift_logits, shift_labels) if not return_dict: output = (logits,) + outputs[1:] return (loss,) + output if loss is not None else output return CausalLMOutputWithPast( loss=loss, logits=logits, past_key_values=past_key_values, hidden_states=all_hidden_states, attentions=attentions, ) else: hidden_states = outputs["hidden_states"] router_aux_loss = outputs["router_aux_loss"] router_z_loss = outputs["router_z_loss"] return { "hidden_states": hidden_states, "past_router_aux_loss": router_aux_loss, "past_router_z_loss": router_z_loss, }